bartz 0.0__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -41,9 +41,10 @@ from jax import random
41
41
  from jax import numpy as jnp
42
42
  from jax import lax
43
43
 
44
+ from . import jaxext
44
45
  from . import grove
45
46
 
46
- def make_bart(*,
47
+ def init(*,
47
48
  X,
48
49
  y,
49
50
  max_split,
@@ -51,8 +52,8 @@ def make_bart(*,
51
52
  p_nonterminal,
52
53
  sigma2_alpha,
53
54
  sigma2_beta,
54
- small_float_dtype=jnp.float32,
55
- large_float_dtype=jnp.float32,
55
+ small_float=jnp.float32,
56
+ large_float=jnp.float32,
56
57
  min_points_per_leaf=None,
57
58
  ):
58
59
  """
@@ -75,9 +76,9 @@ def make_bart(*,
75
76
  The shape parameter of the inverse gamma prior on the noise variance.
76
77
  sigma2_beta : float
77
78
  The scale parameter of the inverse gamma prior on the noise variance.
78
- small_float_dtype : dtype, default float32
79
+ small_float : dtype, default float32
79
80
  The dtype for large arrays used in the algorithm.
80
- large_float_dtype : dtype, default float32
81
+ large_float : dtype, default float32
81
82
  The dtype for scalars, small arrays, and arrays which require accuracy.
82
83
  min_points_per_leaf : int, optional
83
84
  The minimum number of data points in a leaf node. 0 if not specified.
@@ -88,31 +89,30 @@ def make_bart(*,
88
89
  A dictionary with array values, representing a BART mcmc state. The
89
90
  keys are:
90
91
 
91
- 'leaf_trees' : int array (num_trees, 2 ** d)
92
- The leaf values of the trees.
92
+ 'leaf_trees' : small_float array (num_trees, 2 ** d)
93
+ The leaf values.
93
94
  'var_trees' : int array (num_trees, 2 ** (d - 1))
94
- The variable indices of the trees. The bottom level is missing since
95
- it can only contain leaves.
95
+ The decision axes.
96
96
  'split_trees' : int array (num_trees, 2 ** (d - 1))
97
- The splitting points.
98
- 'resid' : large_float_dtype array (n,)
97
+ The decision boundaries.
98
+ 'resid' : large_float array (n,)
99
99
  The residuals (data minus forest value). Large float to avoid
100
100
  roundoff.
101
- 'sigma2' : large_float_dtype
101
+ 'sigma2' : large_float
102
102
  The noise variance.
103
103
  'grow_prop_count', 'prune_prop_count' : int
104
104
  The number of grow/prune proposals made during one full MCMC cycle.
105
105
  'grow_acc_count', 'prune_acc_count' : int
106
106
  The number of grow/prune moves accepted during one full MCMC cycle.
107
- 'p_nonterminal' : large_float_dtype array (d - 1,)
107
+ 'p_nonterminal' : large_float array (d - 1,)
108
108
  The probability of a nonterminal node at each depth.
109
- 'sigma2_alpha' : large_float_dtype
109
+ 'sigma2_alpha' : large_float
110
110
  The shape parameter of the inverse gamma prior on the noise variance.
111
- 'sigma2_beta' : large_float_dtype
111
+ 'sigma2_beta' : large_float
112
112
  The scale parameter of the inverse gamma prior on the noise variance.
113
113
  'max_split' : int array (p,)
114
114
  The maximum split index for each variable.
115
- 'y' : small_float_dtype array (n,)
115
+ 'y' : small_float array (n,)
116
116
  The response.
117
117
  'X' : int array (p, n)
118
118
  The predictors.
@@ -123,7 +123,7 @@ def make_bart(*,
123
123
  datapoints. If `min_points_per_leaf` is not specified, this is None.
124
124
  """
125
125
 
126
- p_nonterminal = jnp.asarray(p_nonterminal, large_float_dtype)
126
+ p_nonterminal = jnp.asarray(p_nonterminal, large_float)
127
127
  max_depth = p_nonterminal.size + 1
128
128
 
129
129
  @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
@@ -131,20 +131,20 @@ def make_bart(*,
131
131
  return grove.make_tree(max_depth, dtype)
132
132
 
133
133
  bart = dict(
134
- leaf_trees=make_forest(max_depth, small_float_dtype),
134
+ leaf_trees=make_forest(max_depth, small_float),
135
135
  var_trees=make_forest(max_depth - 1, grove.minimal_unsigned_dtype(X.shape[0] - 1)),
136
136
  split_trees=make_forest(max_depth - 1, max_split.dtype),
137
- resid=jnp.asarray(y, large_float_dtype),
138
- sigma2=jnp.ones((), large_float_dtype),
137
+ resid=jnp.asarray(y, large_float),
138
+ sigma2=jnp.ones((), large_float),
139
139
  grow_prop_count=jnp.zeros((), int),
140
140
  grow_acc_count=jnp.zeros((), int),
141
141
  prune_prop_count=jnp.zeros((), int),
142
142
  prune_acc_count=jnp.zeros((), int),
143
143
  p_nonterminal=p_nonterminal,
144
- sigma2_alpha=jnp.asarray(sigma2_alpha, large_float_dtype),
145
- sigma2_beta=jnp.asarray(sigma2_beta, large_float_dtype),
144
+ sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
145
+ sigma2_beta=jnp.asarray(sigma2_beta, large_float),
146
146
  max_split=max_split,
147
- y=jnp.asarray(y, small_float_dtype),
147
+ y=jnp.asarray(y, small_float),
148
148
  X=X,
149
149
  min_points_per_leaf=(
150
150
  None if min_points_per_leaf is None else
@@ -158,14 +158,14 @@ def make_bart(*,
158
158
 
159
159
  return bart
160
160
 
161
- def mcmc_step(bart, key):
161
+ def step(bart, key):
162
162
  """
163
163
  Perform one full MCMC step on a BART state.
164
164
 
165
165
  Parameters
166
166
  ----------
167
167
  bart : dict
168
- A BART mcmc state, as created by `make_bart`.
168
+ A BART mcmc state, as created by `init`.
169
169
  key : jax.dtypes.prng_key array
170
170
  A jax random key.
171
171
 
@@ -174,19 +174,18 @@ def mcmc_step(bart, key):
174
174
  bart : dict
175
175
  The new BART mcmc state.
176
176
  """
177
- key1, key2 = random.split(key, 2)
178
- bart = mcmc_sample_trees(bart, key1)
179
- bart = mcmc_sample_sigma(bart, key2)
180
- return bart
177
+ key, subkey = random.split(key)
178
+ bart = sample_trees(bart, subkey)
179
+ return sample_sigma(bart, key)
181
180
 
182
- def mcmc_sample_trees(bart, key):
181
+ def sample_trees(bart, key):
183
182
  """
184
183
  Forest sampling step of BART MCMC.
185
184
 
186
185
  Parameters
187
186
  ----------
188
187
  bart : dict
189
- A BART mcmc state, as created by `make_bart`.
188
+ A BART mcmc state, as created by `init`.
190
189
  key : jax.dtypes.prng_key array
191
190
  A jax random key.
192
191
 
@@ -199,148 +198,60 @@ def mcmc_sample_trees(bart, key):
199
198
  -----
200
199
  This function zeroes the proposal counters.
201
200
  """
202
- bart = bart.copy()
203
- for count_var in ['grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count']:
204
- bart[count_var] = jnp.zeros_like(bart[count_var])
205
-
206
- carry = 0, bart, key
207
- def loop(carry, _):
208
- i, bart, key = carry
209
- key, subkey = random.split(key)
210
- bart = mcmc_sample_tree(bart, subkey, i)
211
- return (i + 1, bart, key), None
212
-
213
- (_, bart, _), _ = lax.scan(loop, carry, None, len(bart['leaf_trees']))
214
- return bart
215
-
216
- def mcmc_sample_tree(bart, key, i_tree):
217
- """
218
- Single tree sampling step of BART MCMC.
219
-
220
- Parameters
221
- ----------
222
- bart : dict
223
- A BART mcmc state, as created by `make_bart`.
224
- key : jax.dtypes.prng_key array
225
- A jax random key.
226
- i_tree : int
227
- The index of the tree to sample.
228
-
229
- Returns
230
- -------
231
- bart : dict
232
- The new BART mcmc state.
233
- """
234
- bart = bart.copy()
235
-
236
- y_tree = grove.evaluate_tree_vmap_x(
237
- bart['X'],
238
- bart['leaf_trees'][i_tree],
239
- bart['var_trees'][i_tree],
240
- bart['split_trees'][i_tree],
241
- bart['resid'].dtype,
242
- )
243
- bart['resid'] += y_tree
244
-
245
- key1, key2 = random.split(key, 2)
246
- bart = mcmc_sample_tree_structure(bart, key1, i_tree)
247
- bart = mcmc_sample_tree_leaves(bart, key2, i_tree)
248
-
249
- y_tree = grove.evaluate_tree_vmap_x(
250
- bart['X'],
251
- bart['leaf_trees'][i_tree],
252
- bart['var_trees'][i_tree],
253
- bart['split_trees'][i_tree],
254
- bart['resid'].dtype,
255
- )
256
- bart['resid'] -= y_tree
257
-
258
- return bart
201
+ key, subkey = random.split(key)
202
+ grow_moves, prune_moves = sample_moves(bart, subkey)
203
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
259
204
 
260
- def mcmc_sample_tree_structure(bart, key, i_tree):
205
+ def sample_moves(bart, key):
261
206
  """
262
- Single tree structure sampling step of BART MCMC.
207
+ Propose moves for all the trees.
263
208
 
264
209
  Parameters
265
210
  ----------
266
211
  bart : dict
267
- A BART mcmc state, as created by `make_bart`. The ``'resid'`` field
268
- shall contain only the residuals w.r.t. the other trees.
212
+ BART mcmc state.
269
213
  key : jax.dtypes.prng_key array
270
214
  A jax random key.
271
- i_tree : int
272
- The index of the tree to sample.
273
215
 
274
216
  Returns
275
217
  -------
276
- bart : dict
277
- The new BART mcmc state.
278
- """
279
- bart = bart.copy()
280
-
281
- var_tree = bart['var_trees'][i_tree]
282
- split_tree = bart['split_trees'][i_tree]
283
- affluence_tree = (
284
- None if bart['affluence_trees'] is None else
285
- bart['affluence_trees'][i_tree]
286
- )
287
-
288
- key1, key2, key3 = random.split(key, 3)
289
- args = [
290
- bart['X'],
291
- var_tree,
292
- split_tree,
293
- affluence_tree,
294
- bart['max_split'],
295
- bart['p_nonterminal'],
296
- bart['sigma2'],
297
- bart['resid'],
298
- len(bart['var_trees']),
299
- bart['min_points_per_leaf'],
300
- key1,
301
- ]
302
- grow_var_tree, grow_split_tree, grow_affluence_tree, grow_allowed, grow_ratio = grow_move(*args)
303
-
304
- args[-1] = key2
305
- prune_var_tree, prune_split_tree, prune_affluence_tree, prune_allowed, prune_ratio = prune_move(*args)
306
-
307
- u0, u1 = random.uniform(key3, (2,))
308
-
309
- p_grow = jnp.where(grow_allowed & prune_allowed, 0.5, grow_allowed)
310
- try_grow = u0 < p_grow
311
- try_prune = prune_allowed & ~try_grow
312
-
313
- do_grow = try_grow & (u1 < grow_ratio)
314
- do_prune = try_prune & (u1 < prune_ratio)
315
-
316
- var_tree = jnp.where(do_grow, grow_var_tree, var_tree)
317
- split_tree = jnp.where(do_grow, grow_split_tree, split_tree)
318
- var_tree = jnp.where(do_prune, prune_var_tree, var_tree)
319
- split_tree = jnp.where(do_prune, prune_split_tree, split_tree)
320
-
321
- bart['var_trees'] = bart['var_trees'].at[i_tree].set(var_tree)
322
- bart['split_trees'] = bart['split_trees'].at[i_tree].set(split_tree)
323
-
324
- if bart['min_points_per_leaf'] is not None:
325
- affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
326
- affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
327
- bart['affluence_trees'] = bart['affluence_trees'].at[i_tree].set(affluence_tree)
328
-
329
- bart['grow_prop_count'] += try_grow
330
- bart['grow_acc_count'] += do_grow
331
- bart['prune_prop_count'] += try_prune
332
- bart['prune_acc_count'] += do_prune
333
-
334
- return bart
335
-
336
- def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
218
+ grow_moves, prune_moves : dict
219
+ The proposals for grow and prune moves, with these fields:
220
+
221
+ 'allowed' : bool array (num_trees,)
222
+ Whether the move is possible.
223
+ 'node' : int array (num_trees,)
224
+ The index of the leaf to grow or node to prune.
225
+ 'var_tree' : int array (num_trees, 2 ** (d - 1),)
226
+ The new decision axes of the tree.
227
+ 'split_tree' : int array (num_trees, 2 ** (d - 1),)
228
+ The new decision boundaries of the tree.
229
+ 'partial_ratio' : float array (num_trees,)
230
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
231
+ the likelihood ratio, and the probability of proposing the prune
232
+ move. For the prune move, the ratio is inverted.
233
+ """
234
+ key = random.split(key, bart['var_trees'].shape[0])
235
+ return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
236
+
237
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
238
+ def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
239
+ key, key1 = random.split(key)
240
+ args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
241
+ grow = grow_move(*args, key)
242
+ prune = prune_move(*args, key1)
243
+ return grow, prune
244
+
245
+ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
337
246
  """
338
247
  Tree structure grow move proposal of BART MCMC.
339
248
 
249
+ This moves picks a leaf node and converts it to a non-terminal node with
250
+ two leaf children. The move is not possible if all the leaves are already at
251
+ maximum depth.
252
+
340
253
  Parameters
341
254
  ----------
342
- X : array (p, n)
343
- The predictors.
344
255
  var_tree : array (2 ** (d - 1),)
345
256
  The variable indices of the tree.
346
257
  split_tree : array (2 ** (d - 1),)
@@ -351,80 +262,47 @@ def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal,
351
262
  The maximum split index for each variable.
352
263
  p_nonterminal : array (d - 1,)
353
264
  The probability of a nonterminal node at each depth.
354
- sigma2 : float
355
- The noise variance.
356
- resid : array (n,)
357
- The residuals (data minus forest value), computed using all trees but
358
- the tree under consideration.
359
- n_tree : int
360
- The number of trees in the forest.
361
- min_points_per_leaf : int
362
- The minimum number of data points in a leaf node.
363
265
  key : jax.dtypes.prng_key array
364
266
  A jax random key.
365
267
 
366
268
  Returns
367
269
  -------
368
- var_tree : array (2 ** (d - 1),)
369
- The new variable indices of the tree.
370
- split_tree : array (2 ** (d - 1),)
371
- The new splitting points of the tree.
372
- affluence_tree : bool array (2 ** (d - 1),) or None
373
- The new indicator whether a leaf has enough points to be grown.
374
- allowed : bool
375
- Whether the move is allowed.
376
- ratio : float
377
- The Metropolis-Hastings ratio.
378
-
379
- Notes
380
- -----
381
- This moves picks a leaf node and converts it to a non-terminal node with
382
- two leaf children. The move is not possible if all the leaves are already at
383
- maximum depth.
384
- """
385
-
386
- key1, key2, key3 = random.split(key, 3)
270
+ grow_move : dict
271
+ A dictionary with fields:
272
+
273
+ 'allowed' : bool
274
+ Whether the move is possible.
275
+ 'node' : int
276
+ The index of the leaf to grow.
277
+ 'var_tree' : array (2 ** (d - 1),)
278
+ The new decision axes of the tree.
279
+ 'split_tree' : array (2 ** (d - 1),)
280
+ The new decision boundaries of the tree.
281
+ 'partial_ratio' : float
282
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
283
+ the likelihood ratio and the probability of proposing the prune
284
+ move.
285
+ """
286
+
287
+ key, key1, key2 = random.split(key, 3)
387
288
 
388
- leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key1)
389
-
390
- var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key2)
289
+ leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
290
+
291
+ var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
391
292
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
392
293
 
393
- split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key3)
294
+ split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
394
295
  new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
395
296
 
396
- likelihood_ratio, new_affluence_tree = compute_likelihood_ratio(X, var_tree, new_split_tree, resid, sigma2, leaf_to_grow, n_tree, min_points_per_leaf)
397
-
398
- trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, leaf_to_grow, split_tree, new_split_tree, affluence_tree, new_affluence_tree)
399
-
400
- ratio = trans_tree_ratio * likelihood_ratio
401
-
402
- return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
403
-
404
- def growable_leaves(split_tree, affluence_tree):
405
- """
406
- Return a mask indicating the leaf nodes that can be proposed for growth.
407
-
408
- Parameters
409
- ----------
410
- split_tree : array (2 ** (d - 1),)
411
- The splitting points of the tree.
412
- affluence_tree : bool array (2 ** (d - 1),) or None
413
- Whether a leaf has enough points to be grown.
297
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree, new_split_tree)
414
298
 
415
- Returns
416
- -------
417
- is_growable : bool array (2 ** (d - 1),)
418
- The mask indicating the leaf nodes that can be proposed to grow, i.e.,
419
- that are not at the bottom level and have at least two times the number
420
- of minimum points per leaf.
421
- allowed : bool
422
- Whether the grow move is allowed, i.e., there are growable leaves.
423
- """
424
- is_growable = grove.is_actual_leaf(split_tree)
425
- if affluence_tree is not None:
426
- is_growable &= affluence_tree
427
- return is_growable, jnp.any(is_growable)
299
+ return dict(
300
+ allowed=allowed,
301
+ node=leaf_to_grow,
302
+ var_tree=var_tree,
303
+ split_tree=new_split_tree,
304
+ partial_ratio=ratio,
305
+ )
428
306
 
429
307
  def choose_leaf(split_tree, affluence_tree, key):
430
308
  """
@@ -443,7 +321,7 @@ def choose_leaf(split_tree, affluence_tree, key):
443
321
  -------
444
322
  leaf_to_grow : int
445
323
  The index of the leaf to grow. If ``num_growable == 0``, return
446
- ``split_tree.size``.
324
+ ``2 ** d``.
447
325
  num_growable : int
448
326
  The number of leaf nodes that can be grown.
449
327
  num_prunable : int
@@ -454,11 +332,37 @@ def choose_leaf(split_tree, affluence_tree, key):
454
332
  """
455
333
  is_growable, allowed = growable_leaves(split_tree, affluence_tree)
456
334
  leaf_to_grow = randint_masked(key, is_growable)
335
+ leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
457
336
  num_growable = jnp.count_nonzero(is_growable)
458
337
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
459
338
  num_prunable = jnp.count_nonzero(is_parent)
460
339
  return leaf_to_grow, num_growable, num_prunable, allowed
461
340
 
341
+ def growable_leaves(split_tree, affluence_tree):
342
+ """
343
+ Return a mask indicating the leaf nodes that can be proposed for growth.
344
+
345
+ Parameters
346
+ ----------
347
+ split_tree : array (2 ** (d - 1),)
348
+ The splitting points of the tree.
349
+ affluence_tree : bool array (2 ** (d - 1),) or None
350
+ Whether a leaf has enough points to be grown.
351
+
352
+ Returns
353
+ -------
354
+ is_growable : bool array (2 ** (d - 1),)
355
+ The mask indicating the leaf nodes that can be proposed to grow, i.e.,
356
+ that are not at the bottom level and have at least two times the number
357
+ of minimum points per leaf.
358
+ allowed : bool
359
+ Whether the grow move is allowed, i.e., there are growable leaves.
360
+ """
361
+ is_growable = grove.is_actual_leaf(split_tree)
362
+ if affluence_tree is not None:
363
+ is_growable &= affluence_tree
364
+ return is_growable, jnp.any(is_growable)
365
+
462
366
  def randint_masked(key, mask):
463
367
  """
464
368
  Return a random integer in a range, including only some values.
@@ -665,7 +569,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
665
569
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
666
570
  return random.randint(key, (), l, r)
667
571
 
668
- def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree, initial_affluence_tree, new_affluence_tree):
572
+ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree):
669
573
  """
670
574
  Compute the product of the transition and prior ratios of a grow move.
671
575
 
@@ -676,8 +580,6 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
676
580
  num_prunable : int
677
581
  The number of leaf parents that could be pruned, after converting the
678
582
  leaf to be grown to a non-terminal node.
679
- tree_halfsize : int
680
- Half the length of the tree array, i.e., 2 ** (d - 1).
681
583
  p_nonterminal : array (d - 1,)
682
584
  The probability of a nonterminal node at each depth.
683
585
  leaf_to_grow : int
@@ -686,16 +588,13 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
686
588
  The splitting points of the tree, before the leaf is grown.
687
589
  new_split_tree : array (2 ** (d - 1),)
688
590
  The splitting points of the tree, after the leaf is grown.
689
- initial_affluence_tree : bool array (2 ** (d - 1),) or None
690
- Whether a leaf has enough points to be grown, before the leaf is grown.
691
- new_affluence_tree : bool array (2 ** (d - 1),) or None
692
- Whether a leaf has enough points to be grown, after the leaf is grown.
693
591
 
694
592
  Returns
695
593
  -------
696
594
  ratio : float
697
595
  The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
698
- times the prior ratio P(new tree) / P(old tree).
596
+ times the prior ratio P(new tree) / P(old tree), but the transition
597
+ ratio is missing the factor P(propose prune) in the numerator.
699
598
  """
700
599
 
701
600
  # the two ratios also contain factors num_available_split *
@@ -704,101 +603,21 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
704
603
  prune_was_allowed = prune_allowed(initial_split_tree)
705
604
  p_grow = jnp.where(prune_was_allowed, 0.5, 1)
706
605
 
707
- _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
708
- p_prune = jnp.where(grow_again_allowed, 0.5, 1)
709
-
710
- trans_ratio = p_prune * num_growable / (p_grow * num_prunable)
606
+ trans_ratio = num_growable / (p_grow * num_prunable)
711
607
 
712
- depth = grove.index_depth(leaf_to_grow, tree_halfsize)
608
+ depth = grove.tree_depths(initial_split_tree.size)[leaf_to_grow]
713
609
  p_parent = p_nonterminal[depth]
714
610
  cp_children = 1 - p_nonterminal.at[depth + 1].get(mode='fill', fill_value=0)
715
611
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
716
612
 
717
613
  return trans_ratio * tree_ratio
718
614
 
719
- def compute_likelihood_ratio(X, var_tree, split_tree, resid, sigma2, new_node, n_tree, min_points_per_leaf):
720
- """
721
- Compute the likelihood ratio of a grow move.
722
-
723
- Parameters
724
- ----------
725
- X : array (p, n)
726
- The predictors.
727
- var_tree : array (2 ** (d - 1),)
728
- The variable indices of the tree, after the grow move.
729
- split_tree : array (2 ** (d - 1),)
730
- The splitting points of the tree, after the grow move.
731
- resid : array (n,)
732
- The residuals (data minus forest value), for all trees but the one
733
- under consideration.
734
- sigma2 : float
735
- The noise variance.
736
- new_node : int
737
- The index of the leaf that has been grown.
738
- n_tree : int
739
- The number of trees in the forest.
740
- min_points_per_leaf : int or None
741
- The minimum number of data points in a leaf node.
742
-
743
- Returns
744
- -------
745
- ratio : float
746
- The likelihood ratio P(data | new tree) / P(data | old tree).
747
- affluence_tree : bool array (2 ** (d - 1),) or None
748
- Whether a leaf has enough points to be grown, after the grow move.
749
- """
750
-
751
- resid_tree, count_tree = agg_values(
752
- X,
753
- var_tree,
754
- split_tree,
755
- resid,
756
- sigma2.dtype,
757
- )
758
-
759
- left_child = new_node << 1
760
- right_child = left_child + 1
761
-
762
- left_resid = resid_tree[left_child]
763
- right_resid = resid_tree[right_child]
764
- total_resid = left_resid + right_resid
765
-
766
- left_count = count_tree[left_child]
767
- right_count = count_tree[right_child]
768
- total_count = left_count + right_count
769
-
770
- sigma_mu2 = 1 / n_tree
771
- sigma2_left = sigma2 + left_count * sigma_mu2
772
- sigma2_right = sigma2 + right_count * sigma_mu2
773
- sigma2_total = sigma2 + total_count * sigma_mu2
774
-
775
- sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
776
-
777
- exp_term = sigma_mu2 / (2 * sigma2) * (
778
- left_resid * left_resid / sigma2_left +
779
- right_resid * right_resid / sigma2_right -
780
- total_resid * total_resid / sigma2_total
781
- )
782
-
783
- ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
784
-
785
- if min_points_per_leaf is not None:
786
- ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
787
- ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
788
- affluence_tree = count_tree[:count_tree.size // 2] >= 2 * min_points_per_leaf
789
- else:
790
- affluence_tree = None
791
-
792
- return ratio, affluence_tree
793
-
794
- def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
615
+ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
795
616
  """
796
617
  Tree structure prune move proposal of BART MCMC.
797
618
 
798
619
  Parameters
799
620
  ----------
800
- X : array (p, n)
801
- The predictors.
802
621
  var_tree : array (2 ** (d - 1),)
803
622
  The variable indices of the tree.
804
623
  split_tree : array (2 ** (d - 1),)
@@ -809,48 +628,41 @@ def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal
809
628
  The maximum split index for each variable.
810
629
  p_nonterminal : array (d - 1,)
811
630
  The probability of a nonterminal node at each depth.
812
- sigma2 : float
813
- The noise variance.
814
- resid : array (n,)
815
- The residuals (data minus forest value), computed using all trees but
816
- the tree under consideration.
817
- n_tree : int
818
- The number of trees in the forest.
819
- min_points_per_leaf : int
820
- The minimum number of data points in a leaf node.
821
631
  key : jax.dtypes.prng_key array
822
632
  A jax random key.
823
633
 
824
634
  Returns
825
635
  -------
826
- var_tree : array (2 ** (d - 1),)
827
- The new variable indices of the tree.
828
- split_tree : array (2 ** (d - 1),)
829
- The new splitting points of the tree.
830
- affluence_tree : bool array (2 ** (d - 1),) or None
831
- The new indicator whether a leaf has enough points to be grown.
832
- allowed : bool
833
- Whether the move is allowed.
834
- ratio : float
835
- The Metropolis-Hastings ratio.
636
+ prune_move : dict
637
+ A dictionary with fields:
638
+
639
+ 'allowed' : bool
640
+ Whether the move is possible.
641
+ 'node' : int
642
+ The index of the leaf to grow.
643
+ 'var_tree' : array (2 ** (d - 1),)
644
+ The new decision axes of the tree.
645
+ 'split_tree' : array (2 ** (d - 1),)
646
+ The new decision boundaries of the tree.
647
+ 'partial_ratio' : float
648
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
649
+ the likelihood ratio and the probability of proposing the prune
650
+ move. This ratio is inverted.
836
651
  """
837
652
  node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
838
653
  allowed = prune_allowed(split_tree)
839
654
 
840
655
  new_split_tree = split_tree.at[node_to_prune].set(0)
841
- # should I clean up var_tree as well? just for debugging. it hasn't given me problems though
842
656
 
843
- likelihood_ratio, _ = compute_likelihood_ratio(X, var_tree, split_tree, resid, sigma2, node_to_prune, n_tree, min_points_per_leaf)
844
- new_affluence_tree = (
845
- None if affluence_tree is None else
846
- affluence_tree.at[node_to_prune].set(True)
847
- )
848
- trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, node_to_prune, new_split_tree, split_tree, new_affluence_tree, affluence_tree)
849
-
850
- ratio = trans_tree_ratio * likelihood_ratio
851
- ratio = 1 / ratio # Question: should I use lax.reciprocal for this?
657
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, new_split_tree, split_tree)
852
658
 
853
- return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
659
+ return dict(
660
+ allowed=allowed,
661
+ node=node_to_prune,
662
+ var_tree=var_tree,
663
+ split_tree=new_split_tree,
664
+ partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
665
+ )
854
666
 
855
667
  def choose_leaf_parent(split_tree, affluence_tree, key):
856
668
  """
@@ -906,116 +718,256 @@ def prune_allowed(split_tree):
906
718
  """
907
719
  return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
908
720
 
909
- def mcmc_sample_tree_leaves(bart, key, i_tree):
910
- """
911
- Single tree leaves sampling step of BART MCMC.
912
-
913
- Parameters
914
- ----------
915
- bart : dict
916
- A BART mcmc state, as created by `make_bart`. The ``'resid'`` field
917
- shall contain the residuals only w.r.t. the other trees.
918
- key : jax.dtypes.prng_key array
919
- A jax random key.
920
- i_tree : int
921
- The index of the tree to sample.
922
-
923
- Returns
924
- -------
925
- bart : dict
926
- The new BART mcmc state.
927
- """
721
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
928
722
  bart = bart.copy()
723
+ def loop(carry, item):
724
+ resid = carry.pop('resid')
725
+ resid, carry, trees = accept_move_and_sample_leaves(
726
+ bart['X'],
727
+ len(bart['leaf_trees']),
728
+ resid,
729
+ bart['sigma2'],
730
+ bart['min_points_per_leaf'],
731
+ carry,
732
+ *item,
733
+ )
734
+ carry['resid'] = resid
735
+ return carry, trees
736
+ carry = {
737
+ k: jnp.zeros_like(bart[k]) for k in
738
+ ['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
739
+ }
740
+ carry['resid'] = bart['resid']
741
+ items = (
742
+ bart['leaf_trees'],
743
+ bart['var_trees'],
744
+ bart['split_trees'],
745
+ bart['affluence_trees'],
746
+ grow_moves,
747
+ prune_moves,
748
+ random.split(key, len(bart['leaf_trees'])),
749
+ )
750
+ carry, trees = lax.scan(loop, carry, items)
751
+ bart.update(carry)
752
+ bart.update(trees)
753
+ return bart
754
+
755
+ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree, var_tree, split_tree, affluence_tree, grow_move, prune_move, key):
756
+
757
+ # compute leaf indices according to grow move tree
758
+ traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
759
+ grow_leaf_indices = traverse_tree(X, grow_move['var_tree'], grow_move['split_tree'])
760
+
761
+ # compute leaf indices in starting tree
762
+ grow_node = grow_move['node']
763
+ grow_left = grow_node << 1
764
+ grow_right = grow_left + 1
765
+ leaf_indices = jnp.where(
766
+ (grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
767
+ grow_node,
768
+ grow_leaf_indices,
769
+ )
929
770
 
930
- resid_tree, count_tree = agg_values(
931
- bart['X'],
932
- bart['var_trees'][i_tree],
933
- bart['split_trees'][i_tree],
934
- bart['resid'],
935
- bart['sigma2'].dtype,
771
+ # compute leaf indices in prune tree
772
+ prune_node = prune_move['node']
773
+ prune_left = prune_node << 1
774
+ prune_right = prune_left + 1
775
+ prune_leaf_indices = jnp.where(
776
+ (leaf_indices == prune_left) | (leaf_indices == prune_right),
777
+ prune_node,
778
+ leaf_indices,
936
779
  )
937
780
 
938
- prec_lk = count_tree / bart['sigma2']
939
- prec_prior = len(bart['leaf_trees'])
940
- var_post = 1 / (prec_lk + prec_prior) # lax.reciprocal?
941
- mean_post = resid_tree / bart['sigma2'] * var_post # = mean_lk * prec_lk * var_post
781
+ # subtract starting tree from function
782
+ resid += leaf_tree[leaf_indices]
783
+
784
+ # aggregate residuals and count units per leaf
785
+ grow_resid_tree = jnp.zeros_like(leaf_tree, sigma2.dtype)
786
+ grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
787
+ grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
788
+ grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
789
+
790
+ # compute aggregations in starting tree
791
+ # I do not zero the children because garbage there does not matter
792
+ resid_tree = (grow_resid_tree.at[grow_node]
793
+ .set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
794
+ count_tree = (grow_count_tree.at[grow_node]
795
+ .set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
796
+
797
+ # compute aggregations in prune tree
798
+ prune_resid_tree = (resid_tree.at[prune_node]
799
+ .set(resid_tree[prune_left] + resid_tree[prune_right]))
800
+ prune_count_tree = (count_tree.at[prune_node]
801
+ .set(count_tree[prune_left] + count_tree[prune_right]))
802
+
803
+ # compute affluence trees
804
+ if min_points_per_leaf is not None:
805
+ grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
806
+ prune_affluence_tree = affluence_tree.at[prune_node].set(True)
807
+
808
+ # compute probability of proposing prune
809
+ grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
810
+ prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
811
+
812
+ # compute likelihood ratios
813
+ grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
814
+ prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
815
+
816
+ # compute acceptance ratios
817
+ grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
818
+ prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
819
+ prune_ratio = lax.reciprocal(prune_ratio)
820
+
821
+ # random coins in [0, 1) for proposal and acceptance
822
+ key, subkey = random.split(key)
823
+ u0, u1 = random.uniform(subkey, (2,))
824
+
825
+ # determine what move to propose (not proposing anything is an option)
826
+ p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
827
+ try_grow = u0 < p_grow
828
+ try_prune = prune_move['allowed'] & ~try_grow
829
+
830
+ # determine whether to accept the move
831
+ do_grow = try_grow & (u1 < grow_ratio)
832
+ do_prune = try_prune & (u1 < prune_ratio)
942
833
 
834
+ # pick trees for chosen move
835
+ trees = {}
836
+ var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
837
+ split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
838
+ var_tree = jnp.where(do_prune, prune_move['var_tree'], var_tree)
839
+ split_tree = jnp.where(do_prune, prune_move['split_tree'], split_tree)
840
+ if min_points_per_leaf is not None:
841
+ affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
842
+ affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
843
+ resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
844
+ count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
845
+ resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
846
+ count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
847
+
848
+ # update acceptance counts
849
+ counts = counts.copy()
850
+ counts['grow_prop_count'] += try_grow
851
+ counts['grow_acc_count'] += do_grow
852
+ counts['prune_prop_count'] += try_prune
853
+ counts['prune_acc_count'] += do_prune
854
+
855
+ # compute leaves posterior
856
+ prec_lk = count_tree / sigma2
857
+ var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
858
+ mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
859
+
860
+ # sample leaves
943
861
  z = random.normal(key, mean_post.shape, mean_post.dtype)
944
- # TODO maybe use long float here, I guess this part is not a bottleneck
945
862
  leaf_tree = mean_post + z * jnp.sqrt(var_post)
946
- leaf_tree = leaf_tree.at[0].set(0) # this 0 is used by evaluate_tree
947
- bart['leaf_trees'] = bart['leaf_trees'].at[i_tree].set(leaf_tree)
948
863
 
949
- return bart
864
+ # add new tree to function
865
+ leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
866
+ leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
867
+ resid -= leaf_tree[leaf_indices]
950
868
 
951
- def agg_values(X, var_tree, split_tree, values, acc_dtype):
869
+ # pack trees
870
+ trees = {
871
+ 'leaf_trees': leaf_tree,
872
+ 'var_trees': var_tree,
873
+ 'split_trees': split_tree,
874
+ 'affluence_trees': affluence_tree,
875
+ }
876
+
877
+ return resid, counts, trees
878
+
879
+ def compute_p_prune_back(new_split_tree, new_affluence_tree):
952
880
  """
953
- Aggregate values at the leaves of a tree.
881
+ Compute the probability of proposing a prune move after doing a grow move.
954
882
 
955
883
  Parameters
956
884
  ----------
957
- X : array (p, n)
958
- The predictors.
959
- var_tree : array (2 ** (d - 1),)
960
- The variable indices of the tree.
961
- split_tree : array (2 ** (d - 1),)
962
- The splitting points of the tree.
963
- values : array (n,)
964
- The values to aggregate.
965
- acc_dtype : dtype
966
- The dtype of the output.
885
+ new_split_tree : int array (2 ** (d - 1),)
886
+ The decision boundaries of the tree, after the grow move.
887
+ new_affluence_tree : bool array (2 ** (d - 1),)
888
+ Which leaves have enough points to be grown, after the grow move.
967
889
 
968
890
  Returns
969
891
  -------
970
- acc_tree : acc_dtype array (2 ** d,)
971
- Tree leaves for the tree structure indicated by the arguments, where
972
- each leaf contains the sum of the `values` whose corresponding `X` fall
973
- into the leaf.
892
+ p_prune : float
893
+ The probability of proposing a prune move after the grow move. This is
894
+ 0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
895
+ at least the node just grown can be pruned.
896
+ """
897
+ _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
898
+ return jnp.where(grow_again_allowed, 0.5, 1)
899
+
900
+ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
901
+ """
902
+ Compute the likelihood ratio of a grow move.
903
+
904
+ Parameters
905
+ ----------
906
+ resid_tree : float array (2 ** d,)
907
+ The sum of the residuals at data points in each leaf.
974
908
  count_tree : int array (2 ** d,)
975
- Tree leaves containing the count of such values.
909
+ The number of data points in each leaf.
910
+ sigma2 : float
911
+ The noise variance.
912
+ node : int
913
+ The index of the leaf that has been grown.
914
+ n_tree : int
915
+ The number of trees in the forest.
916
+ min_points_per_leaf : int or None
917
+ The minimum number of data points in a leaf node.
918
+
919
+ Returns
920
+ -------
921
+ ratio : float
922
+ The likelihood ratio P(data | new tree) / P(data | old tree).
923
+
924
+ Notes
925
+ -----
926
+ The ratio is set to 0 if the grow move would create leaves with not enough
927
+ datapoints per leaf, although this is part of the prior rather than the
928
+ likelihood.
976
929
  """
977
930
 
978
- depth = grove.tree_depth(var_tree) + 1
979
- carry = (
980
- jnp.zeros(values.size, bool),
981
- jnp.ones(values.size, grove.minimal_unsigned_dtype(2 * var_tree.size - 1)),
982
- grove.make_tree(depth, acc_dtype),
983
- grove.make_tree(depth, grove.minimal_unsigned_dtype(values.size - 1)),
931
+ left_child = node << 1
932
+ right_child = left_child + 1
933
+
934
+ left_resid = resid_tree[left_child]
935
+ right_resid = resid_tree[right_child]
936
+ total_resid = left_resid + right_resid
937
+
938
+ left_count = count_tree[left_child]
939
+ right_count = count_tree[right_child]
940
+ total_count = left_count + right_count
941
+
942
+ sigma_mu2 = 1 / n_tree
943
+ sigma2_left = sigma2 + left_count * sigma_mu2
944
+ sigma2_right = sigma2 + right_count * sigma_mu2
945
+ sigma2_total = sigma2 + total_count * sigma_mu2
946
+
947
+ sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
948
+
949
+ exp_term = sigma_mu2 / (2 * sigma2) * (
950
+ left_resid * left_resid / sigma2_left +
951
+ right_resid * right_resid / sigma2_right -
952
+ total_resid * total_resid / sigma2_total
984
953
  )
985
- unit_index = jnp.arange(values.size, dtype=grove.minimal_unsigned_dtype(values.size - 1))
986
954
 
987
- def loop(carry, _):
988
- leaf_found, node_index, acc_tree, count_tree = carry
989
-
990
- is_leaf = split_tree.at[node_index].get(mode='fill', fill_value=0) == 0
991
- leaf_count = is_leaf & ~leaf_found
992
- leaf_values = jnp.where(leaf_count, values, jnp.array(0, values.dtype))
993
- acc_tree = acc_tree.at[node_index].add(leaf_values)
994
- count_tree = count_tree.at[node_index].add(leaf_count)
995
- leaf_found |= is_leaf
996
-
997
- split = split_tree[node_index]
998
- var = var_tree.at[node_index].get(mode='fill', fill_value=0)
999
- x = X[var, unit_index]
1000
-
1001
- node_index <<= 1
1002
- node_index += x >= split
1003
- node_index = jnp.where(leaf_found, 0, node_index)
1004
-
1005
- carry = leaf_found, node_index, acc_tree, count_tree
1006
- return carry, None
1007
-
1008
- (_, _, acc_tree, count_tree), _ = lax.scan(loop, carry, None, depth)
1009
- return acc_tree, count_tree
1010
-
1011
- def mcmc_sample_sigma(bart, key):
955
+ ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
956
+
957
+ if min_points_per_leaf is not None:
958
+ ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
959
+ ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
960
+
961
+ return ratio
962
+
963
+ def sample_sigma(bart, key):
1012
964
  """
1013
965
  Noise variance sampling step of BART MCMC.
1014
966
 
1015
967
  Parameters
1016
968
  ----------
1017
969
  bart : dict
1018
- A BART mcmc state, as created by `make_bart`.
970
+ A BART mcmc state, as created by `init`.
1019
971
  key : jax.dtypes.prng_key array
1020
972
  A jax random key.
1021
973
 
@@ -1028,8 +980,8 @@ def mcmc_sample_sigma(bart, key):
1028
980
 
1029
981
  resid = bart['resid']
1030
982
  alpha = bart['sigma2_alpha'] + resid.size / 2
1031
- norm = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1032
- beta = bart['sigma2_beta'] + norm / 2
983
+ norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
984
+ beta = bart['sigma2_beta'] + norm2 / 2
1033
985
 
1034
986
  sample = random.gamma(key, alpha)
1035
987
  bart['sigma2'] = beta / sample