bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +266 -113
- bartz/__init__.py +4 -12
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +62 -12
- bartz/jaxext.py +111 -37
- bartz/mcmcloop.py +419 -105
- bartz/mcmcstep.py +1528 -760
- bartz/prepcovars.py +25 -10
- {bartz-0.4.1.dist-info → bartz-0.6.0.dist-info}/METADATA +14 -16
- bartz-0.6.0.dist-info/RECORD +13 -0
- bartz-0.6.0.dist-info/WHEEL +4 -0
- bartz-0.4.1.dist-info/LICENSE +0 -21
- bartz-0.4.1.dist-info/RECORD +0 -13
- bartz-0.4.1.dist-info/WHEEL +0 -4
bartz/mcmcstep.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcstep.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -26,199 +26,292 @@
|
|
|
26
26
|
Functions that implement the BART posterior MCMC initialization and update step.
|
|
27
27
|
|
|
28
28
|
Functions that do MCMC steps operate by taking as input a bart state, and
|
|
29
|
-
outputting a new
|
|
30
|
-
modified.
|
|
29
|
+
outputting a new state. The inputs are not modified.
|
|
31
30
|
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
The main entry points are:
|
|
32
|
+
|
|
33
|
+
- `State`: The dataclass that represents a BART MCMC state.
|
|
34
|
+
- `init`: Creates an initial `State` from data and configurations.
|
|
35
|
+
- `step`: Performs one full MCMC step on a `State`, returning a new `State`.
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
|
-
import functools
|
|
37
38
|
import math
|
|
39
|
+
from dataclasses import replace
|
|
40
|
+
from functools import cache, partial
|
|
41
|
+
from typing import Any
|
|
38
42
|
|
|
39
43
|
import jax
|
|
40
|
-
from
|
|
44
|
+
from equinox import Module, field
|
|
45
|
+
from jax import lax, random
|
|
41
46
|
from jax import numpy as jnp
|
|
42
|
-
from
|
|
47
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
|
|
43
48
|
|
|
44
|
-
from . import jaxext
|
|
45
49
|
from . import grove
|
|
50
|
+
from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
|
|
51
|
+
|
|
46
52
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
53
|
+
class Forest(Module):
|
|
54
|
+
"""
|
|
55
|
+
Represents the MCMC state of a sum of trees.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
leaf_trees
|
|
60
|
+
The leaf values.
|
|
61
|
+
var_trees
|
|
62
|
+
The decision axes.
|
|
63
|
+
split_trees
|
|
64
|
+
The decision boundaries.
|
|
65
|
+
p_nonterminal
|
|
66
|
+
The probability of a nonterminal node at each depth, padded with a
|
|
67
|
+
zero.
|
|
68
|
+
p_propose_grow
|
|
69
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
70
|
+
leaf_indices
|
|
71
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
72
|
+
min_points_per_leaf
|
|
73
|
+
The minimum number of data points in a leaf node.
|
|
74
|
+
affluence_trees
|
|
75
|
+
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
76
|
+
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
77
|
+
resid_batch_size
|
|
78
|
+
count_batch_size
|
|
79
|
+
The data batch sizes for computing the sufficient statistics. If `None`,
|
|
80
|
+
they are computed with no batching.
|
|
81
|
+
log_trans_prior
|
|
82
|
+
The log transition and prior Metropolis-Hastings ratio for the
|
|
83
|
+
proposed move on each tree.
|
|
84
|
+
log_likelihood
|
|
85
|
+
The log likelihood ratio.
|
|
86
|
+
grow_prop_count
|
|
87
|
+
prune_prop_count
|
|
88
|
+
The number of grow/prune proposals made during one full MCMC cycle.
|
|
89
|
+
grow_acc_count
|
|
90
|
+
prune_acc_count
|
|
91
|
+
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
92
|
+
sigma_mu2
|
|
93
|
+
The prior variance of a leaf, conditional on the tree structure.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
leaf_trees: Float32[Array, 'num_trees 2**d']
|
|
97
|
+
var_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
98
|
+
split_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
99
|
+
p_nonterminal: Float32[Array, 'd']
|
|
100
|
+
p_propose_grow: Float32[Array, '2**(d-1)']
|
|
101
|
+
leaf_indices: UInt[Array, 'num_trees n']
|
|
102
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
103
|
+
affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
|
|
104
|
+
resid_batch_size: int | None = field(static=True)
|
|
105
|
+
count_batch_size: int | None = field(static=True)
|
|
106
|
+
log_trans_prior: Float32[Array, 'num_trees'] | None
|
|
107
|
+
log_likelihood: Float32[Array, 'num_trees'] | None
|
|
108
|
+
grow_prop_count: Int32[Array, '']
|
|
109
|
+
prune_prop_count: Int32[Array, '']
|
|
110
|
+
grow_acc_count: Int32[Array, '']
|
|
111
|
+
prune_acc_count: Int32[Array, '']
|
|
112
|
+
sigma_mu2: Float32[Array, '']
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class State(Module):
|
|
116
|
+
"""
|
|
117
|
+
Represents the MCMC state of BART.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
X
|
|
122
|
+
The predictors.
|
|
123
|
+
max_split
|
|
124
|
+
The maximum split index for each predictor.
|
|
125
|
+
y
|
|
126
|
+
The response. If the data type is `bool`, the model is binary regression.
|
|
127
|
+
resid
|
|
128
|
+
The residuals (`y` or `z` minus sum of trees).
|
|
129
|
+
z
|
|
130
|
+
The latent variable for binary regression. `None` in continuous
|
|
131
|
+
regression.
|
|
132
|
+
offset
|
|
133
|
+
Constant shift added to the sum of trees.
|
|
134
|
+
sigma2
|
|
135
|
+
The error variance. `None` in binary regression.
|
|
136
|
+
prec_scale
|
|
137
|
+
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
138
|
+
`None` in binary regression.
|
|
139
|
+
sigma2_alpha
|
|
140
|
+
sigma2_beta
|
|
141
|
+
The shape and scale parameters of the inverse gamma prior on the noise
|
|
142
|
+
variance. `None` in binary regression.
|
|
143
|
+
forest
|
|
144
|
+
The sum of trees model.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
X: UInt[Array, 'p n']
|
|
148
|
+
max_split: UInt[Array, 'p']
|
|
149
|
+
y: Float32[Array, 'n'] | Bool[Array, 'n']
|
|
150
|
+
z: None | Float32[Array, 'n']
|
|
151
|
+
offset: Float32[Array, '']
|
|
152
|
+
resid: Float32[Array, 'n']
|
|
153
|
+
sigma2: Float32[Array, ''] | None
|
|
154
|
+
prec_scale: Float32[Array, 'n'] | None
|
|
155
|
+
sigma2_alpha: Float32[Array, ''] | None
|
|
156
|
+
sigma2_beta: Float32[Array, ''] | None
|
|
157
|
+
forest: Forest
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def init(
|
|
161
|
+
*,
|
|
162
|
+
X: UInt[Any, 'p n'],
|
|
163
|
+
y: Float32[Any, 'n'] | Bool[Any, 'n'],
|
|
164
|
+
offset: float | Float32[Any, ''] = 0.0,
|
|
165
|
+
max_split: UInt[Any, 'p'],
|
|
166
|
+
num_trees: int,
|
|
167
|
+
p_nonterminal: Float32[Any, 'd-1'],
|
|
168
|
+
sigma_mu2: float | Float32[Any, ''],
|
|
169
|
+
sigma2_alpha: float | Float32[Any, ''] | None = None,
|
|
170
|
+
sigma2_beta: float | Float32[Any, ''] | None = None,
|
|
171
|
+
error_scale: Float32[Any, 'n'] | None = None,
|
|
172
|
+
min_points_per_leaf: int | None = None,
|
|
173
|
+
resid_batch_size: int | None | str = 'auto',
|
|
174
|
+
count_batch_size: int | None | str = 'auto',
|
|
175
|
+
save_ratios: bool = False,
|
|
176
|
+
) -> State:
|
|
62
177
|
"""
|
|
63
178
|
Make a BART posterior sampling MCMC initial state.
|
|
64
179
|
|
|
65
180
|
Parameters
|
|
66
181
|
----------
|
|
67
|
-
X
|
|
182
|
+
X
|
|
68
183
|
The predictors. Note this is trasposed compared to the usual convention.
|
|
69
|
-
y
|
|
70
|
-
The response.
|
|
71
|
-
|
|
184
|
+
y
|
|
185
|
+
The response. If the data type is `bool`, the regression model is binary
|
|
186
|
+
regression with probit.
|
|
187
|
+
offset
|
|
188
|
+
Constant shift added to the sum of trees. 0 if not specified.
|
|
189
|
+
max_split
|
|
72
190
|
The maximum split index for each variable. All split ranges start at 1.
|
|
73
|
-
num_trees
|
|
191
|
+
num_trees
|
|
74
192
|
The number of trees in the forest.
|
|
75
|
-
p_nonterminal
|
|
193
|
+
p_nonterminal
|
|
76
194
|
The probability of a nonterminal node at each depth. The maximum depth
|
|
77
195
|
of trees is fixed by the length of this array.
|
|
78
|
-
|
|
79
|
-
The
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
196
|
+
sigma_mu2
|
|
197
|
+
The prior variance of a leaf, conditional on the tree structure. The
|
|
198
|
+
prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
|
|
199
|
+
prior mean of leaves is always zero.
|
|
200
|
+
sigma2_alpha
|
|
201
|
+
sigma2_beta
|
|
202
|
+
The shape and scale parameters of the inverse gamma prior on the error
|
|
203
|
+
variance. Leave unspecified for binary regression.
|
|
204
|
+
error_scale
|
|
205
|
+
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
206
|
+
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
207
|
+
Not supported for binary regression. If not specified, defaults to 1 for
|
|
208
|
+
all points, but potentially skipping calculations.
|
|
209
|
+
min_points_per_leaf
|
|
87
210
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
88
|
-
resid_batch_size
|
|
211
|
+
resid_batch_size
|
|
212
|
+
count_batch_size
|
|
89
213
|
The batch sizes, along datapoints, for summing the residuals and
|
|
90
214
|
counting the number of datapoints in each leaf. `None` for no batching.
|
|
91
215
|
If 'auto', pick a value based on the device of `y`, or the default
|
|
92
216
|
device.
|
|
93
|
-
save_ratios
|
|
217
|
+
save_ratios
|
|
94
218
|
Whether to save the Metropolis-Hastings ratios.
|
|
95
219
|
|
|
96
220
|
Returns
|
|
97
221
|
-------
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
'split_trees' : int array (num_trees, 2 ** (d - 1))
|
|
107
|
-
The decision boundaries.
|
|
108
|
-
'resid' : large_float array (n,)
|
|
109
|
-
The residuals (data minus forest value). Large float to avoid
|
|
110
|
-
roundoff.
|
|
111
|
-
'sigma2' : large_float
|
|
112
|
-
The noise variance.
|
|
113
|
-
'grow_prop_count', 'prune_prop_count' : int
|
|
114
|
-
The number of grow/prune proposals made during one full MCMC cycle.
|
|
115
|
-
'grow_acc_count', 'prune_acc_count' : int
|
|
116
|
-
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
117
|
-
'p_nonterminal' : large_float array (d,)
|
|
118
|
-
The probability of a nonterminal node at each depth, padded with a
|
|
119
|
-
zero.
|
|
120
|
-
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
121
|
-
The unnormalized probability of picking a leaf for a grow proposal.
|
|
122
|
-
'sigma2_alpha' : large_float
|
|
123
|
-
The shape parameter of the inverse gamma prior on the noise variance.
|
|
124
|
-
'sigma2_beta' : large_float
|
|
125
|
-
The scale parameter of the inverse gamma prior on the noise variance.
|
|
126
|
-
'max_split' : int array (p,)
|
|
127
|
-
The maximum split index for each variable.
|
|
128
|
-
'y' : small_float array (n,)
|
|
129
|
-
The response.
|
|
130
|
-
'X' : int array (p, n)
|
|
131
|
-
The predictors.
|
|
132
|
-
'leaf_indices' : int array (num_trees, n)
|
|
133
|
-
The index of the leaf each datapoints falls into, for each tree.
|
|
134
|
-
'min_points_per_leaf' : int or None
|
|
135
|
-
The minimum number of data points in a leaf node.
|
|
136
|
-
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
137
|
-
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
138
|
-
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
139
|
-
'opt' : LeafDict
|
|
140
|
-
A dictionary with config values:
|
|
141
|
-
|
|
142
|
-
'small_float' : dtype
|
|
143
|
-
The dtype for large arrays used in the algorithm.
|
|
144
|
-
'large_float' : dtype
|
|
145
|
-
The dtype for scalars, small arrays, and arrays which require
|
|
146
|
-
accuracy.
|
|
147
|
-
'require_min_points' : bool
|
|
148
|
-
Whether the `min_points_per_leaf` parameter is specified.
|
|
149
|
-
'resid_batch_size', 'count_batch_size' : int or None
|
|
150
|
-
The data batch sizes for computing the sufficient statistics.
|
|
151
|
-
'ratios' : dict, optional
|
|
152
|
-
If `save_ratios` is True, this field is present. It has the fields:
|
|
153
|
-
|
|
154
|
-
'log_trans_prior' : large_float array (num_trees,)
|
|
155
|
-
The log transition and prior Metropolis-Hastings ratio for the
|
|
156
|
-
proposed move on each tree.
|
|
157
|
-
'log_likelihood' : large_float array (num_trees,)
|
|
158
|
-
The log likelihood ratio.
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
222
|
+
An initialized BART MCMC state.
|
|
223
|
+
|
|
224
|
+
Raises
|
|
225
|
+
------
|
|
226
|
+
ValueError
|
|
227
|
+
If `y` is boolean and arguments unused in binary regression are set.
|
|
228
|
+
"""
|
|
229
|
+
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
162
230
|
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
163
231
|
max_depth = p_nonterminal.size
|
|
164
232
|
|
|
165
|
-
@
|
|
233
|
+
@partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
166
234
|
def make_forest(max_depth, dtype):
|
|
167
235
|
return grove.make_tree(max_depth, dtype)
|
|
168
236
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
237
|
+
y = jnp.asarray(y)
|
|
238
|
+
offset = jnp.asarray(offset)
|
|
239
|
+
|
|
240
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
241
|
+
resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
is_binary = y.dtype == bool
|
|
245
|
+
if is_binary:
|
|
246
|
+
if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
|
|
247
|
+
raise ValueError(
|
|
248
|
+
'error_scale, sigma2_alpha, and sigma2_beta must be set '
|
|
249
|
+
' to `None` for binary regression.'
|
|
250
|
+
)
|
|
251
|
+
sigma2 = None
|
|
252
|
+
else:
|
|
253
|
+
sigma2_alpha = jnp.asarray(sigma2_alpha)
|
|
254
|
+
sigma2_beta = jnp.asarray(sigma2_beta)
|
|
255
|
+
sigma2 = sigma2_beta / sigma2_alpha
|
|
256
|
+
# sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
257
|
+
# TODO: I don't like this isfinite check, these functions should be
|
|
258
|
+
# low-level and just do the thing. Why was it here?
|
|
259
|
+
|
|
260
|
+
bart = State(
|
|
261
|
+
X=jnp.asarray(X),
|
|
190
262
|
max_split=jnp.asarray(max_split),
|
|
191
263
|
y=y,
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
affluence_trees=(
|
|
199
|
-
None if min_points_per_leaf is None else
|
|
200
|
-
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
264
|
+
z=jnp.full(y.shape, offset) if is_binary else None,
|
|
265
|
+
offset=offset,
|
|
266
|
+
resid=jnp.zeros(y.shape) if is_binary else y - offset,
|
|
267
|
+
sigma2=sigma2,
|
|
268
|
+
prec_scale=(
|
|
269
|
+
None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
|
|
201
270
|
),
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
271
|
+
sigma2_alpha=sigma2_alpha,
|
|
272
|
+
sigma2_beta=sigma2_beta,
|
|
273
|
+
forest=Forest(
|
|
274
|
+
leaf_trees=make_forest(max_depth, jnp.float32),
|
|
275
|
+
var_trees=make_forest(
|
|
276
|
+
max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
|
|
277
|
+
),
|
|
278
|
+
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
279
|
+
grow_prop_count=jnp.zeros((), int),
|
|
280
|
+
grow_acc_count=jnp.zeros((), int),
|
|
281
|
+
prune_prop_count=jnp.zeros((), int),
|
|
282
|
+
prune_acc_count=jnp.zeros((), int),
|
|
283
|
+
p_nonterminal=p_nonterminal,
|
|
284
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
285
|
+
leaf_indices=jnp.ones(
|
|
286
|
+
(num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
|
|
287
|
+
),
|
|
288
|
+
min_points_per_leaf=(
|
|
289
|
+
None
|
|
290
|
+
if min_points_per_leaf is None
|
|
291
|
+
else jnp.asarray(min_points_per_leaf)
|
|
292
|
+
),
|
|
293
|
+
affluence_trees=(
|
|
294
|
+
None
|
|
295
|
+
if min_points_per_leaf is None
|
|
296
|
+
else make_forest(max_depth - 1, bool)
|
|
297
|
+
.at[:, 1]
|
|
298
|
+
.set(y.size >= 2 * min_points_per_leaf)
|
|
299
|
+
),
|
|
206
300
|
resid_batch_size=resid_batch_size,
|
|
207
301
|
count_batch_size=count_batch_size,
|
|
302
|
+
log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
|
|
303
|
+
log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
|
|
304
|
+
sigma_mu2=jnp.asarray(sigma_mu2),
|
|
208
305
|
),
|
|
209
306
|
)
|
|
210
307
|
|
|
211
|
-
if save_ratios:
|
|
212
|
-
bart['ratios'] = dict(
|
|
213
|
-
log_trans_prior=jnp.full(num_trees, jnp.nan),
|
|
214
|
-
log_likelihood=jnp.full(num_trees, jnp.nan),
|
|
215
|
-
)
|
|
216
|
-
|
|
217
308
|
return bart
|
|
218
309
|
|
|
219
|
-
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
|
|
220
310
|
|
|
221
|
-
|
|
311
|
+
def _choose_suffstat_batch_size(
|
|
312
|
+
resid_batch_size, count_batch_size, y, forest_size
|
|
313
|
+
) -> tuple[int | None, ...]:
|
|
314
|
+
@cache
|
|
222
315
|
def get_platform():
|
|
223
316
|
try:
|
|
224
317
|
device = y.devices().pop()
|
|
@@ -233,9 +326,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
233
326
|
platform = get_platform()
|
|
234
327
|
n = max(1, y.size)
|
|
235
328
|
if platform == 'cpu':
|
|
236
|
-
resid_batch_size = 2 ** int(round(math.log2(n / 6)))
|
|
329
|
+
resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
|
|
237
330
|
elif platform == 'gpu':
|
|
238
|
-
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3))
|
|
331
|
+
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
|
|
239
332
|
resid_batch_size = max(1, resid_batch_size)
|
|
240
333
|
|
|
241
334
|
if count_batch_size == 'auto':
|
|
@@ -244,9 +337,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
244
337
|
count_batch_size = None
|
|
245
338
|
elif platform == 'gpu':
|
|
246
339
|
n = max(1, y.size)
|
|
247
|
-
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2))
|
|
248
|
-
|
|
249
|
-
max_memory = 2
|
|
340
|
+
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
341
|
+
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
342
|
+
max_memory = 2**29
|
|
250
343
|
itemsize = 4
|
|
251
344
|
min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
|
|
252
345
|
count_batch_size = max(count_batch_size, min_batch_size)
|
|
@@ -254,210 +347,328 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
254
347
|
|
|
255
348
|
return resid_batch_size, count_batch_size
|
|
256
349
|
|
|
257
|
-
|
|
350
|
+
|
|
351
|
+
@jax.jit
|
|
352
|
+
def step(key: Key[Array, ''], bart: State) -> State:
|
|
258
353
|
"""
|
|
259
|
-
|
|
354
|
+
Do one MCMC step.
|
|
260
355
|
|
|
261
356
|
Parameters
|
|
262
357
|
----------
|
|
263
|
-
|
|
264
|
-
A BART mcmc state, as created by `init`.
|
|
265
|
-
key : jax.dtypes.prng_key array
|
|
358
|
+
key
|
|
266
359
|
A jax random key.
|
|
360
|
+
bart
|
|
361
|
+
A BART mcmc state, as created by `init`.
|
|
267
362
|
|
|
268
363
|
Returns
|
|
269
364
|
-------
|
|
270
|
-
|
|
271
|
-
The new BART mcmc state.
|
|
365
|
+
The new BART mcmc state.
|
|
272
366
|
"""
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
367
|
+
keys = split(key)
|
|
368
|
+
|
|
369
|
+
if bart.y.dtype == bool: # binary regression
|
|
370
|
+
bart = replace(bart, sigma2=jnp.float32(1))
|
|
371
|
+
bart = step_trees(keys.pop(), bart)
|
|
372
|
+
bart = replace(bart, sigma2=None)
|
|
373
|
+
return step_z(keys.pop(), bart)
|
|
276
374
|
|
|
277
|
-
|
|
375
|
+
else: # continuous regression
|
|
376
|
+
bart = step_trees(keys.pop(), bart)
|
|
377
|
+
return step_sigma(keys.pop(), bart)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
278
381
|
"""
|
|
279
382
|
Forest sampling step of BART MCMC.
|
|
280
383
|
|
|
281
384
|
Parameters
|
|
282
385
|
----------
|
|
283
|
-
|
|
284
|
-
A BART mcmc state, as created by `init`.
|
|
285
|
-
key : jax.dtypes.prng_key array
|
|
386
|
+
key
|
|
286
387
|
A jax random key.
|
|
388
|
+
bart
|
|
389
|
+
A BART mcmc state, as created by `init`.
|
|
287
390
|
|
|
288
391
|
Returns
|
|
289
392
|
-------
|
|
290
|
-
|
|
291
|
-
The new BART mcmc state.
|
|
393
|
+
The new BART mcmc state.
|
|
292
394
|
|
|
293
395
|
Notes
|
|
294
396
|
-----
|
|
295
397
|
This function zeroes the proposal counters.
|
|
296
398
|
"""
|
|
297
|
-
|
|
298
|
-
moves =
|
|
299
|
-
return accept_moves_and_sample_leaves(bart, moves
|
|
399
|
+
keys = split(key)
|
|
400
|
+
moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
|
|
401
|
+
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class Moves(Module):
|
|
405
|
+
"""
|
|
406
|
+
Moves proposed to modify each tree.
|
|
300
407
|
|
|
301
|
-
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
allowed
|
|
411
|
+
Whether the move is possible in the first place. There are additional
|
|
412
|
+
constraints that could forbid it, but they are computed at acceptance
|
|
413
|
+
time.
|
|
414
|
+
grow
|
|
415
|
+
Whether the move is a grow move or a prune move.
|
|
416
|
+
num_growable
|
|
417
|
+
The number of growable leaves in the original tree.
|
|
418
|
+
node
|
|
419
|
+
The index of the leaf to grow or node to prune.
|
|
420
|
+
left
|
|
421
|
+
right
|
|
422
|
+
The indices of the children of 'node'.
|
|
423
|
+
partial_ratio
|
|
424
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
425
|
+
the likelihood ratio and the probability of proposing the prune
|
|
426
|
+
move. If the move is PRUNE, the ratio is inverted. `None` once
|
|
427
|
+
`log_trans_prior_ratio` has been computed.
|
|
428
|
+
log_trans_prior_ratio
|
|
429
|
+
The logarithm of the product of the transition and prior terms of the
|
|
430
|
+
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
431
|
+
`None` if not yet computed.
|
|
432
|
+
grow_var
|
|
433
|
+
The decision axes of the new rules.
|
|
434
|
+
grow_split
|
|
435
|
+
The decision boundaries of the new rules.
|
|
436
|
+
var_trees
|
|
437
|
+
The updated decision axes of the trees, valid whatever move.
|
|
438
|
+
logu
|
|
439
|
+
The logarithm of a uniform (0, 1] random variable to be used to
|
|
440
|
+
accept the move. It's in (-oo, 0].
|
|
441
|
+
acc
|
|
442
|
+
Whether the move was accepted. `None` if not yet computed.
|
|
443
|
+
to_prune
|
|
444
|
+
Whether the final operation to apply the move is pruning. This indicates
|
|
445
|
+
an accepted prune move or a rejected grow move. `None` if not yet
|
|
446
|
+
computed.
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
allowed: Bool[Array, 'num_trees']
|
|
450
|
+
grow: Bool[Array, 'num_trees']
|
|
451
|
+
num_growable: UInt[Array, 'num_trees']
|
|
452
|
+
node: UInt[Array, 'num_trees']
|
|
453
|
+
left: UInt[Array, 'num_trees']
|
|
454
|
+
right: UInt[Array, 'num_trees']
|
|
455
|
+
partial_ratio: Float32[Array, 'num_trees'] | None
|
|
456
|
+
log_trans_prior_ratio: None | Float32[Array, 'num_trees']
|
|
457
|
+
grow_var: UInt[Array, 'num_trees']
|
|
458
|
+
grow_split: UInt[Array, 'num_trees']
|
|
459
|
+
var_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
460
|
+
logu: Float32[Array, 'num_trees']
|
|
461
|
+
acc: None | Bool[Array, 'num_trees']
|
|
462
|
+
to_prune: None | Bool[Array, 'num_trees']
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def propose_moves(
|
|
466
|
+
key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
|
|
467
|
+
) -> Moves:
|
|
302
468
|
"""
|
|
303
469
|
Propose moves for all the trees.
|
|
304
470
|
|
|
471
|
+
There are two types of moves: GROW (convert a leaf to a decision node and
|
|
472
|
+
add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
|
|
473
|
+
leaf, deleting its children).
|
|
474
|
+
|
|
305
475
|
Parameters
|
|
306
476
|
----------
|
|
307
|
-
|
|
308
|
-
BART mcmc state.
|
|
309
|
-
key : jax.dtypes.prng_key array
|
|
477
|
+
key
|
|
310
478
|
A jax random key.
|
|
479
|
+
forest
|
|
480
|
+
The `forest` field of a BART MCMC state.
|
|
481
|
+
max_split
|
|
482
|
+
The maximum split index for each variable, found in `State`.
|
|
311
483
|
|
|
312
484
|
Returns
|
|
313
485
|
-------
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
Whether the move is possible.
|
|
319
|
-
'grow' : bool array (num_trees,)
|
|
320
|
-
Whether the move is a grow move or a prune move.
|
|
321
|
-
'num_growable' : int array (num_trees,)
|
|
322
|
-
The number of growable leaves in the original tree.
|
|
323
|
-
'node' : int array (num_trees,)
|
|
324
|
-
The index of the leaf to grow or node to prune.
|
|
325
|
-
'left', 'right' : int array (num_trees,)
|
|
326
|
-
The indices of the children of 'node'.
|
|
327
|
-
'partial_ratio' : float array (num_trees,)
|
|
328
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
329
|
-
the likelihood ratio and the probability of proposing the prune
|
|
330
|
-
move. If the move is Prune, the ratio is inverted.
|
|
331
|
-
'grow_var' : int array (num_trees,)
|
|
332
|
-
The decision axes of the new rules.
|
|
333
|
-
'grow_split' : int array (num_trees,)
|
|
334
|
-
The decision boundaries of the new rules.
|
|
335
|
-
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
336
|
-
The updated decision axes of the trees, valid whatever move.
|
|
337
|
-
'logu' : float array (num_trees,)
|
|
338
|
-
The logarithm of a uniform (0, 1] random variable to be used to
|
|
339
|
-
accept the move. It's in (-oo, 0].
|
|
340
|
-
"""
|
|
341
|
-
ntree = bart['leaf_trees'].shape[0]
|
|
342
|
-
key = random.split(key, 1 + ntree)
|
|
343
|
-
key, subkey = key[0], key[1:]
|
|
486
|
+
The proposed move for each tree.
|
|
487
|
+
"""
|
|
488
|
+
num_trees, _ = forest.leaf_trees.shape
|
|
489
|
+
keys = split(key, 1 + 2 * num_trees)
|
|
344
490
|
|
|
345
491
|
# compute moves
|
|
346
|
-
grow_moves
|
|
492
|
+
grow_moves = propose_grow_moves(
|
|
493
|
+
keys.pop(num_trees),
|
|
494
|
+
forest.var_trees,
|
|
495
|
+
forest.split_trees,
|
|
496
|
+
forest.affluence_trees,
|
|
497
|
+
max_split,
|
|
498
|
+
forest.p_nonterminal,
|
|
499
|
+
forest.p_propose_grow,
|
|
500
|
+
)
|
|
501
|
+
prune_moves = propose_prune_moves(
|
|
502
|
+
keys.pop(num_trees),
|
|
503
|
+
forest.split_trees,
|
|
504
|
+
forest.affluence_trees,
|
|
505
|
+
forest.p_nonterminal,
|
|
506
|
+
forest.p_propose_grow,
|
|
507
|
+
)
|
|
347
508
|
|
|
348
|
-
u, logu = random.uniform(
|
|
509
|
+
u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
|
|
349
510
|
|
|
350
511
|
# choose between grow or prune
|
|
351
|
-
grow_allowed = grow_moves
|
|
352
|
-
p_grow = jnp.where(grow_allowed & prune_moves
|
|
353
|
-
grow = u < p_grow
|
|
512
|
+
grow_allowed = grow_moves.num_growable.astype(bool)
|
|
513
|
+
p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
|
|
514
|
+
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
354
515
|
|
|
355
516
|
# compute children indices
|
|
356
|
-
node = jnp.where(grow, grow_moves
|
|
517
|
+
node = jnp.where(grow, grow_moves.node, prune_moves.node)
|
|
357
518
|
left = node << 1
|
|
358
519
|
right = left + 1
|
|
359
520
|
|
|
360
|
-
return
|
|
361
|
-
allowed=grow | prune_moves
|
|
521
|
+
return Moves(
|
|
522
|
+
allowed=grow | prune_moves.allowed,
|
|
362
523
|
grow=grow,
|
|
363
|
-
num_growable=grow_moves
|
|
524
|
+
num_growable=grow_moves.num_growable,
|
|
364
525
|
node=node,
|
|
365
526
|
left=left,
|
|
366
527
|
right=right,
|
|
367
|
-
partial_ratio=jnp.where(
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
528
|
+
partial_ratio=jnp.where(
|
|
529
|
+
grow, grow_moves.partial_ratio, prune_moves.partial_ratio
|
|
530
|
+
),
|
|
531
|
+
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
532
|
+
grow_var=grow_moves.var,
|
|
533
|
+
grow_split=grow_moves.split,
|
|
534
|
+
var_trees=grow_moves.var_tree,
|
|
371
535
|
logu=jnp.log1p(-logu),
|
|
536
|
+
acc=None, # will be set in accept_moves_sequential_stage
|
|
537
|
+
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
372
538
|
)
|
|
373
539
|
|
|
374
|
-
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
|
|
375
|
-
def _sample_moves_vmap_trees(*args):
|
|
376
|
-
args, key = args[:-1], args[-1]
|
|
377
|
-
key, key1 = random.split(key)
|
|
378
|
-
grow = grow_move(*args, key)
|
|
379
|
-
prune = prune_move(*args, key1)
|
|
380
|
-
return grow, prune
|
|
381
540
|
|
|
382
|
-
|
|
541
|
+
class GrowMoves(Module):
|
|
383
542
|
"""
|
|
384
|
-
|
|
543
|
+
Represent a proposed grow move for each tree.
|
|
385
544
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
545
|
+
Parameters
|
|
546
|
+
----------
|
|
547
|
+
num_growable
|
|
548
|
+
The number of growable leaves.
|
|
549
|
+
node
|
|
550
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
551
|
+
leaves.
|
|
552
|
+
var
|
|
553
|
+
split
|
|
554
|
+
The decision axis and boundary of the new rule.
|
|
555
|
+
partial_ratio
|
|
556
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
557
|
+
the likelihood ratio and the probability of proposing the prune
|
|
558
|
+
move.
|
|
559
|
+
var_tree
|
|
560
|
+
The updated decision axes of the tree.
|
|
561
|
+
"""
|
|
562
|
+
|
|
563
|
+
num_growable: UInt[Array, 'num_trees']
|
|
564
|
+
node: UInt[Array, 'num_trees']
|
|
565
|
+
var: UInt[Array, 'num_trees']
|
|
566
|
+
split: UInt[Array, 'num_trees']
|
|
567
|
+
partial_ratio: Float32[Array, 'num_trees']
|
|
568
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
|
|
572
|
+
def propose_grow_moves(
|
|
573
|
+
key: Key[Array, ''],
|
|
574
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
575
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
576
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
577
|
+
max_split: UInt[Array, 'p'],
|
|
578
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
579
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
580
|
+
) -> GrowMoves:
|
|
581
|
+
"""
|
|
582
|
+
Propose a GROW move for each tree.
|
|
583
|
+
|
|
584
|
+
A GROW move picks a leaf node and converts it to a non-terminal node with
|
|
585
|
+
two leaf children.
|
|
389
586
|
|
|
390
587
|
Parameters
|
|
391
588
|
----------
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
589
|
+
key
|
|
590
|
+
A jax random key.
|
|
591
|
+
var_tree
|
|
592
|
+
The splitting axes of the tree.
|
|
593
|
+
split_tree
|
|
395
594
|
The splitting points of the tree.
|
|
396
|
-
affluence_tree
|
|
595
|
+
affluence_tree
|
|
397
596
|
Whether a leaf has enough points to be grown.
|
|
398
|
-
max_split
|
|
597
|
+
max_split
|
|
399
598
|
The maximum split index for each variable.
|
|
400
|
-
p_nonterminal
|
|
599
|
+
p_nonterminal
|
|
401
600
|
The probability of a nonterminal node at each depth.
|
|
402
|
-
p_propose_grow
|
|
601
|
+
p_propose_grow
|
|
403
602
|
The unnormalized probability of choosing a leaf to grow.
|
|
404
|
-
key : jax.dtypes.prng_key array
|
|
405
|
-
A jax random key.
|
|
406
603
|
|
|
407
604
|
Returns
|
|
408
605
|
-------
|
|
409
|
-
|
|
410
|
-
A dictionary with fields:
|
|
606
|
+
An object representing the proposed move.
|
|
411
607
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
'var', 'split' : int
|
|
418
|
-
The decision axis and boundary of the new rule.
|
|
419
|
-
'partial_ratio' : float
|
|
420
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
421
|
-
the likelihood ratio and the probability of proposing the prune
|
|
422
|
-
move.
|
|
423
|
-
'var_tree' : array (2 ** (d - 1),)
|
|
424
|
-
The updated decision axes of the tree.
|
|
425
|
-
"""
|
|
608
|
+
Notes
|
|
609
|
+
-----
|
|
610
|
+
The move is not proposed if a leaf is already at maximum depth, or if a leaf
|
|
611
|
+
has less than twice the requested minimum number of datapoints per leaf.
|
|
612
|
+
This is marked by returning `num_growable` set to 0.
|
|
426
613
|
|
|
427
|
-
|
|
614
|
+
The move is also not be possible if the ancestors of a leaf have
|
|
615
|
+
exhausted the possible decision rules that lead to a non-empty selection.
|
|
616
|
+
This is marked by returning `var` set to `p` and `split` set to 0. But this
|
|
617
|
+
does not block the move from counting as "proposed", even though it is
|
|
618
|
+
predictably going to be rejected. This simplifies the MCMC and should not
|
|
619
|
+
reduce efficiency if not in unrealistic corner cases.
|
|
620
|
+
"""
|
|
621
|
+
keys = split(key, 3)
|
|
428
622
|
|
|
429
|
-
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
623
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
624
|
+
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
625
|
+
)
|
|
430
626
|
|
|
431
|
-
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow
|
|
627
|
+
var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
|
|
432
628
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
433
629
|
|
|
434
|
-
|
|
630
|
+
split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
|
|
435
631
|
|
|
436
|
-
ratio = compute_partial_ratio(
|
|
632
|
+
ratio = compute_partial_ratio(
|
|
633
|
+
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
634
|
+
)
|
|
437
635
|
|
|
438
|
-
return
|
|
636
|
+
return GrowMoves(
|
|
439
637
|
num_growable=num_growable,
|
|
440
638
|
node=leaf_to_grow,
|
|
441
639
|
var=var,
|
|
442
|
-
split=
|
|
640
|
+
split=split_idx,
|
|
443
641
|
partial_ratio=ratio,
|
|
444
642
|
var_tree=var_tree,
|
|
445
643
|
)
|
|
446
644
|
|
|
447
|
-
|
|
645
|
+
# TODO it is not clear to me how var=p and split=0 when the move is not
|
|
646
|
+
# possible lead to corrent behavior downstream. Like, the move is proposed,
|
|
647
|
+
# but then it's a noop? And since it's a noop, it makes no difference if
|
|
648
|
+
# it's "accepted" or "rejected", it's like it's always rejected, so who
|
|
649
|
+
# cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def choose_leaf(
|
|
653
|
+
key: Key[Array, ''],
|
|
654
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
655
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
656
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
657
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
448
658
|
"""
|
|
449
659
|
Choose a leaf node to grow in a tree.
|
|
450
660
|
|
|
451
661
|
Parameters
|
|
452
662
|
----------
|
|
453
|
-
|
|
663
|
+
key
|
|
664
|
+
A jax random key.
|
|
665
|
+
split_tree
|
|
454
666
|
The splitting points of the tree.
|
|
455
|
-
affluence_tree
|
|
456
|
-
Whether a leaf has enough points
|
|
457
|
-
|
|
667
|
+
affluence_tree
|
|
668
|
+
Whether a leaf has enough points that it could be split into two leaves
|
|
669
|
+
satisfying the `min_points_per_leaf` requirement.
|
|
670
|
+
p_propose_grow
|
|
458
671
|
The unnormalized probability of choosing a leaf to grow.
|
|
459
|
-
key : jax.dtypes.prng_key array
|
|
460
|
-
A jax random key.
|
|
461
672
|
|
|
462
673
|
Returns
|
|
463
674
|
-------
|
|
@@ -465,9 +676,11 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
465
676
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
466
677
|
``2 ** d``.
|
|
467
678
|
num_growable : int
|
|
468
|
-
The number of leaf nodes that can be grown.
|
|
679
|
+
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
680
|
+
and have at least twice `min_points_per_leaf` if set.
|
|
469
681
|
prob_choose : float
|
|
470
|
-
The normalized probability
|
|
682
|
+
The (normalized) probability that this function had to choose that
|
|
683
|
+
specific leaf, given the arguments.
|
|
471
684
|
num_prunable : int
|
|
472
685
|
The number of leaf parents that could be pruned, after converting the
|
|
473
686
|
selected leaf to a non-terminal node.
|
|
@@ -482,71 +695,86 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
482
695
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
483
696
|
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
484
697
|
|
|
485
|
-
|
|
698
|
+
|
|
699
|
+
def growable_leaves(
|
|
700
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
701
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
702
|
+
) -> Bool[Array, '2**(d-1)']:
|
|
486
703
|
"""
|
|
487
704
|
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
488
705
|
|
|
706
|
+
The condition is that a leaf is not at the bottom level and has at least two
|
|
707
|
+
times the number of minimum points per leaf.
|
|
708
|
+
|
|
489
709
|
Parameters
|
|
490
710
|
----------
|
|
491
|
-
split_tree
|
|
711
|
+
split_tree
|
|
492
712
|
The splitting points of the tree.
|
|
493
|
-
affluence_tree
|
|
713
|
+
affluence_tree
|
|
494
714
|
Whether a leaf has enough points to be grown.
|
|
495
715
|
|
|
496
716
|
Returns
|
|
497
717
|
-------
|
|
498
|
-
|
|
499
|
-
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
500
|
-
that are not at the bottom level and have at least two times the number
|
|
501
|
-
of minimum points per leaf.
|
|
718
|
+
The mask indicating the leaf nodes that can be proposed to grow.
|
|
502
719
|
"""
|
|
503
720
|
is_growable = grove.is_actual_leaf(split_tree)
|
|
504
721
|
if affluence_tree is not None:
|
|
505
722
|
is_growable &= affluence_tree
|
|
506
723
|
return is_growable
|
|
507
724
|
|
|
508
|
-
|
|
725
|
+
|
|
726
|
+
def categorical(
|
|
727
|
+
key: Key[Array, ''], distr: Float32[Array, 'n']
|
|
728
|
+
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
509
729
|
"""
|
|
510
730
|
Return a random integer from an arbitrary distribution.
|
|
511
731
|
|
|
512
732
|
Parameters
|
|
513
733
|
----------
|
|
514
|
-
key
|
|
734
|
+
key
|
|
515
735
|
A jax random key.
|
|
516
|
-
distr
|
|
736
|
+
distr
|
|
517
737
|
An unnormalized probability distribution.
|
|
518
738
|
|
|
519
739
|
Returns
|
|
520
740
|
-------
|
|
521
|
-
u :
|
|
741
|
+
u : Int32[Array, '']
|
|
522
742
|
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
523
743
|
return ``n``.
|
|
744
|
+
norm : Float32[Array, '']
|
|
745
|
+
The sum of `distr`.
|
|
524
746
|
"""
|
|
525
747
|
ecdf = jnp.cumsum(distr)
|
|
526
748
|
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
527
749
|
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
528
750
|
|
|
529
|
-
|
|
751
|
+
|
|
752
|
+
def choose_variable(
|
|
753
|
+
key: Key[Array, ''],
|
|
754
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
755
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
756
|
+
max_split: UInt[Array, 'p'],
|
|
757
|
+
leaf_index: Int32[Array, ''],
|
|
758
|
+
) -> Int32[Array, '']:
|
|
530
759
|
"""
|
|
531
760
|
Choose a variable to split on for a new non-terminal node.
|
|
532
761
|
|
|
533
762
|
Parameters
|
|
534
763
|
----------
|
|
535
|
-
|
|
764
|
+
key
|
|
765
|
+
A jax random key.
|
|
766
|
+
var_tree
|
|
536
767
|
The variable indices of the tree.
|
|
537
|
-
split_tree
|
|
768
|
+
split_tree
|
|
538
769
|
The splitting points of the tree.
|
|
539
|
-
max_split
|
|
770
|
+
max_split
|
|
540
771
|
The maximum split index for each variable.
|
|
541
|
-
leaf_index
|
|
772
|
+
leaf_index
|
|
542
773
|
The index of the leaf to grow.
|
|
543
|
-
key : jax.dtypes.prng_key array
|
|
544
|
-
A jax random key.
|
|
545
774
|
|
|
546
775
|
Returns
|
|
547
776
|
-------
|
|
548
|
-
|
|
549
|
-
The index of the variable to split on.
|
|
777
|
+
The index of the variable to split on.
|
|
550
778
|
|
|
551
779
|
Notes
|
|
552
780
|
-----
|
|
@@ -556,37 +784,49 @@ def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
556
784
|
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
557
785
|
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
558
786
|
|
|
559
|
-
|
|
787
|
+
|
|
788
|
+
def fully_used_variables(
|
|
789
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
790
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
791
|
+
max_split: UInt[Array, 'p'],
|
|
792
|
+
leaf_index: Int32[Array, ''],
|
|
793
|
+
) -> UInt[Array, 'd-2']:
|
|
560
794
|
"""
|
|
561
795
|
Return a list of variables that have an empty split range at a given node.
|
|
562
796
|
|
|
563
797
|
Parameters
|
|
564
798
|
----------
|
|
565
|
-
var_tree
|
|
799
|
+
var_tree
|
|
566
800
|
The variable indices of the tree.
|
|
567
|
-
split_tree
|
|
801
|
+
split_tree
|
|
568
802
|
The splitting points of the tree.
|
|
569
|
-
max_split
|
|
803
|
+
max_split
|
|
570
804
|
The maximum split index for each variable.
|
|
571
|
-
leaf_index
|
|
805
|
+
leaf_index
|
|
572
806
|
The index of the node, assumed to be valid for `var_tree`.
|
|
573
807
|
|
|
574
808
|
Returns
|
|
575
809
|
-------
|
|
576
|
-
|
|
577
|
-
The indices of the variables that have an empty split range. Since the
|
|
578
|
-
number of such variables is not fixed, unused values in the array are
|
|
579
|
-
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
580
|
-
particular order. Variables may appear more than once.
|
|
581
|
-
"""
|
|
810
|
+
The indices of the variables that have an empty split range.
|
|
582
811
|
|
|
812
|
+
Notes
|
|
813
|
+
-----
|
|
814
|
+
The number of unused variables is not known in advance. Unused values in the
|
|
815
|
+
array are filled with `p`. The fill values are not guaranteed to be placed
|
|
816
|
+
in any particular order, and variables may appear more than once.
|
|
817
|
+
"""
|
|
583
818
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
584
819
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
585
820
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
586
821
|
num_split = r - l
|
|
587
822
|
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
588
823
|
|
|
589
|
-
|
|
824
|
+
|
|
825
|
+
def ancestor_variables(
|
|
826
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
827
|
+
max_split: UInt[Array, 'p'],
|
|
828
|
+
node_index: Int32[Array, ''],
|
|
829
|
+
) -> UInt[Array, 'd-2']:
|
|
590
830
|
"""
|
|
591
831
|
Return the list of variables in the ancestors of a node.
|
|
592
832
|
|
|
@@ -601,13 +841,18 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
601
841
|
|
|
602
842
|
Returns
|
|
603
843
|
-------
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
844
|
+
The variable indices of the ancestors of the node.
|
|
845
|
+
|
|
846
|
+
Notes
|
|
847
|
+
-----
|
|
848
|
+
The ancestors are the nodes going from the root to the parent of the node.
|
|
849
|
+
The number of ancestors is not known at tracing time; unused spots in the
|
|
850
|
+
output array are filled with `p`.
|
|
607
851
|
"""
|
|
608
852
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
609
|
-
ancestor_vars = jnp.zeros(max_num_ancestors,
|
|
853
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
|
|
610
854
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
855
|
+
|
|
611
856
|
def loop(carry, _):
|
|
612
857
|
i, index, ancestor_vars = carry
|
|
613
858
|
index >>= 1
|
|
@@ -615,34 +860,44 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
615
860
|
var = jnp.where(index, var, max_split.size)
|
|
616
861
|
ancestor_vars = ancestor_vars.at[i].set(var)
|
|
617
862
|
return (i - 1, index, ancestor_vars), None
|
|
863
|
+
|
|
618
864
|
(_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
|
|
619
865
|
return ancestor_vars
|
|
620
866
|
|
|
621
|
-
|
|
867
|
+
|
|
868
|
+
def split_range(
|
|
869
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
870
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
871
|
+
max_split: UInt[Array, 'p'],
|
|
872
|
+
node_index: Int32[Array, ''],
|
|
873
|
+
ref_var: Int32[Array, ''],
|
|
874
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
622
875
|
"""
|
|
623
876
|
Return the range of allowed splits for a variable at a given node.
|
|
624
877
|
|
|
625
878
|
Parameters
|
|
626
879
|
----------
|
|
627
|
-
var_tree
|
|
880
|
+
var_tree
|
|
628
881
|
The variable indices of the tree.
|
|
629
|
-
split_tree
|
|
882
|
+
split_tree
|
|
630
883
|
The splitting points of the tree.
|
|
631
|
-
max_split
|
|
884
|
+
max_split
|
|
632
885
|
The maximum split index for each variable.
|
|
633
|
-
node_index
|
|
886
|
+
node_index
|
|
634
887
|
The index of the node, assumed to be valid for `var_tree`.
|
|
635
|
-
ref_var
|
|
888
|
+
ref_var
|
|
636
889
|
The variable for which to measure the split range.
|
|
637
890
|
|
|
638
891
|
Returns
|
|
639
892
|
-------
|
|
640
|
-
l, r
|
|
641
|
-
The range of allowed splits is [l, r).
|
|
893
|
+
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
|
|
642
894
|
"""
|
|
643
895
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
644
|
-
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
896
|
+
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
897
|
+
jnp.int32
|
|
898
|
+
)
|
|
645
899
|
carry = 0, initial_r, node_index
|
|
900
|
+
|
|
646
901
|
def loop(carry, _):
|
|
647
902
|
l, r, index = carry
|
|
648
903
|
right_child = (index & 1).astype(bool)
|
|
@@ -652,89 +907,114 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
|
652
907
|
l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
|
|
653
908
|
r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
|
|
654
909
|
return (l, r, index), None
|
|
910
|
+
|
|
655
911
|
(l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
|
|
656
912
|
return l + 1, r
|
|
657
913
|
|
|
658
|
-
|
|
914
|
+
|
|
915
|
+
def randint_exclude(
|
|
916
|
+
key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
|
|
917
|
+
) -> Int32[Array, '']:
|
|
659
918
|
"""
|
|
660
919
|
Return a random integer in a range, excluding some values.
|
|
661
920
|
|
|
662
921
|
Parameters
|
|
663
922
|
----------
|
|
664
|
-
key
|
|
923
|
+
key
|
|
665
924
|
A jax random key.
|
|
666
|
-
sup
|
|
925
|
+
sup
|
|
667
926
|
The exclusive upper bound of the range.
|
|
668
|
-
exclude
|
|
927
|
+
exclude
|
|
669
928
|
The values to exclude from the range. Values greater than or equal to
|
|
670
929
|
`sup` are ignored. Values can appear more than once.
|
|
671
930
|
|
|
672
931
|
Returns
|
|
673
932
|
-------
|
|
674
|
-
u
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
933
|
+
A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
|
|
934
|
+
|
|
935
|
+
Notes
|
|
936
|
+
-----
|
|
937
|
+
If all values in the range are excluded, return `sup`.
|
|
678
938
|
"""
|
|
679
939
|
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
680
940
|
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
681
941
|
u = random.randint(key, (), 0, num_allowed)
|
|
942
|
+
|
|
682
943
|
def loop(u, i):
|
|
683
944
|
return jnp.where(i <= u, u + 1, u), None
|
|
945
|
+
|
|
684
946
|
u, _ = lax.scan(loop, u, exclude)
|
|
685
947
|
return u
|
|
686
948
|
|
|
687
|
-
|
|
949
|
+
|
|
950
|
+
def choose_split(
|
|
951
|
+
key: Key[Array, ''],
|
|
952
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
953
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
954
|
+
max_split: UInt[Array, 'p'],
|
|
955
|
+
leaf_index: Int32[Array, ''],
|
|
956
|
+
) -> Int32[Array, '']:
|
|
688
957
|
"""
|
|
689
958
|
Choose a split point for a new non-terminal node.
|
|
690
959
|
|
|
691
960
|
Parameters
|
|
692
961
|
----------
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
962
|
+
key
|
|
963
|
+
A jax random key.
|
|
964
|
+
var_tree
|
|
965
|
+
The splitting axes of the tree.
|
|
966
|
+
split_tree
|
|
696
967
|
The splitting points of the tree.
|
|
697
|
-
max_split
|
|
968
|
+
max_split
|
|
698
969
|
The maximum split index for each variable.
|
|
699
|
-
leaf_index
|
|
970
|
+
leaf_index
|
|
700
971
|
The index of the leaf to grow. It is assumed that `var_tree` already
|
|
701
972
|
contains the target variable at this index.
|
|
702
|
-
key : jax.dtypes.prng_key array
|
|
703
|
-
A jax random key.
|
|
704
973
|
|
|
705
974
|
Returns
|
|
706
975
|
-------
|
|
707
|
-
|
|
708
|
-
The split point.
|
|
976
|
+
The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
|
|
709
977
|
"""
|
|
710
978
|
var = var_tree[leaf_index]
|
|
711
979
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
712
980
|
return random.randint(key, (), l, r)
|
|
713
981
|
|
|
714
|
-
|
|
982
|
+
# TODO what happens if leaf_index is out of bounds? And is the value used
|
|
983
|
+
# in that case?
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def compute_partial_ratio(
|
|
987
|
+
prob_choose: Float32[Array, ''],
|
|
988
|
+
num_prunable: Int32[Array, ''],
|
|
989
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
990
|
+
leaf_to_grow: Int32[Array, ''],
|
|
991
|
+
) -> Float32[Array, '']:
|
|
715
992
|
"""
|
|
716
993
|
Compute the product of the transition and prior ratios of a grow move.
|
|
717
994
|
|
|
718
995
|
Parameters
|
|
719
996
|
----------
|
|
720
|
-
|
|
721
|
-
The
|
|
722
|
-
|
|
997
|
+
prob_choose
|
|
998
|
+
The probability that the leaf had to be chosen amongst the growable
|
|
999
|
+
leaves.
|
|
1000
|
+
num_prunable
|
|
723
1001
|
The number of leaf parents that could be pruned, after converting the
|
|
724
1002
|
leaf to be grown to a non-terminal node.
|
|
725
|
-
p_nonterminal
|
|
1003
|
+
p_nonterminal
|
|
726
1004
|
The probability of a nonterminal node at each depth.
|
|
727
|
-
leaf_to_grow
|
|
1005
|
+
leaf_to_grow
|
|
728
1006
|
The index of the leaf to grow.
|
|
729
1007
|
|
|
730
1008
|
Returns
|
|
731
1009
|
-------
|
|
732
|
-
ratio
|
|
733
|
-
The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
|
|
734
|
-
times the prior ratio P(new tree) / P(old tree), but the transition
|
|
735
|
-
ratio is missing the factor P(propose prune) in the numerator.
|
|
736
|
-
"""
|
|
1010
|
+
The partial transition ratio times the prior ratio.
|
|
737
1011
|
|
|
1012
|
+
Notes
|
|
1013
|
+
-----
|
|
1014
|
+
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
1015
|
+
The "partial" transition ratio returned is missing the factor P(propose
|
|
1016
|
+
prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
|
|
1017
|
+
"""
|
|
738
1018
|
# the two ratios also contain factors num_available_split *
|
|
739
1019
|
# num_available_var, but they cancel out
|
|
740
1020
|
|
|
@@ -742,9 +1022,9 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
742
1022
|
# computed in the acceptance phase
|
|
743
1023
|
|
|
744
1024
|
prune_allowed = leaf_to_grow != 1
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
1025
|
+
# prune allowed <---> the initial tree is not a root
|
|
1026
|
+
# leaf to grow is root --> the tree can only be a root
|
|
1027
|
+
# tree is a root --> the only leaf I can grow is root
|
|
748
1028
|
|
|
749
1029
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
750
1030
|
|
|
@@ -757,76 +1037,104 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
757
1037
|
|
|
758
1038
|
return tree_ratio / inv_trans_ratio
|
|
759
1039
|
|
|
760
|
-
|
|
1040
|
+
|
|
1041
|
+
class PruneMoves(Module):
|
|
1042
|
+
"""
|
|
1043
|
+
Represent a proposed prune move for each tree.
|
|
1044
|
+
|
|
1045
|
+
Parameters
|
|
1046
|
+
----------
|
|
1047
|
+
allowed
|
|
1048
|
+
Whether the move is possible.
|
|
1049
|
+
node
|
|
1050
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
1051
|
+
partial_ratio
|
|
1052
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
1053
|
+
the likelihood ratio and the probability of proposing the prune
|
|
1054
|
+
move. This ratio is inverted, and is meant to be inverted back in
|
|
1055
|
+
`accept_move_and_sample_leaves`.
|
|
1056
|
+
"""
|
|
1057
|
+
|
|
1058
|
+
allowed: Bool[Array, 'num_trees']
|
|
1059
|
+
node: UInt[Array, 'num_trees']
|
|
1060
|
+
partial_ratio: Float32[Array, 'num_trees']
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
1064
|
+
def propose_prune_moves(
|
|
1065
|
+
key: Key[Array, ''],
|
|
1066
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
1067
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
1068
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
1069
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1070
|
+
) -> PruneMoves:
|
|
761
1071
|
"""
|
|
762
1072
|
Tree structure prune move proposal of BART MCMC.
|
|
763
1073
|
|
|
764
1074
|
Parameters
|
|
765
1075
|
----------
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
split_tree
|
|
1076
|
+
key
|
|
1077
|
+
A jax random key.
|
|
1078
|
+
split_tree
|
|
769
1079
|
The splitting points of the tree.
|
|
770
|
-
affluence_tree
|
|
1080
|
+
affluence_tree
|
|
771
1081
|
Whether a leaf has enough points to be grown.
|
|
772
|
-
|
|
773
|
-
The maximum split index for each variable.
|
|
774
|
-
p_nonterminal : float array (d,)
|
|
1082
|
+
p_nonterminal
|
|
775
1083
|
The probability of a nonterminal node at each depth.
|
|
776
|
-
p_propose_grow
|
|
1084
|
+
p_propose_grow
|
|
777
1085
|
The unnormalized probability of choosing a leaf to grow.
|
|
778
|
-
key : jax.dtypes.prng_key array
|
|
779
|
-
A jax random key.
|
|
780
1086
|
|
|
781
1087
|
Returns
|
|
782
1088
|
-------
|
|
783
|
-
|
|
784
|
-
A dictionary with fields:
|
|
785
|
-
|
|
786
|
-
'allowed' : bool
|
|
787
|
-
Whether the move is possible.
|
|
788
|
-
'node' : int
|
|
789
|
-
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
790
|
-
'partial_ratio' : float
|
|
791
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
792
|
-
the likelihood ratio and the probability of proposing the prune
|
|
793
|
-
move. This ratio is inverted.
|
|
1089
|
+
An object representing the proposed moves.
|
|
794
1090
|
"""
|
|
795
|
-
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
796
|
-
|
|
1091
|
+
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
1092
|
+
key, split_tree, affluence_tree, p_propose_grow
|
|
1093
|
+
)
|
|
1094
|
+
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
797
1095
|
|
|
798
|
-
ratio = compute_partial_ratio(
|
|
1096
|
+
ratio = compute_partial_ratio(
|
|
1097
|
+
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
1098
|
+
)
|
|
799
1099
|
|
|
800
|
-
return
|
|
1100
|
+
return PruneMoves(
|
|
801
1101
|
allowed=allowed,
|
|
802
1102
|
node=node_to_prune,
|
|
803
|
-
partial_ratio=ratio,
|
|
1103
|
+
partial_ratio=ratio,
|
|
804
1104
|
)
|
|
805
1105
|
|
|
806
|
-
|
|
1106
|
+
|
|
1107
|
+
def choose_leaf_parent(
|
|
1108
|
+
key: Key[Array, ''],
|
|
1109
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
1110
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
1111
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1112
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
|
|
807
1113
|
"""
|
|
808
1114
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
809
1115
|
|
|
810
1116
|
Parameters
|
|
811
1117
|
----------
|
|
812
|
-
|
|
1118
|
+
key
|
|
1119
|
+
A jax random key.
|
|
1120
|
+
split_tree
|
|
813
1121
|
The splitting points of the tree.
|
|
814
|
-
affluence_tree
|
|
1122
|
+
affluence_tree
|
|
815
1123
|
Whether a leaf has enough points to be grown.
|
|
816
|
-
p_propose_grow
|
|
1124
|
+
p_propose_grow
|
|
817
1125
|
The unnormalized probability of choosing a leaf to grow.
|
|
818
|
-
key : jax.dtypes.prng_key array
|
|
819
|
-
A jax random key.
|
|
820
1126
|
|
|
821
1127
|
Returns
|
|
822
1128
|
-------
|
|
823
|
-
node_to_prune :
|
|
1129
|
+
node_to_prune : Int32[Array, '']
|
|
824
1130
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
825
1131
|
``2 ** d``.
|
|
826
|
-
num_prunable :
|
|
1132
|
+
num_prunable : Int32[Array, '']
|
|
827
1133
|
The number of leaf parents that could be pruned.
|
|
828
|
-
prob_choose :
|
|
829
|
-
The normalized probability
|
|
1134
|
+
prob_choose : Float32[Array, '']
|
|
1135
|
+
The (normalized) probability that `choose_leaf` would chose
|
|
1136
|
+
`node_to_prune` as leaf to grow, if passed the tree where
|
|
1137
|
+
`node_to_prune` had been pruned.
|
|
830
1138
|
"""
|
|
831
1139
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
832
1140
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
@@ -834,517 +1142,910 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
834
1142
|
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
835
1143
|
|
|
836
1144
|
split_tree = split_tree.at[node_to_prune].set(0)
|
|
837
|
-
affluence_tree
|
|
838
|
-
|
|
839
|
-
affluence_tree.at[node_to_prune].set(True)
|
|
840
|
-
)
|
|
1145
|
+
if affluence_tree is not None:
|
|
1146
|
+
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
841
1147
|
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
842
1148
|
prob_choose = p_propose_grow[node_to_prune]
|
|
843
1149
|
prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
844
1150
|
|
|
845
1151
|
return node_to_prune, num_prunable, prob_choose
|
|
846
1152
|
|
|
847
|
-
|
|
1153
|
+
|
|
1154
|
+
def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
|
|
848
1155
|
"""
|
|
849
1156
|
Return a random integer in a range, including only some values.
|
|
850
1157
|
|
|
851
1158
|
Parameters
|
|
852
1159
|
----------
|
|
853
|
-
key
|
|
1160
|
+
key
|
|
854
1161
|
A jax random key.
|
|
855
|
-
mask
|
|
1162
|
+
mask
|
|
856
1163
|
The mask indicating the allowed values.
|
|
857
1164
|
|
|
858
1165
|
Returns
|
|
859
1166
|
-------
|
|
860
|
-
u
|
|
861
|
-
|
|
862
|
-
|
|
1167
|
+
A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
|
|
1168
|
+
|
|
1169
|
+
Notes
|
|
1170
|
+
-----
|
|
1171
|
+
If all values in the mask are `False`, return `n`.
|
|
863
1172
|
"""
|
|
864
1173
|
ecdf = jnp.cumsum(mask)
|
|
865
1174
|
u = random.randint(key, (), 0, ecdf[-1])
|
|
866
1175
|
return jnp.searchsorted(ecdf, u, 'right')
|
|
867
1176
|
|
|
868
|
-
|
|
1177
|
+
|
|
1178
|
+
def accept_moves_and_sample_leaves(
|
|
1179
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1180
|
+
) -> State:
|
|
869
1181
|
"""
|
|
870
1182
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
871
1183
|
|
|
872
1184
|
Parameters
|
|
873
1185
|
----------
|
|
874
|
-
|
|
875
|
-
A BART mcmc state.
|
|
876
|
-
moves : dict
|
|
877
|
-
The proposed moves, see `sample_moves`.
|
|
878
|
-
key : jax.dtypes.prng_key array
|
|
1186
|
+
key
|
|
879
1187
|
A jax random key.
|
|
1188
|
+
bart
|
|
1189
|
+
A valid BART mcmc state.
|
|
1190
|
+
moves
|
|
1191
|
+
The proposed moves, see `propose_moves`.
|
|
880
1192
|
|
|
881
1193
|
Returns
|
|
882
1194
|
-------
|
|
883
|
-
|
|
884
|
-
The new BART mcmc state.
|
|
1195
|
+
A new (valid) BART mcmc state.
|
|
885
1196
|
"""
|
|
886
|
-
|
|
887
|
-
bart, moves = accept_moves_sequential_stage(
|
|
1197
|
+
pso = accept_moves_parallel_stage(key, bart, moves)
|
|
1198
|
+
bart, moves = accept_moves_sequential_stage(pso)
|
|
888
1199
|
return accept_moves_final_stage(bart, moves)
|
|
889
1200
|
|
|
890
|
-
|
|
1201
|
+
|
|
1202
|
+
class Counts(Module):
|
|
1203
|
+
"""
|
|
1204
|
+
Number of datapoints in the nodes involved in proposed moves for each tree.
|
|
1205
|
+
|
|
1206
|
+
Parameters
|
|
1207
|
+
----------
|
|
1208
|
+
left
|
|
1209
|
+
Number of datapoints in the left child.
|
|
1210
|
+
right
|
|
1211
|
+
Number of datapoints in the right child.
|
|
1212
|
+
total
|
|
1213
|
+
Number of datapoints in the parent (``= left + right``).
|
|
1214
|
+
"""
|
|
1215
|
+
|
|
1216
|
+
left: UInt[Array, 'num_trees']
|
|
1217
|
+
right: UInt[Array, 'num_trees']
|
|
1218
|
+
total: UInt[Array, 'num_trees']
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
class Precs(Module):
|
|
1222
|
+
"""
|
|
1223
|
+
Likelihood precision scale in the nodes involved in proposed moves for each tree.
|
|
1224
|
+
|
|
1225
|
+
The "likelihood precision scale" of a tree node is the sum of the inverse
|
|
1226
|
+
squared error scales of the datapoints selected by the node.
|
|
1227
|
+
|
|
1228
|
+
Parameters
|
|
1229
|
+
----------
|
|
1230
|
+
left
|
|
1231
|
+
Likelihood precision scale in the left child.
|
|
1232
|
+
right
|
|
1233
|
+
Likelihood precision scale in the right child.
|
|
1234
|
+
total
|
|
1235
|
+
Likelihood precision scale in the parent (``= left + right``).
|
|
1236
|
+
"""
|
|
1237
|
+
|
|
1238
|
+
left: Float32[Array, 'num_trees']
|
|
1239
|
+
right: Float32[Array, 'num_trees']
|
|
1240
|
+
total: Float32[Array, 'num_trees']
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
class PreLkV(Module):
|
|
1244
|
+
"""
|
|
1245
|
+
Non-sequential terms of the likelihood ratio for each tree.
|
|
1246
|
+
|
|
1247
|
+
These terms can be computed in parallel across trees.
|
|
1248
|
+
|
|
1249
|
+
Parameters
|
|
1250
|
+
----------
|
|
1251
|
+
sigma2_left
|
|
1252
|
+
The noise variance in the left child of the leaves grown or pruned by
|
|
1253
|
+
the moves.
|
|
1254
|
+
sigma2_right
|
|
1255
|
+
The noise variance in the right child of the leaves grown or pruned by
|
|
1256
|
+
the moves.
|
|
1257
|
+
sigma2_total
|
|
1258
|
+
The noise variance in the total of the leaves grown or pruned by the
|
|
1259
|
+
moves.
|
|
1260
|
+
sqrt_term
|
|
1261
|
+
The **logarithm** of the square root term of the likelihood ratio.
|
|
1262
|
+
"""
|
|
1263
|
+
|
|
1264
|
+
sigma2_left: Float32[Array, 'num_trees']
|
|
1265
|
+
sigma2_right: Float32[Array, 'num_trees']
|
|
1266
|
+
sigma2_total: Float32[Array, 'num_trees']
|
|
1267
|
+
sqrt_term: Float32[Array, 'num_trees']
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
class PreLk(Module):
|
|
1271
|
+
"""
|
|
1272
|
+
Non-sequential terms of the likelihood ratio shared by all trees.
|
|
1273
|
+
|
|
1274
|
+
Parameters
|
|
1275
|
+
----------
|
|
1276
|
+
exp_factor
|
|
1277
|
+
The factor to multiply the likelihood ratio by, shared by all trees.
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
exp_factor: Float32[Array, '']
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
class PreLf(Module):
|
|
1284
|
+
"""
|
|
1285
|
+
Pre-computed terms used to sample leaves from their posterior.
|
|
1286
|
+
|
|
1287
|
+
These terms can be computed in parallel across trees.
|
|
1288
|
+
|
|
1289
|
+
Parameters
|
|
1290
|
+
----------
|
|
1291
|
+
mean_factor
|
|
1292
|
+
The factor to be multiplied by the sum of the scaled residuals to
|
|
1293
|
+
obtain the posterior mean.
|
|
1294
|
+
centered_leaves
|
|
1295
|
+
The mean-zero normal values to be added to the posterior mean to
|
|
1296
|
+
obtain the posterior leaf samples.
|
|
1297
|
+
"""
|
|
1298
|
+
|
|
1299
|
+
mean_factor: Float32[Array, 'num_trees 2**d']
|
|
1300
|
+
centered_leaves: Float32[Array, 'num_trees 2**d']
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
class ParallelStageOut(Module):
|
|
1304
|
+
"""
|
|
1305
|
+
The output of `accept_moves_parallel_stage`.
|
|
1306
|
+
|
|
1307
|
+
Parameters
|
|
1308
|
+
----------
|
|
1309
|
+
bart
|
|
1310
|
+
A partially updated BART mcmc state.
|
|
1311
|
+
moves
|
|
1312
|
+
The proposed moves, with `partial_ratio` set to `None` and
|
|
1313
|
+
`log_trans_prior_ratio` set to its final value.
|
|
1314
|
+
prec_trees
|
|
1315
|
+
The likelihood precision scale in each potential or actual leaf node. If
|
|
1316
|
+
there is no precision scale, this is the number of points in each leaf.
|
|
1317
|
+
move_counts
|
|
1318
|
+
The counts of the number of points in the the nodes modified by the
|
|
1319
|
+
moves. If `bart.min_points_per_leaf` is not set and
|
|
1320
|
+
`bart.prec_scale` is set, they are not computed.
|
|
1321
|
+
move_precs
|
|
1322
|
+
The likelihood precision scale in each node modified by the moves. If
|
|
1323
|
+
`bart.prec_scale` is not set, this is set to `move_counts`.
|
|
1324
|
+
prelkv
|
|
1325
|
+
prelk
|
|
1326
|
+
prelf
|
|
1327
|
+
Objects with pre-computed terms of the likelihood ratios and leaf
|
|
1328
|
+
samples.
|
|
1329
|
+
"""
|
|
1330
|
+
|
|
1331
|
+
bart: State
|
|
1332
|
+
moves: Moves
|
|
1333
|
+
prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
|
|
1334
|
+
move_counts: Counts | None
|
|
1335
|
+
move_precs: Precs | Counts
|
|
1336
|
+
prelkv: PreLkV
|
|
1337
|
+
prelk: PreLk
|
|
1338
|
+
prelf: PreLf
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def accept_moves_parallel_stage(
|
|
1342
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1343
|
+
) -> ParallelStageOut:
|
|
891
1344
|
"""
|
|
892
1345
|
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
893
1346
|
|
|
894
1347
|
Parameters
|
|
895
1348
|
----------
|
|
1349
|
+
key : jax.dtypes.prng_key array
|
|
1350
|
+
A jax random key.
|
|
896
1351
|
bart : dict
|
|
897
1352
|
A BART mcmc state.
|
|
898
1353
|
moves : dict
|
|
899
|
-
The proposed moves, see `
|
|
900
|
-
key : jax.dtypes.prng_key array
|
|
901
|
-
A jax random key.
|
|
1354
|
+
The proposed moves, see `propose_moves`.
|
|
902
1355
|
|
|
903
1356
|
Returns
|
|
904
1357
|
-------
|
|
905
|
-
|
|
906
|
-
A partially updated BART mcmc state.
|
|
907
|
-
moves : dict
|
|
908
|
-
The proposed moves, with the field 'partial_ratio' replaced
|
|
909
|
-
by 'log_trans_prior_ratio'.
|
|
910
|
-
count_trees : array (num_trees, 2 ** d)
|
|
911
|
-
The number of points in each potential or actual leaf node.
|
|
912
|
-
move_counts : dict
|
|
913
|
-
The counts of the number of points in the the nodes modified by the
|
|
914
|
-
moves.
|
|
915
|
-
prelkv, prelk, prelf : dict
|
|
916
|
-
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
917
|
-
samples.
|
|
1358
|
+
An object with all that could be done in parallel.
|
|
918
1359
|
"""
|
|
919
|
-
bart = bart.copy()
|
|
920
|
-
|
|
921
1360
|
# where the move is grow, modify the state like the move was accepted
|
|
922
|
-
bart
|
|
923
|
-
|
|
924
|
-
|
|
1361
|
+
bart = replace(
|
|
1362
|
+
bart,
|
|
1363
|
+
forest=replace(
|
|
1364
|
+
bart.forest,
|
|
1365
|
+
var_trees=moves.var_trees,
|
|
1366
|
+
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
1367
|
+
leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
|
|
1368
|
+
),
|
|
1369
|
+
)
|
|
925
1370
|
|
|
926
1371
|
# count number of datapoints per leaf
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
1372
|
+
if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
|
|
1373
|
+
count_trees, move_counts = compute_count_trees(
|
|
1374
|
+
bart.forest.leaf_indices, moves, bart.forest.count_batch_size
|
|
1375
|
+
)
|
|
1376
|
+
else:
|
|
1377
|
+
# move_counts is passed later to a function, but then is unused under
|
|
1378
|
+
# this condition
|
|
1379
|
+
move_counts = None
|
|
1380
|
+
|
|
1381
|
+
# Check if some nodes can't surely be grown because they don't have enough
|
|
1382
|
+
# datapoints. This check is not actually used now, it will be used at the
|
|
1383
|
+
# beginning of the next step to propose moves.
|
|
1384
|
+
if bart.forest.min_points_per_leaf is not None:
|
|
1385
|
+
count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
|
|
1386
|
+
bart = replace(
|
|
1387
|
+
bart,
|
|
1388
|
+
forest=replace(
|
|
1389
|
+
bart.forest,
|
|
1390
|
+
affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
|
|
1391
|
+
),
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
# count number of datapoints per leaf, weighted by error precision scale
|
|
1395
|
+
if bart.prec_scale is None:
|
|
1396
|
+
prec_trees = count_trees
|
|
1397
|
+
move_precs = move_counts
|
|
1398
|
+
else:
|
|
1399
|
+
prec_trees, move_precs = compute_prec_trees(
|
|
1400
|
+
bart.prec_scale,
|
|
1401
|
+
bart.forest.leaf_indices,
|
|
1402
|
+
moves,
|
|
1403
|
+
bart.forest.count_batch_size,
|
|
1404
|
+
)
|
|
931
1405
|
|
|
932
1406
|
# compute some missing information about moves
|
|
933
|
-
moves = complete_ratio(moves, move_counts, bart
|
|
934
|
-
bart
|
|
935
|
-
|
|
1407
|
+
moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
|
|
1408
|
+
bart = replace(
|
|
1409
|
+
bart,
|
|
1410
|
+
forest=replace(
|
|
1411
|
+
bart.forest,
|
|
1412
|
+
grow_prop_count=jnp.sum(moves.grow),
|
|
1413
|
+
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
1414
|
+
),
|
|
1415
|
+
)
|
|
936
1416
|
|
|
937
|
-
prelkv, prelk = precompute_likelihood_terms(
|
|
938
|
-
|
|
1417
|
+
prelkv, prelk = precompute_likelihood_terms(
|
|
1418
|
+
bart.sigma2, bart.forest.sigma_mu2, move_precs
|
|
1419
|
+
)
|
|
1420
|
+
prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
|
|
1421
|
+
|
|
1422
|
+
return ParallelStageOut(
|
|
1423
|
+
bart=bart,
|
|
1424
|
+
moves=moves,
|
|
1425
|
+
prec_trees=prec_trees,
|
|
1426
|
+
move_counts=move_counts,
|
|
1427
|
+
move_precs=move_precs,
|
|
1428
|
+
prelkv=prelkv,
|
|
1429
|
+
prelk=prelk,
|
|
1430
|
+
prelf=prelf,
|
|
1431
|
+
)
|
|
939
1432
|
|
|
940
|
-
return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
|
|
941
1433
|
|
|
942
|
-
@
|
|
943
|
-
def apply_grow_to_indices(
|
|
1434
|
+
@partial(vmap_nodoc, in_axes=(0, 0, None))
|
|
1435
|
+
def apply_grow_to_indices(
|
|
1436
|
+
moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
|
|
1437
|
+
) -> UInt[Array, 'num_trees n']:
|
|
944
1438
|
"""
|
|
945
1439
|
Update the leaf indices to apply a grow move.
|
|
946
1440
|
|
|
947
1441
|
Parameters
|
|
948
1442
|
----------
|
|
949
|
-
moves
|
|
950
|
-
The proposed moves, see `
|
|
951
|
-
leaf_indices
|
|
1443
|
+
moves
|
|
1444
|
+
The proposed moves, see `propose_moves`.
|
|
1445
|
+
leaf_indices
|
|
952
1446
|
The index of the leaf each datapoint falls into.
|
|
953
|
-
X
|
|
1447
|
+
X
|
|
954
1448
|
The predictors matrix.
|
|
955
1449
|
|
|
956
1450
|
Returns
|
|
957
1451
|
-------
|
|
958
|
-
|
|
959
|
-
The updated leaf indices.
|
|
1452
|
+
The updated leaf indices.
|
|
960
1453
|
"""
|
|
961
|
-
left_child = moves
|
|
962
|
-
go_right = X[moves
|
|
963
|
-
tree_size = jnp.array(2 * moves
|
|
964
|
-
node_to_update = jnp.where(moves
|
|
1454
|
+
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
1455
|
+
go_right = X[moves.grow_var, :] >= moves.grow_split
|
|
1456
|
+
tree_size = jnp.array(2 * moves.var_trees.size)
|
|
1457
|
+
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
965
1458
|
return jnp.where(
|
|
966
1459
|
leaf_indices == node_to_update,
|
|
967
1460
|
left_child + go_right,
|
|
968
1461
|
leaf_indices,
|
|
969
1462
|
)
|
|
970
1463
|
|
|
971
|
-
|
|
1464
|
+
|
|
1465
|
+
def compute_count_trees(
|
|
1466
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
|
|
1467
|
+
) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
|
|
972
1468
|
"""
|
|
973
1469
|
Count the number of datapoints in each leaf.
|
|
974
1470
|
|
|
975
1471
|
Parameters
|
|
976
1472
|
----------
|
|
977
|
-
|
|
978
|
-
The index of the leaf each datapoint falls into,
|
|
979
|
-
|
|
980
|
-
moves
|
|
981
|
-
The proposed moves, see `
|
|
982
|
-
batch_size
|
|
1473
|
+
leaf_indices
|
|
1474
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
1475
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
1476
|
+
moves
|
|
1477
|
+
The proposed moves, see `propose_moves`.
|
|
1478
|
+
batch_size
|
|
983
1479
|
The data batch size to use for the summation.
|
|
984
1480
|
|
|
985
1481
|
Returns
|
|
986
1482
|
-------
|
|
987
|
-
count_trees :
|
|
1483
|
+
count_trees : Int32[Array, 'num_trees 2**d']
|
|
988
1484
|
The number of points in each potential or actual leaf node.
|
|
989
|
-
counts :
|
|
990
|
-
The counts of the number of points in the
|
|
991
|
-
moves
|
|
992
|
-
'left', 'right', and 'total'.
|
|
1485
|
+
counts : Counts
|
|
1486
|
+
The counts of the number of points in the leaves grown or pruned by the
|
|
1487
|
+
moves.
|
|
993
1488
|
"""
|
|
994
|
-
|
|
995
|
-
ntree, tree_size = moves['var_trees'].shape
|
|
1489
|
+
num_trees, tree_size = moves.var_trees.shape
|
|
996
1490
|
tree_size *= 2
|
|
997
|
-
tree_indices = jnp.arange(
|
|
1491
|
+
tree_indices = jnp.arange(num_trees)
|
|
998
1492
|
|
|
999
1493
|
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
1000
1494
|
|
|
1001
1495
|
# count datapoints in nodes modified by move
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
counts
|
|
1005
|
-
counts['total'] = counts['left'] + counts['right']
|
|
1496
|
+
left = count_trees[tree_indices, moves.left]
|
|
1497
|
+
right = count_trees[tree_indices, moves.right]
|
|
1498
|
+
counts = Counts(left=left, right=right, total=left + right)
|
|
1006
1499
|
|
|
1007
1500
|
# write count into non-leaf node
|
|
1008
|
-
count_trees = count_trees.at[tree_indices, moves
|
|
1501
|
+
count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
|
|
1009
1502
|
|
|
1010
1503
|
return count_trees, counts
|
|
1011
1504
|
|
|
1012
|
-
|
|
1505
|
+
|
|
1506
|
+
def count_datapoints_per_leaf(
|
|
1507
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
|
|
1508
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1013
1509
|
"""
|
|
1014
1510
|
Count the number of datapoints in each leaf.
|
|
1015
1511
|
|
|
1016
1512
|
Parameters
|
|
1017
1513
|
----------
|
|
1018
|
-
leaf_indices
|
|
1514
|
+
leaf_indices
|
|
1019
1515
|
The index of the leaf each datapoint falls into.
|
|
1020
|
-
tree_size
|
|
1516
|
+
tree_size
|
|
1021
1517
|
The size of the leaf tree array (2 ** d).
|
|
1022
|
-
batch_size
|
|
1518
|
+
batch_size
|
|
1023
1519
|
The data batch size to use for the summation.
|
|
1024
1520
|
|
|
1025
1521
|
Returns
|
|
1026
1522
|
-------
|
|
1027
|
-
|
|
1028
|
-
The number of points in each leaf node.
|
|
1523
|
+
The number of points in each leaf node.
|
|
1029
1524
|
"""
|
|
1030
1525
|
if batch_size is None:
|
|
1031
1526
|
return _count_scan(leaf_indices, tree_size)
|
|
1032
1527
|
else:
|
|
1033
1528
|
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1034
1529
|
|
|
1035
|
-
|
|
1530
|
+
|
|
1531
|
+
def _count_scan(
|
|
1532
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
|
|
1533
|
+
) -> Int32[Array, 'num_trees {tree_size}']:
|
|
1036
1534
|
def loop(_, leaf_indices):
|
|
1037
1535
|
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1536
|
+
|
|
1038
1537
|
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1039
1538
|
return count_trees
|
|
1040
1539
|
|
|
1041
|
-
def _aggregate_scatter(values, indices, size, dtype):
|
|
1042
|
-
return (jnp
|
|
1043
|
-
.zeros(size, dtype)
|
|
1044
|
-
.at[indices]
|
|
1045
|
-
.add(values)
|
|
1046
|
-
)
|
|
1047
1540
|
|
|
1048
|
-
def
|
|
1049
|
-
|
|
1050
|
-
|
|
1541
|
+
def _aggregate_scatter(
|
|
1542
|
+
values: Shaped[Array, '*'],
|
|
1543
|
+
indices: Integer[Array, '*'],
|
|
1544
|
+
size: int,
|
|
1545
|
+
dtype: jnp.dtype,
|
|
1546
|
+
) -> Shaped[Array, '{size}']:
|
|
1547
|
+
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1051
1548
|
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1549
|
+
|
|
1550
|
+
def _count_vec(
|
|
1551
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
|
|
1552
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1553
|
+
return _aggregate_batched_alltrees(
|
|
1554
|
+
1, leaf_indices, tree_size, jnp.uint32, batch_size
|
|
1555
|
+
)
|
|
1556
|
+
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
def _aggregate_batched_alltrees(
|
|
1560
|
+
values: Shaped[Array, '*'],
|
|
1561
|
+
indices: UInt[Array, 'num_trees n'],
|
|
1562
|
+
size: int,
|
|
1563
|
+
dtype: jnp.dtype,
|
|
1564
|
+
batch_size: int,
|
|
1565
|
+
) -> Shaped[Array, 'num_trees {size}']:
|
|
1566
|
+
num_trees, n = indices.shape
|
|
1567
|
+
tree_indices = jnp.arange(num_trees)
|
|
1055
1568
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1056
1569
|
batch_indices = jnp.arange(n) % nbatches
|
|
1057
|
-
return (
|
|
1058
|
-
.zeros((
|
|
1570
|
+
return (
|
|
1571
|
+
jnp.zeros((num_trees, size, nbatches), dtype)
|
|
1059
1572
|
.at[tree_indices[:, None], indices, batch_indices]
|
|
1060
1573
|
.add(values)
|
|
1061
1574
|
.sum(axis=2)
|
|
1062
1575
|
)
|
|
1063
1576
|
|
|
1064
|
-
|
|
1577
|
+
|
|
1578
|
+
def compute_prec_trees(
|
|
1579
|
+
prec_scale: Float32[Array, 'n'],
|
|
1580
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1581
|
+
moves: Moves,
|
|
1582
|
+
batch_size: int | None,
|
|
1583
|
+
) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
|
|
1584
|
+
"""
|
|
1585
|
+
Compute the likelihood precision scale in each leaf.
|
|
1586
|
+
|
|
1587
|
+
Parameters
|
|
1588
|
+
----------
|
|
1589
|
+
prec_scale
|
|
1590
|
+
The scale of the precision of the error on each datapoint.
|
|
1591
|
+
leaf_indices
|
|
1592
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
1593
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
1594
|
+
moves
|
|
1595
|
+
The proposed moves, see `propose_moves`.
|
|
1596
|
+
batch_size
|
|
1597
|
+
The data batch size to use for the summation.
|
|
1598
|
+
|
|
1599
|
+
Returns
|
|
1600
|
+
-------
|
|
1601
|
+
prec_trees : Float32[Array, 'num_trees 2**d']
|
|
1602
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1603
|
+
precs : Precs
|
|
1604
|
+
The likelihood precision scale in the nodes involved in the moves.
|
|
1605
|
+
"""
|
|
1606
|
+
num_trees, tree_size = moves.var_trees.shape
|
|
1607
|
+
tree_size *= 2
|
|
1608
|
+
tree_indices = jnp.arange(num_trees)
|
|
1609
|
+
|
|
1610
|
+
prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1611
|
+
|
|
1612
|
+
# prec datapoints in nodes modified by move
|
|
1613
|
+
left = prec_trees[tree_indices, moves.left]
|
|
1614
|
+
right = prec_trees[tree_indices, moves.right]
|
|
1615
|
+
precs = Precs(left=left, right=right, total=left + right)
|
|
1616
|
+
|
|
1617
|
+
# write prec into non-leaf node
|
|
1618
|
+
prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
|
|
1619
|
+
|
|
1620
|
+
return prec_trees, precs
|
|
1621
|
+
|
|
1622
|
+
|
|
1623
|
+
def prec_per_leaf(
|
|
1624
|
+
prec_scale: Float32[Array, 'n'],
|
|
1625
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1626
|
+
tree_size: int,
|
|
1627
|
+
batch_size: int | None,
|
|
1628
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1629
|
+
"""
|
|
1630
|
+
Compute the likelihood precision scale in each leaf.
|
|
1631
|
+
|
|
1632
|
+
Parameters
|
|
1633
|
+
----------
|
|
1634
|
+
prec_scale
|
|
1635
|
+
The scale of the precision of the error on each datapoint.
|
|
1636
|
+
leaf_indices
|
|
1637
|
+
The index of the leaf each datapoint falls into.
|
|
1638
|
+
tree_size
|
|
1639
|
+
The size of the leaf tree array (2 ** d).
|
|
1640
|
+
batch_size
|
|
1641
|
+
The data batch size to use for the summation.
|
|
1642
|
+
|
|
1643
|
+
Returns
|
|
1644
|
+
-------
|
|
1645
|
+
The likelihood precision scale in each leaf node.
|
|
1646
|
+
"""
|
|
1647
|
+
if batch_size is None:
|
|
1648
|
+
return _prec_scan(prec_scale, leaf_indices, tree_size)
|
|
1649
|
+
else:
|
|
1650
|
+
return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1651
|
+
|
|
1652
|
+
|
|
1653
|
+
def _prec_scan(
|
|
1654
|
+
prec_scale: Float32[Array, 'n'],
|
|
1655
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1656
|
+
tree_size: int,
|
|
1657
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1658
|
+
def loop(_, leaf_indices):
|
|
1659
|
+
return None, _aggregate_scatter(
|
|
1660
|
+
prec_scale, leaf_indices, tree_size, jnp.float32
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
_, prec_trees = lax.scan(loop, None, leaf_indices)
|
|
1664
|
+
return prec_trees
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
def _prec_vec(
|
|
1668
|
+
prec_scale: Float32[Array, 'n'],
|
|
1669
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1670
|
+
tree_size: int,
|
|
1671
|
+
batch_size: int,
|
|
1672
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1673
|
+
return _aggregate_batched_alltrees(
|
|
1674
|
+
prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
|
|
1675
|
+
)
|
|
1676
|
+
|
|
1677
|
+
|
|
1678
|
+
def complete_ratio(
|
|
1679
|
+
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1680
|
+
) -> Moves:
|
|
1065
1681
|
"""
|
|
1066
1682
|
Complete non-likelihood MH ratio calculation.
|
|
1067
1683
|
|
|
1068
|
-
This
|
|
1684
|
+
This function adds the probability of choosing the prune move.
|
|
1069
1685
|
|
|
1070
1686
|
Parameters
|
|
1071
1687
|
----------
|
|
1072
|
-
moves
|
|
1073
|
-
The proposed moves, see `
|
|
1074
|
-
move_counts
|
|
1688
|
+
moves
|
|
1689
|
+
The proposed moves, see `propose_moves`.
|
|
1690
|
+
move_counts
|
|
1075
1691
|
The counts of the number of points in the the nodes modified by the
|
|
1076
1692
|
moves.
|
|
1077
|
-
min_points_per_leaf
|
|
1693
|
+
min_points_per_leaf
|
|
1078
1694
|
The minimum number of data points in a leaf node.
|
|
1079
1695
|
|
|
1080
1696
|
Returns
|
|
1081
1697
|
-------
|
|
1082
|
-
moves
|
|
1083
|
-
The updated moves, with the field 'partial_ratio' replaced by
|
|
1084
|
-
'log_trans_prior_ratio'.
|
|
1698
|
+
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
1085
1699
|
"""
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1700
|
+
p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
|
|
1701
|
+
return replace(
|
|
1702
|
+
moves,
|
|
1703
|
+
log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
|
|
1704
|
+
partial_ratio=None,
|
|
1705
|
+
)
|
|
1706
|
+
|
|
1090
1707
|
|
|
1091
|
-
def compute_p_prune(
|
|
1708
|
+
def compute_p_prune(
|
|
1709
|
+
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1710
|
+
) -> Float32[Array, 'num_trees']:
|
|
1092
1711
|
"""
|
|
1093
|
-
Compute the probability of proposing a prune move.
|
|
1712
|
+
Compute the probability of proposing a prune move for each tree.
|
|
1094
1713
|
|
|
1095
1714
|
Parameters
|
|
1096
1715
|
----------
|
|
1097
|
-
moves
|
|
1098
|
-
The proposed moves, see `
|
|
1099
|
-
|
|
1716
|
+
moves
|
|
1717
|
+
The proposed moves, see `propose_moves`.
|
|
1718
|
+
move_counts
|
|
1100
1719
|
The number of datapoints in the proposed children of the leaf to grow.
|
|
1101
|
-
|
|
1720
|
+
Not used if `min_points_per_leaf` is `None`.
|
|
1721
|
+
min_points_per_leaf
|
|
1102
1722
|
The minimum number of data points in a leaf node.
|
|
1103
1723
|
|
|
1104
1724
|
Returns
|
|
1105
1725
|
-------
|
|
1106
|
-
|
|
1107
|
-
The probability of proposing a prune move. If grow: after accepting the
|
|
1108
|
-
grow move, if prune: right away.
|
|
1109
|
-
"""
|
|
1726
|
+
The probability of proposing a prune move.
|
|
1110
1727
|
|
|
1728
|
+
Notes
|
|
1729
|
+
-----
|
|
1730
|
+
This probability is computed for going from the state with the deeper tree
|
|
1731
|
+
to the one with the shallower one. This means, if grow: after accepting the
|
|
1732
|
+
grow move, if prune: right away.
|
|
1733
|
+
"""
|
|
1111
1734
|
# calculation in case the move is grow
|
|
1112
|
-
other_growable_leaves = moves
|
|
1113
|
-
new_leaves_growable = moves
|
|
1735
|
+
other_growable_leaves = moves.num_growable >= 2
|
|
1736
|
+
new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
|
|
1114
1737
|
if min_points_per_leaf is not None:
|
|
1115
|
-
|
|
1116
|
-
any_above_threshold
|
|
1738
|
+
assert move_counts is not None
|
|
1739
|
+
any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
|
|
1740
|
+
any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
|
|
1117
1741
|
new_leaves_growable &= any_above_threshold
|
|
1118
1742
|
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1119
1743
|
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1120
1744
|
|
|
1121
1745
|
# calculation in case the move is prune
|
|
1122
|
-
prune_p_prune = jnp.where(moves
|
|
1746
|
+
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
1747
|
+
|
|
1748
|
+
return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
1123
1749
|
|
|
1124
|
-
return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
|
|
1125
1750
|
|
|
1126
|
-
@
|
|
1127
|
-
def adapt_leaf_trees_to_grow_indices(
|
|
1751
|
+
@vmap_nodoc
|
|
1752
|
+
def adapt_leaf_trees_to_grow_indices(
|
|
1753
|
+
leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
|
|
1754
|
+
) -> Float32[Array, 'num_trees 2**d']:
|
|
1128
1755
|
"""
|
|
1129
|
-
Modify
|
|
1130
|
-
|
|
1756
|
+
Modify leaves such that post-grow indices work on the original tree.
|
|
1757
|
+
|
|
1758
|
+
The value of the leaf to grow is copied to what would be its children if the
|
|
1759
|
+
grow move was accepted.
|
|
1131
1760
|
|
|
1132
1761
|
Parameters
|
|
1133
1762
|
----------
|
|
1134
|
-
leaf_trees
|
|
1763
|
+
leaf_trees
|
|
1135
1764
|
The leaf values.
|
|
1136
|
-
moves
|
|
1137
|
-
The proposed moves, see `
|
|
1765
|
+
moves
|
|
1766
|
+
The proposed moves, see `propose_moves`.
|
|
1138
1767
|
|
|
1139
1768
|
Returns
|
|
1140
1769
|
-------
|
|
1141
|
-
|
|
1142
|
-
The modified leaf values. The value of the leaf to grow is copied to
|
|
1143
|
-
what would be its children if the grow move was accepted.
|
|
1770
|
+
The modified leaf values.
|
|
1144
1771
|
"""
|
|
1145
|
-
values_at_node = leaf_trees[moves
|
|
1146
|
-
return (
|
|
1147
|
-
.at[jnp.where(moves
|
|
1772
|
+
values_at_node = leaf_trees[moves.node]
|
|
1773
|
+
return (
|
|
1774
|
+
leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
|
|
1148
1775
|
.set(values_at_node)
|
|
1149
|
-
.at[jnp.where(moves
|
|
1776
|
+
.at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
|
|
1150
1777
|
.set(values_at_node)
|
|
1151
1778
|
)
|
|
1152
1779
|
|
|
1153
|
-
|
|
1780
|
+
|
|
1781
|
+
def precompute_likelihood_terms(
|
|
1782
|
+
sigma2: Float32[Array, ''],
|
|
1783
|
+
sigma_mu2: Float32[Array, ''],
|
|
1784
|
+
move_precs: Precs | Counts,
|
|
1785
|
+
) -> tuple[PreLkV, PreLk]:
|
|
1154
1786
|
"""
|
|
1155
1787
|
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1156
1788
|
|
|
1157
1789
|
Parameters
|
|
1158
1790
|
----------
|
|
1159
|
-
|
|
1160
|
-
The
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1791
|
+
sigma2
|
|
1792
|
+
The error variance, or the global error variance factor is `prec_scale`
|
|
1793
|
+
is set.
|
|
1794
|
+
sigma_mu2
|
|
1795
|
+
The prior variance of each leaf.
|
|
1796
|
+
move_precs
|
|
1797
|
+
The likelihood precision scale in the leaves grown or pruned by the
|
|
1798
|
+
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1166
1799
|
|
|
1167
1800
|
Returns
|
|
1168
1801
|
-------
|
|
1169
|
-
prelkv :
|
|
1802
|
+
prelkv : PreLkV
|
|
1170
1803
|
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
1171
1804
|
tree.
|
|
1172
|
-
prelk :
|
|
1805
|
+
prelk : PreLk
|
|
1173
1806
|
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1174
1807
|
all trees.
|
|
1175
1808
|
"""
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
prelkv
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
sigma2 *
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
return prelkv, dict(
|
|
1809
|
+
sigma2_left = sigma2 + move_precs.left * sigma_mu2
|
|
1810
|
+
sigma2_right = sigma2 + move_precs.right * sigma_mu2
|
|
1811
|
+
sigma2_total = sigma2 + move_precs.total * sigma_mu2
|
|
1812
|
+
prelkv = PreLkV(
|
|
1813
|
+
sigma2_left=sigma2_left,
|
|
1814
|
+
sigma2_right=sigma2_right,
|
|
1815
|
+
sigma2_total=sigma2_total,
|
|
1816
|
+
sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
|
|
1817
|
+
)
|
|
1818
|
+
return prelkv, PreLk(
|
|
1187
1819
|
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1188
1820
|
)
|
|
1189
1821
|
|
|
1190
|
-
|
|
1822
|
+
|
|
1823
|
+
def precompute_leaf_terms(
|
|
1824
|
+
key: Key[Array, ''],
|
|
1825
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
1826
|
+
sigma2: Float32[Array, ''],
|
|
1827
|
+
sigma_mu2: Float32[Array, ''],
|
|
1828
|
+
) -> PreLf:
|
|
1191
1829
|
"""
|
|
1192
1830
|
Pre-compute terms used to sample leaves from their posterior.
|
|
1193
1831
|
|
|
1194
1832
|
Parameters
|
|
1195
1833
|
----------
|
|
1196
|
-
|
|
1197
|
-
The number of points in each potential or actual leaf node.
|
|
1198
|
-
sigma2 : float
|
|
1199
|
-
The noise variance.
|
|
1200
|
-
key : jax.dtypes.prng_key array
|
|
1834
|
+
key
|
|
1201
1835
|
A jax random key.
|
|
1836
|
+
prec_trees
|
|
1837
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1838
|
+
sigma2
|
|
1839
|
+
The error variance, or the global error variance factor if `prec_scale`
|
|
1840
|
+
is set.
|
|
1841
|
+
sigma_mu2
|
|
1842
|
+
The prior variance of each leaf.
|
|
1202
1843
|
|
|
1203
1844
|
Returns
|
|
1204
1845
|
-------
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1846
|
+
Pre-computed terms for leaf sampling.
|
|
1847
|
+
"""
|
|
1848
|
+
prec_lk = prec_trees / sigma2
|
|
1849
|
+
prec_prior = lax.reciprocal(sigma_mu2)
|
|
1850
|
+
var_post = lax.reciprocal(prec_lk + prec_prior)
|
|
1851
|
+
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
1852
|
+
return PreLf(
|
|
1853
|
+
mean_factor=var_post / sigma2,
|
|
1854
|
+
# mean = mean_lk * prec_lk * var_post
|
|
1855
|
+
# resid_tree = mean_lk * prec_tree -->
|
|
1856
|
+
# --> mean_lk = resid_tree / prec_tree (kind of)
|
|
1857
|
+
# mean_factor =
|
|
1858
|
+
# = mean / resid_tree =
|
|
1859
|
+
# = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
1860
|
+
# = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
1861
|
+
# = var_post / sigma2
|
|
1221
1862
|
centered_leaves=z * jnp.sqrt(var_post),
|
|
1222
1863
|
)
|
|
1223
1864
|
|
|
1224
|
-
|
|
1865
|
+
|
|
1866
|
+
def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
1225
1867
|
"""
|
|
1226
|
-
|
|
1868
|
+
Accept/reject the moves one tree at a time.
|
|
1869
|
+
|
|
1870
|
+
This is the most performance-sensitive function because it contains all and
|
|
1871
|
+
only the parts of the algorithm that can not be parallelized across trees.
|
|
1227
1872
|
|
|
1228
1873
|
Parameters
|
|
1229
1874
|
----------
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
count_trees : array (num_trees, 2 ** d)
|
|
1233
|
-
The number of points in each potential or actual leaf node.
|
|
1234
|
-
moves : dict
|
|
1235
|
-
The proposed moves, see `sample_moves`.
|
|
1236
|
-
move_counts : dict
|
|
1237
|
-
The counts of the number of points in the the nodes modified by the
|
|
1238
|
-
moves.
|
|
1239
|
-
prelkv, prelk, prelf : dict
|
|
1240
|
-
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1241
|
-
samples.
|
|
1875
|
+
pso
|
|
1876
|
+
The output of `accept_moves_parallel_stage`.
|
|
1242
1877
|
|
|
1243
1878
|
Returns
|
|
1244
1879
|
-------
|
|
1245
|
-
bart :
|
|
1880
|
+
bart : State
|
|
1246
1881
|
A partially updated BART mcmc state.
|
|
1247
|
-
moves :
|
|
1248
|
-
The
|
|
1249
|
-
|
|
1250
|
-
'acc' : bool array (num_trees,)
|
|
1251
|
-
Whether the move was accepted.
|
|
1252
|
-
'to_prune' : bool array (num_trees,)
|
|
1253
|
-
Whether, to reflect the acceptance status of the move, the state
|
|
1254
|
-
should be updated by pruning the leaves involved in the move.
|
|
1882
|
+
moves : Moves
|
|
1883
|
+
The accepted/rejected moves, with `acc` and `to_prune` set.
|
|
1255
1884
|
"""
|
|
1256
|
-
bart = bart.copy()
|
|
1257
|
-
moves = moves.copy()
|
|
1258
1885
|
|
|
1259
|
-
def loop(resid,
|
|
1886
|
+
def loop(resid, pt):
|
|
1260
1887
|
resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
|
|
1261
|
-
bart['X'],
|
|
1262
|
-
len(bart['leaf_trees']),
|
|
1263
|
-
bart['opt']['resid_batch_size'],
|
|
1264
1888
|
resid,
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1889
|
+
SeqStageInAllTrees(
|
|
1890
|
+
pso.bart.X,
|
|
1891
|
+
pso.bart.forest.resid_batch_size,
|
|
1892
|
+
pso.bart.prec_scale,
|
|
1893
|
+
pso.bart.forest.min_points_per_leaf,
|
|
1894
|
+
pso.bart.forest.log_likelihood is not None,
|
|
1895
|
+
pso.prelk,
|
|
1896
|
+
),
|
|
1897
|
+
pt,
|
|
1269
1898
|
)
|
|
1270
1899
|
return resid, (leaf_tree, acc, to_prune, ratios)
|
|
1271
1900
|
|
|
1272
|
-
|
|
1273
|
-
bart
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1901
|
+
pts = SeqStageInPerTree(
|
|
1902
|
+
pso.bart.forest.leaf_trees,
|
|
1903
|
+
pso.prec_trees,
|
|
1904
|
+
pso.moves,
|
|
1905
|
+
pso.move_counts,
|
|
1906
|
+
pso.move_precs,
|
|
1907
|
+
pso.bart.forest.leaf_indices,
|
|
1908
|
+
pso.prelkv,
|
|
1909
|
+
pso.prelf,
|
|
1277
1910
|
)
|
|
1278
|
-
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
bart
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1911
|
+
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
|
|
1912
|
+
|
|
1913
|
+
save_ratios = pso.bart.forest.log_likelihood is not None
|
|
1914
|
+
bart = replace(
|
|
1915
|
+
pso.bart,
|
|
1916
|
+
resid=resid,
|
|
1917
|
+
forest=replace(
|
|
1918
|
+
pso.bart.forest,
|
|
1919
|
+
leaf_trees=leaf_trees,
|
|
1920
|
+
log_likelihood=ratios['log_likelihood'] if save_ratios else None,
|
|
1921
|
+
log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
|
|
1922
|
+
),
|
|
1923
|
+
)
|
|
1924
|
+
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
1285
1925
|
|
|
1286
1926
|
return bart, moves
|
|
1287
1927
|
|
|
1288
|
-
|
|
1928
|
+
|
|
1929
|
+
class SeqStageInAllTrees(Module):
|
|
1289
1930
|
"""
|
|
1290
|
-
|
|
1931
|
+
The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
|
|
1291
1932
|
|
|
1292
1933
|
Parameters
|
|
1293
1934
|
----------
|
|
1294
|
-
X
|
|
1935
|
+
X
|
|
1295
1936
|
The predictors.
|
|
1296
|
-
|
|
1297
|
-
The number of trees in the forest.
|
|
1298
|
-
resid_batch_size : int, None
|
|
1937
|
+
resid_batch_size
|
|
1299
1938
|
The batch size for computing the sum of residuals in each leaf.
|
|
1300
|
-
|
|
1301
|
-
The
|
|
1302
|
-
|
|
1939
|
+
prec_scale
|
|
1940
|
+
The scale of the precision of the error on each datapoint. If None, it
|
|
1941
|
+
is assumed to be 1.
|
|
1942
|
+
min_points_per_leaf
|
|
1303
1943
|
The minimum number of data points in a leaf node.
|
|
1304
|
-
save_ratios
|
|
1944
|
+
save_ratios
|
|
1305
1945
|
Whether to save the acceptance ratios.
|
|
1306
|
-
prelk
|
|
1946
|
+
prelk
|
|
1307
1947
|
The pre-computed terms of the likelihood ratio which are shared across
|
|
1308
1948
|
trees.
|
|
1309
|
-
|
|
1949
|
+
"""
|
|
1950
|
+
|
|
1951
|
+
X: UInt[Array, 'p n']
|
|
1952
|
+
resid_batch_size: int | None
|
|
1953
|
+
prec_scale: Float32[Array, 'n'] | None
|
|
1954
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
1955
|
+
save_ratios: bool
|
|
1956
|
+
prelk: PreLk
|
|
1957
|
+
|
|
1958
|
+
|
|
1959
|
+
class SeqStageInPerTree(Module):
|
|
1960
|
+
"""
|
|
1961
|
+
The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
|
|
1962
|
+
|
|
1963
|
+
Parameters
|
|
1964
|
+
----------
|
|
1965
|
+
leaf_tree
|
|
1310
1966
|
The leaf values of the tree.
|
|
1311
|
-
|
|
1312
|
-
The
|
|
1313
|
-
move
|
|
1314
|
-
The proposed move, see `
|
|
1315
|
-
|
|
1967
|
+
prec_tree
|
|
1968
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1969
|
+
move
|
|
1970
|
+
The proposed move, see `propose_moves`.
|
|
1971
|
+
move_counts
|
|
1972
|
+
The counts of the number of points in the the nodes modified by the
|
|
1973
|
+
moves.
|
|
1974
|
+
move_precs
|
|
1975
|
+
The likelihood precision scale in each node modified by the moves.
|
|
1976
|
+
leaf_indices
|
|
1316
1977
|
The leaf indices for the largest version of the tree compatible with
|
|
1317
1978
|
the move.
|
|
1318
|
-
prelkv
|
|
1979
|
+
prelkv
|
|
1980
|
+
prelf
|
|
1319
1981
|
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
1320
1982
|
are specific to the tree.
|
|
1983
|
+
"""
|
|
1984
|
+
|
|
1985
|
+
leaf_tree: Float32[Array, '2**d']
|
|
1986
|
+
prec_tree: Float32[Array, '2**d']
|
|
1987
|
+
move: Moves
|
|
1988
|
+
move_counts: Counts | None
|
|
1989
|
+
move_precs: Precs | Counts
|
|
1990
|
+
leaf_indices: UInt[Array, 'n']
|
|
1991
|
+
prelkv: PreLkV
|
|
1992
|
+
prelf: PreLf
|
|
1993
|
+
|
|
1994
|
+
|
|
1995
|
+
def accept_move_and_sample_leaves(
|
|
1996
|
+
resid: Float32[Array, 'n'],
|
|
1997
|
+
at: SeqStageInAllTrees,
|
|
1998
|
+
pt: SeqStageInPerTree,
|
|
1999
|
+
) -> tuple[
|
|
2000
|
+
Float32[Array, 'n'],
|
|
2001
|
+
Float32[Array, '2**d'],
|
|
2002
|
+
Bool[Array, ''],
|
|
2003
|
+
Bool[Array, ''],
|
|
2004
|
+
dict[str, Float32[Array, '']],
|
|
2005
|
+
]:
|
|
2006
|
+
"""
|
|
2007
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
2008
|
+
|
|
2009
|
+
Parameters
|
|
2010
|
+
----------
|
|
2011
|
+
resid
|
|
2012
|
+
The residuals (data minus forest value).
|
|
2013
|
+
at
|
|
2014
|
+
The inputs that are the same for all trees.
|
|
2015
|
+
pt
|
|
2016
|
+
The inputs that are separate for each tree.
|
|
1321
2017
|
|
|
1322
2018
|
Returns
|
|
1323
2019
|
-------
|
|
1324
|
-
resid :
|
|
2020
|
+
resid : Float32[Array, 'n']
|
|
1325
2021
|
The updated residuals (data minus forest value).
|
|
1326
|
-
leaf_tree :
|
|
2022
|
+
leaf_tree : Float32[Array, '2**d']
|
|
1327
2023
|
The new leaf values of the tree.
|
|
1328
|
-
acc :
|
|
2024
|
+
acc : Bool[Array, '']
|
|
1329
2025
|
Whether the move was accepted.
|
|
1330
|
-
to_prune :
|
|
2026
|
+
to_prune : Bool[Array, '']
|
|
1331
2027
|
Whether, to reflect the acceptance status of the move, the state should
|
|
1332
2028
|
be updated by pruning the leaves involved in the move.
|
|
1333
|
-
ratios : dict
|
|
2029
|
+
ratios : dict[str, Float32[Array, '']]
|
|
1334
2030
|
The acceptance ratios for the moves. Empty if not to be saved.
|
|
1335
2031
|
"""
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
2032
|
+
# sum residuals in each leaf, in tree proposed by grow move
|
|
2033
|
+
if at.prec_scale is None:
|
|
2034
|
+
scaled_resid = resid
|
|
2035
|
+
else:
|
|
2036
|
+
scaled_resid = resid * at.prec_scale
|
|
2037
|
+
resid_tree = sum_resid(
|
|
2038
|
+
scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
|
|
2039
|
+
)
|
|
1339
2040
|
|
|
1340
2041
|
# subtract starting tree from function
|
|
1341
|
-
resid_tree +=
|
|
2042
|
+
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
1342
2043
|
|
|
1343
2044
|
# get indices of move
|
|
1344
|
-
node = move
|
|
2045
|
+
node = pt.move.node
|
|
1345
2046
|
assert node.dtype == jnp.int32
|
|
1346
|
-
left = move
|
|
1347
|
-
right = move
|
|
2047
|
+
left = pt.move.left
|
|
2048
|
+
right = pt.move.right
|
|
1348
2049
|
|
|
1349
2050
|
# sum residuals in parent node modified by move
|
|
1350
2051
|
resid_left = resid_tree[left]
|
|
@@ -1353,215 +2054,282 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1353
2054
|
resid_tree = resid_tree.at[node].set(resid_total)
|
|
1354
2055
|
|
|
1355
2056
|
# compute acceptance ratio
|
|
1356
|
-
log_lk_ratio = compute_likelihood_ratio(
|
|
1357
|
-
|
|
1358
|
-
|
|
2057
|
+
log_lk_ratio = compute_likelihood_ratio(
|
|
2058
|
+
resid_total, resid_left, resid_right, pt.prelkv, at.prelk
|
|
2059
|
+
)
|
|
2060
|
+
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
2061
|
+
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
1359
2062
|
ratios = {}
|
|
1360
|
-
if save_ratios:
|
|
2063
|
+
if at.save_ratios:
|
|
1361
2064
|
ratios.update(
|
|
1362
|
-
log_trans_prior=move
|
|
2065
|
+
log_trans_prior=pt.move.log_trans_prior_ratio,
|
|
2066
|
+
# TODO save log_trans_prior_ratio as a vector outside of this loop,
|
|
2067
|
+
# then change the option everywhere to `save_likelihood_ratio`.
|
|
1363
2068
|
log_likelihood=log_lk_ratio,
|
|
1364
2069
|
)
|
|
1365
2070
|
|
|
1366
2071
|
# determine whether to accept the move
|
|
1367
|
-
acc = move
|
|
1368
|
-
if min_points_per_leaf is not None:
|
|
1369
|
-
|
|
1370
|
-
acc &= move_counts
|
|
2072
|
+
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
2073
|
+
if at.min_points_per_leaf is not None:
|
|
2074
|
+
assert pt.move_counts is not None
|
|
2075
|
+
acc &= pt.move_counts.left >= at.min_points_per_leaf
|
|
2076
|
+
acc &= pt.move_counts.right >= at.min_points_per_leaf
|
|
1371
2077
|
|
|
1372
2078
|
# compute leaves posterior and sample leaves
|
|
1373
|
-
initial_leaf_tree = leaf_tree
|
|
1374
|
-
mean_post = resid_tree * prelf
|
|
1375
|
-
leaf_tree = mean_post + prelf
|
|
1376
|
-
|
|
1377
|
-
# copy leaves around such that the leaf indices
|
|
1378
|
-
to_prune = acc ^ move
|
|
1379
|
-
leaf_tree = (
|
|
1380
|
-
.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
2079
|
+
initial_leaf_tree = pt.leaf_tree
|
|
2080
|
+
mean_post = resid_tree * pt.prelf.mean_factor
|
|
2081
|
+
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
2082
|
+
|
|
2083
|
+
# copy leaves around such that the leaf indices point to the correct leaf
|
|
2084
|
+
to_prune = acc ^ pt.move.grow
|
|
2085
|
+
leaf_tree = (
|
|
2086
|
+
leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1381
2087
|
.set(leaf_tree[node])
|
|
1382
2088
|
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
1383
2089
|
.set(leaf_tree[node])
|
|
1384
2090
|
)
|
|
1385
2091
|
|
|
1386
2092
|
# replace old tree with new tree in function values
|
|
1387
|
-
resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
|
|
2093
|
+
resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
|
|
1388
2094
|
|
|
1389
2095
|
return resid, leaf_tree, acc, to_prune, ratios
|
|
1390
2096
|
|
|
1391
|
-
|
|
2097
|
+
|
|
2098
|
+
def sum_resid(
|
|
2099
|
+
scaled_resid: Float32[Array, 'n'],
|
|
2100
|
+
leaf_indices: UInt[Array, 'n'],
|
|
2101
|
+
tree_size: int,
|
|
2102
|
+
batch_size: int | None,
|
|
2103
|
+
) -> Float32[Array, '{tree_size}']:
|
|
1392
2104
|
"""
|
|
1393
2105
|
Sum the residuals in each leaf.
|
|
1394
2106
|
|
|
1395
2107
|
Parameters
|
|
1396
2108
|
----------
|
|
1397
|
-
|
|
1398
|
-
The residuals (data minus forest value)
|
|
1399
|
-
|
|
2109
|
+
scaled_resid
|
|
2110
|
+
The residuals (data minus forest value) multiplied by the error
|
|
2111
|
+
precision scale.
|
|
2112
|
+
leaf_indices
|
|
1400
2113
|
The leaf indices of the tree (in which leaf each data point falls into).
|
|
1401
|
-
tree_size
|
|
2114
|
+
tree_size
|
|
1402
2115
|
The size of the tree array (2 ** d).
|
|
1403
|
-
batch_size
|
|
2116
|
+
batch_size
|
|
1404
2117
|
The data batch size for the aggregation. Batching increases numerical
|
|
1405
2118
|
accuracy and parallelism.
|
|
1406
2119
|
|
|
1407
2120
|
Returns
|
|
1408
2121
|
-------
|
|
1409
|
-
|
|
1410
|
-
The sum of the residuals at data points in each leaf.
|
|
2122
|
+
The sum of the residuals at data points in each leaf.
|
|
1411
2123
|
"""
|
|
1412
2124
|
if batch_size is None:
|
|
1413
2125
|
aggr_func = _aggregate_scatter
|
|
1414
2126
|
else:
|
|
1415
|
-
aggr_func =
|
|
1416
|
-
return aggr_func(
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
2127
|
+
aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
2128
|
+
return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
|
|
2129
|
+
|
|
2130
|
+
|
|
2131
|
+
def _aggregate_batched_onetree(
|
|
2132
|
+
values: Shaped[Array, '*'],
|
|
2133
|
+
indices: Integer[Array, '*'],
|
|
2134
|
+
size: int,
|
|
2135
|
+
dtype: jnp.dtype,
|
|
2136
|
+
batch_size: int,
|
|
2137
|
+
) -> Float32[Array, '{size}']:
|
|
2138
|
+
(n,) = indices.shape
|
|
1420
2139
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1421
2140
|
batch_indices = jnp.arange(n) % nbatches
|
|
1422
|
-
return (
|
|
1423
|
-
.zeros((size, nbatches), dtype)
|
|
2141
|
+
return (
|
|
2142
|
+
jnp.zeros((size, nbatches), dtype)
|
|
1424
2143
|
.at[indices, batch_indices]
|
|
1425
2144
|
.add(values)
|
|
1426
2145
|
.sum(axis=1)
|
|
1427
2146
|
)
|
|
1428
2147
|
|
|
1429
|
-
|
|
2148
|
+
|
|
2149
|
+
def compute_likelihood_ratio(
|
|
2150
|
+
total_resid: Float32[Array, ''],
|
|
2151
|
+
left_resid: Float32[Array, ''],
|
|
2152
|
+
right_resid: Float32[Array, ''],
|
|
2153
|
+
prelkv: PreLkV,
|
|
2154
|
+
prelk: PreLk,
|
|
2155
|
+
) -> Float32[Array, '']:
|
|
1430
2156
|
"""
|
|
1431
2157
|
Compute the likelihood ratio of a grow move.
|
|
1432
2158
|
|
|
1433
2159
|
Parameters
|
|
1434
2160
|
----------
|
|
1435
|
-
total_resid
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
The sum of the residuals
|
|
1439
|
-
|
|
2161
|
+
total_resid
|
|
2162
|
+
left_resid
|
|
2163
|
+
right_resid
|
|
2164
|
+
The sum of the residuals (scaled by error precision scale) of the
|
|
2165
|
+
datapoints falling in the nodes involved in the moves.
|
|
2166
|
+
prelkv
|
|
2167
|
+
prelk
|
|
1440
2168
|
The pre-computed terms of the likelihood ratio, see
|
|
1441
2169
|
`precompute_likelihood_terms`.
|
|
1442
2170
|
|
|
1443
2171
|
Returns
|
|
1444
2172
|
-------
|
|
1445
|
-
ratio
|
|
1446
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
2173
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1447
2174
|
"""
|
|
1448
|
-
exp_term = prelk
|
|
1449
|
-
left_resid * left_resid / prelkv
|
|
1450
|
-
right_resid * right_resid / prelkv
|
|
1451
|
-
total_resid * total_resid / prelkv
|
|
2175
|
+
exp_term = prelk.exp_factor * (
|
|
2176
|
+
left_resid * left_resid / prelkv.sigma2_left
|
|
2177
|
+
+ right_resid * right_resid / prelkv.sigma2_right
|
|
2178
|
+
- total_resid * total_resid / prelkv.sigma2_total
|
|
1452
2179
|
)
|
|
1453
|
-
return prelkv
|
|
2180
|
+
return prelkv.sqrt_term + exp_term
|
|
2181
|
+
|
|
1454
2182
|
|
|
1455
|
-
def accept_moves_final_stage(bart, moves):
|
|
2183
|
+
def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
1456
2184
|
"""
|
|
1457
|
-
|
|
2185
|
+
Post-process the mcmc state after accepting/rejecting the moves.
|
|
2186
|
+
|
|
2187
|
+
This function is separate from `accept_moves_sequential_stage` to signal it
|
|
2188
|
+
can work in parallel across trees.
|
|
1458
2189
|
|
|
1459
2190
|
Parameters
|
|
1460
2191
|
----------
|
|
1461
|
-
bart
|
|
2192
|
+
bart
|
|
1462
2193
|
A partially updated BART mcmc state.
|
|
1463
|
-
|
|
1464
|
-
The
|
|
1465
|
-
moves : dict
|
|
1466
|
-
The proposed moves (see `sample_moves`) as updated by
|
|
2194
|
+
moves
|
|
2195
|
+
The proposed moves (see `propose_moves`) as updated by
|
|
1467
2196
|
`accept_moves_sequential_stage`.
|
|
1468
2197
|
|
|
1469
2198
|
Returns
|
|
1470
2199
|
-------
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
2200
|
+
The fully updated BART mcmc state.
|
|
2201
|
+
"""
|
|
2202
|
+
return replace(
|
|
2203
|
+
bart,
|
|
2204
|
+
forest=replace(
|
|
2205
|
+
bart.forest,
|
|
2206
|
+
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
2207
|
+
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
2208
|
+
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
2209
|
+
split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
|
|
2210
|
+
),
|
|
2211
|
+
)
|
|
2212
|
+
|
|
1480
2213
|
|
|
1481
|
-
@
|
|
1482
|
-
def apply_moves_to_leaf_indices(
|
|
2214
|
+
@vmap_nodoc
|
|
2215
|
+
def apply_moves_to_leaf_indices(
|
|
2216
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
|
|
2217
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1483
2218
|
"""
|
|
1484
2219
|
Update the leaf indices to match the accepted move.
|
|
1485
2220
|
|
|
1486
2221
|
Parameters
|
|
1487
2222
|
----------
|
|
1488
|
-
leaf_indices
|
|
2223
|
+
leaf_indices
|
|
1489
2224
|
The index of the leaf each datapoint falls into, if the grow move was
|
|
1490
2225
|
accepted.
|
|
1491
|
-
moves
|
|
1492
|
-
The proposed moves (see `
|
|
2226
|
+
moves
|
|
2227
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1493
2228
|
`accept_moves_sequential_stage`.
|
|
1494
2229
|
|
|
1495
2230
|
Returns
|
|
1496
2231
|
-------
|
|
1497
|
-
|
|
1498
|
-
The updated leaf indices.
|
|
2232
|
+
The updated leaf indices.
|
|
1499
2233
|
"""
|
|
1500
|
-
mask = ~jnp.array(1, leaf_indices.dtype)
|
|
1501
|
-
is_child = (leaf_indices & mask) == moves
|
|
2234
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
2235
|
+
is_child = (leaf_indices & mask) == moves.left
|
|
1502
2236
|
return jnp.where(
|
|
1503
|
-
is_child & moves
|
|
1504
|
-
moves
|
|
2237
|
+
is_child & moves.to_prune,
|
|
2238
|
+
moves.node.astype(leaf_indices.dtype),
|
|
1505
2239
|
leaf_indices,
|
|
1506
2240
|
)
|
|
1507
2241
|
|
|
1508
|
-
|
|
1509
|
-
|
|
2242
|
+
|
|
2243
|
+
@vmap_nodoc
|
|
2244
|
+
def apply_moves_to_split_trees(
|
|
2245
|
+
split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
2246
|
+
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
1510
2247
|
"""
|
|
1511
2248
|
Update the split trees to match the accepted move.
|
|
1512
2249
|
|
|
1513
2250
|
Parameters
|
|
1514
2251
|
----------
|
|
1515
|
-
split_trees
|
|
2252
|
+
split_trees
|
|
1516
2253
|
The cutpoints of the decision nodes in the initial trees.
|
|
1517
|
-
moves
|
|
1518
|
-
The proposed moves (see `
|
|
2254
|
+
moves
|
|
2255
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1519
2256
|
`accept_moves_sequential_stage`.
|
|
1520
2257
|
|
|
1521
2258
|
Returns
|
|
1522
2259
|
-------
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
return (
|
|
1527
|
-
.at[
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
2260
|
+
The updated split trees.
|
|
2261
|
+
"""
|
|
2262
|
+
assert moves.to_prune is not None
|
|
2263
|
+
return (
|
|
2264
|
+
split_trees.at[
|
|
2265
|
+
jnp.where(
|
|
2266
|
+
moves.grow,
|
|
2267
|
+
moves.node,
|
|
2268
|
+
split_trees.size,
|
|
2269
|
+
)
|
|
2270
|
+
]
|
|
2271
|
+
.set(moves.grow_split.astype(split_trees.dtype))
|
|
2272
|
+
.at[
|
|
2273
|
+
jnp.where(
|
|
2274
|
+
moves.to_prune,
|
|
2275
|
+
moves.node,
|
|
2276
|
+
split_trees.size,
|
|
2277
|
+
)
|
|
2278
|
+
]
|
|
1538
2279
|
.set(0)
|
|
1539
2280
|
)
|
|
1540
2281
|
|
|
1541
|
-
|
|
2282
|
+
|
|
2283
|
+
def step_sigma(key: Key[Array, ''], bart: State) -> State:
|
|
1542
2284
|
"""
|
|
1543
|
-
|
|
2285
|
+
MCMC-update the error variance (factor).
|
|
1544
2286
|
|
|
1545
2287
|
Parameters
|
|
1546
2288
|
----------
|
|
1547
|
-
|
|
1548
|
-
A BART mcmc state, as created by `init`.
|
|
1549
|
-
key : jax.dtypes.prng_key array
|
|
2289
|
+
key
|
|
1550
2290
|
A jax random key.
|
|
2291
|
+
bart
|
|
2292
|
+
A BART mcmc state.
|
|
1551
2293
|
|
|
1552
2294
|
Returns
|
|
1553
2295
|
-------
|
|
1554
|
-
|
|
1555
|
-
The new BART mcmc state.
|
|
2296
|
+
The new BART mcmc state, with an updated `sigma2`.
|
|
1556
2297
|
"""
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
2298
|
+
resid = bart.resid
|
|
2299
|
+
alpha = bart.sigma2_alpha + resid.size / 2
|
|
2300
|
+
if bart.prec_scale is None:
|
|
2301
|
+
scaled_resid = resid
|
|
2302
|
+
else:
|
|
2303
|
+
scaled_resid = resid * bart.prec_scale
|
|
2304
|
+
norm2 = resid @ scaled_resid
|
|
2305
|
+
beta = bart.sigma2_beta + norm2 / 2
|
|
1563
2306
|
|
|
1564
2307
|
sample = random.gamma(key, alpha)
|
|
1565
|
-
bart
|
|
2308
|
+
return replace(bart, sigma2=beta / sample)
|
|
1566
2309
|
|
|
1567
|
-
|
|
2310
|
+
|
|
2311
|
+
def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
2312
|
+
"""
|
|
2313
|
+
MCMC-update the latent variable for binary regression.
|
|
2314
|
+
|
|
2315
|
+
Parameters
|
|
2316
|
+
----------
|
|
2317
|
+
key
|
|
2318
|
+
A jax random key.
|
|
2319
|
+
bart
|
|
2320
|
+
A BART MCMC state.
|
|
2321
|
+
|
|
2322
|
+
Returns
|
|
2323
|
+
-------
|
|
2324
|
+
The updated BART MCMC state.
|
|
2325
|
+
"""
|
|
2326
|
+
trees_plus_offset = bart.z - bart.resid
|
|
2327
|
+
lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
|
|
2328
|
+
upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
|
|
2329
|
+
resid = random.truncated_normal(key, lower, upper)
|
|
2330
|
+
# TODO jax's implementation of truncated_normal is not good, it just does
|
|
2331
|
+
# cdf inversion with erf and erf_inv. I can do better, at least avoiding to
|
|
2332
|
+
# compute one of the boundaries, and maybe also flipping and using ndtr
|
|
2333
|
+
# instead of erf for numerical stability (open an issue in jax?)
|
|
2334
|
+
z = trees_plus_offset + resid
|
|
2335
|
+
return replace(bart, z=z, resid=resid)
|