bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -26,220 +26,394 @@
26
26
  Functions that implement the BART posterior MCMC initialization and update step.
27
27
 
28
28
  Functions that do MCMC steps operate by taking as input a bart state, and
29
- outputting a new dictionary with the new state. The input dict/arrays are not
30
- modified.
29
+ outputting a new state. The inputs are not modified.
31
30
 
32
- In general, integer types are chosen to be the minimal types that contain the
33
- range of possible values.
31
+ The entry points are:
32
+
33
+ - `State`: The dataclass that represents a BART MCMC state.
34
+ - `init`: Creates an initial `State` from data and configurations.
35
+ - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36
+ - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
34
37
  """
35
38
 
36
- import functools
37
39
  import math
40
+ from dataclasses import replace
41
+ from functools import cache, partial
42
+ from typing import Any, Literal
38
43
 
39
44
  import jax
45
+ from equinox import Module, field, tree_at
40
46
  from jax import lax, random
41
47
  from jax import numpy as jnp
48
+ from jax.scipy.special import gammaln, logsumexp
49
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
50
+
51
+ from bartz import grove
52
+ from bartz.jaxext import (
53
+ minimal_unsigned_dtype,
54
+ split,
55
+ truncated_normal_onesided,
56
+ vmap_nodoc,
57
+ )
42
58
 
43
- from . import grove, jaxext
59
+
60
+ class Forest(Module):
61
+ """
62
+ Represents the MCMC state of a sum of trees.
63
+
64
+ Parameters
65
+ ----------
66
+ leaf_tree
67
+ The leaf values.
68
+ var_tree
69
+ The decision axes.
70
+ split_tree
71
+ The decision boundaries.
72
+ affluence_tree
73
+ Marks leaves that can be grown.
74
+ max_split
75
+ The maximum split index for each predictor.
76
+ blocked_vars
77
+ Indices of variables that are not used. This shall include at least
78
+ the `i` such that ``max_split[i] == 0``, otherwise behavior is
79
+ undefined.
80
+ p_nonterminal
81
+ The prior probability of each node being nonterminal, conditional on
82
+ its ancestors. Includes the nodes at maximum depth which should be set
83
+ to 0.
84
+ p_propose_grow
85
+ The unnormalized probability of picking a leaf for a grow proposal.
86
+ leaf_indices
87
+ The index of the leaf each datapoints falls into, for each tree.
88
+ min_points_per_decision_node
89
+ The minimum number of data points in a decision node.
90
+ min_points_per_leaf
91
+ The minimum number of data points in a leaf node.
92
+ resid_batch_size
93
+ count_batch_size
94
+ The data batch sizes for computing the sufficient statistics. If `None`,
95
+ they are computed with no batching.
96
+ log_trans_prior
97
+ The log transition and prior Metropolis-Hastings ratio for the
98
+ proposed move on each tree.
99
+ log_likelihood
100
+ The log likelihood ratio.
101
+ grow_prop_count
102
+ prune_prop_count
103
+ The number of grow/prune proposals made during one full MCMC cycle.
104
+ grow_acc_count
105
+ prune_acc_count
106
+ The number of grow/prune moves accepted during one full MCMC cycle.
107
+ sigma_mu2
108
+ The prior variance of a leaf, conditional on the tree structure.
109
+ log_s
110
+ The logarithm of the prior probability for choosing a variable to split
111
+ along in a decision rule, conditional on the ancestors. Not normalized.
112
+ If `None`, use a uniform distribution.
113
+ theta
114
+ The concentration parameter for the Dirichlet prior on the variable
115
+ distribution `s`. Required only to update `s`.
116
+ a
117
+ b
118
+ rho
119
+ Parameters of the prior on `theta`. Required only to sample `theta`.
120
+ See `step_theta`.
121
+ """
122
+
123
+ leaf_tree: Float32[Array, 'num_trees 2**d']
124
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
125
+ split_tree: UInt[Array, 'num_trees 2**(d-1)']
126
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
127
+ max_split: UInt[Array, ' p']
128
+ blocked_vars: UInt[Array, ' k'] | None
129
+ p_nonterminal: Float32[Array, ' 2**d']
130
+ p_propose_grow: Float32[Array, ' 2**(d-1)']
131
+ leaf_indices: UInt[Array, 'num_trees n']
132
+ min_points_per_decision_node: Int32[Array, ''] | None
133
+ min_points_per_leaf: Int32[Array, ''] | None
134
+ resid_batch_size: int | None = field(static=True)
135
+ count_batch_size: int | None = field(static=True)
136
+ log_trans_prior: Float32[Array, ' num_trees'] | None
137
+ log_likelihood: Float32[Array, ' num_trees'] | None
138
+ grow_prop_count: Int32[Array, '']
139
+ prune_prop_count: Int32[Array, '']
140
+ grow_acc_count: Int32[Array, '']
141
+ prune_acc_count: Int32[Array, '']
142
+ sigma_mu2: Float32[Array, '']
143
+ log_s: Float32[Array, ' p'] | None
144
+ theta: Float32[Array, ''] | None
145
+ a: Float32[Array, ''] | None
146
+ b: Float32[Array, ''] | None
147
+ rho: Float32[Array, ''] | None
148
+
149
+
150
+ class State(Module):
151
+ """
152
+ Represents the MCMC state of BART.
153
+
154
+ Parameters
155
+ ----------
156
+ X
157
+ The predictors.
158
+ y
159
+ The response. If the data type is `bool`, the model is binary regression.
160
+ resid
161
+ The residuals (`y` or `z` minus sum of trees).
162
+ z
163
+ The latent variable for binary regression. `None` in continuous
164
+ regression.
165
+ offset
166
+ Constant shift added to the sum of trees.
167
+ sigma2
168
+ The error variance. `None` in binary regression.
169
+ prec_scale
170
+ The scale on the error precision, i.e., ``1 / error_scale ** 2``.
171
+ `None` in binary regression.
172
+ sigma2_alpha
173
+ sigma2_beta
174
+ The shape and scale parameters of the inverse gamma prior on the noise
175
+ variance. `None` in binary regression.
176
+ forest
177
+ The sum of trees model.
178
+ """
179
+
180
+ X: UInt[Array, 'p n']
181
+ y: Float32[Array, ' n'] | Bool[Array, ' n']
182
+ z: None | Float32[Array, ' n']
183
+ offset: Float32[Array, '']
184
+ resid: Float32[Array, ' n']
185
+ sigma2: Float32[Array, ''] | None
186
+ prec_scale: Float32[Array, ' n'] | None
187
+ sigma2_alpha: Float32[Array, ''] | None
188
+ sigma2_beta: Float32[Array, ''] | None
189
+ forest: Forest
44
190
 
45
191
 
46
192
  def init(
47
193
  *,
48
- X,
49
- y,
50
- max_split,
51
- num_trees,
52
- p_nonterminal,
53
- sigma2_alpha,
54
- sigma2_beta,
55
- error_scale=None,
56
- small_float=jnp.float32,
57
- large_float=jnp.float32,
58
- min_points_per_leaf=None,
59
- resid_batch_size='auto',
60
- count_batch_size='auto',
61
- save_ratios=False,
62
- ):
194
+ X: UInt[Any, 'p n'],
195
+ y: Float32[Any, ' n'] | Bool[Any, ' n'],
196
+ offset: float | Float32[Any, ''] = 0.0,
197
+ max_split: UInt[Any, ' p'],
198
+ num_trees: int,
199
+ p_nonterminal: Float32[Any, ' d-1'],
200
+ sigma_mu2: float | Float32[Any, ''],
201
+ sigma2_alpha: float | Float32[Any, ''] | None = None,
202
+ sigma2_beta: float | Float32[Any, ''] | None = None,
203
+ error_scale: Float32[Any, ' n'] | None = None,
204
+ min_points_per_decision_node: int | Integer[Any, ''] | None = None,
205
+ resid_batch_size: int | None | Literal['auto'] = 'auto',
206
+ count_batch_size: int | None | Literal['auto'] = 'auto',
207
+ save_ratios: bool = False,
208
+ filter_splitless_vars: bool = True,
209
+ min_points_per_leaf: int | Integer[Any, ''] | None = None,
210
+ log_s: Float32[Any, ' p'] | None = None,
211
+ theta: float | Float32[Any, ''] | None = None,
212
+ a: float | Float32[Any, ''] | None = None,
213
+ b: float | Float32[Any, ''] | None = None,
214
+ rho: float | Float32[Any, ''] | None = None,
215
+ ) -> State:
63
216
  """
64
217
  Make a BART posterior sampling MCMC initial state.
65
218
 
66
219
  Parameters
67
220
  ----------
68
- X : int array (p, n)
221
+ X
69
222
  The predictors. Note this is trasposed compared to the usual convention.
70
- y : float array (n,)
71
- The response.
72
- max_split : int array (p,)
223
+ y
224
+ The response. If the data type is `bool`, the regression model is binary
225
+ regression with probit.
226
+ offset
227
+ Constant shift added to the sum of trees. 0 if not specified.
228
+ max_split
73
229
  The maximum split index for each variable. All split ranges start at 1.
74
- num_trees : int
230
+ num_trees
75
231
  The number of trees in the forest.
76
- p_nonterminal : float array (d - 1,)
232
+ p_nonterminal
77
233
  The probability of a nonterminal node at each depth. The maximum depth
78
234
  of trees is fixed by the length of this array.
79
- sigma2_alpha : float
80
- The shape parameter of the inverse gamma prior on the error variance.
81
- sigma2_beta : float
82
- The scale parameter of the inverse gamma prior on the error variance.
83
- error_scale : float array (n,), optional
235
+ sigma_mu2
236
+ The prior variance of a leaf, conditional on the tree structure. The
237
+ prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
238
+ prior mean of leaves is always zero.
239
+ sigma2_alpha
240
+ sigma2_beta
241
+ The shape and scale parameters of the inverse gamma prior on the error
242
+ variance. Leave unspecified for binary regression.
243
+ error_scale
84
244
  Each error is scaled by the corresponding factor in `error_scale`, so
85
245
  the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
86
- small_float : dtype, default float32
87
- The dtype for large arrays used in the algorithm.
88
- large_float : dtype, default float32
89
- The dtype for scalars, small arrays, and arrays which require accuracy.
90
- min_points_per_leaf : int, optional
91
- The minimum number of data points in a leaf node. 0 if not specified.
92
- resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
246
+ Not supported for binary regression. If not specified, defaults to 1 for
247
+ all points, but potentially skipping calculations.
248
+ min_points_per_decision_node
249
+ The minimum number of data points in a decision node. 0 if not
250
+ specified.
251
+ resid_batch_size
252
+ count_batch_size
93
253
  The batch sizes, along datapoints, for summing the residuals and
94
254
  counting the number of datapoints in each leaf. `None` for no batching.
95
255
  If 'auto', pick a value based on the device of `y`, or the default
96
256
  device.
97
- save_ratios : bool, default False
257
+ save_ratios
98
258
  Whether to save the Metropolis-Hastings ratios.
259
+ filter_splitless_vars
260
+ Whether to check `max_split` for variables without available cutpoints.
261
+ If any are found, they are put into a list of variables to exclude from
262
+ the MCMC. If `False`, no check is performed, but the results may be
263
+ wrong if any variable is blocked. The function is jax-traceable only
264
+ if this is set to `False`.
265
+ min_points_per_leaf
266
+ The minimum number of datapoints in a leaf node. 0 if not specified.
267
+ Unlike `min_points_per_decision_node`, this constraint is not taken into
268
+ account in the Metropolis-Hastings ratio because it would be expensive
269
+ to compute. Grow moves that would violate this constraint are vetoed.
270
+ This parameter is independent of `min_points_per_decision_node` and
271
+ there is no check that they are coherent. It makes sense to set
272
+ ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
273
+ log_s
274
+ The logarithm of the prior probability for choosing a variable to split
275
+ along in a decision rule, conditional on the ancestors. Not normalized.
276
+ If not specified, use a uniform distribution. If not specified and
277
+ `theta` or `rho`, `a`, `b` are, it's initialized automatically.
278
+ theta
279
+ The concentration parameter for the Dirichlet prior on `s`. Required
280
+ only to update `log_s`. If not specified, and `rho`, `a`, `b` are
281
+ specified, it's initialized automatically.
282
+ a
283
+ b
284
+ rho
285
+ Parameters of the prior on `theta`. Required only to sample `theta`.
99
286
 
100
287
  Returns
101
288
  -------
102
- bart : dict
103
- A dictionary with array values, representing a BART mcmc state. The
104
- keys are:
105
-
106
- 'leaf_trees' : small_float array (num_trees, 2 ** d)
107
- The leaf values.
108
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
109
- The decision axes.
110
- 'split_trees' : int array (num_trees, 2 ** (d - 1))
111
- The decision boundaries.
112
- 'resid' : large_float array (n,)
113
- The residuals (data minus forest value). Large float to avoid
114
- roundoff.
115
- 'sigma2' : large_float
116
- The noise variance.
117
- 'prec_scale' : large_float array (n,) or None
118
- The scale on the error precision, i.e., ``1 / error_scale ** 2``.
119
- 'grow_prop_count', 'prune_prop_count' : int
120
- The number of grow/prune proposals made during one full MCMC cycle.
121
- 'grow_acc_count', 'prune_acc_count' : int
122
- The number of grow/prune moves accepted during one full MCMC cycle.
123
- 'p_nonterminal' : large_float array (d,)
124
- The probability of a nonterminal node at each depth, padded with a
125
- zero.
126
- 'p_propose_grow' : large_float array (2 ** (d - 1),)
127
- The unnormalized probability of picking a leaf for a grow proposal.
128
- 'sigma2_alpha' : large_float
129
- The shape parameter of the inverse gamma prior on the noise variance.
130
- 'sigma2_beta' : large_float
131
- The scale parameter of the inverse gamma prior on the noise variance.
132
- 'max_split' : int array (p,)
133
- The maximum split index for each variable.
134
- 'y' : small_float array (n,)
135
- The response.
136
- 'X' : int array (p, n)
137
- The predictors.
138
- 'leaf_indices' : int array (num_trees, n)
139
- The index of the leaf each datapoints falls into, for each tree.
140
- 'min_points_per_leaf' : int or None
141
- The minimum number of data points in a leaf node.
142
- 'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
143
- Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
144
- datapoints. If `min_points_per_leaf` is not specified, this is None.
145
- 'opt' : LeafDict
146
- A dictionary with config values:
147
-
148
- 'small_float' : dtype
149
- The dtype for large arrays used in the algorithm.
150
- 'large_float' : dtype
151
- The dtype for scalars, small arrays, and arrays which require
152
- accuracy.
153
- 'require_min_points' : bool
154
- Whether the `min_points_per_leaf` parameter is specified.
155
- 'resid_batch_size', 'count_batch_size' : int or None
156
- The data batch sizes for computing the sufficient statistics.
157
- 'ratios' : dict, optional
158
- If `save_ratios` is True, this field is present. It has the fields:
159
-
160
- 'log_trans_prior' : large_float array (num_trees,)
161
- The log transition and prior Metropolis-Hastings ratio for the
162
- proposed move on each tree.
163
- 'log_likelihood' : large_float array (num_trees,)
164
- The log likelihood ratio.
165
- """
166
-
167
- p_nonterminal = jnp.asarray(p_nonterminal, large_float)
289
+ An initialized BART MCMC state.
290
+
291
+ Raises
292
+ ------
293
+ ValueError
294
+ If `y` is boolean and arguments unused in binary regression are set.
295
+
296
+ Notes
297
+ -----
298
+ In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
299
+ of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
300
+ child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
301
+ integers in the range ``[0, 1, ..., max_split[i]]``.
302
+ """
303
+ p_nonterminal = jnp.asarray(p_nonterminal)
168
304
  p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
169
305
  max_depth = p_nonterminal.size
170
306
 
171
- @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
307
+ @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
172
308
  def make_forest(max_depth, dtype):
173
309
  return grove.make_tree(max_depth, dtype)
174
310
 
175
- small_float = jnp.dtype(small_float)
176
- large_float = jnp.dtype(large_float)
177
- y = jnp.asarray(y, small_float)
311
+ y = jnp.asarray(y)
312
+ offset = jnp.asarray(offset)
313
+
178
314
  resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
179
315
  resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
180
316
  )
181
- sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
182
- sigma2 = jnp.where(
183
- jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1
184
- ) # TODO: I don't like this error check, these functions should be low-level and just do the thing. Why is it here?
185
-
186
- bart = dict(
187
- leaf_trees=make_forest(max_depth, small_float),
188
- var_trees=make_forest(
189
- max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)
190
- ),
191
- split_trees=make_forest(max_depth - 1, max_split.dtype),
192
- resid=jnp.asarray(y, large_float),
317
+
318
+ is_binary = y.dtype == bool
319
+ if is_binary:
320
+ if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
321
+ msg = (
322
+ 'error_scale, sigma2_alpha, and sigma2_beta must be set '
323
+ ' to `None` for binary regression.'
324
+ )
325
+ raise ValueError(msg)
326
+ sigma2 = None
327
+ else:
328
+ sigma2_alpha = jnp.asarray(sigma2_alpha)
329
+ sigma2_beta = jnp.asarray(sigma2_beta)
330
+ sigma2 = sigma2_beta / sigma2_alpha
331
+
332
+ max_split = jnp.asarray(max_split)
333
+
334
+ if filter_splitless_vars:
335
+ (blocked_vars,) = jnp.nonzero(max_split == 0)
336
+ blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
337
+ # see `fully_used_variables` for the type cast
338
+ else:
339
+ blocked_vars = None
340
+
341
+ # check and initialize sparsity parameters
342
+ if not _all_none_or_not_none(rho, a, b):
343
+ msg = 'rho, a, b are not either all `None` or all set'
344
+ raise ValueError(msg)
345
+ if theta is None and rho is not None:
346
+ theta = rho
347
+ if log_s is None and theta is not None:
348
+ log_s = jnp.zeros(max_split.size)
349
+
350
+ return State(
351
+ X=jnp.asarray(X),
352
+ y=y,
353
+ z=jnp.full(y.shape, offset) if is_binary else None,
354
+ offset=offset,
355
+ resid=jnp.zeros(y.shape) if is_binary else y - offset,
193
356
  sigma2=sigma2,
194
357
  prec_scale=(
195
- None
196
- if error_scale is None
197
- else lax.reciprocal(jnp.square(jnp.asarray(error_scale, large_float)))
198
- ),
199
- grow_prop_count=jnp.zeros((), int),
200
- grow_acc_count=jnp.zeros((), int),
201
- prune_prop_count=jnp.zeros((), int),
202
- prune_acc_count=jnp.zeros((), int),
203
- p_nonterminal=p_nonterminal,
204
- p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
205
- sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
206
- sigma2_beta=jnp.asarray(sigma2_beta, large_float),
207
- max_split=jnp.asarray(max_split),
208
- y=y,
209
- X=jnp.asarray(X),
210
- leaf_indices=jnp.ones(
211
- (num_trees, y.size), jaxext.minimal_unsigned_dtype(2**max_depth - 1)
212
- ),
213
- min_points_per_leaf=(
214
- None if min_points_per_leaf is None else jnp.asarray(min_points_per_leaf)
358
+ None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
215
359
  ),
216
- affluence_trees=(
217
- None
218
- if min_points_per_leaf is None
219
- else make_forest(max_depth - 1, bool)
220
- .at[:, 1]
221
- .set(y.size >= 2 * min_points_per_leaf)
222
- ),
223
- opt=jaxext.LeafDict(
224
- small_float=small_float,
225
- large_float=large_float,
226
- require_min_points=min_points_per_leaf is not None,
360
+ sigma2_alpha=sigma2_alpha,
361
+ sigma2_beta=sigma2_beta,
362
+ forest=Forest(
363
+ leaf_tree=make_forest(max_depth, jnp.float32),
364
+ var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
365
+ split_tree=make_forest(max_depth - 1, max_split.dtype),
366
+ affluence_tree=(
367
+ make_forest(max_depth - 1, bool)
368
+ .at[:, 1]
369
+ .set(
370
+ True
371
+ if min_points_per_decision_node is None
372
+ else y.size >= min_points_per_decision_node
373
+ )
374
+ ),
375
+ blocked_vars=blocked_vars,
376
+ max_split=max_split,
377
+ grow_prop_count=jnp.zeros((), int),
378
+ grow_acc_count=jnp.zeros((), int),
379
+ prune_prop_count=jnp.zeros((), int),
380
+ prune_acc_count=jnp.zeros((), int),
381
+ p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
382
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
383
+ leaf_indices=jnp.ones(
384
+ (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
385
+ ),
386
+ min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
387
+ min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
227
388
  resid_batch_size=resid_batch_size,
228
389
  count_batch_size=count_batch_size,
390
+ log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
391
+ log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
392
+ sigma_mu2=jnp.asarray(sigma_mu2),
393
+ log_s=_asarray_or_none(log_s),
394
+ theta=_asarray_or_none(theta),
395
+ rho=_asarray_or_none(rho),
396
+ a=_asarray_or_none(a),
397
+ b=_asarray_or_none(b),
229
398
  ),
230
399
  )
231
400
 
232
- if save_ratios:
233
- bart['ratios'] = dict(
234
- log_trans_prior=jnp.full(num_trees, jnp.nan),
235
- log_likelihood=jnp.full(num_trees, jnp.nan),
236
- )
237
401
 
238
- return bart
402
+ def _all_none_or_not_none(*args):
403
+ is_none = [x is None for x in args]
404
+ return all(is_none) or not any(is_none)
239
405
 
240
406
 
241
- def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
242
- @functools.cache
407
+ def _asarray_or_none(x):
408
+ if x is None:
409
+ return None
410
+ return jnp.asarray(x)
411
+
412
+
413
+ def _choose_suffstat_batch_size(
414
+ resid_batch_size, count_batch_size, y, forest_size
415
+ ) -> tuple[int | None, ...]:
416
+ @cache
243
417
  def get_platform():
244
418
  try:
245
419
  device = y.devices().pop()
@@ -247,16 +421,17 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
247
421
  device = jax.devices()[0]
248
422
  platform = device.platform
249
423
  if platform not in ('cpu', 'gpu'):
250
- raise KeyError(f'Unknown platform: {platform}')
424
+ msg = f'Unknown platform: {platform}'
425
+ raise KeyError(msg)
251
426
  return platform
252
427
 
253
428
  if resid_batch_size == 'auto':
254
429
  platform = get_platform()
255
430
  n = max(1, y.size)
256
431
  if platform == 'cpu':
257
- resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
432
+ resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
258
433
  elif platform == 'gpu':
259
- resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
434
+ resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
260
435
  resid_batch_size = max(1, resid_batch_size)
261
436
 
262
437
  if count_batch_size == 'auto':
@@ -265,253 +440,381 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
265
440
  count_batch_size = None
266
441
  elif platform == 'gpu':
267
442
  n = max(1, y.size)
268
- count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
443
+ count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
269
444
  # /4 is good on V100, /2 on L4/T4, still haven't tried A100
270
445
  max_memory = 2**29
271
446
  itemsize = 4
272
- min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
447
+ min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
273
448
  count_batch_size = max(count_batch_size, min_batch_size)
274
449
  count_batch_size = max(1, count_batch_size)
275
450
 
276
451
  return resid_batch_size, count_batch_size
277
452
 
278
453
 
279
- def step(key, bart):
454
+ @jax.jit
455
+ def step(key: Key[Array, ''], bart: State) -> State:
280
456
  """
281
- Perform one full MCMC step on a BART state.
457
+ Do one MCMC step.
282
458
 
283
459
  Parameters
284
460
  ----------
285
- key : jax.dtypes.prng_key array
461
+ key
286
462
  A jax random key.
287
- bart : dict
463
+ bart
288
464
  A BART mcmc state, as created by `init`.
289
465
 
290
466
  Returns
291
467
  -------
292
- bart : dict
293
- The new BART mcmc state.
468
+ The new BART mcmc state.
294
469
  """
295
- key, subkey = random.split(key)
296
- bart = sample_trees(subkey, bart)
297
- return sample_sigma(key, bart)
470
+ keys = split(key)
471
+
472
+ if bart.y.dtype == bool: # binary regression
473
+ bart = replace(bart, sigma2=jnp.float32(1))
474
+ bart = step_trees(keys.pop(), bart)
475
+ bart = replace(bart, sigma2=None)
476
+ return step_z(keys.pop(), bart)
298
477
 
478
+ else: # continuous regression
479
+ bart = step_trees(keys.pop(), bart)
480
+ return step_sigma(keys.pop(), bart)
299
481
 
300
- def sample_trees(key, bart):
482
+
483
+ def step_trees(key: Key[Array, ''], bart: State) -> State:
301
484
  """
302
485
  Forest sampling step of BART MCMC.
303
486
 
304
487
  Parameters
305
488
  ----------
306
- key : jax.dtypes.prng_key array
489
+ key
307
490
  A jax random key.
308
- bart : dict
491
+ bart
309
492
  A BART mcmc state, as created by `init`.
310
493
 
311
494
  Returns
312
495
  -------
313
- bart : dict
314
- The new BART mcmc state.
496
+ The new BART mcmc state.
315
497
 
316
498
  Notes
317
499
  -----
318
500
  This function zeroes the proposal counters.
319
501
  """
320
- key, subkey = random.split(key)
321
- moves = sample_moves(subkey, bart)
322
- return accept_moves_and_sample_leaves(key, bart, moves)
502
+ keys = split(key)
503
+ moves = propose_moves(keys.pop(), bart.forest)
504
+ return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
505
+
323
506
 
507
+ class Moves(Module):
508
+ """
509
+ Moves proposed to modify each tree.
324
510
 
325
- def sample_moves(key, bart):
511
+ Parameters
512
+ ----------
513
+ allowed
514
+ Whether there is a possible move. If `False`, the other values may not
515
+ make sense. The only case in which a move is marked as allowed but is
516
+ then vetoed is if it does not satisfy `min_points_per_leaf`, which for
517
+ efficiency is implemented post-hoc without changing the rest of the
518
+ MCMC logic.
519
+ grow
520
+ Whether the move is a grow move or a prune move.
521
+ num_growable
522
+ The number of growable leaves in the original tree.
523
+ node
524
+ The index of the leaf to grow or node to prune.
525
+ left
526
+ right
527
+ The indices of the children of 'node'.
528
+ partial_ratio
529
+ A factor of the Metropolis-Hastings ratio of the move. It lacks the
530
+ likelihood ratio, the probability of proposing the prune move, and the
531
+ probability that the children of the modified node are terminal. If the
532
+ move is PRUNE, the ratio is inverted. `None` once
533
+ `log_trans_prior_ratio` has been computed.
534
+ log_trans_prior_ratio
535
+ The logarithm of the product of the transition and prior terms of the
536
+ Metropolis-Hastings ratio for the acceptance of the proposed move.
537
+ `None` if not yet computed. If PRUNE, the log-ratio is negated.
538
+ grow_var
539
+ The decision axes of the new rules.
540
+ grow_split
541
+ The decision boundaries of the new rules.
542
+ var_tree
543
+ The updated decision axes of the trees, valid whatever move.
544
+ affluence_tree
545
+ A partially updated `affluence_tree`, marking non-leaf nodes that would
546
+ become leaves if the move was accepted. This mark initially (out of
547
+ `propose_moves`) takes into account if there would be available decision
548
+ rules to grow the leaf, and whether there are enough datapoints in the
549
+ node is marked in `accept_moves_parallel_stage`.
550
+ logu
551
+ The logarithm of a uniform (0, 1] random variable to be used to
552
+ accept the move. It's in (-oo, 0].
553
+ acc
554
+ Whether the move was accepted. `None` if not yet computed.
555
+ to_prune
556
+ Whether the final operation to apply the move is pruning. This indicates
557
+ an accepted prune move or a rejected grow move. `None` if not yet
558
+ computed.
559
+ """
560
+
561
+ allowed: Bool[Array, ' num_trees']
562
+ grow: Bool[Array, ' num_trees']
563
+ num_growable: UInt[Array, ' num_trees']
564
+ node: UInt[Array, ' num_trees']
565
+ left: UInt[Array, ' num_trees']
566
+ right: UInt[Array, ' num_trees']
567
+ partial_ratio: Float32[Array, ' num_trees'] | None
568
+ log_trans_prior_ratio: None | Float32[Array, ' num_trees']
569
+ grow_var: UInt[Array, ' num_trees']
570
+ grow_split: UInt[Array, ' num_trees']
571
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
572
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
573
+ logu: Float32[Array, ' num_trees']
574
+ acc: None | Bool[Array, ' num_trees']
575
+ to_prune: None | Bool[Array, ' num_trees']
576
+
577
+
578
+ def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
326
579
  """
327
580
  Propose moves for all the trees.
328
581
 
582
+ There are two types of moves: GROW (convert a leaf to a decision node and
583
+ add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
584
+ leaf, deleting its children).
585
+
329
586
  Parameters
330
587
  ----------
331
- key : jax.dtypes.prng_key array
588
+ key
332
589
  A jax random key.
333
- bart : dict
334
- BART mcmc state.
590
+ forest
591
+ The `forest` field of a BART MCMC state.
335
592
 
336
593
  Returns
337
594
  -------
338
- moves : dict
339
- A dictionary with fields:
340
-
341
- 'allowed' : bool array (num_trees,)
342
- Whether the move is possible.
343
- 'grow' : bool array (num_trees,)
344
- Whether the move is a grow move or a prune move.
345
- 'num_growable' : int array (num_trees,)
346
- The number of growable leaves in the original tree.
347
- 'node' : int array (num_trees,)
348
- The index of the leaf to grow or node to prune.
349
- 'left', 'right' : int array (num_trees,)
350
- The indices of the children of 'node'.
351
- 'partial_ratio' : float array (num_trees,)
352
- A factor of the Metropolis-Hastings ratio of the move. It lacks
353
- the likelihood ratio and the probability of proposing the prune
354
- move. If the move is Prune, the ratio is inverted.
355
- 'grow_var' : int array (num_trees,)
356
- The decision axes of the new rules.
357
- 'grow_split' : int array (num_trees,)
358
- The decision boundaries of the new rules.
359
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
360
- The updated decision axes of the trees, valid whatever move.
361
- 'logu' : float array (num_trees,)
362
- The logarithm of a uniform (0, 1] random variable to be used to
363
- accept the move. It's in (-oo, 0].
364
- """
365
- ntree = bart['leaf_trees'].shape[0]
366
- key = random.split(key, 1 + ntree)
367
- key, subkey = key[0], key[1:]
595
+ The proposed move for each tree.
596
+ """
597
+ num_trees, _ = forest.leaf_tree.shape
598
+ keys = split(key, 1 + 2 * num_trees)
368
599
 
369
600
  # compute moves
370
- grow_moves, prune_moves = _sample_moves_vmap_trees(
371
- subkey,
372
- bart['var_trees'],
373
- bart['split_trees'],
374
- bart['affluence_trees'],
375
- bart['max_split'],
376
- bart['p_nonterminal'],
377
- bart['p_propose_grow'],
601
+ grow_moves = propose_grow_moves(
602
+ keys.pop(num_trees),
603
+ forest.var_tree,
604
+ forest.split_tree,
605
+ forest.affluence_tree,
606
+ forest.max_split,
607
+ forest.blocked_vars,
608
+ forest.p_nonterminal,
609
+ forest.p_propose_grow,
610
+ forest.log_s,
611
+ )
612
+ prune_moves = propose_prune_moves(
613
+ keys.pop(num_trees),
614
+ forest.split_tree,
615
+ grow_moves.affluence_tree,
616
+ forest.p_nonterminal,
617
+ forest.p_propose_grow,
378
618
  )
379
619
 
380
- u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
620
+ u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
381
621
 
382
622
  # choose between grow or prune
383
- grow_allowed = grow_moves['num_growable'].astype(bool)
384
- p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
623
+ p_grow = jnp.where(
624
+ grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
625
+ )
385
626
  grow = u < p_grow # use < instead of <= because u is in [0, 1)
386
627
 
387
628
  # compute children indices
388
- node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
629
+ node = jnp.where(grow, grow_moves.node, prune_moves.node)
389
630
  left = node << 1
390
631
  right = left + 1
391
632
 
392
- return dict(
393
- allowed=grow | prune_moves['allowed'],
633
+ return Moves(
634
+ allowed=grow_moves.allowed | prune_moves.allowed,
394
635
  grow=grow,
395
- num_growable=grow_moves['num_growable'],
636
+ num_growable=grow_moves.num_growable,
396
637
  node=node,
397
638
  left=left,
398
639
  right=right,
399
640
  partial_ratio=jnp.where(
400
- grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']
641
+ grow, grow_moves.partial_ratio, prune_moves.partial_ratio
401
642
  ),
402
- grow_var=grow_moves['var'],
403
- grow_split=grow_moves['split'],
404
- var_trees=grow_moves['var_tree'],
405
- logu=jnp.log1p(-logu),
643
+ log_trans_prior_ratio=None, # will be set in complete_ratio
644
+ grow_var=grow_moves.var,
645
+ grow_split=grow_moves.split,
646
+ # var_tree does not need to be updated if prune
647
+ var_tree=grow_moves.var_tree,
648
+ # affluence_tree is updated for both moves unconditionally, prune last
649
+ affluence_tree=prune_moves.affluence_tree,
650
+ logu=jnp.log1p(-exp1mlogu),
651
+ acc=None, # will be set in accept_moves_sequential_stage
652
+ to_prune=None, # will be set in accept_moves_sequential_stage
406
653
  )
407
654
 
408
655
 
409
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
410
- def _sample_moves_vmap_trees(*args):
411
- key, args = args[0], args[1:]
412
- key, key1 = random.split(key)
413
- grow = grow_move(key, *args)
414
- prune = prune_move(key1, *args)
415
- return grow, prune
416
-
417
-
418
- def grow_move(
419
- key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
420
- ):
656
+ class GrowMoves(Module):
421
657
  """
422
- Tree structure grow move proposal of BART MCMC.
658
+ Represent a proposed grow move for each tree.
423
659
 
424
- This moves picks a leaf node and converts it to a non-terminal node with
425
- two leaf children. The move is not possible if all the leaves are already at
426
- maximum depth.
660
+ Parameters
661
+ ----------
662
+ allowed
663
+ Whether the move is allowed for proposal.
664
+ num_growable
665
+ The number of leaves that can be proposed for grow.
666
+ node
667
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
668
+ leaves.
669
+ var
670
+ split
671
+ The decision axis and boundary of the new rule.
672
+ partial_ratio
673
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
674
+ the likelihood ratio and the probability of proposing the prune
675
+ move.
676
+ var_tree
677
+ The updated decision axes of the tree.
678
+ affluence_tree
679
+ A partially updated `affluence_tree` that marks each new leaf that
680
+ would be produced as `True` if it would have available decision rules.
681
+ """
682
+
683
+ allowed: Bool[Array, ' num_trees']
684
+ num_growable: UInt[Array, ' num_trees']
685
+ node: UInt[Array, ' num_trees']
686
+ var: UInt[Array, ' num_trees']
687
+ split: UInt[Array, ' num_trees']
688
+ partial_ratio: Float32[Array, ' num_trees']
689
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
690
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
691
+
692
+
693
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
694
+ def propose_grow_moves(
695
+ key: Key[Array, ' num_trees'],
696
+ var_tree: UInt[Array, 'num_trees 2**(d-1)'],
697
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'],
698
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
699
+ max_split: UInt[Array, ' p'],
700
+ blocked_vars: Int32[Array, ' k'] | None,
701
+ p_nonterminal: Float32[Array, ' 2**d'],
702
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
703
+ log_s: Float32[Array, ' p'] | None,
704
+ ) -> GrowMoves:
705
+ """
706
+ Propose a GROW move for each tree.
707
+
708
+ A GROW move picks a leaf node and converts it to a non-terminal node with
709
+ two leaf children.
427
710
 
428
711
  Parameters
429
712
  ----------
430
- var_tree : array (2 ** (d - 1),)
431
- The variable indices of the tree.
432
- split_tree : array (2 ** (d - 1),)
713
+ key
714
+ A jax random key.
715
+ var_tree
716
+ The splitting axes of the tree.
717
+ split_tree
433
718
  The splitting points of the tree.
434
- affluence_tree : bool array (2 ** (d - 1),) or None
435
- Whether a leaf has enough points to be grown.
436
- max_split : array (p,)
719
+ affluence_tree
720
+ Whether each leaf has enough points to be grown.
721
+ max_split
437
722
  The maximum split index for each variable.
438
- p_nonterminal : array (d,)
439
- The probability of a nonterminal node at each depth.
440
- p_propose_grow : array (2 ** (d - 1),)
723
+ blocked_vars
724
+ The indices of the variables that have no available cutpoints.
725
+ p_nonterminal
726
+ The a priori probability of a node to be nonterminal conditional on the
727
+ ancestors, including at the maximum depth where it should be zero.
728
+ p_propose_grow
441
729
  The unnormalized probability of choosing a leaf to grow.
442
- key : jax.dtypes.prng_key array
443
- A jax random key.
730
+ log_s
731
+ Unnormalized log-probability used to choose a variable to split on
732
+ amongst the available ones.
444
733
 
445
734
  Returns
446
735
  -------
447
- grow_move : dict
448
- A dictionary with fields:
449
-
450
- 'num_growable' : int
451
- The number of growable leaves.
452
- 'node' : int
453
- The index of the leaf to grow. ``2 ** d`` if there are no growable
454
- leaves.
455
- 'var', 'split' : int
456
- The decision axis and boundary of the new rule.
457
- 'partial_ratio' : float
458
- A factor of the Metropolis-Hastings ratio of the move. It lacks
459
- the likelihood ratio and the probability of proposing the prune
460
- move.
461
- 'var_tree' : array (2 ** (d - 1),)
462
- The updated decision axes of the tree.
463
- """
464
-
465
- key, key1, key2 = random.split(key, 3)
736
+ An object representing the proposed move.
737
+
738
+ Notes
739
+ -----
740
+ The move is not proposed if each leaf is already at maximum depth, or has
741
+ less datapoints than the requested threshold `min_points_per_decision_node`,
742
+ or it does not have any available decision rules given its ancestors. This
743
+ is marked by setting `allowed` to `False` and `num_growable` to 0.
744
+ """
745
+ keys = split(key, 3)
466
746
 
467
747
  leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
468
- key, split_tree, affluence_tree, p_propose_grow
748
+ keys.pop(), split_tree, affluence_tree, p_propose_grow
469
749
  )
470
750
 
471
- var = choose_variable(key1, var_tree, split_tree, max_split, leaf_to_grow)
472
- var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
751
+ # sample a decision rule
752
+ var, num_available_var = choose_variable(
753
+ keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
754
+ )
755
+ split_idx, l, r = choose_split(
756
+ keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
757
+ )
473
758
 
474
- split = choose_split(key2, var_tree, split_tree, max_split, leaf_to_grow)
759
+ # determine if the new leaves would have available decision rules; if the
760
+ # move is blocked, these values may not make sense
761
+ left_growable = right_growable = num_available_var > 1
762
+ left_growable |= l < split_idx
763
+ right_growable |= split_idx + 1 < r
764
+ left = leaf_to_grow << 1
765
+ right = left + 1
766
+ affluence_tree = affluence_tree.at[left].set(left_growable)
767
+ affluence_tree = affluence_tree.at[right].set(right_growable)
475
768
 
476
769
  ratio = compute_partial_ratio(
477
770
  prob_choose, num_prunable, p_nonterminal, leaf_to_grow
478
771
  )
479
772
 
480
- return dict(
773
+ return GrowMoves(
774
+ allowed=num_growable > 0,
481
775
  num_growable=num_growable,
482
776
  node=leaf_to_grow,
483
777
  var=var,
484
- split=split,
778
+ split=split_idx,
485
779
  partial_ratio=ratio,
486
- var_tree=var_tree,
780
+ var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
781
+ affluence_tree=affluence_tree,
487
782
  )
488
783
 
489
784
 
490
- def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
785
+ def choose_leaf(
786
+ key: Key[Array, ''],
787
+ split_tree: UInt[Array, ' 2**(d-1)'],
788
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
789
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
790
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
491
791
  """
492
792
  Choose a leaf node to grow in a tree.
493
793
 
494
794
  Parameters
495
795
  ----------
496
- split_tree : array (2 ** (d - 1),)
796
+ key
797
+ A jax random key.
798
+ split_tree
497
799
  The splitting points of the tree.
498
- affluence_tree : bool array (2 ** (d - 1),) or None
499
- Whether a leaf has enough points to be grown.
500
- p_propose_grow : array (2 ** (d - 1),)
800
+ affluence_tree
801
+ Whether a leaf has enough points that it could be split into two leaves
802
+ satisfying the `min_points_per_leaf` requirement.
803
+ p_propose_grow
501
804
  The unnormalized probability of choosing a leaf to grow.
502
- key : jax.dtypes.prng_key array
503
- A jax random key.
504
805
 
505
806
  Returns
506
807
  -------
507
- leaf_to_grow : int
808
+ leaf_to_grow : Int32[Array, '']
508
809
  The index of the leaf to grow. If ``num_growable == 0``, return
509
810
  ``2 ** d``.
510
- num_growable : int
511
- The number of leaf nodes that can be grown.
512
- prob_choose : float
513
- The normalized probability of choosing the selected leaf.
514
- num_prunable : int
811
+ num_growable : Int32[Array, '']
812
+ The number of leaf nodes that can be grown, i.e., are nonterminal
813
+ and have at least twice `min_points_per_leaf`.
814
+ prob_choose : Float32[Array, '']
815
+ The (normalized) probability that this function had to choose that
816
+ specific leaf, given the arguments.
817
+ num_prunable : Int32[Array, '']
515
818
  The number of leaf parents that could be pruned, after converting the
516
819
  selected leaf to a non-terminal node.
517
820
  """
@@ -520,145 +823,189 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
520
823
  distr = jnp.where(is_growable, p_propose_grow, 0)
521
824
  leaf_to_grow, distr_norm = categorical(key, distr)
522
825
  leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
523
- prob_choose = distr[leaf_to_grow] / distr_norm
826
+ prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
524
827
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
525
828
  num_prunable = jnp.count_nonzero(is_parent)
526
829
  return leaf_to_grow, num_growable, prob_choose, num_prunable
527
830
 
528
831
 
529
- def growable_leaves(split_tree, affluence_tree):
832
+ def growable_leaves(
833
+ split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
834
+ ) -> Bool[Array, ' 2**(d-1)']:
530
835
  """
531
836
  Return a mask indicating the leaf nodes that can be proposed for growth.
532
837
 
838
+ The condition is that a leaf is not at the bottom level, has available
839
+ decision rules given its ancestors, and has at least
840
+ `min_points_per_decision_node` points.
841
+
533
842
  Parameters
534
843
  ----------
535
- split_tree : array (2 ** (d - 1),)
844
+ split_tree
536
845
  The splitting points of the tree.
537
- affluence_tree : bool array (2 ** (d - 1),) or None
538
- Whether a leaf has enough points to be grown.
846
+ affluence_tree
847
+ Marks leaves that can be grown.
539
848
 
540
849
  Returns
541
850
  -------
542
- is_growable : bool array (2 ** (d - 1),)
543
- The mask indicating the leaf nodes that can be proposed to grow, i.e.,
544
- that are not at the bottom level and have at least two times the number
545
- of minimum points per leaf.
851
+ The mask indicating the leaf nodes that can be proposed to grow.
852
+
853
+ Notes
854
+ -----
855
+ This function needs `split_tree` and not just `affluence_tree` because
856
+ `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
546
857
  """
547
- is_growable = grove.is_actual_leaf(split_tree)
548
- if affluence_tree is not None:
549
- is_growable &= affluence_tree
550
- return is_growable
858
+ return grove.is_actual_leaf(split_tree) & affluence_tree
551
859
 
552
860
 
553
- def categorical(key, distr):
861
+ def categorical(
862
+ key: Key[Array, ''], distr: Float32[Array, ' n']
863
+ ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
554
864
  """
555
865
  Return a random integer from an arbitrary distribution.
556
866
 
557
867
  Parameters
558
868
  ----------
559
- key : jax.dtypes.prng_key array
869
+ key
560
870
  A jax random key.
561
- distr : float array (n,)
871
+ distr
562
872
  An unnormalized probability distribution.
563
873
 
564
874
  Returns
565
875
  -------
566
- u : int
876
+ u : Int32[Array, '']
567
877
  A random integer in the range ``[0, n)``. If all probabilities are zero,
568
878
  return ``n``.
569
- norm : float
879
+ norm : Float32[Array, '']
570
880
  The sum of `distr`.
881
+
882
+ Notes
883
+ -----
884
+ This function uses a cumsum instead of the Gumbel trick, so it's ok only
885
+ for small ranges with probabilities well greater than 0.
571
886
  """
572
887
  ecdf = jnp.cumsum(distr)
573
888
  u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
574
889
  return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
575
890
 
576
891
 
577
- def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
892
+ def choose_variable(
893
+ key: Key[Array, ''],
894
+ var_tree: UInt[Array, ' 2**(d-1)'],
895
+ split_tree: UInt[Array, ' 2**(d-1)'],
896
+ max_split: UInt[Array, ' p'],
897
+ leaf_index: Int32[Array, ''],
898
+ blocked_vars: Int32[Array, ' k'] | None,
899
+ log_s: Float32[Array, ' p'] | None,
900
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
578
901
  """
579
902
  Choose a variable to split on for a new non-terminal node.
580
903
 
581
904
  Parameters
582
905
  ----------
583
- var_tree : int array (2 ** (d - 1),)
906
+ key
907
+ A jax random key.
908
+ var_tree
584
909
  The variable indices of the tree.
585
- split_tree : int array (2 ** (d - 1),)
910
+ split_tree
586
911
  The splitting points of the tree.
587
- max_split : int array (p,)
912
+ max_split
588
913
  The maximum split index for each variable.
589
- leaf_index : int
914
+ leaf_index
590
915
  The index of the leaf to grow.
591
- key : jax.dtypes.prng_key array
592
- A jax random key.
916
+ blocked_vars
917
+ The indices of the variables that have no available cutpoints. If
918
+ `None`, all variables are assumed unblocked.
919
+ log_s
920
+ The logarithm of the prior probability for choosing a variable. If
921
+ `None`, use a uniform distribution.
593
922
 
594
923
  Returns
595
924
  -------
596
- var : int
925
+ var : Int32[Array, '']
597
926
  The index of the variable to split on.
598
-
599
- Notes
600
- -----
601
- The variable is chosen among the variables that have a non-empty range of
602
- allowed splits. If no variable has a non-empty range, return `p`.
927
+ num_available_var : Int32[Array, '']
928
+ The number of variables with available decision rules `var` was chosen
929
+ from.
603
930
  """
604
931
  var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
605
- return randint_exclude(key, max_split.size, var_to_ignore)
932
+ if blocked_vars is not None:
933
+ var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
934
+
935
+ if log_s is None:
936
+ return randint_exclude(key, max_split.size, var_to_ignore)
937
+ else:
938
+ return categorical_exclude(key, log_s, var_to_ignore)
606
939
 
607
940
 
608
- def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
941
+ def fully_used_variables(
942
+ var_tree: UInt[Array, ' 2**(d-1)'],
943
+ split_tree: UInt[Array, ' 2**(d-1)'],
944
+ max_split: UInt[Array, ' p'],
945
+ leaf_index: Int32[Array, ''],
946
+ ) -> UInt[Array, ' d-2']:
609
947
  """
610
- Return a list of variables that have an empty split range at a given node.
948
+ Find variables in the ancestors of a node that have an empty split range.
611
949
 
612
950
  Parameters
613
951
  ----------
614
- var_tree : int array (2 ** (d - 1),)
952
+ var_tree
615
953
  The variable indices of the tree.
616
- split_tree : int array (2 ** (d - 1),)
954
+ split_tree
617
955
  The splitting points of the tree.
618
- max_split : int array (p,)
956
+ max_split
619
957
  The maximum split index for each variable.
620
- leaf_index : int
958
+ leaf_index
621
959
  The index of the node, assumed to be valid for `var_tree`.
622
960
 
623
961
  Returns
624
962
  -------
625
- var_to_ignore : int array (d - 2,)
626
- The indices of the variables that have an empty split range. Since the
627
- number of such variables is not fixed, unused values in the array are
628
- filled with `p`. The fill values are not guaranteed to be placed in any
629
- particular order. Variables may appear more than once.
630
- """
963
+ The indices of the variables that have an empty split range.
631
964
 
965
+ Notes
966
+ -----
967
+ The number of unused variables is not known in advance. Unused values in the
968
+ array are filled with `p`. The fill values are not guaranteed to be placed
969
+ in any particular order, and variables may appear more than once.
970
+ """
632
971
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
633
972
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
634
973
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
635
974
  num_split = r - l
636
975
  return jnp.where(num_split == 0, var_to_ignore, max_split.size)
976
+ # the type of var_to_ignore is already sufficient to hold max_split.size,
977
+ # see ancestor_variables()
637
978
 
638
979
 
639
- def ancestor_variables(var_tree, max_split, node_index):
980
+ def ancestor_variables(
981
+ var_tree: UInt[Array, ' 2**(d-1)'],
982
+ max_split: UInt[Array, ' p'],
983
+ node_index: Int32[Array, ''],
984
+ ) -> UInt[Array, ' d-2']:
640
985
  """
641
986
  Return the list of variables in the ancestors of a node.
642
987
 
643
988
  Parameters
644
989
  ----------
645
- var_tree : int array (2 ** (d - 1),)
990
+ var_tree
646
991
  The variable indices of the tree.
647
- max_split : int array (p,)
992
+ max_split
648
993
  The maximum split index for each variable. Used only to get `p`.
649
- node_index : int
994
+ node_index
650
995
  The index of the node, assumed to be valid for `var_tree`.
651
996
 
652
997
  Returns
653
998
  -------
654
- ancestor_vars : int array (d - 2,)
655
- The variable indices of the ancestors of the node, from the root to
656
- the parent. Unused spots are filled with `p`.
999
+ The variable indices of the ancestors of the node.
1000
+
1001
+ Notes
1002
+ -----
1003
+ The ancestors are the nodes going from the root to the parent of the node.
1004
+ The number of ancestors is not known at tracing time; unused spots in the
1005
+ output array are filled with `p`.
657
1006
  """
658
1007
  max_num_ancestors = grove.tree_depth(var_tree) - 1
659
- ancestor_vars = jnp.zeros(
660
- max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
661
- )
1008
+ ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
662
1009
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
663
1010
 
664
1011
  def loop(carry, _):
@@ -673,33 +1020,38 @@ def ancestor_variables(var_tree, max_split, node_index):
673
1020
  return ancestor_vars
674
1021
 
675
1022
 
676
- def split_range(var_tree, split_tree, max_split, node_index, ref_var):
1023
+ def split_range(
1024
+ var_tree: UInt[Array, ' 2**(d-1)'],
1025
+ split_tree: UInt[Array, ' 2**(d-1)'],
1026
+ max_split: UInt[Array, ' p'],
1027
+ node_index: Int32[Array, ''],
1028
+ ref_var: Int32[Array, ''],
1029
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
677
1030
  """
678
1031
  Return the range of allowed splits for a variable at a given node.
679
1032
 
680
1033
  Parameters
681
1034
  ----------
682
- var_tree : int array (2 ** (d - 1),)
1035
+ var_tree
683
1036
  The variable indices of the tree.
684
- split_tree : int array (2 ** (d - 1),)
1037
+ split_tree
685
1038
  The splitting points of the tree.
686
- max_split : int array (p,)
1039
+ max_split
687
1040
  The maximum split index for each variable.
688
- node_index : int
1041
+ node_index
689
1042
  The index of the node, assumed to be valid for `var_tree`.
690
- ref_var : int
1043
+ ref_var
691
1044
  The variable for which to measure the split range.
692
1045
 
693
1046
  Returns
694
1047
  -------
695
- l, r : int
696
- The range of allowed splits is [l, r).
1048
+ The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
697
1049
  """
698
1050
  max_num_ancestors = grove.tree_depth(var_tree) - 1
699
1051
  initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
700
1052
  jnp.int32
701
1053
  )
702
- carry = 0, initial_r, node_index
1054
+ carry = jnp.int32(0), initial_r, node_index
703
1055
 
704
1056
  def loop(carry, _):
705
1057
  l, r, index = carry
@@ -715,259 +1067,501 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
715
1067
  return l + 1, r
716
1068
 
717
1069
 
718
- def randint_exclude(key, sup, exclude):
1070
+ def randint_exclude(
1071
+ key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1072
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
719
1073
  """
720
1074
  Return a random integer in a range, excluding some values.
721
1075
 
722
1076
  Parameters
723
1077
  ----------
724
- key : jax.dtypes.prng_key array
1078
+ key
725
1079
  A jax random key.
726
- sup : int
1080
+ sup
727
1081
  The exclusive upper bound of the range.
728
- exclude : int array (n,)
1082
+ exclude
729
1083
  The values to exclude from the range. Values greater than or equal to
730
1084
  `sup` are ignored. Values can appear more than once.
731
1085
 
732
1086
  Returns
733
1087
  -------
734
- u : int
735
- A random integer in the range ``[0, sup)``, and which satisfies
736
- ``u not in exclude``. If all values in the range are excluded, return
737
- `sup`.
1088
+ u : Int32[Array, '']
1089
+ A random integer `u` in the range ``[0, sup)`` such that ``u not in
1090
+ exclude``.
1091
+ num_allowed : Int32[Array, '']
1092
+ The number of integers in the range that were not excluded.
1093
+
1094
+ Notes
1095
+ -----
1096
+ If all values in the range are excluded, return `sup`.
738
1097
  """
739
- exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
740
- num_allowed = sup - jnp.count_nonzero(exclude < sup)
1098
+ exclude, num_allowed = _process_exclude(sup, exclude)
741
1099
  u = random.randint(key, (), 0, num_allowed)
742
1100
 
743
- def loop(u, i):
744
- return jnp.where(i <= u, u + 1, u), None
1101
+ def loop(u, i_excluded):
1102
+ return jnp.where(i_excluded <= u, u + 1, u), None
745
1103
 
746
1104
  u, _ = lax.scan(loop, u, exclude)
747
- return u
1105
+ return u, num_allowed
1106
+
1107
+
1108
+ def _process_exclude(sup, exclude):
1109
+ exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
1110
+ num_allowed = sup - jnp.count_nonzero(exclude < sup)
1111
+ return exclude, num_allowed
1112
+
1113
+
1114
+ def categorical_exclude(
1115
+ key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1116
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1117
+ """
1118
+ Draw from a categorical distribution, excluding a set of values.
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ key
1123
+ A jax random key.
1124
+ logits
1125
+ The unnormalized log-probabilities of each category.
1126
+ exclude
1127
+ The values to exclude from the range [0, k). Values greater than or
1128
+ equal to `logits.size` are ignored. Values can appear more than once.
1129
+
1130
+ Returns
1131
+ -------
1132
+ u : Int32[Array, '']
1133
+ A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1134
+ num_allowed : Int32[Array, '']
1135
+ The number of integers in the range that were not excluded.
1136
+
1137
+ Notes
1138
+ -----
1139
+ If all values in the range are excluded, the result is unspecified.
1140
+ """
1141
+ exclude, num_allowed = _process_exclude(logits.size, exclude)
1142
+ kinda_neg_inf = jnp.finfo(logits.dtype).min
1143
+ logits = logits.at[exclude].set(kinda_neg_inf)
1144
+ u = random.categorical(key, logits)
1145
+ return u, num_allowed
748
1146
 
749
1147
 
750
- def choose_split(key, var_tree, split_tree, max_split, leaf_index):
1148
+ def choose_split(
1149
+ key: Key[Array, ''],
1150
+ var: Int32[Array, ''],
1151
+ var_tree: UInt[Array, ' 2**(d-1)'],
1152
+ split_tree: UInt[Array, ' 2**(d-1)'],
1153
+ max_split: UInt[Array, ' p'],
1154
+ leaf_index: Int32[Array, ''],
1155
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
751
1156
  """
752
1157
  Choose a split point for a new non-terminal node.
753
1158
 
754
1159
  Parameters
755
1160
  ----------
756
- var_tree : int array (2 ** (d - 1),)
757
- The variable indices of the tree.
758
- split_tree : int array (2 ** (d - 1),)
1161
+ key
1162
+ A jax random key.
1163
+ var
1164
+ The variable to split on.
1165
+ var_tree
1166
+ The splitting axes of the tree. Does not need to already contain `var`
1167
+ at `leaf_index`.
1168
+ split_tree
759
1169
  The splitting points of the tree.
760
- max_split : int array (p,)
1170
+ max_split
761
1171
  The maximum split index for each variable.
762
- leaf_index : int
763
- The index of the leaf to grow. It is assumed that `var_tree` already
764
- contains the target variable at this index.
765
- key : jax.dtypes.prng_key array
766
- A jax random key.
1172
+ leaf_index
1173
+ The index of the leaf to grow.
767
1174
 
768
1175
  Returns
769
1176
  -------
770
- split : int
771
- The split point.
1177
+ split : Int32[Array, '']
1178
+ The cutpoint.
1179
+ l : Int32[Array, '']
1180
+ r : Int32[Array, '']
1181
+ The integer range `split` was drawn from is [l, r).
1182
+
1183
+ Notes
1184
+ -----
1185
+ If `var` is out of bounds, or if the available split range on that variable
1186
+ is empty, return 0.
772
1187
  """
773
- var = var_tree[leaf_index]
774
1188
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
775
- return random.randint(key, (), l, r)
1189
+ return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
776
1190
 
777
1191
 
778
- def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
1192
+ def compute_partial_ratio(
1193
+ prob_choose: Float32[Array, ''],
1194
+ num_prunable: Int32[Array, ''],
1195
+ p_nonterminal: Float32[Array, ' 2**d'],
1196
+ leaf_to_grow: Int32[Array, ''],
1197
+ ) -> Float32[Array, '']:
779
1198
  """
780
1199
  Compute the product of the transition and prior ratios of a grow move.
781
1200
 
782
1201
  Parameters
783
1202
  ----------
784
- num_growable : int
785
- The number of leaf nodes that can be grown.
786
- num_prunable : int
1203
+ prob_choose
1204
+ The probability that the leaf had to be chosen amongst the growable
1205
+ leaves.
1206
+ num_prunable
787
1207
  The number of leaf parents that could be pruned, after converting the
788
1208
  leaf to be grown to a non-terminal node.
789
- p_nonterminal : array (d,)
790
- The probability of a nonterminal node at each depth.
791
- leaf_to_grow : int
1209
+ p_nonterminal
1210
+ The a priori probability of each node being nonterminal conditional on
1211
+ its ancestors.
1212
+ leaf_to_grow
792
1213
  The index of the leaf to grow.
793
1214
 
794
1215
  Returns
795
1216
  -------
796
- ratio : float
797
- The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
798
- times the prior ratio P(new tree) / P(old tree), but the transition
799
- ratio is missing the factor P(propose prune) in the numerator.
800
- """
1217
+ The partial transition ratio times the prior ratio.
801
1218
 
1219
+ Notes
1220
+ -----
1221
+ The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1222
+ The "partial" transition ratio returned is missing the factor P(propose
1223
+ prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1224
+ "partial" prior ratio is missing the factor P(children are leaves).
1225
+ """
802
1226
  # the two ratios also contain factors num_available_split *
803
- # num_available_var, but they cancel out
1227
+ # num_available_var * s[var], but they cancel out
804
1228
 
805
- # p_prune can't be computed here because it needs the count trees, which are
806
- # computed in the acceptance phase
1229
+ # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1230
+ # computed here because they need the count trees, which are computed in the
1231
+ # acceptance phase
807
1232
 
808
1233
  prune_allowed = leaf_to_grow != 1
809
1234
  # prune allowed <---> the initial tree is not a root
810
1235
  # leaf to grow is root --> the tree can only be a root
811
1236
  # tree is a root --> the only leaf I can grow is root
812
-
813
1237
  p_grow = jnp.where(prune_allowed, 0.5, 1)
814
-
815
1238
  inv_trans_ratio = p_grow * prob_choose * num_prunable
816
1239
 
817
- depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
818
- p_parent = p_nonterminal[depth]
819
- cp_children = 1 - p_nonterminal[depth + 1]
820
- tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
1240
+ # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1241
+ # would produce a 0 and then an inf when `complete_ratio` takes the log
1242
+ pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
1243
+ tree_ratio = pnt / (1 - pnt)
1244
+
1245
+ return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
821
1246
 
822
- return tree_ratio / inv_trans_ratio
823
1247
 
1248
+ class PruneMoves(Module):
1249
+ """
1250
+ Represent a proposed prune move for each tree.
824
1251
 
825
- def prune_move(
826
- key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
827
- ):
1252
+ Parameters
1253
+ ----------
1254
+ allowed
1255
+ Whether the move is possible.
1256
+ node
1257
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
1258
+ partial_ratio
1259
+ A factor of the Metropolis-Hastings ratio of the move. It lacks the
1260
+ likelihood ratio, the probability of proposing the prune move, and the
1261
+ prior probability that the children of the node to prune are leaves.
1262
+ This ratio is inverted, and is meant to be inverted back in
1263
+ `accept_move_and_sample_leaves`.
1264
+ """
1265
+
1266
+ allowed: Bool[Array, ' num_trees']
1267
+ node: UInt[Array, ' num_trees']
1268
+ partial_ratio: Float32[Array, ' num_trees']
1269
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
1270
+
1271
+
1272
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1273
+ def propose_prune_moves(
1274
+ key: Key[Array, ''],
1275
+ split_tree: UInt[Array, ' 2**(d-1)'],
1276
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
1277
+ p_nonterminal: Float32[Array, ' 2**d'],
1278
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
1279
+ ) -> PruneMoves:
828
1280
  """
829
1281
  Tree structure prune move proposal of BART MCMC.
830
1282
 
831
1283
  Parameters
832
1284
  ----------
833
- var_tree : int array (2 ** (d - 1),)
834
- The variable indices of the tree.
835
- split_tree : int array (2 ** (d - 1),)
1285
+ key
1286
+ A jax random key.
1287
+ split_tree
836
1288
  The splitting points of the tree.
837
- affluence_tree : bool array (2 ** (d - 1),) or None
838
- Whether a leaf has enough points to be grown.
839
- max_split : int array (p,)
840
- The maximum split index for each variable.
841
- p_nonterminal : float array (d,)
842
- The probability of a nonterminal node at each depth.
843
- p_propose_grow : float array (2 ** (d - 1),)
1289
+ affluence_tree
1290
+ Whether each leaf can be grown.
1291
+ p_nonterminal
1292
+ The a priori probability of a node to be nonterminal conditional on
1293
+ the ancestors, including at the maximum depth where it should be zero.
1294
+ p_propose_grow
844
1295
  The unnormalized probability of choosing a leaf to grow.
845
- key : jax.dtypes.prng_key array
846
- A jax random key.
847
1296
 
848
1297
  Returns
849
1298
  -------
850
- prune_move : dict
851
- A dictionary with fields:
852
-
853
- 'allowed' : bool
854
- Whether the move is possible.
855
- 'node' : int
856
- The index of the node to prune. ``2 ** d`` if no node can be pruned.
857
- 'partial_ratio' : float
858
- A factor of the Metropolis-Hastings ratio of the move. It lacks
859
- the likelihood ratio and the probability of proposing the prune
860
- move. This ratio is inverted.
861
- """
862
- node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
1299
+ An object representing the proposed moves.
1300
+ """
1301
+ node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
863
1302
  key, split_tree, affluence_tree, p_propose_grow
864
1303
  )
865
- allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
866
1304
 
867
1305
  ratio = compute_partial_ratio(
868
1306
  prob_choose, num_prunable, p_nonterminal, node_to_prune
869
1307
  )
870
1308
 
871
- return dict(
872
- allowed=allowed,
1309
+ return PruneMoves(
1310
+ allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
873
1311
  node=node_to_prune,
874
- partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
1312
+ partial_ratio=ratio,
1313
+ affluence_tree=affluence_tree,
875
1314
  )
876
1315
 
877
1316
 
878
- def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
1317
+ def choose_leaf_parent(
1318
+ key: Key[Array, ''],
1319
+ split_tree: UInt[Array, ' 2**(d-1)'],
1320
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
1321
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
1322
+ ) -> tuple[
1323
+ Int32[Array, ''],
1324
+ Int32[Array, ''],
1325
+ Float32[Array, ''],
1326
+ Bool[Array, 'num_trees 2**(d-1)'],
1327
+ ]:
879
1328
  """
880
1329
  Pick a non-terminal node with leaf children to prune in a tree.
881
1330
 
882
1331
  Parameters
883
1332
  ----------
884
- split_tree : array (2 ** (d - 1),)
1333
+ key
1334
+ A jax random key.
1335
+ split_tree
885
1336
  The splitting points of the tree.
886
- affluence_tree : bool array (2 ** (d - 1),) or None
1337
+ affluence_tree
887
1338
  Whether a leaf has enough points to be grown.
888
- p_propose_grow : array (2 ** (d - 1),)
1339
+ p_propose_grow
889
1340
  The unnormalized probability of choosing a leaf to grow.
890
- key : jax.dtypes.prng_key array
891
- A jax random key.
892
1341
 
893
1342
  Returns
894
1343
  -------
895
- node_to_prune : int
1344
+ node_to_prune : Int32[Array, '']
896
1345
  The index of the node to prune. If ``num_prunable == 0``, return
897
1346
  ``2 ** d``.
898
- num_prunable : int
1347
+ num_prunable : Int32[Array, '']
899
1348
  The number of leaf parents that could be pruned.
900
- prob_choose : float
901
- The normalized probability of choosing the node to prune for growth.
902
- """
1349
+ prob_choose : Float32[Array, '']
1350
+ The (normalized) probability that `choose_leaf` would chose
1351
+ `node_to_prune` as leaf to grow, if passed the tree where
1352
+ `node_to_prune` had been pruned.
1353
+ affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1354
+ A partially updated `affluence_tree`, marking the node to prune as
1355
+ growable.
1356
+ """
1357
+ # sample a node to prune
903
1358
  is_prunable = grove.is_leaves_parent(split_tree)
904
1359
  num_prunable = jnp.count_nonzero(is_prunable)
905
1360
  node_to_prune = randint_masked(key, is_prunable)
906
1361
  node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
907
1362
 
1363
+ # compute stuff for reverse move
908
1364
  split_tree = split_tree.at[node_to_prune].set(0)
909
- affluence_tree = (
910
- None if affluence_tree is None else affluence_tree.at[node_to_prune].set(True)
911
- )
1365
+ affluence_tree = affluence_tree.at[node_to_prune].set(True)
912
1366
  is_growable_leaf = growable_leaves(split_tree, affluence_tree)
913
- prob_choose = p_propose_grow[node_to_prune]
914
- prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
1367
+ distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
1368
+ prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
1369
+ prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
915
1370
 
916
- return node_to_prune, num_prunable, prob_choose
1371
+ return node_to_prune, num_prunable, prob_choose, affluence_tree
917
1372
 
918
1373
 
919
- def randint_masked(key, mask):
1374
+ def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
920
1375
  """
921
1376
  Return a random integer in a range, including only some values.
922
1377
 
923
1378
  Parameters
924
1379
  ----------
925
- key : jax.dtypes.prng_key array
1380
+ key
926
1381
  A jax random key.
927
- mask : bool array (n,)
1382
+ mask
928
1383
  The mask indicating the allowed values.
929
1384
 
930
1385
  Returns
931
1386
  -------
932
- u : int
933
- A random integer in the range ``[0, n)``, and which satisfies
934
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
1387
+ A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1388
+
1389
+ Notes
1390
+ -----
1391
+ If all values in the mask are `False`, return `n`.
935
1392
  """
936
1393
  ecdf = jnp.cumsum(mask)
937
1394
  u = random.randint(key, (), 0, ecdf[-1])
938
1395
  return jnp.searchsorted(ecdf, u, 'right')
939
1396
 
940
1397
 
941
- def accept_moves_and_sample_leaves(key, bart, moves):
1398
+ def accept_moves_and_sample_leaves(
1399
+ key: Key[Array, ''], bart: State, moves: Moves
1400
+ ) -> State:
942
1401
  """
943
1402
  Accept or reject the proposed moves and sample the new leaf values.
944
1403
 
945
1404
  Parameters
946
1405
  ----------
947
- key : jax.dtypes.prng_key array
1406
+ key
948
1407
  A jax random key.
949
- bart : dict
950
- A BART mcmc state.
951
- moves : dict
952
- The proposed moves, see `sample_moves`.
1408
+ bart
1409
+ A valid BART mcmc state.
1410
+ moves
1411
+ The proposed moves, see `propose_moves`.
953
1412
 
954
1413
  Returns
955
1414
  -------
956
- bart : dict
957
- The new BART mcmc state.
1415
+ A new (valid) BART mcmc state.
958
1416
  """
959
- bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf = (
960
- accept_moves_parallel_stage(key, bart, moves)
961
- )
962
- bart, moves = accept_moves_sequential_stage(
963
- bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
964
- )
1417
+ pso = accept_moves_parallel_stage(key, bart, moves)
1418
+ bart, moves = accept_moves_sequential_stage(pso)
965
1419
  return accept_moves_final_stage(bart, moves)
966
1420
 
967
1421
 
968
- def accept_moves_parallel_stage(key, bart, moves):
1422
+ class Counts(Module):
1423
+ """
1424
+ Number of datapoints in the nodes involved in proposed moves for each tree.
1425
+
1426
+ Parameters
1427
+ ----------
1428
+ left
1429
+ Number of datapoints in the left child.
1430
+ right
1431
+ Number of datapoints in the right child.
1432
+ total
1433
+ Number of datapoints in the parent (``= left + right``).
1434
+ """
1435
+
1436
+ left: UInt[Array, ' num_trees']
1437
+ right: UInt[Array, ' num_trees']
1438
+ total: UInt[Array, ' num_trees']
1439
+
1440
+
1441
+ class Precs(Module):
969
1442
  """
970
- Pre-computes quantities used to accept moves, in parallel across trees.
1443
+ Likelihood precision scale in the nodes involved in proposed moves for each tree.
1444
+
1445
+ The "likelihood precision scale" of a tree node is the sum of the inverse
1446
+ squared error scales of the datapoints selected by the node.
1447
+
1448
+ Parameters
1449
+ ----------
1450
+ left
1451
+ Likelihood precision scale in the left child.
1452
+ right
1453
+ Likelihood precision scale in the right child.
1454
+ total
1455
+ Likelihood precision scale in the parent (``= left + right``).
1456
+ """
1457
+
1458
+ left: Float32[Array, ' num_trees']
1459
+ right: Float32[Array, ' num_trees']
1460
+ total: Float32[Array, ' num_trees']
1461
+
1462
+
1463
+ class PreLkV(Module):
1464
+ """
1465
+ Non-sequential terms of the likelihood ratio for each tree.
1466
+
1467
+ These terms can be computed in parallel across trees.
1468
+
1469
+ Parameters
1470
+ ----------
1471
+ sigma2_left
1472
+ The noise variance in the left child of the leaves grown or pruned by
1473
+ the moves.
1474
+ sigma2_right
1475
+ The noise variance in the right child of the leaves grown or pruned by
1476
+ the moves.
1477
+ sigma2_total
1478
+ The noise variance in the total of the leaves grown or pruned by the
1479
+ moves.
1480
+ sqrt_term
1481
+ The **logarithm** of the square root term of the likelihood ratio.
1482
+ """
1483
+
1484
+ sigma2_left: Float32[Array, ' num_trees']
1485
+ sigma2_right: Float32[Array, ' num_trees']
1486
+ sigma2_total: Float32[Array, ' num_trees']
1487
+ sqrt_term: Float32[Array, ' num_trees']
1488
+
1489
+
1490
+ class PreLk(Module):
1491
+ """
1492
+ Non-sequential terms of the likelihood ratio shared by all trees.
1493
+
1494
+ Parameters
1495
+ ----------
1496
+ exp_factor
1497
+ The factor to multiply the likelihood ratio by, shared by all trees.
1498
+ """
1499
+
1500
+ exp_factor: Float32[Array, '']
1501
+
1502
+
1503
+ class PreLf(Module):
1504
+ """
1505
+ Pre-computed terms used to sample leaves from their posterior.
1506
+
1507
+ These terms can be computed in parallel across trees.
1508
+
1509
+ Parameters
1510
+ ----------
1511
+ mean_factor
1512
+ The factor to be multiplied by the sum of the scaled residuals to
1513
+ obtain the posterior mean.
1514
+ centered_leaves
1515
+ The mean-zero normal values to be added to the posterior mean to
1516
+ obtain the posterior leaf samples.
1517
+ """
1518
+
1519
+ mean_factor: Float32[Array, 'num_trees 2**d']
1520
+ centered_leaves: Float32[Array, 'num_trees 2**d']
1521
+
1522
+
1523
+ class ParallelStageOut(Module):
1524
+ """
1525
+ The output of `accept_moves_parallel_stage`.
1526
+
1527
+ Parameters
1528
+ ----------
1529
+ bart
1530
+ A partially updated BART mcmc state.
1531
+ moves
1532
+ The proposed moves, with `partial_ratio` set to `None` and
1533
+ `log_trans_prior_ratio` set to its final value.
1534
+ prec_trees
1535
+ The likelihood precision scale in each potential or actual leaf node. If
1536
+ there is no precision scale, this is the number of points in each leaf.
1537
+ move_counts
1538
+ The counts of the number of points in the the nodes modified by the
1539
+ moves. If `bart.min_points_per_leaf` is not set and
1540
+ `bart.prec_scale` is set, they are not computed.
1541
+ move_precs
1542
+ The likelihood precision scale in each node modified by the moves. If
1543
+ `bart.prec_scale` is not set, this is set to `move_counts`.
1544
+ prelkv
1545
+ prelk
1546
+ prelf
1547
+ Objects with pre-computed terms of the likelihood ratios and leaf
1548
+ samples.
1549
+ """
1550
+
1551
+ bart: State
1552
+ moves: Moves
1553
+ prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1554
+ move_precs: Precs | Counts
1555
+ prelkv: PreLkV
1556
+ prelk: PreLk
1557
+ prelf: PreLf
1558
+
1559
+
1560
+ def accept_moves_parallel_stage(
1561
+ key: Key[Array, ''], bart: State, moves: Moves
1562
+ ) -> ParallelStageOut:
1563
+ """
1564
+ Pre-compute quantities used to accept moves, in parallel across trees.
971
1565
 
972
1566
  Parameters
973
1567
  ----------
@@ -976,153 +1570,186 @@ def accept_moves_parallel_stage(key, bart, moves):
976
1570
  bart : dict
977
1571
  A BART mcmc state.
978
1572
  moves : dict
979
- The proposed moves, see `sample_moves`.
1573
+ The proposed moves, see `propose_moves`.
980
1574
 
981
1575
  Returns
982
1576
  -------
983
- bart : dict
984
- A partially updated BART mcmc state.
985
- moves : dict
986
- The proposed moves, with the field 'partial_ratio' replaced
987
- by 'log_trans_prior_ratio'.
988
- prec_trees : float array (num_trees, 2 ** d)
989
- The likelihood precision scale in each potential or actual leaf node. If
990
- there is no precision scale, this is the number of points in each leaf.
991
- move_counts : dict
992
- The counts of the number of points in the the nodes modified by the
993
- moves.
994
- move_precs : dict
995
- The likelihood precision scale in each node modified by the moves.
996
- prelkv, prelk, prelf : dict
997
- Dictionary with pre-computed terms of the likelihood ratios and leaf
998
- samples.
1577
+ An object with all that could be done in parallel.
999
1578
  """
1000
- bart = bart.copy()
1001
-
1002
1579
  # where the move is grow, modify the state like the move was accepted
1003
- bart['var_trees'] = moves['var_trees']
1004
- bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
1005
- bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
1580
+ bart = replace(
1581
+ bart,
1582
+ forest=replace(
1583
+ bart.forest,
1584
+ var_tree=moves.var_tree,
1585
+ leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1586
+ leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1587
+ ),
1588
+ )
1006
1589
 
1007
1590
  # count number of datapoints per leaf
1008
- count_trees, move_counts = compute_count_trees(
1009
- bart['leaf_indices'], moves, bart['opt']['count_batch_size']
1010
- )
1011
- if bart['opt']['require_min_points']:
1012
- count_half_trees = count_trees[:, : bart['var_trees'].shape[1]]
1013
- bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
1591
+ if (
1592
+ bart.forest.min_points_per_decision_node is not None
1593
+ or bart.forest.min_points_per_leaf is not None
1594
+ or bart.prec_scale is None
1595
+ ):
1596
+ count_trees, move_counts = compute_count_trees(
1597
+ bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1598
+ )
1599
+
1600
+ # mark which leaves & potential leaves have enough points to be grown
1601
+ if bart.forest.min_points_per_decision_node is not None:
1602
+ count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
1603
+ moves = replace(
1604
+ moves,
1605
+ affluence_tree=moves.affluence_tree
1606
+ & (count_half_trees >= bart.forest.min_points_per_decision_node),
1607
+ )
1608
+
1609
+ # copy updated affluence_tree to state
1610
+ bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
1611
+
1612
+ # veto grove move if new leaves don't have enough datapoints
1613
+ if bart.forest.min_points_per_leaf is not None:
1614
+ moves = replace(
1615
+ moves,
1616
+ allowed=moves.allowed
1617
+ & (move_counts.left >= bart.forest.min_points_per_leaf)
1618
+ & (move_counts.right >= bart.forest.min_points_per_leaf),
1619
+ )
1014
1620
 
1015
1621
  # count number of datapoints per leaf, weighted by error precision scale
1016
- if bart['prec_scale'] is None:
1622
+ if bart.prec_scale is None:
1017
1623
  prec_trees = count_trees
1018
1624
  move_precs = move_counts
1019
1625
  else:
1020
1626
  prec_trees, move_precs = compute_prec_trees(
1021
- bart['prec_scale'],
1022
- bart['leaf_indices'],
1627
+ bart.prec_scale,
1628
+ bart.forest.leaf_indices,
1023
1629
  moves,
1024
- bart['opt']['count_batch_size'],
1630
+ bart.forest.count_batch_size,
1025
1631
  )
1632
+ assert move_precs is not None
1026
1633
 
1027
1634
  # compute some missing information about moves
1028
- moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
1029
- bart['grow_prop_count'] = jnp.sum(moves['grow'])
1030
- bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
1031
-
1032
- prelkv, prelk = precompute_likelihood_terms(bart['sigma2'], move_precs)
1033
- prelf = precompute_leaf_terms(key, prec_trees, bart['sigma2'])
1635
+ moves = complete_ratio(moves, bart.forest.p_nonterminal)
1636
+ save_ratios = bart.forest.log_likelihood is not None
1637
+ bart = replace(
1638
+ bart,
1639
+ forest=replace(
1640
+ bart.forest,
1641
+ grow_prop_count=jnp.sum(moves.grow),
1642
+ prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1643
+ log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1644
+ ),
1645
+ )
1034
1646
 
1035
- return bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf
1647
+ # pre-compute some likelihood ratio & posterior terms
1648
+ assert bart.sigma2 is not None # `step` shall temporarily set it to 1
1649
+ prelkv, prelk = precompute_likelihood_terms(
1650
+ bart.sigma2, bart.forest.sigma_mu2, move_precs
1651
+ )
1652
+ prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
1653
+
1654
+ return ParallelStageOut(
1655
+ bart=bart,
1656
+ moves=moves,
1657
+ prec_trees=prec_trees,
1658
+ move_precs=move_precs,
1659
+ prelkv=prelkv,
1660
+ prelk=prelk,
1661
+ prelf=prelf,
1662
+ )
1036
1663
 
1037
1664
 
1038
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
1039
- def apply_grow_to_indices(moves, leaf_indices, X):
1665
+ @partial(vmap_nodoc, in_axes=(0, 0, None))
1666
+ def apply_grow_to_indices(
1667
+ moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1668
+ ) -> UInt[Array, 'num_trees n']:
1040
1669
  """
1041
1670
  Update the leaf indices to apply a grow move.
1042
1671
 
1043
1672
  Parameters
1044
1673
  ----------
1045
- moves : dict
1046
- The proposed moves, see `sample_moves`.
1047
- leaf_indices : array (num_trees, n)
1674
+ moves
1675
+ The proposed moves, see `propose_moves`.
1676
+ leaf_indices
1048
1677
  The index of the leaf each datapoint falls into.
1049
- X : array (p, n)
1678
+ X
1050
1679
  The predictors matrix.
1051
1680
 
1052
1681
  Returns
1053
1682
  -------
1054
- grow_leaf_indices : array (num_trees, n)
1055
- The updated leaf indices.
1683
+ The updated leaf indices.
1056
1684
  """
1057
- left_child = moves['node'].astype(leaf_indices.dtype) << 1
1058
- go_right = X[moves['grow_var'], :] >= moves['grow_split']
1059
- tree_size = jnp.array(2 * moves['var_trees'].size)
1060
- node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
1685
+ left_child = moves.node.astype(leaf_indices.dtype) << 1
1686
+ go_right = X[moves.grow_var, :] >= moves.grow_split
1687
+ tree_size = jnp.array(2 * moves.var_tree.size)
1688
+ node_to_update = jnp.where(moves.grow, moves.node, tree_size)
1061
1689
  return jnp.where(
1062
- leaf_indices == node_to_update,
1063
- left_child + go_right,
1064
- leaf_indices,
1690
+ leaf_indices == node_to_update, left_child + go_right, leaf_indices
1065
1691
  )
1066
1692
 
1067
1693
 
1068
- def compute_count_trees(leaf_indices, moves, batch_size):
1694
+ def compute_count_trees(
1695
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1696
+ ) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1069
1697
  """
1070
1698
  Count the number of datapoints in each leaf.
1071
1699
 
1072
1700
  Parameters
1073
1701
  ----------
1074
- leaf_indices : int array (num_trees, n)
1702
+ leaf_indices
1075
1703
  The index of the leaf each datapoint falls into, with the deeper version
1076
1704
  of the tree (post-GROW, pre-PRUNE).
1077
- moves : dict
1078
- The proposed moves, see `sample_moves`.
1079
- batch_size : int or None
1705
+ moves
1706
+ The proposed moves, see `propose_moves`.
1707
+ batch_size
1080
1708
  The data batch size to use for the summation.
1081
1709
 
1082
1710
  Returns
1083
1711
  -------
1084
- count_trees : int array (num_trees, 2 ** (d - 1))
1712
+ count_trees : Int32[Array, 'num_trees 2**d']
1085
1713
  The number of points in each potential or actual leaf node.
1086
- counts : dict
1714
+ counts : Counts
1087
1715
  The counts of the number of points in the leaves grown or pruned by the
1088
- moves, under keys 'left', 'right', and 'total' (left + right).
1716
+ moves.
1089
1717
  """
1090
-
1091
- ntree, tree_size = moves['var_trees'].shape
1718
+ num_trees, tree_size = moves.var_tree.shape
1092
1719
  tree_size *= 2
1093
- tree_indices = jnp.arange(ntree)
1720
+ tree_indices = jnp.arange(num_trees)
1094
1721
 
1095
1722
  count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1096
1723
 
1097
1724
  # count datapoints in nodes modified by move
1098
- counts = dict()
1099
- counts['left'] = count_trees[tree_indices, moves['left']]
1100
- counts['right'] = count_trees[tree_indices, moves['right']]
1101
- counts['total'] = counts['left'] + counts['right']
1725
+ left = count_trees[tree_indices, moves.left]
1726
+ right = count_trees[tree_indices, moves.right]
1727
+ counts = Counts(left=left, right=right, total=left + right)
1102
1728
 
1103
1729
  # write count into non-leaf node
1104
- count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
1730
+ count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
1105
1731
 
1106
1732
  return count_trees, counts
1107
1733
 
1108
1734
 
1109
- def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1735
+ def count_datapoints_per_leaf(
1736
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1737
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1110
1738
  """
1111
1739
  Count the number of datapoints in each leaf.
1112
1740
 
1113
1741
  Parameters
1114
1742
  ----------
1115
- leaf_indices : int array (num_trees, n)
1743
+ leaf_indices
1116
1744
  The index of the leaf each datapoint falls into.
1117
- tree_size : int
1745
+ tree_size
1118
1746
  The size of the leaf tree array (2 ** d).
1119
- batch_size : int or None
1747
+ batch_size
1120
1748
  The data batch size to use for the summation.
1121
1749
 
1122
1750
  Returns
1123
1751
  -------
1124
- count_trees : int array (num_trees, 2 ** (d - 1))
1125
- The number of points in each leaf node.
1752
+ The number of points in each leaf node.
1126
1753
  """
1127
1754
  if batch_size is None:
1128
1755
  return _count_scan(leaf_indices, tree_size)
@@ -1130,7 +1757,9 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1130
1757
  return _count_vec(leaf_indices, tree_size, batch_size)
1131
1758
 
1132
1759
 
1133
- def _count_scan(leaf_indices, tree_size):
1760
+ def _count_scan(
1761
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1762
+ ) -> Int32[Array, 'num_trees {tree_size}']:
1134
1763
  def loop(_, leaf_indices):
1135
1764
  return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1136
1765
 
@@ -1138,92 +1767,111 @@ def _count_scan(leaf_indices, tree_size):
1138
1767
  return count_trees
1139
1768
 
1140
1769
 
1141
- def _aggregate_scatter(values, indices, size, dtype):
1770
+ def _aggregate_scatter(
1771
+ values: Shaped[Array, '*'],
1772
+ indices: Integer[Array, '*'],
1773
+ size: int,
1774
+ dtype: jnp.dtype,
1775
+ ) -> Shaped[Array, ' {size}']:
1142
1776
  return jnp.zeros(size, dtype).at[indices].add(values)
1143
1777
 
1144
1778
 
1145
- def _count_vec(leaf_indices, tree_size, batch_size):
1779
+ def _count_vec(
1780
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1781
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1146
1782
  return _aggregate_batched_alltrees(
1147
1783
  1, leaf_indices, tree_size, jnp.uint32, batch_size
1148
1784
  )
1149
1785
  # uint16 is super-slow on gpu, don't use it even if n < 2^16
1150
1786
 
1151
1787
 
1152
- def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1153
- ntree, n = indices.shape
1154
- tree_indices = jnp.arange(ntree)
1788
+ def _aggregate_batched_alltrees(
1789
+ values: Shaped[Array, '*'],
1790
+ indices: UInt[Array, 'num_trees n'],
1791
+ size: int,
1792
+ dtype: jnp.dtype,
1793
+ batch_size: int,
1794
+ ) -> Shaped[Array, 'num_trees {size}']:
1795
+ num_trees, n = indices.shape
1796
+ tree_indices = jnp.arange(num_trees)
1155
1797
  nbatches = n // batch_size + bool(n % batch_size)
1156
1798
  batch_indices = jnp.arange(n) % nbatches
1157
1799
  return (
1158
- jnp.zeros((ntree, size, nbatches), dtype)
1800
+ jnp.zeros((num_trees, size, nbatches), dtype)
1159
1801
  .at[tree_indices[:, None], indices, batch_indices]
1160
1802
  .add(values)
1161
1803
  .sum(axis=2)
1162
1804
  )
1163
1805
 
1164
1806
 
1165
- def compute_prec_trees(prec_scale, leaf_indices, moves, batch_size):
1807
+ def compute_prec_trees(
1808
+ prec_scale: Float32[Array, ' n'],
1809
+ leaf_indices: UInt[Array, 'num_trees n'],
1810
+ moves: Moves,
1811
+ batch_size: int | None,
1812
+ ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1166
1813
  """
1167
1814
  Compute the likelihood precision scale in each leaf.
1168
1815
 
1169
1816
  Parameters
1170
1817
  ----------
1171
- prec_scale : float array (n,)
1818
+ prec_scale
1172
1819
  The scale of the precision of the error on each datapoint.
1173
- leaf_indices : int array (num_trees, n)
1820
+ leaf_indices
1174
1821
  The index of the leaf each datapoint falls into, with the deeper version
1175
1822
  of the tree (post-GROW, pre-PRUNE).
1176
- moves : dict
1177
- The proposed moves, see `sample_moves`.
1178
- batch_size : int or None
1823
+ moves
1824
+ The proposed moves, see `propose_moves`.
1825
+ batch_size
1179
1826
  The data batch size to use for the summation.
1180
1827
 
1181
1828
  Returns
1182
1829
  -------
1183
- prec_trees : float array (num_trees, 2 ** (d - 1))
1830
+ prec_trees : Float32[Array, 'num_trees 2**d']
1184
1831
  The likelihood precision scale in each potential or actual leaf node.
1185
- counts : dict
1186
- The likelihood precision scale in the leaves grown or pruned by the
1187
- moves, under keys 'left', 'right', and 'total' (left + right).
1832
+ precs : Precs
1833
+ The likelihood precision scale in the nodes involved in the moves.
1188
1834
  """
1189
-
1190
- ntree, tree_size = moves['var_trees'].shape
1835
+ num_trees, tree_size = moves.var_tree.shape
1191
1836
  tree_size *= 2
1192
- tree_indices = jnp.arange(ntree)
1837
+ tree_indices = jnp.arange(num_trees)
1193
1838
 
1194
1839
  prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1195
1840
 
1196
1841
  # prec datapoints in nodes modified by move
1197
- precs = dict()
1198
- precs['left'] = prec_trees[tree_indices, moves['left']]
1199
- precs['right'] = prec_trees[tree_indices, moves['right']]
1200
- precs['total'] = precs['left'] + precs['right']
1842
+ left = prec_trees[tree_indices, moves.left]
1843
+ right = prec_trees[tree_indices, moves.right]
1844
+ precs = Precs(left=left, right=right, total=left + right)
1201
1845
 
1202
1846
  # write prec into non-leaf node
1203
- prec_trees = prec_trees.at[tree_indices, moves['node']].set(precs['total'])
1847
+ prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
1204
1848
 
1205
1849
  return prec_trees, precs
1206
1850
 
1207
1851
 
1208
- def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
1852
+ def prec_per_leaf(
1853
+ prec_scale: Float32[Array, ' n'],
1854
+ leaf_indices: UInt[Array, 'num_trees n'],
1855
+ tree_size: int,
1856
+ batch_size: int | None,
1857
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1209
1858
  """
1210
1859
  Compute the likelihood precision scale in each leaf.
1211
1860
 
1212
1861
  Parameters
1213
1862
  ----------
1214
- prec_scale : float array (n,)
1863
+ prec_scale
1215
1864
  The scale of the precision of the error on each datapoint.
1216
- leaf_indices : int array (num_trees, n)
1865
+ leaf_indices
1217
1866
  The index of the leaf each datapoint falls into.
1218
- tree_size : int
1867
+ tree_size
1219
1868
  The size of the leaf tree array (2 ** d).
1220
- batch_size : int or None
1869
+ batch_size
1221
1870
  The data batch size to use for the summation.
1222
1871
 
1223
1872
  Returns
1224
1873
  -------
1225
- prec_trees : int array (num_trees, 2 ** (d - 1))
1226
- The likelihood precision scale in each leaf node.
1874
+ The likelihood precision scale in each leaf node.
1227
1875
  """
1228
1876
  if batch_size is None:
1229
1877
  return _prec_scan(prec_scale, leaf_indices, tree_size)
@@ -1231,432 +1879,439 @@ def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
1231
1879
  return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1232
1880
 
1233
1881
 
1234
- def _prec_scan(prec_scale, leaf_indices, tree_size):
1882
+ def _prec_scan(
1883
+ prec_scale: Float32[Array, ' n'],
1884
+ leaf_indices: UInt[Array, 'num_trees n'],
1885
+ tree_size: int,
1886
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1235
1887
  def loop(_, leaf_indices):
1236
1888
  return None, _aggregate_scatter(
1237
1889
  prec_scale, leaf_indices, tree_size, jnp.float32
1238
- ) # TODO: use large_float
1890
+ )
1239
1891
 
1240
1892
  _, prec_trees = lax.scan(loop, None, leaf_indices)
1241
1893
  return prec_trees
1242
1894
 
1243
1895
 
1244
- def _prec_vec(prec_scale, leaf_indices, tree_size, batch_size):
1896
+ def _prec_vec(
1897
+ prec_scale: Float32[Array, ' n'],
1898
+ leaf_indices: UInt[Array, 'num_trees n'],
1899
+ tree_size: int,
1900
+ batch_size: int,
1901
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1245
1902
  return _aggregate_batched_alltrees(
1246
1903
  prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1247
- ) # TODO: use large_float
1904
+ )
1248
1905
 
1249
1906
 
1250
- def complete_ratio(moves, move_counts, min_points_per_leaf):
1907
+ def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
1251
1908
  """
1252
1909
  Complete non-likelihood MH ratio calculation.
1253
1910
 
1254
- This function adds the probability of choosing the prune move.
1911
+ This function adds the probability of choosing a prune move over the grow
1912
+ move in the inverse transition, and the a priori probability that the
1913
+ children nodes are leaves.
1255
1914
 
1256
1915
  Parameters
1257
1916
  ----------
1258
- moves : dict
1259
- The proposed moves, see `sample_moves`.
1260
- move_counts : dict
1261
- The counts of the number of points in the the nodes modified by the
1262
- moves.
1263
- min_points_per_leaf : int or None
1264
- The minimum number of data points in a leaf node.
1917
+ moves
1918
+ The proposed moves. Must have already been updated to keep into account
1919
+ the thresholds on the number of datapoints per node, this happens in
1920
+ `accept_moves_parallel_stage`.
1921
+ p_nonterminal
1922
+ The a priori probability of each node being nonterminal conditional on
1923
+ its ancestors, including at the maximum depth where it should be zero.
1265
1924
 
1266
1925
  Returns
1267
1926
  -------
1268
- moves : dict
1269
- The updated moves, with the field 'partial_ratio' replaced by
1270
- 'log_trans_prior_ratio'.
1927
+ The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1271
1928
  """
1272
- moves = moves.copy()
1273
- p_prune = compute_p_prune(
1274
- moves, move_counts['left'], move_counts['right'], min_points_per_leaf
1929
+ # can the leaves can be grown?
1930
+ num_trees, _ = moves.affluence_tree.shape
1931
+ tree_indices = jnp.arange(num_trees)
1932
+ left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
1933
+ mode='fill', fill_value=False
1934
+ )
1935
+ right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
1936
+ mode='fill', fill_value=False
1275
1937
  )
1276
- moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1277
- return moves
1278
-
1279
1938
 
1280
- def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1281
- """
1282
- Compute the probability of proposing a prune move.
1939
+ # p_prune if grow
1940
+ other_growable_leaves = moves.num_growable >= 2
1941
+ grow_again_allowed = other_growable_leaves | left_growable | right_growable
1942
+ grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1283
1943
 
1284
- Parameters
1285
- ----------
1286
- moves : dict
1287
- The proposed moves, see `sample_moves`.
1288
- left_count, right_count : int
1289
- The number of datapoints in the proposed children of the leaf to grow.
1290
- min_points_per_leaf : int or None
1291
- The minimum number of data points in a leaf node.
1944
+ # p_prune if prune
1945
+ prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1292
1946
 
1293
- Returns
1294
- -------
1295
- p_prune : float
1296
- The probability of proposing a prune move. If grow: after accepting the
1297
- grow move, if prune: right away.
1298
- """
1299
-
1300
- # calculation in case the move is grow
1301
- other_growable_leaves = moves['num_growable'] >= 2
1302
- new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
1303
- if min_points_per_leaf is not None:
1304
- any_above_threshold = left_count >= 2 * min_points_per_leaf
1305
- any_above_threshold |= right_count >= 2 * min_points_per_leaf
1306
- new_leaves_growable &= any_above_threshold
1307
- grow_again_allowed = other_growable_leaves | new_leaves_growable
1308
- grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1947
+ # select p_prune
1948
+ p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1309
1949
 
1310
- # calculation in case the move is prune
1311
- prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
1950
+ # prior probability of both children being terminal
1951
+ pt_left = 1 - p_nonterminal[moves.left] * left_growable
1952
+ pt_right = 1 - p_nonterminal[moves.right] * right_growable
1953
+ pt_children = pt_left * pt_right
1312
1954
 
1313
- return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1955
+ return replace(
1956
+ moves,
1957
+ log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1958
+ partial_ratio=None,
1959
+ )
1314
1960
 
1315
1961
 
1316
- @jaxext.vmap_nodoc
1317
- def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1962
+ @vmap_nodoc
1963
+ def adapt_leaf_trees_to_grow_indices(
1964
+ leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1965
+ ) -> Float32[Array, 'num_trees 2**d']:
1318
1966
  """
1319
- Modify leaf values such that the indices of the grow moves work on the
1320
- original tree.
1967
+ Modify leaves such that post-grow indices work on the original tree.
1968
+
1969
+ The value of the leaf to grow is copied to what would be its children if the
1970
+ grow move was accepted.
1321
1971
 
1322
1972
  Parameters
1323
1973
  ----------
1324
- leaf_trees : float array (num_trees, 2 ** d)
1974
+ leaf_trees
1325
1975
  The leaf values.
1326
- moves : dict
1327
- The proposed moves, see `sample_moves`.
1976
+ moves
1977
+ The proposed moves, see `propose_moves`.
1328
1978
 
1329
1979
  Returns
1330
1980
  -------
1331
- leaf_trees : float array (num_trees, 2 ** d)
1332
- The modified leaf values. The value of the leaf to grow is copied to
1333
- what would be its children if the grow move was accepted.
1981
+ The modified leaf values.
1334
1982
  """
1335
- values_at_node = leaf_trees[moves['node']]
1983
+ values_at_node = leaf_trees[moves.node]
1336
1984
  return (
1337
- leaf_trees.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1985
+ leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1338
1986
  .set(values_at_node)
1339
- .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1987
+ .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1340
1988
  .set(values_at_node)
1341
1989
  )
1342
1990
 
1343
1991
 
1344
- def precompute_likelihood_terms(sigma2, move_precs):
1992
+ def precompute_likelihood_terms(
1993
+ sigma2: Float32[Array, ''],
1994
+ sigma_mu2: Float32[Array, ''],
1995
+ move_precs: Precs | Counts,
1996
+ ) -> tuple[PreLkV, PreLk]:
1345
1997
  """
1346
1998
  Pre-compute terms used in the likelihood ratio of the acceptance step.
1347
1999
 
1348
2000
  Parameters
1349
2001
  ----------
1350
- sigma2 : float
1351
- The noise variance.
1352
- move_precs : dict
2002
+ sigma2
2003
+ The error variance, or the global error variance factor is `prec_scale`
2004
+ is set.
2005
+ sigma_mu2
2006
+ The prior variance of each leaf.
2007
+ move_precs
1353
2008
  The likelihood precision scale in the leaves grown or pruned by the
1354
2009
  moves, under keys 'left', 'right', and 'total' (left + right).
1355
2010
 
1356
2011
  Returns
1357
2012
  -------
1358
- prelkv : dict
2013
+ prelkv : PreLkV
1359
2014
  Dictionary with pre-computed terms of the likelihood ratio, one per
1360
2015
  tree.
1361
- prelk : dict
2016
+ prelk : PreLk
1362
2017
  Dictionary with pre-computed terms of the likelihood ratio, shared by
1363
2018
  all trees.
1364
2019
  """
1365
- ntree = len(move_precs['total'])
1366
- sigma_mu2 = 1 / ntree
1367
- prelkv = dict()
1368
- prelkv['sigma2_left'] = sigma2 + move_precs['left'] * sigma_mu2
1369
- prelkv['sigma2_right'] = sigma2 + move_precs['right'] * sigma_mu2
1370
- prelkv['sigma2_total'] = sigma2 + move_precs['total'] * sigma_mu2
1371
- prelkv['sqrt_term'] = (
1372
- jnp.log(
1373
- sigma2
1374
- * prelkv['sigma2_total']
1375
- / (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1376
- )
1377
- / 2
1378
- )
1379
- return prelkv, dict(
1380
- exp_factor=sigma_mu2 / (2 * sigma2),
2020
+ sigma2_left = sigma2 + move_precs.left * sigma_mu2
2021
+ sigma2_right = sigma2 + move_precs.right * sigma_mu2
2022
+ sigma2_total = sigma2 + move_precs.total * sigma_mu2
2023
+ prelkv = PreLkV(
2024
+ sigma2_left=sigma2_left,
2025
+ sigma2_right=sigma2_right,
2026
+ sigma2_total=sigma2_total,
2027
+ sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1381
2028
  )
2029
+ return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
1382
2030
 
1383
2031
 
1384
- def precompute_leaf_terms(key, prec_trees, sigma2):
2032
+ def precompute_leaf_terms(
2033
+ key: Key[Array, ''],
2034
+ prec_trees: Float32[Array, 'num_trees 2**d'],
2035
+ sigma2: Float32[Array, ''],
2036
+ sigma_mu2: Float32[Array, ''],
2037
+ ) -> PreLf:
1385
2038
  """
1386
2039
  Pre-compute terms used to sample leaves from their posterior.
1387
2040
 
1388
2041
  Parameters
1389
2042
  ----------
1390
- key : jax.dtypes.prng_key array
2043
+ key
1391
2044
  A jax random key.
1392
- prec_trees : array (num_trees, 2 ** d)
2045
+ prec_trees
1393
2046
  The likelihood precision scale in each potential or actual leaf node.
1394
- sigma2 : float
1395
- The noise variance.
2047
+ sigma2
2048
+ The error variance, or the global error variance factor if `prec_scale`
2049
+ is set.
2050
+ sigma_mu2
2051
+ The prior variance of each leaf.
1396
2052
 
1397
2053
  Returns
1398
2054
  -------
1399
- prelf : dict
1400
- Dictionary with pre-computed terms of the leaf sampling, with fields:
1401
-
1402
- 'mean_factor' : float array (num_trees, 2 ** d)
1403
- The factor to be multiplied by the sum of the scaled residuals to
1404
- obtain the posterior mean.
1405
- 'centered_leaves' : float array (num_trees, 2 ** d)
1406
- The mean-zero normal values to be added to the posterior mean to
1407
- obtain the posterior leaf samples.
2055
+ Pre-computed terms for leaf sampling.
1408
2056
  """
1409
- ntree = len(prec_trees)
1410
2057
  prec_lk = prec_trees / sigma2
1411
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
2058
+ prec_prior = lax.reciprocal(sigma_mu2)
2059
+ var_post = lax.reciprocal(prec_lk + prec_prior)
1412
2060
  z = random.normal(key, prec_trees.shape, sigma2.dtype)
1413
- return dict(
1414
- mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
2061
+ return PreLf(
2062
+ mean_factor=var_post / sigma2,
2063
+ # | mean = mean_lk * prec_lk * var_post
2064
+ # | resid_tree = mean_lk * prec_tree -->
2065
+ # | --> mean_lk = resid_tree / prec_tree (kind of)
2066
+ # | mean_factor =
2067
+ # | = mean / resid_tree =
2068
+ # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2069
+ # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2070
+ # | = var_post / sigma2
1415
2071
  centered_leaves=z * jnp.sqrt(var_post),
1416
2072
  )
1417
2073
 
1418
2074
 
1419
- def accept_moves_sequential_stage(
1420
- bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
1421
- ):
2075
+ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1422
2076
  """
1423
- The part of accepting the moves that has to be done one tree at a time.
2077
+ Accept/reject the moves one tree at a time.
2078
+
2079
+ This is the most performance-sensitive function because it contains all and
2080
+ only the parts of the algorithm that can not be parallelized across trees.
1424
2081
 
1425
2082
  Parameters
1426
2083
  ----------
1427
- bart : dict
1428
- A partially updated BART mcmc state.
1429
- prec_trees : float array (num_trees, 2 ** d)
1430
- The likelihood precision scale in each potential or actual leaf node.
1431
- moves : dict
1432
- The proposed moves, see `sample_moves`.
1433
- move_counts : dict
1434
- The counts of the number of points in the the nodes modified by the
1435
- moves.
1436
- move_precs : dict
1437
- The likelihood precision scale in each node modified by the moves.
1438
- prelkv, prelk, prelf : dict
1439
- Dictionaries with pre-computed terms of the likelihood ratios and leaf
1440
- samples.
2084
+ pso
2085
+ The output of `accept_moves_parallel_stage`.
1441
2086
 
1442
2087
  Returns
1443
2088
  -------
1444
- bart : dict
2089
+ bart : State
1445
2090
  A partially updated BART mcmc state.
1446
- moves : dict
1447
- The proposed moves, with these additional fields:
1448
-
1449
- 'acc' : bool array (num_trees,)
1450
- Whether the move was accepted.
1451
- 'to_prune' : bool array (num_trees,)
1452
- Whether, to reflect the acceptance status of the move, the state
1453
- should be updated by pruning the leaves involved in the move.
1454
- """
1455
- bart = bart.copy()
1456
- moves = moves.copy()
1457
-
1458
- def loop(resid, item):
1459
- resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
1460
- bart['X'],
1461
- len(bart['leaf_trees']),
1462
- bart['opt']['resid_batch_size'],
2091
+ moves : Moves
2092
+ The accepted/rejected moves, with `acc` and `to_prune` set.
2093
+ """
2094
+
2095
+ def loop(resid, pt):
2096
+ resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
1463
2097
  resid,
1464
- bart['prec_scale'],
1465
- bart['min_points_per_leaf'],
1466
- 'ratios' in bart,
1467
- prelk,
1468
- *item,
2098
+ SeqStageInAllTrees(
2099
+ pso.bart.X,
2100
+ pso.bart.forest.resid_batch_size,
2101
+ pso.bart.prec_scale,
2102
+ pso.bart.forest.log_likelihood is not None,
2103
+ pso.prelk,
2104
+ ),
2105
+ pt,
1469
2106
  )
1470
- return resid, (leaf_tree, acc, to_prune, ratios)
1471
-
1472
- items = (
1473
- bart['leaf_trees'],
1474
- prec_trees,
1475
- moves,
1476
- move_counts,
1477
- move_precs,
1478
- bart['leaf_indices'],
1479
- prelkv,
1480
- prelf,
2107
+ return resid, (leaf_tree, acc, to_prune, lkratio)
2108
+
2109
+ pts = SeqStageInPerTree(
2110
+ pso.bart.forest.leaf_tree,
2111
+ pso.prec_trees,
2112
+ pso.moves,
2113
+ pso.move_precs,
2114
+ pso.bart.forest.leaf_indices,
2115
+ pso.prelkv,
2116
+ pso.prelf,
1481
2117
  )
1482
- resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
2118
+ resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
1483
2119
 
1484
- bart['resid'] = resid
1485
- bart['leaf_trees'] = leaf_trees
1486
- bart.get('ratios', {}).update(ratios) # noop if there are no ratios
1487
- moves['acc'] = acc
1488
- moves['to_prune'] = to_prune
2120
+ bart = replace(
2121
+ pso.bart,
2122
+ resid=resid,
2123
+ forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2124
+ )
2125
+ moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1489
2126
 
1490
2127
  return bart, moves
1491
2128
 
1492
2129
 
1493
- def accept_move_and_sample_leaves(
1494
- X,
1495
- ntree,
1496
- resid_batch_size,
1497
- resid,
1498
- prec_scale,
1499
- min_points_per_leaf,
1500
- save_ratios,
1501
- prelk,
1502
- leaf_tree,
1503
- prec_tree,
1504
- move,
1505
- move_counts,
1506
- move_precs,
1507
- leaf_indices,
1508
- prelkv,
1509
- prelf,
1510
- ):
2130
+ class SeqStageInAllTrees(Module):
1511
2131
  """
1512
- Accept or reject a proposed move and sample the new leaf values.
2132
+ The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
1513
2133
 
1514
2134
  Parameters
1515
2135
  ----------
1516
- X : int array (p, n)
2136
+ X
1517
2137
  The predictors.
1518
- ntree : int
1519
- The number of trees in the forest.
1520
- resid_batch_size : int, None
2138
+ resid_batch_size
1521
2139
  The batch size for computing the sum of residuals in each leaf.
1522
- resid : float array (n,)
1523
- The residuals (data minus forest value).
1524
- prec_scale : float array (n,) or None
2140
+ prec_scale
1525
2141
  The scale of the precision of the error on each datapoint. If None, it
1526
2142
  is assumed to be 1.
1527
- min_points_per_leaf : int or None
1528
- The minimum number of data points in a leaf node.
1529
- save_ratios : bool
2143
+ save_ratios
1530
2144
  Whether to save the acceptance ratios.
1531
- prelk : dict
2145
+ prelk
1532
2146
  The pre-computed terms of the likelihood ratio which are shared across
1533
2147
  trees.
1534
- leaf_tree : float array (2 ** d,)
2148
+ """
2149
+
2150
+ X: UInt[Array, 'p n']
2151
+ resid_batch_size: int | None = field(static=True)
2152
+ prec_scale: Float32[Array, ' n'] | None
2153
+ save_ratios: bool = field(static=True)
2154
+ prelk: PreLk
2155
+
2156
+
2157
+ class SeqStageInPerTree(Module):
2158
+ """
2159
+ The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2160
+
2161
+ Parameters
2162
+ ----------
2163
+ leaf_tree
1535
2164
  The leaf values of the tree.
1536
- prec_tree : float array (2 ** d,)
2165
+ prec_tree
1537
2166
  The likelihood precision scale in each potential or actual leaf node.
1538
- move : dict
1539
- The proposed move, see `sample_moves`.
1540
- move_counts : dict
1541
- The counts of the number of points in the the nodes modified by the
1542
- moves.
1543
- move_precs : dict
2167
+ move
2168
+ The proposed move, see `propose_moves`.
2169
+ move_precs
1544
2170
  The likelihood precision scale in each node modified by the moves.
1545
- leaf_indices : int array (n,)
2171
+ leaf_indices
1546
2172
  The leaf indices for the largest version of the tree compatible with
1547
2173
  the move.
1548
- prelkv, prelf : dict
2174
+ prelkv
2175
+ prelf
1549
2176
  The pre-computed terms of the likelihood ratio and leaf sampling which
1550
2177
  are specific to the tree.
2178
+ """
2179
+
2180
+ leaf_tree: Float32[Array, ' 2**d']
2181
+ prec_tree: Float32[Array, ' 2**d']
2182
+ move: Moves
2183
+ move_precs: Precs | Counts
2184
+ leaf_indices: UInt[Array, ' n']
2185
+ prelkv: PreLkV
2186
+ prelf: PreLf
2187
+
2188
+
2189
+ def accept_move_and_sample_leaves(
2190
+ resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2191
+ ) -> tuple[
2192
+ Float32[Array, ' n'],
2193
+ Float32[Array, ' 2**d'],
2194
+ Bool[Array, ''],
2195
+ Bool[Array, ''],
2196
+ Float32[Array, ''] | None,
2197
+ ]:
2198
+ """
2199
+ Accept or reject a proposed move and sample the new leaf values.
2200
+
2201
+ Parameters
2202
+ ----------
2203
+ resid
2204
+ The residuals (data minus forest value).
2205
+ at
2206
+ The inputs that are the same for all trees.
2207
+ pt
2208
+ The inputs that are separate for each tree.
1551
2209
 
1552
2210
  Returns
1553
2211
  -------
1554
- resid : float array (n,)
2212
+ resid : Float32[Array, 'n']
1555
2213
  The updated residuals (data minus forest value).
1556
- leaf_tree : float array (2 ** d,)
2214
+ leaf_tree : Float32[Array, '2**d']
1557
2215
  The new leaf values of the tree.
1558
- acc : bool
2216
+ acc : Bool[Array, '']
1559
2217
  Whether the move was accepted.
1560
- to_prune : bool
2218
+ to_prune : Bool[Array, '']
1561
2219
  Whether, to reflect the acceptance status of the move, the state should
1562
2220
  be updated by pruning the leaves involved in the move.
1563
- ratios : dict
1564
- The acceptance ratios for the moves. Empty if not to be saved.
2221
+ log_lk_ratio : Float32[Array, ''] | None
2222
+ The logarithm of the likelihood ratio for the move. `None` if not to be
2223
+ saved.
1565
2224
  """
1566
-
1567
2225
  # sum residuals in each leaf, in tree proposed by grow move
1568
- if prec_scale is None:
2226
+ if at.prec_scale is None:
1569
2227
  scaled_resid = resid
1570
2228
  else:
1571
- scaled_resid = resid * prec_scale
1572
- resid_tree = sum_resid(scaled_resid, leaf_indices, leaf_tree.size, resid_batch_size)
2229
+ scaled_resid = resid * at.prec_scale
2230
+ resid_tree = sum_resid(
2231
+ scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2232
+ )
1573
2233
 
1574
2234
  # subtract starting tree from function
1575
- resid_tree += prec_tree * leaf_tree
1576
-
1577
- # get indices of move
1578
- node = move['node']
1579
- assert node.dtype == jnp.int32
1580
- left = move['left']
1581
- right = move['right']
2235
+ resid_tree += pt.prec_tree * pt.leaf_tree
1582
2236
 
1583
2237
  # sum residuals in parent node modified by move
1584
- resid_left = resid_tree[left]
1585
- resid_right = resid_tree[right]
2238
+ resid_left = resid_tree[pt.move.left]
2239
+ resid_right = resid_tree[pt.move.right]
1586
2240
  resid_total = resid_left + resid_right
1587
- resid_tree = resid_tree.at[node].set(resid_total)
2241
+ assert pt.move.node.dtype == jnp.int32
2242
+ resid_tree = resid_tree.at[pt.move.node].set(resid_total)
1588
2243
 
1589
2244
  # compute acceptance ratio
1590
2245
  log_lk_ratio = compute_likelihood_ratio(
1591
- resid_total, resid_left, resid_right, prelkv, prelk
2246
+ resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1592
2247
  )
1593
- log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1594
- log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
1595
- ratios = {}
1596
- if save_ratios:
1597
- ratios.update(
1598
- log_trans_prior=move['log_trans_prior_ratio'],
1599
- log_likelihood=log_lk_ratio,
1600
- )
2248
+ log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2249
+ log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
2250
+ if not at.save_ratios:
2251
+ log_lk_ratio = None
1601
2252
 
1602
2253
  # determine whether to accept the move
1603
- acc = move['allowed'] & (move['logu'] <= log_ratio)
1604
- if min_points_per_leaf is not None:
1605
- acc &= move_counts['left'] >= min_points_per_leaf
1606
- acc &= move_counts['right'] >= min_points_per_leaf
2254
+ acc = pt.move.allowed & (pt.move.logu <= log_ratio)
1607
2255
 
1608
2256
  # compute leaves posterior and sample leaves
1609
- initial_leaf_tree = leaf_tree
1610
- mean_post = resid_tree * prelf['mean_factor']
1611
- leaf_tree = mean_post + prelf['centered_leaves']
2257
+ mean_post = resid_tree * pt.prelf.mean_factor
2258
+ leaf_tree = mean_post + pt.prelf.centered_leaves
1612
2259
 
1613
2260
  # copy leaves around such that the leaf indices point to the correct leaf
1614
- to_prune = acc ^ move['grow']
2261
+ to_prune = acc ^ pt.move.grow
1615
2262
  leaf_tree = (
1616
- leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
1617
- .set(leaf_tree[node])
1618
- .at[jnp.where(to_prune, right, leaf_tree.size)]
1619
- .set(leaf_tree[node])
2263
+ leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2264
+ .set(leaf_tree[pt.move.node])
2265
+ .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2266
+ .set(leaf_tree[pt.move.node])
1620
2267
  )
1621
2268
 
1622
2269
  # replace old tree with new tree in function values
1623
- resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
2270
+ resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
1624
2271
 
1625
- return resid, leaf_tree, acc, to_prune, ratios
2272
+ return resid, leaf_tree, acc, to_prune, log_lk_ratio
1626
2273
 
1627
2274
 
1628
- def sum_resid(scaled_resid, leaf_indices, tree_size, batch_size):
2275
+ def sum_resid(
2276
+ scaled_resid: Float32[Array, ' n'],
2277
+ leaf_indices: UInt[Array, ' n'],
2278
+ tree_size: int,
2279
+ batch_size: int | None,
2280
+ ) -> Float32[Array, ' {tree_size}']:
1629
2281
  """
1630
2282
  Sum the residuals in each leaf.
1631
2283
 
1632
2284
  Parameters
1633
2285
  ----------
1634
- scaled_resid : float array (n,)
2286
+ scaled_resid
1635
2287
  The residuals (data minus forest value) multiplied by the error
1636
2288
  precision scale.
1637
- leaf_indices : int array (n,)
2289
+ leaf_indices
1638
2290
  The leaf indices of the tree (in which leaf each data point falls into).
1639
- tree_size : int
2291
+ tree_size
1640
2292
  The size of the tree array (2 ** d).
1641
- batch_size : int, None
2293
+ batch_size
1642
2294
  The data batch size for the aggregation. Batching increases numerical
1643
2295
  accuracy and parallelism.
1644
2296
 
1645
2297
  Returns
1646
2298
  -------
1647
- resid_tree : float array (2 ** d,)
1648
- The sum of the residuals at data points in each leaf.
2299
+ The sum of the residuals at data points in each leaf.
1649
2300
  """
1650
2301
  if batch_size is None:
1651
2302
  aggr_func = _aggregate_scatter
1652
2303
  else:
1653
- aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1654
- return aggr_func(
1655
- scaled_resid, leaf_indices, tree_size, jnp.float32
1656
- ) # TODO: use large_float
2304
+ aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
2305
+ return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
1657
2306
 
1658
2307
 
1659
- def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
2308
+ def _aggregate_batched_onetree(
2309
+ values: Shaped[Array, '*'],
2310
+ indices: Integer[Array, '*'],
2311
+ size: int,
2312
+ dtype: jnp.dtype,
2313
+ batch_size: int,
2314
+ ) -> Float32[Array, ' {size}']:
1660
2315
  (n,) = indices.shape
1661
2316
  nbatches = n // batch_size + bool(n % batch_size)
1662
2317
  batch_indices = jnp.arange(n) % nbatches
@@ -1668,153 +2323,294 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1668
2323
  )
1669
2324
 
1670
2325
 
1671
- def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
2326
+ def compute_likelihood_ratio(
2327
+ total_resid: Float32[Array, ''],
2328
+ left_resid: Float32[Array, ''],
2329
+ right_resid: Float32[Array, ''],
2330
+ prelkv: PreLkV,
2331
+ prelk: PreLk,
2332
+ ) -> Float32[Array, '']:
1672
2333
  """
1673
2334
  Compute the likelihood ratio of a grow move.
1674
2335
 
1675
2336
  Parameters
1676
2337
  ----------
1677
- total_resid, left_resid, right_resid : float
2338
+ total_resid
2339
+ left_resid
2340
+ right_resid
1678
2341
  The sum of the residuals (scaled by error precision scale) of the
1679
2342
  datapoints falling in the nodes involved in the moves.
1680
- prelkv, prelk : dict
2343
+ prelkv
2344
+ prelk
1681
2345
  The pre-computed terms of the likelihood ratio, see
1682
2346
  `precompute_likelihood_terms`.
1683
2347
 
1684
2348
  Returns
1685
2349
  -------
1686
- ratio : float
1687
- The likelihood ratio P(data | new tree) / P(data | old tree).
2350
+ The likelihood ratio P(data | new tree) / P(data | old tree).
1688
2351
  """
1689
- exp_term = prelk['exp_factor'] * (
1690
- left_resid * left_resid / prelkv['sigma2_left']
1691
- + right_resid * right_resid / prelkv['sigma2_right']
1692
- - total_resid * total_resid / prelkv['sigma2_total']
2352
+ exp_term = prelk.exp_factor * (
2353
+ left_resid * left_resid / prelkv.sigma2_left
2354
+ + right_resid * right_resid / prelkv.sigma2_right
2355
+ - total_resid * total_resid / prelkv.sigma2_total
1693
2356
  )
1694
- return prelkv['sqrt_term'] + exp_term
2357
+ return prelkv.sqrt_term + exp_term
1695
2358
 
1696
2359
 
1697
- def accept_moves_final_stage(bart, moves):
2360
+ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1698
2361
  """
1699
- The final part of accepting the moves, in parallel across trees.
2362
+ Post-process the mcmc state after accepting/rejecting the moves.
2363
+
2364
+ This function is separate from `accept_moves_sequential_stage` to signal it
2365
+ can work in parallel across trees.
1700
2366
 
1701
2367
  Parameters
1702
2368
  ----------
1703
- bart : dict
2369
+ bart
1704
2370
  A partially updated BART mcmc state.
1705
- counts : dict
1706
- The indicators of proposals and acceptances for grow and prune moves.
1707
- moves : dict
1708
- The proposed moves (see `sample_moves`) as updated by
2371
+ moves
2372
+ The proposed moves (see `propose_moves`) as updated by
1709
2373
  `accept_moves_sequential_stage`.
1710
2374
 
1711
2375
  Returns
1712
2376
  -------
1713
- bart : dict
1714
- The fully updated BART mcmc state.
1715
- """
1716
- bart = bart.copy()
1717
- bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
1718
- bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
1719
- bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
1720
- bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1721
- return bart
2377
+ The fully updated BART mcmc state.
2378
+ """
2379
+ return replace(
2380
+ bart,
2381
+ forest=replace(
2382
+ bart.forest,
2383
+ grow_acc_count=jnp.sum(moves.acc & moves.grow),
2384
+ prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2385
+ leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2386
+ split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2387
+ ),
2388
+ )
1722
2389
 
1723
2390
 
1724
- @jaxext.vmap_nodoc
1725
- def apply_moves_to_leaf_indices(leaf_indices, moves):
2391
+ @vmap_nodoc
2392
+ def apply_moves_to_leaf_indices(
2393
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2394
+ ) -> UInt[Array, 'num_trees n']:
1726
2395
  """
1727
2396
  Update the leaf indices to match the accepted move.
1728
2397
 
1729
2398
  Parameters
1730
2399
  ----------
1731
- leaf_indices : int array (num_trees, n)
2400
+ leaf_indices
1732
2401
  The index of the leaf each datapoint falls into, if the grow move was
1733
2402
  accepted.
1734
- moves : dict
1735
- The proposed moves (see `sample_moves`), as updated by
2403
+ moves
2404
+ The proposed moves (see `propose_moves`), as updated by
1736
2405
  `accept_moves_sequential_stage`.
1737
2406
 
1738
2407
  Returns
1739
2408
  -------
1740
- leaf_indices : int array (num_trees, n)
1741
- The updated leaf indices.
2409
+ The updated leaf indices.
1742
2410
  """
1743
2411
  mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1744
- is_child = (leaf_indices & mask) == moves['left']
2412
+ is_child = (leaf_indices & mask) == moves.left
1745
2413
  return jnp.where(
1746
- is_child & moves['to_prune'],
1747
- moves['node'].astype(leaf_indices.dtype),
1748
- leaf_indices,
2414
+ is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
1749
2415
  )
1750
2416
 
1751
2417
 
1752
- @jaxext.vmap_nodoc
1753
- def apply_moves_to_split_trees(split_trees, moves):
2418
+ @vmap_nodoc
2419
+ def apply_moves_to_split_trees(
2420
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2421
+ ) -> UInt[Array, 'num_trees 2**(d-1)']:
1754
2422
  """
1755
2423
  Update the split trees to match the accepted move.
1756
2424
 
1757
2425
  Parameters
1758
2426
  ----------
1759
- split_trees : int array (num_trees, 2 ** (d - 1))
2427
+ split_tree
1760
2428
  The cutpoints of the decision nodes in the initial trees.
1761
- moves : dict
1762
- The proposed moves (see `sample_moves`), as updated by
2429
+ moves
2430
+ The proposed moves (see `propose_moves`), as updated by
1763
2431
  `accept_moves_sequential_stage`.
1764
2432
 
1765
2433
  Returns
1766
2434
  -------
1767
- split_trees : int array (num_trees, 2 ** (d - 1))
1768
- The updated split trees.
2435
+ The updated split trees.
1769
2436
  """
2437
+ assert moves.to_prune is not None
1770
2438
  return (
1771
- split_trees.at[
1772
- jnp.where(
1773
- moves['grow'],
1774
- moves['node'],
1775
- split_trees.size,
1776
- )
1777
- ]
1778
- .set(moves['grow_split'].astype(split_trees.dtype))
1779
- .at[
1780
- jnp.where(
1781
- moves['to_prune'],
1782
- moves['node'],
1783
- split_trees.size,
1784
- )
1785
- ]
2439
+ split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2440
+ .set(moves.grow_split.astype(split_tree.dtype))
2441
+ .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
1786
2442
  .set(0)
1787
2443
  )
1788
2444
 
1789
2445
 
1790
- def sample_sigma(key, bart):
2446
+ def step_sigma(key: Key[Array, ''], bart: State) -> State:
1791
2447
  """
1792
- Noise variance sampling step of BART MCMC.
2448
+ MCMC-update the error variance (factor).
1793
2449
 
1794
2450
  Parameters
1795
2451
  ----------
1796
- key : jax.dtypes.prng_key array
2452
+ key
1797
2453
  A jax random key.
1798
- bart : dict
1799
- A BART mcmc state, as created by `init`.
2454
+ bart
2455
+ A BART mcmc state.
1800
2456
 
1801
2457
  Returns
1802
2458
  -------
1803
- bart : dict
1804
- The new BART mcmc state.
2459
+ The new BART mcmc state, with an updated `sigma2`.
1805
2460
  """
1806
- bart = bart.copy()
1807
-
1808
- resid = bart['resid']
1809
- alpha = bart['sigma2_alpha'] + resid.size / 2
1810
- if bart['prec_scale'] is None:
2461
+ resid = bart.resid
2462
+ alpha = bart.sigma2_alpha + resid.size / 2
2463
+ if bart.prec_scale is None:
1811
2464
  scaled_resid = resid
1812
2465
  else:
1813
- scaled_resid = resid * bart['prec_scale']
2466
+ scaled_resid = resid * bart.prec_scale
1814
2467
  norm2 = resid @ scaled_resid
1815
- beta = bart['sigma2_beta'] + norm2 / 2
2468
+ beta = bart.sigma2_beta + norm2 / 2
1816
2469
 
1817
2470
  sample = random.gamma(key, alpha)
1818
- bart['sigma2'] = beta / sample
2471
+ # random.gamma seems to be slow at compiling, maybe cdf inversion would
2472
+ # be better, but it's not implemented in jax
2473
+ return replace(bart, sigma2=beta / sample)
2474
+
2475
+
2476
+ def step_z(key: Key[Array, ''], bart: State) -> State:
2477
+ """
2478
+ MCMC-update the latent variable for binary regression.
2479
+
2480
+ Parameters
2481
+ ----------
2482
+ key
2483
+ A jax random key.
2484
+ bart
2485
+ A BART MCMC state.
2486
+
2487
+ Returns
2488
+ -------
2489
+ The updated BART MCMC state.
2490
+ """
2491
+ trees_plus_offset = bart.z - bart.resid
2492
+ assert bart.y.dtype == bool
2493
+ resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
2494
+ z = trees_plus_offset + resid
2495
+ return replace(bart, z=z, resid=resid)
2496
+
2497
+
2498
+ def step_s(key: Key[Array, ''], bart: State) -> State:
2499
+ """
2500
+ Update `log_s` using Dirichlet sampling.
2501
+
2502
+ The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2503
+ is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2504
+ varcount is the count of how many times each variable is used in the
2505
+ current forest.
2506
+
2507
+ Parameters
2508
+ ----------
2509
+ key
2510
+ Random key for sampling.
2511
+ bart
2512
+ The current BART state.
2513
+
2514
+ Returns
2515
+ -------
2516
+ Updated BART state with re-sampled `log_s`.
2517
+
2518
+ """
2519
+ assert bart.forest.theta is not None
2520
+
2521
+ # histogram current variable usage
2522
+ p = bart.forest.max_split.size
2523
+ varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
2524
+
2525
+ # sample from Dirichlet posterior
2526
+ alpha = bart.forest.theta / p + varcount
2527
+ log_s = random.loggamma(key, alpha)
2528
+
2529
+ # update forest with new s
2530
+ return replace(bart, forest=replace(bart.forest, log_s=log_s))
2531
+
1819
2532
 
2533
+ def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
2534
+ """
2535
+ Update `theta`.
2536
+
2537
+ The prior is theta / (theta + rho) ~ Beta(a, b).
2538
+
2539
+ Parameters
2540
+ ----------
2541
+ key
2542
+ Random key for sampling.
2543
+ bart
2544
+ The current BART state.
2545
+ num_grid
2546
+ The number of points in the evenly-spaced grid used to sample
2547
+ theta / (theta + rho).
2548
+
2549
+ Returns
2550
+ -------
2551
+ Updated BART state with re-sampled `theta`.
2552
+ """
2553
+ assert bart.forest.log_s is not None
2554
+ assert bart.forest.rho is not None
2555
+ assert bart.forest.a is not None
2556
+ assert bart.forest.b is not None
2557
+
2558
+ # the grid points are the midpoints of num_grid bins in (0, 1)
2559
+ padding = 1 / (2 * num_grid)
2560
+ lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
2561
+
2562
+ # normalize s
2563
+ log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
2564
+
2565
+ # sample lambda
2566
+ logp, theta_grid = _log_p_lamda(
2567
+ lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2568
+ )
2569
+ i = random.categorical(key, logp)
2570
+ theta = theta_grid[i]
2571
+
2572
+ return replace(bart, forest=replace(bart.forest, theta=theta))
2573
+
2574
+
2575
+ def _log_p_lamda(
2576
+ lamda: Float32[Array, ' num_grid'],
2577
+ log_s: Float32[Array, ' p'],
2578
+ rho: Float32[Array, ''],
2579
+ a: Float32[Array, ''],
2580
+ b: Float32[Array, ''],
2581
+ ) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2582
+ # in the following I use lamda[::-1] == 1 - lamda
2583
+ theta = rho * lamda / lamda[::-1]
2584
+ p = log_s.size
2585
+ return (
2586
+ (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2587
+ + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2588
+ + gammaln(theta)
2589
+ - p * gammaln(theta / p)
2590
+ + theta / p * jnp.sum(log_s)
2591
+ ), theta
2592
+
2593
+
2594
+ def step_sparse(key: Key[Array, ''], bart: State) -> State:
2595
+ """
2596
+ Update the sparsity parameters.
2597
+
2598
+ This invokes `step_s`, and then `step_theta` only if the parameters of
2599
+ the theta prior are defined.
2600
+
2601
+ Parameters
2602
+ ----------
2603
+ key
2604
+ Random key for sampling.
2605
+ bart
2606
+ The current BART state.
2607
+
2608
+ Returns
2609
+ -------
2610
+ Updated BART state with re-sampled `log_s` and `theta`.
2611
+ """
2612
+ keys = split(key)
2613
+ bart = step_s(keys.pop(), bart)
2614
+ if bart.forest.rho is not None:
2615
+ bart = step_theta(keys.pop(), bart)
1820
2616
  return bart