bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcstep.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -26,199 +26,292 @@
26
26
  Functions that implement the BART posterior MCMC initialization and update step.
27
27
 
28
28
  Functions that do MCMC steps operate by taking as input a bart state, and
29
- outputting a new dictionary with the new state. The input dict/arrays are not
30
- modified.
29
+ outputting a new state. The inputs are not modified.
31
30
 
32
- In general, integer types are chosen to be the minimal types that contain the
33
- range of possible values.
31
+ The main entry points are:
32
+
33
+ - `State`: The dataclass that represents a BART MCMC state.
34
+ - `init`: Creates an initial `State` from data and configurations.
35
+ - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
34
36
  """
35
37
 
36
- import functools
37
38
  import math
39
+ from dataclasses import replace
40
+ from functools import cache, partial
41
+ from typing import Any
38
42
 
39
43
  import jax
40
- from jax import random
44
+ from equinox import Module, field
45
+ from jax import lax, random
41
46
  from jax import numpy as jnp
42
- from jax import lax
47
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
43
48
 
44
- from . import jaxext
45
49
  from . import grove
50
+ from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
51
+
46
52
 
47
- def init(*,
48
- X,
49
- y,
50
- max_split,
51
- num_trees,
52
- p_nonterminal,
53
- sigma2_alpha,
54
- sigma2_beta,
55
- small_float=jnp.float32,
56
- large_float=jnp.float32,
57
- min_points_per_leaf=None,
58
- resid_batch_size='auto',
59
- count_batch_size='auto',
60
- save_ratios=False,
61
- ):
53
+ class Forest(Module):
54
+ """
55
+ Represents the MCMC state of a sum of trees.
56
+
57
+ Parameters
58
+ ----------
59
+ leaf_trees
60
+ The leaf values.
61
+ var_trees
62
+ The decision axes.
63
+ split_trees
64
+ The decision boundaries.
65
+ p_nonterminal
66
+ The probability of a nonterminal node at each depth, padded with a
67
+ zero.
68
+ p_propose_grow
69
+ The unnormalized probability of picking a leaf for a grow proposal.
70
+ leaf_indices
71
+ The index of the leaf each datapoints falls into, for each tree.
72
+ min_points_per_leaf
73
+ The minimum number of data points in a leaf node.
74
+ affluence_trees
75
+ Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
76
+ datapoints. If `min_points_per_leaf` is not specified, this is None.
77
+ resid_batch_size
78
+ count_batch_size
79
+ The data batch sizes for computing the sufficient statistics. If `None`,
80
+ they are computed with no batching.
81
+ log_trans_prior
82
+ The log transition and prior Metropolis-Hastings ratio for the
83
+ proposed move on each tree.
84
+ log_likelihood
85
+ The log likelihood ratio.
86
+ grow_prop_count
87
+ prune_prop_count
88
+ The number of grow/prune proposals made during one full MCMC cycle.
89
+ grow_acc_count
90
+ prune_acc_count
91
+ The number of grow/prune moves accepted during one full MCMC cycle.
92
+ sigma_mu2
93
+ The prior variance of a leaf, conditional on the tree structure.
94
+ """
95
+
96
+ leaf_trees: Float32[Array, 'num_trees 2**d']
97
+ var_trees: UInt[Array, 'num_trees 2**(d-1)']
98
+ split_trees: UInt[Array, 'num_trees 2**(d-1)']
99
+ p_nonterminal: Float32[Array, 'd']
100
+ p_propose_grow: Float32[Array, '2**(d-1)']
101
+ leaf_indices: UInt[Array, 'num_trees n']
102
+ min_points_per_leaf: Int32[Array, ''] | None
103
+ affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
104
+ resid_batch_size: int | None = field(static=True)
105
+ count_batch_size: int | None = field(static=True)
106
+ log_trans_prior: Float32[Array, 'num_trees'] | None
107
+ log_likelihood: Float32[Array, 'num_trees'] | None
108
+ grow_prop_count: Int32[Array, '']
109
+ prune_prop_count: Int32[Array, '']
110
+ grow_acc_count: Int32[Array, '']
111
+ prune_acc_count: Int32[Array, '']
112
+ sigma_mu2: Float32[Array, '']
113
+
114
+
115
+ class State(Module):
116
+ """
117
+ Represents the MCMC state of BART.
118
+
119
+ Parameters
120
+ ----------
121
+ X
122
+ The predictors.
123
+ max_split
124
+ The maximum split index for each predictor.
125
+ y
126
+ The response. If the data type is `bool`, the model is binary regression.
127
+ resid
128
+ The residuals (`y` or `z` minus sum of trees).
129
+ z
130
+ The latent variable for binary regression. `None` in continuous
131
+ regression.
132
+ offset
133
+ Constant shift added to the sum of trees.
134
+ sigma2
135
+ The error variance. `None` in binary regression.
136
+ prec_scale
137
+ The scale on the error precision, i.e., ``1 / error_scale ** 2``.
138
+ `None` in binary regression.
139
+ sigma2_alpha
140
+ sigma2_beta
141
+ The shape and scale parameters of the inverse gamma prior on the noise
142
+ variance. `None` in binary regression.
143
+ forest
144
+ The sum of trees model.
145
+ """
146
+
147
+ X: UInt[Array, 'p n']
148
+ max_split: UInt[Array, 'p']
149
+ y: Float32[Array, 'n'] | Bool[Array, 'n']
150
+ z: None | Float32[Array, 'n']
151
+ offset: Float32[Array, '']
152
+ resid: Float32[Array, 'n']
153
+ sigma2: Float32[Array, ''] | None
154
+ prec_scale: Float32[Array, 'n'] | None
155
+ sigma2_alpha: Float32[Array, ''] | None
156
+ sigma2_beta: Float32[Array, ''] | None
157
+ forest: Forest
158
+
159
+
160
+ def init(
161
+ *,
162
+ X: UInt[Any, 'p n'],
163
+ y: Float32[Any, 'n'] | Bool[Any, 'n'],
164
+ offset: float | Float32[Any, ''] = 0.0,
165
+ max_split: UInt[Any, 'p'],
166
+ num_trees: int,
167
+ p_nonterminal: Float32[Any, 'd-1'],
168
+ sigma_mu2: float | Float32[Any, ''],
169
+ sigma2_alpha: float | Float32[Any, ''] | None = None,
170
+ sigma2_beta: float | Float32[Any, ''] | None = None,
171
+ error_scale: Float32[Any, 'n'] | None = None,
172
+ min_points_per_leaf: int | None = None,
173
+ resid_batch_size: int | None | str = 'auto',
174
+ count_batch_size: int | None | str = 'auto',
175
+ save_ratios: bool = False,
176
+ ) -> State:
62
177
  """
63
178
  Make a BART posterior sampling MCMC initial state.
64
179
 
65
180
  Parameters
66
181
  ----------
67
- X : int array (p, n)
182
+ X
68
183
  The predictors. Note this is trasposed compared to the usual convention.
69
- y : float array (n,)
70
- The response.
71
- max_split : int array (p,)
184
+ y
185
+ The response. If the data type is `bool`, the regression model is binary
186
+ regression with probit.
187
+ offset
188
+ Constant shift added to the sum of trees. 0 if not specified.
189
+ max_split
72
190
  The maximum split index for each variable. All split ranges start at 1.
73
- num_trees : int
191
+ num_trees
74
192
  The number of trees in the forest.
75
- p_nonterminal : float array (d - 1,)
193
+ p_nonterminal
76
194
  The probability of a nonterminal node at each depth. The maximum depth
77
195
  of trees is fixed by the length of this array.
78
- sigma2_alpha : float
79
- The shape parameter of the inverse gamma prior on the noise variance.
80
- sigma2_beta : float
81
- The scale parameter of the inverse gamma prior on the noise variance.
82
- small_float : dtype, default float32
83
- The dtype for large arrays used in the algorithm.
84
- large_float : dtype, default float32
85
- The dtype for scalars, small arrays, and arrays which require accuracy.
86
- min_points_per_leaf : int, optional
196
+ sigma_mu2
197
+ The prior variance of a leaf, conditional on the tree structure. The
198
+ prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
199
+ prior mean of leaves is always zero.
200
+ sigma2_alpha
201
+ sigma2_beta
202
+ The shape and scale parameters of the inverse gamma prior on the error
203
+ variance. Leave unspecified for binary regression.
204
+ error_scale
205
+ Each error is scaled by the corresponding factor in `error_scale`, so
206
+ the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
207
+ Not supported for binary regression. If not specified, defaults to 1 for
208
+ all points, but potentially skipping calculations.
209
+ min_points_per_leaf
87
210
  The minimum number of data points in a leaf node. 0 if not specified.
88
- resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
211
+ resid_batch_size
212
+ count_batch_size
89
213
  The batch sizes, along datapoints, for summing the residuals and
90
214
  counting the number of datapoints in each leaf. `None` for no batching.
91
215
  If 'auto', pick a value based on the device of `y`, or the default
92
216
  device.
93
- save_ratios : bool, default False
217
+ save_ratios
94
218
  Whether to save the Metropolis-Hastings ratios.
95
219
 
96
220
  Returns
97
221
  -------
98
- bart : dict
99
- A dictionary with array values, representing a BART mcmc state. The
100
- keys are:
101
-
102
- 'leaf_trees' : small_float array (num_trees, 2 ** d)
103
- The leaf values.
104
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
105
- The decision axes.
106
- 'split_trees' : int array (num_trees, 2 ** (d - 1))
107
- The decision boundaries.
108
- 'resid' : large_float array (n,)
109
- The residuals (data minus forest value). Large float to avoid
110
- roundoff.
111
- 'sigma2' : large_float
112
- The noise variance.
113
- 'grow_prop_count', 'prune_prop_count' : int
114
- The number of grow/prune proposals made during one full MCMC cycle.
115
- 'grow_acc_count', 'prune_acc_count' : int
116
- The number of grow/prune moves accepted during one full MCMC cycle.
117
- 'p_nonterminal' : large_float array (d,)
118
- The probability of a nonterminal node at each depth, padded with a
119
- zero.
120
- 'p_propose_grow' : large_float array (2 ** (d - 1),)
121
- The unnormalized probability of picking a leaf for a grow proposal.
122
- 'sigma2_alpha' : large_float
123
- The shape parameter of the inverse gamma prior on the noise variance.
124
- 'sigma2_beta' : large_float
125
- The scale parameter of the inverse gamma prior on the noise variance.
126
- 'max_split' : int array (p,)
127
- The maximum split index for each variable.
128
- 'y' : small_float array (n,)
129
- The response.
130
- 'X' : int array (p, n)
131
- The predictors.
132
- 'leaf_indices' : int array (num_trees, n)
133
- The index of the leaf each datapoints falls into, for each tree.
134
- 'min_points_per_leaf' : int or None
135
- The minimum number of data points in a leaf node.
136
- 'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
137
- Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
138
- datapoints. If `min_points_per_leaf` is not specified, this is None.
139
- 'opt' : LeafDict
140
- A dictionary with config values:
141
-
142
- 'small_float' : dtype
143
- The dtype for large arrays used in the algorithm.
144
- 'large_float' : dtype
145
- The dtype for scalars, small arrays, and arrays which require
146
- accuracy.
147
- 'require_min_points' : bool
148
- Whether the `min_points_per_leaf` parameter is specified.
149
- 'resid_batch_size', 'count_batch_size' : int or None
150
- The data batch sizes for computing the sufficient statistics.
151
- 'ratios' : dict, optional
152
- If `save_ratios` is True, this field is present. It has the fields:
153
-
154
- 'log_trans_prior' : large_float array (num_trees,)
155
- The log transition and prior Metropolis-Hastings ratio for the
156
- proposed move on each tree.
157
- 'log_likelihood' : large_float array (num_trees,)
158
- The log likelihood ratio.
159
- """
160
-
161
- p_nonterminal = jnp.asarray(p_nonterminal, large_float)
222
+ An initialized BART MCMC state.
223
+
224
+ Raises
225
+ ------
226
+ ValueError
227
+ If `y` is boolean and arguments unused in binary regression are set.
228
+ """
229
+ p_nonterminal = jnp.asarray(p_nonterminal)
162
230
  p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
163
231
  max_depth = p_nonterminal.size
164
232
 
165
- @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
233
+ @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
166
234
  def make_forest(max_depth, dtype):
167
235
  return grove.make_tree(max_depth, dtype)
168
236
 
169
- small_float = jnp.dtype(small_float)
170
- large_float = jnp.dtype(large_float)
171
- y = jnp.asarray(y, small_float)
172
- resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
173
- sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
174
- sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
175
-
176
- bart = dict(
177
- leaf_trees=make_forest(max_depth, small_float),
178
- var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
179
- split_trees=make_forest(max_depth - 1, max_split.dtype),
180
- resid=jnp.asarray(y, large_float),
181
- sigma2=sigma2,
182
- grow_prop_count=jnp.zeros((), int),
183
- grow_acc_count=jnp.zeros((), int),
184
- prune_prop_count=jnp.zeros((), int),
185
- prune_acc_count=jnp.zeros((), int),
186
- p_nonterminal=p_nonterminal,
187
- p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
188
- sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
189
- sigma2_beta=jnp.asarray(sigma2_beta, large_float),
237
+ y = jnp.asarray(y)
238
+ offset = jnp.asarray(offset)
239
+
240
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
241
+ resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
242
+ )
243
+
244
+ is_binary = y.dtype == bool
245
+ if is_binary:
246
+ if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
247
+ raise ValueError(
248
+ 'error_scale, sigma2_alpha, and sigma2_beta must be set '
249
+ ' to `None` for binary regression.'
250
+ )
251
+ sigma2 = None
252
+ else:
253
+ sigma2_alpha = jnp.asarray(sigma2_alpha)
254
+ sigma2_beta = jnp.asarray(sigma2_beta)
255
+ sigma2 = sigma2_beta / sigma2_alpha
256
+ # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
257
+ # TODO: I don't like this isfinite check, these functions should be
258
+ # low-level and just do the thing. Why was it here?
259
+
260
+ bart = State(
261
+ X=jnp.asarray(X),
190
262
  max_split=jnp.asarray(max_split),
191
263
  y=y,
192
- X=jnp.asarray(X),
193
- leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
194
- min_points_per_leaf=(
195
- None if min_points_per_leaf is None else
196
- jnp.asarray(min_points_per_leaf)
197
- ),
198
- affluence_trees=(
199
- None if min_points_per_leaf is None else
200
- make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
264
+ z=jnp.full(y.shape, offset) if is_binary else None,
265
+ offset=offset,
266
+ resid=jnp.zeros(y.shape) if is_binary else y - offset,
267
+ sigma2=sigma2,
268
+ prec_scale=(
269
+ None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
201
270
  ),
202
- opt=jaxext.LeafDict(
203
- small_float=small_float,
204
- large_float=large_float,
205
- require_min_points=min_points_per_leaf is not None,
271
+ sigma2_alpha=sigma2_alpha,
272
+ sigma2_beta=sigma2_beta,
273
+ forest=Forest(
274
+ leaf_trees=make_forest(max_depth, jnp.float32),
275
+ var_trees=make_forest(
276
+ max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
277
+ ),
278
+ split_trees=make_forest(max_depth - 1, max_split.dtype),
279
+ grow_prop_count=jnp.zeros((), int),
280
+ grow_acc_count=jnp.zeros((), int),
281
+ prune_prop_count=jnp.zeros((), int),
282
+ prune_acc_count=jnp.zeros((), int),
283
+ p_nonterminal=p_nonterminal,
284
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
285
+ leaf_indices=jnp.ones(
286
+ (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
287
+ ),
288
+ min_points_per_leaf=(
289
+ None
290
+ if min_points_per_leaf is None
291
+ else jnp.asarray(min_points_per_leaf)
292
+ ),
293
+ affluence_trees=(
294
+ None
295
+ if min_points_per_leaf is None
296
+ else make_forest(max_depth - 1, bool)
297
+ .at[:, 1]
298
+ .set(y.size >= 2 * min_points_per_leaf)
299
+ ),
206
300
  resid_batch_size=resid_batch_size,
207
301
  count_batch_size=count_batch_size,
302
+ log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
303
+ log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
304
+ sigma_mu2=jnp.asarray(sigma_mu2),
208
305
  ),
209
306
  )
210
307
 
211
- if save_ratios:
212
- bart['ratios'] = dict(
213
- log_trans_prior=jnp.full(num_trees, jnp.nan),
214
- log_likelihood=jnp.full(num_trees, jnp.nan),
215
- )
216
-
217
308
  return bart
218
309
 
219
- def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
220
310
 
221
- @functools.cache
311
+ def _choose_suffstat_batch_size(
312
+ resid_batch_size, count_batch_size, y, forest_size
313
+ ) -> tuple[int | None, ...]:
314
+ @cache
222
315
  def get_platform():
223
316
  try:
224
317
  device = y.devices().pop()
@@ -233,9 +326,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
233
326
  platform = get_platform()
234
327
  n = max(1, y.size)
235
328
  if platform == 'cpu':
236
- resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
329
+ resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
237
330
  elif platform == 'gpu':
238
- resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
331
+ resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
239
332
  resid_batch_size = max(1, resid_batch_size)
240
333
 
241
334
  if count_batch_size == 'auto':
@@ -244,9 +337,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
244
337
  count_batch_size = None
245
338
  elif platform == 'gpu':
246
339
  n = max(1, y.size)
247
- count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
248
- # /4 is good on V100, /2 on L4/T4, still haven't tried A100
249
- max_memory = 2 ** 29
340
+ count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
341
+ # /4 is good on V100, /2 on L4/T4, still haven't tried A100
342
+ max_memory = 2**29
250
343
  itemsize = 4
251
344
  min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
252
345
  count_batch_size = max(count_batch_size, min_batch_size)
@@ -254,210 +347,328 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
254
347
 
255
348
  return resid_batch_size, count_batch_size
256
349
 
257
- def step(bart, key):
350
+
351
+ @jax.jit
352
+ def step(key: Key[Array, ''], bart: State) -> State:
258
353
  """
259
- Perform one full MCMC step on a BART state.
354
+ Do one MCMC step.
260
355
 
261
356
  Parameters
262
357
  ----------
263
- bart : dict
264
- A BART mcmc state, as created by `init`.
265
- key : jax.dtypes.prng_key array
358
+ key
266
359
  A jax random key.
360
+ bart
361
+ A BART mcmc state, as created by `init`.
267
362
 
268
363
  Returns
269
364
  -------
270
- bart : dict
271
- The new BART mcmc state.
365
+ The new BART mcmc state.
272
366
  """
273
- key, subkey = random.split(key)
274
- bart = sample_trees(bart, subkey)
275
- return sample_sigma(bart, key)
367
+ keys = split(key)
368
+
369
+ if bart.y.dtype == bool: # binary regression
370
+ bart = replace(bart, sigma2=jnp.float32(1))
371
+ bart = step_trees(keys.pop(), bart)
372
+ bart = replace(bart, sigma2=None)
373
+ return step_z(keys.pop(), bart)
276
374
 
277
- def sample_trees(bart, key):
375
+ else: # continuous regression
376
+ bart = step_trees(keys.pop(), bart)
377
+ return step_sigma(keys.pop(), bart)
378
+
379
+
380
+ def step_trees(key: Key[Array, ''], bart: State) -> State:
278
381
  """
279
382
  Forest sampling step of BART MCMC.
280
383
 
281
384
  Parameters
282
385
  ----------
283
- bart : dict
284
- A BART mcmc state, as created by `init`.
285
- key : jax.dtypes.prng_key array
386
+ key
286
387
  A jax random key.
388
+ bart
389
+ A BART mcmc state, as created by `init`.
287
390
 
288
391
  Returns
289
392
  -------
290
- bart : dict
291
- The new BART mcmc state.
393
+ The new BART mcmc state.
292
394
 
293
395
  Notes
294
396
  -----
295
397
  This function zeroes the proposal counters.
296
398
  """
297
- key, subkey = random.split(key)
298
- moves = sample_moves(bart, subkey)
299
- return accept_moves_and_sample_leaves(bart, moves, key)
399
+ keys = split(key)
400
+ moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
401
+ return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
402
+
403
+
404
+ class Moves(Module):
405
+ """
406
+ Moves proposed to modify each tree.
300
407
 
301
- def sample_moves(bart, key):
408
+ Parameters
409
+ ----------
410
+ allowed
411
+ Whether the move is possible in the first place. There are additional
412
+ constraints that could forbid it, but they are computed at acceptance
413
+ time.
414
+ grow
415
+ Whether the move is a grow move or a prune move.
416
+ num_growable
417
+ The number of growable leaves in the original tree.
418
+ node
419
+ The index of the leaf to grow or node to prune.
420
+ left
421
+ right
422
+ The indices of the children of 'node'.
423
+ partial_ratio
424
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
425
+ the likelihood ratio and the probability of proposing the prune
426
+ move. If the move is PRUNE, the ratio is inverted. `None` once
427
+ `log_trans_prior_ratio` has been computed.
428
+ log_trans_prior_ratio
429
+ The logarithm of the product of the transition and prior terms of the
430
+ Metropolis-Hastings ratio for the acceptance of the proposed move.
431
+ `None` if not yet computed.
432
+ grow_var
433
+ The decision axes of the new rules.
434
+ grow_split
435
+ The decision boundaries of the new rules.
436
+ var_trees
437
+ The updated decision axes of the trees, valid whatever move.
438
+ logu
439
+ The logarithm of a uniform (0, 1] random variable to be used to
440
+ accept the move. It's in (-oo, 0].
441
+ acc
442
+ Whether the move was accepted. `None` if not yet computed.
443
+ to_prune
444
+ Whether the final operation to apply the move is pruning. This indicates
445
+ an accepted prune move or a rejected grow move. `None` if not yet
446
+ computed.
447
+ """
448
+
449
+ allowed: Bool[Array, 'num_trees']
450
+ grow: Bool[Array, 'num_trees']
451
+ num_growable: UInt[Array, 'num_trees']
452
+ node: UInt[Array, 'num_trees']
453
+ left: UInt[Array, 'num_trees']
454
+ right: UInt[Array, 'num_trees']
455
+ partial_ratio: Float32[Array, 'num_trees'] | None
456
+ log_trans_prior_ratio: None | Float32[Array, 'num_trees']
457
+ grow_var: UInt[Array, 'num_trees']
458
+ grow_split: UInt[Array, 'num_trees']
459
+ var_trees: UInt[Array, 'num_trees 2**(d-1)']
460
+ logu: Float32[Array, 'num_trees']
461
+ acc: None | Bool[Array, 'num_trees']
462
+ to_prune: None | Bool[Array, 'num_trees']
463
+
464
+
465
+ def propose_moves(
466
+ key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
467
+ ) -> Moves:
302
468
  """
303
469
  Propose moves for all the trees.
304
470
 
471
+ There are two types of moves: GROW (convert a leaf to a decision node and
472
+ add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
473
+ leaf, deleting its children).
474
+
305
475
  Parameters
306
476
  ----------
307
- bart : dict
308
- BART mcmc state.
309
- key : jax.dtypes.prng_key array
477
+ key
310
478
  A jax random key.
479
+ forest
480
+ The `forest` field of a BART MCMC state.
481
+ max_split
482
+ The maximum split index for each variable, found in `State`.
311
483
 
312
484
  Returns
313
485
  -------
314
- moves : dict
315
- A dictionary with fields:
316
-
317
- 'allowed' : bool array (num_trees,)
318
- Whether the move is possible.
319
- 'grow' : bool array (num_trees,)
320
- Whether the move is a grow move or a prune move.
321
- 'num_growable' : int array (num_trees,)
322
- The number of growable leaves in the original tree.
323
- 'node' : int array (num_trees,)
324
- The index of the leaf to grow or node to prune.
325
- 'left', 'right' : int array (num_trees,)
326
- The indices of the children of 'node'.
327
- 'partial_ratio' : float array (num_trees,)
328
- A factor of the Metropolis-Hastings ratio of the move. It lacks
329
- the likelihood ratio and the probability of proposing the prune
330
- move. If the move is Prune, the ratio is inverted.
331
- 'grow_var' : int array (num_trees,)
332
- The decision axes of the new rules.
333
- 'grow_split' : int array (num_trees,)
334
- The decision boundaries of the new rules.
335
- 'var_trees' : int array (num_trees, 2 ** (d - 1))
336
- The updated decision axes of the trees, valid whatever move.
337
- 'logu' : float array (num_trees,)
338
- The logarithm of a uniform (0, 1] random variable to be used to
339
- accept the move. It's in (-oo, 0].
340
- """
341
- ntree = bart['leaf_trees'].shape[0]
342
- key = random.split(key, 1 + ntree)
343
- key, subkey = key[0], key[1:]
486
+ The proposed move for each tree.
487
+ """
488
+ num_trees, _ = forest.leaf_trees.shape
489
+ keys = split(key, 1 + 2 * num_trees)
344
490
 
345
491
  # compute moves
346
- grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
492
+ grow_moves = propose_grow_moves(
493
+ keys.pop(num_trees),
494
+ forest.var_trees,
495
+ forest.split_trees,
496
+ forest.affluence_trees,
497
+ max_split,
498
+ forest.p_nonterminal,
499
+ forest.p_propose_grow,
500
+ )
501
+ prune_moves = propose_prune_moves(
502
+ keys.pop(num_trees),
503
+ forest.split_trees,
504
+ forest.affluence_trees,
505
+ forest.p_nonterminal,
506
+ forest.p_propose_grow,
507
+ )
347
508
 
348
- u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
509
+ u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
349
510
 
350
511
  # choose between grow or prune
351
- grow_allowed = grow_moves['num_growable'].astype(bool)
352
- p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
353
- grow = u < p_grow # use < instead of <= because u is in [0, 1)
512
+ grow_allowed = grow_moves.num_growable.astype(bool)
513
+ p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
514
+ grow = u < p_grow # use < instead of <= because u is in [0, 1)
354
515
 
355
516
  # compute children indices
356
- node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
517
+ node = jnp.where(grow, grow_moves.node, prune_moves.node)
357
518
  left = node << 1
358
519
  right = left + 1
359
520
 
360
- return dict(
361
- allowed=grow | prune_moves['allowed'],
521
+ return Moves(
522
+ allowed=grow | prune_moves.allowed,
362
523
  grow=grow,
363
- num_growable=grow_moves['num_growable'],
524
+ num_growable=grow_moves.num_growable,
364
525
  node=node,
365
526
  left=left,
366
527
  right=right,
367
- partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
368
- grow_var=grow_moves['var'],
369
- grow_split=grow_moves['split'],
370
- var_trees=grow_moves['var_tree'],
528
+ partial_ratio=jnp.where(
529
+ grow, grow_moves.partial_ratio, prune_moves.partial_ratio
530
+ ),
531
+ log_trans_prior_ratio=None, # will be set in complete_ratio
532
+ grow_var=grow_moves.var,
533
+ grow_split=grow_moves.split,
534
+ var_trees=grow_moves.var_tree,
371
535
  logu=jnp.log1p(-logu),
536
+ acc=None, # will be set in accept_moves_sequential_stage
537
+ to_prune=None, # will be set in accept_moves_sequential_stage
372
538
  )
373
539
 
374
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
375
- def _sample_moves_vmap_trees(*args):
376
- args, key = args[:-1], args[-1]
377
- key, key1 = random.split(key)
378
- grow = grow_move(*args, key)
379
- prune = prune_move(*args, key1)
380
- return grow, prune
381
540
 
382
- def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
541
+ class GrowMoves(Module):
383
542
  """
384
- Tree structure grow move proposal of BART MCMC.
543
+ Represent a proposed grow move for each tree.
385
544
 
386
- This moves picks a leaf node and converts it to a non-terminal node with
387
- two leaf children. The move is not possible if all the leaves are already at
388
- maximum depth.
545
+ Parameters
546
+ ----------
547
+ num_growable
548
+ The number of growable leaves.
549
+ node
550
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
551
+ leaves.
552
+ var
553
+ split
554
+ The decision axis and boundary of the new rule.
555
+ partial_ratio
556
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
557
+ the likelihood ratio and the probability of proposing the prune
558
+ move.
559
+ var_tree
560
+ The updated decision axes of the tree.
561
+ """
562
+
563
+ num_growable: UInt[Array, 'num_trees']
564
+ node: UInt[Array, 'num_trees']
565
+ var: UInt[Array, 'num_trees']
566
+ split: UInt[Array, 'num_trees']
567
+ partial_ratio: Float32[Array, 'num_trees']
568
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
569
+
570
+
571
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
572
+ def propose_grow_moves(
573
+ key: Key[Array, ''],
574
+ var_tree: UInt[Array, '2**(d-1)'],
575
+ split_tree: UInt[Array, '2**(d-1)'],
576
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
577
+ max_split: UInt[Array, 'p'],
578
+ p_nonterminal: Float32[Array, 'd'],
579
+ p_propose_grow: Float32[Array, '2**(d-1)'],
580
+ ) -> GrowMoves:
581
+ """
582
+ Propose a GROW move for each tree.
583
+
584
+ A GROW move picks a leaf node and converts it to a non-terminal node with
585
+ two leaf children.
389
586
 
390
587
  Parameters
391
588
  ----------
392
- var_tree : array (2 ** (d - 1),)
393
- The variable indices of the tree.
394
- split_tree : array (2 ** (d - 1),)
589
+ key
590
+ A jax random key.
591
+ var_tree
592
+ The splitting axes of the tree.
593
+ split_tree
395
594
  The splitting points of the tree.
396
- affluence_tree : bool array (2 ** (d - 1),) or None
595
+ affluence_tree
397
596
  Whether a leaf has enough points to be grown.
398
- max_split : array (p,)
597
+ max_split
399
598
  The maximum split index for each variable.
400
- p_nonterminal : array (d,)
599
+ p_nonterminal
401
600
  The probability of a nonterminal node at each depth.
402
- p_propose_grow : array (2 ** (d - 1),)
601
+ p_propose_grow
403
602
  The unnormalized probability of choosing a leaf to grow.
404
- key : jax.dtypes.prng_key array
405
- A jax random key.
406
603
 
407
604
  Returns
408
605
  -------
409
- grow_move : dict
410
- A dictionary with fields:
606
+ An object representing the proposed move.
411
607
 
412
- 'num_growable' : int
413
- The number of growable leaves.
414
- 'node' : int
415
- The index of the leaf to grow. ``2 ** d`` if there are no growable
416
- leaves.
417
- 'var', 'split' : int
418
- The decision axis and boundary of the new rule.
419
- 'partial_ratio' : float
420
- A factor of the Metropolis-Hastings ratio of the move. It lacks
421
- the likelihood ratio and the probability of proposing the prune
422
- move.
423
- 'var_tree' : array (2 ** (d - 1),)
424
- The updated decision axes of the tree.
425
- """
608
+ Notes
609
+ -----
610
+ The move is not proposed if a leaf is already at maximum depth, or if a leaf
611
+ has less than twice the requested minimum number of datapoints per leaf.
612
+ This is marked by returning `num_growable` set to 0.
426
613
 
427
- key, key1, key2 = random.split(key, 3)
614
+ The move is also not be possible if the ancestors of a leaf have
615
+ exhausted the possible decision rules that lead to a non-empty selection.
616
+ This is marked by returning `var` set to `p` and `split` set to 0. But this
617
+ does not block the move from counting as "proposed", even though it is
618
+ predictably going to be rejected. This simplifies the MCMC and should not
619
+ reduce efficiency if not in unrealistic corner cases.
620
+ """
621
+ keys = split(key, 3)
428
622
 
429
- leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
623
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
624
+ keys.pop(), split_tree, affluence_tree, p_propose_grow
625
+ )
430
626
 
431
- var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
627
+ var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
432
628
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
433
629
 
434
- split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
630
+ split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
435
631
 
436
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
632
+ ratio = compute_partial_ratio(
633
+ prob_choose, num_prunable, p_nonterminal, leaf_to_grow
634
+ )
437
635
 
438
- return dict(
636
+ return GrowMoves(
439
637
  num_growable=num_growable,
440
638
  node=leaf_to_grow,
441
639
  var=var,
442
- split=split,
640
+ split=split_idx,
443
641
  partial_ratio=ratio,
444
642
  var_tree=var_tree,
445
643
  )
446
644
 
447
- def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
645
+ # TODO it is not clear to me how var=p and split=0 when the move is not
646
+ # possible lead to corrent behavior downstream. Like, the move is proposed,
647
+ # but then it's a noop? And since it's a noop, it makes no difference if
648
+ # it's "accepted" or "rejected", it's like it's always rejected, so who
649
+ # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
650
+
651
+
652
+ def choose_leaf(
653
+ key: Key[Array, ''],
654
+ split_tree: UInt[Array, '2**(d-1)'],
655
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
656
+ p_propose_grow: Float32[Array, '2**(d-1)'],
657
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
448
658
  """
449
659
  Choose a leaf node to grow in a tree.
450
660
 
451
661
  Parameters
452
662
  ----------
453
- split_tree : array (2 ** (d - 1),)
663
+ key
664
+ A jax random key.
665
+ split_tree
454
666
  The splitting points of the tree.
455
- affluence_tree : bool array (2 ** (d - 1),) or None
456
- Whether a leaf has enough points to be grown.
457
- p_propose_grow : array (2 ** (d - 1),)
667
+ affluence_tree
668
+ Whether a leaf has enough points that it could be split into two leaves
669
+ satisfying the `min_points_per_leaf` requirement.
670
+ p_propose_grow
458
671
  The unnormalized probability of choosing a leaf to grow.
459
- key : jax.dtypes.prng_key array
460
- A jax random key.
461
672
 
462
673
  Returns
463
674
  -------
@@ -465,9 +676,11 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
465
676
  The index of the leaf to grow. If ``num_growable == 0``, return
466
677
  ``2 ** d``.
467
678
  num_growable : int
468
- The number of leaf nodes that can be grown.
679
+ The number of leaf nodes that can be grown, i.e., are nonterminal
680
+ and have at least twice `min_points_per_leaf` if set.
469
681
  prob_choose : float
470
- The normalized probability of choosing the selected leaf.
682
+ The (normalized) probability that this function had to choose that
683
+ specific leaf, given the arguments.
471
684
  num_prunable : int
472
685
  The number of leaf parents that could be pruned, after converting the
473
686
  selected leaf to a non-terminal node.
@@ -482,71 +695,86 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
482
695
  num_prunable = jnp.count_nonzero(is_parent)
483
696
  return leaf_to_grow, num_growable, prob_choose, num_prunable
484
697
 
485
- def growable_leaves(split_tree, affluence_tree):
698
+
699
+ def growable_leaves(
700
+ split_tree: UInt[Array, '2**(d-1)'],
701
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
702
+ ) -> Bool[Array, '2**(d-1)']:
486
703
  """
487
704
  Return a mask indicating the leaf nodes that can be proposed for growth.
488
705
 
706
+ The condition is that a leaf is not at the bottom level and has at least two
707
+ times the number of minimum points per leaf.
708
+
489
709
  Parameters
490
710
  ----------
491
- split_tree : array (2 ** (d - 1),)
711
+ split_tree
492
712
  The splitting points of the tree.
493
- affluence_tree : bool array (2 ** (d - 1),) or None
713
+ affluence_tree
494
714
  Whether a leaf has enough points to be grown.
495
715
 
496
716
  Returns
497
717
  -------
498
- is_growable : bool array (2 ** (d - 1),)
499
- The mask indicating the leaf nodes that can be proposed to grow, i.e.,
500
- that are not at the bottom level and have at least two times the number
501
- of minimum points per leaf.
718
+ The mask indicating the leaf nodes that can be proposed to grow.
502
719
  """
503
720
  is_growable = grove.is_actual_leaf(split_tree)
504
721
  if affluence_tree is not None:
505
722
  is_growable &= affluence_tree
506
723
  return is_growable
507
724
 
508
- def categorical(key, distr):
725
+
726
+ def categorical(
727
+ key: Key[Array, ''], distr: Float32[Array, 'n']
728
+ ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
509
729
  """
510
730
  Return a random integer from an arbitrary distribution.
511
731
 
512
732
  Parameters
513
733
  ----------
514
- key : jax.dtypes.prng_key array
734
+ key
515
735
  A jax random key.
516
- distr : float array (n,)
736
+ distr
517
737
  An unnormalized probability distribution.
518
738
 
519
739
  Returns
520
740
  -------
521
- u : int
741
+ u : Int32[Array, '']
522
742
  A random integer in the range ``[0, n)``. If all probabilities are zero,
523
743
  return ``n``.
744
+ norm : Float32[Array, '']
745
+ The sum of `distr`.
524
746
  """
525
747
  ecdf = jnp.cumsum(distr)
526
748
  u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
527
749
  return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
528
750
 
529
- def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
751
+
752
+ def choose_variable(
753
+ key: Key[Array, ''],
754
+ var_tree: UInt[Array, '2**(d-1)'],
755
+ split_tree: UInt[Array, '2**(d-1)'],
756
+ max_split: UInt[Array, 'p'],
757
+ leaf_index: Int32[Array, ''],
758
+ ) -> Int32[Array, '']:
530
759
  """
531
760
  Choose a variable to split on for a new non-terminal node.
532
761
 
533
762
  Parameters
534
763
  ----------
535
- var_tree : int array (2 ** (d - 1),)
764
+ key
765
+ A jax random key.
766
+ var_tree
536
767
  The variable indices of the tree.
537
- split_tree : int array (2 ** (d - 1),)
768
+ split_tree
538
769
  The splitting points of the tree.
539
- max_split : int array (p,)
770
+ max_split
540
771
  The maximum split index for each variable.
541
- leaf_index : int
772
+ leaf_index
542
773
  The index of the leaf to grow.
543
- key : jax.dtypes.prng_key array
544
- A jax random key.
545
774
 
546
775
  Returns
547
776
  -------
548
- var : int
549
- The index of the variable to split on.
777
+ The index of the variable to split on.
550
778
 
551
779
  Notes
552
780
  -----
@@ -556,37 +784,49 @@ def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
556
784
  var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
557
785
  return randint_exclude(key, max_split.size, var_to_ignore)
558
786
 
559
- def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
787
+
788
+ def fully_used_variables(
789
+ var_tree: UInt[Array, '2**(d-1)'],
790
+ split_tree: UInt[Array, '2**(d-1)'],
791
+ max_split: UInt[Array, 'p'],
792
+ leaf_index: Int32[Array, ''],
793
+ ) -> UInt[Array, 'd-2']:
560
794
  """
561
795
  Return a list of variables that have an empty split range at a given node.
562
796
 
563
797
  Parameters
564
798
  ----------
565
- var_tree : int array (2 ** (d - 1),)
799
+ var_tree
566
800
  The variable indices of the tree.
567
- split_tree : int array (2 ** (d - 1),)
801
+ split_tree
568
802
  The splitting points of the tree.
569
- max_split : int array (p,)
803
+ max_split
570
804
  The maximum split index for each variable.
571
- leaf_index : int
805
+ leaf_index
572
806
  The index of the node, assumed to be valid for `var_tree`.
573
807
 
574
808
  Returns
575
809
  -------
576
- var_to_ignore : int array (d - 2,)
577
- The indices of the variables that have an empty split range. Since the
578
- number of such variables is not fixed, unused values in the array are
579
- filled with `p`. The fill values are not guaranteed to be placed in any
580
- particular order. Variables may appear more than once.
581
- """
810
+ The indices of the variables that have an empty split range.
582
811
 
812
+ Notes
813
+ -----
814
+ The number of unused variables is not known in advance. Unused values in the
815
+ array are filled with `p`. The fill values are not guaranteed to be placed
816
+ in any particular order, and variables may appear more than once.
817
+ """
583
818
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
584
819
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
585
820
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
586
821
  num_split = r - l
587
822
  return jnp.where(num_split == 0, var_to_ignore, max_split.size)
588
823
 
589
- def ancestor_variables(var_tree, max_split, node_index):
824
+
825
+ def ancestor_variables(
826
+ var_tree: UInt[Array, '2**(d-1)'],
827
+ max_split: UInt[Array, 'p'],
828
+ node_index: Int32[Array, ''],
829
+ ) -> UInt[Array, 'd-2']:
590
830
  """
591
831
  Return the list of variables in the ancestors of a node.
592
832
 
@@ -601,13 +841,18 @@ def ancestor_variables(var_tree, max_split, node_index):
601
841
 
602
842
  Returns
603
843
  -------
604
- ancestor_vars : int array (d - 2,)
605
- The variable indices of the ancestors of the node, from the root to
606
- the parent. Unused spots are filled with `p`.
844
+ The variable indices of the ancestors of the node.
845
+
846
+ Notes
847
+ -----
848
+ The ancestors are the nodes going from the root to the parent of the node.
849
+ The number of ancestors is not known at tracing time; unused spots in the
850
+ output array are filled with `p`.
607
851
  """
608
852
  max_num_ancestors = grove.tree_depth(var_tree) - 1
609
- ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
853
+ ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
610
854
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
855
+
611
856
  def loop(carry, _):
612
857
  i, index, ancestor_vars = carry
613
858
  index >>= 1
@@ -615,34 +860,44 @@ def ancestor_variables(var_tree, max_split, node_index):
615
860
  var = jnp.where(index, var, max_split.size)
616
861
  ancestor_vars = ancestor_vars.at[i].set(var)
617
862
  return (i - 1, index, ancestor_vars), None
863
+
618
864
  (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
619
865
  return ancestor_vars
620
866
 
621
- def split_range(var_tree, split_tree, max_split, node_index, ref_var):
867
+
868
+ def split_range(
869
+ var_tree: UInt[Array, '2**(d-1)'],
870
+ split_tree: UInt[Array, '2**(d-1)'],
871
+ max_split: UInt[Array, 'p'],
872
+ node_index: Int32[Array, ''],
873
+ ref_var: Int32[Array, ''],
874
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
622
875
  """
623
876
  Return the range of allowed splits for a variable at a given node.
624
877
 
625
878
  Parameters
626
879
  ----------
627
- var_tree : int array (2 ** (d - 1),)
880
+ var_tree
628
881
  The variable indices of the tree.
629
- split_tree : int array (2 ** (d - 1),)
882
+ split_tree
630
883
  The splitting points of the tree.
631
- max_split : int array (p,)
884
+ max_split
632
885
  The maximum split index for each variable.
633
- node_index : int
886
+ node_index
634
887
  The index of the node, assumed to be valid for `var_tree`.
635
- ref_var : int
888
+ ref_var
636
889
  The variable for which to measure the split range.
637
890
 
638
891
  Returns
639
892
  -------
640
- l, r : int
641
- The range of allowed splits is [l, r).
893
+ The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
642
894
  """
643
895
  max_num_ancestors = grove.tree_depth(var_tree) - 1
644
- initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(jnp.int32)
896
+ initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
897
+ jnp.int32
898
+ )
645
899
  carry = 0, initial_r, node_index
900
+
646
901
  def loop(carry, _):
647
902
  l, r, index = carry
648
903
  right_child = (index & 1).astype(bool)
@@ -652,89 +907,114 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
652
907
  l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
653
908
  r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
654
909
  return (l, r, index), None
910
+
655
911
  (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
656
912
  return l + 1, r
657
913
 
658
- def randint_exclude(key, sup, exclude):
914
+
915
+ def randint_exclude(
916
+ key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
917
+ ) -> Int32[Array, '']:
659
918
  """
660
919
  Return a random integer in a range, excluding some values.
661
920
 
662
921
  Parameters
663
922
  ----------
664
- key : jax.dtypes.prng_key array
923
+ key
665
924
  A jax random key.
666
- sup : int
925
+ sup
667
926
  The exclusive upper bound of the range.
668
- exclude : int array (n,)
927
+ exclude
669
928
  The values to exclude from the range. Values greater than or equal to
670
929
  `sup` are ignored. Values can appear more than once.
671
930
 
672
931
  Returns
673
932
  -------
674
- u : int
675
- A random integer in the range ``[0, sup)``, and which satisfies
676
- ``u not in exclude``. If all values in the range are excluded, return
677
- `sup`.
933
+ A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
934
+
935
+ Notes
936
+ -----
937
+ If all values in the range are excluded, return `sup`.
678
938
  """
679
939
  exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
680
940
  num_allowed = sup - jnp.count_nonzero(exclude < sup)
681
941
  u = random.randint(key, (), 0, num_allowed)
942
+
682
943
  def loop(u, i):
683
944
  return jnp.where(i <= u, u + 1, u), None
945
+
684
946
  u, _ = lax.scan(loop, u, exclude)
685
947
  return u
686
948
 
687
- def choose_split(var_tree, split_tree, max_split, leaf_index, key):
949
+
950
+ def choose_split(
951
+ key: Key[Array, ''],
952
+ var_tree: UInt[Array, '2**(d-1)'],
953
+ split_tree: UInt[Array, '2**(d-1)'],
954
+ max_split: UInt[Array, 'p'],
955
+ leaf_index: Int32[Array, ''],
956
+ ) -> Int32[Array, '']:
688
957
  """
689
958
  Choose a split point for a new non-terminal node.
690
959
 
691
960
  Parameters
692
961
  ----------
693
- var_tree : int array (2 ** (d - 1),)
694
- The variable indices of the tree.
695
- split_tree : int array (2 ** (d - 1),)
962
+ key
963
+ A jax random key.
964
+ var_tree
965
+ The splitting axes of the tree.
966
+ split_tree
696
967
  The splitting points of the tree.
697
- max_split : int array (p,)
968
+ max_split
698
969
  The maximum split index for each variable.
699
- leaf_index : int
970
+ leaf_index
700
971
  The index of the leaf to grow. It is assumed that `var_tree` already
701
972
  contains the target variable at this index.
702
- key : jax.dtypes.prng_key array
703
- A jax random key.
704
973
 
705
974
  Returns
706
975
  -------
707
- split : int
708
- The split point.
976
+ The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
709
977
  """
710
978
  var = var_tree[leaf_index]
711
979
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
712
980
  return random.randint(key, (), l, r)
713
981
 
714
- def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
982
+ # TODO what happens if leaf_index is out of bounds? And is the value used
983
+ # in that case?
984
+
985
+
986
+ def compute_partial_ratio(
987
+ prob_choose: Float32[Array, ''],
988
+ num_prunable: Int32[Array, ''],
989
+ p_nonterminal: Float32[Array, 'd'],
990
+ leaf_to_grow: Int32[Array, ''],
991
+ ) -> Float32[Array, '']:
715
992
  """
716
993
  Compute the product of the transition and prior ratios of a grow move.
717
994
 
718
995
  Parameters
719
996
  ----------
720
- num_growable : int
721
- The number of leaf nodes that can be grown.
722
- num_prunable : int
997
+ prob_choose
998
+ The probability that the leaf had to be chosen amongst the growable
999
+ leaves.
1000
+ num_prunable
723
1001
  The number of leaf parents that could be pruned, after converting the
724
1002
  leaf to be grown to a non-terminal node.
725
- p_nonterminal : array (d,)
1003
+ p_nonterminal
726
1004
  The probability of a nonterminal node at each depth.
727
- leaf_to_grow : int
1005
+ leaf_to_grow
728
1006
  The index of the leaf to grow.
729
1007
 
730
1008
  Returns
731
1009
  -------
732
- ratio : float
733
- The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
734
- times the prior ratio P(new tree) / P(old tree), but the transition
735
- ratio is missing the factor P(propose prune) in the numerator.
736
- """
1010
+ The partial transition ratio times the prior ratio.
737
1011
 
1012
+ Notes
1013
+ -----
1014
+ The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1015
+ The "partial" transition ratio returned is missing the factor P(propose
1016
+ prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
1017
+ """
738
1018
  # the two ratios also contain factors num_available_split *
739
1019
  # num_available_var, but they cancel out
740
1020
 
@@ -742,9 +1022,9 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
742
1022
  # computed in the acceptance phase
743
1023
 
744
1024
  prune_allowed = leaf_to_grow != 1
745
- # prune allowed <---> the initial tree is not a root
746
- # leaf to grow is root --> the tree can only be a root
747
- # tree is a root --> the only leaf I can grow is root
1025
+ # prune allowed <---> the initial tree is not a root
1026
+ # leaf to grow is root --> the tree can only be a root
1027
+ # tree is a root --> the only leaf I can grow is root
748
1028
 
749
1029
  p_grow = jnp.where(prune_allowed, 0.5, 1)
750
1030
 
@@ -757,76 +1037,104 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
757
1037
 
758
1038
  return tree_ratio / inv_trans_ratio
759
1039
 
760
- def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
1040
+
1041
+ class PruneMoves(Module):
1042
+ """
1043
+ Represent a proposed prune move for each tree.
1044
+
1045
+ Parameters
1046
+ ----------
1047
+ allowed
1048
+ Whether the move is possible.
1049
+ node
1050
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
1051
+ partial_ratio
1052
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
1053
+ the likelihood ratio and the probability of proposing the prune
1054
+ move. This ratio is inverted, and is meant to be inverted back in
1055
+ `accept_move_and_sample_leaves`.
1056
+ """
1057
+
1058
+ allowed: Bool[Array, 'num_trees']
1059
+ node: UInt[Array, 'num_trees']
1060
+ partial_ratio: Float32[Array, 'num_trees']
1061
+
1062
+
1063
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1064
+ def propose_prune_moves(
1065
+ key: Key[Array, ''],
1066
+ split_tree: UInt[Array, '2**(d-1)'],
1067
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
1068
+ p_nonterminal: Float32[Array, 'd'],
1069
+ p_propose_grow: Float32[Array, '2**(d-1)'],
1070
+ ) -> PruneMoves:
761
1071
  """
762
1072
  Tree structure prune move proposal of BART MCMC.
763
1073
 
764
1074
  Parameters
765
1075
  ----------
766
- var_tree : int array (2 ** (d - 1),)
767
- The variable indices of the tree.
768
- split_tree : int array (2 ** (d - 1),)
1076
+ key
1077
+ A jax random key.
1078
+ split_tree
769
1079
  The splitting points of the tree.
770
- affluence_tree : bool array (2 ** (d - 1),) or None
1080
+ affluence_tree
771
1081
  Whether a leaf has enough points to be grown.
772
- max_split : int array (p,)
773
- The maximum split index for each variable.
774
- p_nonterminal : float array (d,)
1082
+ p_nonterminal
775
1083
  The probability of a nonterminal node at each depth.
776
- p_propose_grow : float array (2 ** (d - 1),)
1084
+ p_propose_grow
777
1085
  The unnormalized probability of choosing a leaf to grow.
778
- key : jax.dtypes.prng_key array
779
- A jax random key.
780
1086
 
781
1087
  Returns
782
1088
  -------
783
- prune_move : dict
784
- A dictionary with fields:
785
-
786
- 'allowed' : bool
787
- Whether the move is possible.
788
- 'node' : int
789
- The index of the node to prune. ``2 ** d`` if no node can be pruned.
790
- 'partial_ratio' : float
791
- A factor of the Metropolis-Hastings ratio of the move. It lacks
792
- the likelihood ratio and the probability of proposing the prune
793
- move. This ratio is inverted.
1089
+ An object representing the proposed moves.
794
1090
  """
795
- node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
796
- allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
1091
+ node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
1092
+ key, split_tree, affluence_tree, p_propose_grow
1093
+ )
1094
+ allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
797
1095
 
798
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
1096
+ ratio = compute_partial_ratio(
1097
+ prob_choose, num_prunable, p_nonterminal, node_to_prune
1098
+ )
799
1099
 
800
- return dict(
1100
+ return PruneMoves(
801
1101
  allowed=allowed,
802
1102
  node=node_to_prune,
803
- partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
1103
+ partial_ratio=ratio,
804
1104
  )
805
1105
 
806
- def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
1106
+
1107
+ def choose_leaf_parent(
1108
+ key: Key[Array, ''],
1109
+ split_tree: UInt[Array, '2**(d-1)'],
1110
+ affluence_tree: Bool[Array, '2**(d-1)'] | None,
1111
+ p_propose_grow: Float32[Array, '2**(d-1)'],
1112
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
807
1113
  """
808
1114
  Pick a non-terminal node with leaf children to prune in a tree.
809
1115
 
810
1116
  Parameters
811
1117
  ----------
812
- split_tree : array (2 ** (d - 1),)
1118
+ key
1119
+ A jax random key.
1120
+ split_tree
813
1121
  The splitting points of the tree.
814
- affluence_tree : bool array (2 ** (d - 1),) or None
1122
+ affluence_tree
815
1123
  Whether a leaf has enough points to be grown.
816
- p_propose_grow : array (2 ** (d - 1),)
1124
+ p_propose_grow
817
1125
  The unnormalized probability of choosing a leaf to grow.
818
- key : jax.dtypes.prng_key array
819
- A jax random key.
820
1126
 
821
1127
  Returns
822
1128
  -------
823
- node_to_prune : int
1129
+ node_to_prune : Int32[Array, '']
824
1130
  The index of the node to prune. If ``num_prunable == 0``, return
825
1131
  ``2 ** d``.
826
- num_prunable : int
1132
+ num_prunable : Int32[Array, '']
827
1133
  The number of leaf parents that could be pruned.
828
- prob_choose : float
829
- The normalized probability of choosing the node to prune for growth.
1134
+ prob_choose : Float32[Array, '']
1135
+ The (normalized) probability that `choose_leaf` would chose
1136
+ `node_to_prune` as leaf to grow, if passed the tree where
1137
+ `node_to_prune` had been pruned.
830
1138
  """
831
1139
  is_prunable = grove.is_leaves_parent(split_tree)
832
1140
  num_prunable = jnp.count_nonzero(is_prunable)
@@ -834,517 +1142,910 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
834
1142
  node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
835
1143
 
836
1144
  split_tree = split_tree.at[node_to_prune].set(0)
837
- affluence_tree = (
838
- None if affluence_tree is None else
839
- affluence_tree.at[node_to_prune].set(True)
840
- )
1145
+ if affluence_tree is not None:
1146
+ affluence_tree = affluence_tree.at[node_to_prune].set(True)
841
1147
  is_growable_leaf = growable_leaves(split_tree, affluence_tree)
842
1148
  prob_choose = p_propose_grow[node_to_prune]
843
1149
  prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
844
1150
 
845
1151
  return node_to_prune, num_prunable, prob_choose
846
1152
 
847
- def randint_masked(key, mask):
1153
+
1154
+ def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
848
1155
  """
849
1156
  Return a random integer in a range, including only some values.
850
1157
 
851
1158
  Parameters
852
1159
  ----------
853
- key : jax.dtypes.prng_key array
1160
+ key
854
1161
  A jax random key.
855
- mask : bool array (n,)
1162
+ mask
856
1163
  The mask indicating the allowed values.
857
1164
 
858
1165
  Returns
859
1166
  -------
860
- u : int
861
- A random integer in the range ``[0, n)``, and which satisfies
862
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
1167
+ A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1168
+
1169
+ Notes
1170
+ -----
1171
+ If all values in the mask are `False`, return `n`.
863
1172
  """
864
1173
  ecdf = jnp.cumsum(mask)
865
1174
  u = random.randint(key, (), 0, ecdf[-1])
866
1175
  return jnp.searchsorted(ecdf, u, 'right')
867
1176
 
868
- def accept_moves_and_sample_leaves(bart, moves, key):
1177
+
1178
+ def accept_moves_and_sample_leaves(
1179
+ key: Key[Array, ''], bart: State, moves: Moves
1180
+ ) -> State:
869
1181
  """
870
1182
  Accept or reject the proposed moves and sample the new leaf values.
871
1183
 
872
1184
  Parameters
873
1185
  ----------
874
- bart : dict
875
- A BART mcmc state.
876
- moves : dict
877
- The proposed moves, see `sample_moves`.
878
- key : jax.dtypes.prng_key array
1186
+ key
879
1187
  A jax random key.
1188
+ bart
1189
+ A valid BART mcmc state.
1190
+ moves
1191
+ The proposed moves, see `propose_moves`.
880
1192
 
881
1193
  Returns
882
1194
  -------
883
- bart : dict
884
- The new BART mcmc state.
1195
+ A new (valid) BART mcmc state.
885
1196
  """
886
- bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
887
- bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
1197
+ pso = accept_moves_parallel_stage(key, bart, moves)
1198
+ bart, moves = accept_moves_sequential_stage(pso)
888
1199
  return accept_moves_final_stage(bart, moves)
889
1200
 
890
- def accept_moves_parallel_stage(bart, moves, key):
1201
+
1202
+ class Counts(Module):
1203
+ """
1204
+ Number of datapoints in the nodes involved in proposed moves for each tree.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ left
1209
+ Number of datapoints in the left child.
1210
+ right
1211
+ Number of datapoints in the right child.
1212
+ total
1213
+ Number of datapoints in the parent (``= left + right``).
1214
+ """
1215
+
1216
+ left: UInt[Array, 'num_trees']
1217
+ right: UInt[Array, 'num_trees']
1218
+ total: UInt[Array, 'num_trees']
1219
+
1220
+
1221
+ class Precs(Module):
1222
+ """
1223
+ Likelihood precision scale in the nodes involved in proposed moves for each tree.
1224
+
1225
+ The "likelihood precision scale" of a tree node is the sum of the inverse
1226
+ squared error scales of the datapoints selected by the node.
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ left
1231
+ Likelihood precision scale in the left child.
1232
+ right
1233
+ Likelihood precision scale in the right child.
1234
+ total
1235
+ Likelihood precision scale in the parent (``= left + right``).
1236
+ """
1237
+
1238
+ left: Float32[Array, 'num_trees']
1239
+ right: Float32[Array, 'num_trees']
1240
+ total: Float32[Array, 'num_trees']
1241
+
1242
+
1243
+ class PreLkV(Module):
1244
+ """
1245
+ Non-sequential terms of the likelihood ratio for each tree.
1246
+
1247
+ These terms can be computed in parallel across trees.
1248
+
1249
+ Parameters
1250
+ ----------
1251
+ sigma2_left
1252
+ The noise variance in the left child of the leaves grown or pruned by
1253
+ the moves.
1254
+ sigma2_right
1255
+ The noise variance in the right child of the leaves grown or pruned by
1256
+ the moves.
1257
+ sigma2_total
1258
+ The noise variance in the total of the leaves grown or pruned by the
1259
+ moves.
1260
+ sqrt_term
1261
+ The **logarithm** of the square root term of the likelihood ratio.
1262
+ """
1263
+
1264
+ sigma2_left: Float32[Array, 'num_trees']
1265
+ sigma2_right: Float32[Array, 'num_trees']
1266
+ sigma2_total: Float32[Array, 'num_trees']
1267
+ sqrt_term: Float32[Array, 'num_trees']
1268
+
1269
+
1270
+ class PreLk(Module):
1271
+ """
1272
+ Non-sequential terms of the likelihood ratio shared by all trees.
1273
+
1274
+ Parameters
1275
+ ----------
1276
+ exp_factor
1277
+ The factor to multiply the likelihood ratio by, shared by all trees.
1278
+ """
1279
+
1280
+ exp_factor: Float32[Array, '']
1281
+
1282
+
1283
+ class PreLf(Module):
1284
+ """
1285
+ Pre-computed terms used to sample leaves from their posterior.
1286
+
1287
+ These terms can be computed in parallel across trees.
1288
+
1289
+ Parameters
1290
+ ----------
1291
+ mean_factor
1292
+ The factor to be multiplied by the sum of the scaled residuals to
1293
+ obtain the posterior mean.
1294
+ centered_leaves
1295
+ The mean-zero normal values to be added to the posterior mean to
1296
+ obtain the posterior leaf samples.
1297
+ """
1298
+
1299
+ mean_factor: Float32[Array, 'num_trees 2**d']
1300
+ centered_leaves: Float32[Array, 'num_trees 2**d']
1301
+
1302
+
1303
+ class ParallelStageOut(Module):
1304
+ """
1305
+ The output of `accept_moves_parallel_stage`.
1306
+
1307
+ Parameters
1308
+ ----------
1309
+ bart
1310
+ A partially updated BART mcmc state.
1311
+ moves
1312
+ The proposed moves, with `partial_ratio` set to `None` and
1313
+ `log_trans_prior_ratio` set to its final value.
1314
+ prec_trees
1315
+ The likelihood precision scale in each potential or actual leaf node. If
1316
+ there is no precision scale, this is the number of points in each leaf.
1317
+ move_counts
1318
+ The counts of the number of points in the the nodes modified by the
1319
+ moves. If `bart.min_points_per_leaf` is not set and
1320
+ `bart.prec_scale` is set, they are not computed.
1321
+ move_precs
1322
+ The likelihood precision scale in each node modified by the moves. If
1323
+ `bart.prec_scale` is not set, this is set to `move_counts`.
1324
+ prelkv
1325
+ prelk
1326
+ prelf
1327
+ Objects with pre-computed terms of the likelihood ratios and leaf
1328
+ samples.
1329
+ """
1330
+
1331
+ bart: State
1332
+ moves: Moves
1333
+ prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1334
+ move_counts: Counts | None
1335
+ move_precs: Precs | Counts
1336
+ prelkv: PreLkV
1337
+ prelk: PreLk
1338
+ prelf: PreLf
1339
+
1340
+
1341
+ def accept_moves_parallel_stage(
1342
+ key: Key[Array, ''], bart: State, moves: Moves
1343
+ ) -> ParallelStageOut:
891
1344
  """
892
1345
  Pre-computes quantities used to accept moves, in parallel across trees.
893
1346
 
894
1347
  Parameters
895
1348
  ----------
1349
+ key : jax.dtypes.prng_key array
1350
+ A jax random key.
896
1351
  bart : dict
897
1352
  A BART mcmc state.
898
1353
  moves : dict
899
- The proposed moves, see `sample_moves`.
900
- key : jax.dtypes.prng_key array
901
- A jax random key.
1354
+ The proposed moves, see `propose_moves`.
902
1355
 
903
1356
  Returns
904
1357
  -------
905
- bart : dict
906
- A partially updated BART mcmc state.
907
- moves : dict
908
- The proposed moves, with the field 'partial_ratio' replaced
909
- by 'log_trans_prior_ratio'.
910
- count_trees : array (num_trees, 2 ** d)
911
- The number of points in each potential or actual leaf node.
912
- move_counts : dict
913
- The counts of the number of points in the the nodes modified by the
914
- moves.
915
- prelkv, prelk, prelf : dict
916
- Dictionary with pre-computed terms of the likelihood ratios and leaf
917
- samples.
1358
+ An object with all that could be done in parallel.
918
1359
  """
919
- bart = bart.copy()
920
-
921
1360
  # where the move is grow, modify the state like the move was accepted
922
- bart['var_trees'] = moves['var_trees']
923
- bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
924
- bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
1361
+ bart = replace(
1362
+ bart,
1363
+ forest=replace(
1364
+ bart.forest,
1365
+ var_trees=moves.var_trees,
1366
+ leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1367
+ leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
1368
+ ),
1369
+ )
925
1370
 
926
1371
  # count number of datapoints per leaf
927
- count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
928
- if bart['opt']['require_min_points']:
929
- count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
930
- bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
1372
+ if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
1373
+ count_trees, move_counts = compute_count_trees(
1374
+ bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1375
+ )
1376
+ else:
1377
+ # move_counts is passed later to a function, but then is unused under
1378
+ # this condition
1379
+ move_counts = None
1380
+
1381
+ # Check if some nodes can't surely be grown because they don't have enough
1382
+ # datapoints. This check is not actually used now, it will be used at the
1383
+ # beginning of the next step to propose moves.
1384
+ if bart.forest.min_points_per_leaf is not None:
1385
+ count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
1386
+ bart = replace(
1387
+ bart,
1388
+ forest=replace(
1389
+ bart.forest,
1390
+ affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
1391
+ ),
1392
+ )
1393
+
1394
+ # count number of datapoints per leaf, weighted by error precision scale
1395
+ if bart.prec_scale is None:
1396
+ prec_trees = count_trees
1397
+ move_precs = move_counts
1398
+ else:
1399
+ prec_trees, move_precs = compute_prec_trees(
1400
+ bart.prec_scale,
1401
+ bart.forest.leaf_indices,
1402
+ moves,
1403
+ bart.forest.count_batch_size,
1404
+ )
931
1405
 
932
1406
  # compute some missing information about moves
933
- moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
934
- bart['grow_prop_count'] = jnp.sum(moves['grow'])
935
- bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
1407
+ moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
1408
+ bart = replace(
1409
+ bart,
1410
+ forest=replace(
1411
+ bart.forest,
1412
+ grow_prop_count=jnp.sum(moves.grow),
1413
+ prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1414
+ ),
1415
+ )
936
1416
 
937
- prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
938
- prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
1417
+ prelkv, prelk = precompute_likelihood_terms(
1418
+ bart.sigma2, bart.forest.sigma_mu2, move_precs
1419
+ )
1420
+ prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
1421
+
1422
+ return ParallelStageOut(
1423
+ bart=bart,
1424
+ moves=moves,
1425
+ prec_trees=prec_trees,
1426
+ move_counts=move_counts,
1427
+ move_precs=move_precs,
1428
+ prelkv=prelkv,
1429
+ prelk=prelk,
1430
+ prelf=prelf,
1431
+ )
939
1432
 
940
- return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
941
1433
 
942
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
943
- def apply_grow_to_indices(moves, leaf_indices, X):
1434
+ @partial(vmap_nodoc, in_axes=(0, 0, None))
1435
+ def apply_grow_to_indices(
1436
+ moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1437
+ ) -> UInt[Array, 'num_trees n']:
944
1438
  """
945
1439
  Update the leaf indices to apply a grow move.
946
1440
 
947
1441
  Parameters
948
1442
  ----------
949
- moves : dict
950
- The proposed moves, see `sample_moves`.
951
- leaf_indices : array (num_trees, n)
1443
+ moves
1444
+ The proposed moves, see `propose_moves`.
1445
+ leaf_indices
952
1446
  The index of the leaf each datapoint falls into.
953
- X : array (p, n)
1447
+ X
954
1448
  The predictors matrix.
955
1449
 
956
1450
  Returns
957
1451
  -------
958
- grow_leaf_indices : array (num_trees, n)
959
- The updated leaf indices.
1452
+ The updated leaf indices.
960
1453
  """
961
- left_child = moves['node'].astype(leaf_indices.dtype) << 1
962
- go_right = X[moves['grow_var'], :] >= moves['grow_split']
963
- tree_size = jnp.array(2 * moves['var_trees'].size)
964
- node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
1454
+ left_child = moves.node.astype(leaf_indices.dtype) << 1
1455
+ go_right = X[moves.grow_var, :] >= moves.grow_split
1456
+ tree_size = jnp.array(2 * moves.var_trees.size)
1457
+ node_to_update = jnp.where(moves.grow, moves.node, tree_size)
965
1458
  return jnp.where(
966
1459
  leaf_indices == node_to_update,
967
1460
  left_child + go_right,
968
1461
  leaf_indices,
969
1462
  )
970
1463
 
971
- def compute_count_trees(leaf_indices, moves, batch_size):
1464
+
1465
+ def compute_count_trees(
1466
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1467
+ ) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
972
1468
  """
973
1469
  Count the number of datapoints in each leaf.
974
1470
 
975
1471
  Parameters
976
1472
  ----------
977
- grow_leaf_indices : int array (num_trees, n)
978
- The index of the leaf each datapoint falls into, if the grow move is
979
- accepted.
980
- moves : dict
981
- The proposed moves, see `sample_moves`.
982
- batch_size : int or None
1473
+ leaf_indices
1474
+ The index of the leaf each datapoint falls into, with the deeper version
1475
+ of the tree (post-GROW, pre-PRUNE).
1476
+ moves
1477
+ The proposed moves, see `propose_moves`.
1478
+ batch_size
983
1479
  The data batch size to use for the summation.
984
1480
 
985
1481
  Returns
986
1482
  -------
987
- count_trees : int array (num_trees, 2 ** (d - 1))
1483
+ count_trees : Int32[Array, 'num_trees 2**d']
988
1484
  The number of points in each potential or actual leaf node.
989
- counts : dict
990
- The counts of the number of points in the the nodes modified by the
991
- moves, organized as two dictionaries 'grow' and 'prune', with subfields
992
- 'left', 'right', and 'total'.
1485
+ counts : Counts
1486
+ The counts of the number of points in the leaves grown or pruned by the
1487
+ moves.
993
1488
  """
994
-
995
- ntree, tree_size = moves['var_trees'].shape
1489
+ num_trees, tree_size = moves.var_trees.shape
996
1490
  tree_size *= 2
997
- tree_indices = jnp.arange(ntree)
1491
+ tree_indices = jnp.arange(num_trees)
998
1492
 
999
1493
  count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1000
1494
 
1001
1495
  # count datapoints in nodes modified by move
1002
- counts = dict()
1003
- counts['left'] = count_trees[tree_indices, moves['left']]
1004
- counts['right'] = count_trees[tree_indices, moves['right']]
1005
- counts['total'] = counts['left'] + counts['right']
1496
+ left = count_trees[tree_indices, moves.left]
1497
+ right = count_trees[tree_indices, moves.right]
1498
+ counts = Counts(left=left, right=right, total=left + right)
1006
1499
 
1007
1500
  # write count into non-leaf node
1008
- count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
1501
+ count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
1009
1502
 
1010
1503
  return count_trees, counts
1011
1504
 
1012
- def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1505
+
1506
+ def count_datapoints_per_leaf(
1507
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1508
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1013
1509
  """
1014
1510
  Count the number of datapoints in each leaf.
1015
1511
 
1016
1512
  Parameters
1017
1513
  ----------
1018
- leaf_indices : int array (num_trees, n)
1514
+ leaf_indices
1019
1515
  The index of the leaf each datapoint falls into.
1020
- tree_size : int
1516
+ tree_size
1021
1517
  The size of the leaf tree array (2 ** d).
1022
- batch_size : int or None
1518
+ batch_size
1023
1519
  The data batch size to use for the summation.
1024
1520
 
1025
1521
  Returns
1026
1522
  -------
1027
- count_trees : int array (num_trees, 2 ** (d - 1))
1028
- The number of points in each leaf node.
1523
+ The number of points in each leaf node.
1029
1524
  """
1030
1525
  if batch_size is None:
1031
1526
  return _count_scan(leaf_indices, tree_size)
1032
1527
  else:
1033
1528
  return _count_vec(leaf_indices, tree_size, batch_size)
1034
1529
 
1035
- def _count_scan(leaf_indices, tree_size):
1530
+
1531
+ def _count_scan(
1532
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1533
+ ) -> Int32[Array, 'num_trees {tree_size}']:
1036
1534
  def loop(_, leaf_indices):
1037
1535
  return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1536
+
1038
1537
  _, count_trees = lax.scan(loop, None, leaf_indices)
1039
1538
  return count_trees
1040
1539
 
1041
- def _aggregate_scatter(values, indices, size, dtype):
1042
- return (jnp
1043
- .zeros(size, dtype)
1044
- .at[indices]
1045
- .add(values)
1046
- )
1047
1540
 
1048
- def _count_vec(leaf_indices, tree_size, batch_size):
1049
- return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
1050
- # uint16 is super-slow on gpu, don't use it even if n < 2^16
1541
+ def _aggregate_scatter(
1542
+ values: Shaped[Array, '*'],
1543
+ indices: Integer[Array, '*'],
1544
+ size: int,
1545
+ dtype: jnp.dtype,
1546
+ ) -> Shaped[Array, '{size}']:
1547
+ return jnp.zeros(size, dtype).at[indices].add(values)
1051
1548
 
1052
- def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1053
- ntree, n = indices.shape
1054
- tree_indices = jnp.arange(ntree)
1549
+
1550
+ def _count_vec(
1551
+ leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1552
+ ) -> Int32[Array, 'num_trees 2**(d-1)']:
1553
+ return _aggregate_batched_alltrees(
1554
+ 1, leaf_indices, tree_size, jnp.uint32, batch_size
1555
+ )
1556
+ # uint16 is super-slow on gpu, don't use it even if n < 2^16
1557
+
1558
+
1559
+ def _aggregate_batched_alltrees(
1560
+ values: Shaped[Array, '*'],
1561
+ indices: UInt[Array, 'num_trees n'],
1562
+ size: int,
1563
+ dtype: jnp.dtype,
1564
+ batch_size: int,
1565
+ ) -> Shaped[Array, 'num_trees {size}']:
1566
+ num_trees, n = indices.shape
1567
+ tree_indices = jnp.arange(num_trees)
1055
1568
  nbatches = n // batch_size + bool(n % batch_size)
1056
1569
  batch_indices = jnp.arange(n) % nbatches
1057
- return (jnp
1058
- .zeros((ntree, size, nbatches), dtype)
1570
+ return (
1571
+ jnp.zeros((num_trees, size, nbatches), dtype)
1059
1572
  .at[tree_indices[:, None], indices, batch_indices]
1060
1573
  .add(values)
1061
1574
  .sum(axis=2)
1062
1575
  )
1063
1576
 
1064
- def complete_ratio(moves, move_counts, min_points_per_leaf):
1577
+
1578
+ def compute_prec_trees(
1579
+ prec_scale: Float32[Array, 'n'],
1580
+ leaf_indices: UInt[Array, 'num_trees n'],
1581
+ moves: Moves,
1582
+ batch_size: int | None,
1583
+ ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1584
+ """
1585
+ Compute the likelihood precision scale in each leaf.
1586
+
1587
+ Parameters
1588
+ ----------
1589
+ prec_scale
1590
+ The scale of the precision of the error on each datapoint.
1591
+ leaf_indices
1592
+ The index of the leaf each datapoint falls into, with the deeper version
1593
+ of the tree (post-GROW, pre-PRUNE).
1594
+ moves
1595
+ The proposed moves, see `propose_moves`.
1596
+ batch_size
1597
+ The data batch size to use for the summation.
1598
+
1599
+ Returns
1600
+ -------
1601
+ prec_trees : Float32[Array, 'num_trees 2**d']
1602
+ The likelihood precision scale in each potential or actual leaf node.
1603
+ precs : Precs
1604
+ The likelihood precision scale in the nodes involved in the moves.
1605
+ """
1606
+ num_trees, tree_size = moves.var_trees.shape
1607
+ tree_size *= 2
1608
+ tree_indices = jnp.arange(num_trees)
1609
+
1610
+ prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1611
+
1612
+ # prec datapoints in nodes modified by move
1613
+ left = prec_trees[tree_indices, moves.left]
1614
+ right = prec_trees[tree_indices, moves.right]
1615
+ precs = Precs(left=left, right=right, total=left + right)
1616
+
1617
+ # write prec into non-leaf node
1618
+ prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
1619
+
1620
+ return prec_trees, precs
1621
+
1622
+
1623
+ def prec_per_leaf(
1624
+ prec_scale: Float32[Array, 'n'],
1625
+ leaf_indices: UInt[Array, 'num_trees n'],
1626
+ tree_size: int,
1627
+ batch_size: int | None,
1628
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1629
+ """
1630
+ Compute the likelihood precision scale in each leaf.
1631
+
1632
+ Parameters
1633
+ ----------
1634
+ prec_scale
1635
+ The scale of the precision of the error on each datapoint.
1636
+ leaf_indices
1637
+ The index of the leaf each datapoint falls into.
1638
+ tree_size
1639
+ The size of the leaf tree array (2 ** d).
1640
+ batch_size
1641
+ The data batch size to use for the summation.
1642
+
1643
+ Returns
1644
+ -------
1645
+ The likelihood precision scale in each leaf node.
1646
+ """
1647
+ if batch_size is None:
1648
+ return _prec_scan(prec_scale, leaf_indices, tree_size)
1649
+ else:
1650
+ return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1651
+
1652
+
1653
+ def _prec_scan(
1654
+ prec_scale: Float32[Array, 'n'],
1655
+ leaf_indices: UInt[Array, 'num_trees n'],
1656
+ tree_size: int,
1657
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1658
+ def loop(_, leaf_indices):
1659
+ return None, _aggregate_scatter(
1660
+ prec_scale, leaf_indices, tree_size, jnp.float32
1661
+ )
1662
+
1663
+ _, prec_trees = lax.scan(loop, None, leaf_indices)
1664
+ return prec_trees
1665
+
1666
+
1667
+ def _prec_vec(
1668
+ prec_scale: Float32[Array, 'n'],
1669
+ leaf_indices: UInt[Array, 'num_trees n'],
1670
+ tree_size: int,
1671
+ batch_size: int,
1672
+ ) -> Float32[Array, 'num_trees {tree_size}']:
1673
+ return _aggregate_batched_alltrees(
1674
+ prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1675
+ )
1676
+
1677
+
1678
+ def complete_ratio(
1679
+ moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1680
+ ) -> Moves:
1065
1681
  """
1066
1682
  Complete non-likelihood MH ratio calculation.
1067
1683
 
1068
- This functions adds the probability of choosing the prune move.
1684
+ This function adds the probability of choosing the prune move.
1069
1685
 
1070
1686
  Parameters
1071
1687
  ----------
1072
- moves : dict
1073
- The proposed moves, see `sample_moves`.
1074
- move_counts : dict
1688
+ moves
1689
+ The proposed moves, see `propose_moves`.
1690
+ move_counts
1075
1691
  The counts of the number of points in the the nodes modified by the
1076
1692
  moves.
1077
- min_points_per_leaf : int or None
1693
+ min_points_per_leaf
1078
1694
  The minimum number of data points in a leaf node.
1079
1695
 
1080
1696
  Returns
1081
1697
  -------
1082
- moves : dict
1083
- The updated moves, with the field 'partial_ratio' replaced by
1084
- 'log_trans_prior_ratio'.
1698
+ The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1085
1699
  """
1086
- moves = moves.copy()
1087
- p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
1088
- moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1089
- return moves
1700
+ p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
1701
+ return replace(
1702
+ moves,
1703
+ log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
1704
+ partial_ratio=None,
1705
+ )
1706
+
1090
1707
 
1091
- def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1708
+ def compute_p_prune(
1709
+ moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1710
+ ) -> Float32[Array, 'num_trees']:
1092
1711
  """
1093
- Compute the probability of proposing a prune move.
1712
+ Compute the probability of proposing a prune move for each tree.
1094
1713
 
1095
1714
  Parameters
1096
1715
  ----------
1097
- moves : dict
1098
- The proposed moves, see `sample_moves`.
1099
- left_count, right_count : int
1716
+ moves
1717
+ The proposed moves, see `propose_moves`.
1718
+ move_counts
1100
1719
  The number of datapoints in the proposed children of the leaf to grow.
1101
- min_points_per_leaf : int or None
1720
+ Not used if `min_points_per_leaf` is `None`.
1721
+ min_points_per_leaf
1102
1722
  The minimum number of data points in a leaf node.
1103
1723
 
1104
1724
  Returns
1105
1725
  -------
1106
- p_prune : float
1107
- The probability of proposing a prune move. If grow: after accepting the
1108
- grow move, if prune: right away.
1109
- """
1726
+ The probability of proposing a prune move.
1110
1727
 
1728
+ Notes
1729
+ -----
1730
+ This probability is computed for going from the state with the deeper tree
1731
+ to the one with the shallower one. This means, if grow: after accepting the
1732
+ grow move, if prune: right away.
1733
+ """
1111
1734
  # calculation in case the move is grow
1112
- other_growable_leaves = moves['num_growable'] >= 2
1113
- new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
1735
+ other_growable_leaves = moves.num_growable >= 2
1736
+ new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
1114
1737
  if min_points_per_leaf is not None:
1115
- any_above_threshold = left_count >= 2 * min_points_per_leaf
1116
- any_above_threshold |= right_count >= 2 * min_points_per_leaf
1738
+ assert move_counts is not None
1739
+ any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
1740
+ any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
1117
1741
  new_leaves_growable &= any_above_threshold
1118
1742
  grow_again_allowed = other_growable_leaves | new_leaves_growable
1119
1743
  grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1120
1744
 
1121
1745
  # calculation in case the move is prune
1122
- prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
1746
+ prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1747
+
1748
+ return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1123
1749
 
1124
- return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1125
1750
 
1126
- @jaxext.vmap_nodoc
1127
- def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1751
+ @vmap_nodoc
1752
+ def adapt_leaf_trees_to_grow_indices(
1753
+ leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1754
+ ) -> Float32[Array, 'num_trees 2**d']:
1128
1755
  """
1129
- Modify leaf values such that the indices of the grow moves work on the
1130
- original tree.
1756
+ Modify leaves such that post-grow indices work on the original tree.
1757
+
1758
+ The value of the leaf to grow is copied to what would be its children if the
1759
+ grow move was accepted.
1131
1760
 
1132
1761
  Parameters
1133
1762
  ----------
1134
- leaf_trees : float array (num_trees, 2 ** d)
1763
+ leaf_trees
1135
1764
  The leaf values.
1136
- moves : dict
1137
- The proposed moves, see `sample_moves`.
1765
+ moves
1766
+ The proposed moves, see `propose_moves`.
1138
1767
 
1139
1768
  Returns
1140
1769
  -------
1141
- leaf_trees : float array (num_trees, 2 ** d)
1142
- The modified leaf values. The value of the leaf to grow is copied to
1143
- what would be its children if the grow move was accepted.
1770
+ The modified leaf values.
1144
1771
  """
1145
- values_at_node = leaf_trees[moves['node']]
1146
- return (leaf_trees
1147
- .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1772
+ values_at_node = leaf_trees[moves.node]
1773
+ return (
1774
+ leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1148
1775
  .set(values_at_node)
1149
- .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1776
+ .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1150
1777
  .set(values_at_node)
1151
1778
  )
1152
1779
 
1153
- def precompute_likelihood_terms(count_trees, sigma2, move_counts):
1780
+
1781
+ def precompute_likelihood_terms(
1782
+ sigma2: Float32[Array, ''],
1783
+ sigma_mu2: Float32[Array, ''],
1784
+ move_precs: Precs | Counts,
1785
+ ) -> tuple[PreLkV, PreLk]:
1154
1786
  """
1155
1787
  Pre-compute terms used in the likelihood ratio of the acceptance step.
1156
1788
 
1157
1789
  Parameters
1158
1790
  ----------
1159
- count_trees : array (num_trees, 2 ** d)
1160
- The number of points in each potential or actual leaf node.
1161
- sigma2 : float
1162
- The noise variance.
1163
- move_counts : dict
1164
- The counts of the number of points in the the nodes modified by the
1165
- moves.
1791
+ sigma2
1792
+ The error variance, or the global error variance factor is `prec_scale`
1793
+ is set.
1794
+ sigma_mu2
1795
+ The prior variance of each leaf.
1796
+ move_precs
1797
+ The likelihood precision scale in the leaves grown or pruned by the
1798
+ moves, under keys 'left', 'right', and 'total' (left + right).
1166
1799
 
1167
1800
  Returns
1168
1801
  -------
1169
- prelkv : dict
1802
+ prelkv : PreLkV
1170
1803
  Dictionary with pre-computed terms of the likelihood ratio, one per
1171
1804
  tree.
1172
- prelk : dict
1805
+ prelk : PreLk
1173
1806
  Dictionary with pre-computed terms of the likelihood ratio, shared by
1174
1807
  all trees.
1175
1808
  """
1176
- ntree = len(count_trees)
1177
- sigma_mu2 = 1 / ntree
1178
- prelkv = dict()
1179
- prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
1180
- prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
1181
- prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
1182
- prelkv['sqrt_term'] = jnp.log(
1183
- sigma2 * prelkv['sigma2_total'] /
1184
- (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1185
- ) / 2
1186
- return prelkv, dict(
1809
+ sigma2_left = sigma2 + move_precs.left * sigma_mu2
1810
+ sigma2_right = sigma2 + move_precs.right * sigma_mu2
1811
+ sigma2_total = sigma2 + move_precs.total * sigma_mu2
1812
+ prelkv = PreLkV(
1813
+ sigma2_left=sigma2_left,
1814
+ sigma2_right=sigma2_right,
1815
+ sigma2_total=sigma2_total,
1816
+ sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1817
+ )
1818
+ return prelkv, PreLk(
1187
1819
  exp_factor=sigma_mu2 / (2 * sigma2),
1188
1820
  )
1189
1821
 
1190
- def precompute_leaf_terms(count_trees, sigma2, key):
1822
+
1823
+ def precompute_leaf_terms(
1824
+ key: Key[Array, ''],
1825
+ prec_trees: Float32[Array, 'num_trees 2**d'],
1826
+ sigma2: Float32[Array, ''],
1827
+ sigma_mu2: Float32[Array, ''],
1828
+ ) -> PreLf:
1191
1829
  """
1192
1830
  Pre-compute terms used to sample leaves from their posterior.
1193
1831
 
1194
1832
  Parameters
1195
1833
  ----------
1196
- count_trees : array (num_trees, 2 ** d)
1197
- The number of points in each potential or actual leaf node.
1198
- sigma2 : float
1199
- The noise variance.
1200
- key : jax.dtypes.prng_key array
1834
+ key
1201
1835
  A jax random key.
1836
+ prec_trees
1837
+ The likelihood precision scale in each potential or actual leaf node.
1838
+ sigma2
1839
+ The error variance, or the global error variance factor if `prec_scale`
1840
+ is set.
1841
+ sigma_mu2
1842
+ The prior variance of each leaf.
1202
1843
 
1203
1844
  Returns
1204
1845
  -------
1205
- prelf : dict
1206
- Dictionary with pre-computed terms of the leaf sampling, with fields:
1207
-
1208
- 'mean_factor' : float array (num_trees, 2 ** d)
1209
- The factor to be multiplied by the sum of residuals to obtain the
1210
- posterior mean.
1211
- 'centered_leaves' : float array (num_trees, 2 ** d)
1212
- The mean-zero normal values to be added to the posterior mean to
1213
- obtain the posterior leaf samples.
1214
- """
1215
- ntree = len(count_trees)
1216
- prec_lk = count_trees / sigma2
1217
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1218
- z = random.normal(key, count_trees.shape, sigma2.dtype)
1219
- return dict(
1220
- mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1846
+ Pre-computed terms for leaf sampling.
1847
+ """
1848
+ prec_lk = prec_trees / sigma2
1849
+ prec_prior = lax.reciprocal(sigma_mu2)
1850
+ var_post = lax.reciprocal(prec_lk + prec_prior)
1851
+ z = random.normal(key, prec_trees.shape, sigma2.dtype)
1852
+ return PreLf(
1853
+ mean_factor=var_post / sigma2,
1854
+ # mean = mean_lk * prec_lk * var_post
1855
+ # resid_tree = mean_lk * prec_tree -->
1856
+ # --> mean_lk = resid_tree / prec_tree (kind of)
1857
+ # mean_factor =
1858
+ # = mean / resid_tree =
1859
+ # = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1860
+ # = 1 / prec_tree * prec_tree / sigma2 * var_post =
1861
+ # = var_post / sigma2
1221
1862
  centered_leaves=z * jnp.sqrt(var_post),
1222
1863
  )
1223
1864
 
1224
- def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
1865
+
1866
+ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1225
1867
  """
1226
- The part of accepting the moves that has to be done one tree at a time.
1868
+ Accept/reject the moves one tree at a time.
1869
+
1870
+ This is the most performance-sensitive function because it contains all and
1871
+ only the parts of the algorithm that can not be parallelized across trees.
1227
1872
 
1228
1873
  Parameters
1229
1874
  ----------
1230
- bart : dict
1231
- A partially updated BART mcmc state.
1232
- count_trees : array (num_trees, 2 ** d)
1233
- The number of points in each potential or actual leaf node.
1234
- moves : dict
1235
- The proposed moves, see `sample_moves`.
1236
- move_counts : dict
1237
- The counts of the number of points in the the nodes modified by the
1238
- moves.
1239
- prelkv, prelk, prelf : dict
1240
- Dictionaries with pre-computed terms of the likelihood ratios and leaf
1241
- samples.
1875
+ pso
1876
+ The output of `accept_moves_parallel_stage`.
1242
1877
 
1243
1878
  Returns
1244
1879
  -------
1245
- bart : dict
1880
+ bart : State
1246
1881
  A partially updated BART mcmc state.
1247
- moves : dict
1248
- The proposed moves, with these additional fields:
1249
-
1250
- 'acc' : bool array (num_trees,)
1251
- Whether the move was accepted.
1252
- 'to_prune' : bool array (num_trees,)
1253
- Whether, to reflect the acceptance status of the move, the state
1254
- should be updated by pruning the leaves involved in the move.
1882
+ moves : Moves
1883
+ The accepted/rejected moves, with `acc` and `to_prune` set.
1255
1884
  """
1256
- bart = bart.copy()
1257
- moves = moves.copy()
1258
1885
 
1259
- def loop(resid, item):
1886
+ def loop(resid, pt):
1260
1887
  resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
1261
- bart['X'],
1262
- len(bart['leaf_trees']),
1263
- bart['opt']['resid_batch_size'],
1264
1888
  resid,
1265
- bart['min_points_per_leaf'],
1266
- 'ratios' in bart,
1267
- prelk,
1268
- *item,
1889
+ SeqStageInAllTrees(
1890
+ pso.bart.X,
1891
+ pso.bart.forest.resid_batch_size,
1892
+ pso.bart.prec_scale,
1893
+ pso.bart.forest.min_points_per_leaf,
1894
+ pso.bart.forest.log_likelihood is not None,
1895
+ pso.prelk,
1896
+ ),
1897
+ pt,
1269
1898
  )
1270
1899
  return resid, (leaf_tree, acc, to_prune, ratios)
1271
1900
 
1272
- items = (
1273
- bart['leaf_trees'], count_trees,
1274
- moves, move_counts,
1275
- bart['leaf_indices'],
1276
- prelkv, prelf,
1901
+ pts = SeqStageInPerTree(
1902
+ pso.bart.forest.leaf_trees,
1903
+ pso.prec_trees,
1904
+ pso.moves,
1905
+ pso.move_counts,
1906
+ pso.move_precs,
1907
+ pso.bart.forest.leaf_indices,
1908
+ pso.prelkv,
1909
+ pso.prelf,
1277
1910
  )
1278
- resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
1279
-
1280
- bart['resid'] = resid
1281
- bart['leaf_trees'] = leaf_trees
1282
- bart.get('ratios', {}).update(ratios)
1283
- moves['acc'] = acc
1284
- moves['to_prune'] = to_prune
1911
+ resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
1912
+
1913
+ save_ratios = pso.bart.forest.log_likelihood is not None
1914
+ bart = replace(
1915
+ pso.bart,
1916
+ resid=resid,
1917
+ forest=replace(
1918
+ pso.bart.forest,
1919
+ leaf_trees=leaf_trees,
1920
+ log_likelihood=ratios['log_likelihood'] if save_ratios else None,
1921
+ log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
1922
+ ),
1923
+ )
1924
+ moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1285
1925
 
1286
1926
  return bart, moves
1287
1927
 
1288
- def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
1928
+
1929
+ class SeqStageInAllTrees(Module):
1289
1930
  """
1290
- Accept or reject a proposed move and sample the new leaf values.
1931
+ The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
1291
1932
 
1292
1933
  Parameters
1293
1934
  ----------
1294
- X : int array (p, n)
1935
+ X
1295
1936
  The predictors.
1296
- ntree : int
1297
- The number of trees in the forest.
1298
- resid_batch_size : int, None
1937
+ resid_batch_size
1299
1938
  The batch size for computing the sum of residuals in each leaf.
1300
- resid : float array (n,)
1301
- The residuals (data minus forest value).
1302
- min_points_per_leaf : int or None
1939
+ prec_scale
1940
+ The scale of the precision of the error on each datapoint. If None, it
1941
+ is assumed to be 1.
1942
+ min_points_per_leaf
1303
1943
  The minimum number of data points in a leaf node.
1304
- save_ratios : bool
1944
+ save_ratios
1305
1945
  Whether to save the acceptance ratios.
1306
- prelk : dict
1946
+ prelk
1307
1947
  The pre-computed terms of the likelihood ratio which are shared across
1308
1948
  trees.
1309
- leaf_tree : float array (2 ** d,)
1949
+ """
1950
+
1951
+ X: UInt[Array, 'p n']
1952
+ resid_batch_size: int | None
1953
+ prec_scale: Float32[Array, 'n'] | None
1954
+ min_points_per_leaf: Int32[Array, ''] | None
1955
+ save_ratios: bool
1956
+ prelk: PreLk
1957
+
1958
+
1959
+ class SeqStageInPerTree(Module):
1960
+ """
1961
+ The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
1962
+
1963
+ Parameters
1964
+ ----------
1965
+ leaf_tree
1310
1966
  The leaf values of the tree.
1311
- count_tree : int array (2 ** d,)
1312
- The number of datapoints in each leaf.
1313
- move : dict
1314
- The proposed move, see `sample_moves`.
1315
- leaf_indices : int array (n,)
1967
+ prec_tree
1968
+ The likelihood precision scale in each potential or actual leaf node.
1969
+ move
1970
+ The proposed move, see `propose_moves`.
1971
+ move_counts
1972
+ The counts of the number of points in the the nodes modified by the
1973
+ moves.
1974
+ move_precs
1975
+ The likelihood precision scale in each node modified by the moves.
1976
+ leaf_indices
1316
1977
  The leaf indices for the largest version of the tree compatible with
1317
1978
  the move.
1318
- prelkv, prelf : dict
1979
+ prelkv
1980
+ prelf
1319
1981
  The pre-computed terms of the likelihood ratio and leaf sampling which
1320
1982
  are specific to the tree.
1983
+ """
1984
+
1985
+ leaf_tree: Float32[Array, '2**d']
1986
+ prec_tree: Float32[Array, '2**d']
1987
+ move: Moves
1988
+ move_counts: Counts | None
1989
+ move_precs: Precs | Counts
1990
+ leaf_indices: UInt[Array, 'n']
1991
+ prelkv: PreLkV
1992
+ prelf: PreLf
1993
+
1994
+
1995
+ def accept_move_and_sample_leaves(
1996
+ resid: Float32[Array, 'n'],
1997
+ at: SeqStageInAllTrees,
1998
+ pt: SeqStageInPerTree,
1999
+ ) -> tuple[
2000
+ Float32[Array, 'n'],
2001
+ Float32[Array, '2**d'],
2002
+ Bool[Array, ''],
2003
+ Bool[Array, ''],
2004
+ dict[str, Float32[Array, '']],
2005
+ ]:
2006
+ """
2007
+ Accept or reject a proposed move and sample the new leaf values.
2008
+
2009
+ Parameters
2010
+ ----------
2011
+ resid
2012
+ The residuals (data minus forest value).
2013
+ at
2014
+ The inputs that are the same for all trees.
2015
+ pt
2016
+ The inputs that are separate for each tree.
1321
2017
 
1322
2018
  Returns
1323
2019
  -------
1324
- resid : float array (n,)
2020
+ resid : Float32[Array, 'n']
1325
2021
  The updated residuals (data minus forest value).
1326
- leaf_tree : float array (2 ** d,)
2022
+ leaf_tree : Float32[Array, '2**d']
1327
2023
  The new leaf values of the tree.
1328
- acc : bool
2024
+ acc : Bool[Array, '']
1329
2025
  Whether the move was accepted.
1330
- to_prune : bool
2026
+ to_prune : Bool[Array, '']
1331
2027
  Whether, to reflect the acceptance status of the move, the state should
1332
2028
  be updated by pruning the leaves involved in the move.
1333
- ratios : dict
2029
+ ratios : dict[str, Float32[Array, '']]
1334
2030
  The acceptance ratios for the moves. Empty if not to be saved.
1335
2031
  """
1336
-
1337
- # sum residuals and count units per leaf, in tree proposed by grow move
1338
- resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
2032
+ # sum residuals in each leaf, in tree proposed by grow move
2033
+ if at.prec_scale is None:
2034
+ scaled_resid = resid
2035
+ else:
2036
+ scaled_resid = resid * at.prec_scale
2037
+ resid_tree = sum_resid(
2038
+ scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2039
+ )
1339
2040
 
1340
2041
  # subtract starting tree from function
1341
- resid_tree += count_tree * leaf_tree
2042
+ resid_tree += pt.prec_tree * pt.leaf_tree
1342
2043
 
1343
2044
  # get indices of move
1344
- node = move['node']
2045
+ node = pt.move.node
1345
2046
  assert node.dtype == jnp.int32
1346
- left = move['left']
1347
- right = move['right']
2047
+ left = pt.move.left
2048
+ right = pt.move.right
1348
2049
 
1349
2050
  # sum residuals in parent node modified by move
1350
2051
  resid_left = resid_tree[left]
@@ -1353,215 +2054,282 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1353
2054
  resid_tree = resid_tree.at[node].set(resid_total)
1354
2055
 
1355
2056
  # compute acceptance ratio
1356
- log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
1357
- log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1358
- log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
2057
+ log_lk_ratio = compute_likelihood_ratio(
2058
+ resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2059
+ )
2060
+ log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2061
+ log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
1359
2062
  ratios = {}
1360
- if save_ratios:
2063
+ if at.save_ratios:
1361
2064
  ratios.update(
1362
- log_trans_prior=move['log_trans_prior_ratio'],
2065
+ log_trans_prior=pt.move.log_trans_prior_ratio,
2066
+ # TODO save log_trans_prior_ratio as a vector outside of this loop,
2067
+ # then change the option everywhere to `save_likelihood_ratio`.
1363
2068
  log_likelihood=log_lk_ratio,
1364
2069
  )
1365
2070
 
1366
2071
  # determine whether to accept the move
1367
- acc = move['allowed'] & (move['logu'] <= log_ratio)
1368
- if min_points_per_leaf is not None:
1369
- acc &= move_counts['left'] >= min_points_per_leaf
1370
- acc &= move_counts['right'] >= min_points_per_leaf
2072
+ acc = pt.move.allowed & (pt.move.logu <= log_ratio)
2073
+ if at.min_points_per_leaf is not None:
2074
+ assert pt.move_counts is not None
2075
+ acc &= pt.move_counts.left >= at.min_points_per_leaf
2076
+ acc &= pt.move_counts.right >= at.min_points_per_leaf
1371
2077
 
1372
2078
  # compute leaves posterior and sample leaves
1373
- initial_leaf_tree = leaf_tree
1374
- mean_post = resid_tree * prelf['mean_factor']
1375
- leaf_tree = mean_post + prelf['centered_leaves']
1376
-
1377
- # copy leaves around such that the leaf indices select the right leaf
1378
- to_prune = acc ^ move['grow']
1379
- leaf_tree = (leaf_tree
1380
- .at[jnp.where(to_prune, left, leaf_tree.size)]
2079
+ initial_leaf_tree = pt.leaf_tree
2080
+ mean_post = resid_tree * pt.prelf.mean_factor
2081
+ leaf_tree = mean_post + pt.prelf.centered_leaves
2082
+
2083
+ # copy leaves around such that the leaf indices point to the correct leaf
2084
+ to_prune = acc ^ pt.move.grow
2085
+ leaf_tree = (
2086
+ leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
1381
2087
  .set(leaf_tree[node])
1382
2088
  .at[jnp.where(to_prune, right, leaf_tree.size)]
1383
2089
  .set(leaf_tree[node])
1384
2090
  )
1385
2091
 
1386
2092
  # replace old tree with new tree in function values
1387
- resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
2093
+ resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
1388
2094
 
1389
2095
  return resid, leaf_tree, acc, to_prune, ratios
1390
2096
 
1391
- def sum_resid(resid, leaf_indices, tree_size, batch_size):
2097
+
2098
+ def sum_resid(
2099
+ scaled_resid: Float32[Array, 'n'],
2100
+ leaf_indices: UInt[Array, 'n'],
2101
+ tree_size: int,
2102
+ batch_size: int | None,
2103
+ ) -> Float32[Array, '{tree_size}']:
1392
2104
  """
1393
2105
  Sum the residuals in each leaf.
1394
2106
 
1395
2107
  Parameters
1396
2108
  ----------
1397
- resid : float array (n,)
1398
- The residuals (data minus forest value).
1399
- leaf_indices : int array (n,)
2109
+ scaled_resid
2110
+ The residuals (data minus forest value) multiplied by the error
2111
+ precision scale.
2112
+ leaf_indices
1400
2113
  The leaf indices of the tree (in which leaf each data point falls into).
1401
- tree_size : int
2114
+ tree_size
1402
2115
  The size of the tree array (2 ** d).
1403
- batch_size : int, None
2116
+ batch_size
1404
2117
  The data batch size for the aggregation. Batching increases numerical
1405
2118
  accuracy and parallelism.
1406
2119
 
1407
2120
  Returns
1408
2121
  -------
1409
- resid_tree : float array (2 ** d,)
1410
- The sum of the residuals at data points in each leaf.
2122
+ The sum of the residuals at data points in each leaf.
1411
2123
  """
1412
2124
  if batch_size is None:
1413
2125
  aggr_func = _aggregate_scatter
1414
2126
  else:
1415
- aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1416
- return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
1417
-
1418
- def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1419
- n, = indices.shape
2127
+ aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
2128
+ return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
2129
+
2130
+
2131
+ def _aggregate_batched_onetree(
2132
+ values: Shaped[Array, '*'],
2133
+ indices: Integer[Array, '*'],
2134
+ size: int,
2135
+ dtype: jnp.dtype,
2136
+ batch_size: int,
2137
+ ) -> Float32[Array, '{size}']:
2138
+ (n,) = indices.shape
1420
2139
  nbatches = n // batch_size + bool(n % batch_size)
1421
2140
  batch_indices = jnp.arange(n) % nbatches
1422
- return (jnp
1423
- .zeros((size, nbatches), dtype)
2141
+ return (
2142
+ jnp.zeros((size, nbatches), dtype)
1424
2143
  .at[indices, batch_indices]
1425
2144
  .add(values)
1426
2145
  .sum(axis=1)
1427
2146
  )
1428
2147
 
1429
- def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
2148
+
2149
+ def compute_likelihood_ratio(
2150
+ total_resid: Float32[Array, ''],
2151
+ left_resid: Float32[Array, ''],
2152
+ right_resid: Float32[Array, ''],
2153
+ prelkv: PreLkV,
2154
+ prelk: PreLk,
2155
+ ) -> Float32[Array, '']:
1430
2156
  """
1431
2157
  Compute the likelihood ratio of a grow move.
1432
2158
 
1433
2159
  Parameters
1434
2160
  ----------
1435
- total_resid : float
1436
- The sum of the residuals in the leaf to grow.
1437
- left_resid, right_resid : float
1438
- The sum of the residuals in the left/right child of the leaf to grow.
1439
- prelkv, prelk : dict
2161
+ total_resid
2162
+ left_resid
2163
+ right_resid
2164
+ The sum of the residuals (scaled by error precision scale) of the
2165
+ datapoints falling in the nodes involved in the moves.
2166
+ prelkv
2167
+ prelk
1440
2168
  The pre-computed terms of the likelihood ratio, see
1441
2169
  `precompute_likelihood_terms`.
1442
2170
 
1443
2171
  Returns
1444
2172
  -------
1445
- ratio : float
1446
- The likelihood ratio P(data | new tree) / P(data | old tree).
2173
+ The likelihood ratio P(data | new tree) / P(data | old tree).
1447
2174
  """
1448
- exp_term = prelk['exp_factor'] * (
1449
- left_resid * left_resid / prelkv['sigma2_left'] +
1450
- right_resid * right_resid / prelkv['sigma2_right'] -
1451
- total_resid * total_resid / prelkv['sigma2_total']
2175
+ exp_term = prelk.exp_factor * (
2176
+ left_resid * left_resid / prelkv.sigma2_left
2177
+ + right_resid * right_resid / prelkv.sigma2_right
2178
+ - total_resid * total_resid / prelkv.sigma2_total
1452
2179
  )
1453
- return prelkv['sqrt_term'] + exp_term
2180
+ return prelkv.sqrt_term + exp_term
2181
+
1454
2182
 
1455
- def accept_moves_final_stage(bart, moves):
2183
+ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1456
2184
  """
1457
- The final part of accepting the moves, in parallel across trees.
2185
+ Post-process the mcmc state after accepting/rejecting the moves.
2186
+
2187
+ This function is separate from `accept_moves_sequential_stage` to signal it
2188
+ can work in parallel across trees.
1458
2189
 
1459
2190
  Parameters
1460
2191
  ----------
1461
- bart : dict
2192
+ bart
1462
2193
  A partially updated BART mcmc state.
1463
- counts : dict
1464
- The indicators of proposals and acceptances for grow and prune moves.
1465
- moves : dict
1466
- The proposed moves (see `sample_moves`) as updated by
2194
+ moves
2195
+ The proposed moves (see `propose_moves`) as updated by
1467
2196
  `accept_moves_sequential_stage`.
1468
2197
 
1469
2198
  Returns
1470
2199
  -------
1471
- bart : dict
1472
- The fully updated BART mcmc state.
1473
- """
1474
- bart = bart.copy()
1475
- bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
1476
- bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
1477
- bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
1478
- bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1479
- return bart
2200
+ The fully updated BART mcmc state.
2201
+ """
2202
+ return replace(
2203
+ bart,
2204
+ forest=replace(
2205
+ bart.forest,
2206
+ grow_acc_count=jnp.sum(moves.acc & moves.grow),
2207
+ prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2208
+ leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2209
+ split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
2210
+ ),
2211
+ )
2212
+
1480
2213
 
1481
- @jax.vmap
1482
- def apply_moves_to_leaf_indices(leaf_indices, moves):
2214
+ @vmap_nodoc
2215
+ def apply_moves_to_leaf_indices(
2216
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2217
+ ) -> UInt[Array, 'num_trees n']:
1483
2218
  """
1484
2219
  Update the leaf indices to match the accepted move.
1485
2220
 
1486
2221
  Parameters
1487
2222
  ----------
1488
- leaf_indices : int array (num_trees, n)
2223
+ leaf_indices
1489
2224
  The index of the leaf each datapoint falls into, if the grow move was
1490
2225
  accepted.
1491
- moves : dict
1492
- The proposed moves (see `sample_moves`), as updated by
2226
+ moves
2227
+ The proposed moves (see `propose_moves`), as updated by
1493
2228
  `accept_moves_sequential_stage`.
1494
2229
 
1495
2230
  Returns
1496
2231
  -------
1497
- leaf_indices : int array (num_trees, n)
1498
- The updated leaf indices.
2232
+ The updated leaf indices.
1499
2233
  """
1500
- mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1501
- is_child = (leaf_indices & mask) == moves['left']
2234
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
2235
+ is_child = (leaf_indices & mask) == moves.left
1502
2236
  return jnp.where(
1503
- is_child & moves['to_prune'],
1504
- moves['node'].astype(leaf_indices.dtype),
2237
+ is_child & moves.to_prune,
2238
+ moves.node.astype(leaf_indices.dtype),
1505
2239
  leaf_indices,
1506
2240
  )
1507
2241
 
1508
- @jax.vmap
1509
- def apply_moves_to_split_trees(split_trees, moves):
2242
+
2243
+ @vmap_nodoc
2244
+ def apply_moves_to_split_trees(
2245
+ split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2246
+ ) -> UInt[Array, 'num_trees 2**(d-1)']:
1510
2247
  """
1511
2248
  Update the split trees to match the accepted move.
1512
2249
 
1513
2250
  Parameters
1514
2251
  ----------
1515
- split_trees : int array (num_trees, 2 ** (d - 1))
2252
+ split_trees
1516
2253
  The cutpoints of the decision nodes in the initial trees.
1517
- moves : dict
1518
- The proposed moves (see `sample_moves`), as updated by
2254
+ moves
2255
+ The proposed moves (see `propose_moves`), as updated by
1519
2256
  `accept_moves_sequential_stage`.
1520
2257
 
1521
2258
  Returns
1522
2259
  -------
1523
- split_trees : int array (num_trees, 2 ** (d - 1))
1524
- The updated split trees.
1525
- """
1526
- return (split_trees
1527
- .at[jnp.where(
1528
- moves['grow'],
1529
- moves['node'],
1530
- split_trees.size,
1531
- )]
1532
- .set(moves['grow_split'].astype(split_trees.dtype))
1533
- .at[jnp.where(
1534
- moves['to_prune'],
1535
- moves['node'],
1536
- split_trees.size,
1537
- )]
2260
+ The updated split trees.
2261
+ """
2262
+ assert moves.to_prune is not None
2263
+ return (
2264
+ split_trees.at[
2265
+ jnp.where(
2266
+ moves.grow,
2267
+ moves.node,
2268
+ split_trees.size,
2269
+ )
2270
+ ]
2271
+ .set(moves.grow_split.astype(split_trees.dtype))
2272
+ .at[
2273
+ jnp.where(
2274
+ moves.to_prune,
2275
+ moves.node,
2276
+ split_trees.size,
2277
+ )
2278
+ ]
1538
2279
  .set(0)
1539
2280
  )
1540
2281
 
1541
- def sample_sigma(bart, key):
2282
+
2283
+ def step_sigma(key: Key[Array, ''], bart: State) -> State:
1542
2284
  """
1543
- Noise variance sampling step of BART MCMC.
2285
+ MCMC-update the error variance (factor).
1544
2286
 
1545
2287
  Parameters
1546
2288
  ----------
1547
- bart : dict
1548
- A BART mcmc state, as created by `init`.
1549
- key : jax.dtypes.prng_key array
2289
+ key
1550
2290
  A jax random key.
2291
+ bart
2292
+ A BART mcmc state.
1551
2293
 
1552
2294
  Returns
1553
2295
  -------
1554
- bart : dict
1555
- The new BART mcmc state.
2296
+ The new BART mcmc state, with an updated `sigma2`.
1556
2297
  """
1557
- bart = bart.copy()
1558
-
1559
- resid = bart['resid']
1560
- alpha = bart['sigma2_alpha'] + resid.size / 2
1561
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
1562
- beta = bart['sigma2_beta'] + norm2 / 2
2298
+ resid = bart.resid
2299
+ alpha = bart.sigma2_alpha + resid.size / 2
2300
+ if bart.prec_scale is None:
2301
+ scaled_resid = resid
2302
+ else:
2303
+ scaled_resid = resid * bart.prec_scale
2304
+ norm2 = resid @ scaled_resid
2305
+ beta = bart.sigma2_beta + norm2 / 2
1563
2306
 
1564
2307
  sample = random.gamma(key, alpha)
1565
- bart['sigma2'] = beta / sample
2308
+ return replace(bart, sigma2=beta / sample)
1566
2309
 
1567
- return bart
2310
+
2311
+ def step_z(key: Key[Array, ''], bart: State) -> State:
2312
+ """
2313
+ MCMC-update the latent variable for binary regression.
2314
+
2315
+ Parameters
2316
+ ----------
2317
+ key
2318
+ A jax random key.
2319
+ bart
2320
+ A BART MCMC state.
2321
+
2322
+ Returns
2323
+ -------
2324
+ The updated BART MCMC state.
2325
+ """
2326
+ trees_plus_offset = bart.z - bart.resid
2327
+ lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
2328
+ upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
2329
+ resid = random.truncated_normal(key, lower, upper)
2330
+ # TODO jax's implementation of truncated_normal is not good, it just does
2331
+ # cdf inversion with erf and erf_inv. I can do better, at least avoiding to
2332
+ # compute one of the boundaries, and maybe also flipping and using ndtr
2333
+ # instead of erf for numerical stability (open an issue in jax?)
2334
+ z = trees_plus_offset + resid
2335
+ return replace(bart, z=z, resid=resid)