bartz 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +43 -18
- bartz/_version.py +1 -1
- bartz/grove.py +19 -14
- bartz/jaxext.py +48 -21
- bartz/mcmcloop.py +13 -15
- bartz/mcmcstep.py +687 -297
- bartz/prepcovars.py +43 -13
- bartz-0.3.0.dist-info/METADATA +77 -0
- bartz-0.3.0.dist-info/RECORD +13 -0
- bartz-0.2.0.dist-info/METADATA +0 -32
- bartz-0.2.0.dist-info/RECORD +0 -13
- {bartz-0.2.0.dist-info → bartz-0.3.0.dist-info}/LICENSE +0 -0
- {bartz-0.2.0.dist-info → bartz-0.3.0.dist-info}/WHEEL +0 -0
bartz/mcmcstep.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -34,6 +34,7 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
+
import math
|
|
37
38
|
|
|
38
39
|
import jax
|
|
39
40
|
from jax import random
|
|
@@ -54,7 +55,9 @@ def init(*,
|
|
|
54
55
|
small_float=jnp.float32,
|
|
55
56
|
large_float=jnp.float32,
|
|
56
57
|
min_points_per_leaf=None,
|
|
57
|
-
|
|
58
|
+
resid_batch_size='auto',
|
|
59
|
+
count_batch_size='auto',
|
|
60
|
+
save_ratios=False,
|
|
58
61
|
):
|
|
59
62
|
"""
|
|
60
63
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -82,9 +85,13 @@ def init(*,
|
|
|
82
85
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
83
86
|
min_points_per_leaf : int, optional
|
|
84
87
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
-
|
|
86
|
-
The batch
|
|
87
|
-
|
|
88
|
+
resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
|
|
89
|
+
The batch sizes, along datapoints, for summing the residuals and
|
|
90
|
+
counting the number of datapoints in each leaf. `None` for no batching.
|
|
91
|
+
If 'auto', pick a value based on the device of `y`, or the default
|
|
92
|
+
device.
|
|
93
|
+
save_ratios : bool, default False
|
|
94
|
+
Whether to save the Metropolis-Hastings ratios.
|
|
88
95
|
|
|
89
96
|
Returns
|
|
90
97
|
-------
|
|
@@ -110,6 +117,8 @@ def init(*,
|
|
|
110
117
|
'p_nonterminal' : large_float array (d,)
|
|
111
118
|
The probability of a nonterminal node at each depth, padded with a
|
|
112
119
|
zero.
|
|
120
|
+
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
121
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
113
122
|
'sigma2_alpha' : large_float
|
|
114
123
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
115
124
|
'sigma2_beta' : large_float
|
|
@@ -120,6 +129,8 @@ def init(*,
|
|
|
120
129
|
The response.
|
|
121
130
|
'X' : int array (p, n)
|
|
122
131
|
The predictors.
|
|
132
|
+
'leaf_indices' : int array (num_trees, n)
|
|
133
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
123
134
|
'min_points_per_leaf' : int or None
|
|
124
135
|
The minimum number of data points in a leaf node.
|
|
125
136
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
@@ -128,8 +139,6 @@ def init(*,
|
|
|
128
139
|
'opt' : LeafDict
|
|
129
140
|
A dictionary with config values:
|
|
130
141
|
|
|
131
|
-
'suffstat_batch_size' : int or None
|
|
132
|
-
The batch size for computing sufficient statistics.
|
|
133
142
|
'small_float' : dtype
|
|
134
143
|
The dtype for large arrays used in the algorithm.
|
|
135
144
|
'large_float' : dtype
|
|
@@ -137,6 +146,8 @@ def init(*,
|
|
|
137
146
|
accuracy.
|
|
138
147
|
'require_min_points' : bool
|
|
139
148
|
Whether the `min_points_per_leaf` parameter is specified.
|
|
149
|
+
'resid_batch_size', 'count_batch_size' : int or None
|
|
150
|
+
The data batch sizes for computing the sufficient statistics.
|
|
140
151
|
"""
|
|
141
152
|
|
|
142
153
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
@@ -150,24 +161,28 @@ def init(*,
|
|
|
150
161
|
small_float = jnp.dtype(small_float)
|
|
151
162
|
large_float = jnp.dtype(large_float)
|
|
152
163
|
y = jnp.asarray(y, small_float)
|
|
153
|
-
|
|
164
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
|
|
165
|
+
sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
|
|
166
|
+
sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
154
167
|
|
|
155
168
|
bart = dict(
|
|
156
169
|
leaf_trees=make_forest(max_depth, small_float),
|
|
157
170
|
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
158
171
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
159
172
|
resid=jnp.asarray(y, large_float),
|
|
160
|
-
sigma2=
|
|
173
|
+
sigma2=sigma2,
|
|
161
174
|
grow_prop_count=jnp.zeros((), int),
|
|
162
175
|
grow_acc_count=jnp.zeros((), int),
|
|
163
176
|
prune_prop_count=jnp.zeros((), int),
|
|
164
177
|
prune_acc_count=jnp.zeros((), int),
|
|
165
178
|
p_nonterminal=p_nonterminal,
|
|
179
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
166
180
|
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
167
181
|
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
168
182
|
max_split=jnp.asarray(max_split),
|
|
169
183
|
y=y,
|
|
170
184
|
X=jnp.asarray(X),
|
|
185
|
+
leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
|
|
171
186
|
min_points_per_leaf=(
|
|
172
187
|
None if min_points_per_leaf is None else
|
|
173
188
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -177,30 +192,61 @@ def init(*,
|
|
|
177
192
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
178
193
|
),
|
|
179
194
|
opt=jaxext.LeafDict(
|
|
180
|
-
suffstat_batch_size=suffstat_batch_size,
|
|
181
195
|
small_float=small_float,
|
|
182
196
|
large_float=large_float,
|
|
183
197
|
require_min_points=min_points_per_leaf is not None,
|
|
198
|
+
resid_batch_size=resid_batch_size,
|
|
199
|
+
count_batch_size=count_batch_size,
|
|
184
200
|
),
|
|
185
201
|
)
|
|
186
202
|
|
|
203
|
+
if save_ratios:
|
|
204
|
+
bart['ratios'] = dict(
|
|
205
|
+
grow=dict(
|
|
206
|
+
trans_prior=jnp.full(num_trees, jnp.nan),
|
|
207
|
+
likelihood=jnp.full(num_trees, jnp.nan),
|
|
208
|
+
),
|
|
209
|
+
prune=dict(
|
|
210
|
+
trans_prior=jnp.full(num_trees, jnp.nan),
|
|
211
|
+
likelihood=jnp.full(num_trees, jnp.nan),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
|
|
187
215
|
return bart
|
|
188
216
|
|
|
189
|
-
def _choose_suffstat_batch_size(
|
|
190
|
-
|
|
191
|
-
|
|
217
|
+
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
|
|
218
|
+
|
|
219
|
+
@functools.cache
|
|
220
|
+
def get_platform():
|
|
221
|
+
try:
|
|
222
|
+
device = y.devices().pop()
|
|
223
|
+
except jax.errors.ConcretizationTypeError:
|
|
224
|
+
device = jax.devices()[0]
|
|
225
|
+
platform = device.platform
|
|
226
|
+
if platform not in ('cpu', 'gpu'):
|
|
227
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
228
|
+
return platform
|
|
229
|
+
|
|
230
|
+
if resid_batch_size == 'auto':
|
|
231
|
+
platform = get_platform()
|
|
232
|
+
n = max(1, y.size)
|
|
192
233
|
if platform == 'cpu':
|
|
193
|
-
|
|
194
|
-
# maybe I should batch residuals (not counts) for numerical
|
|
195
|
-
# accuracy, even if it's slower
|
|
234
|
+
resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
|
|
196
235
|
elif platform == 'gpu':
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
236
|
+
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
|
|
237
|
+
resid_batch_size = max(1, resid_batch_size)
|
|
238
|
+
|
|
239
|
+
if count_batch_size == 'auto':
|
|
240
|
+
platform = get_platform()
|
|
241
|
+
if platform == 'cpu':
|
|
242
|
+
count_batch_size = None
|
|
243
|
+
elif platform == 'gpu':
|
|
244
|
+
n = max(1, y.size)
|
|
245
|
+
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
246
|
+
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
247
|
+
count_batch_size = max(1, count_batch_size)
|
|
248
|
+
|
|
249
|
+
return resid_batch_size, count_batch_size
|
|
204
250
|
|
|
205
251
|
def step(bart, key):
|
|
206
252
|
"""
|
|
@@ -240,14 +286,11 @@ def sample_trees(bart, key):
|
|
|
240
286
|
|
|
241
287
|
Notes
|
|
242
288
|
-----
|
|
243
|
-
This function zeroes the proposal counters
|
|
289
|
+
This function zeroes the proposal counters.
|
|
244
290
|
"""
|
|
245
|
-
bart = bart.copy()
|
|
246
291
|
key, subkey = random.split(key)
|
|
247
292
|
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
248
|
-
bart
|
|
249
|
-
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
250
|
-
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
293
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
|
|
251
294
|
|
|
252
295
|
def sample_moves(bart, key):
|
|
253
296
|
"""
|
|
@@ -266,17 +309,17 @@ def sample_moves(bart, key):
|
|
|
266
309
|
The proposals for grow and prune moves. See `grow_move` and `prune_move`.
|
|
267
310
|
"""
|
|
268
311
|
key = random.split(key, bart['var_trees'].shape[0])
|
|
269
|
-
return
|
|
312
|
+
return _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], key)
|
|
270
313
|
|
|
271
|
-
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
|
|
272
|
-
def
|
|
314
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
|
|
315
|
+
def _sample_moves_vmap_trees(*args):
|
|
316
|
+
args, key = args[:-1], args[-1]
|
|
273
317
|
key, key1 = random.split(key)
|
|
274
|
-
args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
275
318
|
grow = grow_move(*args, key)
|
|
276
319
|
prune = prune_move(*args, key1)
|
|
277
320
|
return grow, prune
|
|
278
321
|
|
|
279
|
-
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
322
|
+
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
280
323
|
"""
|
|
281
324
|
Tree structure grow move proposal of BART MCMC.
|
|
282
325
|
|
|
@@ -296,6 +339,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
296
339
|
The maximum split index for each variable.
|
|
297
340
|
p_nonterminal : array (d,)
|
|
298
341
|
The probability of a nonterminal node at each depth.
|
|
342
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
343
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
299
344
|
key : jax.dtypes.prng_key array
|
|
300
345
|
A jax random key.
|
|
301
346
|
|
|
@@ -304,41 +349,49 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
304
349
|
grow_move : dict
|
|
305
350
|
A dictionary with fields:
|
|
306
351
|
|
|
307
|
-
'
|
|
308
|
-
|
|
352
|
+
'num_growable' : int
|
|
353
|
+
The number of growable leaves.
|
|
309
354
|
'node' : int
|
|
310
|
-
The index of the leaf to grow.
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
355
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
356
|
+
leaves.
|
|
357
|
+
'left', 'right' : int
|
|
358
|
+
The indices of the children of 'node'.
|
|
359
|
+
'var', 'split' : int
|
|
360
|
+
The decision axis and boundary of the new rule.
|
|
315
361
|
'partial_ratio' : float
|
|
316
362
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
317
363
|
the likelihood ratio and the probability of proposing the prune
|
|
318
364
|
move.
|
|
365
|
+
'var_tree', 'split_tree' : array (2 ** (d - 1),)
|
|
366
|
+
The updated decision axes and boundaries of the tree.
|
|
319
367
|
"""
|
|
320
368
|
|
|
321
369
|
key, key1, key2 = random.split(key, 3)
|
|
322
|
-
|
|
323
|
-
leaf_to_grow, num_growable,
|
|
370
|
+
|
|
371
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
|
|
324
372
|
|
|
325
373
|
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
|
|
326
374
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
327
|
-
|
|
375
|
+
|
|
328
376
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
329
377
|
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
330
378
|
|
|
331
|
-
ratio = compute_partial_ratio(
|
|
379
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
|
|
332
380
|
|
|
381
|
+
left = leaf_to_grow << 1
|
|
333
382
|
return dict(
|
|
334
|
-
|
|
383
|
+
num_growable=num_growable,
|
|
335
384
|
node=leaf_to_grow,
|
|
385
|
+
left=left,
|
|
386
|
+
right=left + 1,
|
|
387
|
+
var=var,
|
|
388
|
+
split=split,
|
|
336
389
|
partial_ratio=ratio,
|
|
337
390
|
var_tree=var_tree,
|
|
338
391
|
split_tree=split_tree,
|
|
339
392
|
)
|
|
340
393
|
|
|
341
|
-
def choose_leaf(split_tree, affluence_tree, key):
|
|
394
|
+
def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
342
395
|
"""
|
|
343
396
|
Choose a leaf node to grow in a tree.
|
|
344
397
|
|
|
@@ -348,6 +401,8 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
348
401
|
The splitting points of the tree.
|
|
349
402
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
350
403
|
Whether a leaf has enough points to be grown.
|
|
404
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
405
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
351
406
|
key : jax.dtypes.prng_key array
|
|
352
407
|
A jax random key.
|
|
353
408
|
|
|
@@ -358,19 +413,21 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
358
413
|
``2 ** d``.
|
|
359
414
|
num_growable : int
|
|
360
415
|
The number of leaf nodes that can be grown.
|
|
416
|
+
prob_choose : float
|
|
417
|
+
The normalized probability of choosing the selected leaf.
|
|
361
418
|
num_prunable : int
|
|
362
419
|
The number of leaf parents that could be pruned, after converting the
|
|
363
420
|
selected leaf to a non-terminal node.
|
|
364
|
-
allowed : bool
|
|
365
|
-
Whether the grow move is allowed.
|
|
366
421
|
"""
|
|
367
|
-
is_growable
|
|
368
|
-
leaf_to_grow = randint_masked(key, is_growable)
|
|
369
|
-
leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
|
|
422
|
+
is_growable = growable_leaves(split_tree, affluence_tree)
|
|
370
423
|
num_growable = jnp.count_nonzero(is_growable)
|
|
424
|
+
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
425
|
+
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
426
|
+
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
427
|
+
prob_choose = distr[leaf_to_grow] / distr_norm
|
|
371
428
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
372
429
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
373
|
-
return leaf_to_grow, num_growable,
|
|
430
|
+
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
374
431
|
|
|
375
432
|
def growable_leaves(split_tree, affluence_tree):
|
|
376
433
|
"""
|
|
@@ -389,34 +446,32 @@ def growable_leaves(split_tree, affluence_tree):
|
|
|
389
446
|
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
390
447
|
that are not at the bottom level and have at least two times the number
|
|
391
448
|
of minimum points per leaf.
|
|
392
|
-
allowed : bool
|
|
393
|
-
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
394
449
|
"""
|
|
395
450
|
is_growable = grove.is_actual_leaf(split_tree)
|
|
396
451
|
if affluence_tree is not None:
|
|
397
452
|
is_growable &= affluence_tree
|
|
398
|
-
return is_growable
|
|
453
|
+
return is_growable
|
|
399
454
|
|
|
400
|
-
def
|
|
455
|
+
def categorical(key, distr):
|
|
401
456
|
"""
|
|
402
|
-
Return a random integer
|
|
457
|
+
Return a random integer from an arbitrary distribution.
|
|
403
458
|
|
|
404
459
|
Parameters
|
|
405
460
|
----------
|
|
406
461
|
key : jax.dtypes.prng_key array
|
|
407
462
|
A jax random key.
|
|
408
|
-
|
|
409
|
-
|
|
463
|
+
distr : float array (n,)
|
|
464
|
+
An unnormalized probability distribution.
|
|
410
465
|
|
|
411
466
|
Returns
|
|
412
467
|
-------
|
|
413
468
|
u : int
|
|
414
|
-
A random integer in the range ``[0, n)
|
|
415
|
-
|
|
469
|
+
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
470
|
+
return ``n``.
|
|
416
471
|
"""
|
|
417
|
-
ecdf = jnp.cumsum(
|
|
418
|
-
u = random.
|
|
419
|
-
return jnp.searchsorted(ecdf, u, 'right')
|
|
472
|
+
ecdf = jnp.cumsum(distr)
|
|
473
|
+
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
474
|
+
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
420
475
|
|
|
421
476
|
def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
|
|
422
477
|
"""
|
|
@@ -471,7 +526,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
|
471
526
|
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
472
527
|
particular order. Variables may appear more than once.
|
|
473
528
|
"""
|
|
474
|
-
|
|
529
|
+
|
|
475
530
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
476
531
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
477
532
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
@@ -603,7 +658,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
603
658
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
604
659
|
return random.randint(key, (), l, r)
|
|
605
660
|
|
|
606
|
-
def compute_partial_ratio(
|
|
661
|
+
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
|
|
607
662
|
"""
|
|
608
663
|
Compute the product of the transition and prior ratios of a grow move.
|
|
609
664
|
|
|
@@ -632,6 +687,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
632
687
|
# the two ratios also contain factors num_available_split *
|
|
633
688
|
# num_available_var, but they cancel out
|
|
634
689
|
|
|
690
|
+
# p_prune can't be computed here because it needs the count trees, which are
|
|
691
|
+
# computed in the acceptance phase
|
|
692
|
+
|
|
635
693
|
prune_allowed = leaf_to_grow != 1
|
|
636
694
|
# prune allowed <---> the initial tree is not a root
|
|
637
695
|
# leaf to grow is root --> the tree can only be a root
|
|
@@ -639,31 +697,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
639
697
|
|
|
640
698
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
641
699
|
|
|
642
|
-
|
|
700
|
+
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
643
701
|
|
|
644
702
|
depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
|
|
645
703
|
p_parent = p_nonterminal[depth]
|
|
646
704
|
cp_children = 1 - p_nonterminal[depth + 1]
|
|
647
705
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
648
706
|
|
|
649
|
-
return
|
|
707
|
+
return tree_ratio / inv_trans_ratio
|
|
650
708
|
|
|
651
|
-
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
709
|
+
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
652
710
|
"""
|
|
653
711
|
Tree structure prune move proposal of BART MCMC.
|
|
654
712
|
|
|
655
713
|
Parameters
|
|
656
714
|
----------
|
|
657
|
-
var_tree : array (2 ** (d - 1),)
|
|
715
|
+
var_tree : int array (2 ** (d - 1),)
|
|
658
716
|
The variable indices of the tree.
|
|
659
|
-
split_tree : array (2 ** (d - 1),)
|
|
717
|
+
split_tree : int array (2 ** (d - 1),)
|
|
660
718
|
The splitting points of the tree.
|
|
661
719
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
662
720
|
Whether a leaf has enough points to be grown.
|
|
663
|
-
max_split : array (p,)
|
|
721
|
+
max_split : int array (p,)
|
|
664
722
|
The maximum split index for each variable.
|
|
665
|
-
p_nonterminal : array (d,)
|
|
723
|
+
p_nonterminal : float array (d,)
|
|
666
724
|
The probability of a nonterminal node at each depth.
|
|
725
|
+
p_propose_grow : float array (2 ** (d - 1),)
|
|
726
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
667
727
|
key : jax.dtypes.prng_key array
|
|
668
728
|
A jax random key.
|
|
669
729
|
|
|
@@ -675,24 +735,29 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
675
735
|
'allowed' : bool
|
|
676
736
|
Whether the move is possible.
|
|
677
737
|
'node' : int
|
|
678
|
-
The index of the node to prune.
|
|
738
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
739
|
+
'left', 'right' : int
|
|
740
|
+
The indices of the children of 'node'.
|
|
679
741
|
'partial_ratio' : float
|
|
680
742
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
681
743
|
the likelihood ratio and the probability of proposing the prune
|
|
682
744
|
move. This ratio is inverted.
|
|
683
745
|
"""
|
|
684
|
-
node_to_prune, num_prunable,
|
|
746
|
+
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
|
|
685
747
|
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
686
748
|
|
|
687
|
-
ratio = compute_partial_ratio(
|
|
749
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune, split_tree)
|
|
688
750
|
|
|
751
|
+
left = node_to_prune << 1
|
|
689
752
|
return dict(
|
|
690
753
|
allowed=allowed,
|
|
691
754
|
node=node_to_prune,
|
|
755
|
+
left=left,
|
|
756
|
+
right=left + 1,
|
|
692
757
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
693
758
|
)
|
|
694
759
|
|
|
695
|
-
def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
760
|
+
def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
696
761
|
"""
|
|
697
762
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
698
763
|
|
|
@@ -702,6 +767,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
702
767
|
The splitting points of the tree.
|
|
703
768
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
704
769
|
Whether a leaf has enough points to be grown.
|
|
770
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
771
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
705
772
|
key : jax.dtypes.prng_key array
|
|
706
773
|
A jax random key.
|
|
707
774
|
|
|
@@ -709,28 +776,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
709
776
|
-------
|
|
710
777
|
node_to_prune : int
|
|
711
778
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
712
|
-
``
|
|
779
|
+
``2 ** d``.
|
|
713
780
|
num_prunable : int
|
|
714
781
|
The number of leaf parents that could be pruned.
|
|
715
|
-
|
|
716
|
-
The
|
|
717
|
-
node.
|
|
782
|
+
prob_choose : float
|
|
783
|
+
The normalized probability of choosing the node to prune for growth.
|
|
718
784
|
"""
|
|
719
785
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
720
|
-
node_to_prune = randint_masked(key, is_prunable)
|
|
721
786
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
787
|
+
node_to_prune = randint_masked(key, is_prunable)
|
|
788
|
+
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
722
789
|
|
|
723
|
-
|
|
724
|
-
|
|
790
|
+
split_tree = split_tree.at[node_to_prune].set(0)
|
|
791
|
+
affluence_tree = (
|
|
725
792
|
None if affluence_tree is None else
|
|
726
793
|
affluence_tree.at[node_to_prune].set(True)
|
|
727
794
|
)
|
|
728
|
-
is_growable_leaf
|
|
729
|
-
|
|
795
|
+
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
796
|
+
prob_choose = p_propose_grow[node_to_prune]
|
|
797
|
+
prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
798
|
+
|
|
799
|
+
return node_to_prune, num_prunable, prob_choose
|
|
800
|
+
|
|
801
|
+
def randint_masked(key, mask):
|
|
802
|
+
"""
|
|
803
|
+
Return a random integer in a range, including only some values.
|
|
730
804
|
|
|
731
|
-
|
|
805
|
+
Parameters
|
|
806
|
+
----------
|
|
807
|
+
key : jax.dtypes.prng_key array
|
|
808
|
+
A jax random key.
|
|
809
|
+
mask : bool array (n,)
|
|
810
|
+
The mask indicating the allowed values.
|
|
811
|
+
|
|
812
|
+
Returns
|
|
813
|
+
-------
|
|
814
|
+
u : int
|
|
815
|
+
A random integer in the range ``[0, n)``, and which satisfies
|
|
816
|
+
``mask[u] == True``. If all values in the mask are `False`, return `n`.
|
|
817
|
+
"""
|
|
818
|
+
ecdf = jnp.cumsum(mask)
|
|
819
|
+
u = random.randint(key, (), 0, ecdf[-1])
|
|
820
|
+
return jnp.searchsorted(ecdf, u, 'right')
|
|
732
821
|
|
|
733
|
-
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves,
|
|
822
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
734
823
|
"""
|
|
735
824
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
736
825
|
|
|
@@ -744,8 +833,6 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
744
833
|
prune_moves : dict
|
|
745
834
|
The proposals for prune moves, batched over the first axis. See
|
|
746
835
|
`prune_move`.
|
|
747
|
-
grow_leaf_indices : int array (num_trees, n)
|
|
748
|
-
The leaf indices of the trees proposed by the grow move.
|
|
749
836
|
key : jax.dtypes.prng_key array
|
|
750
837
|
A jax random key.
|
|
751
838
|
|
|
@@ -754,41 +841,339 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
754
841
|
bart : dict
|
|
755
842
|
The new BART mcmc state.
|
|
756
843
|
"""
|
|
844
|
+
bart, grow_moves, prune_moves, count_trees, move_counts, u, z = accept_moves_parallel_stage(bart, grow_moves, prune_moves, key)
|
|
845
|
+
bart, counts = accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z)
|
|
846
|
+
return accept_moves_final_stage(bart, counts, grow_moves, prune_moves)
|
|
847
|
+
|
|
848
|
+
def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
|
|
849
|
+
"""
|
|
850
|
+
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
851
|
+
|
|
852
|
+
Parameters
|
|
853
|
+
----------
|
|
854
|
+
bart : dict
|
|
855
|
+
A BART mcmc state.
|
|
856
|
+
grow_moves, prune_moves : dict
|
|
857
|
+
The proposals for the moves, batched over the first axis. See
|
|
858
|
+
`grow_move` and `prune_move`.
|
|
859
|
+
key : jax.dtypes.prng_key array
|
|
860
|
+
A jax random key.
|
|
861
|
+
|
|
862
|
+
Returns
|
|
863
|
+
-------
|
|
864
|
+
bart : dict
|
|
865
|
+
A partially updated BART mcmc state.
|
|
866
|
+
grow_moves, prune_moves : dict
|
|
867
|
+
The proposals for the moves, with the field 'partial_ratio' replaced
|
|
868
|
+
by 'trans_prior_ratio'.
|
|
869
|
+
count_trees : array (num_trees, 2 ** (d - 1))
|
|
870
|
+
The number of points in each potential or actual leaf node.
|
|
871
|
+
move_counts : dict
|
|
872
|
+
The counts of the number of points in the the nodes modified by the
|
|
873
|
+
moves.
|
|
874
|
+
u : float array (num_trees, 2)
|
|
875
|
+
Random uniform values used to accept the moves.
|
|
876
|
+
z : float array (num_trees, 2 ** d)
|
|
877
|
+
Random standard normal values used to sample the new leaf values.
|
|
878
|
+
"""
|
|
879
|
+
bart = bart.copy()
|
|
880
|
+
|
|
881
|
+
bart['var_trees'] = grow_moves['var_tree']
|
|
882
|
+
# Since var_tree can contain garbage, I can set the var of leaf to be
|
|
883
|
+
# grown irrespectively of what move I'm gonna accept in the end.
|
|
884
|
+
|
|
885
|
+
bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
|
|
886
|
+
|
|
887
|
+
count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
|
|
888
|
+
|
|
889
|
+
grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
|
|
890
|
+
|
|
891
|
+
if bart['opt']['require_min_points']:
|
|
892
|
+
count_half_trees = count_trees[:, :grow_moves['split_tree'].shape[1]]
|
|
893
|
+
bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
|
|
894
|
+
|
|
895
|
+
bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], grow_moves)
|
|
896
|
+
|
|
897
|
+
key, subkey = random.split(key)
|
|
898
|
+
u = random.uniform(subkey, (len(bart['leaf_trees']), 2), bart['opt']['large_float'])
|
|
899
|
+
z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
|
|
900
|
+
|
|
901
|
+
return bart, grow_moves, prune_moves, count_trees, move_counts, u, z
|
|
902
|
+
|
|
903
|
+
def apply_grow_to_indices(grow_moves, leaf_indices, X):
|
|
904
|
+
"""
|
|
905
|
+
Update the leaf indices to apply a grow move.
|
|
906
|
+
|
|
907
|
+
Parameters
|
|
908
|
+
----------
|
|
909
|
+
grow_moves : dict
|
|
910
|
+
The proposals for grow moves. See `grow_move`.
|
|
911
|
+
leaf_indices : array (num_trees, n)
|
|
912
|
+
The index of the leaf each datapoint falls into.
|
|
913
|
+
X : array (p, n)
|
|
914
|
+
The predictors matrix.
|
|
915
|
+
|
|
916
|
+
Returns
|
|
917
|
+
-------
|
|
918
|
+
grow_leaf_indices : array (num_trees, n)
|
|
919
|
+
The updated leaf indices.
|
|
920
|
+
"""
|
|
921
|
+
left_child = grow_moves['node'].astype(leaf_indices.dtype) << 1
|
|
922
|
+
go_right = X[grow_moves['var'], :] >= grow_moves['split'][:, None]
|
|
923
|
+
tree_size = jnp.array(2 * grow_moves['split_tree'].shape[1])
|
|
924
|
+
node_to_update = jnp.where(grow_moves['num_growable'], grow_moves['node'], tree_size)
|
|
925
|
+
return jnp.where(
|
|
926
|
+
leaf_indices == node_to_update[:, None],
|
|
927
|
+
left_child[:, None] + go_right,
|
|
928
|
+
leaf_indices,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
|
|
932
|
+
"""
|
|
933
|
+
Count the number of datapoints in each leaf.
|
|
934
|
+
|
|
935
|
+
Parameters
|
|
936
|
+
----------
|
|
937
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
938
|
+
The index of the leaf each datapoint falls into, if the grow move is
|
|
939
|
+
accepted.
|
|
940
|
+
grow_moves, prune_moves : dict
|
|
941
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
942
|
+
batch_size : int or None
|
|
943
|
+
The data batch size to use for the summation.
|
|
944
|
+
|
|
945
|
+
Returns
|
|
946
|
+
-------
|
|
947
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
948
|
+
The number of points in each potential or actual leaf node.
|
|
949
|
+
counts : dict
|
|
950
|
+
The counts of the number of points in the the nodes modified by the
|
|
951
|
+
moves, organized as two dictionaries 'grow' and 'prune', with subfields
|
|
952
|
+
'left', 'right', and 'total'.
|
|
953
|
+
"""
|
|
954
|
+
|
|
955
|
+
ntree, tree_size = grow_moves['split_tree'].shape
|
|
956
|
+
tree_size *= 2
|
|
957
|
+
counts = dict(grow=dict(), prune=dict())
|
|
958
|
+
tree_indices = jnp.arange(ntree)
|
|
959
|
+
|
|
960
|
+
count_trees = count_datapoints_per_leaf(grow_leaf_indices, tree_size, batch_size)
|
|
961
|
+
|
|
962
|
+
# count datapoints in leaf to grow
|
|
963
|
+
counts['grow']['left'] = count_trees[tree_indices, grow_moves['left']]
|
|
964
|
+
counts['grow']['right'] = count_trees[tree_indices, grow_moves['right']]
|
|
965
|
+
counts['grow']['total'] = counts['grow']['left'] + counts['grow']['right']
|
|
966
|
+
count_trees = count_trees.at[tree_indices, grow_moves['node']].set(counts['grow']['total'])
|
|
967
|
+
|
|
968
|
+
# count datapoints in node to prune
|
|
969
|
+
counts['prune']['left'] = count_trees[tree_indices, prune_moves['left']]
|
|
970
|
+
counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
|
|
971
|
+
counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
|
|
972
|
+
count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
|
|
973
|
+
|
|
974
|
+
return count_trees, counts
|
|
975
|
+
|
|
976
|
+
def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
977
|
+
"""
|
|
978
|
+
Count the number of datapoints in each leaf.
|
|
979
|
+
|
|
980
|
+
Parameters
|
|
981
|
+
----------
|
|
982
|
+
leaf_indices : int array (num_trees, n)
|
|
983
|
+
The index of the leaf each datapoint falls into.
|
|
984
|
+
tree_size : int
|
|
985
|
+
The size of the leaf tree array (2 ** d).
|
|
986
|
+
batch_size : int or None
|
|
987
|
+
The data batch size to use for the summation.
|
|
988
|
+
|
|
989
|
+
Returns
|
|
990
|
+
-------
|
|
991
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
992
|
+
The number of points in each leaf node.
|
|
993
|
+
"""
|
|
994
|
+
if batch_size is None:
|
|
995
|
+
return _count_scan(leaf_indices, tree_size)
|
|
996
|
+
else:
|
|
997
|
+
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
998
|
+
|
|
999
|
+
def _count_scan(leaf_indices, tree_size):
|
|
1000
|
+
def loop(_, leaf_indices):
|
|
1001
|
+
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1002
|
+
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1003
|
+
return count_trees
|
|
1004
|
+
|
|
1005
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
1006
|
+
return (jnp
|
|
1007
|
+
.zeros(size, dtype)
|
|
1008
|
+
.at[indices]
|
|
1009
|
+
.add(values)
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
def _count_vec(leaf_indices, tree_size, batch_size):
|
|
1013
|
+
return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
|
|
1014
|
+
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1015
|
+
|
|
1016
|
+
def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
|
|
1017
|
+
ntree, n = indices.shape
|
|
1018
|
+
tree_indices = jnp.arange(ntree)
|
|
1019
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1020
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
1021
|
+
return (jnp
|
|
1022
|
+
.zeros((ntree, size, nbatches), dtype)
|
|
1023
|
+
.at[tree_indices[:, None], indices, batch_indices]
|
|
1024
|
+
.add(values)
|
|
1025
|
+
.sum(axis=2)
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
|
|
1029
|
+
"""
|
|
1030
|
+
Complete non-likelihood MH ratio calculation.
|
|
1031
|
+
|
|
1032
|
+
This functions adds the probability of choosing the prune move.
|
|
1033
|
+
|
|
1034
|
+
Parameters
|
|
1035
|
+
----------
|
|
1036
|
+
grow_moves, prune_moves : dict
|
|
1037
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1038
|
+
move_counts : dict
|
|
1039
|
+
The counts of the number of points in the the nodes modified by the
|
|
1040
|
+
moves.
|
|
1041
|
+
min_points_per_leaf : int or None
|
|
1042
|
+
The minimum number of data points in a leaf node.
|
|
1043
|
+
|
|
1044
|
+
Returns
|
|
1045
|
+
-------
|
|
1046
|
+
grow_moves, prune_moves : dict
|
|
1047
|
+
The proposals for the moves, with the field 'partial_ratio' replaced
|
|
1048
|
+
by 'trans_prior_ratio'.
|
|
1049
|
+
"""
|
|
1050
|
+
grow_moves = grow_moves.copy()
|
|
1051
|
+
prune_moves = prune_moves.copy()
|
|
1052
|
+
compute_p_prune_vec = jax.vmap(compute_p_prune, in_axes=(0, 0, 0, None))
|
|
1053
|
+
grow_p_prune, prune_p_prune = compute_p_prune_vec(grow_moves, move_counts['grow']['left'], move_counts['grow']['right'], min_points_per_leaf)
|
|
1054
|
+
grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
|
|
1055
|
+
prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
|
|
1056
|
+
return grow_moves, prune_moves
|
|
1057
|
+
|
|
1058
|
+
def compute_p_prune(grow_move, grow_left_count, grow_right_count, min_points_per_leaf):
|
|
1059
|
+
"""
|
|
1060
|
+
Compute the probability of proposing a prune move.
|
|
1061
|
+
|
|
1062
|
+
Parameters
|
|
1063
|
+
----------
|
|
1064
|
+
grow_move : dict
|
|
1065
|
+
The proposal for the grow move, see `grow_move`.
|
|
1066
|
+
grow_left_count, grow_right_count : int
|
|
1067
|
+
The number of datapoints in the proposed children of the leaf to grow.
|
|
1068
|
+
min_points_per_leaf : int or None
|
|
1069
|
+
The minimum number of data points in a leaf node.
|
|
1070
|
+
|
|
1071
|
+
Returns
|
|
1072
|
+
-------
|
|
1073
|
+
grow_p_prune : float
|
|
1074
|
+
The probability of proposing a prune move, after accepting the grow
|
|
1075
|
+
move.
|
|
1076
|
+
prune_p_prune : float
|
|
1077
|
+
The probability of proposing the prune move.
|
|
1078
|
+
"""
|
|
1079
|
+
other_growable_leaves = grow_move['num_growable'] >= 2
|
|
1080
|
+
new_leaves_growable = grow_move['node'] < grow_move['split_tree'].size // 2
|
|
1081
|
+
if min_points_per_leaf is not None:
|
|
1082
|
+
any_above_threshold = grow_left_count >= 2 * min_points_per_leaf
|
|
1083
|
+
any_above_threshold |= grow_right_count >= 2 * min_points_per_leaf
|
|
1084
|
+
new_leaves_growable &= any_above_threshold
|
|
1085
|
+
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1086
|
+
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1087
|
+
prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
|
|
1088
|
+
return grow_p_prune, prune_p_prune
|
|
1089
|
+
|
|
1090
|
+
def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
|
|
1091
|
+
"""
|
|
1092
|
+
Modify leaf values such that the indices of the grow move work on the
|
|
1093
|
+
original tree.
|
|
1094
|
+
|
|
1095
|
+
Parameters
|
|
1096
|
+
----------
|
|
1097
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1098
|
+
The leaf values.
|
|
1099
|
+
grow_moves : dict
|
|
1100
|
+
The proposals for grow moves. See `grow_move`.
|
|
1101
|
+
|
|
1102
|
+
Returns
|
|
1103
|
+
-------
|
|
1104
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1105
|
+
The modified leaf values. The value of the leaf to grow is copied to
|
|
1106
|
+
what would be its children if the grow move was accepted.
|
|
1107
|
+
"""
|
|
1108
|
+
ntree, _ = leaf_trees.shape
|
|
1109
|
+
tree_indices = jnp.arange(ntree)
|
|
1110
|
+
values_at_node = leaf_trees[tree_indices, grow_moves['node']]
|
|
1111
|
+
return (leaf_trees
|
|
1112
|
+
.at[tree_indices, grow_moves['left']]
|
|
1113
|
+
.set(values_at_node)
|
|
1114
|
+
.at[tree_indices, grow_moves['right']]
|
|
1115
|
+
.set(values_at_node)
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z):
|
|
1119
|
+
"""
|
|
1120
|
+
The part of accepting the moves that has to be done one tree at a time.
|
|
1121
|
+
|
|
1122
|
+
Parameters
|
|
1123
|
+
----------
|
|
1124
|
+
bart : dict
|
|
1125
|
+
A partially updated BART mcmc state.
|
|
1126
|
+
count_trees : array (num_trees, 2 ** (d - 1))
|
|
1127
|
+
The number of points in each potential or actual leaf node.
|
|
1128
|
+
grow_moves, prune_moves : dict
|
|
1129
|
+
The proposals for the moves, with completed ratios. See `grow_move` and
|
|
1130
|
+
`prune_move`.
|
|
1131
|
+
move_counts : dict
|
|
1132
|
+
The counts of the number of points in the the nodes modified by the
|
|
1133
|
+
moves.
|
|
1134
|
+
u : float array (num_trees, 2)
|
|
1135
|
+
Random uniform values used to for proposal and accept decisions.
|
|
1136
|
+
z : float array (num_trees, 2 ** d)
|
|
1137
|
+
Random standard normal values used to sample the new leaf values.
|
|
1138
|
+
|
|
1139
|
+
Returns
|
|
1140
|
+
-------
|
|
1141
|
+
bart : dict
|
|
1142
|
+
A partially updated BART mcmc state.
|
|
1143
|
+
counts : dict
|
|
1144
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1145
|
+
"""
|
|
757
1146
|
bart = bart.copy()
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
resid,
|
|
1147
|
+
|
|
1148
|
+
def loop(resid, item):
|
|
1149
|
+
resid, leaf_tree, split_tree, counts, ratios = accept_move_and_sample_leaves(
|
|
761
1150
|
bart['X'],
|
|
762
1151
|
len(bart['leaf_trees']),
|
|
763
|
-
bart['opt']['
|
|
1152
|
+
bart['opt']['resid_batch_size'],
|
|
764
1153
|
resid,
|
|
765
1154
|
bart['sigma2'],
|
|
766
1155
|
bart['min_points_per_leaf'],
|
|
767
|
-
|
|
1156
|
+
'ratios' in bart,
|
|
768
1157
|
*item,
|
|
769
1158
|
)
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
carry = {
|
|
773
|
-
k: jnp.zeros_like(bart[k]) for k in
|
|
774
|
-
['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
|
|
775
|
-
}
|
|
776
|
-
carry['resid'] = bart['resid']
|
|
1159
|
+
return resid, (leaf_tree, split_tree, counts, ratios)
|
|
1160
|
+
|
|
777
1161
|
items = (
|
|
778
|
-
bart['leaf_trees'],
|
|
779
|
-
|
|
780
|
-
bart['
|
|
781
|
-
|
|
782
|
-
prune_moves,
|
|
783
|
-
grow_leaf_indices,
|
|
784
|
-
random.split(key, len(bart['leaf_trees'])),
|
|
1162
|
+
bart['leaf_trees'], count_trees,
|
|
1163
|
+
grow_moves, prune_moves, move_counts,
|
|
1164
|
+
bart['leaf_indices'],
|
|
1165
|
+
u, z,
|
|
785
1166
|
)
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
bart
|
|
789
|
-
|
|
1167
|
+
resid, (leaf_trees, split_trees, counts, ratios) = lax.scan(loop, bart['resid'], items)
|
|
1168
|
+
|
|
1169
|
+
bart['resid'] = resid
|
|
1170
|
+
bart['leaf_trees'] = leaf_trees
|
|
1171
|
+
bart['split_trees'] = split_trees
|
|
1172
|
+
bart.get('ratios', {}).update(ratios)
|
|
790
1173
|
|
|
791
|
-
|
|
1174
|
+
return bart, counts
|
|
1175
|
+
|
|
1176
|
+
def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min_points_per_leaf, save_ratios, leaf_tree, count_tree, grow_move, prune_move, move_counts, grow_leaf_indices, u, z):
|
|
792
1177
|
"""
|
|
793
1178
|
Accept or reject a proposed move and sample the new leaf values.
|
|
794
1179
|
|
|
@@ -798,158 +1183,157 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
|
|
|
798
1183
|
The predictors.
|
|
799
1184
|
ntree : int
|
|
800
1185
|
The number of trees in the forest.
|
|
801
|
-
|
|
802
|
-
The batch size for computing
|
|
1186
|
+
resid_batch_size : int, None
|
|
1187
|
+
The batch size for computing the sum of residuals in each leaf.
|
|
803
1188
|
resid : float array (n,)
|
|
804
1189
|
The residuals (data minus forest value).
|
|
805
1190
|
sigma2 : float
|
|
806
1191
|
The noise variance.
|
|
807
1192
|
min_points_per_leaf : int or None
|
|
808
1193
|
The minimum number of data points in a leaf node.
|
|
809
|
-
|
|
810
|
-
|
|
1194
|
+
save_ratios : bool
|
|
1195
|
+
Whether to save the acceptance ratios.
|
|
811
1196
|
leaf_tree : float array (2 ** d,)
|
|
812
1197
|
The leaf values of the tree.
|
|
813
|
-
|
|
814
|
-
The
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
The proposal for the grow move. See `grow_move`.
|
|
819
|
-
prune_move : dict
|
|
820
|
-
The proposal for the prune move. See `prune_move`.
|
|
1198
|
+
count_tree : int array (2 ** d,)
|
|
1199
|
+
The number of datapoints in each leaf.
|
|
1200
|
+
grow_move, prune_move : dict
|
|
1201
|
+
The proposals for the moves, with completed ratios. See `grow_move` and
|
|
1202
|
+
`prune_move`.
|
|
821
1203
|
grow_leaf_indices : int array (n,)
|
|
822
1204
|
The leaf indices of the tree proposed by the grow move.
|
|
823
|
-
|
|
824
|
-
|
|
1205
|
+
u : float array (2,)
|
|
1206
|
+
Two uniform random values in [0, 1).
|
|
1207
|
+
z : float array (2 ** d,)
|
|
1208
|
+
Standard normal random values.
|
|
825
1209
|
|
|
826
1210
|
Returns
|
|
827
1211
|
-------
|
|
828
1212
|
resid : float array (n,)
|
|
829
1213
|
The updated residuals (data minus forest value).
|
|
1214
|
+
leaf_tree : float array (2 ** d,)
|
|
1215
|
+
The new leaf values of the tree.
|
|
1216
|
+
split_tree : int array (2 ** (d - 1),)
|
|
1217
|
+
The updated decision boundaries of the tree.
|
|
830
1218
|
counts : dict
|
|
831
|
-
The
|
|
832
|
-
|
|
833
|
-
The
|
|
1219
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1220
|
+
ratios : dict
|
|
1221
|
+
The acceptance ratios for the moves. Empty if not to be saved.
|
|
834
1222
|
"""
|
|
835
|
-
|
|
836
|
-
# compute leaf indices in starting tree
|
|
837
|
-
grow_node = grow_move['node']
|
|
838
|
-
grow_left = grow_node << 1
|
|
839
|
-
grow_right = grow_left + 1
|
|
840
|
-
leaf_indices = jnp.where(
|
|
841
|
-
(grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
|
|
842
|
-
grow_node,
|
|
843
|
-
grow_leaf_indices,
|
|
844
|
-
)
|
|
845
1223
|
|
|
846
|
-
#
|
|
847
|
-
|
|
848
|
-
prune_left = prune_node << 1
|
|
849
|
-
prune_right = prune_left + 1
|
|
850
|
-
prune_leaf_indices = jnp.where(
|
|
851
|
-
(leaf_indices == prune_left) | (leaf_indices == prune_right),
|
|
852
|
-
prune_node,
|
|
853
|
-
leaf_indices,
|
|
854
|
-
)
|
|
1224
|
+
# sum residuals and count units per leaf, in tree proposed by grow move
|
|
1225
|
+
resid_tree = sum_resid(resid, grow_leaf_indices, leaf_tree.size, resid_batch_size)
|
|
855
1226
|
|
|
856
1227
|
# subtract starting tree from function
|
|
857
|
-
|
|
1228
|
+
resid_tree += count_tree * leaf_tree
|
|
858
1229
|
|
|
859
|
-
#
|
|
860
|
-
|
|
1230
|
+
# get indices of grow move
|
|
1231
|
+
grow_node = grow_move['node']
|
|
1232
|
+
assert grow_node.dtype == jnp.int32
|
|
1233
|
+
grow_left = grow_move['left']
|
|
1234
|
+
grow_right = grow_move['right']
|
|
861
1235
|
|
|
862
|
-
#
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
.set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
|
|
1236
|
+
# sum residuals in leaf to grow
|
|
1237
|
+
grow_resid_left = resid_tree[grow_left]
|
|
1238
|
+
grow_resid_right = resid_tree[grow_right]
|
|
1239
|
+
grow_resid_total = grow_resid_left + grow_resid_right
|
|
1240
|
+
resid_tree = resid_tree.at[grow_node].set(grow_resid_total)
|
|
868
1241
|
|
|
869
|
-
#
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
1242
|
+
# get indices of prune move
|
|
1243
|
+
prune_node = prune_move['node']
|
|
1244
|
+
assert prune_node.dtype == jnp.int32
|
|
1245
|
+
prune_left = prune_move['left']
|
|
1246
|
+
prune_right = prune_move['right']
|
|
874
1247
|
|
|
875
|
-
#
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
1248
|
+
# sum residuals in node to prune
|
|
1249
|
+
prune_resid_left = resid_tree[prune_left]
|
|
1250
|
+
prune_resid_right = resid_tree[prune_right]
|
|
1251
|
+
prune_resid_total = prune_resid_left + prune_resid_right
|
|
1252
|
+
resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
|
|
879
1253
|
|
|
880
|
-
#
|
|
881
|
-
|
|
882
|
-
prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
|
|
1254
|
+
# Now resid_tree and count_tree contain correct values whatever move is
|
|
1255
|
+
# accepted.
|
|
883
1256
|
|
|
884
1257
|
# compute likelihood ratios
|
|
885
|
-
grow_lk_ratio = compute_likelihood_ratio(
|
|
886
|
-
prune_lk_ratio = compute_likelihood_ratio(
|
|
1258
|
+
grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
|
|
1259
|
+
prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
|
|
887
1260
|
|
|
888
1261
|
# compute acceptance ratios
|
|
889
|
-
grow_ratio =
|
|
890
|
-
|
|
1262
|
+
grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
|
|
1263
|
+
if min_points_per_leaf is not None:
|
|
1264
|
+
grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1265
|
+
grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1266
|
+
prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
|
|
891
1267
|
prune_ratio = lax.reciprocal(prune_ratio)
|
|
892
1268
|
|
|
893
|
-
#
|
|
894
|
-
|
|
895
|
-
|
|
1269
|
+
# save acceptance ratios
|
|
1270
|
+
ratios = {}
|
|
1271
|
+
if save_ratios:
|
|
1272
|
+
ratios.update(
|
|
1273
|
+
grow=dict(
|
|
1274
|
+
trans_prior=grow_move['trans_prior_ratio'],
|
|
1275
|
+
likelihood=grow_lk_ratio,
|
|
1276
|
+
),
|
|
1277
|
+
prune=dict(
|
|
1278
|
+
trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
|
|
1279
|
+
likelihood=lax.reciprocal(prune_lk_ratio),
|
|
1280
|
+
),
|
|
1281
|
+
)
|
|
896
1282
|
|
|
897
1283
|
# determine what move to propose (not proposing anything is an option)
|
|
898
|
-
|
|
899
|
-
|
|
1284
|
+
grow_allowed = grow_move['num_growable'].astype(bool)
|
|
1285
|
+
p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
|
|
1286
|
+
try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
|
|
900
1287
|
try_prune = prune_move['allowed'] & ~try_grow
|
|
901
1288
|
|
|
902
1289
|
# determine whether to accept the move
|
|
903
|
-
do_grow = try_grow & (
|
|
904
|
-
do_prune = try_prune & (
|
|
905
|
-
|
|
906
|
-
# pick
|
|
907
|
-
|
|
908
|
-
split_tree = jnp.where(do_grow,
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
|
|
916
|
-
count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
|
|
917
|
-
resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
|
|
918
|
-
count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
|
|
919
|
-
|
|
920
|
-
# update acceptance counts
|
|
921
|
-
counts = counts.copy()
|
|
922
|
-
counts['grow_prop_count'] += try_grow
|
|
923
|
-
counts['grow_acc_count'] += do_grow
|
|
924
|
-
counts['prune_prop_count'] += try_prune
|
|
925
|
-
counts['prune_acc_count'] += do_prune
|
|
926
|
-
|
|
927
|
-
# compute leaves posterior
|
|
928
|
-
prec_lk = count_tree / sigma2
|
|
1290
|
+
do_grow = try_grow & (u[1] < grow_ratio)
|
|
1291
|
+
do_prune = try_prune & (u[1] < prune_ratio)
|
|
1292
|
+
|
|
1293
|
+
# pick split tree for chosen move
|
|
1294
|
+
split_tree = grow_move['split_tree']
|
|
1295
|
+
split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
|
|
1296
|
+
split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
|
|
1297
|
+
# I can leave garbage in var_tree, resid_tree, count_tree
|
|
1298
|
+
|
|
1299
|
+
# compute leaves posterior and sample leaves
|
|
1300
|
+
inv_sigma2 = lax.reciprocal(sigma2)
|
|
1301
|
+
prec_lk = count_tree * inv_sigma2
|
|
929
1302
|
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
930
|
-
mean_post = resid_tree
|
|
931
|
-
|
|
932
|
-
# sample leaves
|
|
933
|
-
z = random.normal(key, mean_post.shape, mean_post.dtype)
|
|
1303
|
+
mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
1304
|
+
initial_leaf_tree = leaf_tree
|
|
934
1305
|
leaf_tree = mean_post + z * jnp.sqrt(var_post)
|
|
935
1306
|
|
|
936
|
-
#
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
1307
|
+
# copy leaves around such that the grow leaf indices select the right leaf
|
|
1308
|
+
leaf_tree = (leaf_tree
|
|
1309
|
+
.at[jnp.where(do_prune, prune_left, leaf_tree.size)]
|
|
1310
|
+
.set(leaf_tree[prune_node])
|
|
1311
|
+
.at[jnp.where(do_prune, prune_right, leaf_tree.size)]
|
|
1312
|
+
.set(leaf_tree[prune_node])
|
|
1313
|
+
)
|
|
1314
|
+
leaf_tree = (leaf_tree
|
|
1315
|
+
.at[jnp.where(do_grow, leaf_tree.size, grow_left)]
|
|
1316
|
+
.set(leaf_tree[grow_node])
|
|
1317
|
+
.at[jnp.where(do_grow, leaf_tree.size, grow_right)]
|
|
1318
|
+
.set(leaf_tree[grow_node])
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# replace old tree with new tree in function values
|
|
1322
|
+
resid += (initial_leaf_tree - leaf_tree)[grow_leaf_indices]
|
|
940
1323
|
|
|
941
|
-
# pack
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
1324
|
+
# pack proposal and acceptance indicators
|
|
1325
|
+
counts = dict(
|
|
1326
|
+
grow_prop_count=try_grow,
|
|
1327
|
+
grow_acc_count=do_grow,
|
|
1328
|
+
prune_prop_count=try_prune,
|
|
1329
|
+
prune_acc_count=do_prune,
|
|
1330
|
+
)
|
|
947
1331
|
|
|
948
|
-
return resid, counts,
|
|
1332
|
+
return resid, leaf_tree, split_tree, counts, ratios
|
|
949
1333
|
|
|
950
|
-
def
|
|
1334
|
+
def sum_resid(resid, leaf_indices, tree_size, batch_size):
|
|
951
1335
|
"""
|
|
952
|
-
|
|
1336
|
+
Sum the residuals in each leaf.
|
|
953
1337
|
|
|
954
1338
|
Parameters
|
|
955
1339
|
----------
|
|
@@ -960,104 +1344,56 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
|
960
1344
|
tree_size : int
|
|
961
1345
|
The size of the tree array (2 ** d).
|
|
962
1346
|
batch_size : int, None
|
|
963
|
-
The batch size for the aggregation. Batching increases numerical
|
|
1347
|
+
The data batch size for the aggregation. Batching increases numerical
|
|
964
1348
|
accuracy and parallelism.
|
|
965
1349
|
|
|
966
1350
|
Returns
|
|
967
1351
|
-------
|
|
968
1352
|
resid_tree : float array (2 ** d,)
|
|
969
1353
|
The sum of the residuals at data points in each leaf.
|
|
970
|
-
count_tree : int array (2 ** d,)
|
|
971
|
-
The number of data points in each leaf.
|
|
972
1354
|
"""
|
|
973
1355
|
if batch_size is None:
|
|
974
1356
|
aggr_func = _aggregate_scatter
|
|
975
1357
|
else:
|
|
976
|
-
aggr_func = functools.partial(
|
|
977
|
-
|
|
978
|
-
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
979
|
-
return resid_tree, count_tree
|
|
980
|
-
|
|
981
|
-
def _aggregate_scatter(values, indices, size, dtype):
|
|
982
|
-
return (jnp
|
|
983
|
-
.zeros(size, dtype)
|
|
984
|
-
.at[indices]
|
|
985
|
-
.add(values)
|
|
986
|
-
)
|
|
1358
|
+
aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
1359
|
+
return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
987
1360
|
|
|
988
|
-
def
|
|
989
|
-
|
|
990
|
-
|
|
1361
|
+
def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
1362
|
+
n, = indices.shape
|
|
1363
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1364
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
991
1365
|
return (jnp
|
|
992
|
-
.zeros((
|
|
993
|
-
.at[
|
|
1366
|
+
.zeros((size, nbatches), dtype)
|
|
1367
|
+
.at[indices, batch_indices]
|
|
994
1368
|
.add(values)
|
|
995
|
-
.sum(axis=
|
|
1369
|
+
.sum(axis=1)
|
|
996
1370
|
)
|
|
997
1371
|
|
|
998
|
-
def
|
|
999
|
-
"""
|
|
1000
|
-
Compute the probability of proposing a prune move after doing a grow move.
|
|
1001
|
-
|
|
1002
|
-
Parameters
|
|
1003
|
-
----------
|
|
1004
|
-
new_split_tree : int array (2 ** (d - 1),)
|
|
1005
|
-
The decision boundaries of the tree, after the grow move.
|
|
1006
|
-
new_affluence_tree : bool array (2 ** (d - 1),)
|
|
1007
|
-
Which leaves have enough points to be grown, after the grow move.
|
|
1008
|
-
|
|
1009
|
-
Returns
|
|
1010
|
-
-------
|
|
1011
|
-
p_prune : float
|
|
1012
|
-
The probability of proposing a prune move after the grow move. This is
|
|
1013
|
-
0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
|
|
1014
|
-
at least the node just grown can be pruned.
|
|
1015
|
-
"""
|
|
1016
|
-
_, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
|
|
1017
|
-
return jnp.where(grow_again_allowed, 0.5, 1)
|
|
1018
|
-
|
|
1019
|
-
def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
|
|
1372
|
+
def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count, left_count, right_count, sigma2, n_tree):
|
|
1020
1373
|
"""
|
|
1021
1374
|
Compute the likelihood ratio of a grow move.
|
|
1022
1375
|
|
|
1023
1376
|
Parameters
|
|
1024
1377
|
----------
|
|
1025
|
-
|
|
1026
|
-
The sum of the residuals
|
|
1027
|
-
|
|
1028
|
-
The
|
|
1378
|
+
total_resid : float
|
|
1379
|
+
The sum of the residuals in the leaf to grow.
|
|
1380
|
+
left_resid, right_resid : float
|
|
1381
|
+
The sum of the residuals in the left/right child of the leaf to grow.
|
|
1382
|
+
total_count : int
|
|
1383
|
+
The number of datapoints in the leaf to grow.
|
|
1384
|
+
left_count, right_count : int
|
|
1385
|
+
The number of datapoints in the left/right child of the leaf to grow.
|
|
1029
1386
|
sigma2 : float
|
|
1030
1387
|
The noise variance.
|
|
1031
|
-
node : int
|
|
1032
|
-
The index of the leaf that has been grown.
|
|
1033
1388
|
n_tree : int
|
|
1034
1389
|
The number of trees in the forest.
|
|
1035
|
-
min_points_per_leaf : int or None
|
|
1036
|
-
The minimum number of data points in a leaf node.
|
|
1037
1390
|
|
|
1038
1391
|
Returns
|
|
1039
1392
|
-------
|
|
1040
1393
|
ratio : float
|
|
1041
1394
|
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1042
|
-
|
|
1043
|
-
Notes
|
|
1044
|
-
-----
|
|
1045
|
-
The ratio is set to 0 if the grow move would create leaves with not enough
|
|
1046
|
-
datapoints per leaf, although this is part of the prior rather than the
|
|
1047
|
-
likelihood.
|
|
1048
1395
|
"""
|
|
1049
1396
|
|
|
1050
|
-
left_child = node << 1
|
|
1051
|
-
right_child = left_child + 1
|
|
1052
|
-
|
|
1053
|
-
left_resid = resid_tree[left_child]
|
|
1054
|
-
right_resid = resid_tree[right_child]
|
|
1055
|
-
total_resid = left_resid + right_resid
|
|
1056
|
-
|
|
1057
|
-
left_count = count_tree[left_child]
|
|
1058
|
-
right_count = count_tree[right_child]
|
|
1059
|
-
total_count = left_count + right_count
|
|
1060
|
-
|
|
1061
1397
|
sigma_mu2 = 1 / n_tree
|
|
1062
1398
|
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
1063
1399
|
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
@@ -1071,13 +1407,67 @@ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_p
|
|
|
1071
1407
|
total_resid * total_resid / sigma2_total
|
|
1072
1408
|
)
|
|
1073
1409
|
|
|
1074
|
-
|
|
1410
|
+
return jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
1075
1411
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1412
|
+
def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
|
|
1413
|
+
"""
|
|
1414
|
+
The final part of accepting the moves, in parallel across trees.
|
|
1415
|
+
|
|
1416
|
+
Parameters
|
|
1417
|
+
----------
|
|
1418
|
+
bart : dict
|
|
1419
|
+
A partially updated BART mcmc state.
|
|
1420
|
+
counts : dict
|
|
1421
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1422
|
+
grow_moves, prune_moves : dict
|
|
1423
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1424
|
+
|
|
1425
|
+
Returns
|
|
1426
|
+
-------
|
|
1427
|
+
bart : dict
|
|
1428
|
+
The fully updated BART mcmc state.
|
|
1429
|
+
"""
|
|
1430
|
+
bart = bart.copy()
|
|
1431
|
+
|
|
1432
|
+
for k, v in counts.items():
|
|
1433
|
+
bart[k] = jnp.sum(v, axis=0)
|
|
1079
1434
|
|
|
1080
|
-
|
|
1435
|
+
bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
|
|
1436
|
+
|
|
1437
|
+
return bart
|
|
1438
|
+
|
|
1439
|
+
def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
|
|
1440
|
+
"""
|
|
1441
|
+
Update the leaf indices to match the accepted move.
|
|
1442
|
+
|
|
1443
|
+
Parameters
|
|
1444
|
+
----------
|
|
1445
|
+
leaf_indices : int array (num_trees, n)
|
|
1446
|
+
The index of the leaf each datapoint falls into, if the grow move was
|
|
1447
|
+
accepted.
|
|
1448
|
+
counts : dict
|
|
1449
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1450
|
+
grow_moves, prune_moves : dict
|
|
1451
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1452
|
+
|
|
1453
|
+
Returns
|
|
1454
|
+
-------
|
|
1455
|
+
leaf_indices : int array (num_trees, n)
|
|
1456
|
+
The updated leaf indices.
|
|
1457
|
+
"""
|
|
1458
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1459
|
+
cond = (leaf_indices & mask) == grow_moves['left'][:, None]
|
|
1460
|
+
leaf_indices = jnp.where(
|
|
1461
|
+
cond & ~counts['grow_acc_count'][:, None],
|
|
1462
|
+
grow_moves['node'][:, None].astype(leaf_indices.dtype),
|
|
1463
|
+
leaf_indices,
|
|
1464
|
+
)
|
|
1465
|
+
cond = (leaf_indices & mask) == prune_moves['left'][:, None]
|
|
1466
|
+
return jnp.where(
|
|
1467
|
+
cond & counts['prune_acc_count'][:, None],
|
|
1468
|
+
prune_moves['node'][:, None].astype(leaf_indices.dtype),
|
|
1469
|
+
leaf_indices,
|
|
1470
|
+
)
|
|
1081
1471
|
|
|
1082
1472
|
def sample_sigma(bart, key):
|
|
1083
1473
|
"""
|
|
@@ -1099,7 +1489,7 @@ def sample_sigma(bart, key):
|
|
|
1099
1489
|
|
|
1100
1490
|
resid = bart['resid']
|
|
1101
1491
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1102
|
-
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['
|
|
1492
|
+
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
|
|
1103
1493
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1104
1494
|
|
|
1105
1495
|
sample = random.gamma(key, alpha)
|