bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -26,220 +26,292 @@
26
26
  Functions that implement the BART posterior MCMC initialization and update step.
27
27
 
28
28
  Functions that do MCMC steps operate by taking as input a bart state, and
29
- outputting a new dictionary with the new state. The input dict/arrays are not
30
- modified.
29
+ outputting a new state. The inputs are not modified.
31
30
 
32
- In general, integer types are chosen to be the minimal types that contain the
33
- range of possible values.
31
+ The main entry points are:
32
+
33
+ - `State`: The dataclass that represents a BART MCMC state.
34
+ - `init`: Creates an initial `State` from data and configurations.
35
+ - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
34
36
  """
35
37
 
36
- import functools
37
38
  import math
39
+ from dataclasses import replace
40
+ from functools import cache, partial
41
+ from typing import Any
38
42
 
39
43
  import jax
44
+ from equinox import Module, field
40
45
  from jax import lax, random
41
46
  from jax import numpy as jnp
47
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
48
+
49
+ from . import grove
50
+ from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
51
+
52
+
53
+ class Forest(Module):
54
+ """
55
+ Represents the MCMC state of a sum of trees.
56
+
57
+ Parameters
58
+ ----------
59
+ leaf_trees
60
+ The leaf values.
61
+ var_trees
62
+ The decision axes.
63
+ split_trees
64
+ The decision boundaries.
65
+ p_nonterminal
66
+ The probability of a nonterminal node at each depth, padded with a
67
+ zero.
68
+ p_propose_grow
69
+ The unnormalized probability of picking a leaf for a grow proposal.
70
+ leaf_indices
71
+ The index of the leaf each datapoints falls into, for each tree.
72
+ min_points_per_leaf
73
+ The minimum number of data points in a leaf node.
74
+ affluence_trees
75
+ Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
76
+ datapoints. If `min_points_per_leaf` is not specified, this is None.
77
+ resid_batch_size
78
+ count_batch_size
79
+ The data batch sizes for computing the sufficient statistics. If `None`,
80
+ they are computed with no batching.
81
+ log_trans_prior
82
+ The log transition and prior Metropolis-Hastings ratio for the
83
+ proposed move on each tree.
84
+ log_likelihood
85
+ The log likelihood ratio.
86
+ grow_prop_count
87
+ prune_prop_count
88
+ The number of grow/prune proposals made during one full MCMC cycle.
89
+ grow_acc_count
90
+ prune_acc_count
91
+ The number of grow/prune moves accepted during one full MCMC cycle.
92
+ sigma_mu2
93
+ The prior variance of a leaf, conditional on the tree structure.
94
+ """
95
+
96
+ leaf_trees: Float32[Array, 'num_trees 2**d']
97
+ var_trees: UInt[Array, 'num_trees 2**(d-1)']
98
+ split_trees: UInt[Array, 'num_trees 2**(d-1)']
99
+ p_nonterminal: Float32[Array, 'd']
100
+ p_propose_grow: Float32[Array, '2**(d-1)']
101
+ leaf_indices: UInt[Array, 'num_trees n']
102
+ min_points_per_leaf: Int32[Array, ''] | None
103
+ affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
104
+ resid_batch_size: int | None = field(static=True)
105
+ count_batch_size: int | None = field(static=True)
106
+ log_trans_prior: Float32[Array, 'num_trees'] | None
107
+ log_likelihood: Float32[Array, 'num_trees'] | None
108
+ grow_prop_count: Int32[Array, '']
109
+ prune_prop_count: Int32[Array, '']
110
+ grow_acc_count: Int32[Array, '']
111
+ prune_acc_count: Int32[Array, '']
112
+ sigma_mu2: Float32[Array, '']
113
+
114
+
115
+ class State(Module):
116
+ """
117
+ Represents the MCMC state of BART.
42
118
 
43
- from . import grove, jaxext
119
+ Parameters
120
+ ----------
121
+ X
122
+ The predictors.
123
+ max_split
124
+ The maximum split index for each predictor.
125
+ y
126
+ The response. If the data type is `bool`, the model is binary regression.
127
+ resid
128
+ The residuals (`y` or `z` minus sum of trees).
129
+ z
130
+ The latent variable for binary regression. `None` in continuous
131
+ regression.
132
+ offset
133
+ Constant shift added to the sum of trees.
134
+ sigma2
135
+ The error variance. `None` in binary regression.
136
+ prec_scale
137
+ The scale on the error precision, i.e., ``1 / error_scale ** 2``.
138
+ `None` in binary regression.
139
+ sigma2_alpha
140
+ sigma2_beta
141
+ The shape and scale parameters of the inverse gamma prior on the noise
142
+ variance. `None` in binary regression.
143
+ forest
144
+ The sum of trees model.
145
+ """
146
+
147
+ X: UInt[Array, 'p n']
148
+ max_split: UInt[Array, 'p']
149
+ y: Float32[Array, 'n'] | Bool[Array, 'n']
150
+ z: None | Float32[Array, 'n']
151
+ offset: Float32[Array, '']
152
+ resid: Float32[Array, 'n']
153
+ sigma2: Float32[Array, ''] | None
154
+ prec_scale: Float32[Array, 'n'] | None
155
+ sigma2_alpha: Float32[Array, ''] | None
156
+ sigma2_beta: Float32[Array, ''] | None
157
+ forest: Forest
44
158
 
45
159
 
46
160
  def init(
47
161
  *,
48
- X,
49
- y,
50
- max_split,
51
- num_trees,
52
- p_nonterminal,
53
- sigma2_alpha,
54
- sigma2_beta,
55
- error_scale=None,
56
- small_float=jnp.float32,
57
- large_float=jnp.float32,
58
- min_points_per_leaf=None,
59
- resid_batch_size='auto',
60
- count_batch_size='auto',
61
- save_ratios=False,
62
- ):
162
+ X: UInt[Any, 'p n'],
163
+ y: Float32[Any, 'n'] | Bool[Any, 'n'],
164
+ offset: float | Float32[Any, ''] = 0.0,
165
+ max_split: UInt[Any, 'p'],
166
+ num_trees: int,
167
+ p_nonterminal: Float32[Any, 'd-1'],
168
+ sigma_mu2: float | Float32[Any, ''],
169
+ sigma2_alpha: float | Float32[Any, ''] | None = None,
170
+ sigma2_beta: float | Float32[Any, ''] | None = None,
171
+ error_scale: Float32[Any, 'n'] | None = None,
172
+ min_points_per_leaf: int | None = None,
173
+ resid_batch_size: int | None | str = 'auto',
174
+ count_batch_size: int | None | str = 'auto',
175
+ save_ratios: bool = False,
176
+ ) -> State:
63
177
  """
64
178
  Make a BART posterior sampling MCMC initial state.
65
179
 
66
180
  Parameters
67
181
  ----------
68
- X : int array (p, n)
182
+ X
69
183
  The predictors. Note this is trasposed compared to the usual convention.
70
- y : float array (n,)
71
- The response.
72
- max_split : int array (p,)
184
+ y
185
+ The response. If the data type is `bool`, the regression model is binary
186
+ regression with probit.
187
+ offset
188
+ Constant shift added to the sum of trees. 0 if not specified.
189
+ max_split
73
190
  The maximum split index for each variable. All split ranges start at 1.
74
- num_trees : int
191
+ num_trees
75
192
  The number of trees in the forest.
76
- p_nonterminal : float array (d - 1,)
193
+ p_nonterminal
77
194
  The probability of a nonterminal node at each depth. The maximum depth
78
195
  of trees is fixed by the length of this array.
79
- sigma2_alpha : float
80
- The shape parameter of the inverse gamma prior on the error variance.
81
- sigma2_beta : float
82
- The scale parameter of the inverse gamma prior on the error variance.
83
- error_scale : float array (n,), optional
196
+ sigma_mu2
197
+ The prior variance of a leaf, conditional on the tree structure. The
198
+ prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
199
+ prior mean of leaves is always zero.
200
+ sigma2_alpha
201
+ sigma2_beta
202
+ The shape and scale parameters of the inverse gamma prior on the error
203
+ variance. Leave unspecified for binary regression.
204
+ error_scale
84
205
  Each error is scaled by the corresponding factor in `error_scale`, so
85
206
  the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
86
- small_float : dtype, default float32
87
- The dtype for large arrays used in the algorithm.
88
- large_float : dtype, default float32
89
- The dtype for scalars, small arrays, and arrays which require accuracy.
90
- min_points_per_leaf : int, optional
207
+ Not supported for binary regression. If not specified, defaults to 1 for
208
+ all points, but potentially skipping calculations.
209
+ min_points_per_leaf
91
210
  The minimum number of data points in a leaf node. 0 if not specified.
92
- resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
211
+ resid_batch_size
212
+ count_batch_size
93
213
  The batch sizes, along datapoints, for summing the residuals and
94
214
  counting the number of datapoints in each leaf. `None` for no batching.
95
215
  If 'auto', pick a value based on the device of `y`, or the default
96
216
  device.
97
- save_ratios : bool, default False
217
+ save_ratios
98
218
  Whether to save the Metropolis-Hastings ratios.
99
219
 
100
220
  Returns
101
221
  -------
102
- bart : dict
103
- A dictionary with array values, representing a BART mcmc state. The
104
- keys are:
105
-
106
- 'leaf_trees' : small_float array (num_trees, 2 ** d)
107
- The leaf values.
108
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
109
- The decision axes.
110
- 'split_trees' : int array (num_trees, 2 ** (d - 1))
111
- The decision boundaries.
112
- 'resid' : large_float array (n,)
113
- The residuals (data minus forest value). Large float to avoid
114
- roundoff.
115
- 'sigma2' : large_float
116
- The noise variance.
117
- 'prec_scale' : large_float array (n,) or None
118
- The scale on the error precision, i.e., ``1 / error_scale ** 2``.
119
- 'grow_prop_count', 'prune_prop_count' : int
120
- The number of grow/prune proposals made during one full MCMC cycle.
121
- 'grow_acc_count', 'prune_acc_count' : int
122
- The number of grow/prune moves accepted during one full MCMC cycle.
123
- 'p_nonterminal' : large_float array (d,)
124
- The probability of a nonterminal node at each depth, padded with a
125
- zero.
126
- 'p_propose_grow' : large_float array (2 ** (d - 1),)
127
- The unnormalized probability of picking a leaf for a grow proposal.
128
- 'sigma2_alpha' : large_float
129
- The shape parameter of the inverse gamma prior on the noise variance.
130
- 'sigma2_beta' : large_float
131
- The scale parameter of the inverse gamma prior on the noise variance.
132
- 'max_split' : int array (p,)
133
- The maximum split index for each variable.
134
- 'y' : small_float array (n,)
135
- The response.
136
- 'X' : int array (p, n)
137
- The predictors.
138
- 'leaf_indices' : int array (num_trees, n)
139
- The index of the leaf each datapoints falls into, for each tree.
140
- 'min_points_per_leaf' : int or None
141
- The minimum number of data points in a leaf node.
142
- 'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
143
- Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
144
- datapoints. If `min_points_per_leaf` is not specified, this is None.
145
- 'opt' : LeafDict
146
- A dictionary with config values:
147
-
148
- 'small_float' : dtype
149
- The dtype for large arrays used in the algorithm.
150
- 'large_float' : dtype
151
- The dtype for scalars, small arrays, and arrays which require
152
- accuracy.
153
- 'require_min_points' : bool
154
- Whether the `min_points_per_leaf` parameter is specified.
155
- 'resid_batch_size', 'count_batch_size' : int or None
156
- The data batch sizes for computing the sufficient statistics.
157
- 'ratios' : dict, optional
158
- If `save_ratios` is True, this field is present. It has the fields:
159
-
160
- 'log_trans_prior' : large_float array (num_trees,)
161
- The log transition and prior Metropolis-Hastings ratio for the
162
- proposed move on each tree.
163
- 'log_likelihood' : large_float array (num_trees,)
164
- The log likelihood ratio.
165
- """
166
-
167
- p_nonterminal = jnp.asarray(p_nonterminal, large_float)
222
+ An initialized BART MCMC state.
223
+
224
+ Raises
225
+ ------
226
+ ValueError
227
+ If `y` is boolean and arguments unused in binary regression are set.
228
+ """
229
+ p_nonterminal = jnp.asarray(p_nonterminal)
168
230
  p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
169
231
  max_depth = p_nonterminal.size
170
232
 
171
- @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
233
+ @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
172
234
  def make_forest(max_depth, dtype):
173
235
  return grove.make_tree(max_depth, dtype)
174
236
 
175
- small_float = jnp.dtype(small_float)
176
- large_float = jnp.dtype(large_float)
177
- y = jnp.asarray(y, small_float)
237
+ y = jnp.asarray(y)
238
+ offset = jnp.asarray(offset)
239
+
178
240
  resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
179
241
  resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
180
242
  )
181
- sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
182
- sigma2 = jnp.where(
183
- jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1
184
- ) # TODO: I don't like this error check, these functions should be low-level and just do the thing. Why is it here?
185
-
186
- bart = dict(
187
- leaf_trees=make_forest(max_depth, small_float),
188
- var_trees=make_forest(
189
- max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)
190
- ),
191
- split_trees=make_forest(max_depth - 1, max_split.dtype),
192
- resid=jnp.asarray(y, large_float),
193
- sigma2=sigma2,
194
- prec_scale=(
195
- None
196
- if error_scale is None
197
- else lax.reciprocal(jnp.square(jnp.asarray(error_scale, large_float)))
198
- ),
199
- grow_prop_count=jnp.zeros((), int),
200
- grow_acc_count=jnp.zeros((), int),
201
- prune_prop_count=jnp.zeros((), int),
202
- prune_acc_count=jnp.zeros((), int),
203
- p_nonterminal=p_nonterminal,
204
- p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
205
- sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
206
- sigma2_beta=jnp.asarray(sigma2_beta, large_float),
243
+
244
+ is_binary = y.dtype == bool
245
+ if is_binary:
246
+ if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
247
+ raise ValueError(
248
+ 'error_scale, sigma2_alpha, and sigma2_beta must be set '
249
+ ' to `None` for binary regression.'
250
+ )
251
+ sigma2 = None
252
+ else:
253
+ sigma2_alpha = jnp.asarray(sigma2_alpha)
254
+ sigma2_beta = jnp.asarray(sigma2_beta)
255
+ sigma2 = sigma2_beta / sigma2_alpha
256
+ # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
257
+ # TODO: I don't like this isfinite check, these functions should be
258
+ # low-level and just do the thing. Why was it here?
259
+
260
+ bart = State(
261
+ X=jnp.asarray(X),
207
262
  max_split=jnp.asarray(max_split),
208
263
  y=y,
209
- X=jnp.asarray(X),
210
- leaf_indices=jnp.ones(
211
- (num_trees, y.size), jaxext.minimal_unsigned_dtype(2**max_depth - 1)
212
- ),
213
- min_points_per_leaf=(
214
- None if min_points_per_leaf is None else jnp.asarray(min_points_per_leaf)
215
- ),
216
- affluence_trees=(
217
- None
218
- if min_points_per_leaf is None
219
- else make_forest(max_depth - 1, bool)
220
- .at[:, 1]
221
- .set(y.size >= 2 * min_points_per_leaf)
264
+ z=jnp.full(y.shape, offset) if is_binary else None,
265
+ offset=offset,
266
+ resid=jnp.zeros(y.shape) if is_binary else y - offset,
267
+ sigma2=sigma2,
268
+ prec_scale=(
269
+ None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
222
270
  ),
223
- opt=jaxext.LeafDict(
224
- small_float=small_float,
225
- large_float=large_float,
226
- require_min_points=min_points_per_leaf is not None,
271
+ sigma2_alpha=sigma2_alpha,
272
+ sigma2_beta=sigma2_beta,
273
+ forest=Forest(
274
+ leaf_trees=make_forest(max_depth, jnp.float32),
275
+ var_trees=make_forest(
276
+ max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
277
+ ),
278
+ split_trees=make_forest(max_depth - 1, max_split.dtype),
279
+ grow_prop_count=jnp.zeros((), int),
280
+ grow_acc_count=jnp.zeros((), int),
281
+ prune_prop_count=jnp.zeros((), int),
282
+ prune_acc_count=jnp.zeros((), int),
283
+ p_nonterminal=p_nonterminal,
284
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
285
+ leaf_indices=jnp.ones(
286
+ (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
287
+ ),
288
+ min_points_per_leaf=(
289
+ None
290
+ if min_points_per_leaf is None
291
+ else jnp.asarray(min_points_per_leaf)
292
+ ),
293
+ affluence_trees=(
294
+ None
295
+ if min_points_per_leaf is None
296
+ else make_forest(max_depth - 1, bool)
297
+ .at[:, 1]
298
+ .set(y.size >= 2 * min_points_per_leaf)
299
+ ),
227
300
  resid_batch_size=resid_batch_size,
228
301
  count_batch_size=count_batch_size,
302
+ log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
303
+ log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
304
+ sigma_mu2=jnp.asarray(sigma_mu2),
229
305
  ),
230
306
  )
231
307
 
232
- if save_ratios:
233
- bart['ratios'] = dict(
234
- log_trans_prior=jnp.full(num_trees, jnp.nan),
235
- log_likelihood=jnp.full(num_trees, jnp.nan),
236
- )
237
-
238
308
  return bart
239
309
 
240
310
 
241
- def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
242
- @functools.cache
311
+ def _choose_suffstat_batch_size(
312
+ resid_batch_size, count_batch_size, y, forest_size
313
+ ) -> tuple[int | None, ...]:
314
+ @cache
243
315
  def get_platform():
244
316
  try:
245
317
  device = y.devices().pop()
@@ -276,231 +348,327 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
276
348
  return resid_batch_size, count_batch_size
277
349
 
278
350
 
279
- def step(key, bart):
351
+ @jax.jit
352
+ def step(key: Key[Array, ''], bart: State) -> State:
280
353
  """
281
- Perform one full MCMC step on a BART state.
354
+ Do one MCMC step.
282
355
 
283
356
  Parameters
284
357
  ----------
285
- key : jax.dtypes.prng_key array
358
+ key
286
359
  A jax random key.
287
- bart : dict
360
+ bart
288
361
  A BART mcmc state, as created by `init`.
289
362
 
290
363
  Returns
291
364
  -------
292
- bart : dict
293
- The new BART mcmc state.
365
+ The new BART mcmc state.
294
366
  """
295
- key, subkey = random.split(key)
296
- bart = sample_trees(subkey, bart)
297
- return sample_sigma(key, bart)
367
+ keys = split(key)
368
+
369
+ if bart.y.dtype == bool: # binary regression
370
+ bart = replace(bart, sigma2=jnp.float32(1))
371
+ bart = step_trees(keys.pop(), bart)
372
+ bart = replace(bart, sigma2=None)
373
+ return step_z(keys.pop(), bart)
298
374
 
375
+ else: # continuous regression
376
+ bart = step_trees(keys.pop(), bart)
377
+ return step_sigma(keys.pop(), bart)
299
378
 
300
- def sample_trees(key, bart):
379
+
380
+ def step_trees(key: Key[Array, ''], bart: State) -> State:
301
381
  """
302
382
  Forest sampling step of BART MCMC.
303
383
 
304
384
  Parameters
305
385
  ----------
306
- key : jax.dtypes.prng_key array
386
+ key
307
387
  A jax random key.
308
- bart : dict
388
+ bart
309
389
  A BART mcmc state, as created by `init`.
310
390
 
311
391
  Returns
312
392
  -------
313
- bart : dict
314
- The new BART mcmc state.
393
+ The new BART mcmc state.
315
394
 
316
395
  Notes
317
396
  -----
318
397
  This function zeroes the proposal counters.
319
398
  """
320
- key, subkey = random.split(key)
321
- moves = sample_moves(subkey, bart)
322
- return accept_moves_and_sample_leaves(key, bart, moves)
399
+ keys = split(key)
400
+ moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
401
+ return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
323
402
 
324
403
 
325
- def sample_moves(key, bart):
404
+ class Moves(Module):
405
+ """
406
+ Moves proposed to modify each tree.
407
+
408
+ Parameters
409
+ ----------
410
+ allowed
411
+ Whether the move is possible in the first place. There are additional
412
+ constraints that could forbid it, but they are computed at acceptance
413
+ time.
414
+ grow
415
+ Whether the move is a grow move or a prune move.
416
+ num_growable
417
+ The number of growable leaves in the original tree.
418
+ node
419
+ The index of the leaf to grow or node to prune.
420
+ left
421
+ right
422
+ The indices of the children of 'node'.
423
+ partial_ratio
424
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
425
+ the likelihood ratio and the probability of proposing the prune
426
+ move. If the move is PRUNE, the ratio is inverted. `None` once
427
+ `log_trans_prior_ratio` has been computed.
428
+ log_trans_prior_ratio
429
+ The logarithm of the product of the transition and prior terms of the
430
+ Metropolis-Hastings ratio for the acceptance of the proposed move.
431
+ `None` if not yet computed.
432
+ grow_var
433
+ The decision axes of the new rules.
434
+ grow_split
435
+ The decision boundaries of the new rules.
436
+ var_trees
437
+ The updated decision axes of the trees, valid whatever move.
438
+ logu
439
+ The logarithm of a uniform (0, 1] random variable to be used to
440
+ accept the move. It's in (-oo, 0].
441
+ acc
442
+ Whether the move was accepted. `None` if not yet computed.
443
+ to_prune
444
+ Whether the final operation to apply the move is pruning. This indicates
445
+ an accepted prune move or a rejected grow move. `None` if not yet
446
+ computed.
447
+ """
448
+
449
+ allowed: Bool[Array, 'num_trees']
450
+ grow: Bool[Array, 'num_trees']
451
+ num_growable: UInt[Array, 'num_trees']
452
+ node: UInt[Array, 'num_trees']
453
+ left: UInt[Array, 'num_trees']
454
+ right: UInt[Array, 'num_trees']
455
+ partial_ratio: Float32[Array, 'num_trees'] | None
456
+ log_trans_prior_ratio: None | Float32[Array, 'num_trees']
457
+ grow_var: UInt[Array, 'num_trees']
458
+ grow_split: UInt[Array, 'num_trees']
459
+ var_trees: UInt[Array, 'num_trees 2**(d-1)']
460
+ logu: Float32[Array, 'num_trees']
461
+ acc: None | Bool[Array, 'num_trees']
462
+ to_prune: None | Bool[Array, 'num_trees']
463
+
464
+
465
+ def propose_moves(
466
+ key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
467
+ ) -> Moves:
326
468
  """
327
469
  Propose moves for all the trees.
328
470
 
471
+ There are two types of moves: GROW (convert a leaf to a decision node and
472
+ add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
473
+ leaf, deleting its children).
474
+
329
475
  Parameters
330
476
  ----------
331
- key : jax.dtypes.prng_key array
477
+ key
332
478
  A jax random key.
333
- bart : dict
334
- BART mcmc state.
479
+ forest
480
+ The `forest` field of a BART MCMC state.
481
+ max_split
482
+ The maximum split index for each variable, found in `State`.
335
483
 
336
484
  Returns
337
485
  -------
338
- moves : dict
339
- A dictionary with fields:
340
-
341
- 'allowed' : bool array (num_trees,)
342
- Whether the move is possible.
343
- 'grow' : bool array (num_trees,)
344
- Whether the move is a grow move or a prune move.
345
- 'num_growable' : int array (num_trees,)
346
- The number of growable leaves in the original tree.
347
- 'node' : int array (num_trees,)
348
- The index of the leaf to grow or node to prune.
349
- 'left', 'right' : int array (num_trees,)
350
- The indices of the children of 'node'.
351
- 'partial_ratio' : float array (num_trees,)
352
- A factor of the Metropolis-Hastings ratio of the move. It lacks
353
- the likelihood ratio and the probability of proposing the prune
354
- move. If the move is Prune, the ratio is inverted.
355
- 'grow_var' : int array (num_trees,)
356
- The decision axes of the new rules.
357
- 'grow_split' : int array (num_trees,)
358
- The decision boundaries of the new rules.
359
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
360
- The updated decision axes of the trees, valid whatever move.
361
- 'logu' : float array (num_trees,)
362
- The logarithm of a uniform (0, 1] random variable to be used to
363
- accept the move. It's in (-oo, 0].
364
- """
365
- ntree = bart['leaf_trees'].shape[0]
366
- key = random.split(key, 1 + ntree)
367
- key, subkey = key[0], key[1:]
486
+ The proposed move for each tree.
487
+ """
488
+ num_trees, _ = forest.leaf_trees.shape
489
+ keys = split(key, 1 + 2 * num_trees)
368
490
 
369
491
  # compute moves
370
- grow_moves, prune_moves = _sample_moves_vmap_trees(
371
- subkey,
372
- bart['var_trees'],
373
- bart['split_trees'],
374
- bart['affluence_trees'],
375
- bart['max_split'],
376
- bart['p_nonterminal'],
377
- bart['p_propose_grow'],
492
+ grow_moves = propose_grow_moves(
493
+ keys.pop(num_trees),
494
+ forest.var_trees,
495
+ forest.split_trees,
496
+ forest.affluence_trees,
497
+ max_split,
498
+ forest.p_nonterminal,
499
+ forest.p_propose_grow,
500
+ )
501
+ prune_moves = propose_prune_moves(
502
+ keys.pop(num_trees),
503
+ forest.split_trees,
504
+ forest.affluence_trees,
505
+ forest.p_nonterminal,
506
+ forest.p_propose_grow,
378
507
  )
379
508
 
380
- u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
509
+ u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
381
510
 
382
511
  # choose between grow or prune
383
- grow_allowed = grow_moves['num_growable'].astype(bool)
384
- p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
512
+ grow_allowed = grow_moves.num_growable.astype(bool)
513
+ p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
385
514
  grow = u < p_grow # use < instead of <= because u is in [0, 1)
386
515
 
387
516
  # compute children indices
388
- node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
517
+ node = jnp.where(grow, grow_moves.node, prune_moves.node)
389
518
  left = node << 1
390
519
  right = left + 1
391
520
 
392
- return dict(
393
- allowed=grow | prune_moves['allowed'],
521
+ return Moves(
522
+ allowed=grow | prune_moves.allowed,
394
523
  grow=grow,
395
- num_growable=grow_moves['num_growable'],
524
+ num_growable=grow_moves.num_growable,
396
525
  node=node,
397
526
  left=left,
398
527
  right=right,
399
528
  partial_ratio=jnp.where(
400
- grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']
529
+ grow, grow_moves.partial_ratio, prune_moves.partial_ratio
401
530
  ),
402
- grow_var=grow_moves['var'],
403
- grow_split=grow_moves['split'],
404
- var_trees=grow_moves['var_tree'],
531
+ log_trans_prior_ratio=None, # will be set in complete_ratio
532
+ grow_var=grow_moves.var,
533
+ grow_split=grow_moves.split,
534
+ var_trees=grow_moves.var_tree,
405
535
  logu=jnp.log1p(-logu),
536
+ acc=None, # will be set in accept_moves_sequential_stage
537
+ to_prune=None, # will be set in accept_moves_sequential_stage
406
538
  )
407
539
 
408
540
 
409
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
410
- def _sample_moves_vmap_trees(*args):
411
- key, args = args[0], args[1:]
412
- key, key1 = random.split(key)
413
- grow = grow_move(key, *args)
414
- prune = prune_move(key1, *args)
415
- return grow, prune
416
-
417
-
418
- def grow_move(
419
- key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
420
- ):
541
+ class GrowMoves(Module):
421
542
  """
422
- Tree structure grow move proposal of BART MCMC.
543
+ Represent a proposed grow move for each tree.
423
544
 
424
- This moves picks a leaf node and converts it to a non-terminal node with
425
- two leaf children. The move is not possible if all the leaves are already at
426
- maximum depth.
545
+ Parameters
546
+ ----------
547
+ num_growable
548
+ The number of growable leaves.
549
+ node
550
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
551
+ leaves.
552
+ var
553
+ split
554
+ The decision axis and boundary of the new rule.
555
+ partial_ratio
556
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
557
+ the likelihood ratio and the probability of proposing the prune
558
+ move.
559
+ var_tree
560
+ The updated decision axes of the tree.
561
+ """
562
+
563
+ num_growable: UInt[Array, 'num_trees']
564
+ node: UInt[Array, 'num_trees']
565
+ var: UInt[Array, 'num_trees']
566
+ split: UInt[Array, 'num_trees']
567
+ partial_ratio: Float32[Array, 'num_trees']
568
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
569
+
570
+
571
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
572
+ def propose_grow_moves(
573
+ key: Key[Array, ''],
574
+ var_tree: UInt[Array, '2**(d-1)'],
575
+ split_tree: UInt[Array, '2**(d-1)'],
576
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
577
+ max_split: UInt[Array, 'p'],
578
+ p_nonterminal: Float32[Array, 'd'],
579
+ p_propose_grow: Float32[Array, '2**(d-1)'],
580
+ ) -> GrowMoves:
581
+ """
582
+ Propose a GROW move for each tree.
583
+
584
+ A GROW move picks a leaf node and converts it to a non-terminal node with
585
+ two leaf children.
427
586
 
428
587
  Parameters
429
588
  ----------
430
- var_tree : array (2 ** (d - 1),)
431
- The variable indices of the tree.
432
- split_tree : array (2 ** (d - 1),)
589
+ key
590
+ A jax random key.
591
+ var_tree
592
+ The splitting axes of the tree.
593
+ split_tree
433
594
  The splitting points of the tree.
434
- affluence_tree : bool array (2 ** (d - 1),) or None
595
+ affluence_tree
435
596
  Whether a leaf has enough points to be grown.
436
- max_split : array (p,)
597
+ max_split
437
598
  The maximum split index for each variable.
438
- p_nonterminal : array (d,)
599
+ p_nonterminal
439
600
  The probability of a nonterminal node at each depth.
440
- p_propose_grow : array (2 ** (d - 1),)
601
+ p_propose_grow
441
602
  The unnormalized probability of choosing a leaf to grow.
442
- key : jax.dtypes.prng_key array
443
- A jax random key.
444
603
 
445
604
  Returns
446
605
  -------
447
- grow_move : dict
448
- A dictionary with fields:
449
-
450
- 'num_growable' : int
451
- The number of growable leaves.
452
- 'node' : int
453
- The index of the leaf to grow. ``2 ** d`` if there are no growable
454
- leaves.
455
- 'var', 'split' : int
456
- The decision axis and boundary of the new rule.
457
- 'partial_ratio' : float
458
- A factor of the Metropolis-Hastings ratio of the move. It lacks
459
- the likelihood ratio and the probability of proposing the prune
460
- move.
461
- 'var_tree' : array (2 ** (d - 1),)
462
- The updated decision axes of the tree.
463
- """
464
-
465
- key, key1, key2 = random.split(key, 3)
606
+ An object representing the proposed move.
607
+
608
+ Notes
609
+ -----
610
+ The move is not proposed if a leaf is already at maximum depth, or if a leaf
611
+ has less than twice the requested minimum number of datapoints per leaf.
612
+ This is marked by returning `num_growable` set to 0.
613
+
614
+ The move is also not be possible if the ancestors of a leaf have
615
+ exhausted the possible decision rules that lead to a non-empty selection.
616
+ This is marked by returning `var` set to `p` and `split` set to 0. But this
617
+ does not block the move from counting as "proposed", even though it is
618
+ predictably going to be rejected. This simplifies the MCMC and should not
619
+ reduce efficiency if not in unrealistic corner cases.
620
+ """
621
+ keys = split(key, 3)
466
622
 
467
623
  leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
468
- key, split_tree, affluence_tree, p_propose_grow
624
+ keys.pop(), split_tree, affluence_tree, p_propose_grow
469
625
  )
470
626
 
471
- var = choose_variable(key1, var_tree, split_tree, max_split, leaf_to_grow)
627
+ var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
472
628
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
473
629
 
474
- split = choose_split(key2, var_tree, split_tree, max_split, leaf_to_grow)
630
+ split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
475
631
 
476
632
  ratio = compute_partial_ratio(
477
633
  prob_choose, num_prunable, p_nonterminal, leaf_to_grow
478
634
  )
479
635
 
480
- return dict(
636
+ return GrowMoves(
481
637
  num_growable=num_growable,
482
638
  node=leaf_to_grow,
483
639
  var=var,
484
- split=split,
640
+ split=split_idx,
485
641
  partial_ratio=ratio,
486
642
  var_tree=var_tree,
487
643
  )
488
644
 
645
+ # TODO it is not clear to me how var=p and split=0 when the move is not
646
+ # possible lead to corrent behavior downstream. Like, the move is proposed,
647
+ # but then it's a noop? And since it's a noop, it makes no difference if
648
+ # it's "accepted" or "rejected", it's like it's always rejected, so who
649
+ # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
650
+
489
651
 
490
- def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
652
+ def choose_leaf(
653
+ key: Key[Array, ''],
654
+ split_tree: UInt[Array, '2**(d-1)'],
655
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
656
+ p_propose_grow: Float32[Array, '2**(d-1)'],
657
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
491
658
  """
492
659
  Choose a leaf node to grow in a tree.
493
660
 
494
661
  Parameters
495
662
  ----------
496
- split_tree : array (2 ** (d - 1),)
663
+ key
664
+ A jax random key.
665
+ split_tree
497
666
  The splitting points of the tree.
498
- affluence_tree : bool array (2 ** (d - 1),) or None
499
- Whether a leaf has enough points to be grown.
500
- p_propose_grow : array (2 ** (d - 1),)
667
+ affluence_tree
668
+ Whether a leaf has enough points that it could be split into two leaves
669
+ satisfying the `min_points_per_leaf` requirement.
670
+ p_propose_grow
501
671
  The unnormalized probability of choosing a leaf to grow.
502
- key : jax.dtypes.prng_key array
503
- A jax random key.
504
672
 
505
673
  Returns
506
674
  -------
@@ -508,9 +676,11 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
508
676
  The index of the leaf to grow. If ``num_growable == 0``, return
509
677
  ``2 ** d``.
510
678
  num_growable : int
511
- The number of leaf nodes that can be grown.
679
+ The number of leaf nodes that can be grown, i.e., are nonterminal
680
+ and have at least twice `min_points_per_leaf` if set.
512
681
  prob_choose : float
513
- The normalized probability of choosing the selected leaf.
682
+ The (normalized) probability that this function had to choose that
683
+ specific leaf, given the arguments.
514
684
  num_prunable : int
515
685
  The number of leaf parents that could be pruned, after converting the
516
686
  selected leaf to a non-terminal node.
@@ -526,23 +696,26 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
526
696
  return leaf_to_grow, num_growable, prob_choose, num_prunable
527
697
 
528
698
 
529
- def growable_leaves(split_tree, affluence_tree):
699
+ def growable_leaves(
700
+ split_tree: UInt[Array, '2**(d-1)'],
701
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
702
+ ) -> Bool[Array, '2**(d-1)']:
530
703
  """
531
704
  Return a mask indicating the leaf nodes that can be proposed for growth.
532
705
 
706
+ The condition is that a leaf is not at the bottom level and has at least two
707
+ times the number of minimum points per leaf.
708
+
533
709
  Parameters
534
710
  ----------
535
- split_tree : array (2 ** (d - 1),)
711
+ split_tree
536
712
  The splitting points of the tree.
537
- affluence_tree : bool array (2 ** (d - 1),) or None
713
+ affluence_tree
538
714
  Whether a leaf has enough points to be grown.
539
715
 
540
716
  Returns
541
717
  -------
542
- is_growable : bool array (2 ** (d - 1),)
543
- The mask indicating the leaf nodes that can be proposed to grow, i.e.,
544
- that are not at the bottom level and have at least two times the number
545
- of minimum points per leaf.
718
+ The mask indicating the leaf nodes that can be proposed to grow.
546
719
  """
547
720
  is_growable = grove.is_actual_leaf(split_tree)
548
721
  if affluence_tree is not None:
@@ -550,23 +723,25 @@ def growable_leaves(split_tree, affluence_tree):
550
723
  return is_growable
551
724
 
552
725
 
553
- def categorical(key, distr):
726
+ def categorical(
727
+ key: Key[Array, ''], distr: Float32[Array, 'n']
728
+ ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
554
729
  """
555
730
  Return a random integer from an arbitrary distribution.
556
731
 
557
732
  Parameters
558
733
  ----------
559
- key : jax.dtypes.prng_key array
734
+ key
560
735
  A jax random key.
561
- distr : float array (n,)
736
+ distr
562
737
  An unnormalized probability distribution.
563
738
 
564
739
  Returns
565
740
  -------
566
- u : int
741
+ u : Int32[Array, '']
567
742
  A random integer in the range ``[0, n)``. If all probabilities are zero,
568
743
  return ``n``.
569
- norm : float
744
+ norm : Float32[Array, '']
570
745
  The sum of `distr`.
571
746
  """
572
747
  ecdf = jnp.cumsum(distr)
@@ -574,27 +749,32 @@ def categorical(key, distr):
574
749
  return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
575
750
 
576
751
 
577
- def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
752
+ def choose_variable(
753
+ key: Key[Array, ''],
754
+ var_tree: UInt[Array, '2**(d-1)'],
755
+ split_tree: UInt[Array, '2**(d-1)'],
756
+ max_split: UInt[Array, 'p'],
757
+ leaf_index: Int32[Array, ''],
758
+ ) -> Int32[Array, '']:
578
759
  """
579
760
  Choose a variable to split on for a new non-terminal node.
580
761
 
581
762
  Parameters
582
763
  ----------
583
- var_tree : int array (2 ** (d - 1),)
764
+ key
765
+ A jax random key.
766
+ var_tree
584
767
  The variable indices of the tree.
585
- split_tree : int array (2 ** (d - 1),)
768
+ split_tree
586
769
  The splitting points of the tree.
587
- max_split : int array (p,)
770
+ max_split
588
771
  The maximum split index for each variable.
589
- leaf_index : int
772
+ leaf_index
590
773
  The index of the leaf to grow.
591
- key : jax.dtypes.prng_key array
592
- A jax random key.
593
774
 
594
775
  Returns
595
776
  -------
596
- var : int
597
- The index of the variable to split on.
777
+ The index of the variable to split on.
598
778
 
599
779
  Notes
600
780
  -----
@@ -605,30 +785,36 @@ def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
605
785
  return randint_exclude(key, max_split.size, var_to_ignore)
606
786
 
607
787
 
608
- def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
788
+ def fully_used_variables(
789
+ var_tree: UInt[Array, '2**(d-1)'],
790
+ split_tree: UInt[Array, '2**(d-1)'],
791
+ max_split: UInt[Array, 'p'],
792
+ leaf_index: Int32[Array, ''],
793
+ ) -> UInt[Array, 'd-2']:
609
794
  """
610
795
  Return a list of variables that have an empty split range at a given node.
611
796
 
612
797
  Parameters
613
798
  ----------
614
- var_tree : int array (2 ** (d - 1),)
799
+ var_tree
615
800
  The variable indices of the tree.
616
- split_tree : int array (2 ** (d - 1),)
801
+ split_tree
617
802
  The splitting points of the tree.
618
- max_split : int array (p,)
803
+ max_split
619
804
  The maximum split index for each variable.
620
- leaf_index : int
805
+ leaf_index
621
806
  The index of the node, assumed to be valid for `var_tree`.
622
807
 
623
808
  Returns
624
809
  -------
625
- var_to_ignore : int array (d - 2,)
626
- The indices of the variables that have an empty split range. Since the
627
- number of such variables is not fixed, unused values in the array are
628
- filled with `p`. The fill values are not guaranteed to be placed in any
629
- particular order. Variables may appear more than once.
630
- """
810
+ The indices of the variables that have an empty split range.
631
811
 
812
+ Notes
813
+ -----
814
+ The number of unused variables is not known in advance. Unused values in the
815
+ array are filled with `p`. The fill values are not guaranteed to be placed
816
+ in any particular order, and variables may appear more than once.
817
+ """
632
818
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
633
819
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
634
820
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
@@ -636,7 +822,11 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
636
822
  return jnp.where(num_split == 0, var_to_ignore, max_split.size)
637
823
 
638
824
 
639
- def ancestor_variables(var_tree, max_split, node_index):
825
+ def ancestor_variables(
826
+ var_tree: UInt[Array, '2**(d-1)'],
827
+ max_split: UInt[Array, 'p'],
828
+ node_index: Int32[Array, ''],
829
+ ) -> UInt[Array, 'd-2']:
640
830
  """
641
831
  Return the list of variables in the ancestors of a node.
642
832
 
@@ -651,14 +841,16 @@ def ancestor_variables(var_tree, max_split, node_index):
651
841
 
652
842
  Returns
653
843
  -------
654
- ancestor_vars : int array (d - 2,)
655
- The variable indices of the ancestors of the node, from the root to
656
- the parent. Unused spots are filled with `p`.
844
+ The variable indices of the ancestors of the node.
845
+
846
+ Notes
847
+ -----
848
+ The ancestors are the nodes going from the root to the parent of the node.
849
+ The number of ancestors is not known at tracing time; unused spots in the
850
+ output array are filled with `p`.
657
851
  """
658
852
  max_num_ancestors = grove.tree_depth(var_tree) - 1
659
- ancestor_vars = jnp.zeros(
660
- max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
661
- )
853
+ ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
662
854
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
663
855
 
664
856
  def loop(carry, _):
@@ -673,27 +865,32 @@ def ancestor_variables(var_tree, max_split, node_index):
673
865
  return ancestor_vars
674
866
 
675
867
 
676
- def split_range(var_tree, split_tree, max_split, node_index, ref_var):
868
+ def split_range(
869
+ var_tree: UInt[Array, '2**(d-1)'],
870
+ split_tree: UInt[Array, '2**(d-1)'],
871
+ max_split: UInt[Array, 'p'],
872
+ node_index: Int32[Array, ''],
873
+ ref_var: Int32[Array, ''],
874
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
677
875
  """
678
876
  Return the range of allowed splits for a variable at a given node.
679
877
 
680
878
  Parameters
681
879
  ----------
682
- var_tree : int array (2 ** (d - 1),)
880
+ var_tree
683
881
  The variable indices of the tree.
684
- split_tree : int array (2 ** (d - 1),)
882
+ split_tree
685
883
  The splitting points of the tree.
686
- max_split : int array (p,)
884
+ max_split
687
885
  The maximum split index for each variable.
688
- node_index : int
886
+ node_index
689
887
  The index of the node, assumed to be valid for `var_tree`.
690
- ref_var : int
888
+ ref_var
691
889
  The variable for which to measure the split range.
692
890
 
693
891
  Returns
694
892
  -------
695
- l, r : int
696
- The range of allowed splits is [l, r).
893
+ The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
697
894
  """
698
895
  max_num_ancestors = grove.tree_depth(var_tree) - 1
699
896
  initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
@@ -715,26 +912,29 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
715
912
  return l + 1, r
716
913
 
717
914
 
718
- def randint_exclude(key, sup, exclude):
915
+ def randint_exclude(
916
+ key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
917
+ ) -> Int32[Array, '']:
719
918
  """
720
919
  Return a random integer in a range, excluding some values.
721
920
 
722
921
  Parameters
723
922
  ----------
724
- key : jax.dtypes.prng_key array
923
+ key
725
924
  A jax random key.
726
- sup : int
925
+ sup
727
926
  The exclusive upper bound of the range.
728
- exclude : int array (n,)
927
+ exclude
729
928
  The values to exclude from the range. Values greater than or equal to
730
929
  `sup` are ignored. Values can appear more than once.
731
930
 
732
931
  Returns
733
932
  -------
734
- u : int
735
- A random integer in the range ``[0, sup)``, and which satisfies
736
- ``u not in exclude``. If all values in the range are excluded, return
737
- `sup`.
933
+ A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
934
+
935
+ Notes
936
+ -----
937
+ If all values in the range are excluded, return `sup`.
738
938
  """
739
939
  exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
740
940
  num_allowed = sup - jnp.count_nonzero(exclude < sup)
@@ -747,58 +947,74 @@ def randint_exclude(key, sup, exclude):
747
947
  return u
748
948
 
749
949
 
750
- def choose_split(key, var_tree, split_tree, max_split, leaf_index):
950
+ def choose_split(
951
+ key: Key[Array, ''],
952
+ var_tree: UInt[Array, '2**(d-1)'],
953
+ split_tree: UInt[Array, '2**(d-1)'],
954
+ max_split: UInt[Array, 'p'],
955
+ leaf_index: Int32[Array, ''],
956
+ ) -> Int32[Array, '']:
751
957
  """
752
958
  Choose a split point for a new non-terminal node.
753
959
 
754
960
  Parameters
755
961
  ----------
756
- var_tree : int array (2 ** (d - 1),)
757
- The variable indices of the tree.
758
- split_tree : int array (2 ** (d - 1),)
962
+ key
963
+ A jax random key.
964
+ var_tree
965
+ The splitting axes of the tree.
966
+ split_tree
759
967
  The splitting points of the tree.
760
- max_split : int array (p,)
968
+ max_split
761
969
  The maximum split index for each variable.
762
- leaf_index : int
970
+ leaf_index
763
971
  The index of the leaf to grow. It is assumed that `var_tree` already
764
972
  contains the target variable at this index.
765
- key : jax.dtypes.prng_key array
766
- A jax random key.
767
973
 
768
974
  Returns
769
975
  -------
770
- split : int
771
- The split point.
976
+ The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
772
977
  """
773
978
  var = var_tree[leaf_index]
774
979
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
775
980
  return random.randint(key, (), l, r)
776
981
 
982
+ # TODO what happens if leaf_index is out of bounds? And is the value used
983
+ # in that case?
777
984
 
778
- def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
985
+
986
+ def compute_partial_ratio(
987
+ prob_choose: Float32[Array, ''],
988
+ num_prunable: Int32[Array, ''],
989
+ p_nonterminal: Float32[Array, 'd'],
990
+ leaf_to_grow: Int32[Array, ''],
991
+ ) -> Float32[Array, '']:
779
992
  """
780
993
  Compute the product of the transition and prior ratios of a grow move.
781
994
 
782
995
  Parameters
783
996
  ----------
784
- num_growable : int
785
- The number of leaf nodes that can be grown.
786
- num_prunable : int
997
+ prob_choose
998
+ The probability that the leaf had to be chosen amongst the growable
999
+ leaves.
1000
+ num_prunable
787
1001
  The number of leaf parents that could be pruned, after converting the
788
1002
  leaf to be grown to a non-terminal node.
789
- p_nonterminal : array (d,)
1003
+ p_nonterminal
790
1004
  The probability of a nonterminal node at each depth.
791
- leaf_to_grow : int
1005
+ leaf_to_grow
792
1006
  The index of the leaf to grow.
793
1007
 
794
1008
  Returns
795
1009
  -------
796
- ratio : float
797
- The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
798
- times the prior ratio P(new tree) / P(old tree), but the transition
799
- ratio is missing the factor P(propose prune) in the numerator.
800
- """
1010
+ The partial transition ratio times the prior ratio.
801
1011
 
1012
+ Notes
1013
+ -----
1014
+ The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1015
+ The "partial" transition ratio returned is missing the factor P(propose
1016
+ prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
1017
+ """
802
1018
  # the two ratios also contain factors num_available_split *
803
1019
  # num_available_var, but they cancel out
804
1020
 
@@ -822,42 +1038,55 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
822
1038
  return tree_ratio / inv_trans_ratio
823
1039
 
824
1040
 
825
- def prune_move(
826
- key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
827
- ):
1041
+ class PruneMoves(Module):
1042
+ """
1043
+ Represent a proposed prune move for each tree.
1044
+
1045
+ Parameters
1046
+ ----------
1047
+ allowed
1048
+ Whether the move is possible.
1049
+ node
1050
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
1051
+ partial_ratio
1052
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
1053
+ the likelihood ratio and the probability of proposing the prune
1054
+ move. This ratio is inverted, and is meant to be inverted back in
1055
+ `accept_move_and_sample_leaves`.
1056
+ """
1057
+
1058
+ allowed: Bool[Array, 'num_trees']
1059
+ node: UInt[Array, 'num_trees']
1060
+ partial_ratio: Float32[Array, 'num_trees']
1061
+
1062
+
1063
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1064
+ def propose_prune_moves(
1065
+ key: Key[Array, ''],
1066
+ split_tree: UInt[Array, '2**(d-1)'],
1067
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
1068
+ p_nonterminal: Float32[Array, 'd'],
1069
+ p_propose_grow: Float32[Array, '2**(d-1)'],
1070
+ ) -> PruneMoves:
828
1071
  """
829
1072
  Tree structure prune move proposal of BART MCMC.
830
1073
 
831
1074
  Parameters
832
1075
  ----------
833
- var_tree : int array (2 ** (d - 1),)
834
- The variable indices of the tree.
835
- split_tree : int array (2 ** (d - 1),)
1076
+ key
1077
+ A jax random key.
1078
+ split_tree
836
1079
  The splitting points of the tree.
837
- affluence_tree : bool array (2 ** (d - 1),) or None
1080
+ affluence_tree
838
1081
  Whether a leaf has enough points to be grown.
839
- max_split : int array (p,)
840
- The maximum split index for each variable.
841
- p_nonterminal : float array (d,)
1082
+ p_nonterminal
842
1083
  The probability of a nonterminal node at each depth.
843
- p_propose_grow : float array (2 ** (d - 1),)
1084
+ p_propose_grow
844
1085
  The unnormalized probability of choosing a leaf to grow.
845
- key : jax.dtypes.prng_key array
846
- A jax random key.
847
1086
 
848
1087
  Returns
849
1088
  -------
850
- prune_move : dict
851
- A dictionary with fields:
852
-
853
- 'allowed' : bool
854
- Whether the move is possible.
855
- 'node' : int
856
- The index of the node to prune. ``2 ** d`` if no node can be pruned.
857
- 'partial_ratio' : float
858
- A factor of the Metropolis-Hastings ratio of the move. It lacks
859
- the likelihood ratio and the probability of proposing the prune
860
- move. This ratio is inverted.
1089
+ An object representing the proposed moves.
861
1090
  """
862
1091
  node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
863
1092
  key, split_tree, affluence_tree, p_propose_grow
@@ -868,37 +1097,44 @@ def prune_move(
868
1097
  prob_choose, num_prunable, p_nonterminal, node_to_prune
869
1098
  )
870
1099
 
871
- return dict(
1100
+ return PruneMoves(
872
1101
  allowed=allowed,
873
1102
  node=node_to_prune,
874
- partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
1103
+ partial_ratio=ratio,
875
1104
  )
876
1105
 
877
1106
 
878
- def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
1107
+ def choose_leaf_parent(
1108
+ key: Key[Array, ''],
1109
+ split_tree: UInt[Array, '2**(d-1)'],
1110
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
1111
+ p_propose_grow: Float32[Array, '2**(d-1)'],
1112
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
879
1113
  """
880
1114
  Pick a non-terminal node with leaf children to prune in a tree.
881
1115
 
882
1116
  Parameters
883
1117
  ----------
884
- split_tree : array (2 ** (d - 1),)
1118
+ key
1119
+ A jax random key.
1120
+ split_tree
885
1121
  The splitting points of the tree.
886
- affluence_tree : bool array (2 ** (d - 1),) or None
1122
+ affluence_tree
887
1123
  Whether a leaf has enough points to be grown.
888
- p_propose_grow : array (2 ** (d - 1),)
1124
+ p_propose_grow
889
1125
  The unnormalized probability of choosing a leaf to grow.
890
- key : jax.dtypes.prng_key array
891
- A jax random key.
892
1126
 
893
1127
  Returns
894
1128
  -------
895
- node_to_prune : int
1129
+ node_to_prune : Int32[Array, '']
896
1130
  The index of the node to prune. If ``num_prunable == 0``, return
897
1131
  ``2 ** d``.
898
- num_prunable : int
1132
+ num_prunable : Int32[Array, '']
899
1133
  The number of leaf parents that could be pruned.
900
- prob_choose : float
901
- The normalized probability of choosing the node to prune for growth.
1134
+ prob_choose : Float32[Array, '']
1135
+ The (normalized) probability that `choose_leaf` would chose
1136
+ `node_to_prune` as leaf to grow, if passed the tree where
1137
+ `node_to_prune` had been pruned.
902
1138
  """
903
1139
  is_prunable = grove.is_leaves_parent(split_tree)
904
1140
  num_prunable = jnp.count_nonzero(is_prunable)
@@ -906,9 +1142,8 @@ def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
906
1142
  node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
907
1143
 
908
1144
  split_tree = split_tree.at[node_to_prune].set(0)
909
- affluence_tree = (
910
- None if affluence_tree is None else affluence_tree.at[node_to_prune].set(True)
911
- )
1145
+ if affluence_tree is not None:
1146
+ affluence_tree = affluence_tree.at[node_to_prune].set(True)
912
1147
  is_growable_leaf = growable_leaves(split_tree, affluence_tree)
913
1148
  prob_choose = p_propose_grow[node_to_prune]
914
1149
  prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
@@ -916,56 +1151,196 @@ def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
916
1151
  return node_to_prune, num_prunable, prob_choose
917
1152
 
918
1153
 
919
- def randint_masked(key, mask):
1154
+ def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
920
1155
  """
921
1156
  Return a random integer in a range, including only some values.
922
1157
 
923
1158
  Parameters
924
1159
  ----------
925
- key : jax.dtypes.prng_key array
1160
+ key
926
1161
  A jax random key.
927
- mask : bool array (n,)
1162
+ mask
928
1163
  The mask indicating the allowed values.
929
1164
 
930
1165
  Returns
931
1166
  -------
932
- u : int
933
- A random integer in the range ``[0, n)``, and which satisfies
934
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
1167
+ A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1168
+
1169
+ Notes
1170
+ -----
1171
+ If all values in the mask are `False`, return `n`.
935
1172
  """
936
1173
  ecdf = jnp.cumsum(mask)
937
1174
  u = random.randint(key, (), 0, ecdf[-1])
938
1175
  return jnp.searchsorted(ecdf, u, 'right')
939
1176
 
940
1177
 
941
- def accept_moves_and_sample_leaves(key, bart, moves):
1178
+ def accept_moves_and_sample_leaves(
1179
+ key: Key[Array, ''], bart: State, moves: Moves
1180
+ ) -> State:
942
1181
  """
943
1182
  Accept or reject the proposed moves and sample the new leaf values.
944
1183
 
945
1184
  Parameters
946
1185
  ----------
947
- key : jax.dtypes.prng_key array
1186
+ key
948
1187
  A jax random key.
949
- bart : dict
950
- A BART mcmc state.
951
- moves : dict
952
- The proposed moves, see `sample_moves`.
1188
+ bart
1189
+ A valid BART mcmc state.
1190
+ moves
1191
+ The proposed moves, see `propose_moves`.
953
1192
 
954
1193
  Returns
955
1194
  -------
956
- bart : dict
957
- The new BART mcmc state.
1195
+ A new (valid) BART mcmc state.
958
1196
  """
959
- bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf = (
960
- accept_moves_parallel_stage(key, bart, moves)
961
- )
962
- bart, moves = accept_moves_sequential_stage(
963
- bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
964
- )
1197
+ pso = accept_moves_parallel_stage(key, bart, moves)
1198
+ bart, moves = accept_moves_sequential_stage(pso)
965
1199
  return accept_moves_final_stage(bart, moves)
966
1200
 
967
1201
 
968
- def accept_moves_parallel_stage(key, bart, moves):
1202
+ class Counts(Module):
1203
+ """
1204
+ Number of datapoints in the nodes involved in proposed moves for each tree.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ left
1209
+ Number of datapoints in the left child.
1210
+ right
1211
+ Number of datapoints in the right child.
1212
+ total
1213
+ Number of datapoints in the parent (``= left + right``).
1214
+ """
1215
+
1216
+ left: UInt[Array, 'num_trees']
1217
+ right: UInt[Array, 'num_trees']
1218
+ total: UInt[Array, 'num_trees']
1219
+
1220
+
1221
+ class Precs(Module):
1222
+ """
1223
+ Likelihood precision scale in the nodes involved in proposed moves for each tree.
1224
+
1225
+ The "likelihood precision scale" of a tree node is the sum of the inverse
1226
+ squared error scales of the datapoints selected by the node.
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ left
1231
+ Likelihood precision scale in the left child.
1232
+ right
1233
+ Likelihood precision scale in the right child.
1234
+ total
1235
+ Likelihood precision scale in the parent (``= left + right``).
1236
+ """
1237
+
1238
+ left: Float32[Array, 'num_trees']
1239
+ right: Float32[Array, 'num_trees']
1240
+ total: Float32[Array, 'num_trees']
1241
+
1242
+
1243
+ class PreLkV(Module):
1244
+ """
1245
+ Non-sequential terms of the likelihood ratio for each tree.
1246
+
1247
+ These terms can be computed in parallel across trees.
1248
+
1249
+ Parameters
1250
+ ----------
1251
+ sigma2_left
1252
+ The noise variance in the left child of the leaves grown or pruned by
1253
+ the moves.
1254
+ sigma2_right
1255
+ The noise variance in the right child of the leaves grown or pruned by
1256
+ the moves.
1257
+ sigma2_total
1258
+ The noise variance in the total of the leaves grown or pruned by the
1259
+ moves.
1260
+ sqrt_term
1261
+ The **logarithm** of the square root term of the likelihood ratio.
1262
+ """
1263
+
1264
+ sigma2_left: Float32[Array, 'num_trees']
1265
+ sigma2_right: Float32[Array, 'num_trees']
1266
+ sigma2_total: Float32[Array, 'num_trees']
1267
+ sqrt_term: Float32[Array, 'num_trees']
1268
+
1269
+
1270
+ class PreLk(Module):
1271
+ """
1272
+ Non-sequential terms of the likelihood ratio shared by all trees.
1273
+
1274
+ Parameters
1275
+ ----------
1276
+ exp_factor
1277
+ The factor to multiply the likelihood ratio by, shared by all trees.
1278
+ """
1279
+
1280
+ exp_factor: Float32[Array, '']
1281
+
1282
+
1283
+ class PreLf(Module):
1284
+ """
1285
+ Pre-computed terms used to sample leaves from their posterior.
1286
+
1287
+ These terms can be computed in parallel across trees.
1288
+
1289
+ Parameters
1290
+ ----------
1291
+ mean_factor
1292
+ The factor to be multiplied by the sum of the scaled residuals to
1293
+ obtain the posterior mean.
1294
+ centered_leaves
1295
+ The mean-zero normal values to be added to the posterior mean to
1296
+ obtain the posterior leaf samples.
1297
+ """
1298
+
1299
+ mean_factor: Float32[Array, 'num_trees 2**d']
1300
+ centered_leaves: Float32[Array, 'num_trees 2**d']
1301
+
1302
+
1303
+ class ParallelStageOut(Module):
1304
+ """
1305
+ The output of `accept_moves_parallel_stage`.
1306
+
1307
+ Parameters
1308
+ ----------
1309
+ bart
1310
+ A partially updated BART mcmc state.
1311
+ moves
1312
+ The proposed moves, with `partial_ratio` set to `None` and
1313
+ `log_trans_prior_ratio` set to its final value.
1314
+ prec_trees
1315
+ The likelihood precision scale in each potential or actual leaf node. If
1316
+ there is no precision scale, this is the number of points in each leaf.
1317
+ move_counts
1318
+ The counts of the number of points in the the nodes modified by the
1319
+ moves. If `bart.min_points_per_leaf` is not set and
1320
+ `bart.prec_scale` is set, they are not computed.
1321
+ move_precs
1322
+ The likelihood precision scale in each node modified by the moves. If
1323
+ `bart.prec_scale` is not set, this is set to `move_counts`.
1324
+ prelkv
1325
+ prelk
1326
+ prelf
1327
+ Objects with pre-computed terms of the likelihood ratios and leaf
1328
+ samples.
1329
+ """
1330
+
1331
+ bart: State
1332
+ moves: Moves
1333
+ prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1334
+ move_counts: Counts | None
1335
+ move_precs: Precs | Counts
1336
+ prelkv: PreLkV
1337
+ prelk: PreLk
1338
+ prelf: PreLf
1339
+
1340
+
1341
+ def accept_moves_parallel_stage(
1342
+ key: Key[Array, ''], bart: State, moves: Moves
1343
+ ) -> ParallelStageOut:
969
1344
  """
970
1345
  Pre-computes quantities used to accept moves, in parallel across trees.
971
1346
 
@@ -976,88 +1351,110 @@ def accept_moves_parallel_stage(key, bart, moves):
976
1351
  bart : dict
977
1352
  A BART mcmc state.
978
1353
  moves : dict
979
- The proposed moves, see `sample_moves`.
1354
+ The proposed moves, see `propose_moves`.
980
1355
 
981
1356
  Returns
982
1357
  -------
983
- bart : dict
984
- A partially updated BART mcmc state.
985
- moves : dict
986
- The proposed moves, with the field 'partial_ratio' replaced
987
- by 'log_trans_prior_ratio'.
988
- prec_trees : float array (num_trees, 2 ** d)
989
- The likelihood precision scale in each potential or actual leaf node. If
990
- there is no precision scale, this is the number of points in each leaf.
991
- move_counts : dict
992
- The counts of the number of points in the the nodes modified by the
993
- moves.
994
- move_precs : dict
995
- The likelihood precision scale in each node modified by the moves.
996
- prelkv, prelk, prelf : dict
997
- Dictionary with pre-computed terms of the likelihood ratios and leaf
998
- samples.
1358
+ An object with all that could be done in parallel.
999
1359
  """
1000
- bart = bart.copy()
1001
-
1002
1360
  # where the move is grow, modify the state like the move was accepted
1003
- bart['var_trees'] = moves['var_trees']
1004
- bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
1005
- bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
1361
+ bart = replace(
1362
+ bart,
1363
+ forest=replace(
1364
+ bart.forest,
1365
+ var_trees=moves.var_trees,
1366
+ leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1367
+ leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
1368
+ ),
1369
+ )
1006
1370
 
1007
1371
  # count number of datapoints per leaf
1008
- count_trees, move_counts = compute_count_trees(
1009
- bart['leaf_indices'], moves, bart['opt']['count_batch_size']
1010
- )
1011
- if bart['opt']['require_min_points']:
1012
- count_half_trees = count_trees[:, : bart['var_trees'].shape[1]]
1013
- bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
1372
+ if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
1373
+ count_trees, move_counts = compute_count_trees(
1374
+ bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1375
+ )
1376
+ else:
1377
+ # move_counts is passed later to a function, but then is unused under
1378
+ # this condition
1379
+ move_counts = None
1380
+
1381
+ # Check if some nodes can't surely be grown because they don't have enough
1382
+ # datapoints. This check is not actually used now, it will be used at the
1383
+ # beginning of the next step to propose moves.
1384
+ if bart.forest.min_points_per_leaf is not None:
1385
+ count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
1386
+ bart = replace(
1387
+ bart,
1388
+ forest=replace(
1389
+ bart.forest,
1390
+ affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
1391
+ ),
1392
+ )
1014
1393
 
1015
1394
  # count number of datapoints per leaf, weighted by error precision scale
1016
- if bart['prec_scale'] is None:
1395
+ if bart.prec_scale is None:
1017
1396
  prec_trees = count_trees
1018
1397
  move_precs = move_counts
1019
1398
  else:
1020
1399
  prec_trees, move_precs = compute_prec_trees(
1021
- bart['prec_scale'],
1022
- bart['leaf_indices'],
1400
+ bart.prec_scale,
1401
+ bart.forest.leaf_indices,
1023
1402
  moves,
1024
- bart['opt']['count_batch_size'],
1403
+ bart.forest.count_batch_size,
1025
1404
  )
1026
1405
 
1027
1406
  # compute some missing information about moves
1028
- moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
1029
- bart['grow_prop_count'] = jnp.sum(moves['grow'])
1030
- bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
1031
-
1032
- prelkv, prelk = precompute_likelihood_terms(bart['sigma2'], move_precs)
1033
- prelf = precompute_leaf_terms(key, prec_trees, bart['sigma2'])
1407
+ moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
1408
+ bart = replace(
1409
+ bart,
1410
+ forest=replace(
1411
+ bart.forest,
1412
+ grow_prop_count=jnp.sum(moves.grow),
1413
+ prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1414
+ ),
1415
+ )
1034
1416
 
1035
- return bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf
1417
+ prelkv, prelk = precompute_likelihood_terms(
1418
+ bart.sigma2, bart.forest.sigma_mu2, move_precs
1419
+ )
1420
+ prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
1421
+
1422
+ return ParallelStageOut(
1423
+ bart=bart,
1424
+ moves=moves,
1425
+ prec_trees=prec_trees,
1426
+ move_counts=move_counts,
1427
+ move_precs=move_precs,
1428
+ prelkv=prelkv,
1429
+ prelk=prelk,
1430
+ prelf=prelf,
1431
+ )
1036
1432
 
1037
1433
 
1038
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
1039
- def apply_grow_to_indices(moves, leaf_indices, X):
1434
+ @partial(vmap_nodoc, in_axes=(0, 0, None))
1435
+ def apply_grow_to_indices(
1436
+ moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1437
+ ) -> UInt[Array, 'num_trees n']:
1040
1438
  """
1041
1439
  Update the leaf indices to apply a grow move.
1042
1440
 
1043
1441
  Parameters
1044
1442
  ----------
1045
- moves : dict
1046
- The proposed moves, see `sample_moves`.
1047
- leaf_indices : array (num_trees, n)
1443
+ moves
1444
+ The proposed moves, see `propose_moves`.
1445
+ leaf_indices
1048
1446
  The index of the leaf each datapoint falls into.
1049
- X : array (p, n)
1447
+ X
1050
1448
  The predictors matrix.
1051
1449
 
1052
1450
  Returns
1053
1451
  -------
1054
- grow_leaf_indices : array (num_trees, n)
1055
- The updated leaf indices.
1452
+ The updated leaf indices.
1056
1453
  """
1057
- left_child = moves['node'].astype(leaf_indices.dtype) << 1
1058
- go_right = X[moves['grow_var'], :] >= moves['grow_split']
1059
- tree_size = jnp.array(2 * moves['var_trees'].size)
1060
- node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
1454
+ left_child = moves.node.astype(leaf_indices.dtype) << 1
1455
+ go_right = X[moves.grow_var, :] >= moves.grow_split
1456
+ tree_size = jnp.array(2 * moves.var_trees.size)
1457
+ node_to_update = jnp.where(moves.grow, moves.node, tree_size)
1061
1458
  return jnp.where(
1062
1459
  leaf_indices == node_to_update,
1063
1460
  left_child + go_right,
@@ -1065,64 +1462,65 @@ def apply_grow_to_indices(moves, leaf_indices, X):
1065
1462
  )
1066
1463
 
1067
1464
 
1068
- def compute_count_trees(leaf_indices, moves, batch_size):
1465
+ def compute_count_trees(
1466
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1467
+ ) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1069
1468
  """
1070
1469
  Count the number of datapoints in each leaf.
1071
1470
 
1072
1471
  Parameters
1073
1472
  ----------
1074
- leaf_indices : int array (num_trees, n)
1473
+ leaf_indices
1075
1474
  The index of the leaf each datapoint falls into, with the deeper version
1076
1475
  of the tree (post-GROW, pre-PRUNE).
1077
- moves : dict
1078
- The proposed moves, see `sample_moves`.
1079
- batch_size : int or None
1476
+ moves
1477
+ The proposed moves, see `propose_moves`.
1478
+ batch_size
1080
1479
  The data batch size to use for the summation.
1081
1480
 
1082
1481
  Returns
1083
1482
  -------
1084
- count_trees : int array (num_trees, 2 ** (d - 1))
1483
+ count_trees : Int32[Array, 'num_trees 2**d']
1085
1484
  The number of points in each potential or actual leaf node.
1086
- counts : dict
1485
+ counts : Counts
1087
1486
  The counts of the number of points in the leaves grown or pruned by the
1088
- moves, under keys 'left', 'right', and 'total' (left + right).
1487
+ moves.
1089
1488
  """
1090
-
1091
- ntree, tree_size = moves['var_trees'].shape
1489
+ num_trees, tree_size = moves.var_trees.shape
1092
1490
  tree_size *= 2
1093
- tree_indices = jnp.arange(ntree)
1491
+ tree_indices = jnp.arange(num_trees)
1094
1492
 
1095
1493
  count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1096
1494
 
1097
1495
  # count datapoints in nodes modified by move
1098
- counts = dict()
1099
- counts['left'] = count_trees[tree_indices, moves['left']]
1100
- counts['right'] = count_trees[tree_indices, moves['right']]
1101
- counts['total'] = counts['left'] + counts['right']
1496
+ left = count_trees[tree_indices, moves.left]
1497
+ right = count_trees[tree_indices, moves.right]
1498
+ counts = Counts(left=left, right=right, total=left + right)
1102
1499
 
1103
1500
  # write count into non-leaf node
1104
- count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
1501
+ count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
1105
1502
 
1106
1503
  return count_trees, counts
1107
1504
 
1108
1505
 
1109
- def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1506
+ def count_datapoints_per_leaf(
1507
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1508
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1110
1509
  """
1111
1510
  Count the number of datapoints in each leaf.
1112
1511
 
1113
1512
  Parameters
1114
1513
  ----------
1115
- leaf_indices : int array (num_trees, n)
1514
+ leaf_indices
1116
1515
  The index of the leaf each datapoint falls into.
1117
- tree_size : int
1516
+ tree_size
1118
1517
  The size of the leaf tree array (2 ** d).
1119
- batch_size : int or None
1518
+ batch_size
1120
1519
  The data batch size to use for the summation.
1121
1520
 
1122
1521
  Returns
1123
1522
  -------
1124
- count_trees : int array (num_trees, 2 ** (d - 1))
1125
- The number of points in each leaf node.
1523
+ The number of points in each leaf node.
1126
1524
  """
1127
1525
  if batch_size is None:
1128
1526
  return _count_scan(leaf_indices, tree_size)
@@ -1130,7 +1528,9 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1130
1528
  return _count_vec(leaf_indices, tree_size, batch_size)
1131
1529
 
1132
1530
 
1133
- def _count_scan(leaf_indices, tree_size):
1531
+ def _count_scan(
1532
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1533
+ ) -> Int32[Array, 'num_trees {tree_size}']:
1134
1534
  def loop(_, leaf_indices):
1135
1535
  return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1136
1536
 
@@ -1138,92 +1538,111 @@ def _count_scan(leaf_indices, tree_size):
1138
1538
  return count_trees
1139
1539
 
1140
1540
 
1141
- def _aggregate_scatter(values, indices, size, dtype):
1541
+ def _aggregate_scatter(
1542
+ values: Shaped[Array, '*'],
1543
+ indices: Integer[Array, '*'],
1544
+ size: int,
1545
+ dtype: jnp.dtype,
1546
+ ) -> Shaped[Array, '{size}']:
1142
1547
  return jnp.zeros(size, dtype).at[indices].add(values)
1143
1548
 
1144
1549
 
1145
- def _count_vec(leaf_indices, tree_size, batch_size):
1550
+ def _count_vec(
1551
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1552
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1146
1553
  return _aggregate_batched_alltrees(
1147
1554
  1, leaf_indices, tree_size, jnp.uint32, batch_size
1148
1555
  )
1149
1556
  # uint16 is super-slow on gpu, don't use it even if n < 2^16
1150
1557
 
1151
1558
 
1152
- def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1153
- ntree, n = indices.shape
1154
- tree_indices = jnp.arange(ntree)
1559
+ def _aggregate_batched_alltrees(
1560
+ values: Shaped[Array, '*'],
1561
+ indices: UInt[Array, 'num_trees n'],
1562
+ size: int,
1563
+ dtype: jnp.dtype,
1564
+ batch_size: int,
1565
+ ) -> Shaped[Array, 'num_trees {size}']:
1566
+ num_trees, n = indices.shape
1567
+ tree_indices = jnp.arange(num_trees)
1155
1568
  nbatches = n // batch_size + bool(n % batch_size)
1156
1569
  batch_indices = jnp.arange(n) % nbatches
1157
1570
  return (
1158
- jnp.zeros((ntree, size, nbatches), dtype)
1571
+ jnp.zeros((num_trees, size, nbatches), dtype)
1159
1572
  .at[tree_indices[:, None], indices, batch_indices]
1160
1573
  .add(values)
1161
1574
  .sum(axis=2)
1162
1575
  )
1163
1576
 
1164
1577
 
1165
- def compute_prec_trees(prec_scale, leaf_indices, moves, batch_size):
1578
+ def compute_prec_trees(
1579
+ prec_scale: Float32[Array, 'n'],
1580
+ leaf_indices: UInt[Array, 'num_trees n'],
1581
+ moves: Moves,
1582
+ batch_size: int | None,
1583
+ ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1166
1584
  """
1167
1585
  Compute the likelihood precision scale in each leaf.
1168
1586
 
1169
1587
  Parameters
1170
1588
  ----------
1171
- prec_scale : float array (n,)
1589
+ prec_scale
1172
1590
  The scale of the precision of the error on each datapoint.
1173
- leaf_indices : int array (num_trees, n)
1591
+ leaf_indices
1174
1592
  The index of the leaf each datapoint falls into, with the deeper version
1175
1593
  of the tree (post-GROW, pre-PRUNE).
1176
- moves : dict
1177
- The proposed moves, see `sample_moves`.
1178
- batch_size : int or None
1594
+ moves
1595
+ The proposed moves, see `propose_moves`.
1596
+ batch_size
1179
1597
  The data batch size to use for the summation.
1180
1598
 
1181
1599
  Returns
1182
1600
  -------
1183
- prec_trees : float array (num_trees, 2 ** (d - 1))
1601
+ prec_trees : Float32[Array, 'num_trees 2**d']
1184
1602
  The likelihood precision scale in each potential or actual leaf node.
1185
- counts : dict
1186
- The likelihood precision scale in the leaves grown or pruned by the
1187
- moves, under keys 'left', 'right', and 'total' (left + right).
1603
+ precs : Precs
1604
+ The likelihood precision scale in the nodes involved in the moves.
1188
1605
  """
1189
-
1190
- ntree, tree_size = moves['var_trees'].shape
1606
+ num_trees, tree_size = moves.var_trees.shape
1191
1607
  tree_size *= 2
1192
- tree_indices = jnp.arange(ntree)
1608
+ tree_indices = jnp.arange(num_trees)
1193
1609
 
1194
1610
  prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1195
1611
 
1196
1612
  # prec datapoints in nodes modified by move
1197
- precs = dict()
1198
- precs['left'] = prec_trees[tree_indices, moves['left']]
1199
- precs['right'] = prec_trees[tree_indices, moves['right']]
1200
- precs['total'] = precs['left'] + precs['right']
1613
+ left = prec_trees[tree_indices, moves.left]
1614
+ right = prec_trees[tree_indices, moves.right]
1615
+ precs = Precs(left=left, right=right, total=left + right)
1201
1616
 
1202
1617
  # write prec into non-leaf node
1203
- prec_trees = prec_trees.at[tree_indices, moves['node']].set(precs['total'])
1618
+ prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
1204
1619
 
1205
1620
  return prec_trees, precs
1206
1621
 
1207
1622
 
1208
- def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
1623
+ def prec_per_leaf(
1624
+ prec_scale: Float32[Array, 'n'],
1625
+ leaf_indices: UInt[Array, 'num_trees n'],
1626
+ tree_size: int,
1627
+ batch_size: int | None,
1628
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1209
1629
  """
1210
1630
  Compute the likelihood precision scale in each leaf.
1211
1631
 
1212
1632
  Parameters
1213
1633
  ----------
1214
- prec_scale : float array (n,)
1634
+ prec_scale
1215
1635
  The scale of the precision of the error on each datapoint.
1216
- leaf_indices : int array (num_trees, n)
1636
+ leaf_indices
1217
1637
  The index of the leaf each datapoint falls into.
1218
- tree_size : int
1638
+ tree_size
1219
1639
  The size of the leaf tree array (2 ** d).
1220
- batch_size : int or None
1640
+ batch_size
1221
1641
  The data batch size to use for the summation.
1222
1642
 
1223
1643
  Returns
1224
1644
  -------
1225
- prec_trees : int array (num_trees, 2 ** (d - 1))
1226
- The likelihood precision scale in each leaf node.
1645
+ The likelihood precision scale in each leaf node.
1227
1646
  """
1228
1647
  if batch_size is None:
1229
1648
  return _prec_scan(prec_scale, leaf_indices, tree_size)
@@ -1231,23 +1650,34 @@ def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
1231
1650
  return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1232
1651
 
1233
1652
 
1234
- def _prec_scan(prec_scale, leaf_indices, tree_size):
1653
+ def _prec_scan(
1654
+ prec_scale: Float32[Array, 'n'],
1655
+ leaf_indices: UInt[Array, 'num_trees n'],
1656
+ tree_size: int,
1657
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1235
1658
  def loop(_, leaf_indices):
1236
1659
  return None, _aggregate_scatter(
1237
1660
  prec_scale, leaf_indices, tree_size, jnp.float32
1238
- ) # TODO: use large_float
1661
+ )
1239
1662
 
1240
1663
  _, prec_trees = lax.scan(loop, None, leaf_indices)
1241
1664
  return prec_trees
1242
1665
 
1243
1666
 
1244
- def _prec_vec(prec_scale, leaf_indices, tree_size, batch_size):
1667
+ def _prec_vec(
1668
+ prec_scale: Float32[Array, 'n'],
1669
+ leaf_indices: UInt[Array, 'num_trees n'],
1670
+ tree_size: int,
1671
+ batch_size: int,
1672
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1245
1673
  return _aggregate_batched_alltrees(
1246
1674
  prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1247
- ) # TODO: use large_float
1675
+ )
1248
1676
 
1249
1677
 
1250
- def complete_ratio(moves, move_counts, min_points_per_leaf):
1678
+ def complete_ratio(
1679
+ moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1680
+ ) -> Moves:
1251
1681
  """
1252
1682
  Complete non-likelihood MH ratio calculation.
1253
1683
 
@@ -1255,330 +1685,367 @@ def complete_ratio(moves, move_counts, min_points_per_leaf):
1255
1685
 
1256
1686
  Parameters
1257
1687
  ----------
1258
- moves : dict
1259
- The proposed moves, see `sample_moves`.
1260
- move_counts : dict
1688
+ moves
1689
+ The proposed moves, see `propose_moves`.
1690
+ move_counts
1261
1691
  The counts of the number of points in the the nodes modified by the
1262
1692
  moves.
1263
- min_points_per_leaf : int or None
1693
+ min_points_per_leaf
1264
1694
  The minimum number of data points in a leaf node.
1265
1695
 
1266
1696
  Returns
1267
1697
  -------
1268
- moves : dict
1269
- The updated moves, with the field 'partial_ratio' replaced by
1270
- 'log_trans_prior_ratio'.
1698
+ The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1271
1699
  """
1272
- moves = moves.copy()
1273
- p_prune = compute_p_prune(
1274
- moves, move_counts['left'], move_counts['right'], min_points_per_leaf
1700
+ p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
1701
+ return replace(
1702
+ moves,
1703
+ log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
1704
+ partial_ratio=None,
1275
1705
  )
1276
- moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1277
- return moves
1278
1706
 
1279
1707
 
1280
- def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1708
+ def compute_p_prune(
1709
+ moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1710
+ ) -> Float32[Array, 'num_trees']:
1281
1711
  """
1282
- Compute the probability of proposing a prune move.
1712
+ Compute the probability of proposing a prune move for each tree.
1283
1713
 
1284
1714
  Parameters
1285
1715
  ----------
1286
- moves : dict
1287
- The proposed moves, see `sample_moves`.
1288
- left_count, right_count : int
1716
+ moves
1717
+ The proposed moves, see `propose_moves`.
1718
+ move_counts
1289
1719
  The number of datapoints in the proposed children of the leaf to grow.
1290
- min_points_per_leaf : int or None
1720
+ Not used if `min_points_per_leaf` is `None`.
1721
+ min_points_per_leaf
1291
1722
  The minimum number of data points in a leaf node.
1292
1723
 
1293
1724
  Returns
1294
1725
  -------
1295
- p_prune : float
1296
- The probability of proposing a prune move. If grow: after accepting the
1297
- grow move, if prune: right away.
1298
- """
1726
+ The probability of proposing a prune move.
1299
1727
 
1728
+ Notes
1729
+ -----
1730
+ This probability is computed for going from the state with the deeper tree
1731
+ to the one with the shallower one. This means, if grow: after accepting the
1732
+ grow move, if prune: right away.
1733
+ """
1300
1734
  # calculation in case the move is grow
1301
- other_growable_leaves = moves['num_growable'] >= 2
1302
- new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
1735
+ other_growable_leaves = moves.num_growable >= 2
1736
+ new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
1303
1737
  if min_points_per_leaf is not None:
1304
- any_above_threshold = left_count >= 2 * min_points_per_leaf
1305
- any_above_threshold |= right_count >= 2 * min_points_per_leaf
1738
+ assert move_counts is not None
1739
+ any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
1740
+ any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
1306
1741
  new_leaves_growable &= any_above_threshold
1307
1742
  grow_again_allowed = other_growable_leaves | new_leaves_growable
1308
1743
  grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1309
1744
 
1310
1745
  # calculation in case the move is prune
1311
- prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
1746
+ prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1312
1747
 
1313
- return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1748
+ return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1314
1749
 
1315
1750
 
1316
- @jaxext.vmap_nodoc
1317
- def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1751
+ @vmap_nodoc
1752
+ def adapt_leaf_trees_to_grow_indices(
1753
+ leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1754
+ ) -> Float32[Array, 'num_trees 2**d']:
1318
1755
  """
1319
- Modify leaf values such that the indices of the grow moves work on the
1320
- original tree.
1756
+ Modify leaves such that post-grow indices work on the original tree.
1757
+
1758
+ The value of the leaf to grow is copied to what would be its children if the
1759
+ grow move was accepted.
1321
1760
 
1322
1761
  Parameters
1323
1762
  ----------
1324
- leaf_trees : float array (num_trees, 2 ** d)
1763
+ leaf_trees
1325
1764
  The leaf values.
1326
- moves : dict
1327
- The proposed moves, see `sample_moves`.
1765
+ moves
1766
+ The proposed moves, see `propose_moves`.
1328
1767
 
1329
1768
  Returns
1330
1769
  -------
1331
- leaf_trees : float array (num_trees, 2 ** d)
1332
- The modified leaf values. The value of the leaf to grow is copied to
1333
- what would be its children if the grow move was accepted.
1770
+ The modified leaf values.
1334
1771
  """
1335
- values_at_node = leaf_trees[moves['node']]
1772
+ values_at_node = leaf_trees[moves.node]
1336
1773
  return (
1337
- leaf_trees.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1774
+ leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1338
1775
  .set(values_at_node)
1339
- .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1776
+ .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1340
1777
  .set(values_at_node)
1341
1778
  )
1342
1779
 
1343
1780
 
1344
- def precompute_likelihood_terms(sigma2, move_precs):
1781
+ def precompute_likelihood_terms(
1782
+ sigma2: Float32[Array, ''],
1783
+ sigma_mu2: Float32[Array, ''],
1784
+ move_precs: Precs | Counts,
1785
+ ) -> tuple[PreLkV, PreLk]:
1345
1786
  """
1346
1787
  Pre-compute terms used in the likelihood ratio of the acceptance step.
1347
1788
 
1348
1789
  Parameters
1349
1790
  ----------
1350
- sigma2 : float
1351
- The noise variance.
1352
- move_precs : dict
1791
+ sigma2
1792
+ The error variance, or the global error variance factor is `prec_scale`
1793
+ is set.
1794
+ sigma_mu2
1795
+ The prior variance of each leaf.
1796
+ move_precs
1353
1797
  The likelihood precision scale in the leaves grown or pruned by the
1354
1798
  moves, under keys 'left', 'right', and 'total' (left + right).
1355
1799
 
1356
1800
  Returns
1357
1801
  -------
1358
- prelkv : dict
1802
+ prelkv : PreLkV
1359
1803
  Dictionary with pre-computed terms of the likelihood ratio, one per
1360
1804
  tree.
1361
- prelk : dict
1805
+ prelk : PreLk
1362
1806
  Dictionary with pre-computed terms of the likelihood ratio, shared by
1363
1807
  all trees.
1364
1808
  """
1365
- ntree = len(move_precs['total'])
1366
- sigma_mu2 = 1 / ntree
1367
- prelkv = dict()
1368
- prelkv['sigma2_left'] = sigma2 + move_precs['left'] * sigma_mu2
1369
- prelkv['sigma2_right'] = sigma2 + move_precs['right'] * sigma_mu2
1370
- prelkv['sigma2_total'] = sigma2 + move_precs['total'] * sigma_mu2
1371
- prelkv['sqrt_term'] = (
1372
- jnp.log(
1373
- sigma2
1374
- * prelkv['sigma2_total']
1375
- / (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1376
- )
1377
- / 2
1809
+ sigma2_left = sigma2 + move_precs.left * sigma_mu2
1810
+ sigma2_right = sigma2 + move_precs.right * sigma_mu2
1811
+ sigma2_total = sigma2 + move_precs.total * sigma_mu2
1812
+ prelkv = PreLkV(
1813
+ sigma2_left=sigma2_left,
1814
+ sigma2_right=sigma2_right,
1815
+ sigma2_total=sigma2_total,
1816
+ sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1378
1817
  )
1379
- return prelkv, dict(
1818
+ return prelkv, PreLk(
1380
1819
  exp_factor=sigma_mu2 / (2 * sigma2),
1381
1820
  )
1382
1821
 
1383
1822
 
1384
- def precompute_leaf_terms(key, prec_trees, sigma2):
1823
+ def precompute_leaf_terms(
1824
+ key: Key[Array, ''],
1825
+ prec_trees: Float32[Array, 'num_trees 2**d'],
1826
+ sigma2: Float32[Array, ''],
1827
+ sigma_mu2: Float32[Array, ''],
1828
+ ) -> PreLf:
1385
1829
  """
1386
1830
  Pre-compute terms used to sample leaves from their posterior.
1387
1831
 
1388
1832
  Parameters
1389
1833
  ----------
1390
- key : jax.dtypes.prng_key array
1834
+ key
1391
1835
  A jax random key.
1392
- prec_trees : array (num_trees, 2 ** d)
1836
+ prec_trees
1393
1837
  The likelihood precision scale in each potential or actual leaf node.
1394
- sigma2 : float
1395
- The noise variance.
1838
+ sigma2
1839
+ The error variance, or the global error variance factor if `prec_scale`
1840
+ is set.
1841
+ sigma_mu2
1842
+ The prior variance of each leaf.
1396
1843
 
1397
1844
  Returns
1398
1845
  -------
1399
- prelf : dict
1400
- Dictionary with pre-computed terms of the leaf sampling, with fields:
1401
-
1402
- 'mean_factor' : float array (num_trees, 2 ** d)
1403
- The factor to be multiplied by the sum of the scaled residuals to
1404
- obtain the posterior mean.
1405
- 'centered_leaves' : float array (num_trees, 2 ** d)
1406
- The mean-zero normal values to be added to the posterior mean to
1407
- obtain the posterior leaf samples.
1846
+ Pre-computed terms for leaf sampling.
1408
1847
  """
1409
- ntree = len(prec_trees)
1410
1848
  prec_lk = prec_trees / sigma2
1411
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1849
+ prec_prior = lax.reciprocal(sigma_mu2)
1850
+ var_post = lax.reciprocal(prec_lk + prec_prior)
1412
1851
  z = random.normal(key, prec_trees.shape, sigma2.dtype)
1413
- return dict(
1414
- mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1852
+ return PreLf(
1853
+ mean_factor=var_post / sigma2,
1854
+ # mean = mean_lk * prec_lk * var_post
1855
+ # resid_tree = mean_lk * prec_tree -->
1856
+ # --> mean_lk = resid_tree / prec_tree (kind of)
1857
+ # mean_factor =
1858
+ # = mean / resid_tree =
1859
+ # = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1860
+ # = 1 / prec_tree * prec_tree / sigma2 * var_post =
1861
+ # = var_post / sigma2
1415
1862
  centered_leaves=z * jnp.sqrt(var_post),
1416
1863
  )
1417
1864
 
1418
1865
 
1419
- def accept_moves_sequential_stage(
1420
- bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
1421
- ):
1866
+ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1422
1867
  """
1423
- The part of accepting the moves that has to be done one tree at a time.
1868
+ Accept/reject the moves one tree at a time.
1869
+
1870
+ This is the most performance-sensitive function because it contains all and
1871
+ only the parts of the algorithm that can not be parallelized across trees.
1424
1872
 
1425
1873
  Parameters
1426
1874
  ----------
1427
- bart : dict
1428
- A partially updated BART mcmc state.
1429
- prec_trees : float array (num_trees, 2 ** d)
1430
- The likelihood precision scale in each potential or actual leaf node.
1431
- moves : dict
1432
- The proposed moves, see `sample_moves`.
1433
- move_counts : dict
1434
- The counts of the number of points in the the nodes modified by the
1435
- moves.
1436
- move_precs : dict
1437
- The likelihood precision scale in each node modified by the moves.
1438
- prelkv, prelk, prelf : dict
1439
- Dictionaries with pre-computed terms of the likelihood ratios and leaf
1440
- samples.
1875
+ pso
1876
+ The output of `accept_moves_parallel_stage`.
1441
1877
 
1442
1878
  Returns
1443
1879
  -------
1444
- bart : dict
1880
+ bart : State
1445
1881
  A partially updated BART mcmc state.
1446
- moves : dict
1447
- The proposed moves, with these additional fields:
1448
-
1449
- 'acc' : bool array (num_trees,)
1450
- Whether the move was accepted.
1451
- 'to_prune' : bool array (num_trees,)
1452
- Whether, to reflect the acceptance status of the move, the state
1453
- should be updated by pruning the leaves involved in the move.
1882
+ moves : Moves
1883
+ The accepted/rejected moves, with `acc` and `to_prune` set.
1454
1884
  """
1455
- bart = bart.copy()
1456
- moves = moves.copy()
1457
1885
 
1458
- def loop(resid, item):
1886
+ def loop(resid, pt):
1459
1887
  resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
1460
- bart['X'],
1461
- len(bart['leaf_trees']),
1462
- bart['opt']['resid_batch_size'],
1463
1888
  resid,
1464
- bart['prec_scale'],
1465
- bart['min_points_per_leaf'],
1466
- 'ratios' in bart,
1467
- prelk,
1468
- *item,
1889
+ SeqStageInAllTrees(
1890
+ pso.bart.X,
1891
+ pso.bart.forest.resid_batch_size,
1892
+ pso.bart.prec_scale,
1893
+ pso.bart.forest.min_points_per_leaf,
1894
+ pso.bart.forest.log_likelihood is not None,
1895
+ pso.prelk,
1896
+ ),
1897
+ pt,
1469
1898
  )
1470
1899
  return resid, (leaf_tree, acc, to_prune, ratios)
1471
1900
 
1472
- items = (
1473
- bart['leaf_trees'],
1474
- prec_trees,
1475
- moves,
1476
- move_counts,
1477
- move_precs,
1478
- bart['leaf_indices'],
1479
- prelkv,
1480
- prelf,
1901
+ pts = SeqStageInPerTree(
1902
+ pso.bart.forest.leaf_trees,
1903
+ pso.prec_trees,
1904
+ pso.moves,
1905
+ pso.move_counts,
1906
+ pso.move_precs,
1907
+ pso.bart.forest.leaf_indices,
1908
+ pso.prelkv,
1909
+ pso.prelf,
1481
1910
  )
1482
- resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
1483
-
1484
- bart['resid'] = resid
1485
- bart['leaf_trees'] = leaf_trees
1486
- bart.get('ratios', {}).update(ratios) # noop if there are no ratios
1487
- moves['acc'] = acc
1488
- moves['to_prune'] = to_prune
1911
+ resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
1912
+
1913
+ save_ratios = pso.bart.forest.log_likelihood is not None
1914
+ bart = replace(
1915
+ pso.bart,
1916
+ resid=resid,
1917
+ forest=replace(
1918
+ pso.bart.forest,
1919
+ leaf_trees=leaf_trees,
1920
+ log_likelihood=ratios['log_likelihood'] if save_ratios else None,
1921
+ log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
1922
+ ),
1923
+ )
1924
+ moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1489
1925
 
1490
1926
  return bart, moves
1491
1927
 
1492
1928
 
1493
- def accept_move_and_sample_leaves(
1494
- X,
1495
- ntree,
1496
- resid_batch_size,
1497
- resid,
1498
- prec_scale,
1499
- min_points_per_leaf,
1500
- save_ratios,
1501
- prelk,
1502
- leaf_tree,
1503
- prec_tree,
1504
- move,
1505
- move_counts,
1506
- move_precs,
1507
- leaf_indices,
1508
- prelkv,
1509
- prelf,
1510
- ):
1929
+ class SeqStageInAllTrees(Module):
1511
1930
  """
1512
- Accept or reject a proposed move and sample the new leaf values.
1931
+ The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
1513
1932
 
1514
1933
  Parameters
1515
1934
  ----------
1516
- X : int array (p, n)
1935
+ X
1517
1936
  The predictors.
1518
- ntree : int
1519
- The number of trees in the forest.
1520
- resid_batch_size : int, None
1937
+ resid_batch_size
1521
1938
  The batch size for computing the sum of residuals in each leaf.
1522
- resid : float array (n,)
1523
- The residuals (data minus forest value).
1524
- prec_scale : float array (n,) or None
1939
+ prec_scale
1525
1940
  The scale of the precision of the error on each datapoint. If None, it
1526
1941
  is assumed to be 1.
1527
- min_points_per_leaf : int or None
1942
+ min_points_per_leaf
1528
1943
  The minimum number of data points in a leaf node.
1529
- save_ratios : bool
1944
+ save_ratios
1530
1945
  Whether to save the acceptance ratios.
1531
- prelk : dict
1946
+ prelk
1532
1947
  The pre-computed terms of the likelihood ratio which are shared across
1533
1948
  trees.
1534
- leaf_tree : float array (2 ** d,)
1949
+ """
1950
+
1951
+ X: UInt[Array, 'p n']
1952
+ resid_batch_size: int | None
1953
+ prec_scale: Float32[Array, 'n'] | None
1954
+ min_points_per_leaf: Int32[Array, ''] | None
1955
+ save_ratios: bool
1956
+ prelk: PreLk
1957
+
1958
+
1959
+ class SeqStageInPerTree(Module):
1960
+ """
1961
+ The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
1962
+
1963
+ Parameters
1964
+ ----------
1965
+ leaf_tree
1535
1966
  The leaf values of the tree.
1536
- prec_tree : float array (2 ** d,)
1967
+ prec_tree
1537
1968
  The likelihood precision scale in each potential or actual leaf node.
1538
- move : dict
1539
- The proposed move, see `sample_moves`.
1540
- move_counts : dict
1969
+ move
1970
+ The proposed move, see `propose_moves`.
1971
+ move_counts
1541
1972
  The counts of the number of points in the the nodes modified by the
1542
1973
  moves.
1543
- move_precs : dict
1974
+ move_precs
1544
1975
  The likelihood precision scale in each node modified by the moves.
1545
- leaf_indices : int array (n,)
1976
+ leaf_indices
1546
1977
  The leaf indices for the largest version of the tree compatible with
1547
1978
  the move.
1548
- prelkv, prelf : dict
1979
+ prelkv
1980
+ prelf
1549
1981
  The pre-computed terms of the likelihood ratio and leaf sampling which
1550
1982
  are specific to the tree.
1983
+ """
1984
+
1985
+ leaf_tree: Float32[Array, '2**d']
1986
+ prec_tree: Float32[Array, '2**d']
1987
+ move: Moves
1988
+ move_counts: Counts | None
1989
+ move_precs: Precs | Counts
1990
+ leaf_indices: UInt[Array, 'n']
1991
+ prelkv: PreLkV
1992
+ prelf: PreLf
1993
+
1994
+
1995
+ def accept_move_and_sample_leaves(
1996
+ resid: Float32[Array, 'n'],
1997
+ at: SeqStageInAllTrees,
1998
+ pt: SeqStageInPerTree,
1999
+ ) -> tuple[
2000
+ Float32[Array, 'n'],
2001
+ Float32[Array, '2**d'],
2002
+ Bool[Array, ''],
2003
+ Bool[Array, ''],
2004
+ dict[str, Float32[Array, '']],
2005
+ ]:
2006
+ """
2007
+ Accept or reject a proposed move and sample the new leaf values.
2008
+
2009
+ Parameters
2010
+ ----------
2011
+ resid
2012
+ The residuals (data minus forest value).
2013
+ at
2014
+ The inputs that are the same for all trees.
2015
+ pt
2016
+ The inputs that are separate for each tree.
1551
2017
 
1552
2018
  Returns
1553
2019
  -------
1554
- resid : float array (n,)
2020
+ resid : Float32[Array, 'n']
1555
2021
  The updated residuals (data minus forest value).
1556
- leaf_tree : float array (2 ** d,)
2022
+ leaf_tree : Float32[Array, '2**d']
1557
2023
  The new leaf values of the tree.
1558
- acc : bool
2024
+ acc : Bool[Array, '']
1559
2025
  Whether the move was accepted.
1560
- to_prune : bool
2026
+ to_prune : Bool[Array, '']
1561
2027
  Whether, to reflect the acceptance status of the move, the state should
1562
2028
  be updated by pruning the leaves involved in the move.
1563
- ratios : dict
2029
+ ratios : dict[str, Float32[Array, '']]
1564
2030
  The acceptance ratios for the moves. Empty if not to be saved.
1565
2031
  """
1566
-
1567
2032
  # sum residuals in each leaf, in tree proposed by grow move
1568
- if prec_scale is None:
2033
+ if at.prec_scale is None:
1569
2034
  scaled_resid = resid
1570
2035
  else:
1571
- scaled_resid = resid * prec_scale
1572
- resid_tree = sum_resid(scaled_resid, leaf_indices, leaf_tree.size, resid_batch_size)
2036
+ scaled_resid = resid * at.prec_scale
2037
+ resid_tree = sum_resid(
2038
+ scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2039
+ )
1573
2040
 
1574
2041
  # subtract starting tree from function
1575
- resid_tree += prec_tree * leaf_tree
2042
+ resid_tree += pt.prec_tree * pt.leaf_tree
1576
2043
 
1577
2044
  # get indices of move
1578
- node = move['node']
2045
+ node = pt.move.node
1579
2046
  assert node.dtype == jnp.int32
1580
- left = move['left']
1581
- right = move['right']
2047
+ left = pt.move.left
2048
+ right = pt.move.right
1582
2049
 
1583
2050
  # sum residuals in parent node modified by move
1584
2051
  resid_left = resid_tree[left]
@@ -1588,30 +2055,33 @@ def accept_move_and_sample_leaves(
1588
2055
 
1589
2056
  # compute acceptance ratio
1590
2057
  log_lk_ratio = compute_likelihood_ratio(
1591
- resid_total, resid_left, resid_right, prelkv, prelk
2058
+ resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1592
2059
  )
1593
- log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1594
- log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
2060
+ log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2061
+ log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
1595
2062
  ratios = {}
1596
- if save_ratios:
2063
+ if at.save_ratios:
1597
2064
  ratios.update(
1598
- log_trans_prior=move['log_trans_prior_ratio'],
2065
+ log_trans_prior=pt.move.log_trans_prior_ratio,
2066
+ # TODO save log_trans_prior_ratio as a vector outside of this loop,
2067
+ # then change the option everywhere to `save_likelihood_ratio`.
1599
2068
  log_likelihood=log_lk_ratio,
1600
2069
  )
1601
2070
 
1602
2071
  # determine whether to accept the move
1603
- acc = move['allowed'] & (move['logu'] <= log_ratio)
1604
- if min_points_per_leaf is not None:
1605
- acc &= move_counts['left'] >= min_points_per_leaf
1606
- acc &= move_counts['right'] >= min_points_per_leaf
2072
+ acc = pt.move.allowed & (pt.move.logu <= log_ratio)
2073
+ if at.min_points_per_leaf is not None:
2074
+ assert pt.move_counts is not None
2075
+ acc &= pt.move_counts.left >= at.min_points_per_leaf
2076
+ acc &= pt.move_counts.right >= at.min_points_per_leaf
1607
2077
 
1608
2078
  # compute leaves posterior and sample leaves
1609
- initial_leaf_tree = leaf_tree
1610
- mean_post = resid_tree * prelf['mean_factor']
1611
- leaf_tree = mean_post + prelf['centered_leaves']
2079
+ initial_leaf_tree = pt.leaf_tree
2080
+ mean_post = resid_tree * pt.prelf.mean_factor
2081
+ leaf_tree = mean_post + pt.prelf.centered_leaves
1612
2082
 
1613
2083
  # copy leaves around such that the leaf indices point to the correct leaf
1614
- to_prune = acc ^ move['grow']
2084
+ to_prune = acc ^ pt.move.grow
1615
2085
  leaf_tree = (
1616
2086
  leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
1617
2087
  .set(leaf_tree[node])
@@ -1620,43 +2090,51 @@ def accept_move_and_sample_leaves(
1620
2090
  )
1621
2091
 
1622
2092
  # replace old tree with new tree in function values
1623
- resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
2093
+ resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
1624
2094
 
1625
2095
  return resid, leaf_tree, acc, to_prune, ratios
1626
2096
 
1627
2097
 
1628
- def sum_resid(scaled_resid, leaf_indices, tree_size, batch_size):
2098
+ def sum_resid(
2099
+ scaled_resid: Float32[Array, 'n'],
2100
+ leaf_indices: UInt[Array, 'n'],
2101
+ tree_size: int,
2102
+ batch_size: int | None,
2103
+ ) -> Float32[Array, '{tree_size}']:
1629
2104
  """
1630
2105
  Sum the residuals in each leaf.
1631
2106
 
1632
2107
  Parameters
1633
2108
  ----------
1634
- scaled_resid : float array (n,)
2109
+ scaled_resid
1635
2110
  The residuals (data minus forest value) multiplied by the error
1636
2111
  precision scale.
1637
- leaf_indices : int array (n,)
2112
+ leaf_indices
1638
2113
  The leaf indices of the tree (in which leaf each data point falls into).
1639
- tree_size : int
2114
+ tree_size
1640
2115
  The size of the tree array (2 ** d).
1641
- batch_size : int, None
2116
+ batch_size
1642
2117
  The data batch size for the aggregation. Batching increases numerical
1643
2118
  accuracy and parallelism.
1644
2119
 
1645
2120
  Returns
1646
2121
  -------
1647
- resid_tree : float array (2 ** d,)
1648
- The sum of the residuals at data points in each leaf.
2122
+ The sum of the residuals at data points in each leaf.
1649
2123
  """
1650
2124
  if batch_size is None:
1651
2125
  aggr_func = _aggregate_scatter
1652
2126
  else:
1653
- aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1654
- return aggr_func(
1655
- scaled_resid, leaf_indices, tree_size, jnp.float32
1656
- ) # TODO: use large_float
2127
+ aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
2128
+ return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
1657
2129
 
1658
2130
 
1659
- def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
2131
+ def _aggregate_batched_onetree(
2132
+ values: Shaped[Array, '*'],
2133
+ indices: Integer[Array, '*'],
2134
+ size: int,
2135
+ dtype: jnp.dtype,
2136
+ batch_size: int,
2137
+ ) -> Float32[Array, '{size}']:
1660
2138
  (n,) = indices.shape
1661
2139
  nbatches = n // batch_size + bool(n % batch_size)
1662
2140
  batch_indices = jnp.arange(n) % nbatches
@@ -1668,118 +2146,133 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1668
2146
  )
1669
2147
 
1670
2148
 
1671
- def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
2149
+ def compute_likelihood_ratio(
2150
+ total_resid: Float32[Array, ''],
2151
+ left_resid: Float32[Array, ''],
2152
+ right_resid: Float32[Array, ''],
2153
+ prelkv: PreLkV,
2154
+ prelk: PreLk,
2155
+ ) -> Float32[Array, '']:
1672
2156
  """
1673
2157
  Compute the likelihood ratio of a grow move.
1674
2158
 
1675
2159
  Parameters
1676
2160
  ----------
1677
- total_resid, left_resid, right_resid : float
2161
+ total_resid
2162
+ left_resid
2163
+ right_resid
1678
2164
  The sum of the residuals (scaled by error precision scale) of the
1679
2165
  datapoints falling in the nodes involved in the moves.
1680
- prelkv, prelk : dict
2166
+ prelkv
2167
+ prelk
1681
2168
  The pre-computed terms of the likelihood ratio, see
1682
2169
  `precompute_likelihood_terms`.
1683
2170
 
1684
2171
  Returns
1685
2172
  -------
1686
- ratio : float
1687
- The likelihood ratio P(data | new tree) / P(data | old tree).
2173
+ The likelihood ratio P(data | new tree) / P(data | old tree).
1688
2174
  """
1689
- exp_term = prelk['exp_factor'] * (
1690
- left_resid * left_resid / prelkv['sigma2_left']
1691
- + right_resid * right_resid / prelkv['sigma2_right']
1692
- - total_resid * total_resid / prelkv['sigma2_total']
2175
+ exp_term = prelk.exp_factor * (
2176
+ left_resid * left_resid / prelkv.sigma2_left
2177
+ + right_resid * right_resid / prelkv.sigma2_right
2178
+ - total_resid * total_resid / prelkv.sigma2_total
1693
2179
  )
1694
- return prelkv['sqrt_term'] + exp_term
2180
+ return prelkv.sqrt_term + exp_term
1695
2181
 
1696
2182
 
1697
- def accept_moves_final_stage(bart, moves):
2183
+ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1698
2184
  """
1699
- The final part of accepting the moves, in parallel across trees.
2185
+ Post-process the mcmc state after accepting/rejecting the moves.
2186
+
2187
+ This function is separate from `accept_moves_sequential_stage` to signal it
2188
+ can work in parallel across trees.
1700
2189
 
1701
2190
  Parameters
1702
2191
  ----------
1703
- bart : dict
2192
+ bart
1704
2193
  A partially updated BART mcmc state.
1705
- counts : dict
1706
- The indicators of proposals and acceptances for grow and prune moves.
1707
- moves : dict
1708
- The proposed moves (see `sample_moves`) as updated by
2194
+ moves
2195
+ The proposed moves (see `propose_moves`) as updated by
1709
2196
  `accept_moves_sequential_stage`.
1710
2197
 
1711
2198
  Returns
1712
2199
  -------
1713
- bart : dict
1714
- The fully updated BART mcmc state.
1715
- """
1716
- bart = bart.copy()
1717
- bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
1718
- bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
1719
- bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
1720
- bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1721
- return bart
2200
+ The fully updated BART mcmc state.
2201
+ """
2202
+ return replace(
2203
+ bart,
2204
+ forest=replace(
2205
+ bart.forest,
2206
+ grow_acc_count=jnp.sum(moves.acc & moves.grow),
2207
+ prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2208
+ leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2209
+ split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
2210
+ ),
2211
+ )
1722
2212
 
1723
2213
 
1724
- @jaxext.vmap_nodoc
1725
- def apply_moves_to_leaf_indices(leaf_indices, moves):
2214
+ @vmap_nodoc
2215
+ def apply_moves_to_leaf_indices(
2216
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2217
+ ) -> UInt[Array, 'num_trees n']:
1726
2218
  """
1727
2219
  Update the leaf indices to match the accepted move.
1728
2220
 
1729
2221
  Parameters
1730
2222
  ----------
1731
- leaf_indices : int array (num_trees, n)
2223
+ leaf_indices
1732
2224
  The index of the leaf each datapoint falls into, if the grow move was
1733
2225
  accepted.
1734
- moves : dict
1735
- The proposed moves (see `sample_moves`), as updated by
2226
+ moves
2227
+ The proposed moves (see `propose_moves`), as updated by
1736
2228
  `accept_moves_sequential_stage`.
1737
2229
 
1738
2230
  Returns
1739
2231
  -------
1740
- leaf_indices : int array (num_trees, n)
1741
- The updated leaf indices.
2232
+ The updated leaf indices.
1742
2233
  """
1743
2234
  mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1744
- is_child = (leaf_indices & mask) == moves['left']
2235
+ is_child = (leaf_indices & mask) == moves.left
1745
2236
  return jnp.where(
1746
- is_child & moves['to_prune'],
1747
- moves['node'].astype(leaf_indices.dtype),
2237
+ is_child & moves.to_prune,
2238
+ moves.node.astype(leaf_indices.dtype),
1748
2239
  leaf_indices,
1749
2240
  )
1750
2241
 
1751
2242
 
1752
- @jaxext.vmap_nodoc
1753
- def apply_moves_to_split_trees(split_trees, moves):
2243
+ @vmap_nodoc
2244
+ def apply_moves_to_split_trees(
2245
+ split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2246
+ ) -> UInt[Array, 'num_trees 2**(d-1)']:
1754
2247
  """
1755
2248
  Update the split trees to match the accepted move.
1756
2249
 
1757
2250
  Parameters
1758
2251
  ----------
1759
- split_trees : int array (num_trees, 2 ** (d - 1))
2252
+ split_trees
1760
2253
  The cutpoints of the decision nodes in the initial trees.
1761
- moves : dict
1762
- The proposed moves (see `sample_moves`), as updated by
2254
+ moves
2255
+ The proposed moves (see `propose_moves`), as updated by
1763
2256
  `accept_moves_sequential_stage`.
1764
2257
 
1765
2258
  Returns
1766
2259
  -------
1767
- split_trees : int array (num_trees, 2 ** (d - 1))
1768
- The updated split trees.
2260
+ The updated split trees.
1769
2261
  """
2262
+ assert moves.to_prune is not None
1770
2263
  return (
1771
2264
  split_trees.at[
1772
2265
  jnp.where(
1773
- moves['grow'],
1774
- moves['node'],
2266
+ moves.grow,
2267
+ moves.node,
1775
2268
  split_trees.size,
1776
2269
  )
1777
2270
  ]
1778
- .set(moves['grow_split'].astype(split_trees.dtype))
2271
+ .set(moves.grow_split.astype(split_trees.dtype))
1779
2272
  .at[
1780
2273
  jnp.where(
1781
- moves['to_prune'],
1782
- moves['node'],
2274
+ moves.to_prune,
2275
+ moves.node,
1783
2276
  split_trees.size,
1784
2277
  )
1785
2278
  ]
@@ -1787,34 +2280,56 @@ def apply_moves_to_split_trees(split_trees, moves):
1787
2280
  )
1788
2281
 
1789
2282
 
1790
- def sample_sigma(key, bart):
2283
+ def step_sigma(key: Key[Array, ''], bart: State) -> State:
1791
2284
  """
1792
- Noise variance sampling step of BART MCMC.
2285
+ MCMC-update the error variance (factor).
1793
2286
 
1794
2287
  Parameters
1795
2288
  ----------
1796
- key : jax.dtypes.prng_key array
2289
+ key
1797
2290
  A jax random key.
1798
- bart : dict
1799
- A BART mcmc state, as created by `init`.
2291
+ bart
2292
+ A BART mcmc state.
1800
2293
 
1801
2294
  Returns
1802
2295
  -------
1803
- bart : dict
1804
- The new BART mcmc state.
2296
+ The new BART mcmc state, with an updated `sigma2`.
1805
2297
  """
1806
- bart = bart.copy()
1807
-
1808
- resid = bart['resid']
1809
- alpha = bart['sigma2_alpha'] + resid.size / 2
1810
- if bart['prec_scale'] is None:
2298
+ resid = bart.resid
2299
+ alpha = bart.sigma2_alpha + resid.size / 2
2300
+ if bart.prec_scale is None:
1811
2301
  scaled_resid = resid
1812
2302
  else:
1813
- scaled_resid = resid * bart['prec_scale']
2303
+ scaled_resid = resid * bart.prec_scale
1814
2304
  norm2 = resid @ scaled_resid
1815
- beta = bart['sigma2_beta'] + norm2 / 2
2305
+ beta = bart.sigma2_beta + norm2 / 2
1816
2306
 
1817
2307
  sample = random.gamma(key, alpha)
1818
- bart['sigma2'] = beta / sample
2308
+ return replace(bart, sigma2=beta / sample)
1819
2309
 
1820
- return bart
2310
+
2311
+ def step_z(key: Key[Array, ''], bart: State) -> State:
2312
+ """
2313
+ MCMC-update the latent variable for binary regression.
2314
+
2315
+ Parameters
2316
+ ----------
2317
+ key
2318
+ A jax random key.
2319
+ bart
2320
+ A BART MCMC state.
2321
+
2322
+ Returns
2323
+ -------
2324
+ The updated BART MCMC state.
2325
+ """
2326
+ trees_plus_offset = bart.z - bart.resid
2327
+ lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
2328
+ upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
2329
+ resid = random.truncated_normal(key, lower, upper)
2330
+ # TODO jax's implementation of truncated_normal is not good, it just does
2331
+ # cdf inversion with erf and erf_inv. I can do better, at least avoiding to
2332
+ # compute one of the boundaries, and maybe also flipping and using ndtr
2333
+ # instead of erf for numerical stability (open an issue in jax?)
2334
+ z = trees_plus_offset + resid
2335
+ return replace(bart, z=z, resid=resid)