bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -28,26 +28,33 @@ Functions that implement the BART posterior MCMC initialization and update step.
28
28
  Functions that do MCMC steps operate by taking as input a bart state, and
29
29
  outputting a new state. The inputs are not modified.
30
30
 
31
- The main entry points are:
31
+ The entry points are:
32
32
 
33
33
  - `State`: The dataclass that represents a BART MCMC state.
34
34
  - `init`: Creates an initial `State` from data and configurations.
35
35
  - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36
+ - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
36
37
  """
37
38
 
38
39
  import math
39
40
  from dataclasses import replace
40
41
  from functools import cache, partial
41
- from typing import Any
42
+ from typing import Any, Literal
42
43
 
43
44
  import jax
44
- from equinox import Module, field
45
+ from equinox import Module, field, tree_at
45
46
  from jax import lax, random
46
47
  from jax import numpy as jnp
48
+ from jax.scipy.special import gammaln, logsumexp
47
49
  from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
48
50
 
49
- from . import grove
50
- from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
51
+ from bartz import grove
52
+ from bartz.jaxext import (
53
+ minimal_unsigned_dtype,
54
+ split,
55
+ truncated_normal_onesided,
56
+ vmap_nodoc,
57
+ )
51
58
 
52
59
 
53
60
  class Forest(Module):
@@ -56,24 +63,32 @@ class Forest(Module):
56
63
 
57
64
  Parameters
58
65
  ----------
59
- leaf_trees
66
+ leaf_tree
60
67
  The leaf values.
61
- var_trees
68
+ var_tree
62
69
  The decision axes.
63
- split_trees
70
+ split_tree
64
71
  The decision boundaries.
72
+ affluence_tree
73
+ Marks leaves that can be grown.
74
+ max_split
75
+ The maximum split index for each predictor.
76
+ blocked_vars
77
+ Indices of variables that are not used. This shall include at least
78
+ the `i` such that ``max_split[i] == 0``, otherwise behavior is
79
+ undefined.
65
80
  p_nonterminal
66
- The probability of a nonterminal node at each depth, padded with a
67
- zero.
81
+ The prior probability of each node being nonterminal, conditional on
82
+ its ancestors. Includes the nodes at maximum depth which should be set
83
+ to 0.
68
84
  p_propose_grow
69
85
  The unnormalized probability of picking a leaf for a grow proposal.
70
86
  leaf_indices
71
87
  The index of the leaf each datapoints falls into, for each tree.
88
+ min_points_per_decision_node
89
+ The minimum number of data points in a decision node.
72
90
  min_points_per_leaf
73
91
  The minimum number of data points in a leaf node.
74
- affluence_trees
75
- Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
76
- datapoints. If `min_points_per_leaf` is not specified, this is None.
77
92
  resid_batch_size
78
93
  count_batch_size
79
94
  The data batch sizes for computing the sufficient statistics. If `None`,
@@ -91,25 +106,45 @@ class Forest(Module):
91
106
  The number of grow/prune moves accepted during one full MCMC cycle.
92
107
  sigma_mu2
93
108
  The prior variance of a leaf, conditional on the tree structure.
94
- """
95
-
96
- leaf_trees: Float32[Array, 'num_trees 2**d']
97
- var_trees: UInt[Array, 'num_trees 2**(d-1)']
98
- split_trees: UInt[Array, 'num_trees 2**(d-1)']
99
- p_nonterminal: Float32[Array, 'd']
100
- p_propose_grow: Float32[Array, '2**(d-1)']
109
+ log_s
110
+ The logarithm of the prior probability for choosing a variable to split
111
+ along in a decision rule, conditional on the ancestors. Not normalized.
112
+ If `None`, use a uniform distribution.
113
+ theta
114
+ The concentration parameter for the Dirichlet prior on the variable
115
+ distribution `s`. Required only to update `s`.
116
+ a
117
+ b
118
+ rho
119
+ Parameters of the prior on `theta`. Required only to sample `theta`.
120
+ See `step_theta`.
121
+ """
122
+
123
+ leaf_tree: Float32[Array, 'num_trees 2**d']
124
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
125
+ split_tree: UInt[Array, 'num_trees 2**(d-1)']
126
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
127
+ max_split: UInt[Array, ' p']
128
+ blocked_vars: UInt[Array, ' k'] | None
129
+ p_nonterminal: Float32[Array, ' 2**d']
130
+ p_propose_grow: Float32[Array, ' 2**(d-1)']
101
131
  leaf_indices: UInt[Array, 'num_trees n']
132
+ min_points_per_decision_node: Int32[Array, ''] | None
102
133
  min_points_per_leaf: Int32[Array, ''] | None
103
- affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
104
134
  resid_batch_size: int | None = field(static=True)
105
135
  count_batch_size: int | None = field(static=True)
106
- log_trans_prior: Float32[Array, 'num_trees'] | None
107
- log_likelihood: Float32[Array, 'num_trees'] | None
136
+ log_trans_prior: Float32[Array, ' num_trees'] | None
137
+ log_likelihood: Float32[Array, ' num_trees'] | None
108
138
  grow_prop_count: Int32[Array, '']
109
139
  prune_prop_count: Int32[Array, '']
110
140
  grow_acc_count: Int32[Array, '']
111
141
  prune_acc_count: Int32[Array, '']
112
142
  sigma_mu2: Float32[Array, '']
143
+ log_s: Float32[Array, ' p'] | None
144
+ theta: Float32[Array, ''] | None
145
+ a: Float32[Array, ''] | None
146
+ b: Float32[Array, ''] | None
147
+ rho: Float32[Array, ''] | None
113
148
 
114
149
 
115
150
  class State(Module):
@@ -120,8 +155,6 @@ class State(Module):
120
155
  ----------
121
156
  X
122
157
  The predictors.
123
- max_split
124
- The maximum split index for each predictor.
125
158
  y
126
159
  The response. If the data type is `bool`, the model is binary regression.
127
160
  resid
@@ -145,13 +178,12 @@ class State(Module):
145
178
  """
146
179
 
147
180
  X: UInt[Array, 'p n']
148
- max_split: UInt[Array, 'p']
149
- y: Float32[Array, 'n'] | Bool[Array, 'n']
150
- z: None | Float32[Array, 'n']
181
+ y: Float32[Array, ' n'] | Bool[Array, ' n']
182
+ z: None | Float32[Array, ' n']
151
183
  offset: Float32[Array, '']
152
- resid: Float32[Array, 'n']
184
+ resid: Float32[Array, ' n']
153
185
  sigma2: Float32[Array, ''] | None
154
- prec_scale: Float32[Array, 'n'] | None
186
+ prec_scale: Float32[Array, ' n'] | None
155
187
  sigma2_alpha: Float32[Array, ''] | None
156
188
  sigma2_beta: Float32[Array, ''] | None
157
189
  forest: Forest
@@ -160,19 +192,26 @@ class State(Module):
160
192
  def init(
161
193
  *,
162
194
  X: UInt[Any, 'p n'],
163
- y: Float32[Any, 'n'] | Bool[Any, 'n'],
195
+ y: Float32[Any, ' n'] | Bool[Any, ' n'],
164
196
  offset: float | Float32[Any, ''] = 0.0,
165
- max_split: UInt[Any, 'p'],
197
+ max_split: UInt[Any, ' p'],
166
198
  num_trees: int,
167
- p_nonterminal: Float32[Any, 'd-1'],
199
+ p_nonterminal: Float32[Any, ' d-1'],
168
200
  sigma_mu2: float | Float32[Any, ''],
169
201
  sigma2_alpha: float | Float32[Any, ''] | None = None,
170
202
  sigma2_beta: float | Float32[Any, ''] | None = None,
171
- error_scale: Float32[Any, 'n'] | None = None,
172
- min_points_per_leaf: int | None = None,
173
- resid_batch_size: int | None | str = 'auto',
174
- count_batch_size: int | None | str = 'auto',
203
+ error_scale: Float32[Any, ' n'] | None = None,
204
+ min_points_per_decision_node: int | Integer[Any, ''] | None = None,
205
+ resid_batch_size: int | None | Literal['auto'] = 'auto',
206
+ count_batch_size: int | None | Literal['auto'] = 'auto',
175
207
  save_ratios: bool = False,
208
+ filter_splitless_vars: bool = True,
209
+ min_points_per_leaf: int | Integer[Any, ''] | None = None,
210
+ log_s: Float32[Any, ' p'] | None = None,
211
+ theta: float | Float32[Any, ''] | None = None,
212
+ a: float | Float32[Any, ''] | None = None,
213
+ b: float | Float32[Any, ''] | None = None,
214
+ rho: float | Float32[Any, ''] | None = None,
176
215
  ) -> State:
177
216
  """
178
217
  Make a BART posterior sampling MCMC initial state.
@@ -206,8 +245,9 @@ def init(
206
245
  the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
207
246
  Not supported for binary regression. If not specified, defaults to 1 for
208
247
  all points, but potentially skipping calculations.
209
- min_points_per_leaf
210
- The minimum number of data points in a leaf node. 0 if not specified.
248
+ min_points_per_decision_node
249
+ The minimum number of data points in a decision node. 0 if not
250
+ specified.
211
251
  resid_batch_size
212
252
  count_batch_size
213
253
  The batch sizes, along datapoints, for summing the residuals and
@@ -216,6 +256,33 @@ def init(
216
256
  device.
217
257
  save_ratios
218
258
  Whether to save the Metropolis-Hastings ratios.
259
+ filter_splitless_vars
260
+ Whether to check `max_split` for variables without available cutpoints.
261
+ If any are found, they are put into a list of variables to exclude from
262
+ the MCMC. If `False`, no check is performed, but the results may be
263
+ wrong if any variable is blocked. The function is jax-traceable only
264
+ if this is set to `False`.
265
+ min_points_per_leaf
266
+ The minimum number of datapoints in a leaf node. 0 if not specified.
267
+ Unlike `min_points_per_decision_node`, this constraint is not taken into
268
+ account in the Metropolis-Hastings ratio because it would be expensive
269
+ to compute. Grow moves that would violate this constraint are vetoed.
270
+ This parameter is independent of `min_points_per_decision_node` and
271
+ there is no check that they are coherent. It makes sense to set
272
+ ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
273
+ log_s
274
+ The logarithm of the prior probability for choosing a variable to split
275
+ along in a decision rule, conditional on the ancestors. Not normalized.
276
+ If not specified, use a uniform distribution. If not specified and
277
+ `theta` or `rho`, `a`, `b` are, it's initialized automatically.
278
+ theta
279
+ The concentration parameter for the Dirichlet prior on `s`. Required
280
+ only to update `log_s`. If not specified, and `rho`, `a`, `b` are
281
+ specified, it's initialized automatically.
282
+ a
283
+ b
284
+ rho
285
+ Parameters of the prior on `theta`. Required only to sample `theta`.
219
286
 
220
287
  Returns
221
288
  -------
@@ -225,6 +292,13 @@ def init(
225
292
  ------
226
293
  ValueError
227
294
  If `y` is boolean and arguments unused in binary regression are set.
295
+
296
+ Notes
297
+ -----
298
+ In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
299
+ of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
300
+ child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
301
+ integers in the range ``[0, 1, ..., max_split[i]]``.
228
302
  """
229
303
  p_nonterminal = jnp.asarray(p_nonterminal)
230
304
  p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
@@ -244,22 +318,37 @@ def init(
244
318
  is_binary = y.dtype == bool
245
319
  if is_binary:
246
320
  if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
247
- raise ValueError(
321
+ msg = (
248
322
  'error_scale, sigma2_alpha, and sigma2_beta must be set '
249
323
  ' to `None` for binary regression.'
250
324
  )
325
+ raise ValueError(msg)
251
326
  sigma2 = None
252
327
  else:
253
328
  sigma2_alpha = jnp.asarray(sigma2_alpha)
254
329
  sigma2_beta = jnp.asarray(sigma2_beta)
255
330
  sigma2 = sigma2_beta / sigma2_alpha
256
- # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
257
- # TODO: I don't like this isfinite check, these functions should be
258
- # low-level and just do the thing. Why was it here?
259
331
 
260
- bart = State(
332
+ max_split = jnp.asarray(max_split)
333
+
334
+ if filter_splitless_vars:
335
+ (blocked_vars,) = jnp.nonzero(max_split == 0)
336
+ blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
337
+ # see `fully_used_variables` for the type cast
338
+ else:
339
+ blocked_vars = None
340
+
341
+ # check and initialize sparsity parameters
342
+ if not _all_none_or_not_none(rho, a, b):
343
+ msg = 'rho, a, b are not either all `None` or all set'
344
+ raise ValueError(msg)
345
+ if theta is None and rho is not None:
346
+ theta = rho
347
+ if log_s is None and theta is not None:
348
+ log_s = jnp.zeros(max_split.size)
349
+
350
+ return State(
261
351
  X=jnp.asarray(X),
262
- max_split=jnp.asarray(max_split),
263
352
  y=y,
264
353
  z=jnp.full(y.shape, offset) if is_binary else None,
265
354
  offset=offset,
@@ -271,41 +360,54 @@ def init(
271
360
  sigma2_alpha=sigma2_alpha,
272
361
  sigma2_beta=sigma2_beta,
273
362
  forest=Forest(
274
- leaf_trees=make_forest(max_depth, jnp.float32),
275
- var_trees=make_forest(
276
- max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
363
+ leaf_tree=make_forest(max_depth, jnp.float32),
364
+ var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
365
+ split_tree=make_forest(max_depth - 1, max_split.dtype),
366
+ affluence_tree=(
367
+ make_forest(max_depth - 1, bool)
368
+ .at[:, 1]
369
+ .set(
370
+ True
371
+ if min_points_per_decision_node is None
372
+ else y.size >= min_points_per_decision_node
373
+ )
277
374
  ),
278
- split_trees=make_forest(max_depth - 1, max_split.dtype),
375
+ blocked_vars=blocked_vars,
376
+ max_split=max_split,
279
377
  grow_prop_count=jnp.zeros((), int),
280
378
  grow_acc_count=jnp.zeros((), int),
281
379
  prune_prop_count=jnp.zeros((), int),
282
380
  prune_acc_count=jnp.zeros((), int),
283
- p_nonterminal=p_nonterminal,
381
+ p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
284
382
  p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
285
383
  leaf_indices=jnp.ones(
286
384
  (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
287
385
  ),
288
- min_points_per_leaf=(
289
- None
290
- if min_points_per_leaf is None
291
- else jnp.asarray(min_points_per_leaf)
292
- ),
293
- affluence_trees=(
294
- None
295
- if min_points_per_leaf is None
296
- else make_forest(max_depth - 1, bool)
297
- .at[:, 1]
298
- .set(y.size >= 2 * min_points_per_leaf)
299
- ),
386
+ min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
387
+ min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
300
388
  resid_batch_size=resid_batch_size,
301
389
  count_batch_size=count_batch_size,
302
- log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
303
- log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
390
+ log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
391
+ log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
304
392
  sigma_mu2=jnp.asarray(sigma_mu2),
393
+ log_s=_asarray_or_none(log_s),
394
+ theta=_asarray_or_none(theta),
395
+ rho=_asarray_or_none(rho),
396
+ a=_asarray_or_none(a),
397
+ b=_asarray_or_none(b),
305
398
  ),
306
399
  )
307
400
 
308
- return bart
401
+
402
+ def _all_none_or_not_none(*args):
403
+ is_none = [x is None for x in args]
404
+ return all(is_none) or not any(is_none)
405
+
406
+
407
+ def _asarray_or_none(x):
408
+ if x is None:
409
+ return None
410
+ return jnp.asarray(x)
309
411
 
310
412
 
311
413
  def _choose_suffstat_batch_size(
@@ -319,16 +421,17 @@ def _choose_suffstat_batch_size(
319
421
  device = jax.devices()[0]
320
422
  platform = device.platform
321
423
  if platform not in ('cpu', 'gpu'):
322
- raise KeyError(f'Unknown platform: {platform}')
424
+ msg = f'Unknown platform: {platform}'
425
+ raise KeyError(msg)
323
426
  return platform
324
427
 
325
428
  if resid_batch_size == 'auto':
326
429
  platform = get_platform()
327
430
  n = max(1, y.size)
328
431
  if platform == 'cpu':
329
- resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
432
+ resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
330
433
  elif platform == 'gpu':
331
- resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
434
+ resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
332
435
  resid_batch_size = max(1, resid_batch_size)
333
436
 
334
437
  if count_batch_size == 'auto':
@@ -337,11 +440,11 @@ def _choose_suffstat_batch_size(
337
440
  count_batch_size = None
338
441
  elif platform == 'gpu':
339
442
  n = max(1, y.size)
340
- count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
443
+ count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
341
444
  # /4 is good on V100, /2 on L4/T4, still haven't tried A100
342
445
  max_memory = 2**29
343
446
  itemsize = 4
344
- min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
447
+ min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
345
448
  count_batch_size = max(count_batch_size, min_batch_size)
346
449
  count_batch_size = max(1, count_batch_size)
347
450
 
@@ -397,7 +500,7 @@ def step_trees(key: Key[Array, ''], bart: State) -> State:
397
500
  This function zeroes the proposal counters.
398
501
  """
399
502
  keys = split(key)
400
- moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
503
+ moves = propose_moves(keys.pop(), bart.forest)
401
504
  return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
402
505
 
403
506
 
@@ -408,9 +511,11 @@ class Moves(Module):
408
511
  Parameters
409
512
  ----------
410
513
  allowed
411
- Whether the move is possible in the first place. There are additional
412
- constraints that could forbid it, but they are computed at acceptance
413
- time.
514
+ Whether there is a possible move. If `False`, the other values may not
515
+ make sense. The only case in which a move is marked as allowed but is
516
+ then vetoed is if it does not satisfy `min_points_per_leaf`, which for
517
+ efficiency is implemented post-hoc without changing the rest of the
518
+ MCMC logic.
414
519
  grow
415
520
  Whether the move is a grow move or a prune move.
416
521
  num_growable
@@ -421,20 +526,27 @@ class Moves(Module):
421
526
  right
422
527
  The indices of the children of 'node'.
423
528
  partial_ratio
424
- A factor of the Metropolis-Hastings ratio of the move. It lacks
425
- the likelihood ratio and the probability of proposing the prune
426
- move. If the move is PRUNE, the ratio is inverted. `None` once
529
+ A factor of the Metropolis-Hastings ratio of the move. It lacks the
530
+ likelihood ratio, the probability of proposing the prune move, and the
531
+ probability that the children of the modified node are terminal. If the
532
+ move is PRUNE, the ratio is inverted. `None` once
427
533
  `log_trans_prior_ratio` has been computed.
428
534
  log_trans_prior_ratio
429
535
  The logarithm of the product of the transition and prior terms of the
430
536
  Metropolis-Hastings ratio for the acceptance of the proposed move.
431
- `None` if not yet computed.
537
+ `None` if not yet computed. If PRUNE, the log-ratio is negated.
432
538
  grow_var
433
539
  The decision axes of the new rules.
434
540
  grow_split
435
541
  The decision boundaries of the new rules.
436
- var_trees
542
+ var_tree
437
543
  The updated decision axes of the trees, valid whatever move.
544
+ affluence_tree
545
+ A partially updated `affluence_tree`, marking non-leaf nodes that would
546
+ become leaves if the move was accepted. This mark initially (out of
547
+ `propose_moves`) takes into account if there would be available decision
548
+ rules to grow the leaf, and whether there are enough datapoints in the
549
+ node is marked in `accept_moves_parallel_stage`.
438
550
  logu
439
551
  The logarithm of a uniform (0, 1] random variable to be used to
440
552
  accept the move. It's in (-oo, 0].
@@ -446,25 +558,24 @@ class Moves(Module):
446
558
  computed.
447
559
  """
448
560
 
449
- allowed: Bool[Array, 'num_trees']
450
- grow: Bool[Array, 'num_trees']
451
- num_growable: UInt[Array, 'num_trees']
452
- node: UInt[Array, 'num_trees']
453
- left: UInt[Array, 'num_trees']
454
- right: UInt[Array, 'num_trees']
455
- partial_ratio: Float32[Array, 'num_trees'] | None
456
- log_trans_prior_ratio: None | Float32[Array, 'num_trees']
457
- grow_var: UInt[Array, 'num_trees']
458
- grow_split: UInt[Array, 'num_trees']
459
- var_trees: UInt[Array, 'num_trees 2**(d-1)']
460
- logu: Float32[Array, 'num_trees']
461
- acc: None | Bool[Array, 'num_trees']
462
- to_prune: None | Bool[Array, 'num_trees']
561
+ allowed: Bool[Array, ' num_trees']
562
+ grow: Bool[Array, ' num_trees']
563
+ num_growable: UInt[Array, ' num_trees']
564
+ node: UInt[Array, ' num_trees']
565
+ left: UInt[Array, ' num_trees']
566
+ right: UInt[Array, ' num_trees']
567
+ partial_ratio: Float32[Array, ' num_trees'] | None
568
+ log_trans_prior_ratio: None | Float32[Array, ' num_trees']
569
+ grow_var: UInt[Array, ' num_trees']
570
+ grow_split: UInt[Array, ' num_trees']
571
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
572
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
573
+ logu: Float32[Array, ' num_trees']
574
+ acc: None | Bool[Array, ' num_trees']
575
+ to_prune: None | Bool[Array, ' num_trees']
463
576
 
464
577
 
465
- def propose_moves(
466
- key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
467
- ) -> Moves:
578
+ def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
468
579
  """
469
580
  Propose moves for all the trees.
470
581
 
@@ -478,39 +589,40 @@ def propose_moves(
478
589
  A jax random key.
479
590
  forest
480
591
  The `forest` field of a BART MCMC state.
481
- max_split
482
- The maximum split index for each variable, found in `State`.
483
592
 
484
593
  Returns
485
594
  -------
486
595
  The proposed move for each tree.
487
596
  """
488
- num_trees, _ = forest.leaf_trees.shape
597
+ num_trees, _ = forest.leaf_tree.shape
489
598
  keys = split(key, 1 + 2 * num_trees)
490
599
 
491
600
  # compute moves
492
601
  grow_moves = propose_grow_moves(
493
602
  keys.pop(num_trees),
494
- forest.var_trees,
495
- forest.split_trees,
496
- forest.affluence_trees,
497
- max_split,
603
+ forest.var_tree,
604
+ forest.split_tree,
605
+ forest.affluence_tree,
606
+ forest.max_split,
607
+ forest.blocked_vars,
498
608
  forest.p_nonterminal,
499
609
  forest.p_propose_grow,
610
+ forest.log_s,
500
611
  )
501
612
  prune_moves = propose_prune_moves(
502
613
  keys.pop(num_trees),
503
- forest.split_trees,
504
- forest.affluence_trees,
614
+ forest.split_tree,
615
+ grow_moves.affluence_tree,
505
616
  forest.p_nonterminal,
506
617
  forest.p_propose_grow,
507
618
  )
508
619
 
509
- u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
620
+ u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
510
621
 
511
622
  # choose between grow or prune
512
- grow_allowed = grow_moves.num_growable.astype(bool)
513
- p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
623
+ p_grow = jnp.where(
624
+ grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
625
+ )
514
626
  grow = u < p_grow # use < instead of <= because u is in [0, 1)
515
627
 
516
628
  # compute children indices
@@ -519,7 +631,7 @@ def propose_moves(
519
631
  right = left + 1
520
632
 
521
633
  return Moves(
522
- allowed=grow | prune_moves.allowed,
634
+ allowed=grow_moves.allowed | prune_moves.allowed,
523
635
  grow=grow,
524
636
  num_growable=grow_moves.num_growable,
525
637
  node=node,
@@ -531,8 +643,11 @@ def propose_moves(
531
643
  log_trans_prior_ratio=None, # will be set in complete_ratio
532
644
  grow_var=grow_moves.var,
533
645
  grow_split=grow_moves.split,
534
- var_trees=grow_moves.var_tree,
535
- logu=jnp.log1p(-logu),
646
+ # var_tree does not need to be updated if prune
647
+ var_tree=grow_moves.var_tree,
648
+ # affluence_tree is updated for both moves unconditionally, prune last
649
+ affluence_tree=prune_moves.affluence_tree,
650
+ logu=jnp.log1p(-exp1mlogu),
536
651
  acc=None, # will be set in accept_moves_sequential_stage
537
652
  to_prune=None, # will be set in accept_moves_sequential_stage
538
653
  )
@@ -544,8 +659,10 @@ class GrowMoves(Module):
544
659
 
545
660
  Parameters
546
661
  ----------
662
+ allowed
663
+ Whether the move is allowed for proposal.
547
664
  num_growable
548
- The number of growable leaves.
665
+ The number of leaves that can be proposed for grow.
549
666
  node
550
667
  The index of the leaf to grow. ``2 ** d`` if there are no growable
551
668
  leaves.
@@ -558,25 +675,32 @@ class GrowMoves(Module):
558
675
  move.
559
676
  var_tree
560
677
  The updated decision axes of the tree.
678
+ affluence_tree
679
+ A partially updated `affluence_tree` that marks each new leaf that
680
+ would be produced as `True` if it would have available decision rules.
561
681
  """
562
682
 
563
- num_growable: UInt[Array, 'num_trees']
564
- node: UInt[Array, 'num_trees']
565
- var: UInt[Array, 'num_trees']
566
- split: UInt[Array, 'num_trees']
567
- partial_ratio: Float32[Array, 'num_trees']
683
+ allowed: Bool[Array, ' num_trees']
684
+ num_growable: UInt[Array, ' num_trees']
685
+ node: UInt[Array, ' num_trees']
686
+ var: UInt[Array, ' num_trees']
687
+ split: UInt[Array, ' num_trees']
688
+ partial_ratio: Float32[Array, ' num_trees']
568
689
  var_tree: UInt[Array, 'num_trees 2**(d-1)']
690
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
569
691
 
570
692
 
571
- @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
693
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
572
694
  def propose_grow_moves(
573
- key: Key[Array, ''],
574
- var_tree: UInt[Array, '2**(d-1)'],
575
- split_tree: UInt[Array, '2**(d-1)'],
576
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
577
- max_split: UInt[Array, 'p'],
578
- p_nonterminal: Float32[Array, 'd'],
579
- p_propose_grow: Float32[Array, '2**(d-1)'],
695
+ key: Key[Array, ' num_trees'],
696
+ var_tree: UInt[Array, 'num_trees 2**(d-1)'],
697
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'],
698
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
699
+ max_split: UInt[Array, ' p'],
700
+ blocked_vars: Int32[Array, ' k'] | None,
701
+ p_nonterminal: Float32[Array, ' 2**d'],
702
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
703
+ log_s: Float32[Array, ' p'] | None,
580
704
  ) -> GrowMoves:
581
705
  """
582
706
  Propose a GROW move for each tree.
@@ -593,13 +717,19 @@ def propose_grow_moves(
593
717
  split_tree
594
718
  The splitting points of the tree.
595
719
  affluence_tree
596
- Whether a leaf has enough points to be grown.
720
+ Whether each leaf has enough points to be grown.
597
721
  max_split
598
722
  The maximum split index for each variable.
723
+ blocked_vars
724
+ The indices of the variables that have no available cutpoints.
599
725
  p_nonterminal
600
- The probability of a nonterminal node at each depth.
726
+ The a priori probability of a node to be nonterminal conditional on the
727
+ ancestors, including at the maximum depth where it should be zero.
601
728
  p_propose_grow
602
729
  The unnormalized probability of choosing a leaf to grow.
730
+ log_s
731
+ Unnormalized log-probability used to choose a variable to split on
732
+ amongst the available ones.
603
733
 
604
734
  Returns
605
735
  -------
@@ -607,16 +737,10 @@ def propose_grow_moves(
607
737
 
608
738
  Notes
609
739
  -----
610
- The move is not proposed if a leaf is already at maximum depth, or if a leaf
611
- has less than twice the requested minimum number of datapoints per leaf.
612
- This is marked by returning `num_growable` set to 0.
613
-
614
- The move is also not be possible if the ancestors of a leaf have
615
- exhausted the possible decision rules that lead to a non-empty selection.
616
- This is marked by returning `var` set to `p` and `split` set to 0. But this
617
- does not block the move from counting as "proposed", even though it is
618
- predictably going to be rejected. This simplifies the MCMC and should not
619
- reduce efficiency if not in unrealistic corner cases.
740
+ The move is not proposed if each leaf is already at maximum depth, or has
741
+ less datapoints than the requested threshold `min_points_per_decision_node`,
742
+ or it does not have any available decision rules given its ancestors. This
743
+ is marked by setting `allowed` to `False` and `num_growable` to 0.
620
744
  """
621
745
  keys = split(key, 3)
622
746
 
@@ -624,36 +748,45 @@ def propose_grow_moves(
624
748
  keys.pop(), split_tree, affluence_tree, p_propose_grow
625
749
  )
626
750
 
627
- var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
628
- var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
751
+ # sample a decision rule
752
+ var, num_available_var = choose_variable(
753
+ keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
754
+ )
755
+ split_idx, l, r = choose_split(
756
+ keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
757
+ )
629
758
 
630
- split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
759
+ # determine if the new leaves would have available decision rules; if the
760
+ # move is blocked, these values may not make sense
761
+ left_growable = right_growable = num_available_var > 1
762
+ left_growable |= l < split_idx
763
+ right_growable |= split_idx + 1 < r
764
+ left = leaf_to_grow << 1
765
+ right = left + 1
766
+ affluence_tree = affluence_tree.at[left].set(left_growable)
767
+ affluence_tree = affluence_tree.at[right].set(right_growable)
631
768
 
632
769
  ratio = compute_partial_ratio(
633
770
  prob_choose, num_prunable, p_nonterminal, leaf_to_grow
634
771
  )
635
772
 
636
773
  return GrowMoves(
774
+ allowed=num_growable > 0,
637
775
  num_growable=num_growable,
638
776
  node=leaf_to_grow,
639
777
  var=var,
640
778
  split=split_idx,
641
779
  partial_ratio=ratio,
642
- var_tree=var_tree,
780
+ var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
781
+ affluence_tree=affluence_tree,
643
782
  )
644
783
 
645
- # TODO it is not clear to me how var=p and split=0 when the move is not
646
- # possible lead to corrent behavior downstream. Like, the move is proposed,
647
- # but then it's a noop? And since it's a noop, it makes no difference if
648
- # it's "accepted" or "rejected", it's like it's always rejected, so who
649
- # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
650
-
651
784
 
652
785
  def choose_leaf(
653
786
  key: Key[Array, ''],
654
- split_tree: UInt[Array, '2**(d-1)'],
655
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
656
- p_propose_grow: Float32[Array, '2**(d-1)'],
787
+ split_tree: UInt[Array, ' 2**(d-1)'],
788
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
789
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
657
790
  ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
658
791
  """
659
792
  Choose a leaf node to grow in a tree.
@@ -672,16 +805,16 @@ def choose_leaf(
672
805
 
673
806
  Returns
674
807
  -------
675
- leaf_to_grow : int
808
+ leaf_to_grow : Int32[Array, '']
676
809
  The index of the leaf to grow. If ``num_growable == 0``, return
677
810
  ``2 ** d``.
678
- num_growable : int
811
+ num_growable : Int32[Array, '']
679
812
  The number of leaf nodes that can be grown, i.e., are nonterminal
680
- and have at least twice `min_points_per_leaf` if set.
681
- prob_choose : float
813
+ and have at least twice `min_points_per_leaf`.
814
+ prob_choose : Float32[Array, '']
682
815
  The (normalized) probability that this function had to choose that
683
816
  specific leaf, given the arguments.
684
- num_prunable : int
817
+ num_prunable : Int32[Array, '']
685
818
  The number of leaf parents that could be pruned, after converting the
686
819
  selected leaf to a non-terminal node.
687
820
  """
@@ -690,41 +823,43 @@ def choose_leaf(
690
823
  distr = jnp.where(is_growable, p_propose_grow, 0)
691
824
  leaf_to_grow, distr_norm = categorical(key, distr)
692
825
  leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
693
- prob_choose = distr[leaf_to_grow] / distr_norm
826
+ prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
694
827
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
695
828
  num_prunable = jnp.count_nonzero(is_parent)
696
829
  return leaf_to_grow, num_growable, prob_choose, num_prunable
697
830
 
698
831
 
699
832
  def growable_leaves(
700
- split_tree: UInt[Array, '2**(d-1)'],
701
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
702
- ) -> Bool[Array, '2**(d-1)']:
833
+ split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
834
+ ) -> Bool[Array, ' 2**(d-1)']:
703
835
  """
704
836
  Return a mask indicating the leaf nodes that can be proposed for growth.
705
837
 
706
- The condition is that a leaf is not at the bottom level and has at least two
707
- times the number of minimum points per leaf.
838
+ The condition is that a leaf is not at the bottom level, has available
839
+ decision rules given its ancestors, and has at least
840
+ `min_points_per_decision_node` points.
708
841
 
709
842
  Parameters
710
843
  ----------
711
844
  split_tree
712
845
  The splitting points of the tree.
713
846
  affluence_tree
714
- Whether a leaf has enough points to be grown.
847
+ Marks leaves that can be grown.
715
848
 
716
849
  Returns
717
850
  -------
718
851
  The mask indicating the leaf nodes that can be proposed to grow.
852
+
853
+ Notes
854
+ -----
855
+ This function needs `split_tree` and not just `affluence_tree` because
856
+ `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
719
857
  """
720
- is_growable = grove.is_actual_leaf(split_tree)
721
- if affluence_tree is not None:
722
- is_growable &= affluence_tree
723
- return is_growable
858
+ return grove.is_actual_leaf(split_tree) & affluence_tree
724
859
 
725
860
 
726
861
  def categorical(
727
- key: Key[Array, ''], distr: Float32[Array, 'n']
862
+ key: Key[Array, ''], distr: Float32[Array, ' n']
728
863
  ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
729
864
  """
730
865
  Return a random integer from an arbitrary distribution.
@@ -743,6 +878,11 @@ def categorical(
743
878
  return ``n``.
744
879
  norm : Float32[Array, '']
745
880
  The sum of `distr`.
881
+
882
+ Notes
883
+ -----
884
+ This function uses a cumsum instead of the Gumbel trick, so it's ok only
885
+ for small ranges with probabilities well greater than 0.
746
886
  """
747
887
  ecdf = jnp.cumsum(distr)
748
888
  u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
@@ -751,11 +891,13 @@ def categorical(
751
891
 
752
892
  def choose_variable(
753
893
  key: Key[Array, ''],
754
- var_tree: UInt[Array, '2**(d-1)'],
755
- split_tree: UInt[Array, '2**(d-1)'],
756
- max_split: UInt[Array, 'p'],
894
+ var_tree: UInt[Array, ' 2**(d-1)'],
895
+ split_tree: UInt[Array, ' 2**(d-1)'],
896
+ max_split: UInt[Array, ' p'],
757
897
  leaf_index: Int32[Array, ''],
758
- ) -> Int32[Array, '']:
898
+ blocked_vars: Int32[Array, ' k'] | None,
899
+ log_s: Float32[Array, ' p'] | None,
900
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
759
901
  """
760
902
  Choose a variable to split on for a new non-terminal node.
761
903
 
@@ -771,28 +913,39 @@ def choose_variable(
771
913
  The maximum split index for each variable.
772
914
  leaf_index
773
915
  The index of the leaf to grow.
916
+ blocked_vars
917
+ The indices of the variables that have no available cutpoints. If
918
+ `None`, all variables are assumed unblocked.
919
+ log_s
920
+ The logarithm of the prior probability for choosing a variable. If
921
+ `None`, use a uniform distribution.
774
922
 
775
923
  Returns
776
924
  -------
777
- The index of the variable to split on.
778
-
779
- Notes
780
- -----
781
- The variable is chosen among the variables that have a non-empty range of
782
- allowed splits. If no variable has a non-empty range, return `p`.
925
+ var : Int32[Array, '']
926
+ The index of the variable to split on.
927
+ num_available_var : Int32[Array, '']
928
+ The number of variables with available decision rules `var` was chosen
929
+ from.
783
930
  """
784
931
  var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
785
- return randint_exclude(key, max_split.size, var_to_ignore)
932
+ if blocked_vars is not None:
933
+ var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
934
+
935
+ if log_s is None:
936
+ return randint_exclude(key, max_split.size, var_to_ignore)
937
+ else:
938
+ return categorical_exclude(key, log_s, var_to_ignore)
786
939
 
787
940
 
788
941
  def fully_used_variables(
789
- var_tree: UInt[Array, '2**(d-1)'],
790
- split_tree: UInt[Array, '2**(d-1)'],
791
- max_split: UInt[Array, 'p'],
942
+ var_tree: UInt[Array, ' 2**(d-1)'],
943
+ split_tree: UInt[Array, ' 2**(d-1)'],
944
+ max_split: UInt[Array, ' p'],
792
945
  leaf_index: Int32[Array, ''],
793
- ) -> UInt[Array, 'd-2']:
946
+ ) -> UInt[Array, ' d-2']:
794
947
  """
795
- Return a list of variables that have an empty split range at a given node.
948
+ Find variables in the ancestors of a node that have an empty split range.
796
949
 
797
950
  Parameters
798
951
  ----------
@@ -820,23 +973,25 @@ def fully_used_variables(
820
973
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
821
974
  num_split = r - l
822
975
  return jnp.where(num_split == 0, var_to_ignore, max_split.size)
976
+ # the type of var_to_ignore is already sufficient to hold max_split.size,
977
+ # see ancestor_variables()
823
978
 
824
979
 
825
980
  def ancestor_variables(
826
- var_tree: UInt[Array, '2**(d-1)'],
827
- max_split: UInt[Array, 'p'],
981
+ var_tree: UInt[Array, ' 2**(d-1)'],
982
+ max_split: UInt[Array, ' p'],
828
983
  node_index: Int32[Array, ''],
829
- ) -> UInt[Array, 'd-2']:
984
+ ) -> UInt[Array, ' d-2']:
830
985
  """
831
986
  Return the list of variables in the ancestors of a node.
832
987
 
833
988
  Parameters
834
989
  ----------
835
- var_tree : int array (2 ** (d - 1),)
990
+ var_tree
836
991
  The variable indices of the tree.
837
- max_split : int array (p,)
992
+ max_split
838
993
  The maximum split index for each variable. Used only to get `p`.
839
- node_index : int
994
+ node_index
840
995
  The index of the node, assumed to be valid for `var_tree`.
841
996
 
842
997
  Returns
@@ -866,9 +1021,9 @@ def ancestor_variables(
866
1021
 
867
1022
 
868
1023
  def split_range(
869
- var_tree: UInt[Array, '2**(d-1)'],
870
- split_tree: UInt[Array, '2**(d-1)'],
871
- max_split: UInt[Array, 'p'],
1024
+ var_tree: UInt[Array, ' 2**(d-1)'],
1025
+ split_tree: UInt[Array, ' 2**(d-1)'],
1026
+ max_split: UInt[Array, ' p'],
872
1027
  node_index: Int32[Array, ''],
873
1028
  ref_var: Int32[Array, ''],
874
1029
  ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
@@ -890,13 +1045,13 @@ def split_range(
890
1045
 
891
1046
  Returns
892
1047
  -------
893
- The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
1048
+ The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
894
1049
  """
895
1050
  max_num_ancestors = grove.tree_depth(var_tree) - 1
896
1051
  initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
897
1052
  jnp.int32
898
1053
  )
899
- carry = 0, initial_r, node_index
1054
+ carry = jnp.int32(0), initial_r, node_index
900
1055
 
901
1056
  def loop(carry, _):
902
1057
  l, r, index = carry
@@ -913,8 +1068,8 @@ def split_range(
913
1068
 
914
1069
 
915
1070
  def randint_exclude(
916
- key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
917
- ) -> Int32[Array, '']:
1071
+ key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1072
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
918
1073
  """
919
1074
  Return a random integer in a range, excluding some values.
920
1075
 
@@ -930,30 +1085,74 @@ def randint_exclude(
930
1085
 
931
1086
  Returns
932
1087
  -------
933
- A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
1088
+ u : Int32[Array, '']
1089
+ A random integer `u` in the range ``[0, sup)`` such that ``u not in
1090
+ exclude``.
1091
+ num_allowed : Int32[Array, '']
1092
+ The number of integers in the range that were not excluded.
934
1093
 
935
1094
  Notes
936
1095
  -----
937
1096
  If all values in the range are excluded, return `sup`.
938
1097
  """
939
- exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
940
- num_allowed = sup - jnp.count_nonzero(exclude < sup)
1098
+ exclude, num_allowed = _process_exclude(sup, exclude)
941
1099
  u = random.randint(key, (), 0, num_allowed)
942
1100
 
943
- def loop(u, i):
944
- return jnp.where(i <= u, u + 1, u), None
1101
+ def loop(u, i_excluded):
1102
+ return jnp.where(i_excluded <= u, u + 1, u), None
945
1103
 
946
1104
  u, _ = lax.scan(loop, u, exclude)
947
- return u
1105
+ return u, num_allowed
1106
+
1107
+
1108
+ def _process_exclude(sup, exclude):
1109
+ exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
1110
+ num_allowed = sup - jnp.count_nonzero(exclude < sup)
1111
+ return exclude, num_allowed
1112
+
1113
+
1114
+ def categorical_exclude(
1115
+ key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1116
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1117
+ """
1118
+ Draw from a categorical distribution, excluding a set of values.
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ key
1123
+ A jax random key.
1124
+ logits
1125
+ The unnormalized log-probabilities of each category.
1126
+ exclude
1127
+ The values to exclude from the range [0, k). Values greater than or
1128
+ equal to `logits.size` are ignored. Values can appear more than once.
1129
+
1130
+ Returns
1131
+ -------
1132
+ u : Int32[Array, '']
1133
+ A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1134
+ num_allowed : Int32[Array, '']
1135
+ The number of integers in the range that were not excluded.
1136
+
1137
+ Notes
1138
+ -----
1139
+ If all values in the range are excluded, the result is unspecified.
1140
+ """
1141
+ exclude, num_allowed = _process_exclude(logits.size, exclude)
1142
+ kinda_neg_inf = jnp.finfo(logits.dtype).min
1143
+ logits = logits.at[exclude].set(kinda_neg_inf)
1144
+ u = random.categorical(key, logits)
1145
+ return u, num_allowed
948
1146
 
949
1147
 
950
1148
  def choose_split(
951
1149
  key: Key[Array, ''],
952
- var_tree: UInt[Array, '2**(d-1)'],
953
- split_tree: UInt[Array, '2**(d-1)'],
954
- max_split: UInt[Array, 'p'],
1150
+ var: Int32[Array, ''],
1151
+ var_tree: UInt[Array, ' 2**(d-1)'],
1152
+ split_tree: UInt[Array, ' 2**(d-1)'],
1153
+ max_split: UInt[Array, ' p'],
955
1154
  leaf_index: Int32[Array, ''],
956
- ) -> Int32[Array, '']:
1155
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
957
1156
  """
958
1157
  Choose a split point for a new non-terminal node.
959
1158
 
@@ -961,32 +1160,39 @@ def choose_split(
961
1160
  ----------
962
1161
  key
963
1162
  A jax random key.
1163
+ var
1164
+ The variable to split on.
964
1165
  var_tree
965
- The splitting axes of the tree.
1166
+ The splitting axes of the tree. Does not need to already contain `var`
1167
+ at `leaf_index`.
966
1168
  split_tree
967
1169
  The splitting points of the tree.
968
1170
  max_split
969
1171
  The maximum split index for each variable.
970
1172
  leaf_index
971
- The index of the leaf to grow. It is assumed that `var_tree` already
972
- contains the target variable at this index.
1173
+ The index of the leaf to grow.
973
1174
 
974
1175
  Returns
975
1176
  -------
976
- The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
1177
+ split : Int32[Array, '']
1178
+ The cutpoint.
1179
+ l : Int32[Array, '']
1180
+ r : Int32[Array, '']
1181
+ The integer range `split` was drawn from is [l, r).
1182
+
1183
+ Notes
1184
+ -----
1185
+ If `var` is out of bounds, or if the available split range on that variable
1186
+ is empty, return 0.
977
1187
  """
978
- var = var_tree[leaf_index]
979
1188
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
980
- return random.randint(key, (), l, r)
981
-
982
- # TODO what happens if leaf_index is out of bounds? And is the value used
983
- # in that case?
1189
+ return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
984
1190
 
985
1191
 
986
1192
  def compute_partial_ratio(
987
1193
  prob_choose: Float32[Array, ''],
988
1194
  num_prunable: Int32[Array, ''],
989
- p_nonterminal: Float32[Array, 'd'],
1195
+ p_nonterminal: Float32[Array, ' 2**d'],
990
1196
  leaf_to_grow: Int32[Array, ''],
991
1197
  ) -> Float32[Array, '']:
992
1198
  """
@@ -1001,7 +1207,8 @@ def compute_partial_ratio(
1001
1207
  The number of leaf parents that could be pruned, after converting the
1002
1208
  leaf to be grown to a non-terminal node.
1003
1209
  p_nonterminal
1004
- The probability of a nonterminal node at each depth.
1210
+ The a priori probability of each node being nonterminal conditional on
1211
+ its ancestors.
1005
1212
  leaf_to_grow
1006
1213
  The index of the leaf to grow.
1007
1214
 
@@ -1013,29 +1220,29 @@ def compute_partial_ratio(
1013
1220
  -----
1014
1221
  The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1015
1222
  The "partial" transition ratio returned is missing the factor P(propose
1016
- prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
1223
+ prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1224
+ "partial" prior ratio is missing the factor P(children are leaves).
1017
1225
  """
1018
1226
  # the two ratios also contain factors num_available_split *
1019
- # num_available_var, but they cancel out
1227
+ # num_available_var * s[var], but they cancel out
1020
1228
 
1021
- # p_prune can't be computed here because it needs the count trees, which are
1022
- # computed in the acceptance phase
1229
+ # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1230
+ # computed here because they need the count trees, which are computed in the
1231
+ # acceptance phase
1023
1232
 
1024
1233
  prune_allowed = leaf_to_grow != 1
1025
1234
  # prune allowed <---> the initial tree is not a root
1026
1235
  # leaf to grow is root --> the tree can only be a root
1027
1236
  # tree is a root --> the only leaf I can grow is root
1028
-
1029
1237
  p_grow = jnp.where(prune_allowed, 0.5, 1)
1030
-
1031
1238
  inv_trans_ratio = p_grow * prob_choose * num_prunable
1032
1239
 
1033
- depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
1034
- p_parent = p_nonterminal[depth]
1035
- cp_children = 1 - p_nonterminal[depth + 1]
1036
- tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
1240
+ # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1241
+ # would produce a 0 and then an inf when `complete_ratio` takes the log
1242
+ pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
1243
+ tree_ratio = pnt / (1 - pnt)
1037
1244
 
1038
- return tree_ratio / inv_trans_ratio
1245
+ return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
1039
1246
 
1040
1247
 
1041
1248
  class PruneMoves(Module):
@@ -1049,24 +1256,26 @@ class PruneMoves(Module):
1049
1256
  node
1050
1257
  The index of the node to prune. ``2 ** d`` if no node can be pruned.
1051
1258
  partial_ratio
1052
- A factor of the Metropolis-Hastings ratio of the move. It lacks
1053
- the likelihood ratio and the probability of proposing the prune
1054
- move. This ratio is inverted, and is meant to be inverted back in
1259
+ A factor of the Metropolis-Hastings ratio of the move. It lacks the
1260
+ likelihood ratio, the probability of proposing the prune move, and the
1261
+ prior probability that the children of the node to prune are leaves.
1262
+ This ratio is inverted, and is meant to be inverted back in
1055
1263
  `accept_move_and_sample_leaves`.
1056
1264
  """
1057
1265
 
1058
- allowed: Bool[Array, 'num_trees']
1059
- node: UInt[Array, 'num_trees']
1060
- partial_ratio: Float32[Array, 'num_trees']
1266
+ allowed: Bool[Array, ' num_trees']
1267
+ node: UInt[Array, ' num_trees']
1268
+ partial_ratio: Float32[Array, ' num_trees']
1269
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
1061
1270
 
1062
1271
 
1063
1272
  @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1064
1273
  def propose_prune_moves(
1065
1274
  key: Key[Array, ''],
1066
- split_tree: UInt[Array, '2**(d-1)'],
1067
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
1068
- p_nonterminal: Float32[Array, 'd'],
1069
- p_propose_grow: Float32[Array, '2**(d-1)'],
1275
+ split_tree: UInt[Array, ' 2**(d-1)'],
1276
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
1277
+ p_nonterminal: Float32[Array, ' 2**d'],
1278
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
1070
1279
  ) -> PruneMoves:
1071
1280
  """
1072
1281
  Tree structure prune move proposal of BART MCMC.
@@ -1078,9 +1287,10 @@ def propose_prune_moves(
1078
1287
  split_tree
1079
1288
  The splitting points of the tree.
1080
1289
  affluence_tree
1081
- Whether a leaf has enough points to be grown.
1290
+ Whether each leaf can be grown.
1082
1291
  p_nonterminal
1083
- The probability of a nonterminal node at each depth.
1292
+ The a priori probability of a node to be nonterminal conditional on
1293
+ the ancestors, including at the maximum depth where it should be zero.
1084
1294
  p_propose_grow
1085
1295
  The unnormalized probability of choosing a leaf to grow.
1086
1296
 
@@ -1088,28 +1298,33 @@ def propose_prune_moves(
1088
1298
  -------
1089
1299
  An object representing the proposed moves.
1090
1300
  """
1091
- node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
1301
+ node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
1092
1302
  key, split_tree, affluence_tree, p_propose_grow
1093
1303
  )
1094
- allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
1095
1304
 
1096
1305
  ratio = compute_partial_ratio(
1097
1306
  prob_choose, num_prunable, p_nonterminal, node_to_prune
1098
1307
  )
1099
1308
 
1100
1309
  return PruneMoves(
1101
- allowed=allowed,
1310
+ allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1102
1311
  node=node_to_prune,
1103
1312
  partial_ratio=ratio,
1313
+ affluence_tree=affluence_tree,
1104
1314
  )
1105
1315
 
1106
1316
 
1107
1317
  def choose_leaf_parent(
1108
1318
  key: Key[Array, ''],
1109
- split_tree: UInt[Array, '2**(d-1)'],
1110
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
1111
- p_propose_grow: Float32[Array, '2**(d-1)'],
1112
- ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
1319
+ split_tree: UInt[Array, ' 2**(d-1)'],
1320
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
1321
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
1322
+ ) -> tuple[
1323
+ Int32[Array, ''],
1324
+ Int32[Array, ''],
1325
+ Float32[Array, ''],
1326
+ Bool[Array, 'num_trees 2**(d-1)'],
1327
+ ]:
1113
1328
  """
1114
1329
  Pick a non-terminal node with leaf children to prune in a tree.
1115
1330
 
@@ -1135,23 +1350,28 @@ def choose_leaf_parent(
1135
1350
  The (normalized) probability that `choose_leaf` would chose
1136
1351
  `node_to_prune` as leaf to grow, if passed the tree where
1137
1352
  `node_to_prune` had been pruned.
1353
+ affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1354
+ A partially updated `affluence_tree`, marking the node to prune as
1355
+ growable.
1138
1356
  """
1357
+ # sample a node to prune
1139
1358
  is_prunable = grove.is_leaves_parent(split_tree)
1140
1359
  num_prunable = jnp.count_nonzero(is_prunable)
1141
1360
  node_to_prune = randint_masked(key, is_prunable)
1142
1361
  node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
1143
1362
 
1363
+ # compute stuff for reverse move
1144
1364
  split_tree = split_tree.at[node_to_prune].set(0)
1145
- if affluence_tree is not None:
1146
- affluence_tree = affluence_tree.at[node_to_prune].set(True)
1365
+ affluence_tree = affluence_tree.at[node_to_prune].set(True)
1147
1366
  is_growable_leaf = growable_leaves(split_tree, affluence_tree)
1148
- prob_choose = p_propose_grow[node_to_prune]
1149
- prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
1367
+ distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
1368
+ prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
1369
+ prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
1150
1370
 
1151
- return node_to_prune, num_prunable, prob_choose
1371
+ return node_to_prune, num_prunable, prob_choose, affluence_tree
1152
1372
 
1153
1373
 
1154
- def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
1374
+ def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
1155
1375
  """
1156
1376
  Return a random integer in a range, including only some values.
1157
1377
 
@@ -1213,9 +1433,9 @@ class Counts(Module):
1213
1433
  Number of datapoints in the parent (``= left + right``).
1214
1434
  """
1215
1435
 
1216
- left: UInt[Array, 'num_trees']
1217
- right: UInt[Array, 'num_trees']
1218
- total: UInt[Array, 'num_trees']
1436
+ left: UInt[Array, ' num_trees']
1437
+ right: UInt[Array, ' num_trees']
1438
+ total: UInt[Array, ' num_trees']
1219
1439
 
1220
1440
 
1221
1441
  class Precs(Module):
@@ -1235,9 +1455,9 @@ class Precs(Module):
1235
1455
  Likelihood precision scale in the parent (``= left + right``).
1236
1456
  """
1237
1457
 
1238
- left: Float32[Array, 'num_trees']
1239
- right: Float32[Array, 'num_trees']
1240
- total: Float32[Array, 'num_trees']
1458
+ left: Float32[Array, ' num_trees']
1459
+ right: Float32[Array, ' num_trees']
1460
+ total: Float32[Array, ' num_trees']
1241
1461
 
1242
1462
 
1243
1463
  class PreLkV(Module):
@@ -1261,10 +1481,10 @@ class PreLkV(Module):
1261
1481
  The **logarithm** of the square root term of the likelihood ratio.
1262
1482
  """
1263
1483
 
1264
- sigma2_left: Float32[Array, 'num_trees']
1265
- sigma2_right: Float32[Array, 'num_trees']
1266
- sigma2_total: Float32[Array, 'num_trees']
1267
- sqrt_term: Float32[Array, 'num_trees']
1484
+ sigma2_left: Float32[Array, ' num_trees']
1485
+ sigma2_right: Float32[Array, ' num_trees']
1486
+ sigma2_total: Float32[Array, ' num_trees']
1487
+ sqrt_term: Float32[Array, ' num_trees']
1268
1488
 
1269
1489
 
1270
1490
  class PreLk(Module):
@@ -1331,7 +1551,6 @@ class ParallelStageOut(Module):
1331
1551
  bart: State
1332
1552
  moves: Moves
1333
1553
  prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1334
- move_counts: Counts | None
1335
1554
  move_precs: Precs | Counts
1336
1555
  prelkv: PreLkV
1337
1556
  prelk: PreLk
@@ -1342,7 +1561,7 @@ def accept_moves_parallel_stage(
1342
1561
  key: Key[Array, ''], bart: State, moves: Moves
1343
1562
  ) -> ParallelStageOut:
1344
1563
  """
1345
- Pre-computes quantities used to accept moves, in parallel across trees.
1564
+ Pre-compute quantities used to accept moves, in parallel across trees.
1346
1565
 
1347
1566
  Parameters
1348
1567
  ----------
@@ -1362,33 +1581,41 @@ def accept_moves_parallel_stage(
1362
1581
  bart,
1363
1582
  forest=replace(
1364
1583
  bart.forest,
1365
- var_trees=moves.var_trees,
1584
+ var_tree=moves.var_tree,
1366
1585
  leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1367
- leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
1586
+ leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1368
1587
  ),
1369
1588
  )
1370
1589
 
1371
1590
  # count number of datapoints per leaf
1372
- if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
1591
+ if (
1592
+ bart.forest.min_points_per_decision_node is not None
1593
+ or bart.forest.min_points_per_leaf is not None
1594
+ or bart.prec_scale is None
1595
+ ):
1373
1596
  count_trees, move_counts = compute_count_trees(
1374
1597
  bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1375
1598
  )
1376
- else:
1377
- # move_counts is passed later to a function, but then is unused under
1378
- # this condition
1379
- move_counts = None
1380
1599
 
1381
- # Check if some nodes can't surely be grown because they don't have enough
1382
- # datapoints. This check is not actually used now, it will be used at the
1383
- # beginning of the next step to propose moves.
1600
+ # mark which leaves & potential leaves have enough points to be grown
1601
+ if bart.forest.min_points_per_decision_node is not None:
1602
+ count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
1603
+ moves = replace(
1604
+ moves,
1605
+ affluence_tree=moves.affluence_tree
1606
+ & (count_half_trees >= bart.forest.min_points_per_decision_node),
1607
+ )
1608
+
1609
+ # copy updated affluence_tree to state
1610
+ bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
1611
+
1612
+ # veto grove move if new leaves don't have enough datapoints
1384
1613
  if bart.forest.min_points_per_leaf is not None:
1385
- count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
1386
- bart = replace(
1387
- bart,
1388
- forest=replace(
1389
- bart.forest,
1390
- affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
1391
- ),
1614
+ moves = replace(
1615
+ moves,
1616
+ allowed=moves.allowed
1617
+ & (move_counts.left >= bart.forest.min_points_per_leaf)
1618
+ & (move_counts.right >= bart.forest.min_points_per_leaf),
1392
1619
  )
1393
1620
 
1394
1621
  # count number of datapoints per leaf, weighted by error precision scale
@@ -1402,18 +1629,23 @@ def accept_moves_parallel_stage(
1402
1629
  moves,
1403
1630
  bart.forest.count_batch_size,
1404
1631
  )
1632
+ assert move_precs is not None
1405
1633
 
1406
1634
  # compute some missing information about moves
1407
- moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
1635
+ moves = complete_ratio(moves, bart.forest.p_nonterminal)
1636
+ save_ratios = bart.forest.log_likelihood is not None
1408
1637
  bart = replace(
1409
1638
  bart,
1410
1639
  forest=replace(
1411
1640
  bart.forest,
1412
1641
  grow_prop_count=jnp.sum(moves.grow),
1413
1642
  prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1643
+ log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1414
1644
  ),
1415
1645
  )
1416
1646
 
1647
+ # pre-compute some likelihood ratio & posterior terms
1648
+ assert bart.sigma2 is not None # `step` shall temporarily set it to 1
1417
1649
  prelkv, prelk = precompute_likelihood_terms(
1418
1650
  bart.sigma2, bart.forest.sigma_mu2, move_precs
1419
1651
  )
@@ -1423,7 +1655,6 @@ def accept_moves_parallel_stage(
1423
1655
  bart=bart,
1424
1656
  moves=moves,
1425
1657
  prec_trees=prec_trees,
1426
- move_counts=move_counts,
1427
1658
  move_precs=move_precs,
1428
1659
  prelkv=prelkv,
1429
1660
  prelk=prelk,
@@ -1453,12 +1684,10 @@ def apply_grow_to_indices(
1453
1684
  """
1454
1685
  left_child = moves.node.astype(leaf_indices.dtype) << 1
1455
1686
  go_right = X[moves.grow_var, :] >= moves.grow_split
1456
- tree_size = jnp.array(2 * moves.var_trees.size)
1687
+ tree_size = jnp.array(2 * moves.var_tree.size)
1457
1688
  node_to_update = jnp.where(moves.grow, moves.node, tree_size)
1458
1689
  return jnp.where(
1459
- leaf_indices == node_to_update,
1460
- left_child + go_right,
1461
- leaf_indices,
1690
+ leaf_indices == node_to_update, left_child + go_right, leaf_indices
1462
1691
  )
1463
1692
 
1464
1693
 
@@ -1486,7 +1715,7 @@ def compute_count_trees(
1486
1715
  The counts of the number of points in the leaves grown or pruned by the
1487
1716
  moves.
1488
1717
  """
1489
- num_trees, tree_size = moves.var_trees.shape
1718
+ num_trees, tree_size = moves.var_tree.shape
1490
1719
  tree_size *= 2
1491
1720
  tree_indices = jnp.arange(num_trees)
1492
1721
 
@@ -1543,7 +1772,7 @@ def _aggregate_scatter(
1543
1772
  indices: Integer[Array, '*'],
1544
1773
  size: int,
1545
1774
  dtype: jnp.dtype,
1546
- ) -> Shaped[Array, '{size}']:
1775
+ ) -> Shaped[Array, ' {size}']:
1547
1776
  return jnp.zeros(size, dtype).at[indices].add(values)
1548
1777
 
1549
1778
 
@@ -1576,7 +1805,7 @@ def _aggregate_batched_alltrees(
1576
1805
 
1577
1806
 
1578
1807
  def compute_prec_trees(
1579
- prec_scale: Float32[Array, 'n'],
1808
+ prec_scale: Float32[Array, ' n'],
1580
1809
  leaf_indices: UInt[Array, 'num_trees n'],
1581
1810
  moves: Moves,
1582
1811
  batch_size: int | None,
@@ -1603,7 +1832,7 @@ def compute_prec_trees(
1603
1832
  precs : Precs
1604
1833
  The likelihood precision scale in the nodes involved in the moves.
1605
1834
  """
1606
- num_trees, tree_size = moves.var_trees.shape
1835
+ num_trees, tree_size = moves.var_tree.shape
1607
1836
  tree_size *= 2
1608
1837
  tree_indices = jnp.arange(num_trees)
1609
1838
 
@@ -1621,7 +1850,7 @@ def compute_prec_trees(
1621
1850
 
1622
1851
 
1623
1852
  def prec_per_leaf(
1624
- prec_scale: Float32[Array, 'n'],
1853
+ prec_scale: Float32[Array, ' n'],
1625
1854
  leaf_indices: UInt[Array, 'num_trees n'],
1626
1855
  tree_size: int,
1627
1856
  batch_size: int | None,
@@ -1651,7 +1880,7 @@ def prec_per_leaf(
1651
1880
 
1652
1881
 
1653
1882
  def _prec_scan(
1654
- prec_scale: Float32[Array, 'n'],
1883
+ prec_scale: Float32[Array, ' n'],
1655
1884
  leaf_indices: UInt[Array, 'num_trees n'],
1656
1885
  tree_size: int,
1657
1886
  ) -> Float32[Array, 'num_trees {tree_size}']:
@@ -1665,7 +1894,7 @@ def _prec_scan(
1665
1894
 
1666
1895
 
1667
1896
  def _prec_vec(
1668
- prec_scale: Float32[Array, 'n'],
1897
+ prec_scale: Float32[Array, ' n'],
1669
1898
  leaf_indices: UInt[Array, 'num_trees n'],
1670
1899
  tree_size: int,
1671
1900
  batch_size: int,
@@ -1675,77 +1904,59 @@ def _prec_vec(
1675
1904
  )
1676
1905
 
1677
1906
 
1678
- def complete_ratio(
1679
- moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1680
- ) -> Moves:
1907
+ def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
1681
1908
  """
1682
1909
  Complete non-likelihood MH ratio calculation.
1683
1910
 
1684
- This function adds the probability of choosing the prune move.
1911
+ This function adds the probability of choosing a prune move over the grow
1912
+ move in the inverse transition, and the a priori probability that the
1913
+ children nodes are leaves.
1685
1914
 
1686
1915
  Parameters
1687
1916
  ----------
1688
1917
  moves
1689
- The proposed moves, see `propose_moves`.
1690
- move_counts
1691
- The counts of the number of points in the the nodes modified by the
1692
- moves.
1693
- min_points_per_leaf
1694
- The minimum number of data points in a leaf node.
1918
+ The proposed moves. Must have already been updated to keep into account
1919
+ the thresholds on the number of datapoints per node, this happens in
1920
+ `accept_moves_parallel_stage`.
1921
+ p_nonterminal
1922
+ The a priori probability of each node being nonterminal conditional on
1923
+ its ancestors, including at the maximum depth where it should be zero.
1695
1924
 
1696
1925
  Returns
1697
1926
  -------
1698
1927
  The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1699
1928
  """
1700
- p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
1701
- return replace(
1702
- moves,
1703
- log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
1704
- partial_ratio=None,
1929
+ # can the leaves can be grown?
1930
+ num_trees, _ = moves.affluence_tree.shape
1931
+ tree_indices = jnp.arange(num_trees)
1932
+ left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
1933
+ mode='fill', fill_value=False
1934
+ )
1935
+ right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
1936
+ mode='fill', fill_value=False
1705
1937
  )
1706
1938
 
1707
-
1708
- def compute_p_prune(
1709
- moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1710
- ) -> Float32[Array, 'num_trees']:
1711
- """
1712
- Compute the probability of proposing a prune move for each tree.
1713
-
1714
- Parameters
1715
- ----------
1716
- moves
1717
- The proposed moves, see `propose_moves`.
1718
- move_counts
1719
- The number of datapoints in the proposed children of the leaf to grow.
1720
- Not used if `min_points_per_leaf` is `None`.
1721
- min_points_per_leaf
1722
- The minimum number of data points in a leaf node.
1723
-
1724
- Returns
1725
- -------
1726
- The probability of proposing a prune move.
1727
-
1728
- Notes
1729
- -----
1730
- This probability is computed for going from the state with the deeper tree
1731
- to the one with the shallower one. This means, if grow: after accepting the
1732
- grow move, if prune: right away.
1733
- """
1734
- # calculation in case the move is grow
1939
+ # p_prune if grow
1735
1940
  other_growable_leaves = moves.num_growable >= 2
1736
- new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
1737
- if min_points_per_leaf is not None:
1738
- assert move_counts is not None
1739
- any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
1740
- any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
1741
- new_leaves_growable &= any_above_threshold
1742
- grow_again_allowed = other_growable_leaves | new_leaves_growable
1941
+ grow_again_allowed = other_growable_leaves | left_growable | right_growable
1743
1942
  grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1744
1943
 
1745
- # calculation in case the move is prune
1944
+ # p_prune if prune
1746
1945
  prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1747
1946
 
1748
- return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1947
+ # select p_prune
1948
+ p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1949
+
1950
+ # prior probability of both children being terminal
1951
+ pt_left = 1 - p_nonterminal[moves.left] * left_growable
1952
+ pt_right = 1 - p_nonterminal[moves.right] * right_growable
1953
+ pt_children = pt_left * pt_right
1954
+
1955
+ return replace(
1956
+ moves,
1957
+ log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1958
+ partial_ratio=None,
1959
+ )
1749
1960
 
1750
1961
 
1751
1962
  @vmap_nodoc
@@ -1815,9 +2026,7 @@ def precompute_likelihood_terms(
1815
2026
  sigma2_total=sigma2_total,
1816
2027
  sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1817
2028
  )
1818
- return prelkv, PreLk(
1819
- exp_factor=sigma_mu2 / (2 * sigma2),
1820
- )
2029
+ return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
1821
2030
 
1822
2031
 
1823
2032
  def precompute_leaf_terms(
@@ -1851,14 +2060,14 @@ def precompute_leaf_terms(
1851
2060
  z = random.normal(key, prec_trees.shape, sigma2.dtype)
1852
2061
  return PreLf(
1853
2062
  mean_factor=var_post / sigma2,
1854
- # mean = mean_lk * prec_lk * var_post
1855
- # resid_tree = mean_lk * prec_tree -->
1856
- # --> mean_lk = resid_tree / prec_tree (kind of)
1857
- # mean_factor =
1858
- # = mean / resid_tree =
1859
- # = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1860
- # = 1 / prec_tree * prec_tree / sigma2 * var_post =
1861
- # = var_post / sigma2
2063
+ # | mean = mean_lk * prec_lk * var_post
2064
+ # | resid_tree = mean_lk * prec_tree -->
2065
+ # | --> mean_lk = resid_tree / prec_tree (kind of)
2066
+ # | mean_factor =
2067
+ # | = mean / resid_tree =
2068
+ # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2069
+ # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2070
+ # | = var_post / sigma2
1862
2071
  centered_leaves=z * jnp.sqrt(var_post),
1863
2072
  )
1864
2073
 
@@ -1884,42 +2093,34 @@ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1884
2093
  """
1885
2094
 
1886
2095
  def loop(resid, pt):
1887
- resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
2096
+ resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
1888
2097
  resid,
1889
2098
  SeqStageInAllTrees(
1890
2099
  pso.bart.X,
1891
2100
  pso.bart.forest.resid_batch_size,
1892
2101
  pso.bart.prec_scale,
1893
- pso.bart.forest.min_points_per_leaf,
1894
2102
  pso.bart.forest.log_likelihood is not None,
1895
2103
  pso.prelk,
1896
2104
  ),
1897
2105
  pt,
1898
2106
  )
1899
- return resid, (leaf_tree, acc, to_prune, ratios)
2107
+ return resid, (leaf_tree, acc, to_prune, lkratio)
1900
2108
 
1901
2109
  pts = SeqStageInPerTree(
1902
- pso.bart.forest.leaf_trees,
2110
+ pso.bart.forest.leaf_tree,
1903
2111
  pso.prec_trees,
1904
2112
  pso.moves,
1905
- pso.move_counts,
1906
2113
  pso.move_precs,
1907
2114
  pso.bart.forest.leaf_indices,
1908
2115
  pso.prelkv,
1909
2116
  pso.prelf,
1910
2117
  )
1911
- resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
2118
+ resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
1912
2119
 
1913
- save_ratios = pso.bart.forest.log_likelihood is not None
1914
2120
  bart = replace(
1915
2121
  pso.bart,
1916
2122
  resid=resid,
1917
- forest=replace(
1918
- pso.bart.forest,
1919
- leaf_trees=leaf_trees,
1920
- log_likelihood=ratios['log_likelihood'] if save_ratios else None,
1921
- log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
1922
- ),
2123
+ forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
1923
2124
  )
1924
2125
  moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1925
2126
 
@@ -1928,7 +2129,7 @@ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1928
2129
 
1929
2130
  class SeqStageInAllTrees(Module):
1930
2131
  """
1931
- The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
2132
+ The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
1932
2133
 
1933
2134
  Parameters
1934
2135
  ----------
@@ -1939,8 +2140,6 @@ class SeqStageInAllTrees(Module):
1939
2140
  prec_scale
1940
2141
  The scale of the precision of the error on each datapoint. If None, it
1941
2142
  is assumed to be 1.
1942
- min_points_per_leaf
1943
- The minimum number of data points in a leaf node.
1944
2143
  save_ratios
1945
2144
  Whether to save the acceptance ratios.
1946
2145
  prelk
@@ -1949,10 +2148,9 @@ class SeqStageInAllTrees(Module):
1949
2148
  """
1950
2149
 
1951
2150
  X: UInt[Array, 'p n']
1952
- resid_batch_size: int | None
1953
- prec_scale: Float32[Array, 'n'] | None
1954
- min_points_per_leaf: Int32[Array, ''] | None
1955
- save_ratios: bool
2151
+ resid_batch_size: int | None = field(static=True)
2152
+ prec_scale: Float32[Array, ' n'] | None
2153
+ save_ratios: bool = field(static=True)
1956
2154
  prelk: PreLk
1957
2155
 
1958
2156
 
@@ -1968,9 +2166,6 @@ class SeqStageInPerTree(Module):
1968
2166
  The likelihood precision scale in each potential or actual leaf node.
1969
2167
  move
1970
2168
  The proposed move, see `propose_moves`.
1971
- move_counts
1972
- The counts of the number of points in the the nodes modified by the
1973
- moves.
1974
2169
  move_precs
1975
2170
  The likelihood precision scale in each node modified by the moves.
1976
2171
  leaf_indices
@@ -1982,26 +2177,23 @@ class SeqStageInPerTree(Module):
1982
2177
  are specific to the tree.
1983
2178
  """
1984
2179
 
1985
- leaf_tree: Float32[Array, '2**d']
1986
- prec_tree: Float32[Array, '2**d']
2180
+ leaf_tree: Float32[Array, ' 2**d']
2181
+ prec_tree: Float32[Array, ' 2**d']
1987
2182
  move: Moves
1988
- move_counts: Counts | None
1989
2183
  move_precs: Precs | Counts
1990
- leaf_indices: UInt[Array, 'n']
2184
+ leaf_indices: UInt[Array, ' n']
1991
2185
  prelkv: PreLkV
1992
2186
  prelf: PreLf
1993
2187
 
1994
2188
 
1995
2189
  def accept_move_and_sample_leaves(
1996
- resid: Float32[Array, 'n'],
1997
- at: SeqStageInAllTrees,
1998
- pt: SeqStageInPerTree,
2190
+ resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
1999
2191
  ) -> tuple[
2000
- Float32[Array, 'n'],
2001
- Float32[Array, '2**d'],
2192
+ Float32[Array, ' n'],
2193
+ Float32[Array, ' 2**d'],
2002
2194
  Bool[Array, ''],
2003
2195
  Bool[Array, ''],
2004
- dict[str, Float32[Array, '']],
2196
+ Float32[Array, ''] | None,
2005
2197
  ]:
2006
2198
  """
2007
2199
  Accept or reject a proposed move and sample the new leaf values.
@@ -2026,8 +2218,9 @@ def accept_move_and_sample_leaves(
2026
2218
  to_prune : Bool[Array, '']
2027
2219
  Whether, to reflect the acceptance status of the move, the state should
2028
2220
  be updated by pruning the leaves involved in the move.
2029
- ratios : dict[str, Float32[Array, '']]
2030
- The acceptance ratios for the moves. Empty if not to be saved.
2221
+ log_lk_ratio : Float32[Array, ''] | None
2222
+ The logarithm of the likelihood ratio for the move. `None` if not to be
2223
+ saved.
2031
2224
  """
2032
2225
  # sum residuals in each leaf, in tree proposed by grow move
2033
2226
  if at.prec_scale is None:
@@ -2041,17 +2234,12 @@ def accept_move_and_sample_leaves(
2041
2234
  # subtract starting tree from function
2042
2235
  resid_tree += pt.prec_tree * pt.leaf_tree
2043
2236
 
2044
- # get indices of move
2045
- node = pt.move.node
2046
- assert node.dtype == jnp.int32
2047
- left = pt.move.left
2048
- right = pt.move.right
2049
-
2050
2237
  # sum residuals in parent node modified by move
2051
- resid_left = resid_tree[left]
2052
- resid_right = resid_tree[right]
2238
+ resid_left = resid_tree[pt.move.left]
2239
+ resid_right = resid_tree[pt.move.right]
2053
2240
  resid_total = resid_left + resid_right
2054
- resid_tree = resid_tree.at[node].set(resid_total)
2241
+ assert pt.move.node.dtype == jnp.int32
2242
+ resid_tree = resid_tree.at[pt.move.node].set(resid_total)
2055
2243
 
2056
2244
  # compute acceptance ratio
2057
2245
  log_lk_ratio = compute_likelihood_ratio(
@@ -2059,48 +2247,37 @@ def accept_move_and_sample_leaves(
2059
2247
  )
2060
2248
  log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2061
2249
  log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
2062
- ratios = {}
2063
- if at.save_ratios:
2064
- ratios.update(
2065
- log_trans_prior=pt.move.log_trans_prior_ratio,
2066
- # TODO save log_trans_prior_ratio as a vector outside of this loop,
2067
- # then change the option everywhere to `save_likelihood_ratio`.
2068
- log_likelihood=log_lk_ratio,
2069
- )
2250
+ if not at.save_ratios:
2251
+ log_lk_ratio = None
2070
2252
 
2071
2253
  # determine whether to accept the move
2072
2254
  acc = pt.move.allowed & (pt.move.logu <= log_ratio)
2073
- if at.min_points_per_leaf is not None:
2074
- assert pt.move_counts is not None
2075
- acc &= pt.move_counts.left >= at.min_points_per_leaf
2076
- acc &= pt.move_counts.right >= at.min_points_per_leaf
2077
2255
 
2078
2256
  # compute leaves posterior and sample leaves
2079
- initial_leaf_tree = pt.leaf_tree
2080
2257
  mean_post = resid_tree * pt.prelf.mean_factor
2081
2258
  leaf_tree = mean_post + pt.prelf.centered_leaves
2082
2259
 
2083
2260
  # copy leaves around such that the leaf indices point to the correct leaf
2084
2261
  to_prune = acc ^ pt.move.grow
2085
2262
  leaf_tree = (
2086
- leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
2087
- .set(leaf_tree[node])
2088
- .at[jnp.where(to_prune, right, leaf_tree.size)]
2089
- .set(leaf_tree[node])
2263
+ leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2264
+ .set(leaf_tree[pt.move.node])
2265
+ .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2266
+ .set(leaf_tree[pt.move.node])
2090
2267
  )
2091
2268
 
2092
2269
  # replace old tree with new tree in function values
2093
- resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
2270
+ resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
2094
2271
 
2095
- return resid, leaf_tree, acc, to_prune, ratios
2272
+ return resid, leaf_tree, acc, to_prune, log_lk_ratio
2096
2273
 
2097
2274
 
2098
2275
  def sum_resid(
2099
- scaled_resid: Float32[Array, 'n'],
2100
- leaf_indices: UInt[Array, 'n'],
2276
+ scaled_resid: Float32[Array, ' n'],
2277
+ leaf_indices: UInt[Array, ' n'],
2101
2278
  tree_size: int,
2102
2279
  batch_size: int | None,
2103
- ) -> Float32[Array, '{tree_size}']:
2280
+ ) -> Float32[Array, ' {tree_size}']:
2104
2281
  """
2105
2282
  Sum the residuals in each leaf.
2106
2283
 
@@ -2134,7 +2311,7 @@ def _aggregate_batched_onetree(
2134
2311
  size: int,
2135
2312
  dtype: jnp.dtype,
2136
2313
  batch_size: int,
2137
- ) -> Float32[Array, '{size}']:
2314
+ ) -> Float32[Array, ' {size}']:
2138
2315
  (n,) = indices.shape
2139
2316
  nbatches = n // batch_size + bool(n % batch_size)
2140
2317
  batch_indices = jnp.arange(n) % nbatches
@@ -2206,7 +2383,7 @@ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
2206
2383
  grow_acc_count=jnp.sum(moves.acc & moves.grow),
2207
2384
  prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2208
2385
  leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2209
- split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
2386
+ split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2210
2387
  ),
2211
2388
  )
2212
2389
 
@@ -2234,22 +2411,20 @@ def apply_moves_to_leaf_indices(
2234
2411
  mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
2235
2412
  is_child = (leaf_indices & mask) == moves.left
2236
2413
  return jnp.where(
2237
- is_child & moves.to_prune,
2238
- moves.node.astype(leaf_indices.dtype),
2239
- leaf_indices,
2414
+ is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2240
2415
  )
2241
2416
 
2242
2417
 
2243
2418
  @vmap_nodoc
2244
2419
  def apply_moves_to_split_trees(
2245
- split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2420
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2246
2421
  ) -> UInt[Array, 'num_trees 2**(d-1)']:
2247
2422
  """
2248
2423
  Update the split trees to match the accepted move.
2249
2424
 
2250
2425
  Parameters
2251
2426
  ----------
2252
- split_trees
2427
+ split_tree
2253
2428
  The cutpoints of the decision nodes in the initial trees.
2254
2429
  moves
2255
2430
  The proposed moves (see `propose_moves`), as updated by
@@ -2261,21 +2436,9 @@ def apply_moves_to_split_trees(
2261
2436
  """
2262
2437
  assert moves.to_prune is not None
2263
2438
  return (
2264
- split_trees.at[
2265
- jnp.where(
2266
- moves.grow,
2267
- moves.node,
2268
- split_trees.size,
2269
- )
2270
- ]
2271
- .set(moves.grow_split.astype(split_trees.dtype))
2272
- .at[
2273
- jnp.where(
2274
- moves.to_prune,
2275
- moves.node,
2276
- split_trees.size,
2277
- )
2278
- ]
2439
+ split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2440
+ .set(moves.grow_split.astype(split_tree.dtype))
2441
+ .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2279
2442
  .set(0)
2280
2443
  )
2281
2444
 
@@ -2305,6 +2468,8 @@ def step_sigma(key: Key[Array, ''], bart: State) -> State:
2305
2468
  beta = bart.sigma2_beta + norm2 / 2
2306
2469
 
2307
2470
  sample = random.gamma(key, alpha)
2471
+ # random.gamma seems to be slow at compiling, maybe cdf inversion would
2472
+ # be better, but it's not implemented in jax
2308
2473
  return replace(bart, sigma2=beta / sample)
2309
2474
 
2310
2475
 
@@ -2324,12 +2489,128 @@ def step_z(key: Key[Array, ''], bart: State) -> State:
2324
2489
  The updated BART MCMC state.
2325
2490
  """
2326
2491
  trees_plus_offset = bart.z - bart.resid
2327
- lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
2328
- upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
2329
- resid = random.truncated_normal(key, lower, upper)
2330
- # TODO jax's implementation of truncated_normal is not good, it just does
2331
- # cdf inversion with erf and erf_inv. I can do better, at least avoiding to
2332
- # compute one of the boundaries, and maybe also flipping and using ndtr
2333
- # instead of erf for numerical stability (open an issue in jax?)
2492
+ assert bart.y.dtype == bool
2493
+ resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
2334
2494
  z = trees_plus_offset + resid
2335
2495
  return replace(bart, z=z, resid=resid)
2496
+
2497
+
2498
+ def step_s(key: Key[Array, ''], bart: State) -> State:
2499
+ """
2500
+ Update `log_s` using Dirichlet sampling.
2501
+
2502
+ The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2503
+ is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2504
+ varcount is the count of how many times each variable is used in the
2505
+ current forest.
2506
+
2507
+ Parameters
2508
+ ----------
2509
+ key
2510
+ Random key for sampling.
2511
+ bart
2512
+ The current BART state.
2513
+
2514
+ Returns
2515
+ -------
2516
+ Updated BART state with re-sampled `log_s`.
2517
+
2518
+ """
2519
+ assert bart.forest.theta is not None
2520
+
2521
+ # histogram current variable usage
2522
+ p = bart.forest.max_split.size
2523
+ varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
2524
+
2525
+ # sample from Dirichlet posterior
2526
+ alpha = bart.forest.theta / p + varcount
2527
+ log_s = random.loggamma(key, alpha)
2528
+
2529
+ # update forest with new s
2530
+ return replace(bart, forest=replace(bart.forest, log_s=log_s))
2531
+
2532
+
2533
+ def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
2534
+ """
2535
+ Update `theta`.
2536
+
2537
+ The prior is theta / (theta + rho) ~ Beta(a, b).
2538
+
2539
+ Parameters
2540
+ ----------
2541
+ key
2542
+ Random key for sampling.
2543
+ bart
2544
+ The current BART state.
2545
+ num_grid
2546
+ The number of points in the evenly-spaced grid used to sample
2547
+ theta / (theta + rho).
2548
+
2549
+ Returns
2550
+ -------
2551
+ Updated BART state with re-sampled `theta`.
2552
+ """
2553
+ assert bart.forest.log_s is not None
2554
+ assert bart.forest.rho is not None
2555
+ assert bart.forest.a is not None
2556
+ assert bart.forest.b is not None
2557
+
2558
+ # the grid points are the midpoints of num_grid bins in (0, 1)
2559
+ padding = 1 / (2 * num_grid)
2560
+ lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
2561
+
2562
+ # normalize s
2563
+ log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
2564
+
2565
+ # sample lambda
2566
+ logp, theta_grid = _log_p_lamda(
2567
+ lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2568
+ )
2569
+ i = random.categorical(key, logp)
2570
+ theta = theta_grid[i]
2571
+
2572
+ return replace(bart, forest=replace(bart.forest, theta=theta))
2573
+
2574
+
2575
+ def _log_p_lamda(
2576
+ lamda: Float32[Array, ' num_grid'],
2577
+ log_s: Float32[Array, ' p'],
2578
+ rho: Float32[Array, ''],
2579
+ a: Float32[Array, ''],
2580
+ b: Float32[Array, ''],
2581
+ ) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2582
+ # in the following I use lamda[::-1] == 1 - lamda
2583
+ theta = rho * lamda / lamda[::-1]
2584
+ p = log_s.size
2585
+ return (
2586
+ (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2587
+ + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2588
+ + gammaln(theta)
2589
+ - p * gammaln(theta / p)
2590
+ + theta / p * jnp.sum(log_s)
2591
+ ), theta
2592
+
2593
+
2594
+ def step_sparse(key: Key[Array, ''], bart: State) -> State:
2595
+ """
2596
+ Update the sparsity parameters.
2597
+
2598
+ This invokes `step_s`, and then `step_theta` only if the parameters of
2599
+ the theta prior are defined.
2600
+
2601
+ Parameters
2602
+ ----------
2603
+ key
2604
+ Random key for sampling.
2605
+ bart
2606
+ The current BART state.
2607
+
2608
+ Returns
2609
+ -------
2610
+ Updated BART state with re-sampled `log_s` and `theta`.
2611
+ """
2612
+ keys = split(key)
2613
+ bart = step_s(keys.pop(), bart)
2614
+ if bart.forest.rho is not None:
2615
+ bart = step_theta(keys.pop(), bart)
2616
+ return bart