bartz 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +43 -18
- bartz/_version.py +1 -1
- bartz/grove.py +19 -14
- bartz/jaxext.py +48 -21
- bartz/mcmcloop.py +13 -15
- bartz/mcmcstep.py +681 -299
- bartz/prepcovars.py +43 -13
- bartz-0.3.0.dist-info/METADATA +77 -0
- bartz-0.3.0.dist-info/RECORD +13 -0
- bartz-0.2.1.dist-info/METADATA +0 -32
- bartz-0.2.1.dist-info/RECORD +0 -13
- {bartz-0.2.1.dist-info → bartz-0.3.0.dist-info}/LICENSE +0 -0
- {bartz-0.2.1.dist-info → bartz-0.3.0.dist-info}/WHEEL +0 -0
bartz/mcmcstep.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -34,6 +34,7 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
+
import math
|
|
37
38
|
|
|
38
39
|
import jax
|
|
39
40
|
from jax import random
|
|
@@ -54,7 +55,9 @@ def init(*,
|
|
|
54
55
|
small_float=jnp.float32,
|
|
55
56
|
large_float=jnp.float32,
|
|
56
57
|
min_points_per_leaf=None,
|
|
57
|
-
|
|
58
|
+
resid_batch_size='auto',
|
|
59
|
+
count_batch_size='auto',
|
|
60
|
+
save_ratios=False,
|
|
58
61
|
):
|
|
59
62
|
"""
|
|
60
63
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -82,10 +85,13 @@ def init(*,
|
|
|
82
85
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
83
86
|
min_points_per_leaf : int, optional
|
|
84
87
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
-
|
|
86
|
-
The batch
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
|
|
89
|
+
The batch sizes, along datapoints, for summing the residuals and
|
|
90
|
+
counting the number of datapoints in each leaf. `None` for no batching.
|
|
91
|
+
If 'auto', pick a value based on the device of `y`, or the default
|
|
92
|
+
device.
|
|
93
|
+
save_ratios : bool, default False
|
|
94
|
+
Whether to save the Metropolis-Hastings ratios.
|
|
89
95
|
|
|
90
96
|
Returns
|
|
91
97
|
-------
|
|
@@ -111,6 +117,8 @@ def init(*,
|
|
|
111
117
|
'p_nonterminal' : large_float array (d,)
|
|
112
118
|
The probability of a nonterminal node at each depth, padded with a
|
|
113
119
|
zero.
|
|
120
|
+
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
121
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
114
122
|
'sigma2_alpha' : large_float
|
|
115
123
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
116
124
|
'sigma2_beta' : large_float
|
|
@@ -121,6 +129,8 @@ def init(*,
|
|
|
121
129
|
The response.
|
|
122
130
|
'X' : int array (p, n)
|
|
123
131
|
The predictors.
|
|
132
|
+
'leaf_indices' : int array (num_trees, n)
|
|
133
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
124
134
|
'min_points_per_leaf' : int or None
|
|
125
135
|
The minimum number of data points in a leaf node.
|
|
126
136
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
@@ -129,8 +139,6 @@ def init(*,
|
|
|
129
139
|
'opt' : LeafDict
|
|
130
140
|
A dictionary with config values:
|
|
131
141
|
|
|
132
|
-
'suffstat_batch_size' : int or None
|
|
133
|
-
The batch size for computing sufficient statistics.
|
|
134
142
|
'small_float' : dtype
|
|
135
143
|
The dtype for large arrays used in the algorithm.
|
|
136
144
|
'large_float' : dtype
|
|
@@ -138,6 +146,8 @@ def init(*,
|
|
|
138
146
|
accuracy.
|
|
139
147
|
'require_min_points' : bool
|
|
140
148
|
Whether the `min_points_per_leaf` parameter is specified.
|
|
149
|
+
'resid_batch_size', 'count_batch_size' : int or None
|
|
150
|
+
The data batch sizes for computing the sufficient statistics.
|
|
141
151
|
"""
|
|
142
152
|
|
|
143
153
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
@@ -151,24 +161,28 @@ def init(*,
|
|
|
151
161
|
small_float = jnp.dtype(small_float)
|
|
152
162
|
large_float = jnp.dtype(large_float)
|
|
153
163
|
y = jnp.asarray(y, small_float)
|
|
154
|
-
|
|
164
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
|
|
165
|
+
sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
|
|
166
|
+
sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
155
167
|
|
|
156
168
|
bart = dict(
|
|
157
169
|
leaf_trees=make_forest(max_depth, small_float),
|
|
158
170
|
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
159
171
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
160
172
|
resid=jnp.asarray(y, large_float),
|
|
161
|
-
sigma2=
|
|
173
|
+
sigma2=sigma2,
|
|
162
174
|
grow_prop_count=jnp.zeros((), int),
|
|
163
175
|
grow_acc_count=jnp.zeros((), int),
|
|
164
176
|
prune_prop_count=jnp.zeros((), int),
|
|
165
177
|
prune_acc_count=jnp.zeros((), int),
|
|
166
178
|
p_nonterminal=p_nonterminal,
|
|
179
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
167
180
|
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
168
181
|
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
169
182
|
max_split=jnp.asarray(max_split),
|
|
170
183
|
y=y,
|
|
171
184
|
X=jnp.asarray(X),
|
|
185
|
+
leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
|
|
172
186
|
min_points_per_leaf=(
|
|
173
187
|
None if min_points_per_leaf is None else
|
|
174
188
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -178,37 +192,61 @@ def init(*,
|
|
|
178
192
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
179
193
|
),
|
|
180
194
|
opt=jaxext.LeafDict(
|
|
181
|
-
suffstat_batch_size=suffstat_batch_size,
|
|
182
195
|
small_float=small_float,
|
|
183
196
|
large_float=large_float,
|
|
184
197
|
require_min_points=min_points_per_leaf is not None,
|
|
198
|
+
resid_batch_size=resid_batch_size,
|
|
199
|
+
count_batch_size=count_batch_size,
|
|
185
200
|
),
|
|
186
201
|
)
|
|
187
202
|
|
|
203
|
+
if save_ratios:
|
|
204
|
+
bart['ratios'] = dict(
|
|
205
|
+
grow=dict(
|
|
206
|
+
trans_prior=jnp.full(num_trees, jnp.nan),
|
|
207
|
+
likelihood=jnp.full(num_trees, jnp.nan),
|
|
208
|
+
),
|
|
209
|
+
prune=dict(
|
|
210
|
+
trans_prior=jnp.full(num_trees, jnp.nan),
|
|
211
|
+
likelihood=jnp.full(num_trees, jnp.nan),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
|
|
188
215
|
return bart
|
|
189
216
|
|
|
190
|
-
def _choose_suffstat_batch_size(
|
|
191
|
-
|
|
217
|
+
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
|
|
218
|
+
|
|
219
|
+
@functools.cache
|
|
220
|
+
def get_platform():
|
|
192
221
|
try:
|
|
193
222
|
device = y.devices().pop()
|
|
194
223
|
except jax.errors.ConcretizationTypeError:
|
|
195
224
|
device = jax.devices()[0]
|
|
196
225
|
platform = device.platform
|
|
226
|
+
if platform not in ('cpu', 'gpu'):
|
|
227
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
228
|
+
return platform
|
|
197
229
|
|
|
230
|
+
if resid_batch_size == 'auto':
|
|
231
|
+
platform = get_platform()
|
|
232
|
+
n = max(1, y.size)
|
|
198
233
|
if platform == 'cpu':
|
|
199
|
-
|
|
200
|
-
# maybe I should batch residuals (not counts) for numerical
|
|
201
|
-
# accuracy, even if it's slower
|
|
234
|
+
resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
|
|
202
235
|
elif platform == 'gpu':
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
236
|
+
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
|
|
237
|
+
resid_batch_size = max(1, resid_batch_size)
|
|
238
|
+
|
|
239
|
+
if count_batch_size == 'auto':
|
|
240
|
+
platform = get_platform()
|
|
241
|
+
if platform == 'cpu':
|
|
242
|
+
count_batch_size = None
|
|
243
|
+
elif platform == 'gpu':
|
|
244
|
+
n = max(1, y.size)
|
|
245
|
+
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
246
|
+
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
247
|
+
count_batch_size = max(1, count_batch_size)
|
|
248
|
+
|
|
249
|
+
return resid_batch_size, count_batch_size
|
|
212
250
|
|
|
213
251
|
def step(bart, key):
|
|
214
252
|
"""
|
|
@@ -248,14 +286,11 @@ def sample_trees(bart, key):
|
|
|
248
286
|
|
|
249
287
|
Notes
|
|
250
288
|
-----
|
|
251
|
-
This function zeroes the proposal counters
|
|
289
|
+
This function zeroes the proposal counters.
|
|
252
290
|
"""
|
|
253
|
-
bart = bart.copy()
|
|
254
291
|
key, subkey = random.split(key)
|
|
255
292
|
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
256
|
-
bart
|
|
257
|
-
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
258
|
-
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
293
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
|
|
259
294
|
|
|
260
295
|
def sample_moves(bart, key):
|
|
261
296
|
"""
|
|
@@ -274,17 +309,17 @@ def sample_moves(bart, key):
|
|
|
274
309
|
The proposals for grow and prune moves. See `grow_move` and `prune_move`.
|
|
275
310
|
"""
|
|
276
311
|
key = random.split(key, bart['var_trees'].shape[0])
|
|
277
|
-
return
|
|
312
|
+
return _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], key)
|
|
278
313
|
|
|
279
|
-
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
|
|
280
|
-
def
|
|
314
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
|
|
315
|
+
def _sample_moves_vmap_trees(*args):
|
|
316
|
+
args, key = args[:-1], args[-1]
|
|
281
317
|
key, key1 = random.split(key)
|
|
282
|
-
args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
|
|
283
318
|
grow = grow_move(*args, key)
|
|
284
319
|
prune = prune_move(*args, key1)
|
|
285
320
|
return grow, prune
|
|
286
321
|
|
|
287
|
-
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
322
|
+
def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
288
323
|
"""
|
|
289
324
|
Tree structure grow move proposal of BART MCMC.
|
|
290
325
|
|
|
@@ -304,6 +339,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
304
339
|
The maximum split index for each variable.
|
|
305
340
|
p_nonterminal : array (d,)
|
|
306
341
|
The probability of a nonterminal node at each depth.
|
|
342
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
343
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
307
344
|
key : jax.dtypes.prng_key array
|
|
308
345
|
A jax random key.
|
|
309
346
|
|
|
@@ -312,41 +349,49 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
312
349
|
grow_move : dict
|
|
313
350
|
A dictionary with fields:
|
|
314
351
|
|
|
315
|
-
'
|
|
316
|
-
|
|
352
|
+
'num_growable' : int
|
|
353
|
+
The number of growable leaves.
|
|
317
354
|
'node' : int
|
|
318
|
-
The index of the leaf to grow.
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
355
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
356
|
+
leaves.
|
|
357
|
+
'left', 'right' : int
|
|
358
|
+
The indices of the children of 'node'.
|
|
359
|
+
'var', 'split' : int
|
|
360
|
+
The decision axis and boundary of the new rule.
|
|
323
361
|
'partial_ratio' : float
|
|
324
362
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
325
363
|
the likelihood ratio and the probability of proposing the prune
|
|
326
364
|
move.
|
|
365
|
+
'var_tree', 'split_tree' : array (2 ** (d - 1),)
|
|
366
|
+
The updated decision axes and boundaries of the tree.
|
|
327
367
|
"""
|
|
328
368
|
|
|
329
369
|
key, key1, key2 = random.split(key, 3)
|
|
330
|
-
|
|
331
|
-
leaf_to_grow, num_growable,
|
|
370
|
+
|
|
371
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
|
|
332
372
|
|
|
333
373
|
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
|
|
334
374
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
335
|
-
|
|
375
|
+
|
|
336
376
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
337
377
|
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
338
378
|
|
|
339
|
-
ratio = compute_partial_ratio(
|
|
379
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
|
|
340
380
|
|
|
381
|
+
left = leaf_to_grow << 1
|
|
341
382
|
return dict(
|
|
342
|
-
|
|
383
|
+
num_growable=num_growable,
|
|
343
384
|
node=leaf_to_grow,
|
|
385
|
+
left=left,
|
|
386
|
+
right=left + 1,
|
|
387
|
+
var=var,
|
|
388
|
+
split=split,
|
|
344
389
|
partial_ratio=ratio,
|
|
345
390
|
var_tree=var_tree,
|
|
346
391
|
split_tree=split_tree,
|
|
347
392
|
)
|
|
348
393
|
|
|
349
|
-
def choose_leaf(split_tree, affluence_tree, key):
|
|
394
|
+
def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
350
395
|
"""
|
|
351
396
|
Choose a leaf node to grow in a tree.
|
|
352
397
|
|
|
@@ -356,6 +401,8 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
356
401
|
The splitting points of the tree.
|
|
357
402
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
358
403
|
Whether a leaf has enough points to be grown.
|
|
404
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
405
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
359
406
|
key : jax.dtypes.prng_key array
|
|
360
407
|
A jax random key.
|
|
361
408
|
|
|
@@ -366,19 +413,21 @@ def choose_leaf(split_tree, affluence_tree, key):
|
|
|
366
413
|
``2 ** d``.
|
|
367
414
|
num_growable : int
|
|
368
415
|
The number of leaf nodes that can be grown.
|
|
416
|
+
prob_choose : float
|
|
417
|
+
The normalized probability of choosing the selected leaf.
|
|
369
418
|
num_prunable : int
|
|
370
419
|
The number of leaf parents that could be pruned, after converting the
|
|
371
420
|
selected leaf to a non-terminal node.
|
|
372
|
-
allowed : bool
|
|
373
|
-
Whether the grow move is allowed.
|
|
374
421
|
"""
|
|
375
|
-
is_growable
|
|
376
|
-
leaf_to_grow = randint_masked(key, is_growable)
|
|
377
|
-
leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
|
|
422
|
+
is_growable = growable_leaves(split_tree, affluence_tree)
|
|
378
423
|
num_growable = jnp.count_nonzero(is_growable)
|
|
424
|
+
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
425
|
+
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
426
|
+
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
427
|
+
prob_choose = distr[leaf_to_grow] / distr_norm
|
|
379
428
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
380
429
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
381
|
-
return leaf_to_grow, num_growable,
|
|
430
|
+
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
382
431
|
|
|
383
432
|
def growable_leaves(split_tree, affluence_tree):
|
|
384
433
|
"""
|
|
@@ -397,34 +446,32 @@ def growable_leaves(split_tree, affluence_tree):
|
|
|
397
446
|
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
398
447
|
that are not at the bottom level and have at least two times the number
|
|
399
448
|
of minimum points per leaf.
|
|
400
|
-
allowed : bool
|
|
401
|
-
Whether the grow move is allowed, i.e., there are growable leaves.
|
|
402
449
|
"""
|
|
403
450
|
is_growable = grove.is_actual_leaf(split_tree)
|
|
404
451
|
if affluence_tree is not None:
|
|
405
452
|
is_growable &= affluence_tree
|
|
406
|
-
return is_growable
|
|
453
|
+
return is_growable
|
|
407
454
|
|
|
408
|
-
def
|
|
455
|
+
def categorical(key, distr):
|
|
409
456
|
"""
|
|
410
|
-
Return a random integer
|
|
457
|
+
Return a random integer from an arbitrary distribution.
|
|
411
458
|
|
|
412
459
|
Parameters
|
|
413
460
|
----------
|
|
414
461
|
key : jax.dtypes.prng_key array
|
|
415
462
|
A jax random key.
|
|
416
|
-
|
|
417
|
-
|
|
463
|
+
distr : float array (n,)
|
|
464
|
+
An unnormalized probability distribution.
|
|
418
465
|
|
|
419
466
|
Returns
|
|
420
467
|
-------
|
|
421
468
|
u : int
|
|
422
|
-
A random integer in the range ``[0, n)
|
|
423
|
-
|
|
469
|
+
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
470
|
+
return ``n``.
|
|
424
471
|
"""
|
|
425
|
-
ecdf = jnp.cumsum(
|
|
426
|
-
u = random.
|
|
427
|
-
return jnp.searchsorted(ecdf, u, 'right')
|
|
472
|
+
ecdf = jnp.cumsum(distr)
|
|
473
|
+
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
474
|
+
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
428
475
|
|
|
429
476
|
def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
|
|
430
477
|
"""
|
|
@@ -479,7 +526,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
|
479
526
|
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
480
527
|
particular order. Variables may appear more than once.
|
|
481
528
|
"""
|
|
482
|
-
|
|
529
|
+
|
|
483
530
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
484
531
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
485
532
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
@@ -611,7 +658,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
611
658
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
612
659
|
return random.randint(key, (), l, r)
|
|
613
660
|
|
|
614
|
-
def compute_partial_ratio(
|
|
661
|
+
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
|
|
615
662
|
"""
|
|
616
663
|
Compute the product of the transition and prior ratios of a grow move.
|
|
617
664
|
|
|
@@ -640,6 +687,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
640
687
|
# the two ratios also contain factors num_available_split *
|
|
641
688
|
# num_available_var, but they cancel out
|
|
642
689
|
|
|
690
|
+
# p_prune can't be computed here because it needs the count trees, which are
|
|
691
|
+
# computed in the acceptance phase
|
|
692
|
+
|
|
643
693
|
prune_allowed = leaf_to_grow != 1
|
|
644
694
|
# prune allowed <---> the initial tree is not a root
|
|
645
695
|
# leaf to grow is root --> the tree can only be a root
|
|
@@ -647,31 +697,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
647
697
|
|
|
648
698
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
649
699
|
|
|
650
|
-
|
|
700
|
+
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
651
701
|
|
|
652
702
|
depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
|
|
653
703
|
p_parent = p_nonterminal[depth]
|
|
654
704
|
cp_children = 1 - p_nonterminal[depth + 1]
|
|
655
705
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
656
706
|
|
|
657
|
-
return
|
|
707
|
+
return tree_ratio / inv_trans_ratio
|
|
658
708
|
|
|
659
|
-
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
|
|
709
|
+
def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
|
|
660
710
|
"""
|
|
661
711
|
Tree structure prune move proposal of BART MCMC.
|
|
662
712
|
|
|
663
713
|
Parameters
|
|
664
714
|
----------
|
|
665
|
-
var_tree : array (2 ** (d - 1),)
|
|
715
|
+
var_tree : int array (2 ** (d - 1),)
|
|
666
716
|
The variable indices of the tree.
|
|
667
|
-
split_tree : array (2 ** (d - 1),)
|
|
717
|
+
split_tree : int array (2 ** (d - 1),)
|
|
668
718
|
The splitting points of the tree.
|
|
669
719
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
670
720
|
Whether a leaf has enough points to be grown.
|
|
671
|
-
max_split : array (p,)
|
|
721
|
+
max_split : int array (p,)
|
|
672
722
|
The maximum split index for each variable.
|
|
673
|
-
p_nonterminal : array (d,)
|
|
723
|
+
p_nonterminal : float array (d,)
|
|
674
724
|
The probability of a nonterminal node at each depth.
|
|
725
|
+
p_propose_grow : float array (2 ** (d - 1),)
|
|
726
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
675
727
|
key : jax.dtypes.prng_key array
|
|
676
728
|
A jax random key.
|
|
677
729
|
|
|
@@ -683,24 +735,29 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
683
735
|
'allowed' : bool
|
|
684
736
|
Whether the move is possible.
|
|
685
737
|
'node' : int
|
|
686
|
-
The index of the node to prune.
|
|
738
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
739
|
+
'left', 'right' : int
|
|
740
|
+
The indices of the children of 'node'.
|
|
687
741
|
'partial_ratio' : float
|
|
688
742
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
689
743
|
the likelihood ratio and the probability of proposing the prune
|
|
690
744
|
move. This ratio is inverted.
|
|
691
745
|
"""
|
|
692
|
-
node_to_prune, num_prunable,
|
|
746
|
+
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
|
|
693
747
|
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
694
748
|
|
|
695
|
-
ratio = compute_partial_ratio(
|
|
749
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune, split_tree)
|
|
696
750
|
|
|
751
|
+
left = node_to_prune << 1
|
|
697
752
|
return dict(
|
|
698
753
|
allowed=allowed,
|
|
699
754
|
node=node_to_prune,
|
|
755
|
+
left=left,
|
|
756
|
+
right=left + 1,
|
|
700
757
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
701
758
|
)
|
|
702
759
|
|
|
703
|
-
def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
760
|
+
def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
704
761
|
"""
|
|
705
762
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
706
763
|
|
|
@@ -710,6 +767,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
710
767
|
The splitting points of the tree.
|
|
711
768
|
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
712
769
|
Whether a leaf has enough points to be grown.
|
|
770
|
+
p_propose_grow : array (2 ** (d - 1),)
|
|
771
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
713
772
|
key : jax.dtypes.prng_key array
|
|
714
773
|
A jax random key.
|
|
715
774
|
|
|
@@ -717,28 +776,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
717
776
|
-------
|
|
718
777
|
node_to_prune : int
|
|
719
778
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
720
|
-
``
|
|
779
|
+
``2 ** d``.
|
|
721
780
|
num_prunable : int
|
|
722
781
|
The number of leaf parents that could be pruned.
|
|
723
|
-
|
|
724
|
-
The
|
|
725
|
-
node.
|
|
782
|
+
prob_choose : float
|
|
783
|
+
The normalized probability of choosing the node to prune for growth.
|
|
726
784
|
"""
|
|
727
785
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
728
|
-
node_to_prune = randint_masked(key, is_prunable)
|
|
729
786
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
787
|
+
node_to_prune = randint_masked(key, is_prunable)
|
|
788
|
+
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
730
789
|
|
|
731
|
-
|
|
732
|
-
|
|
790
|
+
split_tree = split_tree.at[node_to_prune].set(0)
|
|
791
|
+
affluence_tree = (
|
|
733
792
|
None if affluence_tree is None else
|
|
734
793
|
affluence_tree.at[node_to_prune].set(True)
|
|
735
794
|
)
|
|
736
|
-
is_growable_leaf
|
|
737
|
-
|
|
795
|
+
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
796
|
+
prob_choose = p_propose_grow[node_to_prune]
|
|
797
|
+
prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
798
|
+
|
|
799
|
+
return node_to_prune, num_prunable, prob_choose
|
|
800
|
+
|
|
801
|
+
def randint_masked(key, mask):
|
|
802
|
+
"""
|
|
803
|
+
Return a random integer in a range, including only some values.
|
|
804
|
+
|
|
805
|
+
Parameters
|
|
806
|
+
----------
|
|
807
|
+
key : jax.dtypes.prng_key array
|
|
808
|
+
A jax random key.
|
|
809
|
+
mask : bool array (n,)
|
|
810
|
+
The mask indicating the allowed values.
|
|
738
811
|
|
|
739
|
-
|
|
812
|
+
Returns
|
|
813
|
+
-------
|
|
814
|
+
u : int
|
|
815
|
+
A random integer in the range ``[0, n)``, and which satisfies
|
|
816
|
+
``mask[u] == True``. If all values in the mask are `False`, return `n`.
|
|
817
|
+
"""
|
|
818
|
+
ecdf = jnp.cumsum(mask)
|
|
819
|
+
u = random.randint(key, (), 0, ecdf[-1])
|
|
820
|
+
return jnp.searchsorted(ecdf, u, 'right')
|
|
740
821
|
|
|
741
|
-
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves,
|
|
822
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
742
823
|
"""
|
|
743
824
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
744
825
|
|
|
@@ -752,8 +833,6 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
752
833
|
prune_moves : dict
|
|
753
834
|
The proposals for prune moves, batched over the first axis. See
|
|
754
835
|
`prune_move`.
|
|
755
|
-
grow_leaf_indices : int array (num_trees, n)
|
|
756
|
-
The leaf indices of the trees proposed by the grow move.
|
|
757
836
|
key : jax.dtypes.prng_key array
|
|
758
837
|
A jax random key.
|
|
759
838
|
|
|
@@ -762,41 +841,339 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
|
|
|
762
841
|
bart : dict
|
|
763
842
|
The new BART mcmc state.
|
|
764
843
|
"""
|
|
844
|
+
bart, grow_moves, prune_moves, count_trees, move_counts, u, z = accept_moves_parallel_stage(bart, grow_moves, prune_moves, key)
|
|
845
|
+
bart, counts = accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z)
|
|
846
|
+
return accept_moves_final_stage(bart, counts, grow_moves, prune_moves)
|
|
847
|
+
|
|
848
|
+
def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
|
|
849
|
+
"""
|
|
850
|
+
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
851
|
+
|
|
852
|
+
Parameters
|
|
853
|
+
----------
|
|
854
|
+
bart : dict
|
|
855
|
+
A BART mcmc state.
|
|
856
|
+
grow_moves, prune_moves : dict
|
|
857
|
+
The proposals for the moves, batched over the first axis. See
|
|
858
|
+
`grow_move` and `prune_move`.
|
|
859
|
+
key : jax.dtypes.prng_key array
|
|
860
|
+
A jax random key.
|
|
861
|
+
|
|
862
|
+
Returns
|
|
863
|
+
-------
|
|
864
|
+
bart : dict
|
|
865
|
+
A partially updated BART mcmc state.
|
|
866
|
+
grow_moves, prune_moves : dict
|
|
867
|
+
The proposals for the moves, with the field 'partial_ratio' replaced
|
|
868
|
+
by 'trans_prior_ratio'.
|
|
869
|
+
count_trees : array (num_trees, 2 ** (d - 1))
|
|
870
|
+
The number of points in each potential or actual leaf node.
|
|
871
|
+
move_counts : dict
|
|
872
|
+
The counts of the number of points in the the nodes modified by the
|
|
873
|
+
moves.
|
|
874
|
+
u : float array (num_trees, 2)
|
|
875
|
+
Random uniform values used to accept the moves.
|
|
876
|
+
z : float array (num_trees, 2 ** d)
|
|
877
|
+
Random standard normal values used to sample the new leaf values.
|
|
878
|
+
"""
|
|
765
879
|
bart = bart.copy()
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
880
|
+
|
|
881
|
+
bart['var_trees'] = grow_moves['var_tree']
|
|
882
|
+
# Since var_tree can contain garbage, I can set the var of leaf to be
|
|
883
|
+
# grown irrespectively of what move I'm gonna accept in the end.
|
|
884
|
+
|
|
885
|
+
bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
|
|
886
|
+
|
|
887
|
+
count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
|
|
888
|
+
|
|
889
|
+
grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
|
|
890
|
+
|
|
891
|
+
if bart['opt']['require_min_points']:
|
|
892
|
+
count_half_trees = count_trees[:, :grow_moves['split_tree'].shape[1]]
|
|
893
|
+
bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
|
|
894
|
+
|
|
895
|
+
bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], grow_moves)
|
|
896
|
+
|
|
897
|
+
key, subkey = random.split(key)
|
|
898
|
+
u = random.uniform(subkey, (len(bart['leaf_trees']), 2), bart['opt']['large_float'])
|
|
899
|
+
z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
|
|
900
|
+
|
|
901
|
+
return bart, grow_moves, prune_moves, count_trees, move_counts, u, z
|
|
902
|
+
|
|
903
|
+
def apply_grow_to_indices(grow_moves, leaf_indices, X):
|
|
904
|
+
"""
|
|
905
|
+
Update the leaf indices to apply a grow move.
|
|
906
|
+
|
|
907
|
+
Parameters
|
|
908
|
+
----------
|
|
909
|
+
grow_moves : dict
|
|
910
|
+
The proposals for grow moves. See `grow_move`.
|
|
911
|
+
leaf_indices : array (num_trees, n)
|
|
912
|
+
The index of the leaf each datapoint falls into.
|
|
913
|
+
X : array (p, n)
|
|
914
|
+
The predictors matrix.
|
|
915
|
+
|
|
916
|
+
Returns
|
|
917
|
+
-------
|
|
918
|
+
grow_leaf_indices : array (num_trees, n)
|
|
919
|
+
The updated leaf indices.
|
|
920
|
+
"""
|
|
921
|
+
left_child = grow_moves['node'].astype(leaf_indices.dtype) << 1
|
|
922
|
+
go_right = X[grow_moves['var'], :] >= grow_moves['split'][:, None]
|
|
923
|
+
tree_size = jnp.array(2 * grow_moves['split_tree'].shape[1])
|
|
924
|
+
node_to_update = jnp.where(grow_moves['num_growable'], grow_moves['node'], tree_size)
|
|
925
|
+
return jnp.where(
|
|
926
|
+
leaf_indices == node_to_update[:, None],
|
|
927
|
+
left_child[:, None] + go_right,
|
|
928
|
+
leaf_indices,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
|
|
932
|
+
"""
|
|
933
|
+
Count the number of datapoints in each leaf.
|
|
934
|
+
|
|
935
|
+
Parameters
|
|
936
|
+
----------
|
|
937
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
938
|
+
The index of the leaf each datapoint falls into, if the grow move is
|
|
939
|
+
accepted.
|
|
940
|
+
grow_moves, prune_moves : dict
|
|
941
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
942
|
+
batch_size : int or None
|
|
943
|
+
The data batch size to use for the summation.
|
|
944
|
+
|
|
945
|
+
Returns
|
|
946
|
+
-------
|
|
947
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
948
|
+
The number of points in each potential or actual leaf node.
|
|
949
|
+
counts : dict
|
|
950
|
+
The counts of the number of points in the the nodes modified by the
|
|
951
|
+
moves, organized as two dictionaries 'grow' and 'prune', with subfields
|
|
952
|
+
'left', 'right', and 'total'.
|
|
953
|
+
"""
|
|
954
|
+
|
|
955
|
+
ntree, tree_size = grow_moves['split_tree'].shape
|
|
956
|
+
tree_size *= 2
|
|
957
|
+
counts = dict(grow=dict(), prune=dict())
|
|
958
|
+
tree_indices = jnp.arange(ntree)
|
|
959
|
+
|
|
960
|
+
count_trees = count_datapoints_per_leaf(grow_leaf_indices, tree_size, batch_size)
|
|
961
|
+
|
|
962
|
+
# count datapoints in leaf to grow
|
|
963
|
+
counts['grow']['left'] = count_trees[tree_indices, grow_moves['left']]
|
|
964
|
+
counts['grow']['right'] = count_trees[tree_indices, grow_moves['right']]
|
|
965
|
+
counts['grow']['total'] = counts['grow']['left'] + counts['grow']['right']
|
|
966
|
+
count_trees = count_trees.at[tree_indices, grow_moves['node']].set(counts['grow']['total'])
|
|
967
|
+
|
|
968
|
+
# count datapoints in node to prune
|
|
969
|
+
counts['prune']['left'] = count_trees[tree_indices, prune_moves['left']]
|
|
970
|
+
counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
|
|
971
|
+
counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
|
|
972
|
+
count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
|
|
973
|
+
|
|
974
|
+
return count_trees, counts
|
|
975
|
+
|
|
976
|
+
def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
977
|
+
"""
|
|
978
|
+
Count the number of datapoints in each leaf.
|
|
979
|
+
|
|
980
|
+
Parameters
|
|
981
|
+
----------
|
|
982
|
+
leaf_indices : int array (num_trees, n)
|
|
983
|
+
The index of the leaf each datapoint falls into.
|
|
984
|
+
tree_size : int
|
|
985
|
+
The size of the leaf tree array (2 ** d).
|
|
986
|
+
batch_size : int or None
|
|
987
|
+
The data batch size to use for the summation.
|
|
988
|
+
|
|
989
|
+
Returns
|
|
990
|
+
-------
|
|
991
|
+
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
992
|
+
The number of points in each leaf node.
|
|
993
|
+
"""
|
|
994
|
+
if batch_size is None:
|
|
995
|
+
return _count_scan(leaf_indices, tree_size)
|
|
996
|
+
else:
|
|
997
|
+
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
998
|
+
|
|
999
|
+
def _count_scan(leaf_indices, tree_size):
|
|
1000
|
+
def loop(_, leaf_indices):
|
|
1001
|
+
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1002
|
+
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1003
|
+
return count_trees
|
|
1004
|
+
|
|
1005
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
1006
|
+
return (jnp
|
|
1007
|
+
.zeros(size, dtype)
|
|
1008
|
+
.at[indices]
|
|
1009
|
+
.add(values)
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
def _count_vec(leaf_indices, tree_size, batch_size):
|
|
1013
|
+
return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
|
|
1014
|
+
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1015
|
+
|
|
1016
|
+
def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
|
|
1017
|
+
ntree, n = indices.shape
|
|
1018
|
+
tree_indices = jnp.arange(ntree)
|
|
1019
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1020
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
1021
|
+
return (jnp
|
|
1022
|
+
.zeros((ntree, size, nbatches), dtype)
|
|
1023
|
+
.at[tree_indices[:, None], indices, batch_indices]
|
|
1024
|
+
.add(values)
|
|
1025
|
+
.sum(axis=2)
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
|
|
1029
|
+
"""
|
|
1030
|
+
Complete non-likelihood MH ratio calculation.
|
|
1031
|
+
|
|
1032
|
+
This functions adds the probability of choosing the prune move.
|
|
1033
|
+
|
|
1034
|
+
Parameters
|
|
1035
|
+
----------
|
|
1036
|
+
grow_moves, prune_moves : dict
|
|
1037
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1038
|
+
move_counts : dict
|
|
1039
|
+
The counts of the number of points in the the nodes modified by the
|
|
1040
|
+
moves.
|
|
1041
|
+
min_points_per_leaf : int or None
|
|
1042
|
+
The minimum number of data points in a leaf node.
|
|
1043
|
+
|
|
1044
|
+
Returns
|
|
1045
|
+
-------
|
|
1046
|
+
grow_moves, prune_moves : dict
|
|
1047
|
+
The proposals for the moves, with the field 'partial_ratio' replaced
|
|
1048
|
+
by 'trans_prior_ratio'.
|
|
1049
|
+
"""
|
|
1050
|
+
grow_moves = grow_moves.copy()
|
|
1051
|
+
prune_moves = prune_moves.copy()
|
|
1052
|
+
compute_p_prune_vec = jax.vmap(compute_p_prune, in_axes=(0, 0, 0, None))
|
|
1053
|
+
grow_p_prune, prune_p_prune = compute_p_prune_vec(grow_moves, move_counts['grow']['left'], move_counts['grow']['right'], min_points_per_leaf)
|
|
1054
|
+
grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
|
|
1055
|
+
prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
|
|
1056
|
+
return grow_moves, prune_moves
|
|
1057
|
+
|
|
1058
|
+
def compute_p_prune(grow_move, grow_left_count, grow_right_count, min_points_per_leaf):
|
|
1059
|
+
"""
|
|
1060
|
+
Compute the probability of proposing a prune move.
|
|
1061
|
+
|
|
1062
|
+
Parameters
|
|
1063
|
+
----------
|
|
1064
|
+
grow_move : dict
|
|
1065
|
+
The proposal for the grow move, see `grow_move`.
|
|
1066
|
+
grow_left_count, grow_right_count : int
|
|
1067
|
+
The number of datapoints in the proposed children of the leaf to grow.
|
|
1068
|
+
min_points_per_leaf : int or None
|
|
1069
|
+
The minimum number of data points in a leaf node.
|
|
1070
|
+
|
|
1071
|
+
Returns
|
|
1072
|
+
-------
|
|
1073
|
+
grow_p_prune : float
|
|
1074
|
+
The probability of proposing a prune move, after accepting the grow
|
|
1075
|
+
move.
|
|
1076
|
+
prune_p_prune : float
|
|
1077
|
+
The probability of proposing the prune move.
|
|
1078
|
+
"""
|
|
1079
|
+
other_growable_leaves = grow_move['num_growable'] >= 2
|
|
1080
|
+
new_leaves_growable = grow_move['node'] < grow_move['split_tree'].size // 2
|
|
1081
|
+
if min_points_per_leaf is not None:
|
|
1082
|
+
any_above_threshold = grow_left_count >= 2 * min_points_per_leaf
|
|
1083
|
+
any_above_threshold |= grow_right_count >= 2 * min_points_per_leaf
|
|
1084
|
+
new_leaves_growable &= any_above_threshold
|
|
1085
|
+
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1086
|
+
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1087
|
+
prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
|
|
1088
|
+
return grow_p_prune, prune_p_prune
|
|
1089
|
+
|
|
1090
|
+
def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
|
|
1091
|
+
"""
|
|
1092
|
+
Modify leaf values such that the indices of the grow move work on the
|
|
1093
|
+
original tree.
|
|
1094
|
+
|
|
1095
|
+
Parameters
|
|
1096
|
+
----------
|
|
1097
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1098
|
+
The leaf values.
|
|
1099
|
+
grow_moves : dict
|
|
1100
|
+
The proposals for grow moves. See `grow_move`.
|
|
1101
|
+
|
|
1102
|
+
Returns
|
|
1103
|
+
-------
|
|
1104
|
+
leaf_trees : float array (num_trees, 2 ** d)
|
|
1105
|
+
The modified leaf values. The value of the leaf to grow is copied to
|
|
1106
|
+
what would be its children if the grow move was accepted.
|
|
1107
|
+
"""
|
|
1108
|
+
ntree, _ = leaf_trees.shape
|
|
1109
|
+
tree_indices = jnp.arange(ntree)
|
|
1110
|
+
values_at_node = leaf_trees[tree_indices, grow_moves['node']]
|
|
1111
|
+
return (leaf_trees
|
|
1112
|
+
.at[tree_indices, grow_moves['left']]
|
|
1113
|
+
.set(values_at_node)
|
|
1114
|
+
.at[tree_indices, grow_moves['right']]
|
|
1115
|
+
.set(values_at_node)
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z):
|
|
1119
|
+
"""
|
|
1120
|
+
The part of accepting the moves that has to be done one tree at a time.
|
|
1121
|
+
|
|
1122
|
+
Parameters
|
|
1123
|
+
----------
|
|
1124
|
+
bart : dict
|
|
1125
|
+
A partially updated BART mcmc state.
|
|
1126
|
+
count_trees : array (num_trees, 2 ** (d - 1))
|
|
1127
|
+
The number of points in each potential or actual leaf node.
|
|
1128
|
+
grow_moves, prune_moves : dict
|
|
1129
|
+
The proposals for the moves, with completed ratios. See `grow_move` and
|
|
1130
|
+
`prune_move`.
|
|
1131
|
+
move_counts : dict
|
|
1132
|
+
The counts of the number of points in the the nodes modified by the
|
|
1133
|
+
moves.
|
|
1134
|
+
u : float array (num_trees, 2)
|
|
1135
|
+
Random uniform values used to for proposal and accept decisions.
|
|
1136
|
+
z : float array (num_trees, 2 ** d)
|
|
1137
|
+
Random standard normal values used to sample the new leaf values.
|
|
1138
|
+
|
|
1139
|
+
Returns
|
|
1140
|
+
-------
|
|
1141
|
+
bart : dict
|
|
1142
|
+
A partially updated BART mcmc state.
|
|
1143
|
+
counts : dict
|
|
1144
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1145
|
+
"""
|
|
1146
|
+
bart = bart.copy()
|
|
1147
|
+
|
|
1148
|
+
def loop(resid, item):
|
|
1149
|
+
resid, leaf_tree, split_tree, counts, ratios = accept_move_and_sample_leaves(
|
|
769
1150
|
bart['X'],
|
|
770
1151
|
len(bart['leaf_trees']),
|
|
771
|
-
bart['opt']['
|
|
1152
|
+
bart['opt']['resid_batch_size'],
|
|
772
1153
|
resid,
|
|
773
1154
|
bart['sigma2'],
|
|
774
1155
|
bart['min_points_per_leaf'],
|
|
775
|
-
|
|
1156
|
+
'ratios' in bart,
|
|
776
1157
|
*item,
|
|
777
1158
|
)
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
carry = {
|
|
781
|
-
k: jnp.zeros_like(bart[k]) for k in
|
|
782
|
-
['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
|
|
783
|
-
}
|
|
784
|
-
carry['resid'] = bart['resid']
|
|
1159
|
+
return resid, (leaf_tree, split_tree, counts, ratios)
|
|
1160
|
+
|
|
785
1161
|
items = (
|
|
786
|
-
bart['leaf_trees'],
|
|
787
|
-
|
|
788
|
-
bart['
|
|
789
|
-
|
|
790
|
-
prune_moves,
|
|
791
|
-
grow_leaf_indices,
|
|
792
|
-
random.split(key, len(bart['leaf_trees'])),
|
|
1162
|
+
bart['leaf_trees'], count_trees,
|
|
1163
|
+
grow_moves, prune_moves, move_counts,
|
|
1164
|
+
bart['leaf_indices'],
|
|
1165
|
+
u, z,
|
|
793
1166
|
)
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
bart
|
|
797
|
-
|
|
1167
|
+
resid, (leaf_trees, split_trees, counts, ratios) = lax.scan(loop, bart['resid'], items)
|
|
1168
|
+
|
|
1169
|
+
bart['resid'] = resid
|
|
1170
|
+
bart['leaf_trees'] = leaf_trees
|
|
1171
|
+
bart['split_trees'] = split_trees
|
|
1172
|
+
bart.get('ratios', {}).update(ratios)
|
|
798
1173
|
|
|
799
|
-
|
|
1174
|
+
return bart, counts
|
|
1175
|
+
|
|
1176
|
+
def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min_points_per_leaf, save_ratios, leaf_tree, count_tree, grow_move, prune_move, move_counts, grow_leaf_indices, u, z):
|
|
800
1177
|
"""
|
|
801
1178
|
Accept or reject a proposed move and sample the new leaf values.
|
|
802
1179
|
|
|
@@ -806,158 +1183,157 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
|
|
|
806
1183
|
The predictors.
|
|
807
1184
|
ntree : int
|
|
808
1185
|
The number of trees in the forest.
|
|
809
|
-
|
|
810
|
-
The batch size for computing
|
|
1186
|
+
resid_batch_size : int, None
|
|
1187
|
+
The batch size for computing the sum of residuals in each leaf.
|
|
811
1188
|
resid : float array (n,)
|
|
812
1189
|
The residuals (data minus forest value).
|
|
813
1190
|
sigma2 : float
|
|
814
1191
|
The noise variance.
|
|
815
1192
|
min_points_per_leaf : int or None
|
|
816
1193
|
The minimum number of data points in a leaf node.
|
|
817
|
-
|
|
818
|
-
|
|
1194
|
+
save_ratios : bool
|
|
1195
|
+
Whether to save the acceptance ratios.
|
|
819
1196
|
leaf_tree : float array (2 ** d,)
|
|
820
1197
|
The leaf values of the tree.
|
|
821
|
-
|
|
822
|
-
The
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
The proposal for the grow move. See `grow_move`.
|
|
827
|
-
prune_move : dict
|
|
828
|
-
The proposal for the prune move. See `prune_move`.
|
|
1198
|
+
count_tree : int array (2 ** d,)
|
|
1199
|
+
The number of datapoints in each leaf.
|
|
1200
|
+
grow_move, prune_move : dict
|
|
1201
|
+
The proposals for the moves, with completed ratios. See `grow_move` and
|
|
1202
|
+
`prune_move`.
|
|
829
1203
|
grow_leaf_indices : int array (n,)
|
|
830
1204
|
The leaf indices of the tree proposed by the grow move.
|
|
831
|
-
|
|
832
|
-
|
|
1205
|
+
u : float array (2,)
|
|
1206
|
+
Two uniform random values in [0, 1).
|
|
1207
|
+
z : float array (2 ** d,)
|
|
1208
|
+
Standard normal random values.
|
|
833
1209
|
|
|
834
1210
|
Returns
|
|
835
1211
|
-------
|
|
836
1212
|
resid : float array (n,)
|
|
837
1213
|
The updated residuals (data minus forest value).
|
|
1214
|
+
leaf_tree : float array (2 ** d,)
|
|
1215
|
+
The new leaf values of the tree.
|
|
1216
|
+
split_tree : int array (2 ** (d - 1),)
|
|
1217
|
+
The updated decision boundaries of the tree.
|
|
838
1218
|
counts : dict
|
|
839
|
-
The
|
|
840
|
-
|
|
841
|
-
The
|
|
1219
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1220
|
+
ratios : dict
|
|
1221
|
+
The acceptance ratios for the moves. Empty if not to be saved.
|
|
842
1222
|
"""
|
|
843
|
-
|
|
844
|
-
# compute leaf indices in starting tree
|
|
845
|
-
grow_node = grow_move['node']
|
|
846
|
-
grow_left = grow_node << 1
|
|
847
|
-
grow_right = grow_left + 1
|
|
848
|
-
leaf_indices = jnp.where(
|
|
849
|
-
(grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
|
|
850
|
-
grow_node,
|
|
851
|
-
grow_leaf_indices,
|
|
852
|
-
)
|
|
853
1223
|
|
|
854
|
-
#
|
|
855
|
-
|
|
856
|
-
prune_left = prune_node << 1
|
|
857
|
-
prune_right = prune_left + 1
|
|
858
|
-
prune_leaf_indices = jnp.where(
|
|
859
|
-
(leaf_indices == prune_left) | (leaf_indices == prune_right),
|
|
860
|
-
prune_node,
|
|
861
|
-
leaf_indices,
|
|
862
|
-
)
|
|
1224
|
+
# sum residuals and count units per leaf, in tree proposed by grow move
|
|
1225
|
+
resid_tree = sum_resid(resid, grow_leaf_indices, leaf_tree.size, resid_batch_size)
|
|
863
1226
|
|
|
864
1227
|
# subtract starting tree from function
|
|
865
|
-
|
|
1228
|
+
resid_tree += count_tree * leaf_tree
|
|
866
1229
|
|
|
867
|
-
#
|
|
868
|
-
|
|
1230
|
+
# get indices of grow move
|
|
1231
|
+
grow_node = grow_move['node']
|
|
1232
|
+
assert grow_node.dtype == jnp.int32
|
|
1233
|
+
grow_left = grow_move['left']
|
|
1234
|
+
grow_right = grow_move['right']
|
|
869
1235
|
|
|
870
|
-
#
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
.set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
|
|
1236
|
+
# sum residuals in leaf to grow
|
|
1237
|
+
grow_resid_left = resid_tree[grow_left]
|
|
1238
|
+
grow_resid_right = resid_tree[grow_right]
|
|
1239
|
+
grow_resid_total = grow_resid_left + grow_resid_right
|
|
1240
|
+
resid_tree = resid_tree.at[grow_node].set(grow_resid_total)
|
|
876
1241
|
|
|
877
|
-
#
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
1242
|
+
# get indices of prune move
|
|
1243
|
+
prune_node = prune_move['node']
|
|
1244
|
+
assert prune_node.dtype == jnp.int32
|
|
1245
|
+
prune_left = prune_move['left']
|
|
1246
|
+
prune_right = prune_move['right']
|
|
882
1247
|
|
|
883
|
-
#
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
1248
|
+
# sum residuals in node to prune
|
|
1249
|
+
prune_resid_left = resid_tree[prune_left]
|
|
1250
|
+
prune_resid_right = resid_tree[prune_right]
|
|
1251
|
+
prune_resid_total = prune_resid_left + prune_resid_right
|
|
1252
|
+
resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
|
|
887
1253
|
|
|
888
|
-
#
|
|
889
|
-
|
|
890
|
-
prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
|
|
1254
|
+
# Now resid_tree and count_tree contain correct values whatever move is
|
|
1255
|
+
# accepted.
|
|
891
1256
|
|
|
892
1257
|
# compute likelihood ratios
|
|
893
|
-
grow_lk_ratio = compute_likelihood_ratio(
|
|
894
|
-
prune_lk_ratio = compute_likelihood_ratio(
|
|
1258
|
+
grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
|
|
1259
|
+
prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
|
|
895
1260
|
|
|
896
1261
|
# compute acceptance ratios
|
|
897
|
-
grow_ratio =
|
|
898
|
-
|
|
1262
|
+
grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
|
|
1263
|
+
if min_points_per_leaf is not None:
|
|
1264
|
+
grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1265
|
+
grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1266
|
+
prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
|
|
899
1267
|
prune_ratio = lax.reciprocal(prune_ratio)
|
|
900
1268
|
|
|
901
|
-
#
|
|
902
|
-
|
|
903
|
-
|
|
1269
|
+
# save acceptance ratios
|
|
1270
|
+
ratios = {}
|
|
1271
|
+
if save_ratios:
|
|
1272
|
+
ratios.update(
|
|
1273
|
+
grow=dict(
|
|
1274
|
+
trans_prior=grow_move['trans_prior_ratio'],
|
|
1275
|
+
likelihood=grow_lk_ratio,
|
|
1276
|
+
),
|
|
1277
|
+
prune=dict(
|
|
1278
|
+
trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
|
|
1279
|
+
likelihood=lax.reciprocal(prune_lk_ratio),
|
|
1280
|
+
),
|
|
1281
|
+
)
|
|
904
1282
|
|
|
905
1283
|
# determine what move to propose (not proposing anything is an option)
|
|
906
|
-
|
|
907
|
-
|
|
1284
|
+
grow_allowed = grow_move['num_growable'].astype(bool)
|
|
1285
|
+
p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
|
|
1286
|
+
try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
|
|
908
1287
|
try_prune = prune_move['allowed'] & ~try_grow
|
|
909
1288
|
|
|
910
1289
|
# determine whether to accept the move
|
|
911
|
-
do_grow = try_grow & (
|
|
912
|
-
do_prune = try_prune & (
|
|
913
|
-
|
|
914
|
-
# pick
|
|
915
|
-
|
|
916
|
-
split_tree = jnp.where(do_grow,
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
|
|
924
|
-
count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
|
|
925
|
-
resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
|
|
926
|
-
count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
|
|
927
|
-
|
|
928
|
-
# update acceptance counts
|
|
929
|
-
counts = counts.copy()
|
|
930
|
-
counts['grow_prop_count'] += try_grow
|
|
931
|
-
counts['grow_acc_count'] += do_grow
|
|
932
|
-
counts['prune_prop_count'] += try_prune
|
|
933
|
-
counts['prune_acc_count'] += do_prune
|
|
934
|
-
|
|
935
|
-
# compute leaves posterior
|
|
936
|
-
prec_lk = count_tree / sigma2
|
|
1290
|
+
do_grow = try_grow & (u[1] < grow_ratio)
|
|
1291
|
+
do_prune = try_prune & (u[1] < prune_ratio)
|
|
1292
|
+
|
|
1293
|
+
# pick split tree for chosen move
|
|
1294
|
+
split_tree = grow_move['split_tree']
|
|
1295
|
+
split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
|
|
1296
|
+
split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
|
|
1297
|
+
# I can leave garbage in var_tree, resid_tree, count_tree
|
|
1298
|
+
|
|
1299
|
+
# compute leaves posterior and sample leaves
|
|
1300
|
+
inv_sigma2 = lax.reciprocal(sigma2)
|
|
1301
|
+
prec_lk = count_tree * inv_sigma2
|
|
937
1302
|
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
938
|
-
mean_post = resid_tree
|
|
939
|
-
|
|
940
|
-
# sample leaves
|
|
941
|
-
z = random.normal(key, mean_post.shape, mean_post.dtype)
|
|
1303
|
+
mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
1304
|
+
initial_leaf_tree = leaf_tree
|
|
942
1305
|
leaf_tree = mean_post + z * jnp.sqrt(var_post)
|
|
943
1306
|
|
|
944
|
-
#
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
1307
|
+
# copy leaves around such that the grow leaf indices select the right leaf
|
|
1308
|
+
leaf_tree = (leaf_tree
|
|
1309
|
+
.at[jnp.where(do_prune, prune_left, leaf_tree.size)]
|
|
1310
|
+
.set(leaf_tree[prune_node])
|
|
1311
|
+
.at[jnp.where(do_prune, prune_right, leaf_tree.size)]
|
|
1312
|
+
.set(leaf_tree[prune_node])
|
|
1313
|
+
)
|
|
1314
|
+
leaf_tree = (leaf_tree
|
|
1315
|
+
.at[jnp.where(do_grow, leaf_tree.size, grow_left)]
|
|
1316
|
+
.set(leaf_tree[grow_node])
|
|
1317
|
+
.at[jnp.where(do_grow, leaf_tree.size, grow_right)]
|
|
1318
|
+
.set(leaf_tree[grow_node])
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# replace old tree with new tree in function values
|
|
1322
|
+
resid += (initial_leaf_tree - leaf_tree)[grow_leaf_indices]
|
|
948
1323
|
|
|
949
|
-
# pack
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
1324
|
+
# pack proposal and acceptance indicators
|
|
1325
|
+
counts = dict(
|
|
1326
|
+
grow_prop_count=try_grow,
|
|
1327
|
+
grow_acc_count=do_grow,
|
|
1328
|
+
prune_prop_count=try_prune,
|
|
1329
|
+
prune_acc_count=do_prune,
|
|
1330
|
+
)
|
|
955
1331
|
|
|
956
|
-
return resid, counts,
|
|
1332
|
+
return resid, leaf_tree, split_tree, counts, ratios
|
|
957
1333
|
|
|
958
|
-
def
|
|
1334
|
+
def sum_resid(resid, leaf_indices, tree_size, batch_size):
|
|
959
1335
|
"""
|
|
960
|
-
|
|
1336
|
+
Sum the residuals in each leaf.
|
|
961
1337
|
|
|
962
1338
|
Parameters
|
|
963
1339
|
----------
|
|
@@ -968,104 +1344,56 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
|
968
1344
|
tree_size : int
|
|
969
1345
|
The size of the tree array (2 ** d).
|
|
970
1346
|
batch_size : int, None
|
|
971
|
-
The batch size for the aggregation. Batching increases numerical
|
|
1347
|
+
The data batch size for the aggregation. Batching increases numerical
|
|
972
1348
|
accuracy and parallelism.
|
|
973
1349
|
|
|
974
1350
|
Returns
|
|
975
1351
|
-------
|
|
976
1352
|
resid_tree : float array (2 ** d,)
|
|
977
1353
|
The sum of the residuals at data points in each leaf.
|
|
978
|
-
count_tree : int array (2 ** d,)
|
|
979
|
-
The number of data points in each leaf.
|
|
980
1354
|
"""
|
|
981
1355
|
if batch_size is None:
|
|
982
1356
|
aggr_func = _aggregate_scatter
|
|
983
1357
|
else:
|
|
984
|
-
aggr_func = functools.partial(
|
|
985
|
-
|
|
986
|
-
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
987
|
-
return resid_tree, count_tree
|
|
988
|
-
|
|
989
|
-
def _aggregate_scatter(values, indices, size, dtype):
|
|
990
|
-
return (jnp
|
|
991
|
-
.zeros(size, dtype)
|
|
992
|
-
.at[indices]
|
|
993
|
-
.add(values)
|
|
994
|
-
)
|
|
1358
|
+
aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
1359
|
+
return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
995
1360
|
|
|
996
|
-
def
|
|
997
|
-
|
|
998
|
-
|
|
1361
|
+
def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
1362
|
+
n, = indices.shape
|
|
1363
|
+
nbatches = n // batch_size + bool(n % batch_size)
|
|
1364
|
+
batch_indices = jnp.arange(n) % nbatches
|
|
999
1365
|
return (jnp
|
|
1000
|
-
.zeros((
|
|
1001
|
-
.at[
|
|
1366
|
+
.zeros((size, nbatches), dtype)
|
|
1367
|
+
.at[indices, batch_indices]
|
|
1002
1368
|
.add(values)
|
|
1003
|
-
.sum(axis=
|
|
1369
|
+
.sum(axis=1)
|
|
1004
1370
|
)
|
|
1005
1371
|
|
|
1006
|
-
def
|
|
1007
|
-
"""
|
|
1008
|
-
Compute the probability of proposing a prune move after doing a grow move.
|
|
1009
|
-
|
|
1010
|
-
Parameters
|
|
1011
|
-
----------
|
|
1012
|
-
new_split_tree : int array (2 ** (d - 1),)
|
|
1013
|
-
The decision boundaries of the tree, after the grow move.
|
|
1014
|
-
new_affluence_tree : bool array (2 ** (d - 1),)
|
|
1015
|
-
Which leaves have enough points to be grown, after the grow move.
|
|
1016
|
-
|
|
1017
|
-
Returns
|
|
1018
|
-
-------
|
|
1019
|
-
p_prune : float
|
|
1020
|
-
The probability of proposing a prune move after the grow move. This is
|
|
1021
|
-
0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
|
|
1022
|
-
at least the node just grown can be pruned.
|
|
1023
|
-
"""
|
|
1024
|
-
_, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
|
|
1025
|
-
return jnp.where(grow_again_allowed, 0.5, 1)
|
|
1026
|
-
|
|
1027
|
-
def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
|
|
1372
|
+
def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count, left_count, right_count, sigma2, n_tree):
|
|
1028
1373
|
"""
|
|
1029
1374
|
Compute the likelihood ratio of a grow move.
|
|
1030
1375
|
|
|
1031
1376
|
Parameters
|
|
1032
1377
|
----------
|
|
1033
|
-
|
|
1034
|
-
The sum of the residuals
|
|
1035
|
-
|
|
1036
|
-
The
|
|
1378
|
+
total_resid : float
|
|
1379
|
+
The sum of the residuals in the leaf to grow.
|
|
1380
|
+
left_resid, right_resid : float
|
|
1381
|
+
The sum of the residuals in the left/right child of the leaf to grow.
|
|
1382
|
+
total_count : int
|
|
1383
|
+
The number of datapoints in the leaf to grow.
|
|
1384
|
+
left_count, right_count : int
|
|
1385
|
+
The number of datapoints in the left/right child of the leaf to grow.
|
|
1037
1386
|
sigma2 : float
|
|
1038
1387
|
The noise variance.
|
|
1039
|
-
node : int
|
|
1040
|
-
The index of the leaf that has been grown.
|
|
1041
1388
|
n_tree : int
|
|
1042
1389
|
The number of trees in the forest.
|
|
1043
|
-
min_points_per_leaf : int or None
|
|
1044
|
-
The minimum number of data points in a leaf node.
|
|
1045
1390
|
|
|
1046
1391
|
Returns
|
|
1047
1392
|
-------
|
|
1048
1393
|
ratio : float
|
|
1049
1394
|
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1050
|
-
|
|
1051
|
-
Notes
|
|
1052
|
-
-----
|
|
1053
|
-
The ratio is set to 0 if the grow move would create leaves with not enough
|
|
1054
|
-
datapoints per leaf, although this is part of the prior rather than the
|
|
1055
|
-
likelihood.
|
|
1056
1395
|
"""
|
|
1057
1396
|
|
|
1058
|
-
left_child = node << 1
|
|
1059
|
-
right_child = left_child + 1
|
|
1060
|
-
|
|
1061
|
-
left_resid = resid_tree[left_child]
|
|
1062
|
-
right_resid = resid_tree[right_child]
|
|
1063
|
-
total_resid = left_resid + right_resid
|
|
1064
|
-
|
|
1065
|
-
left_count = count_tree[left_child]
|
|
1066
|
-
right_count = count_tree[right_child]
|
|
1067
|
-
total_count = left_count + right_count
|
|
1068
|
-
|
|
1069
1397
|
sigma_mu2 = 1 / n_tree
|
|
1070
1398
|
sigma2_left = sigma2 + left_count * sigma_mu2
|
|
1071
1399
|
sigma2_right = sigma2 + right_count * sigma_mu2
|
|
@@ -1079,13 +1407,67 @@ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_p
|
|
|
1079
1407
|
total_resid * total_resid / sigma2_total
|
|
1080
1408
|
)
|
|
1081
1409
|
|
|
1082
|
-
|
|
1410
|
+
return jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
|
|
1083
1411
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1412
|
+
def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
|
|
1413
|
+
"""
|
|
1414
|
+
The final part of accepting the moves, in parallel across trees.
|
|
1415
|
+
|
|
1416
|
+
Parameters
|
|
1417
|
+
----------
|
|
1418
|
+
bart : dict
|
|
1419
|
+
A partially updated BART mcmc state.
|
|
1420
|
+
counts : dict
|
|
1421
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1422
|
+
grow_moves, prune_moves : dict
|
|
1423
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1424
|
+
|
|
1425
|
+
Returns
|
|
1426
|
+
-------
|
|
1427
|
+
bart : dict
|
|
1428
|
+
The fully updated BART mcmc state.
|
|
1429
|
+
"""
|
|
1430
|
+
bart = bart.copy()
|
|
1431
|
+
|
|
1432
|
+
for k, v in counts.items():
|
|
1433
|
+
bart[k] = jnp.sum(v, axis=0)
|
|
1087
1434
|
|
|
1088
|
-
|
|
1435
|
+
bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
|
|
1436
|
+
|
|
1437
|
+
return bart
|
|
1438
|
+
|
|
1439
|
+
def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
|
|
1440
|
+
"""
|
|
1441
|
+
Update the leaf indices to match the accepted move.
|
|
1442
|
+
|
|
1443
|
+
Parameters
|
|
1444
|
+
----------
|
|
1445
|
+
leaf_indices : int array (num_trees, n)
|
|
1446
|
+
The index of the leaf each datapoint falls into, if the grow move was
|
|
1447
|
+
accepted.
|
|
1448
|
+
counts : dict
|
|
1449
|
+
The indicators of proposals and acceptances for grow and prune moves.
|
|
1450
|
+
grow_moves, prune_moves : dict
|
|
1451
|
+
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1452
|
+
|
|
1453
|
+
Returns
|
|
1454
|
+
-------
|
|
1455
|
+
leaf_indices : int array (num_trees, n)
|
|
1456
|
+
The updated leaf indices.
|
|
1457
|
+
"""
|
|
1458
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1459
|
+
cond = (leaf_indices & mask) == grow_moves['left'][:, None]
|
|
1460
|
+
leaf_indices = jnp.where(
|
|
1461
|
+
cond & ~counts['grow_acc_count'][:, None],
|
|
1462
|
+
grow_moves['node'][:, None].astype(leaf_indices.dtype),
|
|
1463
|
+
leaf_indices,
|
|
1464
|
+
)
|
|
1465
|
+
cond = (leaf_indices & mask) == prune_moves['left'][:, None]
|
|
1466
|
+
return jnp.where(
|
|
1467
|
+
cond & counts['prune_acc_count'][:, None],
|
|
1468
|
+
prune_moves['node'][:, None].astype(leaf_indices.dtype),
|
|
1469
|
+
leaf_indices,
|
|
1470
|
+
)
|
|
1089
1471
|
|
|
1090
1472
|
def sample_sigma(bart, key):
|
|
1091
1473
|
"""
|
|
@@ -1107,7 +1489,7 @@ def sample_sigma(bart, key):
|
|
|
1107
1489
|
|
|
1108
1490
|
resid = bart['resid']
|
|
1109
1491
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1110
|
-
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['
|
|
1492
|
+
norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
|
|
1111
1493
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1112
1494
|
|
|
1113
1495
|
sample = random.gamma(key, alpha)
|