bartz 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/__init__.py +1 -1
- bartz/_version.py +1 -0
- bartz/debug.py +5 -19
- bartz/grove.py +71 -118
- bartz/interface.py +6 -15
- bartz/mcmcloop.py +12 -6
- bartz/mcmcstep.py +379 -427
- {bartz-0.0.1.dist-info → bartz-0.1.0.dist-info}/METADATA +1 -1
- bartz-0.1.0.dist-info/RECORD +13 -0
- bartz-0.0.1.dist-info/RECORD +0 -12
- {bartz-0.0.1.dist-info → bartz-0.1.0.dist-info}/LICENSE +0 -0
- {bartz-0.0.1.dist-info → bartz-0.1.0.dist-info}/WHEEL +0 -0
bartz/mcmcstep.py
CHANGED
|
@@ -41,9 +41,10 @@ from jax import random
|
|
|
41
41
|
from jax import numpy as jnp
|
|
42
42
|
from jax import lax
|
|
43
43
|
|
|
44
|
+
from . import jaxext
|
|
44
45
|
from . import grove
|
|
45
46
|
|
|
46
|
-
def
|
|
47
|
+
def init(*,
|
|
47
48
|
X,
|
|
48
49
|
y,
|
|
49
50
|
max_split,
|
|
@@ -51,8 +52,8 @@ def make_bart(*,
|
|
|
51
52
|
p_nonterminal,
|
|
52
53
|
sigma2_alpha,
|
|
53
54
|
sigma2_beta,
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
small_float=jnp.float32,
|
|
56
|
+
large_float=jnp.float32,
|
|
56
57
|
min_points_per_leaf=None,
|
|
57
58
|
):
|
|
58
59
|
"""
|
|
@@ -75,9 +76,9 @@ def make_bart(*,
|
|
|
75
76
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
76
77
|
sigma2_beta : float
|
|
77
78
|
The scale parameter of the inverse gamma prior on the noise variance.
|
|
78
|
-
|
|
79
|
+
small_float : dtype, default float32
|
|
79
80
|
The dtype for large arrays used in the algorithm.
|
|
80
|
-
|
|
81
|
+
large_float : dtype, default float32
|
|
81
82
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
82
83
|
min_points_per_leaf : int, optional
|
|
83
84
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
@@ -88,31 +89,30 @@ def make_bart(*,
|
|
|
88
89
|
A dictionary with array values, representing a BART mcmc state. The
|
|
89
90
|
keys are:
|
|
90
91
|
|
|
91
|
-
'leaf_trees' :
|
|
92
|
-
The leaf values
|
|
92
|
+
'leaf_trees' : small_float array (num_trees, 2 ** d)
|
|
93
|
+
The leaf values.
|
|
93
94
|
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
94
|
-
The
|
|
95
|
-
it can only contain leaves.
|
|
95
|
+
The decision axes.
|
|
96
96
|
'split_trees' : int array (num_trees, 2 ** (d - 1))
|
|
97
|
-
The
|
|
98
|
-
'resid' :
|
|
97
|
+
The decision boundaries.
|
|
98
|
+
'resid' : large_float array (n,)
|
|
99
99
|
The residuals (data minus forest value). Large float to avoid
|
|
100
100
|
roundoff.
|
|
101
|
-
'sigma2' :
|
|
101
|
+
'sigma2' : large_float
|
|
102
102
|
The noise variance.
|
|
103
103
|
'grow_prop_count', 'prune_prop_count' : int
|
|
104
104
|
The number of grow/prune proposals made during one full MCMC cycle.
|
|
105
105
|
'grow_acc_count', 'prune_acc_count' : int
|
|
106
106
|
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
-
'p_nonterminal' :
|
|
107
|
+
'p_nonterminal' : large_float array (d - 1,)
|
|
108
108
|
The probability of a nonterminal node at each depth.
|
|
109
|
-
'sigma2_alpha' :
|
|
109
|
+
'sigma2_alpha' : large_float
|
|
110
110
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
111
|
-
'sigma2_beta' :
|
|
111
|
+
'sigma2_beta' : large_float
|
|
112
112
|
The scale parameter of the inverse gamma prior on the noise variance.
|
|
113
113
|
'max_split' : int array (p,)
|
|
114
114
|
The maximum split index for each variable.
|
|
115
|
-
'y' :
|
|
115
|
+
'y' : small_float array (n,)
|
|
116
116
|
The response.
|
|
117
117
|
'X' : int array (p, n)
|
|
118
118
|
The predictors.
|
|
@@ -123,7 +123,7 @@ def make_bart(*,
|
|
|
123
123
|
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
124
124
|
"""
|
|
125
125
|
|
|
126
|
-
p_nonterminal = jnp.asarray(p_nonterminal,
|
|
126
|
+
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
127
127
|
max_depth = p_nonterminal.size + 1
|
|
128
128
|
|
|
129
129
|
@functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
@@ -131,20 +131,20 @@ def make_bart(*,
|
|
|
131
131
|
return grove.make_tree(max_depth, dtype)
|
|
132
132
|
|
|
133
133
|
bart = dict(
|
|
134
|
-
leaf_trees=make_forest(max_depth,
|
|
134
|
+
leaf_trees=make_forest(max_depth, small_float),
|
|
135
135
|
var_trees=make_forest(max_depth - 1, grove.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
136
136
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
137
|
-
resid=jnp.asarray(y,
|
|
138
|
-
sigma2=jnp.ones((),
|
|
137
|
+
resid=jnp.asarray(y, large_float),
|
|
138
|
+
sigma2=jnp.ones((), large_float),
|
|
139
139
|
grow_prop_count=jnp.zeros((), int),
|
|
140
140
|
grow_acc_count=jnp.zeros((), int),
|
|
141
141
|
prune_prop_count=jnp.zeros((), int),
|
|
142
142
|
prune_acc_count=jnp.zeros((), int),
|
|
143
143
|
p_nonterminal=p_nonterminal,
|
|
144
|
-
sigma2_alpha=jnp.asarray(sigma2_alpha,
|
|
145
|
-
sigma2_beta=jnp.asarray(sigma2_beta,
|
|
144
|
+
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
145
|
+
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
146
146
|
max_split=max_split,
|
|
147
|
-
y=jnp.asarray(y,
|
|
147
|
+
y=jnp.asarray(y, small_float),
|
|
148
148
|
X=X,
|
|
149
149
|
min_points_per_leaf=(
|
|
150
150
|
None if min_points_per_leaf is None else
|
|
@@ -158,14 +158,14 @@ def make_bart(*,
|
|
|
158
158
|
|
|
159
159
|
return bart
|
|
160
160
|
|
|
161
|
-
def
|
|
161
|
+
def step(bart, key):
|
|
162
162
|
"""
|
|
163
163
|
Perform one full MCMC step on a BART state.
|
|
164
164
|
|
|
165
165
|
Parameters
|
|
166
166
|
----------
|
|
167
167
|
bart : dict
|
|
168
|
-
A BART mcmc state, as created by `
|
|
168
|
+
A BART mcmc state, as created by `init`.
|
|
169
169
|
key : jax.dtypes.prng_key array
|
|
170
170
|
A jax random key.
|
|
171
171
|
|
|
@@ -174,19 +174,18 @@ def mcmc_step(bart, key):
|
|
|
174
174
|
bart : dict
|
|
175
175
|
The new BART mcmc state.
|
|
176
176
|
"""
|
|
177
|
-
|
|
178
|
-
bart =
|
|
179
|
-
|
|
180
|
-
return bart
|
|
177
|
+
key, subkey = random.split(key)
|
|
178
|
+
bart = sample_trees(bart, subkey)
|
|
179
|
+
return sample_sigma(bart, key)
|
|
181
180
|
|
|
182
|
-
def
|
|
181
|
+
def sample_trees(bart, key):
|
|
183
182
|
"""
|
|
184
183
|
Forest sampling step of BART MCMC.
|
|
185
184
|
|
|
186
185
|
Parameters
|
|
187
186
|
----------
|
|
188
187
|
bart : dict
|
|
189
|
-
A BART mcmc state, as created by `
|
|
188
|
+
A BART mcmc state, as created by `init`.
|
|
190
189
|
key : jax.dtypes.prng_key array
|
|
191
190
|
A jax random key.
|
|
192
191
|
|
|
@@ -199,148 +198,60 @@ def mcmc_sample_trees(bart, key):
|
|
|
199
198
|
-----
|
|
200
199
|
This function zeroes the proposal counters.
|
|
201
200
|
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
carry = 0, bart, key
|
|
207
|
-
def loop(carry, _):
|
|
208
|
-
i, bart, key = carry
|
|
209
|
-
key, subkey = random.split(key)
|
|
210
|
-
bart = mcmc_sample_tree(bart, subkey, i)
|
|
211
|
-
return (i + 1, bart, key), None
|
|
212
|
-
|
|
213
|
-
(_, bart, _), _ = lax.scan(loop, carry, None, len(bart['leaf_trees']))
|
|
214
|
-
return bart
|
|
215
|
-
|
|
216
|
-
def mcmc_sample_tree(bart, key, i_tree):
|
|
217
|
-
"""
|
|
218
|
-
Single tree sampling step of BART MCMC.
|
|
219
|
-
|
|
220
|
-
Parameters
|
|
221
|
-
----------
|
|
222
|
-
bart : dict
|
|
223
|
-
A BART mcmc state, as created by `make_bart`.
|
|
224
|
-
key : jax.dtypes.prng_key array
|
|
225
|
-
A jax random key.
|
|
226
|
-
i_tree : int
|
|
227
|
-
The index of the tree to sample.
|
|
228
|
-
|
|
229
|
-
Returns
|
|
230
|
-
-------
|
|
231
|
-
bart : dict
|
|
232
|
-
The new BART mcmc state.
|
|
233
|
-
"""
|
|
234
|
-
bart = bart.copy()
|
|
235
|
-
|
|
236
|
-
y_tree = grove.evaluate_tree_vmap_x(
|
|
237
|
-
bart['X'],
|
|
238
|
-
bart['leaf_trees'][i_tree],
|
|
239
|
-
bart['var_trees'][i_tree],
|
|
240
|
-
bart['split_trees'][i_tree],
|
|
241
|
-
bart['resid'].dtype,
|
|
242
|
-
)
|
|
243
|
-
bart['resid'] += y_tree
|
|
244
|
-
|
|
245
|
-
key1, key2 = random.split(key, 2)
|
|
246
|
-
bart = mcmc_sample_tree_structure(bart, key1, i_tree)
|
|
247
|
-
bart = mcmc_sample_tree_leaves(bart, key2, i_tree)
|
|
248
|
-
|
|
249
|
-
y_tree = grove.evaluate_tree_vmap_x(
|
|
250
|
-
bart['X'],
|
|
251
|
-
bart['leaf_trees'][i_tree],
|
|
252
|
-
bart['var_trees'][i_tree],
|
|
253
|
-
bart['split_trees'][i_tree],
|
|
254
|
-
bart['resid'].dtype,
|
|
255
|
-
)
|
|
256
|
-
bart['resid'] -= y_tree
|
|
257
|
-
|
|
258
|
-
return bart
|
|
201
|
+
key, subkey = random.split(key)
|
|
202
|
+
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
203
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
|
|
259
204
|
|
|
260
|
-
def
|
|
205
|
+
def sample_moves(bart, key):
|
|
261
206
|
"""
|
|
262
|
-
|
|
207
|
+
Propose moves for all the trees.
|
|
263
208
|
|
|
264
209
|
Parameters
|
|
265
210
|
----------
|
|
266
211
|
bart : dict
|
|
267
|
-
|
|
268
|
-
shall contain only the residuals w.r.t. the other trees.
|
|
212
|
+
BART mcmc state.
|
|
269
213
|
key : jax.dtypes.prng_key array
|
|
270
214
|
A jax random key.
|
|
271
|
-
i_tree : int
|
|
272
|
-
The index of the tree to sample.
|
|
273
215
|
|
|
274
216
|
Returns
|
|
275
217
|
-------
|
|
276
|
-
|
|
277
|
-
The
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
args[-1] = key2
|
|
305
|
-
prune_var_tree, prune_split_tree, prune_affluence_tree, prune_allowed, prune_ratio = prune_move(*args)
|
|
306
|
-
|
|
307
|
-
u0, u1 = random.uniform(key3, (2,))
|
|
308
|
-
|
|
309
|
-
p_grow = jnp.where(grow_allowed & prune_allowed, 0.5, grow_allowed)
|
|
310
|
-
try_grow = u0 < p_grow
|
|
311
|
-
try_prune = prune_allowed & ~try_grow
|
|
312
|
-
|
|
313
|
-
do_grow = try_grow & (u1 < grow_ratio)
|
|
314
|
-
do_prune = try_prune & (u1 < prune_ratio)
|
|
315
|
-
|
|
316
|
-
var_tree = jnp.where(do_grow, grow_var_tree, var_tree)
|
|
317
|
-
split_tree = jnp.where(do_grow, grow_split_tree, split_tree)
|
|
318
|
-
var_tree = jnp.where(do_prune, prune_var_tree, var_tree)
|
|
319
|
-
split_tree = jnp.where(do_prune, prune_split_tree, split_tree)
|
|
320
|
-
|
|
321
|
-
bart['var_trees'] = bart['var_trees'].at[i_tree].set(var_tree)
|
|
322
|
-
bart['split_trees'] = bart['split_trees'].at[i_tree].set(split_tree)
|
|
323
|
-
|
|
324
|
-
if bart['min_points_per_leaf'] is not None:
|
|
325
|
-
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
326
|
-
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
327
|
-
bart['affluence_trees'] = bart['affluence_trees'].at[i_tree].set(affluence_tree)
|
|
328
|
-
|
|
329
|
-
bart['grow_prop_count'] += try_grow
|
|
330
|
-
bart['grow_acc_count'] += do_grow
|
|
331
|
-
bart['prune_prop_count'] += try_prune
|
|
332
|
-
bart['prune_acc_count'] += do_prune
|
|
333
|
-
|
|
334
|
-
return bart
|
|
335
|
-
|
|
336
|
-
def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
|
|
218
|
+
grow_moves, prune_moves : dict
|
|
219
|
+
The proposals for grow and prune moves, with these fields:
|
|
220
|
+
|
|
221
|
+
'allowed' : bool array (num_trees,)
|
|
222
|
+
Whether the move is possible.
|
|
223
|
+
'node' : int array (num_trees,)
|
|
224
|
+
The index of the leaf to grow or node to prune.
|
|
225
|
+
'var_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
226
|
+
The new decision axes of the tree.
|
|
227
|
+
'split_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
228
|
+
The new decision boundaries of the tree.
|
|
229
|
+
'partial_ratio' : float array (num_trees,)
|
|
230
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
231
|
+
the likelihood ratio, and the probability of proposing the prune
|
|
232
|
+
move. For the prune move, the ratio is inverted.
|
|
233
|
+
"""
|
|
234
|
+
key = random.split(key, bart['var_trees'].shape[0])
|
|
235
|
+
return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
|
|
236
|
+
|
|
237
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
|
|
238
|
+
def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
239
|
+
key, key1 = random.split(key)
|
|
240
|
+
args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
241
|
+
grow = grow_move(*args, key)
|
|
242
|
+
prune = prune_move(*args, key1)
|
|
243
|
+
return grow, prune
|
|
244
|
+
|
|
245
|
+
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
337
246
|
"""
|
|
338
247
|
Tree structure grow move proposal of BART MCMC.
|
|
339
248
|
|
|
249
|
+
This moves picks a leaf node and converts it to a non-terminal node with
|
|
250
|
+
two leaf children. The move is not possible if all the leaves are already at
|
|
251
|
+
maximum depth.
|
|
252
|
+
|
|
340
253
|
Parameters
|
|
341
254
|
----------
|
|
342
|
-
X : array (p, n)
|
|
343
|
-
The predictors.
|
|
344
255
|
var_tree : array (2 ** (d - 1),)
|
|
345
256
|
The variable indices of the tree.
|
|
346
257
|
split_tree : array (2 ** (d - 1),)
|
|
@@ -351,80 +262,47 @@ def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal,
|
|
|
351
262
|
The maximum split index for each variable.
|
|
352
263
|
p_nonterminal : array (d - 1,)
|
|
353
264
|
The probability of a nonterminal node at each depth.
|
|
354
|
-
sigma2 : float
|
|
355
|
-
The noise variance.
|
|
356
|
-
resid : array (n,)
|
|
357
|
-
The residuals (data minus forest value), computed using all trees but
|
|
358
|
-
the tree under consideration.
|
|
359
|
-
n_tree : int
|
|
360
|
-
The number of trees in the forest.
|
|
361
|
-
min_points_per_leaf : int
|
|
362
|
-
The minimum number of data points in a leaf node.
|
|
363
265
|
key : jax.dtypes.prng_key array
|
|
364
266
|
A jax random key.
|
|
365
267
|
|
|
366
268
|
Returns
|
|
367
269
|
-------
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
key1, key2, key3 = random.split(key, 3)
|
|
270
|
+
grow_move : dict
|
|
271
|
+
A dictionary with fields:
|
|
272
|
+
|
|
273
|
+
'allowed' : bool
|
|
274
|
+
Whether the move is possible.
|
|
275
|
+
'node' : int
|
|
276
|
+
The index of the leaf to grow.
|
|
277
|
+
'var_tree' : array (2 ** (d - 1),)
|
|
278
|
+
The new decision axes of the tree.
|
|
279
|
+
'split_tree' : array (2 ** (d - 1),)
|
|
280
|
+
The new decision boundaries of the tree.
|
|
281
|
+
'partial_ratio' : float
|
|
282
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
283
|
+
the likelihood ratio and the probability of proposing the prune
|
|
284
|
+
move.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
key, key1, key2 = random.split(key, 3)
|
|
387
288
|
|
|
388
|
-
leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree,
|
|
389
|
-
|
|
390
|
-
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow,
|
|
289
|
+
leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
|
|
290
|
+
|
|
291
|
+
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
|
|
391
292
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
392
293
|
|
|
393
|
-
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow,
|
|
294
|
+
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
394
295
|
new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
395
296
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, leaf_to_grow, split_tree, new_split_tree, affluence_tree, new_affluence_tree)
|
|
399
|
-
|
|
400
|
-
ratio = trans_tree_ratio * likelihood_ratio
|
|
401
|
-
|
|
402
|
-
return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
|
|
403
|
-
|
|
404
|
-
def growable_leaves(split_tree, affluence_tree):
|
|
405
|
-
"""
|
|
406
|
-
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
407
|
-
|
|
408
|
-
Parameters
|
|
409
|
-
----------
|
|
410
|
-
split_tree : array (2 ** (d - 1),)
|
|
411
|
-
The splitting points of the tree.
|
|
412
|
-
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
413
|
-
Whether a leaf has enough points to be grown.
|
|
297
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree, new_split_tree)
|
|
414
298
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
423
|
-
"""
|
|
424
|
-
is_growable = grove.is_actual_leaf(split_tree)
|
|
425
|
-
if affluence_tree is not None:
|
|
426
|
-
is_growable &= affluence_tree
|
|
427
|
-
return is_growable, jnp.any(is_growable)
|
|
299
|
+
return dict(
|
|
300
|
+
allowed=allowed,
|
|
301
|
+
node=leaf_to_grow,
|
|
302
|
+
var_tree=var_tree,
|
|
303
|
+
split_tree=new_split_tree,
|
|
304
|
+
partial_ratio=ratio,
|
|
305
|
+
)
|
|
428
306
|
|
|
429
307
|
def choose_leaf(split_tree, affluence_tree, key):
|
|
430
308
|
"""
|
|
@@ -443,7 +321,7 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
443
321
|
-------
|
|
444
322
|
leaf_to_grow : int
|
|
445
323
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
446
|
-
``
|
|
324
|
+
``2 ** d``.
|
|
447
325
|
num_growable : int
|
|
448
326
|
The number of leaf nodes that can be grown.
|
|
449
327
|
num_prunable : int
|
|
@@ -454,11 +332,37 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
454
332
|
"""
|
|
455
333
|
is_growable, allowed = growable_leaves(split_tree, affluence_tree)
|
|
456
334
|
leaf_to_grow = randint_masked(key, is_growable)
|
|
335
|
+
leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
|
|
457
336
|
num_growable = jnp.count_nonzero(is_growable)
|
|
458
337
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
459
338
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
460
339
|
return leaf_to_grow, num_growable, num_prunable, allowed
|
|
461
340
|
|
|
341
|
+
def growable_leaves(split_tree, affluence_tree):
|
|
342
|
+
"""
|
|
343
|
+
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
split_tree : array (2 ** (d - 1),)
|
|
348
|
+
The splitting points of the tree.
|
|
349
|
+
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
350
|
+
Whether a leaf has enough points to be grown.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
is_growable : bool array (2 ** (d - 1),)
|
|
355
|
+
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
356
|
+
that are not at the bottom level and have at least two times the number
|
|
357
|
+
of minimum points per leaf.
|
|
358
|
+
allowed : bool
|
|
359
|
+
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
360
|
+
"""
|
|
361
|
+
is_growable = grove.is_actual_leaf(split_tree)
|
|
362
|
+
if affluence_tree is not None:
|
|
363
|
+
is_growable &= affluence_tree
|
|
364
|
+
return is_growable, jnp.any(is_growable)
|
|
365
|
+
|
|
462
366
|
def randint_masked(key, mask):
|
|
463
367
|
"""
|
|
464
368
|
Return a random integer in a range, including only some values.
|
|
@@ -665,7 +569,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
665
569
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
666
570
|
return random.randint(key, (), l, r)
|
|
667
571
|
|
|
668
|
-
def
|
|
572
|
+
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree):
|
|
669
573
|
"""
|
|
670
574
|
Compute the product of the transition and prior ratios of a grow move.
|
|
671
575
|
|
|
@@ -676,8 +580,6 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
|
|
|
676
580
|
num_prunable : int
|
|
677
581
|
The number of leaf parents that could be pruned, after converting the
|
|
678
582
|
leaf to be grown to a non-terminal node.
|
|
679
|
-
tree_halfsize : int
|
|
680
|
-
Half the length of the tree array, i.e., 2 ** (d - 1).
|
|
681
583
|
p_nonterminal : array (d - 1,)
|
|
682
584
|
The probability of a nonterminal node at each depth.
|
|
683
585
|
leaf_to_grow : int
|
|
@@ -686,16 +588,13 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
|
|
|
686
588
|
The splitting points of the tree, before the leaf is grown.
|
|
687
589
|
new_split_tree : array (2 ** (d - 1),)
|
|
688
590
|
The splitting points of the tree, after the leaf is grown.
|
|
689
|
-
initial_affluence_tree : bool array (2 ** (d - 1),) or None
|
|
690
|
-
Whether a leaf has enough points to be grown, before the leaf is grown.
|
|
691
|
-
new_affluence_tree : bool array (2 ** (d - 1),) or None
|
|
692
|
-
Whether a leaf has enough points to be grown, after the leaf is grown.
|
|
693
591
|
|
|
694
592
|
Returns
|
|
695
593
|
-------
|
|
696
594
|
ratio : float
|
|
697
595
|
The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
|
|
698
|
-
times the prior ratio P(new tree) / P(old tree)
|
|
596
|
+
times the prior ratio P(new tree) / P(old tree), but the transition
|
|
597
|
+
ratio is missing the factor P(propose prune) in the numerator.
|
|
699
598
|
"""
|
|
700
599
|
|
|
701
600
|
# the two ratios also contain factors num_available_split *
|
|
@@ -704,101 +603,21 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
|
|
|
704
603
|
prune_was_allowed = prune_allowed(initial_split_tree)
|
|
705
604
|
p_grow = jnp.where(prune_was_allowed, 0.5, 1)
|
|
706
605
|
|
|
707
|
-
|
|
708
|
-
p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
709
|
-
|
|
710
|
-
trans_ratio = p_prune * num_growable / (p_grow * num_prunable)
|
|
606
|
+
trans_ratio = num_growable / (p_grow * num_prunable)
|
|
711
607
|
|
|
712
|
-
depth = grove.
|
|
608
|
+
depth = grove.tree_depths(initial_split_tree.size)[leaf_to_grow]
|
|
713
609
|
p_parent = p_nonterminal[depth]
|
|
714
610
|
cp_children = 1 - p_nonterminal.at[depth + 1].get(mode='fill', fill_value=0)
|
|
715
611
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
716
612
|
|
|
717
613
|
return trans_ratio * tree_ratio
|
|
718
614
|
|
|
719
|
-
def
|
|
720
|
-
"""
|
|
721
|
-
Compute the likelihood ratio of a grow move.
|
|
722
|
-
|
|
723
|
-
Parameters
|
|
724
|
-
----------
|
|
725
|
-
X : array (p, n)
|
|
726
|
-
The predictors.
|
|
727
|
-
var_tree : array (2 ** (d - 1),)
|
|
728
|
-
The variable indices of the tree, after the grow move.
|
|
729
|
-
split_tree : array (2 ** (d - 1),)
|
|
730
|
-
The splitting points of the tree, after the grow move.
|
|
731
|
-
resid : array (n,)
|
|
732
|
-
The residuals (data minus forest value), for all trees but the one
|
|
733
|
-
under consideration.
|
|
734
|
-
sigma2 : float
|
|
735
|
-
The noise variance.
|
|
736
|
-
new_node : int
|
|
737
|
-
The index of the leaf that has been grown.
|
|
738
|
-
n_tree : int
|
|
739
|
-
The number of trees in the forest.
|
|
740
|
-
min_points_per_leaf : int or None
|
|
741
|
-
The minimum number of data points in a leaf node.
|
|
742
|
-
|
|
743
|
-
Returns
|
|
744
|
-
-------
|
|
745
|
-
ratio : float
|
|
746
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
747
|
-
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
748
|
-
Whether a leaf has enough points to be grown, after the grow move.
|
|
749
|
-
"""
|
|
750
|
-
|
|
751
|
-
resid_tree, count_tree = agg_values(
|
|
752
|
-
X,
|
|
753
|
-
var_tree,
|
|
754
|
-
split_tree,
|
|
755
|
-
resid,
|
|
756
|
-
sigma2.dtype,
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
left_child = new_node << 1
|
|
760
|
-
right_child = left_child + 1
|
|
761
|
-
|
|
762
|
-
left_resid = resid_tree[left_child]
|
|
763
|
-
right_resid = resid_tree[right_child]
|
|
764
|
-
total_resid = left_resid + right_resid
|
|
765
|
-
|
|
766
|
-
left_count = count_tree[left_child]
|
|
767
|
-
right_count = count_tree[right_child]
|
|
768
|
-
total_count = left_count + right_count
|
|
769
|
-
|
|
770
|
-
sigma_mu2 = 1 / n_tree
|
|
771
|
-
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
772
|
-
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
773
|
-
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
774
|
-
|
|
775
|
-
sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
|
|
776
|
-
|
|
777
|
-
exp_term = sigma_mu2 / (2 * sigma2) * (
|
|
778
|
-
left_resid * left_resid / sigma2_left +
|
|
779
|
-
right_resid * right_resid / sigma2_right -
|
|
780
|
-
total_resid * total_resid / sigma2_total
|
|
781
|
-
)
|
|
782
|
-
|
|
783
|
-
ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
784
|
-
|
|
785
|
-
if min_points_per_leaf is not None:
|
|
786
|
-
ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
|
|
787
|
-
ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
|
|
788
|
-
affluence_tree = count_tree[:count_tree.size // 2] >= 2 * min_points_per_leaf
|
|
789
|
-
else:
|
|
790
|
-
affluence_tree = None
|
|
791
|
-
|
|
792
|
-
return ratio, affluence_tree
|
|
793
|
-
|
|
794
|
-
def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
|
|
615
|
+
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
795
616
|
"""
|
|
796
617
|
Tree structure prune move proposal of BART MCMC.
|
|
797
618
|
|
|
798
619
|
Parameters
|
|
799
620
|
----------
|
|
800
|
-
X : array (p, n)
|
|
801
|
-
The predictors.
|
|
802
621
|
var_tree : array (2 ** (d - 1),)
|
|
803
622
|
The variable indices of the tree.
|
|
804
623
|
split_tree : array (2 ** (d - 1),)
|
|
@@ -809,48 +628,41 @@ def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
|
809
628
|
The maximum split index for each variable.
|
|
810
629
|
p_nonterminal : array (d - 1,)
|
|
811
630
|
The probability of a nonterminal node at each depth.
|
|
812
|
-
sigma2 : float
|
|
813
|
-
The noise variance.
|
|
814
|
-
resid : array (n,)
|
|
815
|
-
The residuals (data minus forest value), computed using all trees but
|
|
816
|
-
the tree under consideration.
|
|
817
|
-
n_tree : int
|
|
818
|
-
The number of trees in the forest.
|
|
819
|
-
min_points_per_leaf : int
|
|
820
|
-
The minimum number of data points in a leaf node.
|
|
821
631
|
key : jax.dtypes.prng_key array
|
|
822
632
|
A jax random key.
|
|
823
633
|
|
|
824
634
|
Returns
|
|
825
635
|
-------
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
636
|
+
prune_move : dict
|
|
637
|
+
A dictionary with fields:
|
|
638
|
+
|
|
639
|
+
'allowed' : bool
|
|
640
|
+
Whether the move is possible.
|
|
641
|
+
'node' : int
|
|
642
|
+
The index of the leaf to grow.
|
|
643
|
+
'var_tree' : array (2 ** (d - 1),)
|
|
644
|
+
The new decision axes of the tree.
|
|
645
|
+
'split_tree' : array (2 ** (d - 1),)
|
|
646
|
+
The new decision boundaries of the tree.
|
|
647
|
+
'partial_ratio' : float
|
|
648
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
649
|
+
the likelihood ratio and the probability of proposing the prune
|
|
650
|
+
move. This ratio is inverted.
|
|
836
651
|
"""
|
|
837
652
|
node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
|
|
838
653
|
allowed = prune_allowed(split_tree)
|
|
839
654
|
|
|
840
655
|
new_split_tree = split_tree.at[node_to_prune].set(0)
|
|
841
|
-
# should I clean up var_tree as well? just for debugging. it hasn't given me problems though
|
|
842
656
|
|
|
843
|
-
|
|
844
|
-
new_affluence_tree = (
|
|
845
|
-
None if affluence_tree is None else
|
|
846
|
-
affluence_tree.at[node_to_prune].set(True)
|
|
847
|
-
)
|
|
848
|
-
trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, node_to_prune, new_split_tree, split_tree, new_affluence_tree, affluence_tree)
|
|
849
|
-
|
|
850
|
-
ratio = trans_tree_ratio * likelihood_ratio
|
|
851
|
-
ratio = 1 / ratio # Question: should I use lax.reciprocal for this?
|
|
657
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, new_split_tree, split_tree)
|
|
852
658
|
|
|
853
|
-
return
|
|
659
|
+
return dict(
|
|
660
|
+
allowed=allowed,
|
|
661
|
+
node=node_to_prune,
|
|
662
|
+
var_tree=var_tree,
|
|
663
|
+
split_tree=new_split_tree,
|
|
664
|
+
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
665
|
+
)
|
|
854
666
|
|
|
855
667
|
def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
856
668
|
"""
|
|
@@ -906,116 +718,256 @@ def prune_allowed(split_tree):
|
|
|
906
718
|
"""
|
|
907
719
|
return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
|
|
908
720
|
|
|
909
|
-
def
|
|
910
|
-
"""
|
|
911
|
-
Single tree leaves sampling step of BART MCMC.
|
|
912
|
-
|
|
913
|
-
Parameters
|
|
914
|
-
----------
|
|
915
|
-
bart : dict
|
|
916
|
-
A BART mcmc state, as created by `make_bart`. The ``'resid'`` field
|
|
917
|
-
shall contain the residuals only w.r.t. the other trees.
|
|
918
|
-
key : jax.dtypes.prng_key array
|
|
919
|
-
A jax random key.
|
|
920
|
-
i_tree : int
|
|
921
|
-
The index of the tree to sample.
|
|
922
|
-
|
|
923
|
-
Returns
|
|
924
|
-
-------
|
|
925
|
-
bart : dict
|
|
926
|
-
The new BART mcmc state.
|
|
927
|
-
"""
|
|
721
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
928
722
|
bart = bart.copy()
|
|
723
|
+
def loop(carry, item):
|
|
724
|
+
resid = carry.pop('resid')
|
|
725
|
+
resid, carry, trees = accept_move_and_sample_leaves(
|
|
726
|
+
bart['X'],
|
|
727
|
+
len(bart['leaf_trees']),
|
|
728
|
+
resid,
|
|
729
|
+
bart['sigma2'],
|
|
730
|
+
bart['min_points_per_leaf'],
|
|
731
|
+
carry,
|
|
732
|
+
*item,
|
|
733
|
+
)
|
|
734
|
+
carry['resid'] = resid
|
|
735
|
+
return carry, trees
|
|
736
|
+
carry = {
|
|
737
|
+
k: jnp.zeros_like(bart[k]) for k in
|
|
738
|
+
['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
|
|
739
|
+
}
|
|
740
|
+
carry['resid'] = bart['resid']
|
|
741
|
+
items = (
|
|
742
|
+
bart['leaf_trees'],
|
|
743
|
+
bart['var_trees'],
|
|
744
|
+
bart['split_trees'],
|
|
745
|
+
bart['affluence_trees'],
|
|
746
|
+
grow_moves,
|
|
747
|
+
prune_moves,
|
|
748
|
+
random.split(key, len(bart['leaf_trees'])),
|
|
749
|
+
)
|
|
750
|
+
carry, trees = lax.scan(loop, carry, items)
|
|
751
|
+
bart.update(carry)
|
|
752
|
+
bart.update(trees)
|
|
753
|
+
return bart
|
|
754
|
+
|
|
755
|
+
def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree, var_tree, split_tree, affluence_tree, grow_move, prune_move, key):
|
|
756
|
+
|
|
757
|
+
# compute leaf indices according to grow move tree
|
|
758
|
+
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
759
|
+
grow_leaf_indices = traverse_tree(X, grow_move['var_tree'], grow_move['split_tree'])
|
|
760
|
+
|
|
761
|
+
# compute leaf indices in starting tree
|
|
762
|
+
grow_node = grow_move['node']
|
|
763
|
+
grow_left = grow_node << 1
|
|
764
|
+
grow_right = grow_left + 1
|
|
765
|
+
leaf_indices = jnp.where(
|
|
766
|
+
(grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
|
|
767
|
+
grow_node,
|
|
768
|
+
grow_leaf_indices,
|
|
769
|
+
)
|
|
929
770
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
771
|
+
# compute leaf indices in prune tree
|
|
772
|
+
prune_node = prune_move['node']
|
|
773
|
+
prune_left = prune_node << 1
|
|
774
|
+
prune_right = prune_left + 1
|
|
775
|
+
prune_leaf_indices = jnp.where(
|
|
776
|
+
(leaf_indices == prune_left) | (leaf_indices == prune_right),
|
|
777
|
+
prune_node,
|
|
778
|
+
leaf_indices,
|
|
936
779
|
)
|
|
937
780
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
781
|
+
# subtract starting tree from function
|
|
782
|
+
resid += leaf_tree[leaf_indices]
|
|
783
|
+
|
|
784
|
+
# aggregate residuals and count units per leaf
|
|
785
|
+
grow_resid_tree = jnp.zeros_like(leaf_tree, sigma2.dtype)
|
|
786
|
+
grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
|
|
787
|
+
grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
|
|
788
|
+
grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
|
|
789
|
+
|
|
790
|
+
# compute aggregations in starting tree
|
|
791
|
+
# I do not zero the children because garbage there does not matter
|
|
792
|
+
resid_tree = (grow_resid_tree.at[grow_node]
|
|
793
|
+
.set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
|
|
794
|
+
count_tree = (grow_count_tree.at[grow_node]
|
|
795
|
+
.set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
|
|
796
|
+
|
|
797
|
+
# compute aggregations in prune tree
|
|
798
|
+
prune_resid_tree = (resid_tree.at[prune_node]
|
|
799
|
+
.set(resid_tree[prune_left] + resid_tree[prune_right]))
|
|
800
|
+
prune_count_tree = (count_tree.at[prune_node]
|
|
801
|
+
.set(count_tree[prune_left] + count_tree[prune_right]))
|
|
802
|
+
|
|
803
|
+
# compute affluence trees
|
|
804
|
+
if min_points_per_leaf is not None:
|
|
805
|
+
grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
|
|
806
|
+
prune_affluence_tree = affluence_tree.at[prune_node].set(True)
|
|
807
|
+
|
|
808
|
+
# compute probability of proposing prune
|
|
809
|
+
grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
|
|
810
|
+
prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
|
|
811
|
+
|
|
812
|
+
# compute likelihood ratios
|
|
813
|
+
grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
|
|
814
|
+
prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
|
|
815
|
+
|
|
816
|
+
# compute acceptance ratios
|
|
817
|
+
grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
|
|
818
|
+
prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
|
|
819
|
+
prune_ratio = lax.reciprocal(prune_ratio)
|
|
820
|
+
|
|
821
|
+
# random coins in [0, 1) for proposal and acceptance
|
|
822
|
+
key, subkey = random.split(key)
|
|
823
|
+
u0, u1 = random.uniform(subkey, (2,))
|
|
824
|
+
|
|
825
|
+
# determine what move to propose (not proposing anything is an option)
|
|
826
|
+
p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
|
|
827
|
+
try_grow = u0 < p_grow
|
|
828
|
+
try_prune = prune_move['allowed'] & ~try_grow
|
|
829
|
+
|
|
830
|
+
# determine whether to accept the move
|
|
831
|
+
do_grow = try_grow & (u1 < grow_ratio)
|
|
832
|
+
do_prune = try_prune & (u1 < prune_ratio)
|
|
942
833
|
|
|
834
|
+
# pick trees for chosen move
|
|
835
|
+
trees = {}
|
|
836
|
+
var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
|
|
837
|
+
split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
|
|
838
|
+
var_tree = jnp.where(do_prune, prune_move['var_tree'], var_tree)
|
|
839
|
+
split_tree = jnp.where(do_prune, prune_move['split_tree'], split_tree)
|
|
840
|
+
if min_points_per_leaf is not None:
|
|
841
|
+
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
842
|
+
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
843
|
+
resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
|
|
844
|
+
count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
|
|
845
|
+
resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
|
|
846
|
+
count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
|
|
847
|
+
|
|
848
|
+
# update acceptance counts
|
|
849
|
+
counts = counts.copy()
|
|
850
|
+
counts['grow_prop_count'] += try_grow
|
|
851
|
+
counts['grow_acc_count'] += do_grow
|
|
852
|
+
counts['prune_prop_count'] += try_prune
|
|
853
|
+
counts['prune_acc_count'] += do_prune
|
|
854
|
+
|
|
855
|
+
# compute leaves posterior
|
|
856
|
+
prec_lk = count_tree / sigma2
|
|
857
|
+
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
858
|
+
mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
859
|
+
|
|
860
|
+
# sample leaves
|
|
943
861
|
z = random.normal(key, mean_post.shape, mean_post.dtype)
|
|
944
|
-
# TODO maybe use long float here, I guess this part is not a bottleneck
|
|
945
862
|
leaf_tree = mean_post + z * jnp.sqrt(var_post)
|
|
946
|
-
leaf_tree = leaf_tree.at[0].set(0) # this 0 is used by evaluate_tree
|
|
947
|
-
bart['leaf_trees'] = bart['leaf_trees'].at[i_tree].set(leaf_tree)
|
|
948
863
|
|
|
949
|
-
|
|
864
|
+
# add new tree to function
|
|
865
|
+
leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
|
|
866
|
+
leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
|
|
867
|
+
resid -= leaf_tree[leaf_indices]
|
|
950
868
|
|
|
951
|
-
|
|
869
|
+
# pack trees
|
|
870
|
+
trees = {
|
|
871
|
+
'leaf_trees': leaf_tree,
|
|
872
|
+
'var_trees': var_tree,
|
|
873
|
+
'split_trees': split_tree,
|
|
874
|
+
'affluence_trees': affluence_tree,
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
return resid, counts, trees
|
|
878
|
+
|
|
879
|
+
def compute_p_prune_back(new_split_tree, new_affluence_tree):
|
|
952
880
|
"""
|
|
953
|
-
|
|
881
|
+
Compute the probability of proposing a prune move after doing a grow move.
|
|
954
882
|
|
|
955
883
|
Parameters
|
|
956
884
|
----------
|
|
957
|
-
|
|
958
|
-
The
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
split_tree : array (2 ** (d - 1),)
|
|
962
|
-
The splitting points of the tree.
|
|
963
|
-
values : array (n,)
|
|
964
|
-
The values to aggregate.
|
|
965
|
-
acc_dtype : dtype
|
|
966
|
-
The dtype of the output.
|
|
885
|
+
new_split_tree : int array (2 ** (d - 1),)
|
|
886
|
+
The decision boundaries of the tree, after the grow move.
|
|
887
|
+
new_affluence_tree : bool array (2 ** (d - 1),)
|
|
888
|
+
Which leaves have enough points to be grown, after the grow move.
|
|
967
889
|
|
|
968
890
|
Returns
|
|
969
891
|
-------
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
892
|
+
p_prune : float
|
|
893
|
+
The probability of proposing a prune move after the grow move. This is
|
|
894
|
+
0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
|
|
895
|
+
at least the node just grown can be pruned.
|
|
896
|
+
"""
|
|
897
|
+
_, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
|
|
898
|
+
return jnp.where(grow_again_allowed, 0.5, 1)
|
|
899
|
+
|
|
900
|
+
def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
|
|
901
|
+
"""
|
|
902
|
+
Compute the likelihood ratio of a grow move.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
resid_tree : float array (2 ** d,)
|
|
907
|
+
The sum of the residuals at data points in each leaf.
|
|
974
908
|
count_tree : int array (2 ** d,)
|
|
975
|
-
|
|
909
|
+
The number of data points in each leaf.
|
|
910
|
+
sigma2 : float
|
|
911
|
+
The noise variance.
|
|
912
|
+
node : int
|
|
913
|
+
The index of the leaf that has been grown.
|
|
914
|
+
n_tree : int
|
|
915
|
+
The number of trees in the forest.
|
|
916
|
+
min_points_per_leaf : int or None
|
|
917
|
+
The minimum number of data points in a leaf node.
|
|
918
|
+
|
|
919
|
+
Returns
|
|
920
|
+
-------
|
|
921
|
+
ratio : float
|
|
922
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
923
|
+
|
|
924
|
+
Notes
|
|
925
|
+
-----
|
|
926
|
+
The ratio is set to 0 if the grow move would create leaves with not enough
|
|
927
|
+
datapoints per leaf, although this is part of the prior rather than the
|
|
928
|
+
likelihood.
|
|
976
929
|
"""
|
|
977
930
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
931
|
+
left_child = node << 1
|
|
932
|
+
right_child = left_child + 1
|
|
933
|
+
|
|
934
|
+
left_resid = resid_tree[left_child]
|
|
935
|
+
right_resid = resid_tree[right_child]
|
|
936
|
+
total_resid = left_resid + right_resid
|
|
937
|
+
|
|
938
|
+
left_count = count_tree[left_child]
|
|
939
|
+
right_count = count_tree[right_child]
|
|
940
|
+
total_count = left_count + right_count
|
|
941
|
+
|
|
942
|
+
sigma_mu2 = 1 / n_tree
|
|
943
|
+
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
944
|
+
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
945
|
+
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
946
|
+
|
|
947
|
+
sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
|
|
948
|
+
|
|
949
|
+
exp_term = sigma_mu2 / (2 * sigma2) * (
|
|
950
|
+
left_resid * left_resid / sigma2_left +
|
|
951
|
+
right_resid * right_resid / sigma2_right -
|
|
952
|
+
total_resid * total_resid / sigma2_total
|
|
984
953
|
)
|
|
985
|
-
unit_index = jnp.arange(values.size, dtype=grove.minimal_unsigned_dtype(values.size - 1))
|
|
986
954
|
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
split = split_tree[node_index]
|
|
998
|
-
var = var_tree.at[node_index].get(mode='fill', fill_value=0)
|
|
999
|
-
x = X[var, unit_index]
|
|
1000
|
-
|
|
1001
|
-
node_index <<= 1
|
|
1002
|
-
node_index += x >= split
|
|
1003
|
-
node_index = jnp.where(leaf_found, 0, node_index)
|
|
1004
|
-
|
|
1005
|
-
carry = leaf_found, node_index, acc_tree, count_tree
|
|
1006
|
-
return carry, None
|
|
1007
|
-
|
|
1008
|
-
(_, _, acc_tree, count_tree), _ = lax.scan(loop, carry, None, depth)
|
|
1009
|
-
return acc_tree, count_tree
|
|
1010
|
-
|
|
1011
|
-
def mcmc_sample_sigma(bart, key):
|
|
955
|
+
ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
956
|
+
|
|
957
|
+
if min_points_per_leaf is not None:
|
|
958
|
+
ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
|
|
959
|
+
ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
|
|
960
|
+
|
|
961
|
+
return ratio
|
|
962
|
+
|
|
963
|
+
def sample_sigma(bart, key):
|
|
1012
964
|
"""
|
|
1013
965
|
Noise variance sampling step of BART MCMC.
|
|
1014
966
|
|
|
1015
967
|
Parameters
|
|
1016
968
|
----------
|
|
1017
969
|
bart : dict
|
|
1018
|
-
A BART mcmc state, as created by `
|
|
970
|
+
A BART mcmc state, as created by `init`.
|
|
1019
971
|
key : jax.dtypes.prng_key array
|
|
1020
972
|
A jax random key.
|
|
1021
973
|
|
|
@@ -1028,8 +980,8 @@ def mcmc_sample_sigma(bart, key):
|
|
|
1028
980
|
|
|
1029
981
|
resid = bart['resid']
|
|
1030
982
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1031
|
-
|
|
1032
|
-
beta = bart['sigma2_beta'] +
|
|
983
|
+
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
|
|
984
|
+
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1033
985
|
|
|
1034
986
|
sample = random.gamma(key, alpha)
|
|
1035
987
|
bart['sigma2'] = beta / sample
|