bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/{interface.py → BART.py} +10 -18
- bartz/__init__.py +7 -2
- bartz/_version.py +1 -0
- bartz/debug.py +9 -22
- bartz/grove.py +73 -120
- bartz/jaxext.py +261 -5
- bartz/mcmcloop.py +27 -13
- bartz/mcmcstep.py +510 -439
- bartz/prepcovars.py +25 -30
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/METADATA +7 -1
- bartz-0.2.0.dist-info/RECORD +13 -0
- bartz-0.0.1.dist-info/RECORD +0 -12
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/LICENSE +0 -0
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/WHEEL +0 -0
bartz/mcmcstep.py
CHANGED
|
@@ -34,16 +34,16 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
-
import math
|
|
38
37
|
|
|
39
38
|
import jax
|
|
40
39
|
from jax import random
|
|
41
40
|
from jax import numpy as jnp
|
|
42
41
|
from jax import lax
|
|
43
42
|
|
|
43
|
+
from . import jaxext
|
|
44
44
|
from . import grove
|
|
45
45
|
|
|
46
|
-
def
|
|
46
|
+
def init(*,
|
|
47
47
|
X,
|
|
48
48
|
y,
|
|
49
49
|
max_split,
|
|
@@ -51,9 +51,10 @@ def make_bart(*,
|
|
|
51
51
|
p_nonterminal,
|
|
52
52
|
sigma2_alpha,
|
|
53
53
|
sigma2_beta,
|
|
54
|
-
|
|
55
|
-
|
|
54
|
+
small_float=jnp.float32,
|
|
55
|
+
large_float=jnp.float32,
|
|
56
56
|
min_points_per_leaf=None,
|
|
57
|
+
suffstat_batch_size='auto',
|
|
57
58
|
):
|
|
58
59
|
"""
|
|
59
60
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -75,12 +76,15 @@ def make_bart(*,
|
|
|
75
76
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
76
77
|
sigma2_beta : float
|
|
77
78
|
The scale parameter of the inverse gamma prior on the noise variance.
|
|
78
|
-
|
|
79
|
+
small_float : dtype, default float32
|
|
79
80
|
The dtype for large arrays used in the algorithm.
|
|
80
|
-
|
|
81
|
+
large_float : dtype, default float32
|
|
81
82
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
82
83
|
min_points_per_leaf : int, optional
|
|
83
84
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
+
suffstat_batch_size : int, None, str, default 'auto'
|
|
86
|
+
The batch size for computing sufficient statistics. `None` for no
|
|
87
|
+
batching. If 'auto', pick a value based on the device of `y`.
|
|
84
88
|
|
|
85
89
|
Returns
|
|
86
90
|
-------
|
|
@@ -88,31 +92,31 @@ def make_bart(*,
|
|
|
88
92
|
A dictionary with array values, representing a BART mcmc state. The
|
|
89
93
|
keys are:
|
|
90
94
|
|
|
91
|
-
'leaf_trees' :
|
|
92
|
-
The leaf values
|
|
95
|
+
'leaf_trees' : small_float array (num_trees, 2 ** d)
|
|
96
|
+
The leaf values.
|
|
93
97
|
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
94
|
-
The
|
|
95
|
-
it can only contain leaves.
|
|
98
|
+
The decision axes.
|
|
96
99
|
'split_trees' : int array (num_trees, 2 ** (d - 1))
|
|
97
|
-
The
|
|
98
|
-
'resid' :
|
|
100
|
+
The decision boundaries.
|
|
101
|
+
'resid' : large_float array (n,)
|
|
99
102
|
The residuals (data minus forest value). Large float to avoid
|
|
100
103
|
roundoff.
|
|
101
|
-
'sigma2' :
|
|
104
|
+
'sigma2' : large_float
|
|
102
105
|
The noise variance.
|
|
103
106
|
'grow_prop_count', 'prune_prop_count' : int
|
|
104
107
|
The number of grow/prune proposals made during one full MCMC cycle.
|
|
105
108
|
'grow_acc_count', 'prune_acc_count' : int
|
|
106
109
|
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
-
'p_nonterminal' :
|
|
108
|
-
The probability of a nonterminal node at each depth
|
|
109
|
-
|
|
110
|
+
'p_nonterminal' : large_float array (d,)
|
|
111
|
+
The probability of a nonterminal node at each depth, padded with a
|
|
112
|
+
zero.
|
|
113
|
+
'sigma2_alpha' : large_float
|
|
110
114
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
111
|
-
'sigma2_beta' :
|
|
115
|
+
'sigma2_beta' : large_float
|
|
112
116
|
The scale parameter of the inverse gamma prior on the noise variance.
|
|
113
117
|
'max_split' : int array (p,)
|
|
114
118
|
The maximum split index for each variable.
|
|
115
|
-
'y' :
|
|
119
|
+
'y' : small_float array (n,)
|
|
116
120
|
The response.
|
|
117
121
|
'X' : int array (p, n)
|
|
118
122
|
The predictors.
|
|
@@ -121,31 +125,49 @@ def make_bart(*,
|
|
|
121
125
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
122
126
|
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
123
127
|
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
128
|
+
'opt' : LeafDict
|
|
129
|
+
A dictionary with config values:
|
|
130
|
+
|
|
131
|
+
'suffstat_batch_size' : int or None
|
|
132
|
+
The batch size for computing sufficient statistics.
|
|
133
|
+
'small_float' : dtype
|
|
134
|
+
The dtype for large arrays used in the algorithm.
|
|
135
|
+
'large_float' : dtype
|
|
136
|
+
The dtype for scalars, small arrays, and arrays which require
|
|
137
|
+
accuracy.
|
|
138
|
+
'require_min_points' : bool
|
|
139
|
+
Whether the `min_points_per_leaf` parameter is specified.
|
|
124
140
|
"""
|
|
125
141
|
|
|
126
|
-
p_nonterminal = jnp.asarray(p_nonterminal,
|
|
127
|
-
|
|
142
|
+
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
143
|
+
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
144
|
+
max_depth = p_nonterminal.size
|
|
128
145
|
|
|
129
146
|
@functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
130
147
|
def make_forest(max_depth, dtype):
|
|
131
148
|
return grove.make_tree(max_depth, dtype)
|
|
132
149
|
|
|
150
|
+
small_float = jnp.dtype(small_float)
|
|
151
|
+
large_float = jnp.dtype(large_float)
|
|
152
|
+
y = jnp.asarray(y, small_float)
|
|
153
|
+
suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
|
|
154
|
+
|
|
133
155
|
bart = dict(
|
|
134
|
-
leaf_trees=make_forest(max_depth,
|
|
135
|
-
var_trees=make_forest(max_depth - 1,
|
|
156
|
+
leaf_trees=make_forest(max_depth, small_float),
|
|
157
|
+
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
136
158
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
137
|
-
resid=jnp.asarray(y,
|
|
138
|
-
sigma2=jnp.ones((),
|
|
159
|
+
resid=jnp.asarray(y, large_float),
|
|
160
|
+
sigma2=jnp.ones((), large_float),
|
|
139
161
|
grow_prop_count=jnp.zeros((), int),
|
|
140
162
|
grow_acc_count=jnp.zeros((), int),
|
|
141
163
|
prune_prop_count=jnp.zeros((), int),
|
|
142
164
|
prune_acc_count=jnp.zeros((), int),
|
|
143
165
|
p_nonterminal=p_nonterminal,
|
|
144
|
-
sigma2_alpha=jnp.asarray(sigma2_alpha,
|
|
145
|
-
sigma2_beta=jnp.asarray(sigma2_beta,
|
|
146
|
-
max_split=max_split,
|
|
147
|
-
y=
|
|
148
|
-
X=X,
|
|
166
|
+
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
167
|
+
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
168
|
+
max_split=jnp.asarray(max_split),
|
|
169
|
+
y=y,
|
|
170
|
+
X=jnp.asarray(X),
|
|
149
171
|
min_points_per_leaf=(
|
|
150
172
|
None if min_points_per_leaf is None else
|
|
151
173
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -154,18 +176,40 @@ def make_bart(*,
|
|
|
154
176
|
None if min_points_per_leaf is None else
|
|
155
177
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
156
178
|
),
|
|
179
|
+
opt=jaxext.LeafDict(
|
|
180
|
+
suffstat_batch_size=suffstat_batch_size,
|
|
181
|
+
small_float=small_float,
|
|
182
|
+
large_float=large_float,
|
|
183
|
+
require_min_points=min_points_per_leaf is not None,
|
|
184
|
+
),
|
|
157
185
|
)
|
|
158
186
|
|
|
159
187
|
return bart
|
|
160
188
|
|
|
161
|
-
def
|
|
189
|
+
def _choose_suffstat_batch_size(size, y):
|
|
190
|
+
if size == 'auto':
|
|
191
|
+
platform = y.devices().pop().platform
|
|
192
|
+
if platform == 'cpu':
|
|
193
|
+
return None
|
|
194
|
+
# maybe I should batch residuals (not counts) for numerical
|
|
195
|
+
# accuracy, even if it's slower
|
|
196
|
+
elif platform == 'gpu':
|
|
197
|
+
return 128 # 128 is good on A100, and V100 at high n
|
|
198
|
+
# 512 is good on T4, and V100 at low n
|
|
199
|
+
else:
|
|
200
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
201
|
+
elif size is not None:
|
|
202
|
+
return int(size)
|
|
203
|
+
return size
|
|
204
|
+
|
|
205
|
+
def step(bart, key):
|
|
162
206
|
"""
|
|
163
207
|
Perform one full MCMC step on a BART state.
|
|
164
208
|
|
|
165
209
|
Parameters
|
|
166
210
|
----------
|
|
167
211
|
bart : dict
|
|
168
|
-
A BART mcmc state, as created by `
|
|
212
|
+
A BART mcmc state, as created by `init`.
|
|
169
213
|
key : jax.dtypes.prng_key array
|
|
170
214
|
A jax random key.
|
|
171
215
|
|
|
@@ -174,19 +218,18 @@ def mcmc_step(bart, key):
|
|
|
174
218
|
bart : dict
|
|
175
219
|
The new BART mcmc state.
|
|
176
220
|
"""
|
|
177
|
-
|
|
178
|
-
bart =
|
|
179
|
-
|
|
180
|
-
return bart
|
|
221
|
+
key, subkey = random.split(key)
|
|
222
|
+
bart = sample_trees(bart, subkey)
|
|
223
|
+
return sample_sigma(bart, key)
|
|
181
224
|
|
|
182
|
-
def
|
|
225
|
+
def sample_trees(bart, key):
|
|
183
226
|
"""
|
|
184
227
|
Forest sampling step of BART MCMC.
|
|
185
228
|
|
|
186
229
|
Parameters
|
|
187
230
|
----------
|
|
188
231
|
bart : dict
|
|
189
|
-
A BART mcmc state, as created by `
|
|
232
|
+
A BART mcmc state, as created by `init`.
|
|
190
233
|
key : jax.dtypes.prng_key array
|
|
191
234
|
A jax random key.
|
|
192
235
|
|
|
@@ -197,150 +240,52 @@ def mcmc_sample_trees(bart, key):
|
|
|
197
240
|
|
|
198
241
|
Notes
|
|
199
242
|
-----
|
|
200
|
-
This function zeroes the proposal counters.
|
|
243
|
+
This function zeroes the proposal counters before using them.
|
|
201
244
|
"""
|
|
202
245
|
bart = bart.copy()
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
i, bart, key = carry
|
|
209
|
-
key, subkey = random.split(key)
|
|
210
|
-
bart = mcmc_sample_tree(bart, subkey, i)
|
|
211
|
-
return (i + 1, bart, key), None
|
|
212
|
-
|
|
213
|
-
(_, bart, _), _ = lax.scan(loop, carry, None, len(bart['leaf_trees']))
|
|
214
|
-
return bart
|
|
246
|
+
key, subkey = random.split(key)
|
|
247
|
+
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
248
|
+
bart['var_trees'] = grow_moves['var_tree']
|
|
249
|
+
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
250
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
215
251
|
|
|
216
|
-
def
|
|
252
|
+
def sample_moves(bart, key):
|
|
217
253
|
"""
|
|
218
|
-
|
|
254
|
+
Propose moves for all the trees.
|
|
219
255
|
|
|
220
256
|
Parameters
|
|
221
257
|
----------
|
|
222
258
|
bart : dict
|
|
223
|
-
|
|
259
|
+
BART mcmc state.
|
|
224
260
|
key : jax.dtypes.prng_key array
|
|
225
261
|
A jax random key.
|
|
226
|
-
i_tree : int
|
|
227
|
-
The index of the tree to sample.
|
|
228
262
|
|
|
229
263
|
Returns
|
|
230
264
|
-------
|
|
231
|
-
|
|
232
|
-
The
|
|
233
|
-
"""
|
|
234
|
-
bart = bart.copy()
|
|
235
|
-
|
|
236
|
-
y_tree = grove.evaluate_tree_vmap_x(
|
|
237
|
-
bart['X'],
|
|
238
|
-
bart['leaf_trees'][i_tree],
|
|
239
|
-
bart['var_trees'][i_tree],
|
|
240
|
-
bart['split_trees'][i_tree],
|
|
241
|
-
bart['resid'].dtype,
|
|
242
|
-
)
|
|
243
|
-
bart['resid'] += y_tree
|
|
244
|
-
|
|
245
|
-
key1, key2 = random.split(key, 2)
|
|
246
|
-
bart = mcmc_sample_tree_structure(bart, key1, i_tree)
|
|
247
|
-
bart = mcmc_sample_tree_leaves(bart, key2, i_tree)
|
|
248
|
-
|
|
249
|
-
y_tree = grove.evaluate_tree_vmap_x(
|
|
250
|
-
bart['X'],
|
|
251
|
-
bart['leaf_trees'][i_tree],
|
|
252
|
-
bart['var_trees'][i_tree],
|
|
253
|
-
bart['split_trees'][i_tree],
|
|
254
|
-
bart['resid'].dtype,
|
|
255
|
-
)
|
|
256
|
-
bart['resid'] -= y_tree
|
|
257
|
-
|
|
258
|
-
return bart
|
|
259
|
-
|
|
260
|
-
def mcmc_sample_tree_structure(bart, key, i_tree):
|
|
265
|
+
grow_moves, prune_moves : dict
|
|
266
|
+
The proposals for grow and prune moves. See `grow_move` and `prune_move`.
|
|
261
267
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
Returns
|
|
275
|
-
-------
|
|
276
|
-
bart : dict
|
|
277
|
-
The new BART mcmc state.
|
|
278
|
-
"""
|
|
279
|
-
bart = bart.copy()
|
|
280
|
-
|
|
281
|
-
var_tree = bart['var_trees'][i_tree]
|
|
282
|
-
split_tree = bart['split_trees'][i_tree]
|
|
283
|
-
affluence_tree = (
|
|
284
|
-
None if bart['affluence_trees'] is None else
|
|
285
|
-
bart['affluence_trees'][i_tree]
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
key1, key2, key3 = random.split(key, 3)
|
|
289
|
-
args = [
|
|
290
|
-
bart['X'],
|
|
291
|
-
var_tree,
|
|
292
|
-
split_tree,
|
|
293
|
-
affluence_tree,
|
|
294
|
-
bart['max_split'],
|
|
295
|
-
bart['p_nonterminal'],
|
|
296
|
-
bart['sigma2'],
|
|
297
|
-
bart['resid'],
|
|
298
|
-
len(bart['var_trees']),
|
|
299
|
-
bart['min_points_per_leaf'],
|
|
300
|
-
key1,
|
|
301
|
-
]
|
|
302
|
-
grow_var_tree, grow_split_tree, grow_affluence_tree, grow_allowed, grow_ratio = grow_move(*args)
|
|
303
|
-
|
|
304
|
-
args[-1] = key2
|
|
305
|
-
prune_var_tree, prune_split_tree, prune_affluence_tree, prune_allowed, prune_ratio = prune_move(*args)
|
|
306
|
-
|
|
307
|
-
u0, u1 = random.uniform(key3, (2,))
|
|
308
|
-
|
|
309
|
-
p_grow = jnp.where(grow_allowed & prune_allowed, 0.5, grow_allowed)
|
|
310
|
-
try_grow = u0 < p_grow
|
|
311
|
-
try_prune = prune_allowed & ~try_grow
|
|
312
|
-
|
|
313
|
-
do_grow = try_grow & (u1 < grow_ratio)
|
|
314
|
-
do_prune = try_prune & (u1 < prune_ratio)
|
|
315
|
-
|
|
316
|
-
var_tree = jnp.where(do_grow, grow_var_tree, var_tree)
|
|
317
|
-
split_tree = jnp.where(do_grow, grow_split_tree, split_tree)
|
|
318
|
-
var_tree = jnp.where(do_prune, prune_var_tree, var_tree)
|
|
319
|
-
split_tree = jnp.where(do_prune, prune_split_tree, split_tree)
|
|
320
|
-
|
|
321
|
-
bart['var_trees'] = bart['var_trees'].at[i_tree].set(var_tree)
|
|
322
|
-
bart['split_trees'] = bart['split_trees'].at[i_tree].set(split_tree)
|
|
323
|
-
|
|
324
|
-
if bart['min_points_per_leaf'] is not None:
|
|
325
|
-
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
326
|
-
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
327
|
-
bart['affluence_trees'] = bart['affluence_trees'].at[i_tree].set(affluence_tree)
|
|
328
|
-
|
|
329
|
-
bart['grow_prop_count'] += try_grow
|
|
330
|
-
bart['grow_acc_count'] += do_grow
|
|
331
|
-
bart['prune_prop_count'] += try_prune
|
|
332
|
-
bart['prune_acc_count'] += do_prune
|
|
333
|
-
|
|
334
|
-
return bart
|
|
335
|
-
|
|
336
|
-
def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
|
|
268
|
+
key = random.split(key, bart['var_trees'].shape[0])
|
|
269
|
+
return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
|
|
270
|
+
|
|
271
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
|
|
272
|
+
def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
273
|
+
key, key1 = random.split(key)
|
|
274
|
+
args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
275
|
+
grow = grow_move(*args, key)
|
|
276
|
+
prune = prune_move(*args, key1)
|
|
277
|
+
return grow, prune
|
|
278
|
+
|
|
279
|
+
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
337
280
|
"""
|
|
338
281
|
Tree structure grow move proposal of BART MCMC.
|
|
339
282
|
|
|
283
|
+
This moves picks a leaf node and converts it to a non-terminal node with
|
|
284
|
+
two leaf children. The move is not possible if all the leaves are already at
|
|
285
|
+
maximum depth.
|
|
286
|
+
|
|
340
287
|
Parameters
|
|
341
288
|
----------
|
|
342
|
-
X : array (p, n)
|
|
343
|
-
The predictors.
|
|
344
289
|
var_tree : array (2 ** (d - 1),)
|
|
345
290
|
The variable indices of the tree.
|
|
346
291
|
split_tree : array (2 ** (d - 1),)
|
|
@@ -349,82 +294,49 @@ def grow_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal,
|
|
|
349
294
|
Whether a leaf has enough points to be grown.
|
|
350
295
|
max_split : array (p,)
|
|
351
296
|
The maximum split index for each variable.
|
|
352
|
-
p_nonterminal : array (d
|
|
297
|
+
p_nonterminal : array (d,)
|
|
353
298
|
The probability of a nonterminal node at each depth.
|
|
354
|
-
sigma2 : float
|
|
355
|
-
The noise variance.
|
|
356
|
-
resid : array (n,)
|
|
357
|
-
The residuals (data minus forest value), computed using all trees but
|
|
358
|
-
the tree under consideration.
|
|
359
|
-
n_tree : int
|
|
360
|
-
The number of trees in the forest.
|
|
361
|
-
min_points_per_leaf : int
|
|
362
|
-
The minimum number of data points in a leaf node.
|
|
363
299
|
key : jax.dtypes.prng_key array
|
|
364
300
|
A jax random key.
|
|
365
301
|
|
|
366
302
|
Returns
|
|
367
303
|
-------
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
maximum depth.
|
|
304
|
+
grow_move : dict
|
|
305
|
+
A dictionary with fields:
|
|
306
|
+
|
|
307
|
+
'allowed' : bool
|
|
308
|
+
Whether the move is possible.
|
|
309
|
+
'node' : int
|
|
310
|
+
The index of the leaf to grow.
|
|
311
|
+
'var_tree' : array (2 ** (d - 1),)
|
|
312
|
+
The new decision axes of the tree.
|
|
313
|
+
'split_tree' : array (2 ** (d - 1),)
|
|
314
|
+
The new decision boundaries of the tree.
|
|
315
|
+
'partial_ratio' : float
|
|
316
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
317
|
+
the likelihood ratio and the probability of proposing the prune
|
|
318
|
+
move.
|
|
384
319
|
"""
|
|
385
320
|
|
|
386
|
-
key1, key2
|
|
387
|
-
|
|
388
|
-
leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key1)
|
|
321
|
+
key, key1, key2 = random.split(key, 3)
|
|
389
322
|
|
|
390
|
-
|
|
391
|
-
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
392
|
-
|
|
393
|
-
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key3)
|
|
394
|
-
new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
323
|
+
leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
|
|
395
324
|
|
|
396
|
-
|
|
325
|
+
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
|
|
326
|
+
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
397
327
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
ratio = trans_tree_ratio * likelihood_ratio
|
|
401
|
-
|
|
402
|
-
return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
|
|
403
|
-
|
|
404
|
-
def growable_leaves(split_tree, affluence_tree):
|
|
405
|
-
"""
|
|
406
|
-
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
328
|
+
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
329
|
+
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
407
330
|
|
|
408
|
-
|
|
409
|
-
----------
|
|
410
|
-
split_tree : array (2 ** (d - 1),)
|
|
411
|
-
The splitting points of the tree.
|
|
412
|
-
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
413
|
-
Whether a leaf has enough points to be grown.
|
|
331
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
|
|
414
332
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
423
|
-
"""
|
|
424
|
-
is_growable = grove.is_actual_leaf(split_tree)
|
|
425
|
-
if affluence_tree is not None:
|
|
426
|
-
is_growable &= affluence_tree
|
|
427
|
-
return is_growable, jnp.any(is_growable)
|
|
333
|
+
return dict(
|
|
334
|
+
allowed=allowed,
|
|
335
|
+
node=leaf_to_grow,
|
|
336
|
+
partial_ratio=ratio,
|
|
337
|
+
var_tree=var_tree,
|
|
338
|
+
split_tree=split_tree,
|
|
339
|
+
)
|
|
428
340
|
|
|
429
341
|
def choose_leaf(split_tree, affluence_tree, key):
|
|
430
342
|
"""
|
|
@@ -443,7 +355,7 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
443
355
|
-------
|
|
444
356
|
leaf_to_grow : int
|
|
445
357
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
446
|
-
``
|
|
358
|
+
``2 ** d``.
|
|
447
359
|
num_growable : int
|
|
448
360
|
The number of leaf nodes that can be grown.
|
|
449
361
|
num_prunable : int
|
|
@@ -454,11 +366,37 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
454
366
|
"""
|
|
455
367
|
is_growable, allowed = growable_leaves(split_tree, affluence_tree)
|
|
456
368
|
leaf_to_grow = randint_masked(key, is_growable)
|
|
369
|
+
leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
|
|
457
370
|
num_growable = jnp.count_nonzero(is_growable)
|
|
458
371
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
459
372
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
460
373
|
return leaf_to_grow, num_growable, num_prunable, allowed
|
|
461
374
|
|
|
375
|
+
def growable_leaves(split_tree, affluence_tree):
|
|
376
|
+
"""
|
|
377
|
+
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
split_tree : array (2 ** (d - 1),)
|
|
382
|
+
The splitting points of the tree.
|
|
383
|
+
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
384
|
+
Whether a leaf has enough points to be grown.
|
|
385
|
+
|
|
386
|
+
Returns
|
|
387
|
+
-------
|
|
388
|
+
is_growable : bool array (2 ** (d - 1),)
|
|
389
|
+
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
390
|
+
that are not at the bottom level and have at least two times the number
|
|
391
|
+
of minimum points per leaf.
|
|
392
|
+
allowed : bool
|
|
393
|
+
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
394
|
+
"""
|
|
395
|
+
is_growable = grove.is_actual_leaf(split_tree)
|
|
396
|
+
if affluence_tree is not None:
|
|
397
|
+
is_growable &= affluence_tree
|
|
398
|
+
return is_growable, jnp.any(is_growable)
|
|
399
|
+
|
|
462
400
|
def randint_masked(key, mask):
|
|
463
401
|
"""
|
|
464
402
|
Return a random integer in a range, including only some values.
|
|
@@ -560,7 +498,7 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
560
498
|
the parent. Unused spots are filled with `p`.
|
|
561
499
|
"""
|
|
562
500
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
563
|
-
ancestor_vars = jnp.zeros(max_num_ancestors,
|
|
501
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
|
|
564
502
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
565
503
|
def loop(carry, _):
|
|
566
504
|
i, index, ancestor_vars = carry
|
|
@@ -665,7 +603,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
665
603
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
666
604
|
return random.randint(key, (), l, r)
|
|
667
605
|
|
|
668
|
-
def
|
|
606
|
+
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
|
|
669
607
|
"""
|
|
670
608
|
Compute the product of the transition and prior ratios of a grow move.
|
|
671
609
|
|
|
@@ -676,129 +614,46 @@ def compute_trans_tree_ratio(num_growable, num_prunable, tree_halfsize, p_nonter
|
|
|
676
614
|
num_prunable : int
|
|
677
615
|
The number of leaf parents that could be pruned, after converting the
|
|
678
616
|
leaf to be grown to a non-terminal node.
|
|
679
|
-
|
|
680
|
-
Half the length of the tree array, i.e., 2 ** (d - 1).
|
|
681
|
-
p_nonterminal : array (d - 1,)
|
|
617
|
+
p_nonterminal : array (d,)
|
|
682
618
|
The probability of a nonterminal node at each depth.
|
|
683
619
|
leaf_to_grow : int
|
|
684
620
|
The index of the leaf to grow.
|
|
685
|
-
initial_split_tree : array (2 ** (d - 1),)
|
|
686
|
-
The splitting points of the tree, before the leaf is grown.
|
|
687
621
|
new_split_tree : array (2 ** (d - 1),)
|
|
688
622
|
The splitting points of the tree, after the leaf is grown.
|
|
689
|
-
initial_affluence_tree : bool array (2 ** (d - 1),) or None
|
|
690
|
-
Whether a leaf has enough points to be grown, before the leaf is grown.
|
|
691
|
-
new_affluence_tree : bool array (2 ** (d - 1),) or None
|
|
692
|
-
Whether a leaf has enough points to be grown, after the leaf is grown.
|
|
693
623
|
|
|
694
624
|
Returns
|
|
695
625
|
-------
|
|
696
626
|
ratio : float
|
|
697
627
|
The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
|
|
698
|
-
times the prior ratio P(new tree) / P(old tree)
|
|
628
|
+
times the prior ratio P(new tree) / P(old tree), but the transition
|
|
629
|
+
ratio is missing the factor P(propose prune) in the numerator.
|
|
699
630
|
"""
|
|
700
631
|
|
|
701
632
|
# the two ratios also contain factors num_available_split *
|
|
702
633
|
# num_available_var, but they cancel out
|
|
703
634
|
|
|
704
|
-
|
|
705
|
-
|
|
635
|
+
prune_allowed = leaf_to_grow != 1
|
|
636
|
+
# prune allowed <---> the initial tree is not a root
|
|
637
|
+
# leaf to grow is root --> the tree can only be a root
|
|
638
|
+
# tree is a root --> the only leaf I can grow is root
|
|
706
639
|
|
|
707
|
-
|
|
708
|
-
p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
640
|
+
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
709
641
|
|
|
710
|
-
trans_ratio =
|
|
642
|
+
trans_ratio = num_growable / (p_grow * num_prunable)
|
|
711
643
|
|
|
712
|
-
depth = grove.
|
|
644
|
+
depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
|
|
713
645
|
p_parent = p_nonterminal[depth]
|
|
714
|
-
cp_children = 1 - p_nonterminal
|
|
646
|
+
cp_children = 1 - p_nonterminal[depth + 1]
|
|
715
647
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
716
648
|
|
|
717
649
|
return trans_ratio * tree_ratio
|
|
718
650
|
|
|
719
|
-
def
|
|
720
|
-
"""
|
|
721
|
-
Compute the likelihood ratio of a grow move.
|
|
722
|
-
|
|
723
|
-
Parameters
|
|
724
|
-
----------
|
|
725
|
-
X : array (p, n)
|
|
726
|
-
The predictors.
|
|
727
|
-
var_tree : array (2 ** (d - 1),)
|
|
728
|
-
The variable indices of the tree, after the grow move.
|
|
729
|
-
split_tree : array (2 ** (d - 1),)
|
|
730
|
-
The splitting points of the tree, after the grow move.
|
|
731
|
-
resid : array (n,)
|
|
732
|
-
The residuals (data minus forest value), for all trees but the one
|
|
733
|
-
under consideration.
|
|
734
|
-
sigma2 : float
|
|
735
|
-
The noise variance.
|
|
736
|
-
new_node : int
|
|
737
|
-
The index of the leaf that has been grown.
|
|
738
|
-
n_tree : int
|
|
739
|
-
The number of trees in the forest.
|
|
740
|
-
min_points_per_leaf : int or None
|
|
741
|
-
The minimum number of data points in a leaf node.
|
|
742
|
-
|
|
743
|
-
Returns
|
|
744
|
-
-------
|
|
745
|
-
ratio : float
|
|
746
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
747
|
-
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
748
|
-
Whether a leaf has enough points to be grown, after the grow move.
|
|
749
|
-
"""
|
|
750
|
-
|
|
751
|
-
resid_tree, count_tree = agg_values(
|
|
752
|
-
X,
|
|
753
|
-
var_tree,
|
|
754
|
-
split_tree,
|
|
755
|
-
resid,
|
|
756
|
-
sigma2.dtype,
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
left_child = new_node << 1
|
|
760
|
-
right_child = left_child + 1
|
|
761
|
-
|
|
762
|
-
left_resid = resid_tree[left_child]
|
|
763
|
-
right_resid = resid_tree[right_child]
|
|
764
|
-
total_resid = left_resid + right_resid
|
|
765
|
-
|
|
766
|
-
left_count = count_tree[left_child]
|
|
767
|
-
right_count = count_tree[right_child]
|
|
768
|
-
total_count = left_count + right_count
|
|
769
|
-
|
|
770
|
-
sigma_mu2 = 1 / n_tree
|
|
771
|
-
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
772
|
-
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
773
|
-
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
774
|
-
|
|
775
|
-
sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
|
|
776
|
-
|
|
777
|
-
exp_term = sigma_mu2 / (2 * sigma2) * (
|
|
778
|
-
left_resid * left_resid / sigma2_left +
|
|
779
|
-
right_resid * right_resid / sigma2_right -
|
|
780
|
-
total_resid * total_resid / sigma2_total
|
|
781
|
-
)
|
|
782
|
-
|
|
783
|
-
ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
784
|
-
|
|
785
|
-
if min_points_per_leaf is not None:
|
|
786
|
-
ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
|
|
787
|
-
ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
|
|
788
|
-
affluence_tree = count_tree[:count_tree.size // 2] >= 2 * min_points_per_leaf
|
|
789
|
-
else:
|
|
790
|
-
affluence_tree = None
|
|
791
|
-
|
|
792
|
-
return ratio, affluence_tree
|
|
793
|
-
|
|
794
|
-
def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, sigma2, resid, n_tree, min_points_per_leaf, key):
|
|
651
|
+
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
795
652
|
"""
|
|
796
653
|
Tree structure prune move proposal of BART MCMC.
|
|
797
654
|
|
|
798
655
|
Parameters
|
|
799
656
|
----------
|
|
800
|
-
X : array (p, n)
|
|
801
|
-
The predictors.
|
|
802
657
|
var_tree : array (2 ** (d - 1),)
|
|
803
658
|
The variable indices of the tree.
|
|
804
659
|
split_tree : array (2 ** (d - 1),)
|
|
@@ -807,50 +662,35 @@ def prune_move(X, var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
|
807
662
|
Whether a leaf has enough points to be grown.
|
|
808
663
|
max_split : array (p,)
|
|
809
664
|
The maximum split index for each variable.
|
|
810
|
-
p_nonterminal : array (d
|
|
665
|
+
p_nonterminal : array (d,)
|
|
811
666
|
The probability of a nonterminal node at each depth.
|
|
812
|
-
sigma2 : float
|
|
813
|
-
The noise variance.
|
|
814
|
-
resid : array (n,)
|
|
815
|
-
The residuals (data minus forest value), computed using all trees but
|
|
816
|
-
the tree under consideration.
|
|
817
|
-
n_tree : int
|
|
818
|
-
The number of trees in the forest.
|
|
819
|
-
min_points_per_leaf : int
|
|
820
|
-
The minimum number of data points in a leaf node.
|
|
821
667
|
key : jax.dtypes.prng_key array
|
|
822
668
|
A jax random key.
|
|
823
669
|
|
|
824
670
|
Returns
|
|
825
671
|
-------
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
672
|
+
prune_move : dict
|
|
673
|
+
A dictionary with fields:
|
|
674
|
+
|
|
675
|
+
'allowed' : bool
|
|
676
|
+
Whether the move is possible.
|
|
677
|
+
'node' : int
|
|
678
|
+
The index of the node to prune.
|
|
679
|
+
'partial_ratio' : float
|
|
680
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
681
|
+
the likelihood ratio and the probability of proposing the prune
|
|
682
|
+
move. This ratio is inverted.
|
|
836
683
|
"""
|
|
837
684
|
node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
|
|
838
|
-
allowed =
|
|
685
|
+
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
839
686
|
|
|
840
|
-
|
|
841
|
-
# should I clean up var_tree as well? just for debugging. it hasn't given me problems though
|
|
687
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
|
|
842
688
|
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
689
|
+
return dict(
|
|
690
|
+
allowed=allowed,
|
|
691
|
+
node=node_to_prune,
|
|
692
|
+
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
847
693
|
)
|
|
848
|
-
trans_tree_ratio = compute_trans_tree_ratio(num_growable, num_prunable, split_tree.size, p_nonterminal, node_to_prune, new_split_tree, split_tree, new_affluence_tree, affluence_tree)
|
|
849
|
-
|
|
850
|
-
ratio = trans_tree_ratio * likelihood_ratio
|
|
851
|
-
ratio = 1 / ratio # Question: should I use lax.reciprocal for this?
|
|
852
|
-
|
|
853
|
-
return var_tree, new_split_tree, new_affluence_tree, allowed, ratio
|
|
854
694
|
|
|
855
695
|
def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
856
696
|
"""
|
|
@@ -890,132 +730,363 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
890
730
|
|
|
891
731
|
return node_to_prune, num_prunable, num_growable
|
|
892
732
|
|
|
893
|
-
def
|
|
733
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
|
|
894
734
|
"""
|
|
895
|
-
|
|
735
|
+
Accept or reject the proposed moves and sample the new leaf values.
|
|
896
736
|
|
|
897
737
|
Parameters
|
|
898
738
|
----------
|
|
899
|
-
|
|
900
|
-
|
|
739
|
+
bart : dict
|
|
740
|
+
A BART mcmc state.
|
|
741
|
+
grow_moves : dict
|
|
742
|
+
The proposals for grow moves, batched over the first axis. See
|
|
743
|
+
`grow_move`.
|
|
744
|
+
prune_moves : dict
|
|
745
|
+
The proposals for prune moves, batched over the first axis. See
|
|
746
|
+
`prune_move`.
|
|
747
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
748
|
+
The leaf indices of the trees proposed by the grow move.
|
|
749
|
+
key : jax.dtypes.prng_key array
|
|
750
|
+
A jax random key.
|
|
901
751
|
|
|
902
752
|
Returns
|
|
903
753
|
-------
|
|
904
|
-
|
|
905
|
-
|
|
754
|
+
bart : dict
|
|
755
|
+
The new BART mcmc state.
|
|
906
756
|
"""
|
|
907
|
-
|
|
757
|
+
bart = bart.copy()
|
|
758
|
+
def loop(carry, item):
|
|
759
|
+
resid = carry.pop('resid')
|
|
760
|
+
resid, carry, trees = accept_move_and_sample_leaves(
|
|
761
|
+
bart['X'],
|
|
762
|
+
len(bart['leaf_trees']),
|
|
763
|
+
bart['opt']['suffstat_batch_size'],
|
|
764
|
+
resid,
|
|
765
|
+
bart['sigma2'],
|
|
766
|
+
bart['min_points_per_leaf'],
|
|
767
|
+
carry,
|
|
768
|
+
*item,
|
|
769
|
+
)
|
|
770
|
+
carry['resid'] = resid
|
|
771
|
+
return carry, trees
|
|
772
|
+
carry = {
|
|
773
|
+
k: jnp.zeros_like(bart[k]) for k in
|
|
774
|
+
['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
|
|
775
|
+
}
|
|
776
|
+
carry['resid'] = bart['resid']
|
|
777
|
+
items = (
|
|
778
|
+
bart['leaf_trees'],
|
|
779
|
+
bart['split_trees'],
|
|
780
|
+
bart['affluence_trees'],
|
|
781
|
+
grow_moves,
|
|
782
|
+
prune_moves,
|
|
783
|
+
grow_leaf_indices,
|
|
784
|
+
random.split(key, len(bart['leaf_trees'])),
|
|
785
|
+
)
|
|
786
|
+
carry, trees = lax.scan(loop, carry, items)
|
|
787
|
+
bart.update(carry)
|
|
788
|
+
bart.update(trees)
|
|
789
|
+
return bart
|
|
908
790
|
|
|
909
|
-
def
|
|
791
|
+
def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
|
|
910
792
|
"""
|
|
911
|
-
|
|
793
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
912
794
|
|
|
913
795
|
Parameters
|
|
914
796
|
----------
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
797
|
+
X : int array (p, n)
|
|
798
|
+
The predictors.
|
|
799
|
+
ntree : int
|
|
800
|
+
The number of trees in the forest.
|
|
801
|
+
suffstat_batch_size : int, None
|
|
802
|
+
The batch size for computing sufficient statistics.
|
|
803
|
+
resid : float array (n,)
|
|
804
|
+
The residuals (data minus forest value).
|
|
805
|
+
sigma2 : float
|
|
806
|
+
The noise variance.
|
|
807
|
+
min_points_per_leaf : int or None
|
|
808
|
+
The minimum number of data points in a leaf node.
|
|
809
|
+
counts : dict
|
|
810
|
+
The acceptance counts from the mcmc state dict.
|
|
811
|
+
leaf_tree : float array (2 ** d,)
|
|
812
|
+
The leaf values of the tree.
|
|
813
|
+
split_tree : int array (2 ** (d - 1),)
|
|
814
|
+
The decision boundaries of the tree.
|
|
815
|
+
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
816
|
+
Whether a leaf has enough points to be grown.
|
|
817
|
+
grow_move : dict
|
|
818
|
+
The proposal for the grow move. See `grow_move`.
|
|
819
|
+
prune_move : dict
|
|
820
|
+
The proposal for the prune move. See `prune_move`.
|
|
821
|
+
grow_leaf_indices : int array (n,)
|
|
822
|
+
The leaf indices of the tree proposed by the grow move.
|
|
918
823
|
key : jax.dtypes.prng_key array
|
|
919
824
|
A jax random key.
|
|
920
|
-
i_tree : int
|
|
921
|
-
The index of the tree to sample.
|
|
922
825
|
|
|
923
826
|
Returns
|
|
924
827
|
-------
|
|
925
|
-
|
|
926
|
-
The
|
|
828
|
+
resid : float array (n,)
|
|
829
|
+
The updated residuals (data minus forest value).
|
|
830
|
+
counts : dict
|
|
831
|
+
The updated acceptance counts.
|
|
832
|
+
trees : dict
|
|
833
|
+
The updated tree arrays.
|
|
927
834
|
"""
|
|
928
|
-
|
|
835
|
+
|
|
836
|
+
# compute leaf indices in starting tree
|
|
837
|
+
grow_node = grow_move['node']
|
|
838
|
+
grow_left = grow_node << 1
|
|
839
|
+
grow_right = grow_left + 1
|
|
840
|
+
leaf_indices = jnp.where(
|
|
841
|
+
(grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
|
|
842
|
+
grow_node,
|
|
843
|
+
grow_leaf_indices,
|
|
844
|
+
)
|
|
929
845
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
846
|
+
# compute leaf indices in prune tree
|
|
847
|
+
prune_node = prune_move['node']
|
|
848
|
+
prune_left = prune_node << 1
|
|
849
|
+
prune_right = prune_left + 1
|
|
850
|
+
prune_leaf_indices = jnp.where(
|
|
851
|
+
(leaf_indices == prune_left) | (leaf_indices == prune_right),
|
|
852
|
+
prune_node,
|
|
853
|
+
leaf_indices,
|
|
936
854
|
)
|
|
937
855
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
856
|
+
# subtract starting tree from function
|
|
857
|
+
resid += leaf_tree[leaf_indices]
|
|
858
|
+
|
|
859
|
+
# aggregate residuals and count units per leaf
|
|
860
|
+
grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
|
|
861
|
+
|
|
862
|
+
# compute aggregations in starting tree
|
|
863
|
+
# I do not zero the children because garbage there does not matter
|
|
864
|
+
resid_tree = (grow_resid_tree.at[grow_node]
|
|
865
|
+
.set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
|
|
866
|
+
count_tree = (grow_count_tree.at[grow_node]
|
|
867
|
+
.set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
|
|
942
868
|
|
|
869
|
+
# compute aggregations in prune tree
|
|
870
|
+
prune_resid_tree = (resid_tree.at[prune_node]
|
|
871
|
+
.set(resid_tree[prune_left] + resid_tree[prune_right]))
|
|
872
|
+
prune_count_tree = (count_tree.at[prune_node]
|
|
873
|
+
.set(count_tree[prune_left] + count_tree[prune_right]))
|
|
874
|
+
|
|
875
|
+
# compute affluence trees
|
|
876
|
+
if min_points_per_leaf is not None:
|
|
877
|
+
grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
|
|
878
|
+
prune_affluence_tree = affluence_tree.at[prune_node].set(True)
|
|
879
|
+
|
|
880
|
+
# compute probability of proposing prune
|
|
881
|
+
grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
|
|
882
|
+
prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
|
|
883
|
+
|
|
884
|
+
# compute likelihood ratios
|
|
885
|
+
grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
|
|
886
|
+
prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
|
|
887
|
+
|
|
888
|
+
# compute acceptance ratios
|
|
889
|
+
grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
|
|
890
|
+
prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
|
|
891
|
+
prune_ratio = lax.reciprocal(prune_ratio)
|
|
892
|
+
|
|
893
|
+
# random coins in [0, 1) for proposal and acceptance
|
|
894
|
+
key, subkey = random.split(key)
|
|
895
|
+
u0, u1 = random.uniform(subkey, (2,))
|
|
896
|
+
|
|
897
|
+
# determine what move to propose (not proposing anything is an option)
|
|
898
|
+
p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
|
|
899
|
+
try_grow = u0 < p_grow
|
|
900
|
+
try_prune = prune_move['allowed'] & ~try_grow
|
|
901
|
+
|
|
902
|
+
# determine whether to accept the move
|
|
903
|
+
do_grow = try_grow & (u1 < grow_ratio)
|
|
904
|
+
do_prune = try_prune & (u1 < prune_ratio)
|
|
905
|
+
|
|
906
|
+
# pick trees for chosen move
|
|
907
|
+
trees = {}
|
|
908
|
+
split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
|
|
909
|
+
# the prune var tree is equal to the initial one, because I leave garbage values behind
|
|
910
|
+
split_tree = split_tree.at[prune_node].set(
|
|
911
|
+
jnp.where(do_prune, 0, split_tree[prune_node]))
|
|
912
|
+
if min_points_per_leaf is not None:
|
|
913
|
+
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
914
|
+
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
915
|
+
resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
|
|
916
|
+
count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
|
|
917
|
+
resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
|
|
918
|
+
count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
|
|
919
|
+
|
|
920
|
+
# update acceptance counts
|
|
921
|
+
counts = counts.copy()
|
|
922
|
+
counts['grow_prop_count'] += try_grow
|
|
923
|
+
counts['grow_acc_count'] += do_grow
|
|
924
|
+
counts['prune_prop_count'] += try_prune
|
|
925
|
+
counts['prune_acc_count'] += do_prune
|
|
926
|
+
|
|
927
|
+
# compute leaves posterior
|
|
928
|
+
prec_lk = count_tree / sigma2
|
|
929
|
+
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
930
|
+
mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
931
|
+
|
|
932
|
+
# sample leaves
|
|
943
933
|
z = random.normal(key, mean_post.shape, mean_post.dtype)
|
|
944
|
-
# TODO maybe use long float here, I guess this part is not a bottleneck
|
|
945
934
|
leaf_tree = mean_post + z * jnp.sqrt(var_post)
|
|
946
|
-
leaf_tree = leaf_tree.at[0].set(0) # this 0 is used by evaluate_tree
|
|
947
|
-
bart['leaf_trees'] = bart['leaf_trees'].at[i_tree].set(leaf_tree)
|
|
948
935
|
|
|
949
|
-
|
|
936
|
+
# add new tree to function
|
|
937
|
+
leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
|
|
938
|
+
leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
|
|
939
|
+
resid -= leaf_tree[leaf_indices]
|
|
950
940
|
|
|
951
|
-
|
|
941
|
+
# pack trees
|
|
942
|
+
trees = {
|
|
943
|
+
'leaf_trees': leaf_tree,
|
|
944
|
+
'split_trees': split_tree,
|
|
945
|
+
'affluence_trees': affluence_tree,
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
return resid, counts, trees
|
|
949
|
+
|
|
950
|
+
def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
952
951
|
"""
|
|
953
|
-
|
|
952
|
+
Compute the sufficient statistics for the likelihood ratio of a tree move.
|
|
954
953
|
|
|
955
954
|
Parameters
|
|
956
955
|
----------
|
|
957
|
-
|
|
958
|
-
The
|
|
959
|
-
|
|
960
|
-
The
|
|
961
|
-
|
|
962
|
-
The
|
|
963
|
-
|
|
964
|
-
The
|
|
965
|
-
|
|
966
|
-
The dtype of the output.
|
|
956
|
+
resid : float array (n,)
|
|
957
|
+
The residuals (data minus forest value).
|
|
958
|
+
leaf_indices : int array (n,)
|
|
959
|
+
The leaf indices of the tree (in which leaf each data point falls into).
|
|
960
|
+
tree_size : int
|
|
961
|
+
The size of the tree array (2 ** d).
|
|
962
|
+
batch_size : int, None
|
|
963
|
+
The batch size for the aggregation. Batching increases numerical
|
|
964
|
+
accuracy and parallelism.
|
|
967
965
|
|
|
968
966
|
Returns
|
|
969
967
|
-------
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
each leaf contains the sum of the `values` whose corresponding `X` fall
|
|
973
|
-
into the leaf.
|
|
968
|
+
resid_tree : float array (2 ** d,)
|
|
969
|
+
The sum of the residuals at data points in each leaf.
|
|
974
970
|
count_tree : int array (2 ** d,)
|
|
975
|
-
|
|
971
|
+
The number of data points in each leaf.
|
|
976
972
|
"""
|
|
973
|
+
if batch_size is None:
|
|
974
|
+
aggr_func = _aggregate_scatter
|
|
975
|
+
else:
|
|
976
|
+
aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
|
|
977
|
+
resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
978
|
+
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
979
|
+
return resid_tree, count_tree
|
|
980
|
+
|
|
981
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
982
|
+
return (jnp
|
|
983
|
+
.zeros(size, dtype)
|
|
984
|
+
.at[indices]
|
|
985
|
+
.add(values)
|
|
986
|
+
)
|
|
977
987
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
988
|
+
def _aggregate_batched(values, indices, size, dtype, batch_size):
|
|
989
|
+
nbatches = indices.size // batch_size + bool(indices.size % batch_size)
|
|
990
|
+
batch_indices = jnp.arange(indices.size) // batch_size
|
|
991
|
+
return (jnp
|
|
992
|
+
.zeros((nbatches, size), dtype)
|
|
993
|
+
.at[batch_indices, indices]
|
|
994
|
+
.add(values)
|
|
995
|
+
.sum(axis=0)
|
|
984
996
|
)
|
|
985
|
-
unit_index = jnp.arange(values.size, dtype=grove.minimal_unsigned_dtype(values.size - 1))
|
|
986
997
|
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
998
|
+
def compute_p_prune_back(new_split_tree, new_affluence_tree):
|
|
999
|
+
"""
|
|
1000
|
+
Compute the probability of proposing a prune move after doing a grow move.
|
|
1001
|
+
|
|
1002
|
+
Parameters
|
|
1003
|
+
----------
|
|
1004
|
+
new_split_tree : int array (2 ** (d - 1),)
|
|
1005
|
+
The decision boundaries of the tree, after the grow move.
|
|
1006
|
+
new_affluence_tree : bool array (2 ** (d - 1),)
|
|
1007
|
+
Which leaves have enough points to be grown, after the grow move.
|
|
1008
|
+
|
|
1009
|
+
Returns
|
|
1010
|
+
-------
|
|
1011
|
+
p_prune : float
|
|
1012
|
+
The probability of proposing a prune move after the grow move. This is
|
|
1013
|
+
0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
|
|
1014
|
+
at least the node just grown can be pruned.
|
|
1015
|
+
"""
|
|
1016
|
+
_, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
|
|
1017
|
+
return jnp.where(grow_again_allowed, 0.5, 1)
|
|
1018
|
+
|
|
1019
|
+
def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
|
|
1020
|
+
"""
|
|
1021
|
+
Compute the likelihood ratio of a grow move.
|
|
1022
|
+
|
|
1023
|
+
Parameters
|
|
1024
|
+
----------
|
|
1025
|
+
resid_tree : float array (2 ** d,)
|
|
1026
|
+
The sum of the residuals at data points in each leaf.
|
|
1027
|
+
count_tree : int array (2 ** d,)
|
|
1028
|
+
The number of data points in each leaf.
|
|
1029
|
+
sigma2 : float
|
|
1030
|
+
The noise variance.
|
|
1031
|
+
node : int
|
|
1032
|
+
The index of the leaf that has been grown.
|
|
1033
|
+
n_tree : int
|
|
1034
|
+
The number of trees in the forest.
|
|
1035
|
+
min_points_per_leaf : int or None
|
|
1036
|
+
The minimum number of data points in a leaf node.
|
|
1037
|
+
|
|
1038
|
+
Returns
|
|
1039
|
+
-------
|
|
1040
|
+
ratio : float
|
|
1041
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1042
|
+
|
|
1043
|
+
Notes
|
|
1044
|
+
-----
|
|
1045
|
+
The ratio is set to 0 if the grow move would create leaves with not enough
|
|
1046
|
+
datapoints per leaf, although this is part of the prior rather than the
|
|
1047
|
+
likelihood.
|
|
1048
|
+
"""
|
|
1049
|
+
|
|
1050
|
+
left_child = node << 1
|
|
1051
|
+
right_child = left_child + 1
|
|
1052
|
+
|
|
1053
|
+
left_resid = resid_tree[left_child]
|
|
1054
|
+
right_resid = resid_tree[right_child]
|
|
1055
|
+
total_resid = left_resid + right_resid
|
|
1056
|
+
|
|
1057
|
+
left_count = count_tree[left_child]
|
|
1058
|
+
right_count = count_tree[right_child]
|
|
1059
|
+
total_count = left_count + right_count
|
|
1060
|
+
|
|
1061
|
+
sigma_mu2 = 1 / n_tree
|
|
1062
|
+
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
1063
|
+
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
1064
|
+
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
1065
|
+
|
|
1066
|
+
sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
|
|
1067
|
+
|
|
1068
|
+
exp_term = sigma_mu2 / (2 * sigma2) * (
|
|
1069
|
+
left_resid * left_resid / sigma2_left +
|
|
1070
|
+
right_resid * right_resid / sigma2_right -
|
|
1071
|
+
total_resid * total_resid / sigma2_total
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
1075
|
+
|
|
1076
|
+
if min_points_per_leaf is not None:
|
|
1077
|
+
ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
|
|
1078
|
+
ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
|
|
1079
|
+
|
|
1080
|
+
return ratio
|
|
1081
|
+
|
|
1082
|
+
def sample_sigma(bart, key):
|
|
1012
1083
|
"""
|
|
1013
1084
|
Noise variance sampling step of BART MCMC.
|
|
1014
1085
|
|
|
1015
1086
|
Parameters
|
|
1016
1087
|
----------
|
|
1017
1088
|
bart : dict
|
|
1018
|
-
A BART mcmc state, as created by `
|
|
1089
|
+
A BART mcmc state, as created by `init`.
|
|
1019
1090
|
key : jax.dtypes.prng_key array
|
|
1020
1091
|
A jax random key.
|
|
1021
1092
|
|
|
@@ -1028,8 +1099,8 @@ def mcmc_sample_sigma(bart, key):
|
|
|
1028
1099
|
|
|
1029
1100
|
resid = bart['resid']
|
|
1030
1101
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1031
|
-
|
|
1032
|
-
beta = bart['sigma2_beta'] +
|
|
1102
|
+
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
|
|
1103
|
+
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1033
1104
|
|
|
1034
1105
|
sample = random.gamma(key, alpha)
|
|
1035
1106
|
bart['sigma2'] = beta / sample
|