bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py DELETED
@@ -1,2616 +0,0 @@
1
- # bartz/src/bartz/mcmcstep.py
2
- #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
4
- #
5
- # This file is part of bartz.
6
- #
7
- # Permission is hereby granted, free of charge, to any person obtaining a copy
8
- # of this software and associated documentation files (the "Software"), to deal
9
- # in the Software without restriction, including without limitation the rights
10
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- # copies of the Software, and to permit persons to whom the Software is
12
- # furnished to do so, subject to the following conditions:
13
- #
14
- # The above copyright notice and this permission notice shall be included in all
15
- # copies or substantial portions of the Software.
16
- #
17
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- # SOFTWARE.
24
-
25
- """
26
- Functions that implement the BART posterior MCMC initialization and update step.
27
-
28
- Functions that do MCMC steps operate by taking as input a bart state, and
29
- outputting a new state. The inputs are not modified.
30
-
31
- The entry points are:
32
-
33
- - `State`: The dataclass that represents a BART MCMC state.
34
- - `init`: Creates an initial `State` from data and configurations.
35
- - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36
- - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
37
- """
38
-
39
- import math
40
- from dataclasses import replace
41
- from functools import cache, partial
42
- from typing import Any, Literal
43
-
44
- import jax
45
- from equinox import Module, field, tree_at
46
- from jax import lax, random
47
- from jax import numpy as jnp
48
- from jax.scipy.special import gammaln, logsumexp
49
- from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
50
-
51
- from bartz import grove
52
- from bartz.jaxext import (
53
- minimal_unsigned_dtype,
54
- split,
55
- truncated_normal_onesided,
56
- vmap_nodoc,
57
- )
58
-
59
-
60
- class Forest(Module):
61
- """
62
- Represents the MCMC state of a sum of trees.
63
-
64
- Parameters
65
- ----------
66
- leaf_tree
67
- The leaf values.
68
- var_tree
69
- The decision axes.
70
- split_tree
71
- The decision boundaries.
72
- affluence_tree
73
- Marks leaves that can be grown.
74
- max_split
75
- The maximum split index for each predictor.
76
- blocked_vars
77
- Indices of variables that are not used. This shall include at least
78
- the `i` such that ``max_split[i] == 0``, otherwise behavior is
79
- undefined.
80
- p_nonterminal
81
- The prior probability of each node being nonterminal, conditional on
82
- its ancestors. Includes the nodes at maximum depth which should be set
83
- to 0.
84
- p_propose_grow
85
- The unnormalized probability of picking a leaf for a grow proposal.
86
- leaf_indices
87
- The index of the leaf each datapoints falls into, for each tree.
88
- min_points_per_decision_node
89
- The minimum number of data points in a decision node.
90
- min_points_per_leaf
91
- The minimum number of data points in a leaf node.
92
- resid_batch_size
93
- count_batch_size
94
- The data batch sizes for computing the sufficient statistics. If `None`,
95
- they are computed with no batching.
96
- log_trans_prior
97
- The log transition and prior Metropolis-Hastings ratio for the
98
- proposed move on each tree.
99
- log_likelihood
100
- The log likelihood ratio.
101
- grow_prop_count
102
- prune_prop_count
103
- The number of grow/prune proposals made during one full MCMC cycle.
104
- grow_acc_count
105
- prune_acc_count
106
- The number of grow/prune moves accepted during one full MCMC cycle.
107
- sigma_mu2
108
- The prior variance of a leaf, conditional on the tree structure.
109
- log_s
110
- The logarithm of the prior probability for choosing a variable to split
111
- along in a decision rule, conditional on the ancestors. Not normalized.
112
- If `None`, use a uniform distribution.
113
- theta
114
- The concentration parameter for the Dirichlet prior on the variable
115
- distribution `s`. Required only to update `s`.
116
- a
117
- b
118
- rho
119
- Parameters of the prior on `theta`. Required only to sample `theta`.
120
- See `step_theta`.
121
- """
122
-
123
- leaf_tree: Float32[Array, 'num_trees 2**d']
124
- var_tree: UInt[Array, 'num_trees 2**(d-1)']
125
- split_tree: UInt[Array, 'num_trees 2**(d-1)']
126
- affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
127
- max_split: UInt[Array, ' p']
128
- blocked_vars: UInt[Array, ' k'] | None
129
- p_nonterminal: Float32[Array, ' 2**d']
130
- p_propose_grow: Float32[Array, ' 2**(d-1)']
131
- leaf_indices: UInt[Array, 'num_trees n']
132
- min_points_per_decision_node: Int32[Array, ''] | None
133
- min_points_per_leaf: Int32[Array, ''] | None
134
- resid_batch_size: int | None = field(static=True)
135
- count_batch_size: int | None = field(static=True)
136
- log_trans_prior: Float32[Array, ' num_trees'] | None
137
- log_likelihood: Float32[Array, ' num_trees'] | None
138
- grow_prop_count: Int32[Array, '']
139
- prune_prop_count: Int32[Array, '']
140
- grow_acc_count: Int32[Array, '']
141
- prune_acc_count: Int32[Array, '']
142
- sigma_mu2: Float32[Array, '']
143
- log_s: Float32[Array, ' p'] | None
144
- theta: Float32[Array, ''] | None
145
- a: Float32[Array, ''] | None
146
- b: Float32[Array, ''] | None
147
- rho: Float32[Array, ''] | None
148
-
149
-
150
- class State(Module):
151
- """
152
- Represents the MCMC state of BART.
153
-
154
- Parameters
155
- ----------
156
- X
157
- The predictors.
158
- y
159
- The response. If the data type is `bool`, the model is binary regression.
160
- resid
161
- The residuals (`y` or `z` minus sum of trees).
162
- z
163
- The latent variable for binary regression. `None` in continuous
164
- regression.
165
- offset
166
- Constant shift added to the sum of trees.
167
- sigma2
168
- The error variance. `None` in binary regression.
169
- prec_scale
170
- The scale on the error precision, i.e., ``1 / error_scale ** 2``.
171
- `None` in binary regression.
172
- sigma2_alpha
173
- sigma2_beta
174
- The shape and scale parameters of the inverse gamma prior on the noise
175
- variance. `None` in binary regression.
176
- forest
177
- The sum of trees model.
178
- """
179
-
180
- X: UInt[Array, 'p n']
181
- y: Float32[Array, ' n'] | Bool[Array, ' n']
182
- z: None | Float32[Array, ' n']
183
- offset: Float32[Array, '']
184
- resid: Float32[Array, ' n']
185
- sigma2: Float32[Array, ''] | None
186
- prec_scale: Float32[Array, ' n'] | None
187
- sigma2_alpha: Float32[Array, ''] | None
188
- sigma2_beta: Float32[Array, ''] | None
189
- forest: Forest
190
-
191
-
192
- def init(
193
- *,
194
- X: UInt[Any, 'p n'],
195
- y: Float32[Any, ' n'] | Bool[Any, ' n'],
196
- offset: float | Float32[Any, ''] = 0.0,
197
- max_split: UInt[Any, ' p'],
198
- num_trees: int,
199
- p_nonterminal: Float32[Any, ' d-1'],
200
- sigma_mu2: float | Float32[Any, ''],
201
- sigma2_alpha: float | Float32[Any, ''] | None = None,
202
- sigma2_beta: float | Float32[Any, ''] | None = None,
203
- error_scale: Float32[Any, ' n'] | None = None,
204
- min_points_per_decision_node: int | Integer[Any, ''] | None = None,
205
- resid_batch_size: int | None | Literal['auto'] = 'auto',
206
- count_batch_size: int | None | Literal['auto'] = 'auto',
207
- save_ratios: bool = False,
208
- filter_splitless_vars: bool = True,
209
- min_points_per_leaf: int | Integer[Any, ''] | None = None,
210
- log_s: Float32[Any, ' p'] | None = None,
211
- theta: float | Float32[Any, ''] | None = None,
212
- a: float | Float32[Any, ''] | None = None,
213
- b: float | Float32[Any, ''] | None = None,
214
- rho: float | Float32[Any, ''] | None = None,
215
- ) -> State:
216
- """
217
- Make a BART posterior sampling MCMC initial state.
218
-
219
- Parameters
220
- ----------
221
- X
222
- The predictors. Note this is trasposed compared to the usual convention.
223
- y
224
- The response. If the data type is `bool`, the regression model is binary
225
- regression with probit.
226
- offset
227
- Constant shift added to the sum of trees. 0 if not specified.
228
- max_split
229
- The maximum split index for each variable. All split ranges start at 1.
230
- num_trees
231
- The number of trees in the forest.
232
- p_nonterminal
233
- The probability of a nonterminal node at each depth. The maximum depth
234
- of trees is fixed by the length of this array.
235
- sigma_mu2
236
- The prior variance of a leaf, conditional on the tree structure. The
237
- prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
238
- prior mean of leaves is always zero.
239
- sigma2_alpha
240
- sigma2_beta
241
- The shape and scale parameters of the inverse gamma prior on the error
242
- variance. Leave unspecified for binary regression.
243
- error_scale
244
- Each error is scaled by the corresponding factor in `error_scale`, so
245
- the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
246
- Not supported for binary regression. If not specified, defaults to 1 for
247
- all points, but potentially skipping calculations.
248
- min_points_per_decision_node
249
- The minimum number of data points in a decision node. 0 if not
250
- specified.
251
- resid_batch_size
252
- count_batch_size
253
- The batch sizes, along datapoints, for summing the residuals and
254
- counting the number of datapoints in each leaf. `None` for no batching.
255
- If 'auto', pick a value based on the device of `y`, or the default
256
- device.
257
- save_ratios
258
- Whether to save the Metropolis-Hastings ratios.
259
- filter_splitless_vars
260
- Whether to check `max_split` for variables without available cutpoints.
261
- If any are found, they are put into a list of variables to exclude from
262
- the MCMC. If `False`, no check is performed, but the results may be
263
- wrong if any variable is blocked. The function is jax-traceable only
264
- if this is set to `False`.
265
- min_points_per_leaf
266
- The minimum number of datapoints in a leaf node. 0 if not specified.
267
- Unlike `min_points_per_decision_node`, this constraint is not taken into
268
- account in the Metropolis-Hastings ratio because it would be expensive
269
- to compute. Grow moves that would violate this constraint are vetoed.
270
- This parameter is independent of `min_points_per_decision_node` and
271
- there is no check that they are coherent. It makes sense to set
272
- ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
273
- log_s
274
- The logarithm of the prior probability for choosing a variable to split
275
- along in a decision rule, conditional on the ancestors. Not normalized.
276
- If not specified, use a uniform distribution. If not specified and
277
- `theta` or `rho`, `a`, `b` are, it's initialized automatically.
278
- theta
279
- The concentration parameter for the Dirichlet prior on `s`. Required
280
- only to update `log_s`. If not specified, and `rho`, `a`, `b` are
281
- specified, it's initialized automatically.
282
- a
283
- b
284
- rho
285
- Parameters of the prior on `theta`. Required only to sample `theta`.
286
-
287
- Returns
288
- -------
289
- An initialized BART MCMC state.
290
-
291
- Raises
292
- ------
293
- ValueError
294
- If `y` is boolean and arguments unused in binary regression are set.
295
-
296
- Notes
297
- -----
298
- In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
299
- of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
300
- child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
301
- integers in the range ``[0, 1, ..., max_split[i]]``.
302
- """
303
- p_nonterminal = jnp.asarray(p_nonterminal)
304
- p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
305
- max_depth = p_nonterminal.size
306
-
307
- @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
308
- def make_forest(max_depth, dtype):
309
- return grove.make_tree(max_depth, dtype)
310
-
311
- y = jnp.asarray(y)
312
- offset = jnp.asarray(offset)
313
-
314
- resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
315
- resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
316
- )
317
-
318
- is_binary = y.dtype == bool
319
- if is_binary:
320
- if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
321
- msg = (
322
- 'error_scale, sigma2_alpha, and sigma2_beta must be set '
323
- ' to `None` for binary regression.'
324
- )
325
- raise ValueError(msg)
326
- sigma2 = None
327
- else:
328
- sigma2_alpha = jnp.asarray(sigma2_alpha)
329
- sigma2_beta = jnp.asarray(sigma2_beta)
330
- sigma2 = sigma2_beta / sigma2_alpha
331
-
332
- max_split = jnp.asarray(max_split)
333
-
334
- if filter_splitless_vars:
335
- (blocked_vars,) = jnp.nonzero(max_split == 0)
336
- blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
337
- # see `fully_used_variables` for the type cast
338
- else:
339
- blocked_vars = None
340
-
341
- # check and initialize sparsity parameters
342
- if not _all_none_or_not_none(rho, a, b):
343
- msg = 'rho, a, b are not either all `None` or all set'
344
- raise ValueError(msg)
345
- if theta is None and rho is not None:
346
- theta = rho
347
- if log_s is None and theta is not None:
348
- log_s = jnp.zeros(max_split.size)
349
-
350
- return State(
351
- X=jnp.asarray(X),
352
- y=y,
353
- z=jnp.full(y.shape, offset) if is_binary else None,
354
- offset=offset,
355
- resid=jnp.zeros(y.shape) if is_binary else y - offset,
356
- sigma2=sigma2,
357
- prec_scale=(
358
- None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
359
- ),
360
- sigma2_alpha=sigma2_alpha,
361
- sigma2_beta=sigma2_beta,
362
- forest=Forest(
363
- leaf_tree=make_forest(max_depth, jnp.float32),
364
- var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
365
- split_tree=make_forest(max_depth - 1, max_split.dtype),
366
- affluence_tree=(
367
- make_forest(max_depth - 1, bool)
368
- .at[:, 1]
369
- .set(
370
- True
371
- if min_points_per_decision_node is None
372
- else y.size >= min_points_per_decision_node
373
- )
374
- ),
375
- blocked_vars=blocked_vars,
376
- max_split=max_split,
377
- grow_prop_count=jnp.zeros((), int),
378
- grow_acc_count=jnp.zeros((), int),
379
- prune_prop_count=jnp.zeros((), int),
380
- prune_acc_count=jnp.zeros((), int),
381
- p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
382
- p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
383
- leaf_indices=jnp.ones(
384
- (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
385
- ),
386
- min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
387
- min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
388
- resid_batch_size=resid_batch_size,
389
- count_batch_size=count_batch_size,
390
- log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
391
- log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
392
- sigma_mu2=jnp.asarray(sigma_mu2),
393
- log_s=_asarray_or_none(log_s),
394
- theta=_asarray_or_none(theta),
395
- rho=_asarray_or_none(rho),
396
- a=_asarray_or_none(a),
397
- b=_asarray_or_none(b),
398
- ),
399
- )
400
-
401
-
402
- def _all_none_or_not_none(*args):
403
- is_none = [x is None for x in args]
404
- return all(is_none) or not any(is_none)
405
-
406
-
407
- def _asarray_or_none(x):
408
- if x is None:
409
- return None
410
- return jnp.asarray(x)
411
-
412
-
413
- def _choose_suffstat_batch_size(
414
- resid_batch_size, count_batch_size, y, forest_size
415
- ) -> tuple[int | None, ...]:
416
- @cache
417
- def get_platform():
418
- try:
419
- device = y.devices().pop()
420
- except jax.errors.ConcretizationTypeError:
421
- device = jax.devices()[0]
422
- platform = device.platform
423
- if platform not in ('cpu', 'gpu'):
424
- msg = f'Unknown platform: {platform}'
425
- raise KeyError(msg)
426
- return platform
427
-
428
- if resid_batch_size == 'auto':
429
- platform = get_platform()
430
- n = max(1, y.size)
431
- if platform == 'cpu':
432
- resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
433
- elif platform == 'gpu':
434
- resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
435
- resid_batch_size = max(1, resid_batch_size)
436
-
437
- if count_batch_size == 'auto':
438
- platform = get_platform()
439
- if platform == 'cpu':
440
- count_batch_size = None
441
- elif platform == 'gpu':
442
- n = max(1, y.size)
443
- count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
444
- # /4 is good on V100, /2 on L4/T4, still haven't tried A100
445
- max_memory = 2**29
446
- itemsize = 4
447
- min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
448
- count_batch_size = max(count_batch_size, min_batch_size)
449
- count_batch_size = max(1, count_batch_size)
450
-
451
- return resid_batch_size, count_batch_size
452
-
453
-
454
- @jax.jit
455
- def step(key: Key[Array, ''], bart: State) -> State:
456
- """
457
- Do one MCMC step.
458
-
459
- Parameters
460
- ----------
461
- key
462
- A jax random key.
463
- bart
464
- A BART mcmc state, as created by `init`.
465
-
466
- Returns
467
- -------
468
- The new BART mcmc state.
469
- """
470
- keys = split(key)
471
-
472
- if bart.y.dtype == bool: # binary regression
473
- bart = replace(bart, sigma2=jnp.float32(1))
474
- bart = step_trees(keys.pop(), bart)
475
- bart = replace(bart, sigma2=None)
476
- return step_z(keys.pop(), bart)
477
-
478
- else: # continuous regression
479
- bart = step_trees(keys.pop(), bart)
480
- return step_sigma(keys.pop(), bart)
481
-
482
-
483
- def step_trees(key: Key[Array, ''], bart: State) -> State:
484
- """
485
- Forest sampling step of BART MCMC.
486
-
487
- Parameters
488
- ----------
489
- key
490
- A jax random key.
491
- bart
492
- A BART mcmc state, as created by `init`.
493
-
494
- Returns
495
- -------
496
- The new BART mcmc state.
497
-
498
- Notes
499
- -----
500
- This function zeroes the proposal counters.
501
- """
502
- keys = split(key)
503
- moves = propose_moves(keys.pop(), bart.forest)
504
- return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
505
-
506
-
507
- class Moves(Module):
508
- """
509
- Moves proposed to modify each tree.
510
-
511
- Parameters
512
- ----------
513
- allowed
514
- Whether there is a possible move. If `False`, the other values may not
515
- make sense. The only case in which a move is marked as allowed but is
516
- then vetoed is if it does not satisfy `min_points_per_leaf`, which for
517
- efficiency is implemented post-hoc without changing the rest of the
518
- MCMC logic.
519
- grow
520
- Whether the move is a grow move or a prune move.
521
- num_growable
522
- The number of growable leaves in the original tree.
523
- node
524
- The index of the leaf to grow or node to prune.
525
- left
526
- right
527
- The indices of the children of 'node'.
528
- partial_ratio
529
- A factor of the Metropolis-Hastings ratio of the move. It lacks the
530
- likelihood ratio, the probability of proposing the prune move, and the
531
- probability that the children of the modified node are terminal. If the
532
- move is PRUNE, the ratio is inverted. `None` once
533
- `log_trans_prior_ratio` has been computed.
534
- log_trans_prior_ratio
535
- The logarithm of the product of the transition and prior terms of the
536
- Metropolis-Hastings ratio for the acceptance of the proposed move.
537
- `None` if not yet computed. If PRUNE, the log-ratio is negated.
538
- grow_var
539
- The decision axes of the new rules.
540
- grow_split
541
- The decision boundaries of the new rules.
542
- var_tree
543
- The updated decision axes of the trees, valid whatever move.
544
- affluence_tree
545
- A partially updated `affluence_tree`, marking non-leaf nodes that would
546
- become leaves if the move was accepted. This mark initially (out of
547
- `propose_moves`) takes into account if there would be available decision
548
- rules to grow the leaf, and whether there are enough datapoints in the
549
- node is marked in `accept_moves_parallel_stage`.
550
- logu
551
- The logarithm of a uniform (0, 1] random variable to be used to
552
- accept the move. It's in (-oo, 0].
553
- acc
554
- Whether the move was accepted. `None` if not yet computed.
555
- to_prune
556
- Whether the final operation to apply the move is pruning. This indicates
557
- an accepted prune move or a rejected grow move. `None` if not yet
558
- computed.
559
- """
560
-
561
- allowed: Bool[Array, ' num_trees']
562
- grow: Bool[Array, ' num_trees']
563
- num_growable: UInt[Array, ' num_trees']
564
- node: UInt[Array, ' num_trees']
565
- left: UInt[Array, ' num_trees']
566
- right: UInt[Array, ' num_trees']
567
- partial_ratio: Float32[Array, ' num_trees'] | None
568
- log_trans_prior_ratio: None | Float32[Array, ' num_trees']
569
- grow_var: UInt[Array, ' num_trees']
570
- grow_split: UInt[Array, ' num_trees']
571
- var_tree: UInt[Array, 'num_trees 2**(d-1)']
572
- affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
573
- logu: Float32[Array, ' num_trees']
574
- acc: None | Bool[Array, ' num_trees']
575
- to_prune: None | Bool[Array, ' num_trees']
576
-
577
-
578
- def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
579
- """
580
- Propose moves for all the trees.
581
-
582
- There are two types of moves: GROW (convert a leaf to a decision node and
583
- add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
584
- leaf, deleting its children).
585
-
586
- Parameters
587
- ----------
588
- key
589
- A jax random key.
590
- forest
591
- The `forest` field of a BART MCMC state.
592
-
593
- Returns
594
- -------
595
- The proposed move for each tree.
596
- """
597
- num_trees, _ = forest.leaf_tree.shape
598
- keys = split(key, 1 + 2 * num_trees)
599
-
600
- # compute moves
601
- grow_moves = propose_grow_moves(
602
- keys.pop(num_trees),
603
- forest.var_tree,
604
- forest.split_tree,
605
- forest.affluence_tree,
606
- forest.max_split,
607
- forest.blocked_vars,
608
- forest.p_nonterminal,
609
- forest.p_propose_grow,
610
- forest.log_s,
611
- )
612
- prune_moves = propose_prune_moves(
613
- keys.pop(num_trees),
614
- forest.split_tree,
615
- grow_moves.affluence_tree,
616
- forest.p_nonterminal,
617
- forest.p_propose_grow,
618
- )
619
-
620
- u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
621
-
622
- # choose between grow or prune
623
- p_grow = jnp.where(
624
- grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
625
- )
626
- grow = u < p_grow # use < instead of <= because u is in [0, 1)
627
-
628
- # compute children indices
629
- node = jnp.where(grow, grow_moves.node, prune_moves.node)
630
- left = node << 1
631
- right = left + 1
632
-
633
- return Moves(
634
- allowed=grow_moves.allowed | prune_moves.allowed,
635
- grow=grow,
636
- num_growable=grow_moves.num_growable,
637
- node=node,
638
- left=left,
639
- right=right,
640
- partial_ratio=jnp.where(
641
- grow, grow_moves.partial_ratio, prune_moves.partial_ratio
642
- ),
643
- log_trans_prior_ratio=None, # will be set in complete_ratio
644
- grow_var=grow_moves.var,
645
- grow_split=grow_moves.split,
646
- # var_tree does not need to be updated if prune
647
- var_tree=grow_moves.var_tree,
648
- # affluence_tree is updated for both moves unconditionally, prune last
649
- affluence_tree=prune_moves.affluence_tree,
650
- logu=jnp.log1p(-exp1mlogu),
651
- acc=None, # will be set in accept_moves_sequential_stage
652
- to_prune=None, # will be set in accept_moves_sequential_stage
653
- )
654
-
655
-
656
- class GrowMoves(Module):
657
- """
658
- Represent a proposed grow move for each tree.
659
-
660
- Parameters
661
- ----------
662
- allowed
663
- Whether the move is allowed for proposal.
664
- num_growable
665
- The number of leaves that can be proposed for grow.
666
- node
667
- The index of the leaf to grow. ``2 ** d`` if there are no growable
668
- leaves.
669
- var
670
- split
671
- The decision axis and boundary of the new rule.
672
- partial_ratio
673
- A factor of the Metropolis-Hastings ratio of the move. It lacks
674
- the likelihood ratio and the probability of proposing the prune
675
- move.
676
- var_tree
677
- The updated decision axes of the tree.
678
- affluence_tree
679
- A partially updated `affluence_tree` that marks each new leaf that
680
- would be produced as `True` if it would have available decision rules.
681
- """
682
-
683
- allowed: Bool[Array, ' num_trees']
684
- num_growable: UInt[Array, ' num_trees']
685
- node: UInt[Array, ' num_trees']
686
- var: UInt[Array, ' num_trees']
687
- split: UInt[Array, ' num_trees']
688
- partial_ratio: Float32[Array, ' num_trees']
689
- var_tree: UInt[Array, 'num_trees 2**(d-1)']
690
- affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
691
-
692
-
693
- @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
694
- def propose_grow_moves(
695
- key: Key[Array, ' num_trees'],
696
- var_tree: UInt[Array, 'num_trees 2**(d-1)'],
697
- split_tree: UInt[Array, 'num_trees 2**(d-1)'],
698
- affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
699
- max_split: UInt[Array, ' p'],
700
- blocked_vars: Int32[Array, ' k'] | None,
701
- p_nonterminal: Float32[Array, ' 2**d'],
702
- p_propose_grow: Float32[Array, ' 2**(d-1)'],
703
- log_s: Float32[Array, ' p'] | None,
704
- ) -> GrowMoves:
705
- """
706
- Propose a GROW move for each tree.
707
-
708
- A GROW move picks a leaf node and converts it to a non-terminal node with
709
- two leaf children.
710
-
711
- Parameters
712
- ----------
713
- key
714
- A jax random key.
715
- var_tree
716
- The splitting axes of the tree.
717
- split_tree
718
- The splitting points of the tree.
719
- affluence_tree
720
- Whether each leaf has enough points to be grown.
721
- max_split
722
- The maximum split index for each variable.
723
- blocked_vars
724
- The indices of the variables that have no available cutpoints.
725
- p_nonterminal
726
- The a priori probability of a node to be nonterminal conditional on the
727
- ancestors, including at the maximum depth where it should be zero.
728
- p_propose_grow
729
- The unnormalized probability of choosing a leaf to grow.
730
- log_s
731
- Unnormalized log-probability used to choose a variable to split on
732
- amongst the available ones.
733
-
734
- Returns
735
- -------
736
- An object representing the proposed move.
737
-
738
- Notes
739
- -----
740
- The move is not proposed if each leaf is already at maximum depth, or has
741
- less datapoints than the requested threshold `min_points_per_decision_node`,
742
- or it does not have any available decision rules given its ancestors. This
743
- is marked by setting `allowed` to `False` and `num_growable` to 0.
744
- """
745
- keys = split(key, 3)
746
-
747
- leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
748
- keys.pop(), split_tree, affluence_tree, p_propose_grow
749
- )
750
-
751
- # sample a decision rule
752
- var, num_available_var = choose_variable(
753
- keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
754
- )
755
- split_idx, l, r = choose_split(
756
- keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
757
- )
758
-
759
- # determine if the new leaves would have available decision rules; if the
760
- # move is blocked, these values may not make sense
761
- left_growable = right_growable = num_available_var > 1
762
- left_growable |= l < split_idx
763
- right_growable |= split_idx + 1 < r
764
- left = leaf_to_grow << 1
765
- right = left + 1
766
- affluence_tree = affluence_tree.at[left].set(left_growable)
767
- affluence_tree = affluence_tree.at[right].set(right_growable)
768
-
769
- ratio = compute_partial_ratio(
770
- prob_choose, num_prunable, p_nonterminal, leaf_to_grow
771
- )
772
-
773
- return GrowMoves(
774
- allowed=num_growable > 0,
775
- num_growable=num_growable,
776
- node=leaf_to_grow,
777
- var=var,
778
- split=split_idx,
779
- partial_ratio=ratio,
780
- var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
781
- affluence_tree=affluence_tree,
782
- )
783
-
784
-
785
- def choose_leaf(
786
- key: Key[Array, ''],
787
- split_tree: UInt[Array, ' 2**(d-1)'],
788
- affluence_tree: Bool[Array, ' 2**(d-1)'],
789
- p_propose_grow: Float32[Array, ' 2**(d-1)'],
790
- ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
791
- """
792
- Choose a leaf node to grow in a tree.
793
-
794
- Parameters
795
- ----------
796
- key
797
- A jax random key.
798
- split_tree
799
- The splitting points of the tree.
800
- affluence_tree
801
- Whether a leaf has enough points that it could be split into two leaves
802
- satisfying the `min_points_per_leaf` requirement.
803
- p_propose_grow
804
- The unnormalized probability of choosing a leaf to grow.
805
-
806
- Returns
807
- -------
808
- leaf_to_grow : Int32[Array, '']
809
- The index of the leaf to grow. If ``num_growable == 0``, return
810
- ``2 ** d``.
811
- num_growable : Int32[Array, '']
812
- The number of leaf nodes that can be grown, i.e., are nonterminal
813
- and have at least twice `min_points_per_leaf`.
814
- prob_choose : Float32[Array, '']
815
- The (normalized) probability that this function had to choose that
816
- specific leaf, given the arguments.
817
- num_prunable : Int32[Array, '']
818
- The number of leaf parents that could be pruned, after converting the
819
- selected leaf to a non-terminal node.
820
- """
821
- is_growable = growable_leaves(split_tree, affluence_tree)
822
- num_growable = jnp.count_nonzero(is_growable)
823
- distr = jnp.where(is_growable, p_propose_grow, 0)
824
- leaf_to_grow, distr_norm = categorical(key, distr)
825
- leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
826
- prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
827
- is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
828
- num_prunable = jnp.count_nonzero(is_parent)
829
- return leaf_to_grow, num_growable, prob_choose, num_prunable
830
-
831
-
832
- def growable_leaves(
833
- split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
834
- ) -> Bool[Array, ' 2**(d-1)']:
835
- """
836
- Return a mask indicating the leaf nodes that can be proposed for growth.
837
-
838
- The condition is that a leaf is not at the bottom level, has available
839
- decision rules given its ancestors, and has at least
840
- `min_points_per_decision_node` points.
841
-
842
- Parameters
843
- ----------
844
- split_tree
845
- The splitting points of the tree.
846
- affluence_tree
847
- Marks leaves that can be grown.
848
-
849
- Returns
850
- -------
851
- The mask indicating the leaf nodes that can be proposed to grow.
852
-
853
- Notes
854
- -----
855
- This function needs `split_tree` and not just `affluence_tree` because
856
- `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
857
- """
858
- return grove.is_actual_leaf(split_tree) & affluence_tree
859
-
860
-
861
- def categorical(
862
- key: Key[Array, ''], distr: Float32[Array, ' n']
863
- ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
864
- """
865
- Return a random integer from an arbitrary distribution.
866
-
867
- Parameters
868
- ----------
869
- key
870
- A jax random key.
871
- distr
872
- An unnormalized probability distribution.
873
-
874
- Returns
875
- -------
876
- u : Int32[Array, '']
877
- A random integer in the range ``[0, n)``. If all probabilities are zero,
878
- return ``n``.
879
- norm : Float32[Array, '']
880
- The sum of `distr`.
881
-
882
- Notes
883
- -----
884
- This function uses a cumsum instead of the Gumbel trick, so it's ok only
885
- for small ranges with probabilities well greater than 0.
886
- """
887
- ecdf = jnp.cumsum(distr)
888
- u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
889
- return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
890
-
891
-
892
- def choose_variable(
893
- key: Key[Array, ''],
894
- var_tree: UInt[Array, ' 2**(d-1)'],
895
- split_tree: UInt[Array, ' 2**(d-1)'],
896
- max_split: UInt[Array, ' p'],
897
- leaf_index: Int32[Array, ''],
898
- blocked_vars: Int32[Array, ' k'] | None,
899
- log_s: Float32[Array, ' p'] | None,
900
- ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
901
- """
902
- Choose a variable to split on for a new non-terminal node.
903
-
904
- Parameters
905
- ----------
906
- key
907
- A jax random key.
908
- var_tree
909
- The variable indices of the tree.
910
- split_tree
911
- The splitting points of the tree.
912
- max_split
913
- The maximum split index for each variable.
914
- leaf_index
915
- The index of the leaf to grow.
916
- blocked_vars
917
- The indices of the variables that have no available cutpoints. If
918
- `None`, all variables are assumed unblocked.
919
- log_s
920
- The logarithm of the prior probability for choosing a variable. If
921
- `None`, use a uniform distribution.
922
-
923
- Returns
924
- -------
925
- var : Int32[Array, '']
926
- The index of the variable to split on.
927
- num_available_var : Int32[Array, '']
928
- The number of variables with available decision rules `var` was chosen
929
- from.
930
- """
931
- var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
932
- if blocked_vars is not None:
933
- var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
934
-
935
- if log_s is None:
936
- return randint_exclude(key, max_split.size, var_to_ignore)
937
- else:
938
- return categorical_exclude(key, log_s, var_to_ignore)
939
-
940
-
941
- def fully_used_variables(
942
- var_tree: UInt[Array, ' 2**(d-1)'],
943
- split_tree: UInt[Array, ' 2**(d-1)'],
944
- max_split: UInt[Array, ' p'],
945
- leaf_index: Int32[Array, ''],
946
- ) -> UInt[Array, ' d-2']:
947
- """
948
- Find variables in the ancestors of a node that have an empty split range.
949
-
950
- Parameters
951
- ----------
952
- var_tree
953
- The variable indices of the tree.
954
- split_tree
955
- The splitting points of the tree.
956
- max_split
957
- The maximum split index for each variable.
958
- leaf_index
959
- The index of the node, assumed to be valid for `var_tree`.
960
-
961
- Returns
962
- -------
963
- The indices of the variables that have an empty split range.
964
-
965
- Notes
966
- -----
967
- The number of unused variables is not known in advance. Unused values in the
968
- array are filled with `p`. The fill values are not guaranteed to be placed
969
- in any particular order, and variables may appear more than once.
970
- """
971
- var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
972
- split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
973
- l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
974
- num_split = r - l
975
- return jnp.where(num_split == 0, var_to_ignore, max_split.size)
976
- # the type of var_to_ignore is already sufficient to hold max_split.size,
977
- # see ancestor_variables()
978
-
979
-
980
- def ancestor_variables(
981
- var_tree: UInt[Array, ' 2**(d-1)'],
982
- max_split: UInt[Array, ' p'],
983
- node_index: Int32[Array, ''],
984
- ) -> UInt[Array, ' d-2']:
985
- """
986
- Return the list of variables in the ancestors of a node.
987
-
988
- Parameters
989
- ----------
990
- var_tree
991
- The variable indices of the tree.
992
- max_split
993
- The maximum split index for each variable. Used only to get `p`.
994
- node_index
995
- The index of the node, assumed to be valid for `var_tree`.
996
-
997
- Returns
998
- -------
999
- The variable indices of the ancestors of the node.
1000
-
1001
- Notes
1002
- -----
1003
- The ancestors are the nodes going from the root to the parent of the node.
1004
- The number of ancestors is not known at tracing time; unused spots in the
1005
- output array are filled with `p`.
1006
- """
1007
- max_num_ancestors = grove.tree_depth(var_tree) - 1
1008
- ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
1009
- carry = ancestor_vars.size - 1, node_index, ancestor_vars
1010
-
1011
- def loop(carry, _):
1012
- i, index, ancestor_vars = carry
1013
- index >>= 1
1014
- var = var_tree[index]
1015
- var = jnp.where(index, var, max_split.size)
1016
- ancestor_vars = ancestor_vars.at[i].set(var)
1017
- return (i - 1, index, ancestor_vars), None
1018
-
1019
- (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
1020
- return ancestor_vars
1021
-
1022
-
1023
- def split_range(
1024
- var_tree: UInt[Array, ' 2**(d-1)'],
1025
- split_tree: UInt[Array, ' 2**(d-1)'],
1026
- max_split: UInt[Array, ' p'],
1027
- node_index: Int32[Array, ''],
1028
- ref_var: Int32[Array, ''],
1029
- ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1030
- """
1031
- Return the range of allowed splits for a variable at a given node.
1032
-
1033
- Parameters
1034
- ----------
1035
- var_tree
1036
- The variable indices of the tree.
1037
- split_tree
1038
- The splitting points of the tree.
1039
- max_split
1040
- The maximum split index for each variable.
1041
- node_index
1042
- The index of the node, assumed to be valid for `var_tree`.
1043
- ref_var
1044
- The variable for which to measure the split range.
1045
-
1046
- Returns
1047
- -------
1048
- The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
1049
- """
1050
- max_num_ancestors = grove.tree_depth(var_tree) - 1
1051
- initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
1052
- jnp.int32
1053
- )
1054
- carry = jnp.int32(0), initial_r, node_index
1055
-
1056
- def loop(carry, _):
1057
- l, r, index = carry
1058
- right_child = (index & 1).astype(bool)
1059
- index >>= 1
1060
- split = split_tree[index]
1061
- cond = (var_tree[index] == ref_var) & index.astype(bool)
1062
- l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
1063
- r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
1064
- return (l, r, index), None
1065
-
1066
- (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
1067
- return l + 1, r
1068
-
1069
-
1070
- def randint_exclude(
1071
- key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1072
- ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1073
- """
1074
- Return a random integer in a range, excluding some values.
1075
-
1076
- Parameters
1077
- ----------
1078
- key
1079
- A jax random key.
1080
- sup
1081
- The exclusive upper bound of the range.
1082
- exclude
1083
- The values to exclude from the range. Values greater than or equal to
1084
- `sup` are ignored. Values can appear more than once.
1085
-
1086
- Returns
1087
- -------
1088
- u : Int32[Array, '']
1089
- A random integer `u` in the range ``[0, sup)`` such that ``u not in
1090
- exclude``.
1091
- num_allowed : Int32[Array, '']
1092
- The number of integers in the range that were not excluded.
1093
-
1094
- Notes
1095
- -----
1096
- If all values in the range are excluded, return `sup`.
1097
- """
1098
- exclude, num_allowed = _process_exclude(sup, exclude)
1099
- u = random.randint(key, (), 0, num_allowed)
1100
-
1101
- def loop(u, i_excluded):
1102
- return jnp.where(i_excluded <= u, u + 1, u), None
1103
-
1104
- u, _ = lax.scan(loop, u, exclude)
1105
- return u, num_allowed
1106
-
1107
-
1108
- def _process_exclude(sup, exclude):
1109
- exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
1110
- num_allowed = sup - jnp.count_nonzero(exclude < sup)
1111
- return exclude, num_allowed
1112
-
1113
-
1114
- def categorical_exclude(
1115
- key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1116
- ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1117
- """
1118
- Draw from a categorical distribution, excluding a set of values.
1119
-
1120
- Parameters
1121
- ----------
1122
- key
1123
- A jax random key.
1124
- logits
1125
- The unnormalized log-probabilities of each category.
1126
- exclude
1127
- The values to exclude from the range [0, k). Values greater than or
1128
- equal to `logits.size` are ignored. Values can appear more than once.
1129
-
1130
- Returns
1131
- -------
1132
- u : Int32[Array, '']
1133
- A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1134
- num_allowed : Int32[Array, '']
1135
- The number of integers in the range that were not excluded.
1136
-
1137
- Notes
1138
- -----
1139
- If all values in the range are excluded, the result is unspecified.
1140
- """
1141
- exclude, num_allowed = _process_exclude(logits.size, exclude)
1142
- kinda_neg_inf = jnp.finfo(logits.dtype).min
1143
- logits = logits.at[exclude].set(kinda_neg_inf)
1144
- u = random.categorical(key, logits)
1145
- return u, num_allowed
1146
-
1147
-
1148
- def choose_split(
1149
- key: Key[Array, ''],
1150
- var: Int32[Array, ''],
1151
- var_tree: UInt[Array, ' 2**(d-1)'],
1152
- split_tree: UInt[Array, ' 2**(d-1)'],
1153
- max_split: UInt[Array, ' p'],
1154
- leaf_index: Int32[Array, ''],
1155
- ) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
1156
- """
1157
- Choose a split point for a new non-terminal node.
1158
-
1159
- Parameters
1160
- ----------
1161
- key
1162
- A jax random key.
1163
- var
1164
- The variable to split on.
1165
- var_tree
1166
- The splitting axes of the tree. Does not need to already contain `var`
1167
- at `leaf_index`.
1168
- split_tree
1169
- The splitting points of the tree.
1170
- max_split
1171
- The maximum split index for each variable.
1172
- leaf_index
1173
- The index of the leaf to grow.
1174
-
1175
- Returns
1176
- -------
1177
- split : Int32[Array, '']
1178
- The cutpoint.
1179
- l : Int32[Array, '']
1180
- r : Int32[Array, '']
1181
- The integer range `split` was drawn from is [l, r).
1182
-
1183
- Notes
1184
- -----
1185
- If `var` is out of bounds, or if the available split range on that variable
1186
- is empty, return 0.
1187
- """
1188
- l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
1189
- return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
1190
-
1191
-
1192
- def compute_partial_ratio(
1193
- prob_choose: Float32[Array, ''],
1194
- num_prunable: Int32[Array, ''],
1195
- p_nonterminal: Float32[Array, ' 2**d'],
1196
- leaf_to_grow: Int32[Array, ''],
1197
- ) -> Float32[Array, '']:
1198
- """
1199
- Compute the product of the transition and prior ratios of a grow move.
1200
-
1201
- Parameters
1202
- ----------
1203
- prob_choose
1204
- The probability that the leaf had to be chosen amongst the growable
1205
- leaves.
1206
- num_prunable
1207
- The number of leaf parents that could be pruned, after converting the
1208
- leaf to be grown to a non-terminal node.
1209
- p_nonterminal
1210
- The a priori probability of each node being nonterminal conditional on
1211
- its ancestors.
1212
- leaf_to_grow
1213
- The index of the leaf to grow.
1214
-
1215
- Returns
1216
- -------
1217
- The partial transition ratio times the prior ratio.
1218
-
1219
- Notes
1220
- -----
1221
- The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1222
- The "partial" transition ratio returned is missing the factor P(propose
1223
- prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1224
- "partial" prior ratio is missing the factor P(children are leaves).
1225
- """
1226
- # the two ratios also contain factors num_available_split *
1227
- # num_available_var * s[var], but they cancel out
1228
-
1229
- # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1230
- # computed here because they need the count trees, which are computed in the
1231
- # acceptance phase
1232
-
1233
- prune_allowed = leaf_to_grow != 1
1234
- # prune allowed <---> the initial tree is not a root
1235
- # leaf to grow is root --> the tree can only be a root
1236
- # tree is a root --> the only leaf I can grow is root
1237
- p_grow = jnp.where(prune_allowed, 0.5, 1)
1238
- inv_trans_ratio = p_grow * prob_choose * num_prunable
1239
-
1240
- # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1241
- # would produce a 0 and then an inf when `complete_ratio` takes the log
1242
- pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
1243
- tree_ratio = pnt / (1 - pnt)
1244
-
1245
- return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
1246
-
1247
-
1248
- class PruneMoves(Module):
1249
- """
1250
- Represent a proposed prune move for each tree.
1251
-
1252
- Parameters
1253
- ----------
1254
- allowed
1255
- Whether the move is possible.
1256
- node
1257
- The index of the node to prune. ``2 ** d`` if no node can be pruned.
1258
- partial_ratio
1259
- A factor of the Metropolis-Hastings ratio of the move. It lacks the
1260
- likelihood ratio, the probability of proposing the prune move, and the
1261
- prior probability that the children of the node to prune are leaves.
1262
- This ratio is inverted, and is meant to be inverted back in
1263
- `accept_move_and_sample_leaves`.
1264
- """
1265
-
1266
- allowed: Bool[Array, ' num_trees']
1267
- node: UInt[Array, ' num_trees']
1268
- partial_ratio: Float32[Array, ' num_trees']
1269
- affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
1270
-
1271
-
1272
- @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1273
- def propose_prune_moves(
1274
- key: Key[Array, ''],
1275
- split_tree: UInt[Array, ' 2**(d-1)'],
1276
- affluence_tree: Bool[Array, ' 2**(d-1)'],
1277
- p_nonterminal: Float32[Array, ' 2**d'],
1278
- p_propose_grow: Float32[Array, ' 2**(d-1)'],
1279
- ) -> PruneMoves:
1280
- """
1281
- Tree structure prune move proposal of BART MCMC.
1282
-
1283
- Parameters
1284
- ----------
1285
- key
1286
- A jax random key.
1287
- split_tree
1288
- The splitting points of the tree.
1289
- affluence_tree
1290
- Whether each leaf can be grown.
1291
- p_nonterminal
1292
- The a priori probability of a node to be nonterminal conditional on
1293
- the ancestors, including at the maximum depth where it should be zero.
1294
- p_propose_grow
1295
- The unnormalized probability of choosing a leaf to grow.
1296
-
1297
- Returns
1298
- -------
1299
- An object representing the proposed moves.
1300
- """
1301
- node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
1302
- key, split_tree, affluence_tree, p_propose_grow
1303
- )
1304
-
1305
- ratio = compute_partial_ratio(
1306
- prob_choose, num_prunable, p_nonterminal, node_to_prune
1307
- )
1308
-
1309
- return PruneMoves(
1310
- allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1311
- node=node_to_prune,
1312
- partial_ratio=ratio,
1313
- affluence_tree=affluence_tree,
1314
- )
1315
-
1316
-
1317
- def choose_leaf_parent(
1318
- key: Key[Array, ''],
1319
- split_tree: UInt[Array, ' 2**(d-1)'],
1320
- affluence_tree: Bool[Array, ' 2**(d-1)'],
1321
- p_propose_grow: Float32[Array, ' 2**(d-1)'],
1322
- ) -> tuple[
1323
- Int32[Array, ''],
1324
- Int32[Array, ''],
1325
- Float32[Array, ''],
1326
- Bool[Array, 'num_trees 2**(d-1)'],
1327
- ]:
1328
- """
1329
- Pick a non-terminal node with leaf children to prune in a tree.
1330
-
1331
- Parameters
1332
- ----------
1333
- key
1334
- A jax random key.
1335
- split_tree
1336
- The splitting points of the tree.
1337
- affluence_tree
1338
- Whether a leaf has enough points to be grown.
1339
- p_propose_grow
1340
- The unnormalized probability of choosing a leaf to grow.
1341
-
1342
- Returns
1343
- -------
1344
- node_to_prune : Int32[Array, '']
1345
- The index of the node to prune. If ``num_prunable == 0``, return
1346
- ``2 ** d``.
1347
- num_prunable : Int32[Array, '']
1348
- The number of leaf parents that could be pruned.
1349
- prob_choose : Float32[Array, '']
1350
- The (normalized) probability that `choose_leaf` would chose
1351
- `node_to_prune` as leaf to grow, if passed the tree where
1352
- `node_to_prune` had been pruned.
1353
- affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1354
- A partially updated `affluence_tree`, marking the node to prune as
1355
- growable.
1356
- """
1357
- # sample a node to prune
1358
- is_prunable = grove.is_leaves_parent(split_tree)
1359
- num_prunable = jnp.count_nonzero(is_prunable)
1360
- node_to_prune = randint_masked(key, is_prunable)
1361
- node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
1362
-
1363
- # compute stuff for reverse move
1364
- split_tree = split_tree.at[node_to_prune].set(0)
1365
- affluence_tree = affluence_tree.at[node_to_prune].set(True)
1366
- is_growable_leaf = growable_leaves(split_tree, affluence_tree)
1367
- distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
1368
- prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
1369
- prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
1370
-
1371
- return node_to_prune, num_prunable, prob_choose, affluence_tree
1372
-
1373
-
1374
- def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
1375
- """
1376
- Return a random integer in a range, including only some values.
1377
-
1378
- Parameters
1379
- ----------
1380
- key
1381
- A jax random key.
1382
- mask
1383
- The mask indicating the allowed values.
1384
-
1385
- Returns
1386
- -------
1387
- A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1388
-
1389
- Notes
1390
- -----
1391
- If all values in the mask are `False`, return `n`.
1392
- """
1393
- ecdf = jnp.cumsum(mask)
1394
- u = random.randint(key, (), 0, ecdf[-1])
1395
- return jnp.searchsorted(ecdf, u, 'right')
1396
-
1397
-
1398
- def accept_moves_and_sample_leaves(
1399
- key: Key[Array, ''], bart: State, moves: Moves
1400
- ) -> State:
1401
- """
1402
- Accept or reject the proposed moves and sample the new leaf values.
1403
-
1404
- Parameters
1405
- ----------
1406
- key
1407
- A jax random key.
1408
- bart
1409
- A valid BART mcmc state.
1410
- moves
1411
- The proposed moves, see `propose_moves`.
1412
-
1413
- Returns
1414
- -------
1415
- A new (valid) BART mcmc state.
1416
- """
1417
- pso = accept_moves_parallel_stage(key, bart, moves)
1418
- bart, moves = accept_moves_sequential_stage(pso)
1419
- return accept_moves_final_stage(bart, moves)
1420
-
1421
-
1422
- class Counts(Module):
1423
- """
1424
- Number of datapoints in the nodes involved in proposed moves for each tree.
1425
-
1426
- Parameters
1427
- ----------
1428
- left
1429
- Number of datapoints in the left child.
1430
- right
1431
- Number of datapoints in the right child.
1432
- total
1433
- Number of datapoints in the parent (``= left + right``).
1434
- """
1435
-
1436
- left: UInt[Array, ' num_trees']
1437
- right: UInt[Array, ' num_trees']
1438
- total: UInt[Array, ' num_trees']
1439
-
1440
-
1441
- class Precs(Module):
1442
- """
1443
- Likelihood precision scale in the nodes involved in proposed moves for each tree.
1444
-
1445
- The "likelihood precision scale" of a tree node is the sum of the inverse
1446
- squared error scales of the datapoints selected by the node.
1447
-
1448
- Parameters
1449
- ----------
1450
- left
1451
- Likelihood precision scale in the left child.
1452
- right
1453
- Likelihood precision scale in the right child.
1454
- total
1455
- Likelihood precision scale in the parent (``= left + right``).
1456
- """
1457
-
1458
- left: Float32[Array, ' num_trees']
1459
- right: Float32[Array, ' num_trees']
1460
- total: Float32[Array, ' num_trees']
1461
-
1462
-
1463
- class PreLkV(Module):
1464
- """
1465
- Non-sequential terms of the likelihood ratio for each tree.
1466
-
1467
- These terms can be computed in parallel across trees.
1468
-
1469
- Parameters
1470
- ----------
1471
- sigma2_left
1472
- The noise variance in the left child of the leaves grown or pruned by
1473
- the moves.
1474
- sigma2_right
1475
- The noise variance in the right child of the leaves grown or pruned by
1476
- the moves.
1477
- sigma2_total
1478
- The noise variance in the total of the leaves grown or pruned by the
1479
- moves.
1480
- sqrt_term
1481
- The **logarithm** of the square root term of the likelihood ratio.
1482
- """
1483
-
1484
- sigma2_left: Float32[Array, ' num_trees']
1485
- sigma2_right: Float32[Array, ' num_trees']
1486
- sigma2_total: Float32[Array, ' num_trees']
1487
- sqrt_term: Float32[Array, ' num_trees']
1488
-
1489
-
1490
- class PreLk(Module):
1491
- """
1492
- Non-sequential terms of the likelihood ratio shared by all trees.
1493
-
1494
- Parameters
1495
- ----------
1496
- exp_factor
1497
- The factor to multiply the likelihood ratio by, shared by all trees.
1498
- """
1499
-
1500
- exp_factor: Float32[Array, '']
1501
-
1502
-
1503
- class PreLf(Module):
1504
- """
1505
- Pre-computed terms used to sample leaves from their posterior.
1506
-
1507
- These terms can be computed in parallel across trees.
1508
-
1509
- Parameters
1510
- ----------
1511
- mean_factor
1512
- The factor to be multiplied by the sum of the scaled residuals to
1513
- obtain the posterior mean.
1514
- centered_leaves
1515
- The mean-zero normal values to be added to the posterior mean to
1516
- obtain the posterior leaf samples.
1517
- """
1518
-
1519
- mean_factor: Float32[Array, 'num_trees 2**d']
1520
- centered_leaves: Float32[Array, 'num_trees 2**d']
1521
-
1522
-
1523
- class ParallelStageOut(Module):
1524
- """
1525
- The output of `accept_moves_parallel_stage`.
1526
-
1527
- Parameters
1528
- ----------
1529
- bart
1530
- A partially updated BART mcmc state.
1531
- moves
1532
- The proposed moves, with `partial_ratio` set to `None` and
1533
- `log_trans_prior_ratio` set to its final value.
1534
- prec_trees
1535
- The likelihood precision scale in each potential or actual leaf node. If
1536
- there is no precision scale, this is the number of points in each leaf.
1537
- move_counts
1538
- The counts of the number of points in the the nodes modified by the
1539
- moves. If `bart.min_points_per_leaf` is not set and
1540
- `bart.prec_scale` is set, they are not computed.
1541
- move_precs
1542
- The likelihood precision scale in each node modified by the moves. If
1543
- `bart.prec_scale` is not set, this is set to `move_counts`.
1544
- prelkv
1545
- prelk
1546
- prelf
1547
- Objects with pre-computed terms of the likelihood ratios and leaf
1548
- samples.
1549
- """
1550
-
1551
- bart: State
1552
- moves: Moves
1553
- prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1554
- move_precs: Precs | Counts
1555
- prelkv: PreLkV
1556
- prelk: PreLk
1557
- prelf: PreLf
1558
-
1559
-
1560
- def accept_moves_parallel_stage(
1561
- key: Key[Array, ''], bart: State, moves: Moves
1562
- ) -> ParallelStageOut:
1563
- """
1564
- Pre-compute quantities used to accept moves, in parallel across trees.
1565
-
1566
- Parameters
1567
- ----------
1568
- key : jax.dtypes.prng_key array
1569
- A jax random key.
1570
- bart : dict
1571
- A BART mcmc state.
1572
- moves : dict
1573
- The proposed moves, see `propose_moves`.
1574
-
1575
- Returns
1576
- -------
1577
- An object with all that could be done in parallel.
1578
- """
1579
- # where the move is grow, modify the state like the move was accepted
1580
- bart = replace(
1581
- bart,
1582
- forest=replace(
1583
- bart.forest,
1584
- var_tree=moves.var_tree,
1585
- leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1586
- leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1587
- ),
1588
- )
1589
-
1590
- # count number of datapoints per leaf
1591
- if (
1592
- bart.forest.min_points_per_decision_node is not None
1593
- or bart.forest.min_points_per_leaf is not None
1594
- or bart.prec_scale is None
1595
- ):
1596
- count_trees, move_counts = compute_count_trees(
1597
- bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1598
- )
1599
-
1600
- # mark which leaves & potential leaves have enough points to be grown
1601
- if bart.forest.min_points_per_decision_node is not None:
1602
- count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
1603
- moves = replace(
1604
- moves,
1605
- affluence_tree=moves.affluence_tree
1606
- & (count_half_trees >= bart.forest.min_points_per_decision_node),
1607
- )
1608
-
1609
- # copy updated affluence_tree to state
1610
- bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
1611
-
1612
- # veto grove move if new leaves don't have enough datapoints
1613
- if bart.forest.min_points_per_leaf is not None:
1614
- moves = replace(
1615
- moves,
1616
- allowed=moves.allowed
1617
- & (move_counts.left >= bart.forest.min_points_per_leaf)
1618
- & (move_counts.right >= bart.forest.min_points_per_leaf),
1619
- )
1620
-
1621
- # count number of datapoints per leaf, weighted by error precision scale
1622
- if bart.prec_scale is None:
1623
- prec_trees = count_trees
1624
- move_precs = move_counts
1625
- else:
1626
- prec_trees, move_precs = compute_prec_trees(
1627
- bart.prec_scale,
1628
- bart.forest.leaf_indices,
1629
- moves,
1630
- bart.forest.count_batch_size,
1631
- )
1632
- assert move_precs is not None
1633
-
1634
- # compute some missing information about moves
1635
- moves = complete_ratio(moves, bart.forest.p_nonterminal)
1636
- save_ratios = bart.forest.log_likelihood is not None
1637
- bart = replace(
1638
- bart,
1639
- forest=replace(
1640
- bart.forest,
1641
- grow_prop_count=jnp.sum(moves.grow),
1642
- prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1643
- log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1644
- ),
1645
- )
1646
-
1647
- # pre-compute some likelihood ratio & posterior terms
1648
- assert bart.sigma2 is not None # `step` shall temporarily set it to 1
1649
- prelkv, prelk = precompute_likelihood_terms(
1650
- bart.sigma2, bart.forest.sigma_mu2, move_precs
1651
- )
1652
- prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
1653
-
1654
- return ParallelStageOut(
1655
- bart=bart,
1656
- moves=moves,
1657
- prec_trees=prec_trees,
1658
- move_precs=move_precs,
1659
- prelkv=prelkv,
1660
- prelk=prelk,
1661
- prelf=prelf,
1662
- )
1663
-
1664
-
1665
- @partial(vmap_nodoc, in_axes=(0, 0, None))
1666
- def apply_grow_to_indices(
1667
- moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1668
- ) -> UInt[Array, 'num_trees n']:
1669
- """
1670
- Update the leaf indices to apply a grow move.
1671
-
1672
- Parameters
1673
- ----------
1674
- moves
1675
- The proposed moves, see `propose_moves`.
1676
- leaf_indices
1677
- The index of the leaf each datapoint falls into.
1678
- X
1679
- The predictors matrix.
1680
-
1681
- Returns
1682
- -------
1683
- The updated leaf indices.
1684
- """
1685
- left_child = moves.node.astype(leaf_indices.dtype) << 1
1686
- go_right = X[moves.grow_var, :] >= moves.grow_split
1687
- tree_size = jnp.array(2 * moves.var_tree.size)
1688
- node_to_update = jnp.where(moves.grow, moves.node, tree_size)
1689
- return jnp.where(
1690
- leaf_indices == node_to_update, left_child + go_right, leaf_indices
1691
- )
1692
-
1693
-
1694
- def compute_count_trees(
1695
- leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1696
- ) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1697
- """
1698
- Count the number of datapoints in each leaf.
1699
-
1700
- Parameters
1701
- ----------
1702
- leaf_indices
1703
- The index of the leaf each datapoint falls into, with the deeper version
1704
- of the tree (post-GROW, pre-PRUNE).
1705
- moves
1706
- The proposed moves, see `propose_moves`.
1707
- batch_size
1708
- The data batch size to use for the summation.
1709
-
1710
- Returns
1711
- -------
1712
- count_trees : Int32[Array, 'num_trees 2**d']
1713
- The number of points in each potential or actual leaf node.
1714
- counts : Counts
1715
- The counts of the number of points in the leaves grown or pruned by the
1716
- moves.
1717
- """
1718
- num_trees, tree_size = moves.var_tree.shape
1719
- tree_size *= 2
1720
- tree_indices = jnp.arange(num_trees)
1721
-
1722
- count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1723
-
1724
- # count datapoints in nodes modified by move
1725
- left = count_trees[tree_indices, moves.left]
1726
- right = count_trees[tree_indices, moves.right]
1727
- counts = Counts(left=left, right=right, total=left + right)
1728
-
1729
- # write count into non-leaf node
1730
- count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
1731
-
1732
- return count_trees, counts
1733
-
1734
-
1735
- def count_datapoints_per_leaf(
1736
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1737
- ) -> Int32[Array, 'num_trees 2**(d-1)']:
1738
- """
1739
- Count the number of datapoints in each leaf.
1740
-
1741
- Parameters
1742
- ----------
1743
- leaf_indices
1744
- The index of the leaf each datapoint falls into.
1745
- tree_size
1746
- The size of the leaf tree array (2 ** d).
1747
- batch_size
1748
- The data batch size to use for the summation.
1749
-
1750
- Returns
1751
- -------
1752
- The number of points in each leaf node.
1753
- """
1754
- if batch_size is None:
1755
- return _count_scan(leaf_indices, tree_size)
1756
- else:
1757
- return _count_vec(leaf_indices, tree_size, batch_size)
1758
-
1759
-
1760
- def _count_scan(
1761
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1762
- ) -> Int32[Array, 'num_trees {tree_size}']:
1763
- def loop(_, leaf_indices):
1764
- return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1765
-
1766
- _, count_trees = lax.scan(loop, None, leaf_indices)
1767
- return count_trees
1768
-
1769
-
1770
- def _aggregate_scatter(
1771
- values: Shaped[Array, '*'],
1772
- indices: Integer[Array, '*'],
1773
- size: int,
1774
- dtype: jnp.dtype,
1775
- ) -> Shaped[Array, ' {size}']:
1776
- return jnp.zeros(size, dtype).at[indices].add(values)
1777
-
1778
-
1779
- def _count_vec(
1780
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1781
- ) -> Int32[Array, 'num_trees 2**(d-1)']:
1782
- return _aggregate_batched_alltrees(
1783
- 1, leaf_indices, tree_size, jnp.uint32, batch_size
1784
- )
1785
- # uint16 is super-slow on gpu, don't use it even if n < 2^16
1786
-
1787
-
1788
- def _aggregate_batched_alltrees(
1789
- values: Shaped[Array, '*'],
1790
- indices: UInt[Array, 'num_trees n'],
1791
- size: int,
1792
- dtype: jnp.dtype,
1793
- batch_size: int,
1794
- ) -> Shaped[Array, 'num_trees {size}']:
1795
- num_trees, n = indices.shape
1796
- tree_indices = jnp.arange(num_trees)
1797
- nbatches = n // batch_size + bool(n % batch_size)
1798
- batch_indices = jnp.arange(n) % nbatches
1799
- return (
1800
- jnp.zeros((num_trees, size, nbatches), dtype)
1801
- .at[tree_indices[:, None], indices, batch_indices]
1802
- .add(values)
1803
- .sum(axis=2)
1804
- )
1805
-
1806
-
1807
- def compute_prec_trees(
1808
- prec_scale: Float32[Array, ' n'],
1809
- leaf_indices: UInt[Array, 'num_trees n'],
1810
- moves: Moves,
1811
- batch_size: int | None,
1812
- ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1813
- """
1814
- Compute the likelihood precision scale in each leaf.
1815
-
1816
- Parameters
1817
- ----------
1818
- prec_scale
1819
- The scale of the precision of the error on each datapoint.
1820
- leaf_indices
1821
- The index of the leaf each datapoint falls into, with the deeper version
1822
- of the tree (post-GROW, pre-PRUNE).
1823
- moves
1824
- The proposed moves, see `propose_moves`.
1825
- batch_size
1826
- The data batch size to use for the summation.
1827
-
1828
- Returns
1829
- -------
1830
- prec_trees : Float32[Array, 'num_trees 2**d']
1831
- The likelihood precision scale in each potential or actual leaf node.
1832
- precs : Precs
1833
- The likelihood precision scale in the nodes involved in the moves.
1834
- """
1835
- num_trees, tree_size = moves.var_tree.shape
1836
- tree_size *= 2
1837
- tree_indices = jnp.arange(num_trees)
1838
-
1839
- prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1840
-
1841
- # prec datapoints in nodes modified by move
1842
- left = prec_trees[tree_indices, moves.left]
1843
- right = prec_trees[tree_indices, moves.right]
1844
- precs = Precs(left=left, right=right, total=left + right)
1845
-
1846
- # write prec into non-leaf node
1847
- prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
1848
-
1849
- return prec_trees, precs
1850
-
1851
-
1852
- def prec_per_leaf(
1853
- prec_scale: Float32[Array, ' n'],
1854
- leaf_indices: UInt[Array, 'num_trees n'],
1855
- tree_size: int,
1856
- batch_size: int | None,
1857
- ) -> Float32[Array, 'num_trees {tree_size}']:
1858
- """
1859
- Compute the likelihood precision scale in each leaf.
1860
-
1861
- Parameters
1862
- ----------
1863
- prec_scale
1864
- The scale of the precision of the error on each datapoint.
1865
- leaf_indices
1866
- The index of the leaf each datapoint falls into.
1867
- tree_size
1868
- The size of the leaf tree array (2 ** d).
1869
- batch_size
1870
- The data batch size to use for the summation.
1871
-
1872
- Returns
1873
- -------
1874
- The likelihood precision scale in each leaf node.
1875
- """
1876
- if batch_size is None:
1877
- return _prec_scan(prec_scale, leaf_indices, tree_size)
1878
- else:
1879
- return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1880
-
1881
-
1882
- def _prec_scan(
1883
- prec_scale: Float32[Array, ' n'],
1884
- leaf_indices: UInt[Array, 'num_trees n'],
1885
- tree_size: int,
1886
- ) -> Float32[Array, 'num_trees {tree_size}']:
1887
- def loop(_, leaf_indices):
1888
- return None, _aggregate_scatter(
1889
- prec_scale, leaf_indices, tree_size, jnp.float32
1890
- )
1891
-
1892
- _, prec_trees = lax.scan(loop, None, leaf_indices)
1893
- return prec_trees
1894
-
1895
-
1896
- def _prec_vec(
1897
- prec_scale: Float32[Array, ' n'],
1898
- leaf_indices: UInt[Array, 'num_trees n'],
1899
- tree_size: int,
1900
- batch_size: int,
1901
- ) -> Float32[Array, 'num_trees {tree_size}']:
1902
- return _aggregate_batched_alltrees(
1903
- prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1904
- )
1905
-
1906
-
1907
- def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
1908
- """
1909
- Complete non-likelihood MH ratio calculation.
1910
-
1911
- This function adds the probability of choosing a prune move over the grow
1912
- move in the inverse transition, and the a priori probability that the
1913
- children nodes are leaves.
1914
-
1915
- Parameters
1916
- ----------
1917
- moves
1918
- The proposed moves. Must have already been updated to keep into account
1919
- the thresholds on the number of datapoints per node, this happens in
1920
- `accept_moves_parallel_stage`.
1921
- p_nonterminal
1922
- The a priori probability of each node being nonterminal conditional on
1923
- its ancestors, including at the maximum depth where it should be zero.
1924
-
1925
- Returns
1926
- -------
1927
- The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1928
- """
1929
- # can the leaves can be grown?
1930
- num_trees, _ = moves.affluence_tree.shape
1931
- tree_indices = jnp.arange(num_trees)
1932
- left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
1933
- mode='fill', fill_value=False
1934
- )
1935
- right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
1936
- mode='fill', fill_value=False
1937
- )
1938
-
1939
- # p_prune if grow
1940
- other_growable_leaves = moves.num_growable >= 2
1941
- grow_again_allowed = other_growable_leaves | left_growable | right_growable
1942
- grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1943
-
1944
- # p_prune if prune
1945
- prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1946
-
1947
- # select p_prune
1948
- p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1949
-
1950
- # prior probability of both children being terminal
1951
- pt_left = 1 - p_nonterminal[moves.left] * left_growable
1952
- pt_right = 1 - p_nonterminal[moves.right] * right_growable
1953
- pt_children = pt_left * pt_right
1954
-
1955
- return replace(
1956
- moves,
1957
- log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1958
- partial_ratio=None,
1959
- )
1960
-
1961
-
1962
- @vmap_nodoc
1963
- def adapt_leaf_trees_to_grow_indices(
1964
- leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1965
- ) -> Float32[Array, 'num_trees 2**d']:
1966
- """
1967
- Modify leaves such that post-grow indices work on the original tree.
1968
-
1969
- The value of the leaf to grow is copied to what would be its children if the
1970
- grow move was accepted.
1971
-
1972
- Parameters
1973
- ----------
1974
- leaf_trees
1975
- The leaf values.
1976
- moves
1977
- The proposed moves, see `propose_moves`.
1978
-
1979
- Returns
1980
- -------
1981
- The modified leaf values.
1982
- """
1983
- values_at_node = leaf_trees[moves.node]
1984
- return (
1985
- leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1986
- .set(values_at_node)
1987
- .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1988
- .set(values_at_node)
1989
- )
1990
-
1991
-
1992
- def precompute_likelihood_terms(
1993
- sigma2: Float32[Array, ''],
1994
- sigma_mu2: Float32[Array, ''],
1995
- move_precs: Precs | Counts,
1996
- ) -> tuple[PreLkV, PreLk]:
1997
- """
1998
- Pre-compute terms used in the likelihood ratio of the acceptance step.
1999
-
2000
- Parameters
2001
- ----------
2002
- sigma2
2003
- The error variance, or the global error variance factor is `prec_scale`
2004
- is set.
2005
- sigma_mu2
2006
- The prior variance of each leaf.
2007
- move_precs
2008
- The likelihood precision scale in the leaves grown or pruned by the
2009
- moves, under keys 'left', 'right', and 'total' (left + right).
2010
-
2011
- Returns
2012
- -------
2013
- prelkv : PreLkV
2014
- Dictionary with pre-computed terms of the likelihood ratio, one per
2015
- tree.
2016
- prelk : PreLk
2017
- Dictionary with pre-computed terms of the likelihood ratio, shared by
2018
- all trees.
2019
- """
2020
- sigma2_left = sigma2 + move_precs.left * sigma_mu2
2021
- sigma2_right = sigma2 + move_precs.right * sigma_mu2
2022
- sigma2_total = sigma2 + move_precs.total * sigma_mu2
2023
- prelkv = PreLkV(
2024
- sigma2_left=sigma2_left,
2025
- sigma2_right=sigma2_right,
2026
- sigma2_total=sigma2_total,
2027
- sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
2028
- )
2029
- return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
2030
-
2031
-
2032
- def precompute_leaf_terms(
2033
- key: Key[Array, ''],
2034
- prec_trees: Float32[Array, 'num_trees 2**d'],
2035
- sigma2: Float32[Array, ''],
2036
- sigma_mu2: Float32[Array, ''],
2037
- ) -> PreLf:
2038
- """
2039
- Pre-compute terms used to sample leaves from their posterior.
2040
-
2041
- Parameters
2042
- ----------
2043
- key
2044
- A jax random key.
2045
- prec_trees
2046
- The likelihood precision scale in each potential or actual leaf node.
2047
- sigma2
2048
- The error variance, or the global error variance factor if `prec_scale`
2049
- is set.
2050
- sigma_mu2
2051
- The prior variance of each leaf.
2052
-
2053
- Returns
2054
- -------
2055
- Pre-computed terms for leaf sampling.
2056
- """
2057
- prec_lk = prec_trees / sigma2
2058
- prec_prior = lax.reciprocal(sigma_mu2)
2059
- var_post = lax.reciprocal(prec_lk + prec_prior)
2060
- z = random.normal(key, prec_trees.shape, sigma2.dtype)
2061
- return PreLf(
2062
- mean_factor=var_post / sigma2,
2063
- # | mean = mean_lk * prec_lk * var_post
2064
- # | resid_tree = mean_lk * prec_tree -->
2065
- # | --> mean_lk = resid_tree / prec_tree (kind of)
2066
- # | mean_factor =
2067
- # | = mean / resid_tree =
2068
- # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2069
- # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2070
- # | = var_post / sigma2
2071
- centered_leaves=z * jnp.sqrt(var_post),
2072
- )
2073
-
2074
-
2075
- def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
2076
- """
2077
- Accept/reject the moves one tree at a time.
2078
-
2079
- This is the most performance-sensitive function because it contains all and
2080
- only the parts of the algorithm that can not be parallelized across trees.
2081
-
2082
- Parameters
2083
- ----------
2084
- pso
2085
- The output of `accept_moves_parallel_stage`.
2086
-
2087
- Returns
2088
- -------
2089
- bart : State
2090
- A partially updated BART mcmc state.
2091
- moves : Moves
2092
- The accepted/rejected moves, with `acc` and `to_prune` set.
2093
- """
2094
-
2095
- def loop(resid, pt):
2096
- resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
2097
- resid,
2098
- SeqStageInAllTrees(
2099
- pso.bart.X,
2100
- pso.bart.forest.resid_batch_size,
2101
- pso.bart.prec_scale,
2102
- pso.bart.forest.log_likelihood is not None,
2103
- pso.prelk,
2104
- ),
2105
- pt,
2106
- )
2107
- return resid, (leaf_tree, acc, to_prune, lkratio)
2108
-
2109
- pts = SeqStageInPerTree(
2110
- pso.bart.forest.leaf_tree,
2111
- pso.prec_trees,
2112
- pso.moves,
2113
- pso.move_precs,
2114
- pso.bart.forest.leaf_indices,
2115
- pso.prelkv,
2116
- pso.prelf,
2117
- )
2118
- resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
2119
-
2120
- bart = replace(
2121
- pso.bart,
2122
- resid=resid,
2123
- forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2124
- )
2125
- moves = replace(pso.moves, acc=acc, to_prune=to_prune)
2126
-
2127
- return bart, moves
2128
-
2129
-
2130
- class SeqStageInAllTrees(Module):
2131
- """
2132
- The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
2133
-
2134
- Parameters
2135
- ----------
2136
- X
2137
- The predictors.
2138
- resid_batch_size
2139
- The batch size for computing the sum of residuals in each leaf.
2140
- prec_scale
2141
- The scale of the precision of the error on each datapoint. If None, it
2142
- is assumed to be 1.
2143
- save_ratios
2144
- Whether to save the acceptance ratios.
2145
- prelk
2146
- The pre-computed terms of the likelihood ratio which are shared across
2147
- trees.
2148
- """
2149
-
2150
- X: UInt[Array, 'p n']
2151
- resid_batch_size: int | None = field(static=True)
2152
- prec_scale: Float32[Array, ' n'] | None
2153
- save_ratios: bool = field(static=True)
2154
- prelk: PreLk
2155
-
2156
-
2157
- class SeqStageInPerTree(Module):
2158
- """
2159
- The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2160
-
2161
- Parameters
2162
- ----------
2163
- leaf_tree
2164
- The leaf values of the tree.
2165
- prec_tree
2166
- The likelihood precision scale in each potential or actual leaf node.
2167
- move
2168
- The proposed move, see `propose_moves`.
2169
- move_precs
2170
- The likelihood precision scale in each node modified by the moves.
2171
- leaf_indices
2172
- The leaf indices for the largest version of the tree compatible with
2173
- the move.
2174
- prelkv
2175
- prelf
2176
- The pre-computed terms of the likelihood ratio and leaf sampling which
2177
- are specific to the tree.
2178
- """
2179
-
2180
- leaf_tree: Float32[Array, ' 2**d']
2181
- prec_tree: Float32[Array, ' 2**d']
2182
- move: Moves
2183
- move_precs: Precs | Counts
2184
- leaf_indices: UInt[Array, ' n']
2185
- prelkv: PreLkV
2186
- prelf: PreLf
2187
-
2188
-
2189
- def accept_move_and_sample_leaves(
2190
- resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2191
- ) -> tuple[
2192
- Float32[Array, ' n'],
2193
- Float32[Array, ' 2**d'],
2194
- Bool[Array, ''],
2195
- Bool[Array, ''],
2196
- Float32[Array, ''] | None,
2197
- ]:
2198
- """
2199
- Accept or reject a proposed move and sample the new leaf values.
2200
-
2201
- Parameters
2202
- ----------
2203
- resid
2204
- The residuals (data minus forest value).
2205
- at
2206
- The inputs that are the same for all trees.
2207
- pt
2208
- The inputs that are separate for each tree.
2209
-
2210
- Returns
2211
- -------
2212
- resid : Float32[Array, 'n']
2213
- The updated residuals (data minus forest value).
2214
- leaf_tree : Float32[Array, '2**d']
2215
- The new leaf values of the tree.
2216
- acc : Bool[Array, '']
2217
- Whether the move was accepted.
2218
- to_prune : Bool[Array, '']
2219
- Whether, to reflect the acceptance status of the move, the state should
2220
- be updated by pruning the leaves involved in the move.
2221
- log_lk_ratio : Float32[Array, ''] | None
2222
- The logarithm of the likelihood ratio for the move. `None` if not to be
2223
- saved.
2224
- """
2225
- # sum residuals in each leaf, in tree proposed by grow move
2226
- if at.prec_scale is None:
2227
- scaled_resid = resid
2228
- else:
2229
- scaled_resid = resid * at.prec_scale
2230
- resid_tree = sum_resid(
2231
- scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2232
- )
2233
-
2234
- # subtract starting tree from function
2235
- resid_tree += pt.prec_tree * pt.leaf_tree
2236
-
2237
- # sum residuals in parent node modified by move
2238
- resid_left = resid_tree[pt.move.left]
2239
- resid_right = resid_tree[pt.move.right]
2240
- resid_total = resid_left + resid_right
2241
- assert pt.move.node.dtype == jnp.int32
2242
- resid_tree = resid_tree.at[pt.move.node].set(resid_total)
2243
-
2244
- # compute acceptance ratio
2245
- log_lk_ratio = compute_likelihood_ratio(
2246
- resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2247
- )
2248
- log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2249
- log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
2250
- if not at.save_ratios:
2251
- log_lk_ratio = None
2252
-
2253
- # determine whether to accept the move
2254
- acc = pt.move.allowed & (pt.move.logu <= log_ratio)
2255
-
2256
- # compute leaves posterior and sample leaves
2257
- mean_post = resid_tree * pt.prelf.mean_factor
2258
- leaf_tree = mean_post + pt.prelf.centered_leaves
2259
-
2260
- # copy leaves around such that the leaf indices point to the correct leaf
2261
- to_prune = acc ^ pt.move.grow
2262
- leaf_tree = (
2263
- leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2264
- .set(leaf_tree[pt.move.node])
2265
- .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2266
- .set(leaf_tree[pt.move.node])
2267
- )
2268
-
2269
- # replace old tree with new tree in function values
2270
- resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
2271
-
2272
- return resid, leaf_tree, acc, to_prune, log_lk_ratio
2273
-
2274
-
2275
- def sum_resid(
2276
- scaled_resid: Float32[Array, ' n'],
2277
- leaf_indices: UInt[Array, ' n'],
2278
- tree_size: int,
2279
- batch_size: int | None,
2280
- ) -> Float32[Array, ' {tree_size}']:
2281
- """
2282
- Sum the residuals in each leaf.
2283
-
2284
- Parameters
2285
- ----------
2286
- scaled_resid
2287
- The residuals (data minus forest value) multiplied by the error
2288
- precision scale.
2289
- leaf_indices
2290
- The leaf indices of the tree (in which leaf each data point falls into).
2291
- tree_size
2292
- The size of the tree array (2 ** d).
2293
- batch_size
2294
- The data batch size for the aggregation. Batching increases numerical
2295
- accuracy and parallelism.
2296
-
2297
- Returns
2298
- -------
2299
- The sum of the residuals at data points in each leaf.
2300
- """
2301
- if batch_size is None:
2302
- aggr_func = _aggregate_scatter
2303
- else:
2304
- aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
2305
- return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
2306
-
2307
-
2308
- def _aggregate_batched_onetree(
2309
- values: Shaped[Array, '*'],
2310
- indices: Integer[Array, '*'],
2311
- size: int,
2312
- dtype: jnp.dtype,
2313
- batch_size: int,
2314
- ) -> Float32[Array, ' {size}']:
2315
- (n,) = indices.shape
2316
- nbatches = n // batch_size + bool(n % batch_size)
2317
- batch_indices = jnp.arange(n) % nbatches
2318
- return (
2319
- jnp.zeros((size, nbatches), dtype)
2320
- .at[indices, batch_indices]
2321
- .add(values)
2322
- .sum(axis=1)
2323
- )
2324
-
2325
-
2326
- def compute_likelihood_ratio(
2327
- total_resid: Float32[Array, ''],
2328
- left_resid: Float32[Array, ''],
2329
- right_resid: Float32[Array, ''],
2330
- prelkv: PreLkV,
2331
- prelk: PreLk,
2332
- ) -> Float32[Array, '']:
2333
- """
2334
- Compute the likelihood ratio of a grow move.
2335
-
2336
- Parameters
2337
- ----------
2338
- total_resid
2339
- left_resid
2340
- right_resid
2341
- The sum of the residuals (scaled by error precision scale) of the
2342
- datapoints falling in the nodes involved in the moves.
2343
- prelkv
2344
- prelk
2345
- The pre-computed terms of the likelihood ratio, see
2346
- `precompute_likelihood_terms`.
2347
-
2348
- Returns
2349
- -------
2350
- The likelihood ratio P(data | new tree) / P(data | old tree).
2351
- """
2352
- exp_term = prelk.exp_factor * (
2353
- left_resid * left_resid / prelkv.sigma2_left
2354
- + right_resid * right_resid / prelkv.sigma2_right
2355
- - total_resid * total_resid / prelkv.sigma2_total
2356
- )
2357
- return prelkv.sqrt_term + exp_term
2358
-
2359
-
2360
- def accept_moves_final_stage(bart: State, moves: Moves) -> State:
2361
- """
2362
- Post-process the mcmc state after accepting/rejecting the moves.
2363
-
2364
- This function is separate from `accept_moves_sequential_stage` to signal it
2365
- can work in parallel across trees.
2366
-
2367
- Parameters
2368
- ----------
2369
- bart
2370
- A partially updated BART mcmc state.
2371
- moves
2372
- The proposed moves (see `propose_moves`) as updated by
2373
- `accept_moves_sequential_stage`.
2374
-
2375
- Returns
2376
- -------
2377
- The fully updated BART mcmc state.
2378
- """
2379
- return replace(
2380
- bart,
2381
- forest=replace(
2382
- bart.forest,
2383
- grow_acc_count=jnp.sum(moves.acc & moves.grow),
2384
- prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2385
- leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2386
- split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2387
- ),
2388
- )
2389
-
2390
-
2391
- @vmap_nodoc
2392
- def apply_moves_to_leaf_indices(
2393
- leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2394
- ) -> UInt[Array, 'num_trees n']:
2395
- """
2396
- Update the leaf indices to match the accepted move.
2397
-
2398
- Parameters
2399
- ----------
2400
- leaf_indices
2401
- The index of the leaf each datapoint falls into, if the grow move was
2402
- accepted.
2403
- moves
2404
- The proposed moves (see `propose_moves`), as updated by
2405
- `accept_moves_sequential_stage`.
2406
-
2407
- Returns
2408
- -------
2409
- The updated leaf indices.
2410
- """
2411
- mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
2412
- is_child = (leaf_indices & mask) == moves.left
2413
- return jnp.where(
2414
- is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2415
- )
2416
-
2417
-
2418
- @vmap_nodoc
2419
- def apply_moves_to_split_trees(
2420
- split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2421
- ) -> UInt[Array, 'num_trees 2**(d-1)']:
2422
- """
2423
- Update the split trees to match the accepted move.
2424
-
2425
- Parameters
2426
- ----------
2427
- split_tree
2428
- The cutpoints of the decision nodes in the initial trees.
2429
- moves
2430
- The proposed moves (see `propose_moves`), as updated by
2431
- `accept_moves_sequential_stage`.
2432
-
2433
- Returns
2434
- -------
2435
- The updated split trees.
2436
- """
2437
- assert moves.to_prune is not None
2438
- return (
2439
- split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2440
- .set(moves.grow_split.astype(split_tree.dtype))
2441
- .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2442
- .set(0)
2443
- )
2444
-
2445
-
2446
- def step_sigma(key: Key[Array, ''], bart: State) -> State:
2447
- """
2448
- MCMC-update the error variance (factor).
2449
-
2450
- Parameters
2451
- ----------
2452
- key
2453
- A jax random key.
2454
- bart
2455
- A BART mcmc state.
2456
-
2457
- Returns
2458
- -------
2459
- The new BART mcmc state, with an updated `sigma2`.
2460
- """
2461
- resid = bart.resid
2462
- alpha = bart.sigma2_alpha + resid.size / 2
2463
- if bart.prec_scale is None:
2464
- scaled_resid = resid
2465
- else:
2466
- scaled_resid = resid * bart.prec_scale
2467
- norm2 = resid @ scaled_resid
2468
- beta = bart.sigma2_beta + norm2 / 2
2469
-
2470
- sample = random.gamma(key, alpha)
2471
- # random.gamma seems to be slow at compiling, maybe cdf inversion would
2472
- # be better, but it's not implemented in jax
2473
- return replace(bart, sigma2=beta / sample)
2474
-
2475
-
2476
- def step_z(key: Key[Array, ''], bart: State) -> State:
2477
- """
2478
- MCMC-update the latent variable for binary regression.
2479
-
2480
- Parameters
2481
- ----------
2482
- key
2483
- A jax random key.
2484
- bart
2485
- A BART MCMC state.
2486
-
2487
- Returns
2488
- -------
2489
- The updated BART MCMC state.
2490
- """
2491
- trees_plus_offset = bart.z - bart.resid
2492
- assert bart.y.dtype == bool
2493
- resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
2494
- z = trees_plus_offset + resid
2495
- return replace(bart, z=z, resid=resid)
2496
-
2497
-
2498
- def step_s(key: Key[Array, ''], bart: State) -> State:
2499
- """
2500
- Update `log_s` using Dirichlet sampling.
2501
-
2502
- The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2503
- is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2504
- varcount is the count of how many times each variable is used in the
2505
- current forest.
2506
-
2507
- Parameters
2508
- ----------
2509
- key
2510
- Random key for sampling.
2511
- bart
2512
- The current BART state.
2513
-
2514
- Returns
2515
- -------
2516
- Updated BART state with re-sampled `log_s`.
2517
-
2518
- """
2519
- assert bart.forest.theta is not None
2520
-
2521
- # histogram current variable usage
2522
- p = bart.forest.max_split.size
2523
- varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
2524
-
2525
- # sample from Dirichlet posterior
2526
- alpha = bart.forest.theta / p + varcount
2527
- log_s = random.loggamma(key, alpha)
2528
-
2529
- # update forest with new s
2530
- return replace(bart, forest=replace(bart.forest, log_s=log_s))
2531
-
2532
-
2533
- def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
2534
- """
2535
- Update `theta`.
2536
-
2537
- The prior is theta / (theta + rho) ~ Beta(a, b).
2538
-
2539
- Parameters
2540
- ----------
2541
- key
2542
- Random key for sampling.
2543
- bart
2544
- The current BART state.
2545
- num_grid
2546
- The number of points in the evenly-spaced grid used to sample
2547
- theta / (theta + rho).
2548
-
2549
- Returns
2550
- -------
2551
- Updated BART state with re-sampled `theta`.
2552
- """
2553
- assert bart.forest.log_s is not None
2554
- assert bart.forest.rho is not None
2555
- assert bart.forest.a is not None
2556
- assert bart.forest.b is not None
2557
-
2558
- # the grid points are the midpoints of num_grid bins in (0, 1)
2559
- padding = 1 / (2 * num_grid)
2560
- lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
2561
-
2562
- # normalize s
2563
- log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
2564
-
2565
- # sample lambda
2566
- logp, theta_grid = _log_p_lamda(
2567
- lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2568
- )
2569
- i = random.categorical(key, logp)
2570
- theta = theta_grid[i]
2571
-
2572
- return replace(bart, forest=replace(bart.forest, theta=theta))
2573
-
2574
-
2575
- def _log_p_lamda(
2576
- lamda: Float32[Array, ' num_grid'],
2577
- log_s: Float32[Array, ' p'],
2578
- rho: Float32[Array, ''],
2579
- a: Float32[Array, ''],
2580
- b: Float32[Array, ''],
2581
- ) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2582
- # in the following I use lamda[::-1] == 1 - lamda
2583
- theta = rho * lamda / lamda[::-1]
2584
- p = log_s.size
2585
- return (
2586
- (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2587
- + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2588
- + gammaln(theta)
2589
- - p * gammaln(theta / p)
2590
- + theta / p * jnp.sum(log_s)
2591
- ), theta
2592
-
2593
-
2594
- def step_sparse(key: Key[Array, ''], bart: State) -> State:
2595
- """
2596
- Update the sparsity parameters.
2597
-
2598
- This invokes `step_s`, and then `step_theta` only if the parameters of
2599
- the theta prior are defined.
2600
-
2601
- Parameters
2602
- ----------
2603
- key
2604
- Random key for sampling.
2605
- bart
2606
- The current BART state.
2607
-
2608
- Returns
2609
- -------
2610
- Updated BART state with re-sampled `log_s` and `theta`.
2611
- """
2612
- keys = split(key)
2613
- bart = step_s(keys.pop(), bart)
2614
- if bart.forest.rho is not None:
2615
- bart = step_theta(keys.pop(), bart)
2616
- return bart