bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -34,16 +34,16 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
- import math
38
37
 
39
38
  import jax
40
39
  from jax import random
41
40
  from jax import numpy as jnp
42
41
  from jax import lax
43
42
 
43
+ from . import jaxext
44
44
  from . import grove
45
45
 
46
- def make_bart(*,
46
+ def init(*,
47
47
  X,
48
48
  y,
49
49
  max_split,
@@ -51,9 +51,10 @@ def make_bart(*,
51
51
  p_nonterminal,
52
52
  sigma2_alpha,
53
53
  sigma2_beta,
54
- small_float_dtype=jnp.float32,
55
- large_float_dtype=jnp.float32,
54
+ small_float=jnp.float32,
55
+ large_float=jnp.float32,
56
56
  min_points_per_leaf=None,
57
+ suffstat_batch_size='auto',
57
58
  ):
58
59
  """
59
60
  Make a BART posterior sampling MCMC initial state.
@@ -75,12 +76,15 @@ def make_bart(*,
75
76
  The shape parameter of the inverse gamma prior on the noise variance.
76
77
  sigma2_beta : float
77
78
  The scale parameter of the inverse gamma prior on the noise variance.
78
- small_float_dtype : dtype, default float32
79
+ small_float : dtype, default float32
79
80
  The dtype for large arrays used in the algorithm.
80
- large_float_dtype : dtype, default float32
81
+ large_float : dtype, default float32
81
82
  The dtype for scalars, small arrays, and arrays which require accuracy.
82
83
  min_points_per_leaf : int, optional
83
84
  The minimum number of data points in a leaf node. 0 if not specified.
85
+ suffstat_batch_size : int, None, str, default 'auto'
86
+ The batch size for computing sufficient statistics. `None` for no
87
+ batching. If 'auto', pick a value based on the device of `y`.
84
88
 
85
89
  Returns
86
90
  -------
@@ -88,31 +92,31 @@ def make_bart(*,
88
92
  A dictionary with array values, representing a BART mcmc state. The
89
93
  keys are:
90
94
 
91
- 'leaf_trees' : int array (num_trees, 2 ** d)
92
- The leaf values of the trees.
95
+ 'leaf_trees' : small_float array (num_trees, 2 ** d)
96
+ The leaf values.
93
97
  'var_trees' : int array (num_trees, 2 ** (d - 1))
94
- The variable indices of the trees. The bottom level is missing since
95
- it can only contain leaves.
98
+ The decision axes.
96
99
  'split_trees' : int array (num_trees, 2 ** (d - 1))
97
- The splitting points.
98
- 'resid' : large_float_dtype array (n,)
100
+ The decision boundaries.
101
+ 'resid' : large_float array (n,)
99
102
  The residuals (data minus forest value). Large float to avoid
100
103
  roundoff.
101
- 'sigma2' : large_float_dtype
104
+ 'sigma2' : large_float
102
105
  The noise variance.
103
106
  'grow_prop_count', 'prune_prop_count' : int
104
107
  The number of grow/prune proposals made during one full MCMC cycle.
105
108
  'grow_acc_count', 'prune_acc_count' : int
106
109
  The number of grow/prune moves accepted during one full MCMC cycle.
107
- 'p_nonterminal' : large_float_dtype array (d - 1,)
108
- The probability of a nonterminal node at each depth.
109
- 'sigma2_alpha' : large_float_dtype
110
+ 'p_nonterminal' : large_float array (d,)
111
+ The probability of a nonterminal node at each depth, padded with a
112
+ zero.
113
+ 'sigma2_alpha' : large_float
110
114
  The shape parameter of the inverse gamma prior on the noise variance.
111
- 'sigma2_beta' : large_float_dtype
115
+ 'sigma2_beta' : large_float
112
116
  The scale parameter of the inverse gamma prior on the noise variance.
113
117
  'max_split' : int array (p,)
114
118
  The maximum split index for each variable.
115
- 'y' : small_float_dtype array (n,)
119
+ 'y' : small_float array (n,)
116
120
  The response.
117
121
  'X' : int array (p, n)
118
122
  The predictors.
@@ -121,31 +125,49 @@ def make_bart(*,
121
125
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
122
126
  Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
123
127
  datapoints. If `min_points_per_leaf` is not specified, this is None.
128
+ 'opt' : LeafDict
129
+ A dictionary with config values:
130
+
131
+ 'suffstat_batch_size' : int or None
132
+ The batch size for computing sufficient statistics.
133
+ 'small_float' : dtype
134
+ The dtype for large arrays used in the algorithm.
135
+ 'large_float' : dtype
136
+ The dtype for scalars, small arrays, and arrays which require
137
+ accuracy.
138
+ 'require_min_points' : bool
139
+ Whether the `min_points_per_leaf` parameter is specified.
124
140
  """
125
141
 
126
- p_nonterminal = jnp.asarray(p_nonterminal, large_float_dtype)
127
- max_depth = p_nonterminal.size + 1
142
+ p_nonterminal = jnp.asarray(p_nonterminal, large_float)
143
+ p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
144
+ max_depth = p_nonterminal.size
128
145
 
129
146
  @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
130
147
  def make_forest(max_depth, dtype):
131
148
  return grove.make_tree(max_depth, dtype)
132
149
 
150
+ small_float = jnp.dtype(small_float)
151
+ large_float = jnp.dtype(large_float)
152
+ y = jnp.asarray(y, small_float)
153
+ suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
154
+
133
155
  bart = dict(
134
- leaf_trees=make_forest(max_depth, small_float_dtype),
135
- var_trees=make_forest(max_depth - 1, grove.minimal_unsigned_dtype(X.shape[0] - 1)),
156
+ leaf_trees=make_forest(max_depth, small_float),
157
+ var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
136
158
  split_trees=make_forest(max_depth - 1, max_split.dtype),
137
- resid=jnp.asarray(y, large_float_dtype),
138
- sigma2=jnp.ones((), large_float_dtype),
159
+ resid=jnp.asarray(y, large_float),
160
+ sigma2=jnp.ones((), large_float),
139
161
  grow_prop_count=jnp.zeros((), int),
140
162
  grow_acc_count=jnp.zeros((), int),
141
163
  prune_prop_count=jnp.zeros((), int),
142
164
  prune_acc_count=jnp.zeros((), int),
143
165
  p_nonterminal=p_nonterminal,
144
- sigma2_alpha=jnp.asarray(sigma2_alpha, large_float_dtype),
145
- sigma2_beta=jnp.asarray(sigma2_beta, large_float_dtype),
146
- max_split=max_split,
147
- y=jnp.asarray(y, small_float_dtype),
148
- X=X,
166
+ sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
167
+ sigma2_beta=jnp.asarray(sigma2_beta, large_float),
168
+ max_split=jnp.asarray(max_split),
169
+ y=y,
170
+ X=jnp.asarray(X),
149
171
  min_points_per_leaf=(
150
172
  None if min_points_per_leaf is None else
151
173
  jnp.asarray(min_points_per_leaf)
@@ -154,18 +176,40 @@ def make_bart(*,
154
176
  None if min_points_per_leaf is None else
155
177
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
156
178
  ),
179
+ opt=jaxext.LeafDict(
180
+ suffstat_batch_size=suffstat_batch_size,
181
+ small_float=small_float,
182
+ large_float=large_float,
183
+ require_min_points=min_points_per_leaf is not None,
184
+ ),
157
185
  )
158
186
 
159
187
  return bart
160
188
 
161
- def mcmc_step(bart, key):
189
+ def _choose_suffstat_batch_size(size, y):
190
+ if size == 'auto':
191
+ platform = y.devices().pop().platform
192
+ if platform == 'cpu':
193
+ return None
194
+ # maybe I should batch residuals (not counts) for numerical
195
+ # accuracy, even if it's slower
196
+ elif platform == 'gpu':
197
+ return 128 # 128 is good on A100, and V100 at high n
198
+ # 512 is good on T4, and V100 at low n
199
+ else:
200
+ raise KeyError(f'Unknown platform: {platform}')
201
+ elif size is not None:
202
+ return int(size)
203
+ return size
204
+
205
+ def step(bart, key):
162
206
  """
163
207
  Perform one full MCMC step on a BART state.
164
208
 
165
209
  Parameters
166
210
  ----------
167
211
  bart : dict
168
- A BART mcmc state, as created by `make_bart`.
212
+ A BART mcmc state, as created by `init`.
169
213
  key : jax.dtypes.prng_key array
170
214
  A jax random key.
171
215
 
@@ -174,19 +218,18 @@ def mcmc_step(bart, key):
174
218
  bart : dict
175
219
  The new BART mcmc state.
176
220
  """
177
- key1, key2 = random.split(key, 2)
178
- bart = mcmc_sample_trees(bart, key1)
179
- bart = mcmc_sample_sigma(bart, key2)
180
- return bart
221
+ key, subkey = random.split(key)
222
+ bart = sample_trees(bart, subkey)
223
+ return sample_sigma(bart, key)
181
224
 
182
- def mcmc_sample_trees(bart, key):
225
+ def sample_trees(bart, key):
183
226
  """
184
227
  Forest sampling step of BART MCMC.
185
228
 
186
229
  Parameters
187
230
  ----------
188
231
  bart : dict
189
- A BART mcmc state, as created by `make_bart`.
232
+ A BART mcmc state, as created by `init`.
190
233
  key : jax.dtypes.prng_key array
191
234
  A jax random key.
192
235
 
@@ -197,150 +240,52 @@ def mcmc_sample_trees(bart, key):
197
240
 
198
241
  Notes
199
242
  -----
200
- This function zeroes the proposal counters.
243
+ This function zeroes the proposal counters before using them.
201
244
  """
202
245
  bart = bart.copy()
203
- for count_var in ['grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count']:
204
- bart[count_var] = jnp.zeros_like(bart[count_var])
205
-
206
- carry = 0, bart, key
207
- def loop(carry, _):
208
- i, bart, key = carry
209
- key, subkey = random.split(key)
210
- bart = mcmc_sample_tree(bart, subkey, i)
211
- return (i + 1, bart, key), None
212
-
213
- (_, bart, _), _ = lax.scan(loop, carry, None, len(bart['leaf_trees']))
214
- return bart
246
+ key, subkey = random.split(key)
247
+ grow_moves, prune_moves = sample_moves(bart, subkey)
248
+ bart['var_trees'] = grow_moves['var_tree']
249
+ grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
250
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
215
251
 
216
- def mcmc_sample_tree(bart, key, i_tree):
252
+ def sample_moves(bart, key):
217
253
  """
218
- Single tree sampling step of BART MCMC.
254
+ Propose moves for all the trees.
219
255
 
220
256
  Parameters
221
257
  ----------
222
258
  bart : dict
223
- A BART mcmc state, as created by `make_bart`.
259
+ BART mcmc state.
224
260
  key : jax.dtypes.prng_key array
225
261
  A jax random key.
226
- i_tree : int
227
- The index of the tree to sample.
228
262
 
229
263
  Returns
230
264
  -------
231
- bart : dict
232
- The new BART mcmc state.
233
- """
234
- bart = bart.copy()
235
-
236
- y_tree = grove.evaluate_tree_vmap_x(
237
- bart['X'],
238
- bart['leaf_trees'][i_tree],
239
- bart['var_trees'][i_tree],
240
- bart['split_trees'][i_tree],
241
- bart['resid'].dtype,
242
- )
243
- bart['resid'] += y_tree
244
-
245
- key1, key2 = random.split(key, 2)
246
- bart = mcmc_sample_tree_structure(bart, key1, i_tree)
247
- bart = mcmc_sample_tree_leaves(bart, key2, i_tree)
248
-
249
- y_tree = grove.evaluate_tree_vmap_x(
250
- bart['X'],
251
- bart['leaf_trees'][i_tree],
252
- bart['var_trees'][i_tree],
253
- bart['split_trees'][i_tree],
254
- bart['resid'].dtype,
255
- )
256
- bart['resid'] -= y_tree
257
-
258
- return bart
259
-
260
- def mcmc_sample_tree_structure(bart, key, i_tree):
265
+ grow_moves, prune_moves : dict
266
+ The proposals for grow and prune moves. See `grow_move` and `prune_move`.
261
267
  """
262
- Single tree structure sampling step of BART MCMC.
263
-
264
- Parameters
265
- ----------
266
- bart : dict
267
- A BART mcmc state, as created by `make_bart`. The ``'resid'`` field
268
- shall contain only the residuals w.r.t. the other trees.
269
- key : jax.dtypes.prng_key array
270
- A jax random key.
271
- i_tree : int
272
- The index of the tree to sample.
273
-
274
- Returns
275
- -------
276
- bart : dict
277
- The new BART mcmc state.
278
- """
279
- bart = bart.copy()
280
-
281
- var_tree = bart['var_trees'][i_tree]
282
- split_tree = bart['split_trees'][i_tree]
283
- affluence_tree = (
284
- None if bart['affluence_trees'] is None else
285
- bart['affluence_trees'][i_tree]
286
- )
287
-
288
- key1, key2, key3 = random.split(key, 3)
289
- args = [
290
- bart['X'],
291
- var_tree,
292
- split_tree,
293
- affluence_tree,
294
- bart['max_split'],
295
- bart['p_nonterminal'],
296
- bart['sigma2'],
297
- bart['resid'],
298
- len(bart['var_trees']),
299
- bart['min_points_per_leaf'],
300
- key1,
301
- ]
302
- grow_var_tree, grow_split_tree, grow_affluence_tree, grow_allowed, grow_ratio = grow_move(*args)
303
-
304
- args[-1] = key2
305
- prune_var_tree, prune_split_tree, prune_affluence_tree, prune_allowed, prune_ratio = prune_move(*args)
306
-
307
- u0, u1 = random.uniform(key3, (2,))
308
-
309
- p_grow = jnp.where(grow_allowed & prune_allowed, 0.5, grow_allowed)
310
- try_grow = u0 < p_grow
311
- try_prune = prune_allowed & ~try_grow
312
-
313
- do_grow = try_grow & (u1 < grow_ratio)
314
- do_prune = try_prune & (u1 < prune_ratio)
315
-
316
- var_tree = jnp.where(do_grow, grow_var_tree, var_tree)
317
- split_tree = jnp.where(do_grow, grow_split_tree, split_tree)
318
- var_tree = jnp.where(do_prune, prune_var_tree, var_tree)
319
- split_tree = jnp.where(do_prune, prune_split_tree, split_tree)
320
-
321
- bart['var_trees'] = bart['var_trees'].at[i_tree].set(var_tree)
322
- bart['split_trees'] = bart['split_trees'].at[i_tree].set(split_tree)
323
-
324
- if bart['min_points_per_leaf'] is not None:
325
- affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
326
- affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
327
- bart['affluence_trees'] = bart['affluence_trees'].at[i_tree].set(affluence_tree)
328
-
329
- bart['grow_prop_count'] += try_grow
330
- bart['grow_acc_count'] += do_grow
331
- bart['prune_prop_count'] += try_prune
332
- bart['prune_acc_count'] += do_prune
333
-
334
- return bart
335
-
336
- def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
268
+ key = random.split(key, bart['var_trees'].shape[0])
269
+ return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
270
+
271
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
272
+ def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
273
+ key, key1 = random.split(key)
274
+ args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
275
+ grow = grow_move(*args, key)
276
+ prune = prune_move(*args, key1)
277
+ return grow, prune
278
+
279
+ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
337
280
  """
338
281
  Tree structure grow move proposal of BART MCMC.
339
282
 
283
+ This moves picks a leaf node and converts it to a non-terminal node with
284
+ two leaf children. The move is not possible if all the leaves are already at
285
+ maximum depth.
286
+
340
287
  Parameters
341
288
  ----------
342
- X : array (p, n)
343
- The predictors.
344
289
  var_tree : array (2 ** (d - 1),)
345
290
  The variable indices of the tree.
346
291
  split_tree : array (2 ** (d - 1),)
@@ -349,82 +294,49 @@ def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal,
349
294
  Whether a leaf has enough points to be grown.
350
295
  max_split : array (p,)
351
296
  The maximum split index for each variable.
352
- p_nonterminal : array (d - 1,)
297
+ p_nonterminal : array (d,)
353
298
  The probability of a nonterminal node at each depth.
354
- sigma2 : float
355
- The noise variance.
356
- resid : array (n,)
357
- The residuals (data minus forest value), computed using all trees but
358
- the tree under consideration.
359
- n_tree : int
360
- The number of trees in the forest.
361
- min_points_per_leaf : int
362
- The minimum number of data points in a leaf node.
363
299
  key : jax.dtypes.prng_key array
364
300
  A jax random key.
365
301
 
366
302
  Returns
367
303
  -------
368
- var_tree : array (2 ** (d - 1),)
369
- The new variable indices of the tree.
370
- split_tree : array (2 ** (d - 1),)
371
- The new splitting points of the tree.
372
- affluence_tree : bool array (2 ** (d - 1),) or None
373
- The new indicator whether a leaf has enough points to be grown.
374
- allowed : bool
375
- Whether the move is allowed.
376
- ratio : float
377
- The Metropolis-Hastings ratio.
378
-
379
- Notes
380
- -----
381
- This moves picks a leaf node and converts it to a non-terminal node with
382
- two leaf children. The move is not possible if all the leaves are already at
383
- maximum depth.
304
+ grow_move : dict
305
+ A dictionary with fields:
306
+
307
+ 'allowed' : bool
308
+ Whether the move is possible.
309
+ 'node' : int
310
+ The index of the leaf to grow.
311
+ 'var_tree' : array (2 ** (d - 1),)
312
+ The new decision axes of the tree.
313
+ 'split_tree' : array (2 ** (d - 1),)
314
+ The new decision boundaries of the tree.
315
+ 'partial_ratio' : float
316
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
317
+ the likelihood ratio and the probability of proposing the prune
318
+ move.
384
319
  """
385
320
 
386
- key1, key2, key3 = random.split(key, 3)
387
-
388
- leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key1)
321
+ key, key1, key2 = random.split(key, 3)
389
322
 
390
- var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key2)
391
- var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
392
-
393
- split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key3)
394
- new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
323
+ leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
395
324
 
396
- likelihood_ratio, new_affluence_tree = compute_likelihood_ratio(X, var_tree, new_split_tree, resid, sigma2, leaf_to_grow, n_tree, min_points_per_leaf)
325
+ var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
326
+ var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
397
327
 
398
- trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, leaf_to_grow, split_tree, new_split_tree, affluence_tree, new_affluence_tree)
399
-
400
- ratio = trans_tree_ratio * likelihood_ratio
401
-
402
- return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
403
-
404
- def growable_leaves(split_tree, affluence_tree):
405
- """
406
- Return a mask indicating the leaf nodes that can be proposed for growth.
328
+ split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
329
+ split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
407
330
 
408
- Parameters
409
- ----------
410
- split_tree : array (2 ** (d - 1),)
411
- The splitting points of the tree.
412
- affluence_tree : bool array (2 ** (d - 1),) or None
413
- Whether a leaf has enough points to be grown.
331
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
414
332
 
415
- Returns
416
- -------
417
- is_growable : bool array (2 ** (d - 1),)
418
- The mask indicating the leaf nodes that can be proposed to grow, i.e.,
419
- that are not at the bottom level and have at least two times the number
420
- of minimum points per leaf.
421
- allowed : bool
422
- Whether the grow move is allowed, i.e., there are growable leaves.
423
- """
424
- is_growable = grove.is_actual_leaf(split_tree)
425
- if affluence_tree is not None:
426
- is_growable &= affluence_tree
427
- return is_growable, jnp.any(is_growable)
333
+ return dict(
334
+ allowed=allowed,
335
+ node=leaf_to_grow,
336
+ partial_ratio=ratio,
337
+ var_tree=var_tree,
338
+ split_tree=split_tree,
339
+ )
428
340
 
429
341
  def choose_leaf(split_tree, affluence_tree, key):
430
342
  """
@@ -443,7 +355,7 @@ def choose_leaf(split_tree, affluence_tree, key):
443
355
  -------
444
356
  leaf_to_grow : int
445
357
  The index of the leaf to grow. If ``num_growable == 0``, return
446
- ``split_tree.size``.
358
+ ``2 ** d``.
447
359
  num_growable : int
448
360
  The number of leaf nodes that can be grown.
449
361
  num_prunable : int
@@ -454,11 +366,37 @@ def choose_leaf(split_tree, affluence_tree, key):
454
366
  """
455
367
  is_growable, allowed = growable_leaves(split_tree, affluence_tree)
456
368
  leaf_to_grow = randint_masked(key, is_growable)
369
+ leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
457
370
  num_growable = jnp.count_nonzero(is_growable)
458
371
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
459
372
  num_prunable = jnp.count_nonzero(is_parent)
460
373
  return leaf_to_grow, num_growable, num_prunable, allowed
461
374
 
375
+ def growable_leaves(split_tree, affluence_tree):
376
+ """
377
+ Return a mask indicating the leaf nodes that can be proposed for growth.
378
+
379
+ Parameters
380
+ ----------
381
+ split_tree : array (2 ** (d - 1),)
382
+ The splitting points of the tree.
383
+ affluence_tree : bool array (2 ** (d - 1),) or None
384
+ Whether a leaf has enough points to be grown.
385
+
386
+ Returns
387
+ -------
388
+ is_growable : bool array (2 ** (d - 1),)
389
+ The mask indicating the leaf nodes that can be proposed to grow, i.e.,
390
+ that are not at the bottom level and have at least two times the number
391
+ of minimum points per leaf.
392
+ allowed : bool
393
+ Whether the grow move is allowed, i.e., there are growable leaves.
394
+ """
395
+ is_growable = grove.is_actual_leaf(split_tree)
396
+ if affluence_tree is not None:
397
+ is_growable &= affluence_tree
398
+ return is_growable, jnp.any(is_growable)
399
+
462
400
  def randint_masked(key, mask):
463
401
  """
464
402
  Return a random integer in a range, including only some values.
@@ -560,7 +498,7 @@ def ancestor_variables(var_tree, max_split, node_index):
560
498
  the parent. Unused spots are filled with `p`.
561
499
  """
562
500
  max_num_ancestors = grove.tree_depth(var_tree) - 1
563
- ancestor_vars = jnp.zeros(max_num_ancestors, grove.minimal_unsigned_dtype(max_split.size))
501
+ ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
564
502
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
565
503
  def loop(carry, _):
566
504
  i, index, ancestor_vars = carry
@@ -665,7 +603,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
665
603
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
666
604
  return random.randint(key, (), l, r)
667
605
 
668
- def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree, initial_affluence_tree, new_affluence_tree):
606
+ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
669
607
  """
670
608
  Compute the product of the transition and prior ratios of a grow move.
671
609
 
@@ -676,129 +614,46 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
676
614
  num_prunable : int
677
615
  The number of leaf parents that could be pruned, after converting the
678
616
  leaf to be grown to a non-terminal node.
679
- tree_halfsize : int
680
- Half the length of the tree array, i.e., 2 ** (d - 1).
681
- p_nonterminal : array (d - 1,)
617
+ p_nonterminal : array (d,)
682
618
  The probability of a nonterminal node at each depth.
683
619
  leaf_to_grow : int
684
620
  The index of the leaf to grow.
685
- initial_split_tree : array (2 ** (d - 1),)
686
- The splitting points of the tree, before the leaf is grown.
687
621
  new_split_tree : array (2 ** (d - 1),)
688
622
  The splitting points of the tree, after the leaf is grown.
689
- initial_affluence_tree : bool array (2 ** (d - 1),) or None
690
- Whether a leaf has enough points to be grown, before the leaf is grown.
691
- new_affluence_tree : bool array (2 ** (d - 1),) or None
692
- Whether a leaf has enough points to be grown, after the leaf is grown.
693
623
 
694
624
  Returns
695
625
  -------
696
626
  ratio : float
697
627
  The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
698
- times the prior ratio P(new tree) / P(old tree).
628
+ times the prior ratio P(new tree) / P(old tree), but the transition
629
+ ratio is missing the factor P(propose prune) in the numerator.
699
630
  """
700
631
 
701
632
  # the two ratios also contain factors num_available_split *
702
633
  # num_available_var, but they cancel out
703
634
 
704
- prune_was_allowed = prune_allowed(initial_split_tree)
705
- p_grow = jnp.where(prune_was_allowed, 0.5, 1)
635
+ prune_allowed = leaf_to_grow != 1
636
+ # prune allowed <---> the initial tree is not a root
637
+ # leaf to grow is root --> the tree can only be a root
638
+ # tree is a root --> the only leaf I can grow is root
706
639
 
707
- _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
708
- p_prune = jnp.where(grow_again_allowed, 0.5, 1)
640
+ p_grow = jnp.where(prune_allowed, 0.5, 1)
709
641
 
710
- trans_ratio = p_prune * num_growable / (p_grow * num_prunable)
642
+ trans_ratio = num_growable / (p_grow * num_prunable)
711
643
 
712
- depth = grove.index_depth(leaf_to_grow, tree_halfsize)
644
+ depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
713
645
  p_parent = p_nonterminal[depth]
714
- cp_children = 1 - p_nonterminal.at[depth + 1].get(mode='fill', fill_value=0)
646
+ cp_children = 1 - p_nonterminal[depth + 1]
715
647
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
716
648
 
717
649
  return trans_ratio * tree_ratio
718
650
 
719
- def compute_likelihood_ratio(X, var_tree, split_tree, resid, sigma2, new_node, n_tree, min_points_per_leaf):
720
- """
721
- Compute the likelihood ratio of a grow move.
722
-
723
- Parameters
724
- ----------
725
- X : array (p, n)
726
- The predictors.
727
- var_tree : array (2 ** (d - 1),)
728
- The variable indices of the tree, after the grow move.
729
- split_tree : array (2 ** (d - 1),)
730
- The splitting points of the tree, after the grow move.
731
- resid : array (n,)
732
- The residuals (data minus forest value), for all trees but the one
733
- under consideration.
734
- sigma2 : float
735
- The noise variance.
736
- new_node : int
737
- The index of the leaf that has been grown.
738
- n_tree : int
739
- The number of trees in the forest.
740
- min_points_per_leaf : int or None
741
- The minimum number of data points in a leaf node.
742
-
743
- Returns
744
- -------
745
- ratio : float
746
- The likelihood ratio P(data | new tree) / P(data | old tree).
747
- affluence_tree : bool array (2 ** (d - 1),) or None
748
- Whether a leaf has enough points to be grown, after the grow move.
749
- """
750
-
751
- resid_tree, count_tree = agg_values(
752
- X,
753
- var_tree,
754
- split_tree,
755
- resid,
756
- sigma2.dtype,
757
- )
758
-
759
- left_child = new_node << 1
760
- right_child = left_child + 1
761
-
762
- left_resid = resid_tree[left_child]
763
- right_resid = resid_tree[right_child]
764
- total_resid = left_resid + right_resid
765
-
766
- left_count = count_tree[left_child]
767
- right_count = count_tree[right_child]
768
- total_count = left_count + right_count
769
-
770
- sigma_mu2 = 1 / n_tree
771
- sigma2_left = sigma2 + left_count * sigma_mu2
772
- sigma2_right = sigma2 + right_count * sigma_mu2
773
- sigma2_total = sigma2 + total_count * sigma_mu2
774
-
775
- sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
776
-
777
- exp_term = sigma_mu2 / (2 * sigma2) * (
778
- left_resid * left_resid / sigma2_left +
779
- right_resid * right_resid / sigma2_right -
780
- total_resid * total_resid / sigma2_total
781
- )
782
-
783
- ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
784
-
785
- if min_points_per_leaf is not None:
786
- ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
787
- ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
788
- affluence_tree = count_tree[:count_tree.size // 2] >= 2 * min_points_per_leaf
789
- else:
790
- affluence_tree = None
791
-
792
- return ratio, affluence_tree
793
-
794
- def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
651
+ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
795
652
  """
796
653
  Tree structure prune move proposal of BART MCMC.
797
654
 
798
655
  Parameters
799
656
  ----------
800
- X : array (p, n)
801
- The predictors.
802
657
  var_tree : array (2 ** (d - 1),)
803
658
  The variable indices of the tree.
804
659
  split_tree : array (2 ** (d - 1),)
@@ -807,50 +662,35 @@ def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal
807
662
  Whether a leaf has enough points to be grown.
808
663
  max_split : array (p,)
809
664
  The maximum split index for each variable.
810
- p_nonterminal : array (d - 1,)
665
+ p_nonterminal : array (d,)
811
666
  The probability of a nonterminal node at each depth.
812
- sigma2 : float
813
- The noise variance.
814
- resid : array (n,)
815
- The residuals (data minus forest value), computed using all trees but
816
- the tree under consideration.
817
- n_tree : int
818
- The number of trees in the forest.
819
- min_points_per_leaf : int
820
- The minimum number of data points in a leaf node.
821
667
  key : jax.dtypes.prng_key array
822
668
  A jax random key.
823
669
 
824
670
  Returns
825
671
  -------
826
- var_tree : array (2 ** (d - 1),)
827
- The new variable indices of the tree.
828
- split_tree : array (2 ** (d - 1),)
829
- The new splitting points of the tree.
830
- affluence_tree : bool array (2 ** (d - 1),) or None
831
- The new indicator whether a leaf has enough points to be grown.
832
- allowed : bool
833
- Whether the move is allowed.
834
- ratio : float
835
- The Metropolis-Hastings ratio.
672
+ prune_move : dict
673
+ A dictionary with fields:
674
+
675
+ 'allowed' : bool
676
+ Whether the move is possible.
677
+ 'node' : int
678
+ The index of the node to prune.
679
+ 'partial_ratio' : float
680
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
681
+ the likelihood ratio and the probability of proposing the prune
682
+ move. This ratio is inverted.
836
683
  """
837
684
  node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
838
- allowed = prune_allowed(split_tree)
685
+ allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
839
686
 
840
- new_split_tree = split_tree.at[node_to_prune].set(0)
841
- # should I clean up var_tree as well? just for debugging. it hasn't given me problems though
687
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
842
688
 
843
- likelihood_ratio, _ = compute_likelihood_ratio(X, var_tree, split_tree, resid, sigma2, node_to_prune, n_tree, min_points_per_leaf)
844
- new_affluence_tree = (
845
- None if affluence_tree is None else
846
- affluence_tree.at[node_to_prune].set(True)
689
+ return dict(
690
+ allowed=allowed,
691
+ node=node_to_prune,
692
+ partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
847
693
  )
848
- trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, node_to_prune, new_split_tree, split_tree, new_affluence_tree, affluence_tree)
849
-
850
- ratio = trans_tree_ratio * likelihood_ratio
851
- ratio = 1 / ratio # Question: should I use lax.reciprocal for this?
852
-
853
- return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
854
694
 
855
695
  def choose_leaf_parent(split_tree, affluence_tree, key):
856
696
  """
@@ -890,132 +730,363 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
890
730
 
891
731
  return node_to_prune, num_prunable, num_growable
892
732
 
893
- def prune_allowed(split_tree):
733
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
894
734
  """
895
- Return whether a prune move is allowed.
735
+ Accept or reject the proposed moves and sample the new leaf values.
896
736
 
897
737
  Parameters
898
738
  ----------
899
- split_tree : array (2 ** (d - 1),)
900
- The splitting points of the tree.
739
+ bart : dict
740
+ A BART mcmc state.
741
+ grow_moves : dict
742
+ The proposals for grow moves, batched over the first axis. See
743
+ `grow_move`.
744
+ prune_moves : dict
745
+ The proposals for prune moves, batched over the first axis. See
746
+ `prune_move`.
747
+ grow_leaf_indices : int array (num_trees, n)
748
+ The leaf indices of the trees proposed by the grow move.
749
+ key : jax.dtypes.prng_key array
750
+ A jax random key.
901
751
 
902
752
  Returns
903
753
  -------
904
- allowed : bool
905
- Whether a prune move is allowed.
754
+ bart : dict
755
+ The new BART mcmc state.
906
756
  """
907
- return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
757
+ bart = bart.copy()
758
+ def loop(carry, item):
759
+ resid = carry.pop('resid')
760
+ resid, carry, trees = accept_move_and_sample_leaves(
761
+ bart['X'],
762
+ len(bart['leaf_trees']),
763
+ bart['opt']['suffstat_batch_size'],
764
+ resid,
765
+ bart['sigma2'],
766
+ bart['min_points_per_leaf'],
767
+ carry,
768
+ *item,
769
+ )
770
+ carry['resid'] = resid
771
+ return carry, trees
772
+ carry = {
773
+ k: jnp.zeros_like(bart[k]) for k in
774
+ ['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
775
+ }
776
+ carry['resid'] = bart['resid']
777
+ items = (
778
+ bart['leaf_trees'],
779
+ bart['split_trees'],
780
+ bart['affluence_trees'],
781
+ grow_moves,
782
+ prune_moves,
783
+ grow_leaf_indices,
784
+ random.split(key, len(bart['leaf_trees'])),
785
+ )
786
+ carry, trees = lax.scan(loop, carry, items)
787
+ bart.update(carry)
788
+ bart.update(trees)
789
+ return bart
908
790
 
909
- def mcmc_sample_tree_leaves(bart, key, i_tree):
791
+ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
910
792
  """
911
- Single tree leaves sampling step of BART MCMC.
793
+ Accept or reject a proposed move and sample the new leaf values.
912
794
 
913
795
  Parameters
914
796
  ----------
915
- bart : dict
916
- A BART mcmc state, as created by `make_bart`. The ``'resid'`` field
917
- shall contain the residuals only w.r.t. the other trees.
797
+ X : int array (p, n)
798
+ The predictors.
799
+ ntree : int
800
+ The number of trees in the forest.
801
+ suffstat_batch_size : int, None
802
+ The batch size for computing sufficient statistics.
803
+ resid : float array (n,)
804
+ The residuals (data minus forest value).
805
+ sigma2 : float
806
+ The noise variance.
807
+ min_points_per_leaf : int or None
808
+ The minimum number of data points in a leaf node.
809
+ counts : dict
810
+ The acceptance counts from the mcmc state dict.
811
+ leaf_tree : float array (2 ** d,)
812
+ The leaf values of the tree.
813
+ split_tree : int array (2 ** (d - 1),)
814
+ The decision boundaries of the tree.
815
+ affluence_tree : bool array (2 ** (d - 1),) or None
816
+ Whether a leaf has enough points to be grown.
817
+ grow_move : dict
818
+ The proposal for the grow move. See `grow_move`.
819
+ prune_move : dict
820
+ The proposal for the prune move. See `prune_move`.
821
+ grow_leaf_indices : int array (n,)
822
+ The leaf indices of the tree proposed by the grow move.
918
823
  key : jax.dtypes.prng_key array
919
824
  A jax random key.
920
- i_tree : int
921
- The index of the tree to sample.
922
825
 
923
826
  Returns
924
827
  -------
925
- bart : dict
926
- The new BART mcmc state.
828
+ resid : float array (n,)
829
+ The updated residuals (data minus forest value).
830
+ counts : dict
831
+ The updated acceptance counts.
832
+ trees : dict
833
+ The updated tree arrays.
927
834
  """
928
- bart = bart.copy()
835
+
836
+ # compute leaf indices in starting tree
837
+ grow_node = grow_move['node']
838
+ grow_left = grow_node << 1
839
+ grow_right = grow_left + 1
840
+ leaf_indices = jnp.where(
841
+ (grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
842
+ grow_node,
843
+ grow_leaf_indices,
844
+ )
929
845
 
930
- resid_tree, count_tree = agg_values(
931
- bart['X'],
932
- bart['var_trees'][i_tree],
933
- bart['split_trees'][i_tree],
934
- bart['resid'],
935
- bart['sigma2'].dtype,
846
+ # compute leaf indices in prune tree
847
+ prune_node = prune_move['node']
848
+ prune_left = prune_node << 1
849
+ prune_right = prune_left + 1
850
+ prune_leaf_indices = jnp.where(
851
+ (leaf_indices == prune_left) | (leaf_indices == prune_right),
852
+ prune_node,
853
+ leaf_indices,
936
854
  )
937
855
 
938
- prec_lk = count_tree / bart['sigma2']
939
- prec_prior = len(bart['leaf_trees'])
940
- var_post = 1 / (prec_lk + prec_prior) # lax.reciprocal?
941
- mean_post = resid_tree / bart['sigma2'] * var_post # = mean_lk * prec_lk * var_post
856
+ # subtract starting tree from function
857
+ resid += leaf_tree[leaf_indices]
858
+
859
+ # aggregate residuals and count units per leaf
860
+ grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
861
+
862
+ # compute aggregations in starting tree
863
+ # I do not zero the children because garbage there does not matter
864
+ resid_tree = (grow_resid_tree.at[grow_node]
865
+ .set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
866
+ count_tree = (grow_count_tree.at[grow_node]
867
+ .set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
942
868
 
869
+ # compute aggregations in prune tree
870
+ prune_resid_tree = (resid_tree.at[prune_node]
871
+ .set(resid_tree[prune_left] + resid_tree[prune_right]))
872
+ prune_count_tree = (count_tree.at[prune_node]
873
+ .set(count_tree[prune_left] + count_tree[prune_right]))
874
+
875
+ # compute affluence trees
876
+ if min_points_per_leaf is not None:
877
+ grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
878
+ prune_affluence_tree = affluence_tree.at[prune_node].set(True)
879
+
880
+ # compute probability of proposing prune
881
+ grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
882
+ prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
883
+
884
+ # compute likelihood ratios
885
+ grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
886
+ prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
887
+
888
+ # compute acceptance ratios
889
+ grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
890
+ prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
891
+ prune_ratio = lax.reciprocal(prune_ratio)
892
+
893
+ # random coins in [0, 1) for proposal and acceptance
894
+ key, subkey = random.split(key)
895
+ u0, u1 = random.uniform(subkey, (2,))
896
+
897
+ # determine what move to propose (not proposing anything is an option)
898
+ p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
899
+ try_grow = u0 < p_grow
900
+ try_prune = prune_move['allowed'] & ~try_grow
901
+
902
+ # determine whether to accept the move
903
+ do_grow = try_grow & (u1 < grow_ratio)
904
+ do_prune = try_prune & (u1 < prune_ratio)
905
+
906
+ # pick trees for chosen move
907
+ trees = {}
908
+ split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
909
+ # the prune var tree is equal to the initial one, because I leave garbage values behind
910
+ split_tree = split_tree.at[prune_node].set(
911
+ jnp.where(do_prune, 0, split_tree[prune_node]))
912
+ if min_points_per_leaf is not None:
913
+ affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
914
+ affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
915
+ resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
916
+ count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
917
+ resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
918
+ count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
919
+
920
+ # update acceptance counts
921
+ counts = counts.copy()
922
+ counts['grow_prop_count'] += try_grow
923
+ counts['grow_acc_count'] += do_grow
924
+ counts['prune_prop_count'] += try_prune
925
+ counts['prune_acc_count'] += do_prune
926
+
927
+ # compute leaves posterior
928
+ prec_lk = count_tree / sigma2
929
+ var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
930
+ mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
931
+
932
+ # sample leaves
943
933
  z = random.normal(key, mean_post.shape, mean_post.dtype)
944
- # TODO maybe use long float here, I guess this part is not a bottleneck
945
934
  leaf_tree = mean_post + z * jnp.sqrt(var_post)
946
- leaf_tree = leaf_tree.at[0].set(0) # this 0 is used by evaluate_tree
947
- bart['leaf_trees'] = bart['leaf_trees'].at[i_tree].set(leaf_tree)
948
935
 
949
- return bart
936
+ # add new tree to function
937
+ leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
938
+ leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
939
+ resid -= leaf_tree[leaf_indices]
950
940
 
951
- def agg_values(X, var_tree, split_tree, values, acc_dtype):
941
+ # pack trees
942
+ trees = {
943
+ 'leaf_trees': leaf_tree,
944
+ 'split_trees': split_tree,
945
+ 'affluence_trees': affluence_tree,
946
+ }
947
+
948
+ return resid, counts, trees
949
+
950
+ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
952
951
  """
953
- Aggregate values at the leaves of a tree.
952
+ Compute the sufficient statistics for the likelihood ratio of a tree move.
954
953
 
955
954
  Parameters
956
955
  ----------
957
- X : array (p, n)
958
- The predictors.
959
- var_tree : array (2 ** (d - 1),)
960
- The variable indices of the tree.
961
- split_tree : array (2 ** (d - 1),)
962
- The splitting points of the tree.
963
- values : array (n,)
964
- The values to aggregate.
965
- acc_dtype : dtype
966
- The dtype of the output.
956
+ resid : float array (n,)
957
+ The residuals (data minus forest value).
958
+ leaf_indices : int array (n,)
959
+ The leaf indices of the tree (in which leaf each data point falls into).
960
+ tree_size : int
961
+ The size of the tree array (2 ** d).
962
+ batch_size : int, None
963
+ The batch size for the aggregation. Batching increases numerical
964
+ accuracy and parallelism.
967
965
 
968
966
  Returns
969
967
  -------
970
- acc_tree : acc_dtype array (2 ** d,)
971
- Tree leaves for the tree structure indicated by the arguments, where
972
- each leaf contains the sum of the `values` whose corresponding `X` fall
973
- into the leaf.
968
+ resid_tree : float array (2 ** d,)
969
+ The sum of the residuals at data points in each leaf.
974
970
  count_tree : int array (2 ** d,)
975
- Tree leaves containing the count of such values.
971
+ The number of data points in each leaf.
976
972
  """
973
+ if batch_size is None:
974
+ aggr_func = _aggregate_scatter
975
+ else:
976
+ aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
977
+ resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
978
+ count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
979
+ return resid_tree, count_tree
980
+
981
+ def _aggregate_scatter(values, indices, size, dtype):
982
+ return (jnp
983
+ .zeros(size, dtype)
984
+ .at[indices]
985
+ .add(values)
986
+ )
977
987
 
978
- depth = grove.tree_depth(var_tree) + 1
979
- carry = (
980
- jnp.zeros(values.size, bool),
981
- jnp.ones(values.size, grove.minimal_unsigned_dtype(2 * var_tree.size - 1)),
982
- grove.make_tree(depth, acc_dtype),
983
- grove.make_tree(depth, grove.minimal_unsigned_dtype(values.size - 1)),
988
+ def _aggregate_batched(values, indices, size, dtype, batch_size):
989
+ nbatches = indices.size // batch_size + bool(indices.size % batch_size)
990
+ batch_indices = jnp.arange(indices.size) // batch_size
991
+ return (jnp
992
+ .zeros((nbatches, size), dtype)
993
+ .at[batch_indices, indices]
994
+ .add(values)
995
+ .sum(axis=0)
984
996
  )
985
- unit_index = jnp.arange(values.size, dtype=grove.minimal_unsigned_dtype(values.size - 1))
986
997
 
987
- def loop(carry, _):
988
- leaf_found, node_index, acc_tree, count_tree = carry
989
-
990
- is_leaf = split_tree.at[node_index].get(mode='fill', fill_value=0) == 0
991
- leaf_count = is_leaf & ~leaf_found
992
- leaf_values = jnp.where(leaf_count, values, jnp.array(0, values.dtype))
993
- acc_tree = acc_tree.at[node_index].add(leaf_values)
994
- count_tree = count_tree.at[node_index].add(leaf_count)
995
- leaf_found |= is_leaf
996
-
997
- split = split_tree[node_index]
998
- var = var_tree.at[node_index].get(mode='fill', fill_value=0)
999
- x = X[var, unit_index]
1000
-
1001
- node_index <<= 1
1002
- node_index += x >= split
1003
- node_index = jnp.where(leaf_found, 0, node_index)
1004
-
1005
- carry = leaf_found, node_index, acc_tree, count_tree
1006
- return carry, None
1007
-
1008
- (_, _, acc_tree, count_tree), _ = lax.scan(loop, carry, None, depth)
1009
- return acc_tree, count_tree
1010
-
1011
- def mcmc_sample_sigma(bart, key):
998
+ def compute_p_prune_back(new_split_tree, new_affluence_tree):
999
+ """
1000
+ Compute the probability of proposing a prune move after doing a grow move.
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ new_split_tree : int array (2 ** (d - 1),)
1005
+ The decision boundaries of the tree, after the grow move.
1006
+ new_affluence_tree : bool array (2 ** (d - 1),)
1007
+ Which leaves have enough points to be grown, after the grow move.
1008
+
1009
+ Returns
1010
+ -------
1011
+ p_prune : float
1012
+ The probability of proposing a prune move after the grow move. This is
1013
+ 0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
1014
+ at least the node just grown can be pruned.
1015
+ """
1016
+ _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
1017
+ return jnp.where(grow_again_allowed, 0.5, 1)
1018
+
1019
+ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
1020
+ """
1021
+ Compute the likelihood ratio of a grow move.
1022
+
1023
+ Parameters
1024
+ ----------
1025
+ resid_tree : float array (2 ** d,)
1026
+ The sum of the residuals at data points in each leaf.
1027
+ count_tree : int array (2 ** d,)
1028
+ The number of data points in each leaf.
1029
+ sigma2 : float
1030
+ The noise variance.
1031
+ node : int
1032
+ The index of the leaf that has been grown.
1033
+ n_tree : int
1034
+ The number of trees in the forest.
1035
+ min_points_per_leaf : int or None
1036
+ The minimum number of data points in a leaf node.
1037
+
1038
+ Returns
1039
+ -------
1040
+ ratio : float
1041
+ The likelihood ratio P(data | new tree) / P(data | old tree).
1042
+
1043
+ Notes
1044
+ -----
1045
+ The ratio is set to 0 if the grow move would create leaves with not enough
1046
+ datapoints per leaf, although this is part of the prior rather than the
1047
+ likelihood.
1048
+ """
1049
+
1050
+ left_child = node << 1
1051
+ right_child = left_child + 1
1052
+
1053
+ left_resid = resid_tree[left_child]
1054
+ right_resid = resid_tree[right_child]
1055
+ total_resid = left_resid + right_resid
1056
+
1057
+ left_count = count_tree[left_child]
1058
+ right_count = count_tree[right_child]
1059
+ total_count = left_count + right_count
1060
+
1061
+ sigma_mu2 = 1 / n_tree
1062
+ sigma2_left = sigma2 + left_count * sigma_mu2
1063
+ sigma2_right = sigma2 + right_count * sigma_mu2
1064
+ sigma2_total = sigma2 + total_count * sigma_mu2
1065
+
1066
+ sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
1067
+
1068
+ exp_term = sigma_mu2 / (2 * sigma2) * (
1069
+ left_resid * left_resid / sigma2_left +
1070
+ right_resid * right_resid / sigma2_right -
1071
+ total_resid * total_resid / sigma2_total
1072
+ )
1073
+
1074
+ ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1075
+
1076
+ if min_points_per_leaf is not None:
1077
+ ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
1078
+ ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
1079
+
1080
+ return ratio
1081
+
1082
+ def sample_sigma(bart, key):
1012
1083
  """
1013
1084
  Noise variance sampling step of BART MCMC.
1014
1085
 
1015
1086
  Parameters
1016
1087
  ----------
1017
1088
  bart : dict
1018
- A BART mcmc state, as created by `make_bart`.
1089
+ A BART mcmc state, as created by `init`.
1019
1090
  key : jax.dtypes.prng_key array
1020
1091
  A jax random key.
1021
1092
 
@@ -1028,8 +1099,8 @@ def mcmc_sample_sigma(bart, key):
1028
1099
 
1029
1100
  resid = bart['resid']
1030
1101
  alpha = bart['sigma2_alpha'] + resid.size / 2
1031
- norm = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1032
- beta = bart['sigma2_beta'] + norm / 2
1102
+ norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1103
+ beta = bart['sigma2_beta'] + norm2 / 2
1033
1104
 
1034
1105
  sample = random.gamma(key, alpha)
1035
1106
  bart['sigma2'] = beta / sample