bartz 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +43 -18
- bartz/_version.py +1 -1
- bartz/grove.py +19 -14
- bartz/jaxext.py +48 -21
- bartz/mcmcloop.py +13 -15
- bartz/mcmcstep.py +795 -344
- bartz/prepcovars.py +43 -13
- bartz-0.4.0.dist-info/METADATA +77 -0
- bartz-0.4.0.dist-info/RECORD +13 -0
- bartz-0.2.1.dist-info/METADATA +0 -32
- bartz-0.2.1.dist-info/RECORD +0 -13
- {bartz-0.2.1.dist-info → bartz-0.4.0.dist-info}/LICENSE +0 -0
- {bartz-0.2.1.dist-info → bartz-0.4.0.dist-info}/WHEEL +0 -0
bartz/mcmcstep.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -34,6 +34,7 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
+
import math
|
|
37
38
|
|
|
38
39
|
import jax
|
|
39
40
|
from jax import random
|
|
@@ -54,7 +55,9 @@ def init(*,
|
|
|
54
55
|
small_float=jnp.float32,
|
|
55
56
|
large_float=jnp.float32,
|
|
56
57
|
min_points_per_leaf=None,
|
|
57
|
-
|
|
58
|
+
resid_batch_size='auto',
|
|
59
|
+
count_batch_size='auto',
|
|
60
|
+
save_ratios=False,
|
|
58
61
|
):
|
|
59
62
|
"""
|
|
60
63
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -82,10 +85,13 @@ def init(*,
|
|
|
82
85
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
83
86
|
min_points_per_leaf : int, optional
|
|
84
87
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
-
|
|
86
|
-
The batch
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
|
|
89
|
+
The batch sizes, along datapoints, for summing the residuals and
|
|
90
|
+
counting the number of datapoints in each leaf. `None` for no batching.
|
|
91
|
+
If 'auto', pick a value based on the device of `y`, or the default
|
|
92
|
+
device.
|
|
93
|
+
save_ratios : bool, default False
|
|
94
|
+
Whether to save the Metropolis-Hastings ratios.
|
|
89
95
|
|
|
90
96
|
Returns
|
|
91
97
|
-------
|
|
@@ -111,6 +117,8 @@ def init(*,
|
|
|
111
117
|
'p_nonterminal' : large_float array (d,)
|
|
112
118
|
The probability of a nonterminal node at each depth, padded with a
|
|
113
119
|
zero.
|
|
120
|
+
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
121
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
114
122
|
'sigma2_alpha' : large_float
|
|
115
123
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
116
124
|
'sigma2_beta' : large_float
|
|
@@ -121,6 +129,8 @@ def init(*,
|
|
|
121
129
|
The response.
|
|
122
130
|
'X' : int array (p, n)
|
|
123
131
|
The predictors.
|
|
132
|
+
'leaf_indices' : int array (num_trees, n)
|
|
133
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
124
134
|
'min_points_per_leaf' : int or None
|
|
125
135
|
The minimum number of data points in a leaf node.
|
|
126
136
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
@@ -129,8 +139,6 @@ def init(*,
|
|
|
129
139
|
'opt' : LeafDict
|
|
130
140
|
A dictionary with config values:
|
|
131
141
|
|
|
132
|
-
'suffstat_batch_size' : int or None
|
|
133
|
-
The batch size for computing sufficient statistics.
|
|
134
142
|
'small_float' : dtype
|
|
135
143
|
The dtype for large arrays used in the algorithm.
|
|
136
144
|
'large_float' : dtype
|
|
@@ -138,6 +146,16 @@ def init(*,
|
|
|
138
146
|
accuracy.
|
|
139
147
|
'require_min_points' : bool
|
|
140
148
|
Whether the `min_points_per_leaf` parameter is specified.
|
|
149
|
+
'resid_batch_size', 'count_batch_size' : int or None
|
|
150
|
+
The data batch sizes for computing the sufficient statistics.
|
|
151
|
+
'ratios' : dict, optional
|
|
152
|
+
If `save_ratios` is True, this field is present. It has the fields:
|
|
153
|
+
|
|
154
|
+
'log_trans_prior' : large_float array (num_trees,)
|
|
155
|
+
The log transition and prior Metropolis-Hastings ratio for the
|
|
156
|
+
proposed move on each tree.
|
|
157
|
+
'log_likelihood' : large_float array (num_trees,)
|
|
158
|
+
The log likelihood ratio.
|
|
141
159
|
"""
|
|
142
160
|
|
|
143
161
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
@@ -151,24 +169,28 @@ def init(*,
|
|
|
151
169
|
small_float = jnp.dtype(small_float)
|
|
152
170
|
large_float = jnp.dtype(large_float)
|
|
153
171
|
y = jnp.asarray(y, small_float)
|
|
154
|
-
|
|
172
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
|
|
173
|
+
sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
|
|
174
|
+
sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
155
175
|
|
|
156
176
|
bart = dict(
|
|
157
177
|
leaf_trees=make_forest(max_depth, small_float),
|
|
158
178
|
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
159
179
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
160
180
|
resid=jnp.asarray(y, large_float),
|
|
161
|
-
sigma2=
|
|
181
|
+
sigma2=sigma2,
|
|
162
182
|
grow_prop_count=jnp.zeros((), int),
|
|
163
183
|
grow_acc_count=jnp.zeros((), int),
|
|
164
184
|
prune_prop_count=jnp.zeros((), int),
|
|
165
185
|
prune_acc_count=jnp.zeros((), int),
|
|
166
186
|
p_nonterminal=p_nonterminal,
|
|
187
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
167
188
|
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
168
189
|
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
169
190
|
max_split=jnp.asarray(max_split),
|
|
170
191
|
y=y,
|
|
171
192
|
X=jnp.asarray(X),
|
|
193
|
+
leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
|
|
172
194
|
min_points_per_leaf=(
|
|
173
195
|
None if min_points_per_leaf is None else
|
|
174
196
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -178,37 +200,59 @@ def init(*,
|
|
|
178
200
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
179
201
|
),
|
|
180
202
|
opt=jaxext.LeafDict(
|
|
181
|
-
suffstat_batch_size=suffstat_batch_size,
|
|
182
203
|
small_float=small_float,
|
|
183
204
|
large_float=large_float,
|
|
184
205
|
require_min_points=min_points_per_leaf is not None,
|
|
206
|
+
resid_batch_size=resid_batch_size,
|
|
207
|
+
count_batch_size=count_batch_size,
|
|
185
208
|
),
|
|
186
209
|
)
|
|
187
210
|
|
|
211
|
+
if save_ratios:
|
|
212
|
+
bart['ratios'] = dict(
|
|
213
|
+
log_trans_prior=jnp.full(num_trees, jnp.nan),
|
|
214
|
+
log_likelihood=jnp.full(num_trees, jnp.nan),
|
|
215
|
+
)
|
|
216
|
+
|
|
188
217
|
return bart
|
|
189
218
|
|
|
190
|
-
def _choose_suffstat_batch_size(
|
|
191
|
-
|
|
219
|
+
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
|
|
220
|
+
|
|
221
|
+
@functools.cache
|
|
222
|
+
def get_platform():
|
|
192
223
|
try:
|
|
193
224
|
device = y.devices().pop()
|
|
194
225
|
except jax.errors.ConcretizationTypeError:
|
|
195
226
|
device = jax.devices()[0]
|
|
196
227
|
platform = device.platform
|
|
228
|
+
if platform not in ('cpu', 'gpu'):
|
|
229
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
230
|
+
return platform
|
|
197
231
|
|
|
232
|
+
if resid_batch_size == 'auto':
|
|
233
|
+
platform = get_platform()
|
|
234
|
+
n = max(1, y.size)
|
|
198
235
|
if platform == 'cpu':
|
|
199
|
-
|
|
200
|
-
# maybe I should batch residuals (not counts) for numerical
|
|
201
|
-
# accuracy, even if it's slower
|
|
236
|
+
resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
|
|
202
237
|
elif platform == 'gpu':
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
238
|
+
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
|
|
239
|
+
resid_batch_size = max(1, resid_batch_size)
|
|
240
|
+
|
|
241
|
+
if count_batch_size == 'auto':
|
|
242
|
+
platform = get_platform()
|
|
243
|
+
if platform == 'cpu':
|
|
244
|
+
count_batch_size = None
|
|
245
|
+
elif platform == 'gpu':
|
|
246
|
+
n = max(1, y.size)
|
|
247
|
+
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
248
|
+
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
249
|
+
max_memory = 2 ** 29
|
|
250
|
+
itemsize = 4
|
|
251
|
+
min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
|
|
252
|
+
count_batch_size = max(count_batch_size, min_batch_size)
|
|
253
|
+
count_batch_size = max(1, count_batch_size)
|
|
254
|
+
|
|
255
|
+
return resid_batch_size, count_batch_size
|
|
212
256
|
|
|
213
257
|
def step(bart, key):
|
|
214
258
|
"""
|
|
@@ -248,14 +292,11 @@ def sample_trees(bart, key):
|
|
|
248
292
|
|
|
249
293
|
Notes
|
|
250
294
|
-----
|
|
251
|
-
This function zeroes the proposal counters
|
|
295
|
+
This function zeroes the proposal counters.
|
|
252
296
|
"""
|
|
253
|
-
bart = bart.copy()
|
|
254
297
|
key, subkey = random.split(key)
|
|
255
|
-
|
|
256
|
-
bart
|
|
257
|
-
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
258
|
-
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
298
|
+
moves = sample_moves(bart, subkey)
|
|
299
|
+
return accept_moves_and_sample_leaves(bart, moves, key)
|
|
259
300
|
|
|
260
301
|
def sample_moves(bart, key):
|
|
261
302
|
"""
|
|
@@ -270,21 +311,75 @@ def sample_moves(bart, key):
|
|
|
270
311
|
|
|
271
312
|
Returns
|
|
272
313
|
-------
|
|
273
|
-
|
|
274
|
-
|
|
314
|
+
moves : dict
|
|
315
|
+
A dictionary with fields:
|
|
316
|
+
|
|
317
|
+
'allowed' : bool array (num_trees,)
|
|
318
|
+
Whether the move is possible.
|
|
319
|
+
'grow' : bool array (num_trees,)
|
|
320
|
+
Whether the move is a grow move or a prune move.
|
|
321
|
+
'num_growable' : int array (num_trees,)
|
|
322
|
+
The number of growable leaves in the original tree.
|
|
323
|
+
'node' : int array (num_trees,)
|
|
324
|
+
The index of the leaf to grow or node to prune.
|
|
325
|
+
'left', 'right' : int array (num_trees,)
|
|
326
|
+
The indices of the children of 'node'.
|
|
327
|
+
'partial_ratio' : float array (num_trees,)
|
|
328
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
329
|
+
the likelihood ratio and the probability of proposing the prune
|
|
330
|
+
move. If the move is Prune, the ratio is inverted.
|
|
331
|
+
'grow_var' : int array (num_trees,)
|
|
332
|
+
The decision axes of the new rules.
|
|
333
|
+
'grow_split' : int array (num_trees,)
|
|
334
|
+
The decision boundaries of the new rules.
|
|
335
|
+
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
336
|
+
The updated decision axes of the trees, valid whatever move.
|
|
337
|
+
'logu' : float array (num_trees,)
|
|
338
|
+
The logarithm of a uniform (0, 1] random variable to be used to
|
|
339
|
+
accept the move. It's in (-oo, 0].
|
|
275
340
|
"""
|
|
276
|
-
|
|
277
|
-
|
|
341
|
+
ntree = bart['leaf_trees'].shape[0]
|
|
342
|
+
key = random.split(key, 1 + ntree)
|
|
343
|
+
key, subkey = key[0], key[1:]
|
|
344
|
+
|
|
345
|
+
# compute moves
|
|
346
|
+
grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
|
|
347
|
+
|
|
348
|
+
u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
|
|
278
349
|
|
|
279
|
-
|
|
280
|
-
|
|
350
|
+
# choose between grow or prune
|
|
351
|
+
grow_allowed = grow_moves['num_growable'].astype(bool)
|
|
352
|
+
p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
|
|
353
|
+
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
354
|
+
|
|
355
|
+
# compute children indices
|
|
356
|
+
node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
|
|
357
|
+
left = node << 1
|
|
358
|
+
right = left + 1
|
|
359
|
+
|
|
360
|
+
return dict(
|
|
361
|
+
allowed=grow | prune_moves['allowed'],
|
|
362
|
+
grow=grow,
|
|
363
|
+
num_growable=grow_moves['num_growable'],
|
|
364
|
+
node=node,
|
|
365
|
+
left=left,
|
|
366
|
+
right=right,
|
|
367
|
+
partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
|
|
368
|
+
grow_var=grow_moves['var'],
|
|
369
|
+
grow_split=grow_moves['split'],
|
|
370
|
+
var_trees=grow_moves['var_tree'],
|
|
371
|
+
logu=jnp.log1p(-logu),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
|
|
375
|
+
def _sample_moves_vmap_trees(*args):
|
|
376
|
+
args, key = args[:-1], args[-1]
|
|
281
377
|
key, key1 = random.split(key)
|
|
282
|
-
args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
283
378
|
grow = grow_move(*args, key)
|
|
284
379
|
prune = prune_move(*args, key1)
|
|
285
380
|
return grow, prune
|
|
286
381
|
|
|
287
|
-
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
382
|
+
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
288
383
|
"""
|
|
289
384
|
Tree structure grow move proposal of BART MCMC.
|
|
290
385
|
|
|
@@ -304,6 +399,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
304
399
|
The maximum split index for each variable.
|
|
305
400
|
p_nonterminal : array (d,)
|
|
306
401
|
The probability of a nonterminal node at each depth.
|
|
402
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
403
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
307
404
|
key : jax.dtypes.prng_key array
|
|
308
405
|
A jax random key.
|
|
309
406
|
|
|
@@ -312,41 +409,42 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
312
409
|
grow_move : dict
|
|
313
410
|
A dictionary with fields:
|
|
314
411
|
|
|
315
|
-
'
|
|
316
|
-
|
|
412
|
+
'num_growable' : int
|
|
413
|
+
The number of growable leaves.
|
|
317
414
|
'node' : int
|
|
318
|
-
The index of the leaf to grow.
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
The new decision boundaries of the tree.
|
|
415
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
416
|
+
leaves.
|
|
417
|
+
'var', 'split' : int
|
|
418
|
+
The decision axis and boundary of the new rule.
|
|
323
419
|
'partial_ratio' : float
|
|
324
420
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
325
421
|
the likelihood ratio and the probability of proposing the prune
|
|
326
422
|
move.
|
|
423
|
+
'var_tree' : array (2 ** (d - 1),)
|
|
424
|
+
The updated decision axes of the tree.
|
|
327
425
|
"""
|
|
328
426
|
|
|
329
427
|
key, key1, key2 = random.split(key, 3)
|
|
330
|
-
|
|
331
|
-
leaf_to_grow, num_growable,
|
|
428
|
+
|
|
429
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
|
|
332
430
|
|
|
333
431
|
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
|
|
334
432
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
335
|
-
|
|
433
|
+
|
|
336
434
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
337
|
-
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
338
435
|
|
|
339
|
-
ratio = compute_partial_ratio(
|
|
436
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
|
|
340
437
|
|
|
341
438
|
return dict(
|
|
342
|
-
|
|
439
|
+
num_growable=num_growable,
|
|
343
440
|
node=leaf_to_grow,
|
|
441
|
+
var=var,
|
|
442
|
+
split=split,
|
|
344
443
|
partial_ratio=ratio,
|
|
345
444
|
var_tree=var_tree,
|
|
346
|
-
split_tree=split_tree,
|
|
347
445
|
)
|
|
348
446
|
|
|
349
|
-
def choose_leaf(split_tree, affluence_tree, key):
|
|
447
|
+
def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
350
448
|
"""
|
|
351
449
|
Choose a leaf node to grow in a tree.
|
|
352
450
|
|
|
@@ -356,6 +454,8 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
356
454
|
The splitting points of the tree.
|
|
357
455
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
358
456
|
Whether a leaf has enough points to be grown.
|
|
457
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
458
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
359
459
|
key : jax.dtypes.prng_key array
|
|
360
460
|
A jax random key.
|
|
361
461
|
|
|
@@ -366,19 +466,21 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
366
466
|
``2 ** d``.
|
|
367
467
|
num_growable : int
|
|
368
468
|
The number of leaf nodes that can be grown.
|
|
469
|
+
prob_choose : float
|
|
470
|
+
The normalized probability of choosing the selected leaf.
|
|
369
471
|
num_prunable : int
|
|
370
472
|
The number of leaf parents that could be pruned, after converting the
|
|
371
473
|
selected leaf to a non-terminal node.
|
|
372
|
-
allowed : bool
|
|
373
|
-
Whether the grow move is allowed.
|
|
374
474
|
"""
|
|
375
|
-
is_growable
|
|
376
|
-
leaf_to_grow = randint_masked(key, is_growable)
|
|
377
|
-
leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
|
|
475
|
+
is_growable = growable_leaves(split_tree, affluence_tree)
|
|
378
476
|
num_growable = jnp.count_nonzero(is_growable)
|
|
477
|
+
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
478
|
+
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
479
|
+
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
480
|
+
prob_choose = distr[leaf_to_grow] / distr_norm
|
|
379
481
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
380
482
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
381
|
-
return leaf_to_grow, num_growable,
|
|
483
|
+
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
382
484
|
|
|
383
485
|
def growable_leaves(split_tree, affluence_tree):
|
|
384
486
|
"""
|
|
@@ -397,34 +499,32 @@ def growable_leaves(split_tree, affluence_tree):
|
|
|
397
499
|
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
398
500
|
that are not at the bottom level and have at least two times the number
|
|
399
501
|
of minimum points per leaf.
|
|
400
|
-
allowed : bool
|
|
401
|
-
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
402
502
|
"""
|
|
403
503
|
is_growable = grove.is_actual_leaf(split_tree)
|
|
404
504
|
if affluence_tree is not None:
|
|
405
505
|
is_growable &= affluence_tree
|
|
406
|
-
return is_growable
|
|
506
|
+
return is_growable
|
|
407
507
|
|
|
408
|
-
def
|
|
508
|
+
def categorical(key, distr):
|
|
409
509
|
"""
|
|
410
|
-
Return a random integer
|
|
510
|
+
Return a random integer from an arbitrary distribution.
|
|
411
511
|
|
|
412
512
|
Parameters
|
|
413
513
|
----------
|
|
414
514
|
key : jax.dtypes.prng_key array
|
|
415
515
|
A jax random key.
|
|
416
|
-
|
|
417
|
-
|
|
516
|
+
distr : float array (n,)
|
|
517
|
+
An unnormalized probability distribution.
|
|
418
518
|
|
|
419
519
|
Returns
|
|
420
520
|
-------
|
|
421
521
|
u : int
|
|
422
|
-
A random integer in the range ``[0, n)
|
|
423
|
-
|
|
522
|
+
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
523
|
+
return ``n``.
|
|
424
524
|
"""
|
|
425
|
-
ecdf = jnp.cumsum(
|
|
426
|
-
u = random.
|
|
427
|
-
return jnp.searchsorted(ecdf, u, 'right')
|
|
525
|
+
ecdf = jnp.cumsum(distr)
|
|
526
|
+
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
527
|
+
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
428
528
|
|
|
429
529
|
def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
|
|
430
530
|
"""
|
|
@@ -479,7 +579,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
|
479
579
|
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
480
580
|
particular order. Variables may appear more than once.
|
|
481
581
|
"""
|
|
482
|
-
|
|
582
|
+
|
|
483
583
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
484
584
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
485
585
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
@@ -611,7 +711,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
611
711
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
612
712
|
return random.randint(key, (), l, r)
|
|
613
713
|
|
|
614
|
-
def compute_partial_ratio(
|
|
714
|
+
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
|
|
615
715
|
"""
|
|
616
716
|
Compute the product of the transition and prior ratios of a grow move.
|
|
617
717
|
|
|
@@ -626,8 +726,6 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
626
726
|
The probability of a nonterminal node at each depth.
|
|
627
727
|
leaf_to_grow : int
|
|
628
728
|
The index of the leaf to grow.
|
|
629
|
-
new_split_tree : array (2 ** (d - 1),)
|
|
630
|
-
The splitting points of the tree, after the leaf is grown.
|
|
631
729
|
|
|
632
730
|
Returns
|
|
633
731
|
-------
|
|
@@ -640,6 +738,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
640
738
|
# the two ratios also contain factors num_available_split *
|
|
641
739
|
# num_available_var, but they cancel out
|
|
642
740
|
|
|
741
|
+
# p_prune can't be computed here because it needs the count trees, which are
|
|
742
|
+
# computed in the acceptance phase
|
|
743
|
+
|
|
643
744
|
prune_allowed = leaf_to_grow != 1
|
|
644
745
|
# prune allowed <---> the initial tree is not a root
|
|
645
746
|
# leaf to grow is root --> the tree can only be a root
|
|
@@ -647,31 +748,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
647
748
|
|
|
648
749
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
649
750
|
|
|
650
|
-
|
|
751
|
+
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
651
752
|
|
|
652
|
-
depth = grove.tree_depths(
|
|
753
|
+
depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
|
|
653
754
|
p_parent = p_nonterminal[depth]
|
|
654
755
|
cp_children = 1 - p_nonterminal[depth + 1]
|
|
655
756
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
656
757
|
|
|
657
|
-
return
|
|
758
|
+
return tree_ratio / inv_trans_ratio
|
|
658
759
|
|
|
659
|
-
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
760
|
+
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
660
761
|
"""
|
|
661
762
|
Tree structure prune move proposal of BART MCMC.
|
|
662
763
|
|
|
663
764
|
Parameters
|
|
664
765
|
----------
|
|
665
|
-
var_tree : array (2 ** (d - 1),)
|
|
766
|
+
var_tree : int array (2 ** (d - 1),)
|
|
666
767
|
The variable indices of the tree.
|
|
667
|
-
split_tree : array (2 ** (d - 1),)
|
|
768
|
+
split_tree : int array (2 ** (d - 1),)
|
|
668
769
|
The splitting points of the tree.
|
|
669
770
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
670
771
|
Whether a leaf has enough points to be grown.
|
|
671
|
-
max_split : array (p,)
|
|
772
|
+
max_split : int array (p,)
|
|
672
773
|
The maximum split index for each variable.
|
|
673
|
-
p_nonterminal : array (d,)
|
|
774
|
+
p_nonterminal : float array (d,)
|
|
674
775
|
The probability of a nonterminal node at each depth.
|
|
776
|
+
p_propose_grow : float array (2 ** (d - 1),)
|
|
777
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
675
778
|
key : jax.dtypes.prng_key array
|
|
676
779
|
A jax random key.
|
|
677
780
|
|
|
@@ -683,16 +786,16 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
683
786
|
'allowed' : bool
|
|
684
787
|
Whether the move is possible.
|
|
685
788
|
'node' : int
|
|
686
|
-
The index of the node to prune.
|
|
789
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
687
790
|
'partial_ratio' : float
|
|
688
791
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
689
792
|
the likelihood ratio and the probability of proposing the prune
|
|
690
793
|
move. This ratio is inverted.
|
|
691
794
|
"""
|
|
692
|
-
node_to_prune, num_prunable,
|
|
795
|
+
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
|
|
693
796
|
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
694
797
|
|
|
695
|
-
ratio = compute_partial_ratio(
|
|
798
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
|
|
696
799
|
|
|
697
800
|
return dict(
|
|
698
801
|
allowed=allowed,
|
|
@@ -700,7 +803,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
700
803
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
701
804
|
)
|
|
702
805
|
|
|
703
|
-
def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
806
|
+
def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
704
807
|
"""
|
|
705
808
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
706
809
|
|
|
@@ -710,6 +813,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
710
813
|
The splitting points of the tree.
|
|
711
814
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
712
815
|
Whether a leaf has enough points to be grown.
|
|
816
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
817
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
713
818
|
key : jax.dtypes.prng_key array
|
|
714
819
|
A jax random key.
|
|
715
820
|
|
|
@@ -717,28 +822,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
717
822
|
-------
|
|
718
823
|
node_to_prune : int
|
|
719
824
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
720
|
-
``
|
|
825
|
+
``2 ** d``.
|
|
721
826
|
num_prunable : int
|
|
722
827
|
The number of leaf parents that could be pruned.
|
|
723
|
-
|
|
724
|
-
The
|
|
725
|
-
node.
|
|
828
|
+
prob_choose : float
|
|
829
|
+
The normalized probability of choosing the node to prune for growth.
|
|
726
830
|
"""
|
|
727
831
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
728
|
-
node_to_prune = randint_masked(key, is_prunable)
|
|
729
832
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
833
|
+
node_to_prune = randint_masked(key, is_prunable)
|
|
834
|
+
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
730
835
|
|
|
731
|
-
|
|
732
|
-
|
|
836
|
+
split_tree = split_tree.at[node_to_prune].set(0)
|
|
837
|
+
affluence_tree = (
|
|
733
838
|
None if affluence_tree is None else
|
|
734
839
|
affluence_tree.at[node_to_prune].set(True)
|
|
735
840
|
)
|
|
736
|
-
is_growable_leaf
|
|
737
|
-
|
|
841
|
+
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
842
|
+
prob_choose = p_propose_grow[node_to_prune]
|
|
843
|
+
prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
738
844
|
|
|
739
|
-
return node_to_prune, num_prunable,
|
|
845
|
+
return node_to_prune, num_prunable, prob_choose
|
|
740
846
|
|
|
741
|
-
def
|
|
847
|
+
def randint_masked(key, mask):
|
|
848
|
+
"""
|
|
849
|
+
Return a random integer in a range, including only some values.
|
|
850
|
+
|
|
851
|
+
Parameters
|
|
852
|
+
----------
|
|
853
|
+
key : jax.dtypes.prng_key array
|
|
854
|
+
A jax random key.
|
|
855
|
+
mask : bool array (n,)
|
|
856
|
+
The mask indicating the allowed values.
|
|
857
|
+
|
|
858
|
+
Returns
|
|
859
|
+
-------
|
|
860
|
+
u : int
|
|
861
|
+
A random integer in the range ``[0, n)``, and which satisfies
|
|
862
|
+
``mask[u] == True``. If all values in the mask are `False`, return `n`.
|
|
863
|
+
"""
|
|
864
|
+
ecdf = jnp.cumsum(mask)
|
|
865
|
+
u = random.randint(key, (), 0, ecdf[-1])
|
|
866
|
+
return jnp.searchsorted(ecdf, u, 'right')
|
|
867
|
+
|
|
868
|
+
def accept_moves_and_sample_leaves(bart, moves, key):
|
|
742
869
|
"""
|
|
743
870
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
744
871
|
|
|
@@ -746,14 +873,8 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
746
873
|
----------
|
|
747
874
|
bart : dict
|
|
748
875
|
A BART mcmc state.
|
|
749
|
-
|
|
750
|
-
The
|
|
751
|
-
`grow_move`.
|
|
752
|
-
prune_moves : dict
|
|
753
|
-
The proposals for prune moves, batched over the first axis. See
|
|
754
|
-
`prune_move`.
|
|
755
|
-
grow_leaf_indices : int array (num_trees, n)
|
|
756
|
-
The leaf indices of the trees proposed by the grow move.
|
|
876
|
+
moves : dict
|
|
877
|
+
The proposed moves, see `sample_moves`.
|
|
757
878
|
key : jax.dtypes.prng_key array
|
|
758
879
|
A jax random key.
|
|
759
880
|
|
|
@@ -762,41 +883,409 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
762
883
|
bart : dict
|
|
763
884
|
The new BART mcmc state.
|
|
764
885
|
"""
|
|
886
|
+
bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
|
|
887
|
+
bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
|
|
888
|
+
return accept_moves_final_stage(bart, moves)
|
|
889
|
+
|
|
890
|
+
def accept_moves_parallel_stage(bart, moves, key):
|
|
891
|
+
"""
|
|
892
|
+
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
893
|
+
|
|
894
|
+
Parameters
|
|
895
|
+
----------
|
|
896
|
+
bart : dict
|
|
897
|
+
A BART mcmc state.
|
|
898
|
+
moves : dict
|
|
899
|
+
The proposed moves, see `sample_moves`.
|
|
900
|
+
key : jax.dtypes.prng_key array
|
|
901
|
+
A jax random key.
|
|
902
|
+
|
|
903
|
+
Returns
|
|
904
|
+
-------
|
|
905
|
+
bart : dict
|
|
906
|
+
A partially updated BART mcmc state.
|
|
907
|
+
moves : dict
|
|
908
|
+
The proposed moves, with the field 'partial_ratio' replaced
|
|
909
|
+
by 'log_trans_prior_ratio'.
|
|
910
|
+
count_trees : array (num_trees, 2 ** d)
|
|
911
|
+
The number of points in each potential or actual leaf node.
|
|
912
|
+
move_counts : dict
|
|
913
|
+
The counts of the number of points in the the nodes modified by the
|
|
914
|
+
moves.
|
|
915
|
+
prelkv, prelk, prelf : dict
|
|
916
|
+
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
917
|
+
samples.
|
|
918
|
+
"""
|
|
919
|
+
bart = bart.copy()
|
|
920
|
+
|
|
921
|
+
# where the move is grow, modify the state like the move was accepted
|
|
922
|
+
bart['var_trees'] = moves['var_trees']
|
|
923
|
+
bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
|
|
924
|
+
bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
|
|
925
|
+
|
|
926
|
+
# count number of datapoints per leaf
|
|
927
|
+
count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
|
|
928
|
+
if bart['opt']['require_min_points']:
|
|
929
|
+
count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
|
|
930
|
+
bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
|
|
931
|
+
|
|
932
|
+
# compute some missing information about moves
|
|
933
|
+
moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
|
|
934
|
+
bart['grow_prop_count'] = jnp.sum(moves['grow'])
|
|
935
|
+
bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
|
|
936
|
+
|
|
937
|
+
prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
|
|
938
|
+
prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
|
|
939
|
+
|
|
940
|
+
return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
|
|
941
|
+
|
|
942
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
|
|
943
|
+
def apply_grow_to_indices(moves, leaf_indices, X):
|
|
944
|
+
"""
|
|
945
|
+
Update the leaf indices to apply a grow move.
|
|
946
|
+
|
|
947
|
+
Parameters
|
|
948
|
+
----------
|
|
949
|
+
moves : dict
|
|
950
|
+
The proposed moves, see `sample_moves`.
|
|
951
|
+
leaf_indices : array (num_trees, n)
|
|
952
|
+
The index of the leaf each datapoint falls into.
|
|
953
|
+
X : array (p, n)
|
|
954
|
+
The predictors matrix.
|
|
955
|
+
|
|
956
|
+
Returns
|
|
957
|
+
-------
|
|
958
|
+
grow_leaf_indices : array (num_trees, n)
|
|
959
|
+
The updated leaf indices.
|
|
960
|
+
"""
|
|
961
|
+
left_child = moves['node'].astype(leaf_indices.dtype) << 1
|
|
962
|
+
go_right = X[moves['grow_var'], :] >= moves['grow_split']
|
|
963
|
+
tree_size = jnp.array(2 * moves['var_trees'].size)
|
|
964
|
+
node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
|
|
965
|
+
return jnp.where(
|
|
966
|
+
leaf_indices == node_to_update,
|
|
967
|
+
left_child + go_right,
|
|
968
|
+
leaf_indices,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
def compute_count_trees(leaf_indices, moves, batch_size):
|
|
972
|
+
"""
|
|
973
|
+
Count the number of datapoints in each leaf.
|
|
974
|
+
|
|
975
|
+
Parameters
|
|
976
|
+
----------
|
|
977
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
978
|
+
The index of the leaf each datapoint falls into, if the grow move is
|
|
979
|
+
accepted.
|
|
980
|
+
moves : dict
|
|
981
|
+
The proposed moves, see `sample_moves`.
|
|
982
|
+
batch_size : int or None
|
|
983
|
+
The data batch size to use for the summation.
|
|
984
|
+
|
|
985
|
+
Returns
|
|
986
|
+
-------
|
|
987
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
988
|
+
The number of points in each potential or actual leaf node.
|
|
989
|
+
counts : dict
|
|
990
|
+
The counts of the number of points in the the nodes modified by the
|
|
991
|
+
moves, organized as two dictionaries 'grow' and 'prune', with subfields
|
|
992
|
+
'left', 'right', and 'total'.
|
|
993
|
+
"""
|
|
994
|
+
|
|
995
|
+
ntree, tree_size = moves['var_trees'].shape
|
|
996
|
+
tree_size *= 2
|
|
997
|
+
tree_indices = jnp.arange(ntree)
|
|
998
|
+
|
|
999
|
+
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
1000
|
+
|
|
1001
|
+
# count datapoints in nodes modified by move
|
|
1002
|
+
counts = dict()
|
|
1003
|
+
counts['left'] = count_trees[tree_indices, moves['left']]
|
|
1004
|
+
counts['right'] = count_trees[tree_indices, moves['right']]
|
|
1005
|
+
counts['total'] = counts['left'] + counts['right']
|
|
1006
|
+
|
|
1007
|
+
# write count into non-leaf node
|
|
1008
|
+
count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
|
|
1009
|
+
|
|
1010
|
+
return count_trees, counts
|
|
1011
|
+
|
|
1012
|
+
def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
1013
|
+
"""
|
|
1014
|
+
Count the number of datapoints in each leaf.
|
|
1015
|
+
|
|
1016
|
+
Parameters
|
|
1017
|
+
----------
|
|
1018
|
+
leaf_indices : int array (num_trees, n)
|
|
1019
|
+
The index of the leaf each datapoint falls into.
|
|
1020
|
+
tree_size : int
|
|
1021
|
+
The size of the leaf tree array (2 ** d).
|
|
1022
|
+
batch_size : int or None
|
|
1023
|
+
The data batch size to use for the summation.
|
|
1024
|
+
|
|
1025
|
+
Returns
|
|
1026
|
+
-------
|
|
1027
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
1028
|
+
The number of points in each leaf node.
|
|
1029
|
+
"""
|
|
1030
|
+
if batch_size is None:
|
|
1031
|
+
return _count_scan(leaf_indices, tree_size)
|
|
1032
|
+
else:
|
|
1033
|
+
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1034
|
+
|
|
1035
|
+
def _count_scan(leaf_indices, tree_size):
|
|
1036
|
+
def loop(_, leaf_indices):
|
|
1037
|
+
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1038
|
+
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1039
|
+
return count_trees
|
|
1040
|
+
|
|
1041
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
1042
|
+
return (jnp
|
|
1043
|
+
.zeros(size, dtype)
|
|
1044
|
+
.at[indices]
|
|
1045
|
+
.add(values)
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
def _count_vec(leaf_indices, tree_size, batch_size):
|
|
1049
|
+
return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
|
|
1050
|
+
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1051
|
+
|
|
1052
|
+
def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
|
|
1053
|
+
ntree, n = indices.shape
|
|
1054
|
+
tree_indices = jnp.arange(ntree)
|
|
1055
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1056
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
1057
|
+
return (jnp
|
|
1058
|
+
.zeros((ntree, size, nbatches), dtype)
|
|
1059
|
+
.at[tree_indices[:, None], indices, batch_indices]
|
|
1060
|
+
.add(values)
|
|
1061
|
+
.sum(axis=2)
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
def complete_ratio(moves, move_counts, min_points_per_leaf):
|
|
1065
|
+
"""
|
|
1066
|
+
Complete non-likelihood MH ratio calculation.
|
|
1067
|
+
|
|
1068
|
+
This functions adds the probability of choosing the prune move.
|
|
1069
|
+
|
|
1070
|
+
Parameters
|
|
1071
|
+
----------
|
|
1072
|
+
moves : dict
|
|
1073
|
+
The proposed moves, see `sample_moves`.
|
|
1074
|
+
move_counts : dict
|
|
1075
|
+
The counts of the number of points in the the nodes modified by the
|
|
1076
|
+
moves.
|
|
1077
|
+
min_points_per_leaf : int or None
|
|
1078
|
+
The minimum number of data points in a leaf node.
|
|
1079
|
+
|
|
1080
|
+
Returns
|
|
1081
|
+
-------
|
|
1082
|
+
moves : dict
|
|
1083
|
+
The updated moves, with the field 'partial_ratio' replaced by
|
|
1084
|
+
'log_trans_prior_ratio'.
|
|
1085
|
+
"""
|
|
1086
|
+
moves = moves.copy()
|
|
1087
|
+
p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
|
|
1088
|
+
moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
|
|
1089
|
+
return moves
|
|
1090
|
+
|
|
1091
|
+
def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
|
|
1092
|
+
"""
|
|
1093
|
+
Compute the probability of proposing a prune move.
|
|
1094
|
+
|
|
1095
|
+
Parameters
|
|
1096
|
+
----------
|
|
1097
|
+
moves : dict
|
|
1098
|
+
The proposed moves, see `sample_moves`.
|
|
1099
|
+
left_count, right_count : int
|
|
1100
|
+
The number of datapoints in the proposed children of the leaf to grow.
|
|
1101
|
+
min_points_per_leaf : int or None
|
|
1102
|
+
The minimum number of data points in a leaf node.
|
|
1103
|
+
|
|
1104
|
+
Returns
|
|
1105
|
+
-------
|
|
1106
|
+
p_prune : float
|
|
1107
|
+
The probability of proposing a prune move. If grow: after accepting the
|
|
1108
|
+
grow move, if prune: right away.
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
# calculation in case the move is grow
|
|
1112
|
+
other_growable_leaves = moves['num_growable'] >= 2
|
|
1113
|
+
new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
|
|
1114
|
+
if min_points_per_leaf is not None:
|
|
1115
|
+
any_above_threshold = left_count >= 2 * min_points_per_leaf
|
|
1116
|
+
any_above_threshold |= right_count >= 2 * min_points_per_leaf
|
|
1117
|
+
new_leaves_growable &= any_above_threshold
|
|
1118
|
+
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1119
|
+
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1120
|
+
|
|
1121
|
+
# calculation in case the move is prune
|
|
1122
|
+
prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
|
|
1123
|
+
|
|
1124
|
+
return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
|
|
1125
|
+
|
|
1126
|
+
@jaxext.vmap_nodoc
|
|
1127
|
+
def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
|
|
1128
|
+
"""
|
|
1129
|
+
Modify leaf values such that the indices of the grow moves work on the
|
|
1130
|
+
original tree.
|
|
1131
|
+
|
|
1132
|
+
Parameters
|
|
1133
|
+
----------
|
|
1134
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1135
|
+
The leaf values.
|
|
1136
|
+
moves : dict
|
|
1137
|
+
The proposed moves, see `sample_moves`.
|
|
1138
|
+
|
|
1139
|
+
Returns
|
|
1140
|
+
-------
|
|
1141
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1142
|
+
The modified leaf values. The value of the leaf to grow is copied to
|
|
1143
|
+
what would be its children if the grow move was accepted.
|
|
1144
|
+
"""
|
|
1145
|
+
values_at_node = leaf_trees[moves['node']]
|
|
1146
|
+
return (leaf_trees
|
|
1147
|
+
.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
|
|
1148
|
+
.set(values_at_node)
|
|
1149
|
+
.at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
|
|
1150
|
+
.set(values_at_node)
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
def precompute_likelihood_terms(count_trees, sigma2, move_counts):
|
|
1154
|
+
"""
|
|
1155
|
+
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1156
|
+
|
|
1157
|
+
Parameters
|
|
1158
|
+
----------
|
|
1159
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1160
|
+
The number of points in each potential or actual leaf node.
|
|
1161
|
+
sigma2 : float
|
|
1162
|
+
The noise variance.
|
|
1163
|
+
move_counts : dict
|
|
1164
|
+
The counts of the number of points in the the nodes modified by the
|
|
1165
|
+
moves.
|
|
1166
|
+
|
|
1167
|
+
Returns
|
|
1168
|
+
-------
|
|
1169
|
+
prelkv : dict
|
|
1170
|
+
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
1171
|
+
tree.
|
|
1172
|
+
prelk : dict
|
|
1173
|
+
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1174
|
+
all trees.
|
|
1175
|
+
"""
|
|
1176
|
+
ntree = len(count_trees)
|
|
1177
|
+
sigma_mu2 = 1 / ntree
|
|
1178
|
+
prelkv = dict()
|
|
1179
|
+
prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
|
|
1180
|
+
prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
|
|
1181
|
+
prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
|
|
1182
|
+
prelkv['sqrt_term'] = jnp.log(
|
|
1183
|
+
sigma2 * prelkv['sigma2_total'] /
|
|
1184
|
+
(prelkv['sigma2_left'] * prelkv['sigma2_right'])
|
|
1185
|
+
) / 2
|
|
1186
|
+
return prelkv, dict(
|
|
1187
|
+
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
def precompute_leaf_terms(count_trees, sigma2, key):
|
|
1191
|
+
"""
|
|
1192
|
+
Pre-compute terms used to sample leaves from their posterior.
|
|
1193
|
+
|
|
1194
|
+
Parameters
|
|
1195
|
+
----------
|
|
1196
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1197
|
+
The number of points in each potential or actual leaf node.
|
|
1198
|
+
sigma2 : float
|
|
1199
|
+
The noise variance.
|
|
1200
|
+
key : jax.dtypes.prng_key array
|
|
1201
|
+
A jax random key.
|
|
1202
|
+
|
|
1203
|
+
Returns
|
|
1204
|
+
-------
|
|
1205
|
+
prelf : dict
|
|
1206
|
+
Dictionary with pre-computed terms of the leaf sampling, with fields:
|
|
1207
|
+
|
|
1208
|
+
'mean_factor' : float array (num_trees, 2 ** d)
|
|
1209
|
+
The factor to be multiplied by the sum of residuals to obtain the
|
|
1210
|
+
posterior mean.
|
|
1211
|
+
'centered_leaves' : float array (num_trees, 2 ** d)
|
|
1212
|
+
The mean-zero normal values to be added to the posterior mean to
|
|
1213
|
+
obtain the posterior leaf samples.
|
|
1214
|
+
"""
|
|
1215
|
+
ntree = len(count_trees)
|
|
1216
|
+
prec_lk = count_trees / sigma2
|
|
1217
|
+
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
1218
|
+
z = random.normal(key, count_trees.shape, sigma2.dtype)
|
|
1219
|
+
return dict(
|
|
1220
|
+
mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
|
|
1221
|
+
centered_leaves=z * jnp.sqrt(var_post),
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
|
|
1225
|
+
"""
|
|
1226
|
+
The part of accepting the moves that has to be done one tree at a time.
|
|
1227
|
+
|
|
1228
|
+
Parameters
|
|
1229
|
+
----------
|
|
1230
|
+
bart : dict
|
|
1231
|
+
A partially updated BART mcmc state.
|
|
1232
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1233
|
+
The number of points in each potential or actual leaf node.
|
|
1234
|
+
moves : dict
|
|
1235
|
+
The proposed moves, see `sample_moves`.
|
|
1236
|
+
move_counts : dict
|
|
1237
|
+
The counts of the number of points in the the nodes modified by the
|
|
1238
|
+
moves.
|
|
1239
|
+
prelkv, prelk, prelf : dict
|
|
1240
|
+
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1241
|
+
samples.
|
|
1242
|
+
|
|
1243
|
+
Returns
|
|
1244
|
+
-------
|
|
1245
|
+
bart : dict
|
|
1246
|
+
A partially updated BART mcmc state.
|
|
1247
|
+
moves : dict
|
|
1248
|
+
The proposed moves, with these additional fields:
|
|
1249
|
+
|
|
1250
|
+
'acc' : bool array (num_trees,)
|
|
1251
|
+
Whether the move was accepted.
|
|
1252
|
+
'to_prune' : bool array (num_trees,)
|
|
1253
|
+
Whether, to reflect the acceptance status of the move, the state
|
|
1254
|
+
should be updated by pruning the leaves involved in the move.
|
|
1255
|
+
"""
|
|
765
1256
|
bart = bart.copy()
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
1257
|
+
moves = moves.copy()
|
|
1258
|
+
|
|
1259
|
+
def loop(resid, item):
|
|
1260
|
+
resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
|
|
769
1261
|
bart['X'],
|
|
770
1262
|
len(bart['leaf_trees']),
|
|
771
|
-
bart['opt']['
|
|
1263
|
+
bart['opt']['resid_batch_size'],
|
|
772
1264
|
resid,
|
|
773
|
-
bart['sigma2'],
|
|
774
1265
|
bart['min_points_per_leaf'],
|
|
775
|
-
|
|
1266
|
+
'ratios' in bart,
|
|
1267
|
+
prelk,
|
|
776
1268
|
*item,
|
|
777
1269
|
)
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
carry = {
|
|
781
|
-
k: jnp.zeros_like(bart[k]) for k in
|
|
782
|
-
['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
|
|
783
|
-
}
|
|
784
|
-
carry['resid'] = bart['resid']
|
|
1270
|
+
return resid, (leaf_tree, acc, to_prune, ratios)
|
|
1271
|
+
|
|
785
1272
|
items = (
|
|
786
|
-
bart['leaf_trees'],
|
|
787
|
-
|
|
788
|
-
bart['
|
|
789
|
-
|
|
790
|
-
prune_moves,
|
|
791
|
-
grow_leaf_indices,
|
|
792
|
-
random.split(key, len(bart['leaf_trees'])),
|
|
1273
|
+
bart['leaf_trees'], count_trees,
|
|
1274
|
+
moves, move_counts,
|
|
1275
|
+
bart['leaf_indices'],
|
|
1276
|
+
prelkv, prelf,
|
|
793
1277
|
)
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
bart
|
|
797
|
-
|
|
1278
|
+
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
|
|
1279
|
+
|
|
1280
|
+
bart['resid'] = resid
|
|
1281
|
+
bart['leaf_trees'] = leaf_trees
|
|
1282
|
+
bart.get('ratios', {}).update(ratios)
|
|
1283
|
+
moves['acc'] = acc
|
|
1284
|
+
moves['to_prune'] = to_prune
|
|
1285
|
+
|
|
1286
|
+
return bart, moves
|
|
798
1287
|
|
|
799
|
-
def accept_move_and_sample_leaves(X, ntree,
|
|
1288
|
+
def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
|
|
800
1289
|
"""
|
|
801
1290
|
Accept or reject a proposed move and sample the new leaf values.
|
|
802
1291
|
|
|
@@ -806,158 +1295,102 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
|
|
|
806
1295
|
The predictors.
|
|
807
1296
|
ntree : int
|
|
808
1297
|
The number of trees in the forest.
|
|
809
|
-
|
|
810
|
-
The batch size for computing
|
|
1298
|
+
resid_batch_size : int, None
|
|
1299
|
+
The batch size for computing the sum of residuals in each leaf.
|
|
811
1300
|
resid : float array (n,)
|
|
812
1301
|
The residuals (data minus forest value).
|
|
813
|
-
sigma2 : float
|
|
814
|
-
The noise variance.
|
|
815
1302
|
min_points_per_leaf : int or None
|
|
816
1303
|
The minimum number of data points in a leaf node.
|
|
817
|
-
|
|
818
|
-
|
|
1304
|
+
save_ratios : bool
|
|
1305
|
+
Whether to save the acceptance ratios.
|
|
1306
|
+
prelk : dict
|
|
1307
|
+
The pre-computed terms of the likelihood ratio which are shared across
|
|
1308
|
+
trees.
|
|
819
1309
|
leaf_tree : float array (2 ** d,)
|
|
820
1310
|
The leaf values of the tree.
|
|
821
|
-
|
|
822
|
-
The
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
The
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
key : jax.dtypes.prng_key array
|
|
832
|
-
A jax random key.
|
|
1311
|
+
count_tree : int array (2 ** d,)
|
|
1312
|
+
The number of datapoints in each leaf.
|
|
1313
|
+
move : dict
|
|
1314
|
+
The proposed move, see `sample_moves`.
|
|
1315
|
+
leaf_indices : int array (n,)
|
|
1316
|
+
The leaf indices for the largest version of the tree compatible with
|
|
1317
|
+
the move.
|
|
1318
|
+
prelkv, prelf : dict
|
|
1319
|
+
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
1320
|
+
are specific to the tree.
|
|
833
1321
|
|
|
834
1322
|
Returns
|
|
835
1323
|
-------
|
|
836
1324
|
resid : float array (n,)
|
|
837
1325
|
The updated residuals (data minus forest value).
|
|
838
|
-
|
|
839
|
-
The
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
leaf_indices = jnp.where(
|
|
849
|
-
(grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
|
|
850
|
-
grow_node,
|
|
851
|
-
grow_leaf_indices,
|
|
852
|
-
)
|
|
1326
|
+
leaf_tree : float array (2 ** d,)
|
|
1327
|
+
The new leaf values of the tree.
|
|
1328
|
+
acc : bool
|
|
1329
|
+
Whether the move was accepted.
|
|
1330
|
+
to_prune : bool
|
|
1331
|
+
Whether, to reflect the acceptance status of the move, the state should
|
|
1332
|
+
be updated by pruning the leaves involved in the move.
|
|
1333
|
+
ratios : dict
|
|
1334
|
+
The acceptance ratios for the moves. Empty if not to be saved.
|
|
1335
|
+
"""
|
|
853
1336
|
|
|
854
|
-
#
|
|
855
|
-
|
|
856
|
-
prune_left = prune_node << 1
|
|
857
|
-
prune_right = prune_left + 1
|
|
858
|
-
prune_leaf_indices = jnp.where(
|
|
859
|
-
(leaf_indices == prune_left) | (leaf_indices == prune_right),
|
|
860
|
-
prune_node,
|
|
861
|
-
leaf_indices,
|
|
862
|
-
)
|
|
1337
|
+
# sum residuals and count units per leaf, in tree proposed by grow move
|
|
1338
|
+
resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
|
|
863
1339
|
|
|
864
1340
|
# subtract starting tree from function
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
#
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
if
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
|
|
890
|
-
prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
|
|
891
|
-
|
|
892
|
-
# compute likelihood ratios
|
|
893
|
-
grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
|
|
894
|
-
prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
|
|
895
|
-
|
|
896
|
-
# compute acceptance ratios
|
|
897
|
-
grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
|
|
898
|
-
prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
|
|
899
|
-
prune_ratio = lax.reciprocal(prune_ratio)
|
|
900
|
-
|
|
901
|
-
# random coins in [0, 1) for proposal and acceptance
|
|
902
|
-
key, subkey = random.split(key)
|
|
903
|
-
u0, u1 = random.uniform(subkey, (2,))
|
|
904
|
-
|
|
905
|
-
# determine what move to propose (not proposing anything is an option)
|
|
906
|
-
p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
|
|
907
|
-
try_grow = u0 < p_grow
|
|
908
|
-
try_prune = prune_move['allowed'] & ~try_grow
|
|
1341
|
+
resid_tree += count_tree * leaf_tree
|
|
1342
|
+
|
|
1343
|
+
# get indices of move
|
|
1344
|
+
node = move['node']
|
|
1345
|
+
assert node.dtype == jnp.int32
|
|
1346
|
+
left = move['left']
|
|
1347
|
+
right = move['right']
|
|
1348
|
+
|
|
1349
|
+
# sum residuals in parent node modified by move
|
|
1350
|
+
resid_left = resid_tree[left]
|
|
1351
|
+
resid_right = resid_tree[right]
|
|
1352
|
+
resid_total = resid_left + resid_right
|
|
1353
|
+
resid_tree = resid_tree.at[node].set(resid_total)
|
|
1354
|
+
|
|
1355
|
+
# compute acceptance ratio
|
|
1356
|
+
log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
|
|
1357
|
+
log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
|
|
1358
|
+
log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
|
|
1359
|
+
ratios = {}
|
|
1360
|
+
if save_ratios:
|
|
1361
|
+
ratios.update(
|
|
1362
|
+
log_trans_prior=move['log_trans_prior_ratio'],
|
|
1363
|
+
log_likelihood=log_lk_ratio,
|
|
1364
|
+
)
|
|
909
1365
|
|
|
910
1366
|
# determine whether to accept the move
|
|
911
|
-
|
|
912
|
-
do_prune = try_prune & (u1 < prune_ratio)
|
|
913
|
-
|
|
914
|
-
# pick trees for chosen move
|
|
915
|
-
trees = {}
|
|
916
|
-
split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
|
|
917
|
-
# the prune var tree is equal to the initial one, because I leave garbage values behind
|
|
918
|
-
split_tree = split_tree.at[prune_node].set(
|
|
919
|
-
jnp.where(do_prune, 0, split_tree[prune_node]))
|
|
1367
|
+
acc = move['allowed'] & (move['logu'] <= log_ratio)
|
|
920
1368
|
if min_points_per_leaf is not None:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
938
|
-
mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
939
|
-
|
|
940
|
-
# sample leaves
|
|
941
|
-
z = random.normal(key, mean_post.shape, mean_post.dtype)
|
|
942
|
-
leaf_tree = mean_post + z * jnp.sqrt(var_post)
|
|
943
|
-
|
|
944
|
-
# add new tree to function
|
|
945
|
-
leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
|
|
946
|
-
leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
|
|
947
|
-
resid -= leaf_tree[leaf_indices]
|
|
1369
|
+
acc &= move_counts['left'] >= min_points_per_leaf
|
|
1370
|
+
acc &= move_counts['right'] >= min_points_per_leaf
|
|
1371
|
+
|
|
1372
|
+
# compute leaves posterior and sample leaves
|
|
1373
|
+
initial_leaf_tree = leaf_tree
|
|
1374
|
+
mean_post = resid_tree * prelf['mean_factor']
|
|
1375
|
+
leaf_tree = mean_post + prelf['centered_leaves']
|
|
1376
|
+
|
|
1377
|
+
# copy leaves around such that the leaf indices select the right leaf
|
|
1378
|
+
to_prune = acc ^ move['grow']
|
|
1379
|
+
leaf_tree = (leaf_tree
|
|
1380
|
+
.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1381
|
+
.set(leaf_tree[node])
|
|
1382
|
+
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
1383
|
+
.set(leaf_tree[node])
|
|
1384
|
+
)
|
|
948
1385
|
|
|
949
|
-
#
|
|
950
|
-
|
|
951
|
-
'leaf_trees': leaf_tree,
|
|
952
|
-
'split_trees': split_tree,
|
|
953
|
-
'affluence_trees': affluence_tree,
|
|
954
|
-
}
|
|
1386
|
+
# replace old tree with new tree in function values
|
|
1387
|
+
resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
|
|
955
1388
|
|
|
956
|
-
return resid,
|
|
1389
|
+
return resid, leaf_tree, acc, to_prune, ratios
|
|
957
1390
|
|
|
958
|
-
def
|
|
1391
|
+
def sum_resid(resid, leaf_indices, tree_size, batch_size):
|
|
959
1392
|
"""
|
|
960
|
-
|
|
1393
|
+
Sum the residuals in each leaf.
|
|
961
1394
|
|
|
962
1395
|
Parameters
|
|
963
1396
|
----------
|
|
@@ -968,124 +1401,142 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
|
968
1401
|
tree_size : int
|
|
969
1402
|
The size of the tree array (2 ** d).
|
|
970
1403
|
batch_size : int, None
|
|
971
|
-
The batch size for the aggregation. Batching increases numerical
|
|
1404
|
+
The data batch size for the aggregation. Batching increases numerical
|
|
972
1405
|
accuracy and parallelism.
|
|
973
1406
|
|
|
974
1407
|
Returns
|
|
975
1408
|
-------
|
|
976
1409
|
resid_tree : float array (2 ** d,)
|
|
977
1410
|
The sum of the residuals at data points in each leaf.
|
|
978
|
-
count_tree : int array (2 ** d,)
|
|
979
|
-
The number of data points in each leaf.
|
|
980
1411
|
"""
|
|
981
1412
|
if batch_size is None:
|
|
982
1413
|
aggr_func = _aggregate_scatter
|
|
983
1414
|
else:
|
|
984
|
-
aggr_func = functools.partial(
|
|
985
|
-
|
|
986
|
-
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
987
|
-
return resid_tree, count_tree
|
|
1415
|
+
aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
1416
|
+
return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
988
1417
|
|
|
989
|
-
def
|
|
1418
|
+
def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
1419
|
+
n, = indices.shape
|
|
1420
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1421
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
990
1422
|
return (jnp
|
|
991
|
-
.zeros(size, dtype)
|
|
992
|
-
.at[indices]
|
|
1423
|
+
.zeros((size, nbatches), dtype)
|
|
1424
|
+
.at[indices, batch_indices]
|
|
993
1425
|
.add(values)
|
|
1426
|
+
.sum(axis=1)
|
|
994
1427
|
)
|
|
995
1428
|
|
|
996
|
-
def
|
|
997
|
-
nbatches = indices.size // batch_size + bool(indices.size % batch_size)
|
|
998
|
-
batch_indices = jnp.arange(indices.size) // batch_size
|
|
999
|
-
return (jnp
|
|
1000
|
-
.zeros((nbatches, size), dtype)
|
|
1001
|
-
.at[batch_indices, indices]
|
|
1002
|
-
.add(values)
|
|
1003
|
-
.sum(axis=0)
|
|
1004
|
-
)
|
|
1005
|
-
|
|
1006
|
-
def compute_p_prune_back(new_split_tree, new_affluence_tree):
|
|
1429
|
+
def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
|
|
1007
1430
|
"""
|
|
1008
|
-
Compute the
|
|
1431
|
+
Compute the likelihood ratio of a grow move.
|
|
1009
1432
|
|
|
1010
1433
|
Parameters
|
|
1011
1434
|
----------
|
|
1012
|
-
|
|
1013
|
-
The
|
|
1014
|
-
|
|
1015
|
-
|
|
1435
|
+
total_resid : float
|
|
1436
|
+
The sum of the residuals in the leaf to grow.
|
|
1437
|
+
left_resid, right_resid : float
|
|
1438
|
+
The sum of the residuals in the left/right child of the leaf to grow.
|
|
1439
|
+
prelkv, prelk : dict
|
|
1440
|
+
The pre-computed terms of the likelihood ratio, see
|
|
1441
|
+
`precompute_likelihood_terms`.
|
|
1016
1442
|
|
|
1017
1443
|
Returns
|
|
1018
1444
|
-------
|
|
1019
|
-
|
|
1020
|
-
The
|
|
1021
|
-
0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
|
|
1022
|
-
at least the node just grown can be pruned.
|
|
1445
|
+
ratio : float
|
|
1446
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1023
1447
|
"""
|
|
1024
|
-
|
|
1025
|
-
|
|
1448
|
+
exp_term = prelk['exp_factor'] * (
|
|
1449
|
+
left_resid * left_resid / prelkv['sigma2_left'] +
|
|
1450
|
+
right_resid * right_resid / prelkv['sigma2_right'] -
|
|
1451
|
+
total_resid * total_resid / prelkv['sigma2_total']
|
|
1452
|
+
)
|
|
1453
|
+
return prelkv['sqrt_term'] + exp_term
|
|
1026
1454
|
|
|
1027
|
-
def
|
|
1455
|
+
def accept_moves_final_stage(bart, moves):
|
|
1028
1456
|
"""
|
|
1029
|
-
|
|
1457
|
+
The final part of accepting the moves, in parallel across trees.
|
|
1030
1458
|
|
|
1031
1459
|
Parameters
|
|
1032
1460
|
----------
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
The
|
|
1037
|
-
|
|
1038
|
-
The
|
|
1039
|
-
|
|
1040
|
-
The index of the leaf that has been grown.
|
|
1041
|
-
n_tree : int
|
|
1042
|
-
The number of trees in the forest.
|
|
1043
|
-
min_points_per_leaf : int or None
|
|
1044
|
-
The minimum number of data points in a leaf node.
|
|
1461
|
+
bart : dict
|
|
1462
|
+
A partially updated BART mcmc state.
|
|
1463
|
+
counts : dict
|
|
1464
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1465
|
+
moves : dict
|
|
1466
|
+
The proposed moves (see `sample_moves`) as updated by
|
|
1467
|
+
`accept_moves_sequential_stage`.
|
|
1045
1468
|
|
|
1046
1469
|
Returns
|
|
1047
1470
|
-------
|
|
1048
|
-
|
|
1049
|
-
The
|
|
1050
|
-
|
|
1051
|
-
Notes
|
|
1052
|
-
-----
|
|
1053
|
-
The ratio is set to 0 if the grow move would create leaves with not enough
|
|
1054
|
-
datapoints per leaf, although this is part of the prior rather than the
|
|
1055
|
-
likelihood.
|
|
1471
|
+
bart : dict
|
|
1472
|
+
The fully updated BART mcmc state.
|
|
1056
1473
|
"""
|
|
1474
|
+
bart = bart.copy()
|
|
1475
|
+
bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
|
|
1476
|
+
bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
|
|
1477
|
+
bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
|
|
1478
|
+
bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
|
|
1479
|
+
return bart
|
|
1057
1480
|
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
right_resid = resid_tree[right_child]
|
|
1063
|
-
total_resid = left_resid + right_resid
|
|
1064
|
-
|
|
1065
|
-
left_count = count_tree[left_child]
|
|
1066
|
-
right_count = count_tree[right_child]
|
|
1067
|
-
total_count = left_count + right_count
|
|
1068
|
-
|
|
1069
|
-
sigma_mu2 = 1 / n_tree
|
|
1070
|
-
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
1071
|
-
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
1072
|
-
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
1481
|
+
@jax.vmap
|
|
1482
|
+
def apply_moves_to_leaf_indices(leaf_indices, moves):
|
|
1483
|
+
"""
|
|
1484
|
+
Update the leaf indices to match the accepted move.
|
|
1073
1485
|
|
|
1074
|
-
|
|
1486
|
+
Parameters
|
|
1487
|
+
----------
|
|
1488
|
+
leaf_indices : int array (num_trees, n)
|
|
1489
|
+
The index of the leaf each datapoint falls into, if the grow move was
|
|
1490
|
+
accepted.
|
|
1491
|
+
moves : dict
|
|
1492
|
+
The proposed moves (see `sample_moves`), as updated by
|
|
1493
|
+
`accept_moves_sequential_stage`.
|
|
1075
1494
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1495
|
+
Returns
|
|
1496
|
+
-------
|
|
1497
|
+
leaf_indices : int array (num_trees, n)
|
|
1498
|
+
The updated leaf indices.
|
|
1499
|
+
"""
|
|
1500
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1501
|
+
is_child = (leaf_indices & mask) == moves['left']
|
|
1502
|
+
return jnp.where(
|
|
1503
|
+
is_child & moves['to_prune'],
|
|
1504
|
+
moves['node'].astype(leaf_indices.dtype),
|
|
1505
|
+
leaf_indices,
|
|
1080
1506
|
)
|
|
1081
1507
|
|
|
1082
|
-
|
|
1508
|
+
@jax.vmap
|
|
1509
|
+
def apply_moves_to_split_trees(split_trees, moves):
|
|
1510
|
+
"""
|
|
1511
|
+
Update the split trees to match the accepted move.
|
|
1083
1512
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1513
|
+
Parameters
|
|
1514
|
+
----------
|
|
1515
|
+
split_trees : int array (num_trees, 2 ** (d - 1))
|
|
1516
|
+
The cutpoints of the decision nodes in the initial trees.
|
|
1517
|
+
moves : dict
|
|
1518
|
+
The proposed moves (see `sample_moves`), as updated by
|
|
1519
|
+
`accept_moves_sequential_stage`.
|
|
1087
1520
|
|
|
1088
|
-
|
|
1521
|
+
Returns
|
|
1522
|
+
-------
|
|
1523
|
+
split_trees : int array (num_trees, 2 ** (d - 1))
|
|
1524
|
+
The updated split trees.
|
|
1525
|
+
"""
|
|
1526
|
+
return (split_trees
|
|
1527
|
+
.at[jnp.where(
|
|
1528
|
+
moves['grow'],
|
|
1529
|
+
moves['node'],
|
|
1530
|
+
split_trees.size,
|
|
1531
|
+
)]
|
|
1532
|
+
.set(moves['grow_split'].astype(split_trees.dtype))
|
|
1533
|
+
.at[jnp.where(
|
|
1534
|
+
moves['to_prune'],
|
|
1535
|
+
moves['node'],
|
|
1536
|
+
split_trees.size,
|
|
1537
|
+
)]
|
|
1538
|
+
.set(0)
|
|
1539
|
+
)
|
|
1089
1540
|
|
|
1090
1541
|
def sample_sigma(bart, key):
|
|
1091
1542
|
"""
|
|
@@ -1107,7 +1558,7 @@ def sample_sigma(bart, key):
|
|
|
1107
1558
|
|
|
1108
1559
|
resid = bart['resid']
|
|
1109
1560
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1110
|
-
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['
|
|
1561
|
+
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
|
|
1111
1562
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1112
1563
|
|
|
1113
1564
|
sample = random.gamma(key, alpha)
|