bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +196 -103
- bartz/__init__.py +1 -1
- bartz/_version.py +1 -1
- bartz/debug.py +1 -1
- bartz/grove.py +43 -2
- bartz/jaxext.py +82 -33
- bartz/mcmcloop.py +367 -114
- bartz/mcmcstep.py +1322 -807
- bartz/prepcovars.py +3 -1
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/METADATA +7 -5
- bartz-0.6.0.dist-info/RECORD +13 -0
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/WHEEL +1 -1
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/mcmcstep.py
CHANGED
|
@@ -26,220 +26,292 @@
|
|
|
26
26
|
Functions that implement the BART posterior MCMC initialization and update step.
|
|
27
27
|
|
|
28
28
|
Functions that do MCMC steps operate by taking as input a bart state, and
|
|
29
|
-
outputting a new
|
|
30
|
-
modified.
|
|
29
|
+
outputting a new state. The inputs are not modified.
|
|
31
30
|
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
The main entry points are:
|
|
32
|
+
|
|
33
|
+
- `State`: The dataclass that represents a BART MCMC state.
|
|
34
|
+
- `init`: Creates an initial `State` from data and configurations.
|
|
35
|
+
- `step`: Performs one full MCMC step on a `State`, returning a new `State`.
|
|
34
36
|
"""
|
|
35
37
|
|
|
36
|
-
import functools
|
|
37
38
|
import math
|
|
39
|
+
from dataclasses import replace
|
|
40
|
+
from functools import cache, partial
|
|
41
|
+
from typing import Any
|
|
38
42
|
|
|
39
43
|
import jax
|
|
44
|
+
from equinox import Module, field
|
|
40
45
|
from jax import lax, random
|
|
41
46
|
from jax import numpy as jnp
|
|
47
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
|
|
48
|
+
|
|
49
|
+
from . import grove
|
|
50
|
+
from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Forest(Module):
|
|
54
|
+
"""
|
|
55
|
+
Represents the MCMC state of a sum of trees.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
leaf_trees
|
|
60
|
+
The leaf values.
|
|
61
|
+
var_trees
|
|
62
|
+
The decision axes.
|
|
63
|
+
split_trees
|
|
64
|
+
The decision boundaries.
|
|
65
|
+
p_nonterminal
|
|
66
|
+
The probability of a nonterminal node at each depth, padded with a
|
|
67
|
+
zero.
|
|
68
|
+
p_propose_grow
|
|
69
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
70
|
+
leaf_indices
|
|
71
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
72
|
+
min_points_per_leaf
|
|
73
|
+
The minimum number of data points in a leaf node.
|
|
74
|
+
affluence_trees
|
|
75
|
+
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
76
|
+
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
77
|
+
resid_batch_size
|
|
78
|
+
count_batch_size
|
|
79
|
+
The data batch sizes for computing the sufficient statistics. If `None`,
|
|
80
|
+
they are computed with no batching.
|
|
81
|
+
log_trans_prior
|
|
82
|
+
The log transition and prior Metropolis-Hastings ratio for the
|
|
83
|
+
proposed move on each tree.
|
|
84
|
+
log_likelihood
|
|
85
|
+
The log likelihood ratio.
|
|
86
|
+
grow_prop_count
|
|
87
|
+
prune_prop_count
|
|
88
|
+
The number of grow/prune proposals made during one full MCMC cycle.
|
|
89
|
+
grow_acc_count
|
|
90
|
+
prune_acc_count
|
|
91
|
+
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
92
|
+
sigma_mu2
|
|
93
|
+
The prior variance of a leaf, conditional on the tree structure.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
leaf_trees: Float32[Array, 'num_trees 2**d']
|
|
97
|
+
var_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
98
|
+
split_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
99
|
+
p_nonterminal: Float32[Array, 'd']
|
|
100
|
+
p_propose_grow: Float32[Array, '2**(d-1)']
|
|
101
|
+
leaf_indices: UInt[Array, 'num_trees n']
|
|
102
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
103
|
+
affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
|
|
104
|
+
resid_batch_size: int | None = field(static=True)
|
|
105
|
+
count_batch_size: int | None = field(static=True)
|
|
106
|
+
log_trans_prior: Float32[Array, 'num_trees'] | None
|
|
107
|
+
log_likelihood: Float32[Array, 'num_trees'] | None
|
|
108
|
+
grow_prop_count: Int32[Array, '']
|
|
109
|
+
prune_prop_count: Int32[Array, '']
|
|
110
|
+
grow_acc_count: Int32[Array, '']
|
|
111
|
+
prune_acc_count: Int32[Array, '']
|
|
112
|
+
sigma_mu2: Float32[Array, '']
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class State(Module):
|
|
116
|
+
"""
|
|
117
|
+
Represents the MCMC state of BART.
|
|
42
118
|
|
|
43
|
-
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
X
|
|
122
|
+
The predictors.
|
|
123
|
+
max_split
|
|
124
|
+
The maximum split index for each predictor.
|
|
125
|
+
y
|
|
126
|
+
The response. If the data type is `bool`, the model is binary regression.
|
|
127
|
+
resid
|
|
128
|
+
The residuals (`y` or `z` minus sum of trees).
|
|
129
|
+
z
|
|
130
|
+
The latent variable for binary regression. `None` in continuous
|
|
131
|
+
regression.
|
|
132
|
+
offset
|
|
133
|
+
Constant shift added to the sum of trees.
|
|
134
|
+
sigma2
|
|
135
|
+
The error variance. `None` in binary regression.
|
|
136
|
+
prec_scale
|
|
137
|
+
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
138
|
+
`None` in binary regression.
|
|
139
|
+
sigma2_alpha
|
|
140
|
+
sigma2_beta
|
|
141
|
+
The shape and scale parameters of the inverse gamma prior on the noise
|
|
142
|
+
variance. `None` in binary regression.
|
|
143
|
+
forest
|
|
144
|
+
The sum of trees model.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
X: UInt[Array, 'p n']
|
|
148
|
+
max_split: UInt[Array, 'p']
|
|
149
|
+
y: Float32[Array, 'n'] | Bool[Array, 'n']
|
|
150
|
+
z: None | Float32[Array, 'n']
|
|
151
|
+
offset: Float32[Array, '']
|
|
152
|
+
resid: Float32[Array, 'n']
|
|
153
|
+
sigma2: Float32[Array, ''] | None
|
|
154
|
+
prec_scale: Float32[Array, 'n'] | None
|
|
155
|
+
sigma2_alpha: Float32[Array, ''] | None
|
|
156
|
+
sigma2_beta: Float32[Array, ''] | None
|
|
157
|
+
forest: Forest
|
|
44
158
|
|
|
45
159
|
|
|
46
160
|
def init(
|
|
47
161
|
*,
|
|
48
|
-
X,
|
|
49
|
-
y,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
min_points_per_leaf=None,
|
|
59
|
-
resid_batch_size='auto',
|
|
60
|
-
count_batch_size='auto',
|
|
61
|
-
save_ratios=False,
|
|
62
|
-
):
|
|
162
|
+
X: UInt[Any, 'p n'],
|
|
163
|
+
y: Float32[Any, 'n'] | Bool[Any, 'n'],
|
|
164
|
+
offset: float | Float32[Any, ''] = 0.0,
|
|
165
|
+
max_split: UInt[Any, 'p'],
|
|
166
|
+
num_trees: int,
|
|
167
|
+
p_nonterminal: Float32[Any, 'd-1'],
|
|
168
|
+
sigma_mu2: float | Float32[Any, ''],
|
|
169
|
+
sigma2_alpha: float | Float32[Any, ''] | None = None,
|
|
170
|
+
sigma2_beta: float | Float32[Any, ''] | None = None,
|
|
171
|
+
error_scale: Float32[Any, 'n'] | None = None,
|
|
172
|
+
min_points_per_leaf: int | None = None,
|
|
173
|
+
resid_batch_size: int | None | str = 'auto',
|
|
174
|
+
count_batch_size: int | None | str = 'auto',
|
|
175
|
+
save_ratios: bool = False,
|
|
176
|
+
) -> State:
|
|
63
177
|
"""
|
|
64
178
|
Make a BART posterior sampling MCMC initial state.
|
|
65
179
|
|
|
66
180
|
Parameters
|
|
67
181
|
----------
|
|
68
|
-
X
|
|
182
|
+
X
|
|
69
183
|
The predictors. Note this is trasposed compared to the usual convention.
|
|
70
|
-
y
|
|
71
|
-
The response.
|
|
72
|
-
|
|
184
|
+
y
|
|
185
|
+
The response. If the data type is `bool`, the regression model is binary
|
|
186
|
+
regression with probit.
|
|
187
|
+
offset
|
|
188
|
+
Constant shift added to the sum of trees. 0 if not specified.
|
|
189
|
+
max_split
|
|
73
190
|
The maximum split index for each variable. All split ranges start at 1.
|
|
74
|
-
num_trees
|
|
191
|
+
num_trees
|
|
75
192
|
The number of trees in the forest.
|
|
76
|
-
p_nonterminal
|
|
193
|
+
p_nonterminal
|
|
77
194
|
The probability of a nonterminal node at each depth. The maximum depth
|
|
78
195
|
of trees is fixed by the length of this array.
|
|
79
|
-
|
|
80
|
-
The
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
196
|
+
sigma_mu2
|
|
197
|
+
The prior variance of a leaf, conditional on the tree structure. The
|
|
198
|
+
prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
|
|
199
|
+
prior mean of leaves is always zero.
|
|
200
|
+
sigma2_alpha
|
|
201
|
+
sigma2_beta
|
|
202
|
+
The shape and scale parameters of the inverse gamma prior on the error
|
|
203
|
+
variance. Leave unspecified for binary regression.
|
|
204
|
+
error_scale
|
|
84
205
|
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
85
206
|
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
90
|
-
min_points_per_leaf : int, optional
|
|
207
|
+
Not supported for binary regression. If not specified, defaults to 1 for
|
|
208
|
+
all points, but potentially skipping calculations.
|
|
209
|
+
min_points_per_leaf
|
|
91
210
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
92
|
-
resid_batch_size
|
|
211
|
+
resid_batch_size
|
|
212
|
+
count_batch_size
|
|
93
213
|
The batch sizes, along datapoints, for summing the residuals and
|
|
94
214
|
counting the number of datapoints in each leaf. `None` for no batching.
|
|
95
215
|
If 'auto', pick a value based on the device of `y`, or the default
|
|
96
216
|
device.
|
|
97
|
-
save_ratios
|
|
217
|
+
save_ratios
|
|
98
218
|
Whether to save the Metropolis-Hastings ratios.
|
|
99
219
|
|
|
100
220
|
Returns
|
|
101
221
|
-------
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
'split_trees' : int array (num_trees, 2 ** (d - 1))
|
|
111
|
-
The decision boundaries.
|
|
112
|
-
'resid' : large_float array (n,)
|
|
113
|
-
The residuals (data minus forest value). Large float to avoid
|
|
114
|
-
roundoff.
|
|
115
|
-
'sigma2' : large_float
|
|
116
|
-
The noise variance.
|
|
117
|
-
'prec_scale' : large_float array (n,) or None
|
|
118
|
-
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
119
|
-
'grow_prop_count', 'prune_prop_count' : int
|
|
120
|
-
The number of grow/prune proposals made during one full MCMC cycle.
|
|
121
|
-
'grow_acc_count', 'prune_acc_count' : int
|
|
122
|
-
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
123
|
-
'p_nonterminal' : large_float array (d,)
|
|
124
|
-
The probability of a nonterminal node at each depth, padded with a
|
|
125
|
-
zero.
|
|
126
|
-
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
127
|
-
The unnormalized probability of picking a leaf for a grow proposal.
|
|
128
|
-
'sigma2_alpha' : large_float
|
|
129
|
-
The shape parameter of the inverse gamma prior on the noise variance.
|
|
130
|
-
'sigma2_beta' : large_float
|
|
131
|
-
The scale parameter of the inverse gamma prior on the noise variance.
|
|
132
|
-
'max_split' : int array (p,)
|
|
133
|
-
The maximum split index for each variable.
|
|
134
|
-
'y' : small_float array (n,)
|
|
135
|
-
The response.
|
|
136
|
-
'X' : int array (p, n)
|
|
137
|
-
The predictors.
|
|
138
|
-
'leaf_indices' : int array (num_trees, n)
|
|
139
|
-
The index of the leaf each datapoints falls into, for each tree.
|
|
140
|
-
'min_points_per_leaf' : int or None
|
|
141
|
-
The minimum number of data points in a leaf node.
|
|
142
|
-
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
143
|
-
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
144
|
-
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
145
|
-
'opt' : LeafDict
|
|
146
|
-
A dictionary with config values:
|
|
147
|
-
|
|
148
|
-
'small_float' : dtype
|
|
149
|
-
The dtype for large arrays used in the algorithm.
|
|
150
|
-
'large_float' : dtype
|
|
151
|
-
The dtype for scalars, small arrays, and arrays which require
|
|
152
|
-
accuracy.
|
|
153
|
-
'require_min_points' : bool
|
|
154
|
-
Whether the `min_points_per_leaf` parameter is specified.
|
|
155
|
-
'resid_batch_size', 'count_batch_size' : int or None
|
|
156
|
-
The data batch sizes for computing the sufficient statistics.
|
|
157
|
-
'ratios' : dict, optional
|
|
158
|
-
If `save_ratios` is True, this field is present. It has the fields:
|
|
159
|
-
|
|
160
|
-
'log_trans_prior' : large_float array (num_trees,)
|
|
161
|
-
The log transition and prior Metropolis-Hastings ratio for the
|
|
162
|
-
proposed move on each tree.
|
|
163
|
-
'log_likelihood' : large_float array (num_trees,)
|
|
164
|
-
The log likelihood ratio.
|
|
165
|
-
"""
|
|
166
|
-
|
|
167
|
-
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
222
|
+
An initialized BART MCMC state.
|
|
223
|
+
|
|
224
|
+
Raises
|
|
225
|
+
------
|
|
226
|
+
ValueError
|
|
227
|
+
If `y` is boolean and arguments unused in binary regression are set.
|
|
228
|
+
"""
|
|
229
|
+
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
168
230
|
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
169
231
|
max_depth = p_nonterminal.size
|
|
170
232
|
|
|
171
|
-
@
|
|
233
|
+
@partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
172
234
|
def make_forest(max_depth, dtype):
|
|
173
235
|
return grove.make_tree(max_depth, dtype)
|
|
174
236
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
237
|
+
y = jnp.asarray(y)
|
|
238
|
+
offset = jnp.asarray(offset)
|
|
239
|
+
|
|
178
240
|
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
179
241
|
resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
|
|
180
242
|
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
sigma2=
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
grow_acc_count=jnp.zeros((), int),
|
|
201
|
-
prune_prop_count=jnp.zeros((), int),
|
|
202
|
-
prune_acc_count=jnp.zeros((), int),
|
|
203
|
-
p_nonterminal=p_nonterminal,
|
|
204
|
-
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
205
|
-
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
206
|
-
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
243
|
+
|
|
244
|
+
is_binary = y.dtype == bool
|
|
245
|
+
if is_binary:
|
|
246
|
+
if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
|
|
247
|
+
raise ValueError(
|
|
248
|
+
'error_scale, sigma2_alpha, and sigma2_beta must be set '
|
|
249
|
+
' to `None` for binary regression.'
|
|
250
|
+
)
|
|
251
|
+
sigma2 = None
|
|
252
|
+
else:
|
|
253
|
+
sigma2_alpha = jnp.asarray(sigma2_alpha)
|
|
254
|
+
sigma2_beta = jnp.asarray(sigma2_beta)
|
|
255
|
+
sigma2 = sigma2_beta / sigma2_alpha
|
|
256
|
+
# sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
257
|
+
# TODO: I don't like this isfinite check, these functions should be
|
|
258
|
+
# low-level and just do the thing. Why was it here?
|
|
259
|
+
|
|
260
|
+
bart = State(
|
|
261
|
+
X=jnp.asarray(X),
|
|
207
262
|
max_split=jnp.asarray(max_split),
|
|
208
263
|
y=y,
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
None if
|
|
215
|
-
),
|
|
216
|
-
affluence_trees=(
|
|
217
|
-
None
|
|
218
|
-
if min_points_per_leaf is None
|
|
219
|
-
else make_forest(max_depth - 1, bool)
|
|
220
|
-
.at[:, 1]
|
|
221
|
-
.set(y.size >= 2 * min_points_per_leaf)
|
|
264
|
+
z=jnp.full(y.shape, offset) if is_binary else None,
|
|
265
|
+
offset=offset,
|
|
266
|
+
resid=jnp.zeros(y.shape) if is_binary else y - offset,
|
|
267
|
+
sigma2=sigma2,
|
|
268
|
+
prec_scale=(
|
|
269
|
+
None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
|
|
222
270
|
),
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
271
|
+
sigma2_alpha=sigma2_alpha,
|
|
272
|
+
sigma2_beta=sigma2_beta,
|
|
273
|
+
forest=Forest(
|
|
274
|
+
leaf_trees=make_forest(max_depth, jnp.float32),
|
|
275
|
+
var_trees=make_forest(
|
|
276
|
+
max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
|
|
277
|
+
),
|
|
278
|
+
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
279
|
+
grow_prop_count=jnp.zeros((), int),
|
|
280
|
+
grow_acc_count=jnp.zeros((), int),
|
|
281
|
+
prune_prop_count=jnp.zeros((), int),
|
|
282
|
+
prune_acc_count=jnp.zeros((), int),
|
|
283
|
+
p_nonterminal=p_nonterminal,
|
|
284
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
285
|
+
leaf_indices=jnp.ones(
|
|
286
|
+
(num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
|
|
287
|
+
),
|
|
288
|
+
min_points_per_leaf=(
|
|
289
|
+
None
|
|
290
|
+
if min_points_per_leaf is None
|
|
291
|
+
else jnp.asarray(min_points_per_leaf)
|
|
292
|
+
),
|
|
293
|
+
affluence_trees=(
|
|
294
|
+
None
|
|
295
|
+
if min_points_per_leaf is None
|
|
296
|
+
else make_forest(max_depth - 1, bool)
|
|
297
|
+
.at[:, 1]
|
|
298
|
+
.set(y.size >= 2 * min_points_per_leaf)
|
|
299
|
+
),
|
|
227
300
|
resid_batch_size=resid_batch_size,
|
|
228
301
|
count_batch_size=count_batch_size,
|
|
302
|
+
log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
|
|
303
|
+
log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
|
|
304
|
+
sigma_mu2=jnp.asarray(sigma_mu2),
|
|
229
305
|
),
|
|
230
306
|
)
|
|
231
307
|
|
|
232
|
-
if save_ratios:
|
|
233
|
-
bart['ratios'] = dict(
|
|
234
|
-
log_trans_prior=jnp.full(num_trees, jnp.nan),
|
|
235
|
-
log_likelihood=jnp.full(num_trees, jnp.nan),
|
|
236
|
-
)
|
|
237
|
-
|
|
238
308
|
return bart
|
|
239
309
|
|
|
240
310
|
|
|
241
|
-
def _choose_suffstat_batch_size(
|
|
242
|
-
|
|
311
|
+
def _choose_suffstat_batch_size(
|
|
312
|
+
resid_batch_size, count_batch_size, y, forest_size
|
|
313
|
+
) -> tuple[int | None, ...]:
|
|
314
|
+
@cache
|
|
243
315
|
def get_platform():
|
|
244
316
|
try:
|
|
245
317
|
device = y.devices().pop()
|
|
@@ -276,231 +348,327 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
276
348
|
return resid_batch_size, count_batch_size
|
|
277
349
|
|
|
278
350
|
|
|
279
|
-
|
|
351
|
+
@jax.jit
|
|
352
|
+
def step(key: Key[Array, ''], bart: State) -> State:
|
|
280
353
|
"""
|
|
281
|
-
|
|
354
|
+
Do one MCMC step.
|
|
282
355
|
|
|
283
356
|
Parameters
|
|
284
357
|
----------
|
|
285
|
-
key
|
|
358
|
+
key
|
|
286
359
|
A jax random key.
|
|
287
|
-
bart
|
|
360
|
+
bart
|
|
288
361
|
A BART mcmc state, as created by `init`.
|
|
289
362
|
|
|
290
363
|
Returns
|
|
291
364
|
-------
|
|
292
|
-
|
|
293
|
-
The new BART mcmc state.
|
|
365
|
+
The new BART mcmc state.
|
|
294
366
|
"""
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
367
|
+
keys = split(key)
|
|
368
|
+
|
|
369
|
+
if bart.y.dtype == bool: # binary regression
|
|
370
|
+
bart = replace(bart, sigma2=jnp.float32(1))
|
|
371
|
+
bart = step_trees(keys.pop(), bart)
|
|
372
|
+
bart = replace(bart, sigma2=None)
|
|
373
|
+
return step_z(keys.pop(), bart)
|
|
298
374
|
|
|
375
|
+
else: # continuous regression
|
|
376
|
+
bart = step_trees(keys.pop(), bart)
|
|
377
|
+
return step_sigma(keys.pop(), bart)
|
|
299
378
|
|
|
300
|
-
|
|
379
|
+
|
|
380
|
+
def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
301
381
|
"""
|
|
302
382
|
Forest sampling step of BART MCMC.
|
|
303
383
|
|
|
304
384
|
Parameters
|
|
305
385
|
----------
|
|
306
|
-
key
|
|
386
|
+
key
|
|
307
387
|
A jax random key.
|
|
308
|
-
bart
|
|
388
|
+
bart
|
|
309
389
|
A BART mcmc state, as created by `init`.
|
|
310
390
|
|
|
311
391
|
Returns
|
|
312
392
|
-------
|
|
313
|
-
|
|
314
|
-
The new BART mcmc state.
|
|
393
|
+
The new BART mcmc state.
|
|
315
394
|
|
|
316
395
|
Notes
|
|
317
396
|
-----
|
|
318
397
|
This function zeroes the proposal counters.
|
|
319
398
|
"""
|
|
320
|
-
|
|
321
|
-
moves =
|
|
322
|
-
return accept_moves_and_sample_leaves(
|
|
399
|
+
keys = split(key)
|
|
400
|
+
moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
|
|
401
|
+
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
323
402
|
|
|
324
403
|
|
|
325
|
-
|
|
404
|
+
class Moves(Module):
|
|
405
|
+
"""
|
|
406
|
+
Moves proposed to modify each tree.
|
|
407
|
+
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
allowed
|
|
411
|
+
Whether the move is possible in the first place. There are additional
|
|
412
|
+
constraints that could forbid it, but they are computed at acceptance
|
|
413
|
+
time.
|
|
414
|
+
grow
|
|
415
|
+
Whether the move is a grow move or a prune move.
|
|
416
|
+
num_growable
|
|
417
|
+
The number of growable leaves in the original tree.
|
|
418
|
+
node
|
|
419
|
+
The index of the leaf to grow or node to prune.
|
|
420
|
+
left
|
|
421
|
+
right
|
|
422
|
+
The indices of the children of 'node'.
|
|
423
|
+
partial_ratio
|
|
424
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
425
|
+
the likelihood ratio and the probability of proposing the prune
|
|
426
|
+
move. If the move is PRUNE, the ratio is inverted. `None` once
|
|
427
|
+
`log_trans_prior_ratio` has been computed.
|
|
428
|
+
log_trans_prior_ratio
|
|
429
|
+
The logarithm of the product of the transition and prior terms of the
|
|
430
|
+
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
431
|
+
`None` if not yet computed.
|
|
432
|
+
grow_var
|
|
433
|
+
The decision axes of the new rules.
|
|
434
|
+
grow_split
|
|
435
|
+
The decision boundaries of the new rules.
|
|
436
|
+
var_trees
|
|
437
|
+
The updated decision axes of the trees, valid whatever move.
|
|
438
|
+
logu
|
|
439
|
+
The logarithm of a uniform (0, 1] random variable to be used to
|
|
440
|
+
accept the move. It's in (-oo, 0].
|
|
441
|
+
acc
|
|
442
|
+
Whether the move was accepted. `None` if not yet computed.
|
|
443
|
+
to_prune
|
|
444
|
+
Whether the final operation to apply the move is pruning. This indicates
|
|
445
|
+
an accepted prune move or a rejected grow move. `None` if not yet
|
|
446
|
+
computed.
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
allowed: Bool[Array, 'num_trees']
|
|
450
|
+
grow: Bool[Array, 'num_trees']
|
|
451
|
+
num_growable: UInt[Array, 'num_trees']
|
|
452
|
+
node: UInt[Array, 'num_trees']
|
|
453
|
+
left: UInt[Array, 'num_trees']
|
|
454
|
+
right: UInt[Array, 'num_trees']
|
|
455
|
+
partial_ratio: Float32[Array, 'num_trees'] | None
|
|
456
|
+
log_trans_prior_ratio: None | Float32[Array, 'num_trees']
|
|
457
|
+
grow_var: UInt[Array, 'num_trees']
|
|
458
|
+
grow_split: UInt[Array, 'num_trees']
|
|
459
|
+
var_trees: UInt[Array, 'num_trees 2**(d-1)']
|
|
460
|
+
logu: Float32[Array, 'num_trees']
|
|
461
|
+
acc: None | Bool[Array, 'num_trees']
|
|
462
|
+
to_prune: None | Bool[Array, 'num_trees']
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def propose_moves(
|
|
466
|
+
key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
|
|
467
|
+
) -> Moves:
|
|
326
468
|
"""
|
|
327
469
|
Propose moves for all the trees.
|
|
328
470
|
|
|
471
|
+
There are two types of moves: GROW (convert a leaf to a decision node and
|
|
472
|
+
add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
|
|
473
|
+
leaf, deleting its children).
|
|
474
|
+
|
|
329
475
|
Parameters
|
|
330
476
|
----------
|
|
331
|
-
key
|
|
477
|
+
key
|
|
332
478
|
A jax random key.
|
|
333
|
-
|
|
334
|
-
BART
|
|
479
|
+
forest
|
|
480
|
+
The `forest` field of a BART MCMC state.
|
|
481
|
+
max_split
|
|
482
|
+
The maximum split index for each variable, found in `State`.
|
|
335
483
|
|
|
336
484
|
Returns
|
|
337
485
|
-------
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
Whether the move is possible.
|
|
343
|
-
'grow' : bool array (num_trees,)
|
|
344
|
-
Whether the move is a grow move or a prune move.
|
|
345
|
-
'num_growable' : int array (num_trees,)
|
|
346
|
-
The number of growable leaves in the original tree.
|
|
347
|
-
'node' : int array (num_trees,)
|
|
348
|
-
The index of the leaf to grow or node to prune.
|
|
349
|
-
'left', 'right' : int array (num_trees,)
|
|
350
|
-
The indices of the children of 'node'.
|
|
351
|
-
'partial_ratio' : float array (num_trees,)
|
|
352
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
353
|
-
the likelihood ratio and the probability of proposing the prune
|
|
354
|
-
move. If the move is Prune, the ratio is inverted.
|
|
355
|
-
'grow_var' : int array (num_trees,)
|
|
356
|
-
The decision axes of the new rules.
|
|
357
|
-
'grow_split' : int array (num_trees,)
|
|
358
|
-
The decision boundaries of the new rules.
|
|
359
|
-
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
360
|
-
The updated decision axes of the trees, valid whatever move.
|
|
361
|
-
'logu' : float array (num_trees,)
|
|
362
|
-
The logarithm of a uniform (0, 1] random variable to be used to
|
|
363
|
-
accept the move. It's in (-oo, 0].
|
|
364
|
-
"""
|
|
365
|
-
ntree = bart['leaf_trees'].shape[0]
|
|
366
|
-
key = random.split(key, 1 + ntree)
|
|
367
|
-
key, subkey = key[0], key[1:]
|
|
486
|
+
The proposed move for each tree.
|
|
487
|
+
"""
|
|
488
|
+
num_trees, _ = forest.leaf_trees.shape
|
|
489
|
+
keys = split(key, 1 + 2 * num_trees)
|
|
368
490
|
|
|
369
491
|
# compute moves
|
|
370
|
-
grow_moves
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
492
|
+
grow_moves = propose_grow_moves(
|
|
493
|
+
keys.pop(num_trees),
|
|
494
|
+
forest.var_trees,
|
|
495
|
+
forest.split_trees,
|
|
496
|
+
forest.affluence_trees,
|
|
497
|
+
max_split,
|
|
498
|
+
forest.p_nonterminal,
|
|
499
|
+
forest.p_propose_grow,
|
|
500
|
+
)
|
|
501
|
+
prune_moves = propose_prune_moves(
|
|
502
|
+
keys.pop(num_trees),
|
|
503
|
+
forest.split_trees,
|
|
504
|
+
forest.affluence_trees,
|
|
505
|
+
forest.p_nonterminal,
|
|
506
|
+
forest.p_propose_grow,
|
|
378
507
|
)
|
|
379
508
|
|
|
380
|
-
u, logu = random.uniform(
|
|
509
|
+
u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
|
|
381
510
|
|
|
382
511
|
# choose between grow or prune
|
|
383
|
-
grow_allowed = grow_moves
|
|
384
|
-
p_grow = jnp.where(grow_allowed & prune_moves
|
|
512
|
+
grow_allowed = grow_moves.num_growable.astype(bool)
|
|
513
|
+
p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
|
|
385
514
|
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
386
515
|
|
|
387
516
|
# compute children indices
|
|
388
|
-
node = jnp.where(grow, grow_moves
|
|
517
|
+
node = jnp.where(grow, grow_moves.node, prune_moves.node)
|
|
389
518
|
left = node << 1
|
|
390
519
|
right = left + 1
|
|
391
520
|
|
|
392
|
-
return
|
|
393
|
-
allowed=grow | prune_moves
|
|
521
|
+
return Moves(
|
|
522
|
+
allowed=grow | prune_moves.allowed,
|
|
394
523
|
grow=grow,
|
|
395
|
-
num_growable=grow_moves
|
|
524
|
+
num_growable=grow_moves.num_growable,
|
|
396
525
|
node=node,
|
|
397
526
|
left=left,
|
|
398
527
|
right=right,
|
|
399
528
|
partial_ratio=jnp.where(
|
|
400
|
-
grow, grow_moves
|
|
529
|
+
grow, grow_moves.partial_ratio, prune_moves.partial_ratio
|
|
401
530
|
),
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
531
|
+
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
532
|
+
grow_var=grow_moves.var,
|
|
533
|
+
grow_split=grow_moves.split,
|
|
534
|
+
var_trees=grow_moves.var_tree,
|
|
405
535
|
logu=jnp.log1p(-logu),
|
|
536
|
+
acc=None, # will be set in accept_moves_sequential_stage
|
|
537
|
+
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
406
538
|
)
|
|
407
539
|
|
|
408
540
|
|
|
409
|
-
|
|
410
|
-
def _sample_moves_vmap_trees(*args):
|
|
411
|
-
key, args = args[0], args[1:]
|
|
412
|
-
key, key1 = random.split(key)
|
|
413
|
-
grow = grow_move(key, *args)
|
|
414
|
-
prune = prune_move(key1, *args)
|
|
415
|
-
return grow, prune
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
def grow_move(
|
|
419
|
-
key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
|
|
420
|
-
):
|
|
541
|
+
class GrowMoves(Module):
|
|
421
542
|
"""
|
|
422
|
-
|
|
543
|
+
Represent a proposed grow move for each tree.
|
|
423
544
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
545
|
+
Parameters
|
|
546
|
+
----------
|
|
547
|
+
num_growable
|
|
548
|
+
The number of growable leaves.
|
|
549
|
+
node
|
|
550
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
551
|
+
leaves.
|
|
552
|
+
var
|
|
553
|
+
split
|
|
554
|
+
The decision axis and boundary of the new rule.
|
|
555
|
+
partial_ratio
|
|
556
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
557
|
+
the likelihood ratio and the probability of proposing the prune
|
|
558
|
+
move.
|
|
559
|
+
var_tree
|
|
560
|
+
The updated decision axes of the tree.
|
|
561
|
+
"""
|
|
562
|
+
|
|
563
|
+
num_growable: UInt[Array, 'num_trees']
|
|
564
|
+
node: UInt[Array, 'num_trees']
|
|
565
|
+
var: UInt[Array, 'num_trees']
|
|
566
|
+
split: UInt[Array, 'num_trees']
|
|
567
|
+
partial_ratio: Float32[Array, 'num_trees']
|
|
568
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
|
|
572
|
+
def propose_grow_moves(
|
|
573
|
+
key: Key[Array, ''],
|
|
574
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
575
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
576
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
577
|
+
max_split: UInt[Array, 'p'],
|
|
578
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
579
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
580
|
+
) -> GrowMoves:
|
|
581
|
+
"""
|
|
582
|
+
Propose a GROW move for each tree.
|
|
583
|
+
|
|
584
|
+
A GROW move picks a leaf node and converts it to a non-terminal node with
|
|
585
|
+
two leaf children.
|
|
427
586
|
|
|
428
587
|
Parameters
|
|
429
588
|
----------
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
589
|
+
key
|
|
590
|
+
A jax random key.
|
|
591
|
+
var_tree
|
|
592
|
+
The splitting axes of the tree.
|
|
593
|
+
split_tree
|
|
433
594
|
The splitting points of the tree.
|
|
434
|
-
affluence_tree
|
|
595
|
+
affluence_tree
|
|
435
596
|
Whether a leaf has enough points to be grown.
|
|
436
|
-
max_split
|
|
597
|
+
max_split
|
|
437
598
|
The maximum split index for each variable.
|
|
438
|
-
p_nonterminal
|
|
599
|
+
p_nonterminal
|
|
439
600
|
The probability of a nonterminal node at each depth.
|
|
440
|
-
p_propose_grow
|
|
601
|
+
p_propose_grow
|
|
441
602
|
The unnormalized probability of choosing a leaf to grow.
|
|
442
|
-
key : jax.dtypes.prng_key array
|
|
443
|
-
A jax random key.
|
|
444
603
|
|
|
445
604
|
Returns
|
|
446
605
|
-------
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
"""
|
|
464
|
-
|
|
465
|
-
key, key1, key2 = random.split(key, 3)
|
|
606
|
+
An object representing the proposed move.
|
|
607
|
+
|
|
608
|
+
Notes
|
|
609
|
+
-----
|
|
610
|
+
The move is not proposed if a leaf is already at maximum depth, or if a leaf
|
|
611
|
+
has less than twice the requested minimum number of datapoints per leaf.
|
|
612
|
+
This is marked by returning `num_growable` set to 0.
|
|
613
|
+
|
|
614
|
+
The move is also not be possible if the ancestors of a leaf have
|
|
615
|
+
exhausted the possible decision rules that lead to a non-empty selection.
|
|
616
|
+
This is marked by returning `var` set to `p` and `split` set to 0. But this
|
|
617
|
+
does not block the move from counting as "proposed", even though it is
|
|
618
|
+
predictably going to be rejected. This simplifies the MCMC and should not
|
|
619
|
+
reduce efficiency if not in unrealistic corner cases.
|
|
620
|
+
"""
|
|
621
|
+
keys = split(key, 3)
|
|
466
622
|
|
|
467
623
|
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
468
|
-
|
|
624
|
+
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
469
625
|
)
|
|
470
626
|
|
|
471
|
-
var = choose_variable(
|
|
627
|
+
var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
|
|
472
628
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
473
629
|
|
|
474
|
-
|
|
630
|
+
split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
|
|
475
631
|
|
|
476
632
|
ratio = compute_partial_ratio(
|
|
477
633
|
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
478
634
|
)
|
|
479
635
|
|
|
480
|
-
return
|
|
636
|
+
return GrowMoves(
|
|
481
637
|
num_growable=num_growable,
|
|
482
638
|
node=leaf_to_grow,
|
|
483
639
|
var=var,
|
|
484
|
-
split=
|
|
640
|
+
split=split_idx,
|
|
485
641
|
partial_ratio=ratio,
|
|
486
642
|
var_tree=var_tree,
|
|
487
643
|
)
|
|
488
644
|
|
|
645
|
+
# TODO it is not clear to me how var=p and split=0 when the move is not
|
|
646
|
+
# possible lead to corrent behavior downstream. Like, the move is proposed,
|
|
647
|
+
# but then it's a noop? And since it's a noop, it makes no difference if
|
|
648
|
+
# it's "accepted" or "rejected", it's like it's always rejected, so who
|
|
649
|
+
# cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
|
|
650
|
+
|
|
489
651
|
|
|
490
|
-
def choose_leaf(
|
|
652
|
+
def choose_leaf(
|
|
653
|
+
key: Key[Array, ''],
|
|
654
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
655
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
656
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
657
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
491
658
|
"""
|
|
492
659
|
Choose a leaf node to grow in a tree.
|
|
493
660
|
|
|
494
661
|
Parameters
|
|
495
662
|
----------
|
|
496
|
-
|
|
663
|
+
key
|
|
664
|
+
A jax random key.
|
|
665
|
+
split_tree
|
|
497
666
|
The splitting points of the tree.
|
|
498
|
-
affluence_tree
|
|
499
|
-
Whether a leaf has enough points
|
|
500
|
-
|
|
667
|
+
affluence_tree
|
|
668
|
+
Whether a leaf has enough points that it could be split into two leaves
|
|
669
|
+
satisfying the `min_points_per_leaf` requirement.
|
|
670
|
+
p_propose_grow
|
|
501
671
|
The unnormalized probability of choosing a leaf to grow.
|
|
502
|
-
key : jax.dtypes.prng_key array
|
|
503
|
-
A jax random key.
|
|
504
672
|
|
|
505
673
|
Returns
|
|
506
674
|
-------
|
|
@@ -508,9 +676,11 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
|
|
|
508
676
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
509
677
|
``2 ** d``.
|
|
510
678
|
num_growable : int
|
|
511
|
-
The number of leaf nodes that can be grown.
|
|
679
|
+
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
680
|
+
and have at least twice `min_points_per_leaf` if set.
|
|
512
681
|
prob_choose : float
|
|
513
|
-
The normalized probability
|
|
682
|
+
The (normalized) probability that this function had to choose that
|
|
683
|
+
specific leaf, given the arguments.
|
|
514
684
|
num_prunable : int
|
|
515
685
|
The number of leaf parents that could be pruned, after converting the
|
|
516
686
|
selected leaf to a non-terminal node.
|
|
@@ -526,23 +696,26 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
|
|
|
526
696
|
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
527
697
|
|
|
528
698
|
|
|
529
|
-
def growable_leaves(
|
|
699
|
+
def growable_leaves(
|
|
700
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
701
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
702
|
+
) -> Bool[Array, '2**(d-1)']:
|
|
530
703
|
"""
|
|
531
704
|
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
532
705
|
|
|
706
|
+
The condition is that a leaf is not at the bottom level and has at least two
|
|
707
|
+
times the number of minimum points per leaf.
|
|
708
|
+
|
|
533
709
|
Parameters
|
|
534
710
|
----------
|
|
535
|
-
split_tree
|
|
711
|
+
split_tree
|
|
536
712
|
The splitting points of the tree.
|
|
537
|
-
affluence_tree
|
|
713
|
+
affluence_tree
|
|
538
714
|
Whether a leaf has enough points to be grown.
|
|
539
715
|
|
|
540
716
|
Returns
|
|
541
717
|
-------
|
|
542
|
-
|
|
543
|
-
The mask indicating the leaf nodes that can be proposed to grow, i.e.,
|
|
544
|
-
that are not at the bottom level and have at least two times the number
|
|
545
|
-
of minimum points per leaf.
|
|
718
|
+
The mask indicating the leaf nodes that can be proposed to grow.
|
|
546
719
|
"""
|
|
547
720
|
is_growable = grove.is_actual_leaf(split_tree)
|
|
548
721
|
if affluence_tree is not None:
|
|
@@ -550,23 +723,25 @@ def growable_leaves(split_tree, affluence_tree):
|
|
|
550
723
|
return is_growable
|
|
551
724
|
|
|
552
725
|
|
|
553
|
-
def categorical(
|
|
726
|
+
def categorical(
|
|
727
|
+
key: Key[Array, ''], distr: Float32[Array, 'n']
|
|
728
|
+
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
554
729
|
"""
|
|
555
730
|
Return a random integer from an arbitrary distribution.
|
|
556
731
|
|
|
557
732
|
Parameters
|
|
558
733
|
----------
|
|
559
|
-
key
|
|
734
|
+
key
|
|
560
735
|
A jax random key.
|
|
561
|
-
distr
|
|
736
|
+
distr
|
|
562
737
|
An unnormalized probability distribution.
|
|
563
738
|
|
|
564
739
|
Returns
|
|
565
740
|
-------
|
|
566
|
-
u :
|
|
741
|
+
u : Int32[Array, '']
|
|
567
742
|
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
568
743
|
return ``n``.
|
|
569
|
-
norm :
|
|
744
|
+
norm : Float32[Array, '']
|
|
570
745
|
The sum of `distr`.
|
|
571
746
|
"""
|
|
572
747
|
ecdf = jnp.cumsum(distr)
|
|
@@ -574,27 +749,32 @@ def categorical(key, distr):
|
|
|
574
749
|
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
575
750
|
|
|
576
751
|
|
|
577
|
-
def choose_variable(
|
|
752
|
+
def choose_variable(
|
|
753
|
+
key: Key[Array, ''],
|
|
754
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
755
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
756
|
+
max_split: UInt[Array, 'p'],
|
|
757
|
+
leaf_index: Int32[Array, ''],
|
|
758
|
+
) -> Int32[Array, '']:
|
|
578
759
|
"""
|
|
579
760
|
Choose a variable to split on for a new non-terminal node.
|
|
580
761
|
|
|
581
762
|
Parameters
|
|
582
763
|
----------
|
|
583
|
-
|
|
764
|
+
key
|
|
765
|
+
A jax random key.
|
|
766
|
+
var_tree
|
|
584
767
|
The variable indices of the tree.
|
|
585
|
-
split_tree
|
|
768
|
+
split_tree
|
|
586
769
|
The splitting points of the tree.
|
|
587
|
-
max_split
|
|
770
|
+
max_split
|
|
588
771
|
The maximum split index for each variable.
|
|
589
|
-
leaf_index
|
|
772
|
+
leaf_index
|
|
590
773
|
The index of the leaf to grow.
|
|
591
|
-
key : jax.dtypes.prng_key array
|
|
592
|
-
A jax random key.
|
|
593
774
|
|
|
594
775
|
Returns
|
|
595
776
|
-------
|
|
596
|
-
|
|
597
|
-
The index of the variable to split on.
|
|
777
|
+
The index of the variable to split on.
|
|
598
778
|
|
|
599
779
|
Notes
|
|
600
780
|
-----
|
|
@@ -605,30 +785,36 @@ def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
|
|
|
605
785
|
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
606
786
|
|
|
607
787
|
|
|
608
|
-
def fully_used_variables(
|
|
788
|
+
def fully_used_variables(
|
|
789
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
790
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
791
|
+
max_split: UInt[Array, 'p'],
|
|
792
|
+
leaf_index: Int32[Array, ''],
|
|
793
|
+
) -> UInt[Array, 'd-2']:
|
|
609
794
|
"""
|
|
610
795
|
Return a list of variables that have an empty split range at a given node.
|
|
611
796
|
|
|
612
797
|
Parameters
|
|
613
798
|
----------
|
|
614
|
-
var_tree
|
|
799
|
+
var_tree
|
|
615
800
|
The variable indices of the tree.
|
|
616
|
-
split_tree
|
|
801
|
+
split_tree
|
|
617
802
|
The splitting points of the tree.
|
|
618
|
-
max_split
|
|
803
|
+
max_split
|
|
619
804
|
The maximum split index for each variable.
|
|
620
|
-
leaf_index
|
|
805
|
+
leaf_index
|
|
621
806
|
The index of the node, assumed to be valid for `var_tree`.
|
|
622
807
|
|
|
623
808
|
Returns
|
|
624
809
|
-------
|
|
625
|
-
|
|
626
|
-
The indices of the variables that have an empty split range. Since the
|
|
627
|
-
number of such variables is not fixed, unused values in the array are
|
|
628
|
-
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
629
|
-
particular order. Variables may appear more than once.
|
|
630
|
-
"""
|
|
810
|
+
The indices of the variables that have an empty split range.
|
|
631
811
|
|
|
812
|
+
Notes
|
|
813
|
+
-----
|
|
814
|
+
The number of unused variables is not known in advance. Unused values in the
|
|
815
|
+
array are filled with `p`. The fill values are not guaranteed to be placed
|
|
816
|
+
in any particular order, and variables may appear more than once.
|
|
817
|
+
"""
|
|
632
818
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
633
819
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
634
820
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
@@ -636,7 +822,11 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
|
636
822
|
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
637
823
|
|
|
638
824
|
|
|
639
|
-
def ancestor_variables(
|
|
825
|
+
def ancestor_variables(
|
|
826
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
827
|
+
max_split: UInt[Array, 'p'],
|
|
828
|
+
node_index: Int32[Array, ''],
|
|
829
|
+
) -> UInt[Array, 'd-2']:
|
|
640
830
|
"""
|
|
641
831
|
Return the list of variables in the ancestors of a node.
|
|
642
832
|
|
|
@@ -651,14 +841,16 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
651
841
|
|
|
652
842
|
Returns
|
|
653
843
|
-------
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
844
|
+
The variable indices of the ancestors of the node.
|
|
845
|
+
|
|
846
|
+
Notes
|
|
847
|
+
-----
|
|
848
|
+
The ancestors are the nodes going from the root to the parent of the node.
|
|
849
|
+
The number of ancestors is not known at tracing time; unused spots in the
|
|
850
|
+
output array are filled with `p`.
|
|
657
851
|
"""
|
|
658
852
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
659
|
-
ancestor_vars = jnp.zeros(
|
|
660
|
-
max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
|
|
661
|
-
)
|
|
853
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
|
|
662
854
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
663
855
|
|
|
664
856
|
def loop(carry, _):
|
|
@@ -673,27 +865,32 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
673
865
|
return ancestor_vars
|
|
674
866
|
|
|
675
867
|
|
|
676
|
-
def split_range(
|
|
868
|
+
def split_range(
|
|
869
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
870
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
871
|
+
max_split: UInt[Array, 'p'],
|
|
872
|
+
node_index: Int32[Array, ''],
|
|
873
|
+
ref_var: Int32[Array, ''],
|
|
874
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
677
875
|
"""
|
|
678
876
|
Return the range of allowed splits for a variable at a given node.
|
|
679
877
|
|
|
680
878
|
Parameters
|
|
681
879
|
----------
|
|
682
|
-
var_tree
|
|
880
|
+
var_tree
|
|
683
881
|
The variable indices of the tree.
|
|
684
|
-
split_tree
|
|
882
|
+
split_tree
|
|
685
883
|
The splitting points of the tree.
|
|
686
|
-
max_split
|
|
884
|
+
max_split
|
|
687
885
|
The maximum split index for each variable.
|
|
688
|
-
node_index
|
|
886
|
+
node_index
|
|
689
887
|
The index of the node, assumed to be valid for `var_tree`.
|
|
690
|
-
ref_var
|
|
888
|
+
ref_var
|
|
691
889
|
The variable for which to measure the split range.
|
|
692
890
|
|
|
693
891
|
Returns
|
|
694
892
|
-------
|
|
695
|
-
l, r
|
|
696
|
-
The range of allowed splits is [l, r).
|
|
893
|
+
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
|
|
697
894
|
"""
|
|
698
895
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
699
896
|
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
@@ -715,26 +912,29 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
|
715
912
|
return l + 1, r
|
|
716
913
|
|
|
717
914
|
|
|
718
|
-
def randint_exclude(
|
|
915
|
+
def randint_exclude(
|
|
916
|
+
key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
|
|
917
|
+
) -> Int32[Array, '']:
|
|
719
918
|
"""
|
|
720
919
|
Return a random integer in a range, excluding some values.
|
|
721
920
|
|
|
722
921
|
Parameters
|
|
723
922
|
----------
|
|
724
|
-
key
|
|
923
|
+
key
|
|
725
924
|
A jax random key.
|
|
726
|
-
sup
|
|
925
|
+
sup
|
|
727
926
|
The exclusive upper bound of the range.
|
|
728
|
-
exclude
|
|
927
|
+
exclude
|
|
729
928
|
The values to exclude from the range. Values greater than or equal to
|
|
730
929
|
`sup` are ignored. Values can appear more than once.
|
|
731
930
|
|
|
732
931
|
Returns
|
|
733
932
|
-------
|
|
734
|
-
u
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
933
|
+
A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
|
|
934
|
+
|
|
935
|
+
Notes
|
|
936
|
+
-----
|
|
937
|
+
If all values in the range are excluded, return `sup`.
|
|
738
938
|
"""
|
|
739
939
|
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
740
940
|
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
@@ -747,58 +947,74 @@ def randint_exclude(key, sup, exclude):
|
|
|
747
947
|
return u
|
|
748
948
|
|
|
749
949
|
|
|
750
|
-
def choose_split(
|
|
950
|
+
def choose_split(
|
|
951
|
+
key: Key[Array, ''],
|
|
952
|
+
var_tree: UInt[Array, '2**(d-1)'],
|
|
953
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
954
|
+
max_split: UInt[Array, 'p'],
|
|
955
|
+
leaf_index: Int32[Array, ''],
|
|
956
|
+
) -> Int32[Array, '']:
|
|
751
957
|
"""
|
|
752
958
|
Choose a split point for a new non-terminal node.
|
|
753
959
|
|
|
754
960
|
Parameters
|
|
755
961
|
----------
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
962
|
+
key
|
|
963
|
+
A jax random key.
|
|
964
|
+
var_tree
|
|
965
|
+
The splitting axes of the tree.
|
|
966
|
+
split_tree
|
|
759
967
|
The splitting points of the tree.
|
|
760
|
-
max_split
|
|
968
|
+
max_split
|
|
761
969
|
The maximum split index for each variable.
|
|
762
|
-
leaf_index
|
|
970
|
+
leaf_index
|
|
763
971
|
The index of the leaf to grow. It is assumed that `var_tree` already
|
|
764
972
|
contains the target variable at this index.
|
|
765
|
-
key : jax.dtypes.prng_key array
|
|
766
|
-
A jax random key.
|
|
767
973
|
|
|
768
974
|
Returns
|
|
769
975
|
-------
|
|
770
|
-
|
|
771
|
-
The split point.
|
|
976
|
+
The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
|
|
772
977
|
"""
|
|
773
978
|
var = var_tree[leaf_index]
|
|
774
979
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
775
980
|
return random.randint(key, (), l, r)
|
|
776
981
|
|
|
982
|
+
# TODO what happens if leaf_index is out of bounds? And is the value used
|
|
983
|
+
# in that case?
|
|
777
984
|
|
|
778
|
-
|
|
985
|
+
|
|
986
|
+
def compute_partial_ratio(
|
|
987
|
+
prob_choose: Float32[Array, ''],
|
|
988
|
+
num_prunable: Int32[Array, ''],
|
|
989
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
990
|
+
leaf_to_grow: Int32[Array, ''],
|
|
991
|
+
) -> Float32[Array, '']:
|
|
779
992
|
"""
|
|
780
993
|
Compute the product of the transition and prior ratios of a grow move.
|
|
781
994
|
|
|
782
995
|
Parameters
|
|
783
996
|
----------
|
|
784
|
-
|
|
785
|
-
The
|
|
786
|
-
|
|
997
|
+
prob_choose
|
|
998
|
+
The probability that the leaf had to be chosen amongst the growable
|
|
999
|
+
leaves.
|
|
1000
|
+
num_prunable
|
|
787
1001
|
The number of leaf parents that could be pruned, after converting the
|
|
788
1002
|
leaf to be grown to a non-terminal node.
|
|
789
|
-
p_nonterminal
|
|
1003
|
+
p_nonterminal
|
|
790
1004
|
The probability of a nonterminal node at each depth.
|
|
791
|
-
leaf_to_grow
|
|
1005
|
+
leaf_to_grow
|
|
792
1006
|
The index of the leaf to grow.
|
|
793
1007
|
|
|
794
1008
|
Returns
|
|
795
1009
|
-------
|
|
796
|
-
ratio
|
|
797
|
-
The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
|
|
798
|
-
times the prior ratio P(new tree) / P(old tree), but the transition
|
|
799
|
-
ratio is missing the factor P(propose prune) in the numerator.
|
|
800
|
-
"""
|
|
1010
|
+
The partial transition ratio times the prior ratio.
|
|
801
1011
|
|
|
1012
|
+
Notes
|
|
1013
|
+
-----
|
|
1014
|
+
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
1015
|
+
The "partial" transition ratio returned is missing the factor P(propose
|
|
1016
|
+
prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
|
|
1017
|
+
"""
|
|
802
1018
|
# the two ratios also contain factors num_available_split *
|
|
803
1019
|
# num_available_var, but they cancel out
|
|
804
1020
|
|
|
@@ -822,42 +1038,55 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
822
1038
|
return tree_ratio / inv_trans_ratio
|
|
823
1039
|
|
|
824
1040
|
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
1041
|
+
class PruneMoves(Module):
|
|
1042
|
+
"""
|
|
1043
|
+
Represent a proposed prune move for each tree.
|
|
1044
|
+
|
|
1045
|
+
Parameters
|
|
1046
|
+
----------
|
|
1047
|
+
allowed
|
|
1048
|
+
Whether the move is possible.
|
|
1049
|
+
node
|
|
1050
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
1051
|
+
partial_ratio
|
|
1052
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
1053
|
+
the likelihood ratio and the probability of proposing the prune
|
|
1054
|
+
move. This ratio is inverted, and is meant to be inverted back in
|
|
1055
|
+
`accept_move_and_sample_leaves`.
|
|
1056
|
+
"""
|
|
1057
|
+
|
|
1058
|
+
allowed: Bool[Array, 'num_trees']
|
|
1059
|
+
node: UInt[Array, 'num_trees']
|
|
1060
|
+
partial_ratio: Float32[Array, 'num_trees']
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
1064
|
+
def propose_prune_moves(
|
|
1065
|
+
key: Key[Array, ''],
|
|
1066
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
1067
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
1068
|
+
p_nonterminal: Float32[Array, 'd'],
|
|
1069
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1070
|
+
) -> PruneMoves:
|
|
828
1071
|
"""
|
|
829
1072
|
Tree structure prune move proposal of BART MCMC.
|
|
830
1073
|
|
|
831
1074
|
Parameters
|
|
832
1075
|
----------
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
split_tree
|
|
1076
|
+
key
|
|
1077
|
+
A jax random key.
|
|
1078
|
+
split_tree
|
|
836
1079
|
The splitting points of the tree.
|
|
837
|
-
affluence_tree
|
|
1080
|
+
affluence_tree
|
|
838
1081
|
Whether a leaf has enough points to be grown.
|
|
839
|
-
|
|
840
|
-
The maximum split index for each variable.
|
|
841
|
-
p_nonterminal : float array (d,)
|
|
1082
|
+
p_nonterminal
|
|
842
1083
|
The probability of a nonterminal node at each depth.
|
|
843
|
-
p_propose_grow
|
|
1084
|
+
p_propose_grow
|
|
844
1085
|
The unnormalized probability of choosing a leaf to grow.
|
|
845
|
-
key : jax.dtypes.prng_key array
|
|
846
|
-
A jax random key.
|
|
847
1086
|
|
|
848
1087
|
Returns
|
|
849
1088
|
-------
|
|
850
|
-
|
|
851
|
-
A dictionary with fields:
|
|
852
|
-
|
|
853
|
-
'allowed' : bool
|
|
854
|
-
Whether the move is possible.
|
|
855
|
-
'node' : int
|
|
856
|
-
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
857
|
-
'partial_ratio' : float
|
|
858
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
859
|
-
the likelihood ratio and the probability of proposing the prune
|
|
860
|
-
move. This ratio is inverted.
|
|
1089
|
+
An object representing the proposed moves.
|
|
861
1090
|
"""
|
|
862
1091
|
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
863
1092
|
key, split_tree, affluence_tree, p_propose_grow
|
|
@@ -868,37 +1097,44 @@ def prune_move(
|
|
|
868
1097
|
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
869
1098
|
)
|
|
870
1099
|
|
|
871
|
-
return
|
|
1100
|
+
return PruneMoves(
|
|
872
1101
|
allowed=allowed,
|
|
873
1102
|
node=node_to_prune,
|
|
874
|
-
partial_ratio=ratio,
|
|
1103
|
+
partial_ratio=ratio,
|
|
875
1104
|
)
|
|
876
1105
|
|
|
877
1106
|
|
|
878
|
-
def choose_leaf_parent(
|
|
1107
|
+
def choose_leaf_parent(
|
|
1108
|
+
key: Key[Array, ''],
|
|
1109
|
+
split_tree: UInt[Array, '2**(d-1)'],
|
|
1110
|
+
affluence_tree: Bool[Array, '2**(d-1)'] | None,
|
|
1111
|
+
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1112
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
|
|
879
1113
|
"""
|
|
880
1114
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
881
1115
|
|
|
882
1116
|
Parameters
|
|
883
1117
|
----------
|
|
884
|
-
|
|
1118
|
+
key
|
|
1119
|
+
A jax random key.
|
|
1120
|
+
split_tree
|
|
885
1121
|
The splitting points of the tree.
|
|
886
|
-
affluence_tree
|
|
1122
|
+
affluence_tree
|
|
887
1123
|
Whether a leaf has enough points to be grown.
|
|
888
|
-
p_propose_grow
|
|
1124
|
+
p_propose_grow
|
|
889
1125
|
The unnormalized probability of choosing a leaf to grow.
|
|
890
|
-
key : jax.dtypes.prng_key array
|
|
891
|
-
A jax random key.
|
|
892
1126
|
|
|
893
1127
|
Returns
|
|
894
1128
|
-------
|
|
895
|
-
node_to_prune :
|
|
1129
|
+
node_to_prune : Int32[Array, '']
|
|
896
1130
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
897
1131
|
``2 ** d``.
|
|
898
|
-
num_prunable :
|
|
1132
|
+
num_prunable : Int32[Array, '']
|
|
899
1133
|
The number of leaf parents that could be pruned.
|
|
900
|
-
prob_choose :
|
|
901
|
-
The normalized probability
|
|
1134
|
+
prob_choose : Float32[Array, '']
|
|
1135
|
+
The (normalized) probability that `choose_leaf` would chose
|
|
1136
|
+
`node_to_prune` as leaf to grow, if passed the tree where
|
|
1137
|
+
`node_to_prune` had been pruned.
|
|
902
1138
|
"""
|
|
903
1139
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
904
1140
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
@@ -906,9 +1142,8 @@ def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
|
|
|
906
1142
|
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
907
1143
|
|
|
908
1144
|
split_tree = split_tree.at[node_to_prune].set(0)
|
|
909
|
-
affluence_tree
|
|
910
|
-
|
|
911
|
-
)
|
|
1145
|
+
if affluence_tree is not None:
|
|
1146
|
+
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
912
1147
|
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
913
1148
|
prob_choose = p_propose_grow[node_to_prune]
|
|
914
1149
|
prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
@@ -916,56 +1151,196 @@ def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
|
|
|
916
1151
|
return node_to_prune, num_prunable, prob_choose
|
|
917
1152
|
|
|
918
1153
|
|
|
919
|
-
def randint_masked(key, mask):
|
|
1154
|
+
def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
|
|
920
1155
|
"""
|
|
921
1156
|
Return a random integer in a range, including only some values.
|
|
922
1157
|
|
|
923
1158
|
Parameters
|
|
924
1159
|
----------
|
|
925
|
-
key
|
|
1160
|
+
key
|
|
926
1161
|
A jax random key.
|
|
927
|
-
mask
|
|
1162
|
+
mask
|
|
928
1163
|
The mask indicating the allowed values.
|
|
929
1164
|
|
|
930
1165
|
Returns
|
|
931
1166
|
-------
|
|
932
|
-
u
|
|
933
|
-
|
|
934
|
-
|
|
1167
|
+
A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
|
|
1168
|
+
|
|
1169
|
+
Notes
|
|
1170
|
+
-----
|
|
1171
|
+
If all values in the mask are `False`, return `n`.
|
|
935
1172
|
"""
|
|
936
1173
|
ecdf = jnp.cumsum(mask)
|
|
937
1174
|
u = random.randint(key, (), 0, ecdf[-1])
|
|
938
1175
|
return jnp.searchsorted(ecdf, u, 'right')
|
|
939
1176
|
|
|
940
1177
|
|
|
941
|
-
def accept_moves_and_sample_leaves(
|
|
1178
|
+
def accept_moves_and_sample_leaves(
|
|
1179
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1180
|
+
) -> State:
|
|
942
1181
|
"""
|
|
943
1182
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
944
1183
|
|
|
945
1184
|
Parameters
|
|
946
1185
|
----------
|
|
947
|
-
key
|
|
1186
|
+
key
|
|
948
1187
|
A jax random key.
|
|
949
|
-
bart
|
|
950
|
-
A BART mcmc state.
|
|
951
|
-
moves
|
|
952
|
-
The proposed moves, see `
|
|
1188
|
+
bart
|
|
1189
|
+
A valid BART mcmc state.
|
|
1190
|
+
moves
|
|
1191
|
+
The proposed moves, see `propose_moves`.
|
|
953
1192
|
|
|
954
1193
|
Returns
|
|
955
1194
|
-------
|
|
956
|
-
|
|
957
|
-
The new BART mcmc state.
|
|
1195
|
+
A new (valid) BART mcmc state.
|
|
958
1196
|
"""
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
)
|
|
962
|
-
bart, moves = accept_moves_sequential_stage(
|
|
963
|
-
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
964
|
-
)
|
|
1197
|
+
pso = accept_moves_parallel_stage(key, bart, moves)
|
|
1198
|
+
bart, moves = accept_moves_sequential_stage(pso)
|
|
965
1199
|
return accept_moves_final_stage(bart, moves)
|
|
966
1200
|
|
|
967
1201
|
|
|
968
|
-
|
|
1202
|
+
class Counts(Module):
|
|
1203
|
+
"""
|
|
1204
|
+
Number of datapoints in the nodes involved in proposed moves for each tree.
|
|
1205
|
+
|
|
1206
|
+
Parameters
|
|
1207
|
+
----------
|
|
1208
|
+
left
|
|
1209
|
+
Number of datapoints in the left child.
|
|
1210
|
+
right
|
|
1211
|
+
Number of datapoints in the right child.
|
|
1212
|
+
total
|
|
1213
|
+
Number of datapoints in the parent (``= left + right``).
|
|
1214
|
+
"""
|
|
1215
|
+
|
|
1216
|
+
left: UInt[Array, 'num_trees']
|
|
1217
|
+
right: UInt[Array, 'num_trees']
|
|
1218
|
+
total: UInt[Array, 'num_trees']
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
class Precs(Module):
|
|
1222
|
+
"""
|
|
1223
|
+
Likelihood precision scale in the nodes involved in proposed moves for each tree.
|
|
1224
|
+
|
|
1225
|
+
The "likelihood precision scale" of a tree node is the sum of the inverse
|
|
1226
|
+
squared error scales of the datapoints selected by the node.
|
|
1227
|
+
|
|
1228
|
+
Parameters
|
|
1229
|
+
----------
|
|
1230
|
+
left
|
|
1231
|
+
Likelihood precision scale in the left child.
|
|
1232
|
+
right
|
|
1233
|
+
Likelihood precision scale in the right child.
|
|
1234
|
+
total
|
|
1235
|
+
Likelihood precision scale in the parent (``= left + right``).
|
|
1236
|
+
"""
|
|
1237
|
+
|
|
1238
|
+
left: Float32[Array, 'num_trees']
|
|
1239
|
+
right: Float32[Array, 'num_trees']
|
|
1240
|
+
total: Float32[Array, 'num_trees']
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
class PreLkV(Module):
|
|
1244
|
+
"""
|
|
1245
|
+
Non-sequential terms of the likelihood ratio for each tree.
|
|
1246
|
+
|
|
1247
|
+
These terms can be computed in parallel across trees.
|
|
1248
|
+
|
|
1249
|
+
Parameters
|
|
1250
|
+
----------
|
|
1251
|
+
sigma2_left
|
|
1252
|
+
The noise variance in the left child of the leaves grown or pruned by
|
|
1253
|
+
the moves.
|
|
1254
|
+
sigma2_right
|
|
1255
|
+
The noise variance in the right child of the leaves grown or pruned by
|
|
1256
|
+
the moves.
|
|
1257
|
+
sigma2_total
|
|
1258
|
+
The noise variance in the total of the leaves grown or pruned by the
|
|
1259
|
+
moves.
|
|
1260
|
+
sqrt_term
|
|
1261
|
+
The **logarithm** of the square root term of the likelihood ratio.
|
|
1262
|
+
"""
|
|
1263
|
+
|
|
1264
|
+
sigma2_left: Float32[Array, 'num_trees']
|
|
1265
|
+
sigma2_right: Float32[Array, 'num_trees']
|
|
1266
|
+
sigma2_total: Float32[Array, 'num_trees']
|
|
1267
|
+
sqrt_term: Float32[Array, 'num_trees']
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
class PreLk(Module):
|
|
1271
|
+
"""
|
|
1272
|
+
Non-sequential terms of the likelihood ratio shared by all trees.
|
|
1273
|
+
|
|
1274
|
+
Parameters
|
|
1275
|
+
----------
|
|
1276
|
+
exp_factor
|
|
1277
|
+
The factor to multiply the likelihood ratio by, shared by all trees.
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
exp_factor: Float32[Array, '']
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
class PreLf(Module):
|
|
1284
|
+
"""
|
|
1285
|
+
Pre-computed terms used to sample leaves from their posterior.
|
|
1286
|
+
|
|
1287
|
+
These terms can be computed in parallel across trees.
|
|
1288
|
+
|
|
1289
|
+
Parameters
|
|
1290
|
+
----------
|
|
1291
|
+
mean_factor
|
|
1292
|
+
The factor to be multiplied by the sum of the scaled residuals to
|
|
1293
|
+
obtain the posterior mean.
|
|
1294
|
+
centered_leaves
|
|
1295
|
+
The mean-zero normal values to be added to the posterior mean to
|
|
1296
|
+
obtain the posterior leaf samples.
|
|
1297
|
+
"""
|
|
1298
|
+
|
|
1299
|
+
mean_factor: Float32[Array, 'num_trees 2**d']
|
|
1300
|
+
centered_leaves: Float32[Array, 'num_trees 2**d']
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
class ParallelStageOut(Module):
|
|
1304
|
+
"""
|
|
1305
|
+
The output of `accept_moves_parallel_stage`.
|
|
1306
|
+
|
|
1307
|
+
Parameters
|
|
1308
|
+
----------
|
|
1309
|
+
bart
|
|
1310
|
+
A partially updated BART mcmc state.
|
|
1311
|
+
moves
|
|
1312
|
+
The proposed moves, with `partial_ratio` set to `None` and
|
|
1313
|
+
`log_trans_prior_ratio` set to its final value.
|
|
1314
|
+
prec_trees
|
|
1315
|
+
The likelihood precision scale in each potential or actual leaf node. If
|
|
1316
|
+
there is no precision scale, this is the number of points in each leaf.
|
|
1317
|
+
move_counts
|
|
1318
|
+
The counts of the number of points in the the nodes modified by the
|
|
1319
|
+
moves. If `bart.min_points_per_leaf` is not set and
|
|
1320
|
+
`bart.prec_scale` is set, they are not computed.
|
|
1321
|
+
move_precs
|
|
1322
|
+
The likelihood precision scale in each node modified by the moves. If
|
|
1323
|
+
`bart.prec_scale` is not set, this is set to `move_counts`.
|
|
1324
|
+
prelkv
|
|
1325
|
+
prelk
|
|
1326
|
+
prelf
|
|
1327
|
+
Objects with pre-computed terms of the likelihood ratios and leaf
|
|
1328
|
+
samples.
|
|
1329
|
+
"""
|
|
1330
|
+
|
|
1331
|
+
bart: State
|
|
1332
|
+
moves: Moves
|
|
1333
|
+
prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
|
|
1334
|
+
move_counts: Counts | None
|
|
1335
|
+
move_precs: Precs | Counts
|
|
1336
|
+
prelkv: PreLkV
|
|
1337
|
+
prelk: PreLk
|
|
1338
|
+
prelf: PreLf
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def accept_moves_parallel_stage(
|
|
1342
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1343
|
+
) -> ParallelStageOut:
|
|
969
1344
|
"""
|
|
970
1345
|
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
971
1346
|
|
|
@@ -976,88 +1351,110 @@ def accept_moves_parallel_stage(key, bart, moves):
|
|
|
976
1351
|
bart : dict
|
|
977
1352
|
A BART mcmc state.
|
|
978
1353
|
moves : dict
|
|
979
|
-
The proposed moves, see `
|
|
1354
|
+
The proposed moves, see `propose_moves`.
|
|
980
1355
|
|
|
981
1356
|
Returns
|
|
982
1357
|
-------
|
|
983
|
-
|
|
984
|
-
A partially updated BART mcmc state.
|
|
985
|
-
moves : dict
|
|
986
|
-
The proposed moves, with the field 'partial_ratio' replaced
|
|
987
|
-
by 'log_trans_prior_ratio'.
|
|
988
|
-
prec_trees : float array (num_trees, 2 ** d)
|
|
989
|
-
The likelihood precision scale in each potential or actual leaf node. If
|
|
990
|
-
there is no precision scale, this is the number of points in each leaf.
|
|
991
|
-
move_counts : dict
|
|
992
|
-
The counts of the number of points in the the nodes modified by the
|
|
993
|
-
moves.
|
|
994
|
-
move_precs : dict
|
|
995
|
-
The likelihood precision scale in each node modified by the moves.
|
|
996
|
-
prelkv, prelk, prelf : dict
|
|
997
|
-
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
998
|
-
samples.
|
|
1358
|
+
An object with all that could be done in parallel.
|
|
999
1359
|
"""
|
|
1000
|
-
bart = bart.copy()
|
|
1001
|
-
|
|
1002
1360
|
# where the move is grow, modify the state like the move was accepted
|
|
1003
|
-
bart
|
|
1004
|
-
|
|
1005
|
-
|
|
1361
|
+
bart = replace(
|
|
1362
|
+
bart,
|
|
1363
|
+
forest=replace(
|
|
1364
|
+
bart.forest,
|
|
1365
|
+
var_trees=moves.var_trees,
|
|
1366
|
+
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
1367
|
+
leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
|
|
1368
|
+
),
|
|
1369
|
+
)
|
|
1006
1370
|
|
|
1007
1371
|
# count number of datapoints per leaf
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1372
|
+
if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
|
|
1373
|
+
count_trees, move_counts = compute_count_trees(
|
|
1374
|
+
bart.forest.leaf_indices, moves, bart.forest.count_batch_size
|
|
1375
|
+
)
|
|
1376
|
+
else:
|
|
1377
|
+
# move_counts is passed later to a function, but then is unused under
|
|
1378
|
+
# this condition
|
|
1379
|
+
move_counts = None
|
|
1380
|
+
|
|
1381
|
+
# Check if some nodes can't surely be grown because they don't have enough
|
|
1382
|
+
# datapoints. This check is not actually used now, it will be used at the
|
|
1383
|
+
# beginning of the next step to propose moves.
|
|
1384
|
+
if bart.forest.min_points_per_leaf is not None:
|
|
1385
|
+
count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
|
|
1386
|
+
bart = replace(
|
|
1387
|
+
bart,
|
|
1388
|
+
forest=replace(
|
|
1389
|
+
bart.forest,
|
|
1390
|
+
affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
|
|
1391
|
+
),
|
|
1392
|
+
)
|
|
1014
1393
|
|
|
1015
1394
|
# count number of datapoints per leaf, weighted by error precision scale
|
|
1016
|
-
if bart
|
|
1395
|
+
if bart.prec_scale is None:
|
|
1017
1396
|
prec_trees = count_trees
|
|
1018
1397
|
move_precs = move_counts
|
|
1019
1398
|
else:
|
|
1020
1399
|
prec_trees, move_precs = compute_prec_trees(
|
|
1021
|
-
bart
|
|
1022
|
-
bart
|
|
1400
|
+
bart.prec_scale,
|
|
1401
|
+
bart.forest.leaf_indices,
|
|
1023
1402
|
moves,
|
|
1024
|
-
bart
|
|
1403
|
+
bart.forest.count_batch_size,
|
|
1025
1404
|
)
|
|
1026
1405
|
|
|
1027
1406
|
# compute some missing information about moves
|
|
1028
|
-
moves = complete_ratio(moves, move_counts, bart
|
|
1029
|
-
bart
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1407
|
+
moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
|
|
1408
|
+
bart = replace(
|
|
1409
|
+
bart,
|
|
1410
|
+
forest=replace(
|
|
1411
|
+
bart.forest,
|
|
1412
|
+
grow_prop_count=jnp.sum(moves.grow),
|
|
1413
|
+
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
1414
|
+
),
|
|
1415
|
+
)
|
|
1034
1416
|
|
|
1035
|
-
|
|
1417
|
+
prelkv, prelk = precompute_likelihood_terms(
|
|
1418
|
+
bart.sigma2, bart.forest.sigma_mu2, move_precs
|
|
1419
|
+
)
|
|
1420
|
+
prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
|
|
1421
|
+
|
|
1422
|
+
return ParallelStageOut(
|
|
1423
|
+
bart=bart,
|
|
1424
|
+
moves=moves,
|
|
1425
|
+
prec_trees=prec_trees,
|
|
1426
|
+
move_counts=move_counts,
|
|
1427
|
+
move_precs=move_precs,
|
|
1428
|
+
prelkv=prelkv,
|
|
1429
|
+
prelk=prelk,
|
|
1430
|
+
prelf=prelf,
|
|
1431
|
+
)
|
|
1036
1432
|
|
|
1037
1433
|
|
|
1038
|
-
@
|
|
1039
|
-
def apply_grow_to_indices(
|
|
1434
|
+
@partial(vmap_nodoc, in_axes=(0, 0, None))
|
|
1435
|
+
def apply_grow_to_indices(
|
|
1436
|
+
moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
|
|
1437
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1040
1438
|
"""
|
|
1041
1439
|
Update the leaf indices to apply a grow move.
|
|
1042
1440
|
|
|
1043
1441
|
Parameters
|
|
1044
1442
|
----------
|
|
1045
|
-
moves
|
|
1046
|
-
The proposed moves, see `
|
|
1047
|
-
leaf_indices
|
|
1443
|
+
moves
|
|
1444
|
+
The proposed moves, see `propose_moves`.
|
|
1445
|
+
leaf_indices
|
|
1048
1446
|
The index of the leaf each datapoint falls into.
|
|
1049
|
-
X
|
|
1447
|
+
X
|
|
1050
1448
|
The predictors matrix.
|
|
1051
1449
|
|
|
1052
1450
|
Returns
|
|
1053
1451
|
-------
|
|
1054
|
-
|
|
1055
|
-
The updated leaf indices.
|
|
1452
|
+
The updated leaf indices.
|
|
1056
1453
|
"""
|
|
1057
|
-
left_child = moves
|
|
1058
|
-
go_right = X[moves
|
|
1059
|
-
tree_size = jnp.array(2 * moves
|
|
1060
|
-
node_to_update = jnp.where(moves
|
|
1454
|
+
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
1455
|
+
go_right = X[moves.grow_var, :] >= moves.grow_split
|
|
1456
|
+
tree_size = jnp.array(2 * moves.var_trees.size)
|
|
1457
|
+
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
1061
1458
|
return jnp.where(
|
|
1062
1459
|
leaf_indices == node_to_update,
|
|
1063
1460
|
left_child + go_right,
|
|
@@ -1065,64 +1462,65 @@ def apply_grow_to_indices(moves, leaf_indices, X):
|
|
|
1065
1462
|
)
|
|
1066
1463
|
|
|
1067
1464
|
|
|
1068
|
-
def compute_count_trees(
|
|
1465
|
+
def compute_count_trees(
|
|
1466
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
|
|
1467
|
+
) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
|
|
1069
1468
|
"""
|
|
1070
1469
|
Count the number of datapoints in each leaf.
|
|
1071
1470
|
|
|
1072
1471
|
Parameters
|
|
1073
1472
|
----------
|
|
1074
|
-
leaf_indices
|
|
1473
|
+
leaf_indices
|
|
1075
1474
|
The index of the leaf each datapoint falls into, with the deeper version
|
|
1076
1475
|
of the tree (post-GROW, pre-PRUNE).
|
|
1077
|
-
moves
|
|
1078
|
-
The proposed moves, see `
|
|
1079
|
-
batch_size
|
|
1476
|
+
moves
|
|
1477
|
+
The proposed moves, see `propose_moves`.
|
|
1478
|
+
batch_size
|
|
1080
1479
|
The data batch size to use for the summation.
|
|
1081
1480
|
|
|
1082
1481
|
Returns
|
|
1083
1482
|
-------
|
|
1084
|
-
count_trees :
|
|
1483
|
+
count_trees : Int32[Array, 'num_trees 2**d']
|
|
1085
1484
|
The number of points in each potential or actual leaf node.
|
|
1086
|
-
counts :
|
|
1485
|
+
counts : Counts
|
|
1087
1486
|
The counts of the number of points in the leaves grown or pruned by the
|
|
1088
|
-
moves
|
|
1487
|
+
moves.
|
|
1089
1488
|
"""
|
|
1090
|
-
|
|
1091
|
-
ntree, tree_size = moves['var_trees'].shape
|
|
1489
|
+
num_trees, tree_size = moves.var_trees.shape
|
|
1092
1490
|
tree_size *= 2
|
|
1093
|
-
tree_indices = jnp.arange(
|
|
1491
|
+
tree_indices = jnp.arange(num_trees)
|
|
1094
1492
|
|
|
1095
1493
|
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
1096
1494
|
|
|
1097
1495
|
# count datapoints in nodes modified by move
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
counts
|
|
1101
|
-
counts['total'] = counts['left'] + counts['right']
|
|
1496
|
+
left = count_trees[tree_indices, moves.left]
|
|
1497
|
+
right = count_trees[tree_indices, moves.right]
|
|
1498
|
+
counts = Counts(left=left, right=right, total=left + right)
|
|
1102
1499
|
|
|
1103
1500
|
# write count into non-leaf node
|
|
1104
|
-
count_trees = count_trees.at[tree_indices, moves
|
|
1501
|
+
count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
|
|
1105
1502
|
|
|
1106
1503
|
return count_trees, counts
|
|
1107
1504
|
|
|
1108
1505
|
|
|
1109
|
-
def count_datapoints_per_leaf(
|
|
1506
|
+
def count_datapoints_per_leaf(
|
|
1507
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
|
|
1508
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1110
1509
|
"""
|
|
1111
1510
|
Count the number of datapoints in each leaf.
|
|
1112
1511
|
|
|
1113
1512
|
Parameters
|
|
1114
1513
|
----------
|
|
1115
|
-
leaf_indices
|
|
1514
|
+
leaf_indices
|
|
1116
1515
|
The index of the leaf each datapoint falls into.
|
|
1117
|
-
tree_size
|
|
1516
|
+
tree_size
|
|
1118
1517
|
The size of the leaf tree array (2 ** d).
|
|
1119
|
-
batch_size
|
|
1518
|
+
batch_size
|
|
1120
1519
|
The data batch size to use for the summation.
|
|
1121
1520
|
|
|
1122
1521
|
Returns
|
|
1123
1522
|
-------
|
|
1124
|
-
|
|
1125
|
-
The number of points in each leaf node.
|
|
1523
|
+
The number of points in each leaf node.
|
|
1126
1524
|
"""
|
|
1127
1525
|
if batch_size is None:
|
|
1128
1526
|
return _count_scan(leaf_indices, tree_size)
|
|
@@ -1130,7 +1528,9 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
|
1130
1528
|
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1131
1529
|
|
|
1132
1530
|
|
|
1133
|
-
def _count_scan(
|
|
1531
|
+
def _count_scan(
|
|
1532
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
|
|
1533
|
+
) -> Int32[Array, 'num_trees {tree_size}']:
|
|
1134
1534
|
def loop(_, leaf_indices):
|
|
1135
1535
|
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1136
1536
|
|
|
@@ -1138,92 +1538,111 @@ def _count_scan(leaf_indices, tree_size):
|
|
|
1138
1538
|
return count_trees
|
|
1139
1539
|
|
|
1140
1540
|
|
|
1141
|
-
def _aggregate_scatter(
|
|
1541
|
+
def _aggregate_scatter(
|
|
1542
|
+
values: Shaped[Array, '*'],
|
|
1543
|
+
indices: Integer[Array, '*'],
|
|
1544
|
+
size: int,
|
|
1545
|
+
dtype: jnp.dtype,
|
|
1546
|
+
) -> Shaped[Array, '{size}']:
|
|
1142
1547
|
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1143
1548
|
|
|
1144
1549
|
|
|
1145
|
-
def _count_vec(
|
|
1550
|
+
def _count_vec(
|
|
1551
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
|
|
1552
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1146
1553
|
return _aggregate_batched_alltrees(
|
|
1147
1554
|
1, leaf_indices, tree_size, jnp.uint32, batch_size
|
|
1148
1555
|
)
|
|
1149
1556
|
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1150
1557
|
|
|
1151
1558
|
|
|
1152
|
-
def _aggregate_batched_alltrees(
|
|
1153
|
-
|
|
1154
|
-
|
|
1559
|
+
def _aggregate_batched_alltrees(
|
|
1560
|
+
values: Shaped[Array, '*'],
|
|
1561
|
+
indices: UInt[Array, 'num_trees n'],
|
|
1562
|
+
size: int,
|
|
1563
|
+
dtype: jnp.dtype,
|
|
1564
|
+
batch_size: int,
|
|
1565
|
+
) -> Shaped[Array, 'num_trees {size}']:
|
|
1566
|
+
num_trees, n = indices.shape
|
|
1567
|
+
tree_indices = jnp.arange(num_trees)
|
|
1155
1568
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1156
1569
|
batch_indices = jnp.arange(n) % nbatches
|
|
1157
1570
|
return (
|
|
1158
|
-
jnp.zeros((
|
|
1571
|
+
jnp.zeros((num_trees, size, nbatches), dtype)
|
|
1159
1572
|
.at[tree_indices[:, None], indices, batch_indices]
|
|
1160
1573
|
.add(values)
|
|
1161
1574
|
.sum(axis=2)
|
|
1162
1575
|
)
|
|
1163
1576
|
|
|
1164
1577
|
|
|
1165
|
-
def compute_prec_trees(
|
|
1578
|
+
def compute_prec_trees(
|
|
1579
|
+
prec_scale: Float32[Array, 'n'],
|
|
1580
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1581
|
+
moves: Moves,
|
|
1582
|
+
batch_size: int | None,
|
|
1583
|
+
) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
|
|
1166
1584
|
"""
|
|
1167
1585
|
Compute the likelihood precision scale in each leaf.
|
|
1168
1586
|
|
|
1169
1587
|
Parameters
|
|
1170
1588
|
----------
|
|
1171
|
-
prec_scale
|
|
1589
|
+
prec_scale
|
|
1172
1590
|
The scale of the precision of the error on each datapoint.
|
|
1173
|
-
leaf_indices
|
|
1591
|
+
leaf_indices
|
|
1174
1592
|
The index of the leaf each datapoint falls into, with the deeper version
|
|
1175
1593
|
of the tree (post-GROW, pre-PRUNE).
|
|
1176
|
-
moves
|
|
1177
|
-
The proposed moves, see `
|
|
1178
|
-
batch_size
|
|
1594
|
+
moves
|
|
1595
|
+
The proposed moves, see `propose_moves`.
|
|
1596
|
+
batch_size
|
|
1179
1597
|
The data batch size to use for the summation.
|
|
1180
1598
|
|
|
1181
1599
|
Returns
|
|
1182
1600
|
-------
|
|
1183
|
-
prec_trees :
|
|
1601
|
+
prec_trees : Float32[Array, 'num_trees 2**d']
|
|
1184
1602
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1185
|
-
|
|
1186
|
-
The likelihood precision scale in the
|
|
1187
|
-
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1603
|
+
precs : Precs
|
|
1604
|
+
The likelihood precision scale in the nodes involved in the moves.
|
|
1188
1605
|
"""
|
|
1189
|
-
|
|
1190
|
-
ntree, tree_size = moves['var_trees'].shape
|
|
1606
|
+
num_trees, tree_size = moves.var_trees.shape
|
|
1191
1607
|
tree_size *= 2
|
|
1192
|
-
tree_indices = jnp.arange(
|
|
1608
|
+
tree_indices = jnp.arange(num_trees)
|
|
1193
1609
|
|
|
1194
1610
|
prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1195
1611
|
|
|
1196
1612
|
# prec datapoints in nodes modified by move
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
precs
|
|
1200
|
-
precs['total'] = precs['left'] + precs['right']
|
|
1613
|
+
left = prec_trees[tree_indices, moves.left]
|
|
1614
|
+
right = prec_trees[tree_indices, moves.right]
|
|
1615
|
+
precs = Precs(left=left, right=right, total=left + right)
|
|
1201
1616
|
|
|
1202
1617
|
# write prec into non-leaf node
|
|
1203
|
-
prec_trees = prec_trees.at[tree_indices, moves
|
|
1618
|
+
prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
|
|
1204
1619
|
|
|
1205
1620
|
return prec_trees, precs
|
|
1206
1621
|
|
|
1207
1622
|
|
|
1208
|
-
def prec_per_leaf(
|
|
1623
|
+
def prec_per_leaf(
|
|
1624
|
+
prec_scale: Float32[Array, 'n'],
|
|
1625
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1626
|
+
tree_size: int,
|
|
1627
|
+
batch_size: int | None,
|
|
1628
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1209
1629
|
"""
|
|
1210
1630
|
Compute the likelihood precision scale in each leaf.
|
|
1211
1631
|
|
|
1212
1632
|
Parameters
|
|
1213
1633
|
----------
|
|
1214
|
-
prec_scale
|
|
1634
|
+
prec_scale
|
|
1215
1635
|
The scale of the precision of the error on each datapoint.
|
|
1216
|
-
leaf_indices
|
|
1636
|
+
leaf_indices
|
|
1217
1637
|
The index of the leaf each datapoint falls into.
|
|
1218
|
-
tree_size
|
|
1638
|
+
tree_size
|
|
1219
1639
|
The size of the leaf tree array (2 ** d).
|
|
1220
|
-
batch_size
|
|
1640
|
+
batch_size
|
|
1221
1641
|
The data batch size to use for the summation.
|
|
1222
1642
|
|
|
1223
1643
|
Returns
|
|
1224
1644
|
-------
|
|
1225
|
-
|
|
1226
|
-
The likelihood precision scale in each leaf node.
|
|
1645
|
+
The likelihood precision scale in each leaf node.
|
|
1227
1646
|
"""
|
|
1228
1647
|
if batch_size is None:
|
|
1229
1648
|
return _prec_scan(prec_scale, leaf_indices, tree_size)
|
|
@@ -1231,23 +1650,34 @@ def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
|
|
|
1231
1650
|
return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1232
1651
|
|
|
1233
1652
|
|
|
1234
|
-
def _prec_scan(
|
|
1653
|
+
def _prec_scan(
|
|
1654
|
+
prec_scale: Float32[Array, 'n'],
|
|
1655
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1656
|
+
tree_size: int,
|
|
1657
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1235
1658
|
def loop(_, leaf_indices):
|
|
1236
1659
|
return None, _aggregate_scatter(
|
|
1237
1660
|
prec_scale, leaf_indices, tree_size, jnp.float32
|
|
1238
|
-
)
|
|
1661
|
+
)
|
|
1239
1662
|
|
|
1240
1663
|
_, prec_trees = lax.scan(loop, None, leaf_indices)
|
|
1241
1664
|
return prec_trees
|
|
1242
1665
|
|
|
1243
1666
|
|
|
1244
|
-
def _prec_vec(
|
|
1667
|
+
def _prec_vec(
|
|
1668
|
+
prec_scale: Float32[Array, 'n'],
|
|
1669
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1670
|
+
tree_size: int,
|
|
1671
|
+
batch_size: int,
|
|
1672
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1245
1673
|
return _aggregate_batched_alltrees(
|
|
1246
1674
|
prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
|
|
1247
|
-
)
|
|
1675
|
+
)
|
|
1248
1676
|
|
|
1249
1677
|
|
|
1250
|
-
def complete_ratio(
|
|
1678
|
+
def complete_ratio(
|
|
1679
|
+
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1680
|
+
) -> Moves:
|
|
1251
1681
|
"""
|
|
1252
1682
|
Complete non-likelihood MH ratio calculation.
|
|
1253
1683
|
|
|
@@ -1255,330 +1685,367 @@ def complete_ratio(moves, move_counts, min_points_per_leaf):
|
|
|
1255
1685
|
|
|
1256
1686
|
Parameters
|
|
1257
1687
|
----------
|
|
1258
|
-
moves
|
|
1259
|
-
The proposed moves, see `
|
|
1260
|
-
move_counts
|
|
1688
|
+
moves
|
|
1689
|
+
The proposed moves, see `propose_moves`.
|
|
1690
|
+
move_counts
|
|
1261
1691
|
The counts of the number of points in the the nodes modified by the
|
|
1262
1692
|
moves.
|
|
1263
|
-
min_points_per_leaf
|
|
1693
|
+
min_points_per_leaf
|
|
1264
1694
|
The minimum number of data points in a leaf node.
|
|
1265
1695
|
|
|
1266
1696
|
Returns
|
|
1267
1697
|
-------
|
|
1268
|
-
moves
|
|
1269
|
-
The updated moves, with the field 'partial_ratio' replaced by
|
|
1270
|
-
'log_trans_prior_ratio'.
|
|
1698
|
+
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
1271
1699
|
"""
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
moves,
|
|
1700
|
+
p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
|
|
1701
|
+
return replace(
|
|
1702
|
+
moves,
|
|
1703
|
+
log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
|
|
1704
|
+
partial_ratio=None,
|
|
1275
1705
|
)
|
|
1276
|
-
moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
|
|
1277
|
-
return moves
|
|
1278
1706
|
|
|
1279
1707
|
|
|
1280
|
-
def compute_p_prune(
|
|
1708
|
+
def compute_p_prune(
|
|
1709
|
+
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1710
|
+
) -> Float32[Array, 'num_trees']:
|
|
1281
1711
|
"""
|
|
1282
|
-
Compute the probability of proposing a prune move.
|
|
1712
|
+
Compute the probability of proposing a prune move for each tree.
|
|
1283
1713
|
|
|
1284
1714
|
Parameters
|
|
1285
1715
|
----------
|
|
1286
|
-
moves
|
|
1287
|
-
The proposed moves, see `
|
|
1288
|
-
|
|
1716
|
+
moves
|
|
1717
|
+
The proposed moves, see `propose_moves`.
|
|
1718
|
+
move_counts
|
|
1289
1719
|
The number of datapoints in the proposed children of the leaf to grow.
|
|
1290
|
-
|
|
1720
|
+
Not used if `min_points_per_leaf` is `None`.
|
|
1721
|
+
min_points_per_leaf
|
|
1291
1722
|
The minimum number of data points in a leaf node.
|
|
1292
1723
|
|
|
1293
1724
|
Returns
|
|
1294
1725
|
-------
|
|
1295
|
-
|
|
1296
|
-
The probability of proposing a prune move. If grow: after accepting the
|
|
1297
|
-
grow move, if prune: right away.
|
|
1298
|
-
"""
|
|
1726
|
+
The probability of proposing a prune move.
|
|
1299
1727
|
|
|
1728
|
+
Notes
|
|
1729
|
+
-----
|
|
1730
|
+
This probability is computed for going from the state with the deeper tree
|
|
1731
|
+
to the one with the shallower one. This means, if grow: after accepting the
|
|
1732
|
+
grow move, if prune: right away.
|
|
1733
|
+
"""
|
|
1300
1734
|
# calculation in case the move is grow
|
|
1301
|
-
other_growable_leaves = moves
|
|
1302
|
-
new_leaves_growable = moves
|
|
1735
|
+
other_growable_leaves = moves.num_growable >= 2
|
|
1736
|
+
new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
|
|
1303
1737
|
if min_points_per_leaf is not None:
|
|
1304
|
-
|
|
1305
|
-
any_above_threshold
|
|
1738
|
+
assert move_counts is not None
|
|
1739
|
+
any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
|
|
1740
|
+
any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
|
|
1306
1741
|
new_leaves_growable &= any_above_threshold
|
|
1307
1742
|
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1308
1743
|
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1309
1744
|
|
|
1310
1745
|
# calculation in case the move is prune
|
|
1311
|
-
prune_p_prune = jnp.where(moves
|
|
1746
|
+
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
1312
1747
|
|
|
1313
|
-
return jnp.where(moves
|
|
1748
|
+
return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
1314
1749
|
|
|
1315
1750
|
|
|
1316
|
-
@
|
|
1317
|
-
def adapt_leaf_trees_to_grow_indices(
|
|
1751
|
+
@vmap_nodoc
|
|
1752
|
+
def adapt_leaf_trees_to_grow_indices(
|
|
1753
|
+
leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
|
|
1754
|
+
) -> Float32[Array, 'num_trees 2**d']:
|
|
1318
1755
|
"""
|
|
1319
|
-
Modify
|
|
1320
|
-
|
|
1756
|
+
Modify leaves such that post-grow indices work on the original tree.
|
|
1757
|
+
|
|
1758
|
+
The value of the leaf to grow is copied to what would be its children if the
|
|
1759
|
+
grow move was accepted.
|
|
1321
1760
|
|
|
1322
1761
|
Parameters
|
|
1323
1762
|
----------
|
|
1324
|
-
leaf_trees
|
|
1763
|
+
leaf_trees
|
|
1325
1764
|
The leaf values.
|
|
1326
|
-
moves
|
|
1327
|
-
The proposed moves, see `
|
|
1765
|
+
moves
|
|
1766
|
+
The proposed moves, see `propose_moves`.
|
|
1328
1767
|
|
|
1329
1768
|
Returns
|
|
1330
1769
|
-------
|
|
1331
|
-
|
|
1332
|
-
The modified leaf values. The value of the leaf to grow is copied to
|
|
1333
|
-
what would be its children if the grow move was accepted.
|
|
1770
|
+
The modified leaf values.
|
|
1334
1771
|
"""
|
|
1335
|
-
values_at_node = leaf_trees[moves
|
|
1772
|
+
values_at_node = leaf_trees[moves.node]
|
|
1336
1773
|
return (
|
|
1337
|
-
leaf_trees.at[jnp.where(moves
|
|
1774
|
+
leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
|
|
1338
1775
|
.set(values_at_node)
|
|
1339
|
-
.at[jnp.where(moves
|
|
1776
|
+
.at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
|
|
1340
1777
|
.set(values_at_node)
|
|
1341
1778
|
)
|
|
1342
1779
|
|
|
1343
1780
|
|
|
1344
|
-
def precompute_likelihood_terms(
|
|
1781
|
+
def precompute_likelihood_terms(
|
|
1782
|
+
sigma2: Float32[Array, ''],
|
|
1783
|
+
sigma_mu2: Float32[Array, ''],
|
|
1784
|
+
move_precs: Precs | Counts,
|
|
1785
|
+
) -> tuple[PreLkV, PreLk]:
|
|
1345
1786
|
"""
|
|
1346
1787
|
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1347
1788
|
|
|
1348
1789
|
Parameters
|
|
1349
1790
|
----------
|
|
1350
|
-
sigma2
|
|
1351
|
-
The
|
|
1352
|
-
|
|
1791
|
+
sigma2
|
|
1792
|
+
The error variance, or the global error variance factor is `prec_scale`
|
|
1793
|
+
is set.
|
|
1794
|
+
sigma_mu2
|
|
1795
|
+
The prior variance of each leaf.
|
|
1796
|
+
move_precs
|
|
1353
1797
|
The likelihood precision scale in the leaves grown or pruned by the
|
|
1354
1798
|
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1355
1799
|
|
|
1356
1800
|
Returns
|
|
1357
1801
|
-------
|
|
1358
|
-
prelkv :
|
|
1802
|
+
prelkv : PreLkV
|
|
1359
1803
|
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
1360
1804
|
tree.
|
|
1361
|
-
prelk :
|
|
1805
|
+
prelk : PreLk
|
|
1362
1806
|
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1363
1807
|
all trees.
|
|
1364
1808
|
"""
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
prelkv
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
jnp.log(
|
|
1373
|
-
sigma2
|
|
1374
|
-
* prelkv['sigma2_total']
|
|
1375
|
-
/ (prelkv['sigma2_left'] * prelkv['sigma2_right'])
|
|
1376
|
-
)
|
|
1377
|
-
/ 2
|
|
1809
|
+
sigma2_left = sigma2 + move_precs.left * sigma_mu2
|
|
1810
|
+
sigma2_right = sigma2 + move_precs.right * sigma_mu2
|
|
1811
|
+
sigma2_total = sigma2 + move_precs.total * sigma_mu2
|
|
1812
|
+
prelkv = PreLkV(
|
|
1813
|
+
sigma2_left=sigma2_left,
|
|
1814
|
+
sigma2_right=sigma2_right,
|
|
1815
|
+
sigma2_total=sigma2_total,
|
|
1816
|
+
sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
|
|
1378
1817
|
)
|
|
1379
|
-
return prelkv,
|
|
1818
|
+
return prelkv, PreLk(
|
|
1380
1819
|
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1381
1820
|
)
|
|
1382
1821
|
|
|
1383
1822
|
|
|
1384
|
-
def precompute_leaf_terms(
|
|
1823
|
+
def precompute_leaf_terms(
|
|
1824
|
+
key: Key[Array, ''],
|
|
1825
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
1826
|
+
sigma2: Float32[Array, ''],
|
|
1827
|
+
sigma_mu2: Float32[Array, ''],
|
|
1828
|
+
) -> PreLf:
|
|
1385
1829
|
"""
|
|
1386
1830
|
Pre-compute terms used to sample leaves from their posterior.
|
|
1387
1831
|
|
|
1388
1832
|
Parameters
|
|
1389
1833
|
----------
|
|
1390
|
-
key
|
|
1834
|
+
key
|
|
1391
1835
|
A jax random key.
|
|
1392
|
-
prec_trees
|
|
1836
|
+
prec_trees
|
|
1393
1837
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1394
|
-
sigma2
|
|
1395
|
-
The
|
|
1838
|
+
sigma2
|
|
1839
|
+
The error variance, or the global error variance factor if `prec_scale`
|
|
1840
|
+
is set.
|
|
1841
|
+
sigma_mu2
|
|
1842
|
+
The prior variance of each leaf.
|
|
1396
1843
|
|
|
1397
1844
|
Returns
|
|
1398
1845
|
-------
|
|
1399
|
-
|
|
1400
|
-
Dictionary with pre-computed terms of the leaf sampling, with fields:
|
|
1401
|
-
|
|
1402
|
-
'mean_factor' : float array (num_trees, 2 ** d)
|
|
1403
|
-
The factor to be multiplied by the sum of the scaled residuals to
|
|
1404
|
-
obtain the posterior mean.
|
|
1405
|
-
'centered_leaves' : float array (num_trees, 2 ** d)
|
|
1406
|
-
The mean-zero normal values to be added to the posterior mean to
|
|
1407
|
-
obtain the posterior leaf samples.
|
|
1846
|
+
Pre-computed terms for leaf sampling.
|
|
1408
1847
|
"""
|
|
1409
|
-
ntree = len(prec_trees)
|
|
1410
1848
|
prec_lk = prec_trees / sigma2
|
|
1411
|
-
|
|
1849
|
+
prec_prior = lax.reciprocal(sigma_mu2)
|
|
1850
|
+
var_post = lax.reciprocal(prec_lk + prec_prior)
|
|
1412
1851
|
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
1413
|
-
return
|
|
1414
|
-
mean_factor=var_post / sigma2,
|
|
1852
|
+
return PreLf(
|
|
1853
|
+
mean_factor=var_post / sigma2,
|
|
1854
|
+
# mean = mean_lk * prec_lk * var_post
|
|
1855
|
+
# resid_tree = mean_lk * prec_tree -->
|
|
1856
|
+
# --> mean_lk = resid_tree / prec_tree (kind of)
|
|
1857
|
+
# mean_factor =
|
|
1858
|
+
# = mean / resid_tree =
|
|
1859
|
+
# = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
1860
|
+
# = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
1861
|
+
# = var_post / sigma2
|
|
1415
1862
|
centered_leaves=z * jnp.sqrt(var_post),
|
|
1416
1863
|
)
|
|
1417
1864
|
|
|
1418
1865
|
|
|
1419
|
-
def accept_moves_sequential_stage(
|
|
1420
|
-
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
1421
|
-
):
|
|
1866
|
+
def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
1422
1867
|
"""
|
|
1423
|
-
|
|
1868
|
+
Accept/reject the moves one tree at a time.
|
|
1869
|
+
|
|
1870
|
+
This is the most performance-sensitive function because it contains all and
|
|
1871
|
+
only the parts of the algorithm that can not be parallelized across trees.
|
|
1424
1872
|
|
|
1425
1873
|
Parameters
|
|
1426
1874
|
----------
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
prec_trees : float array (num_trees, 2 ** d)
|
|
1430
|
-
The likelihood precision scale in each potential or actual leaf node.
|
|
1431
|
-
moves : dict
|
|
1432
|
-
The proposed moves, see `sample_moves`.
|
|
1433
|
-
move_counts : dict
|
|
1434
|
-
The counts of the number of points in the the nodes modified by the
|
|
1435
|
-
moves.
|
|
1436
|
-
move_precs : dict
|
|
1437
|
-
The likelihood precision scale in each node modified by the moves.
|
|
1438
|
-
prelkv, prelk, prelf : dict
|
|
1439
|
-
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1440
|
-
samples.
|
|
1875
|
+
pso
|
|
1876
|
+
The output of `accept_moves_parallel_stage`.
|
|
1441
1877
|
|
|
1442
1878
|
Returns
|
|
1443
1879
|
-------
|
|
1444
|
-
bart :
|
|
1880
|
+
bart : State
|
|
1445
1881
|
A partially updated BART mcmc state.
|
|
1446
|
-
moves :
|
|
1447
|
-
The
|
|
1448
|
-
|
|
1449
|
-
'acc' : bool array (num_trees,)
|
|
1450
|
-
Whether the move was accepted.
|
|
1451
|
-
'to_prune' : bool array (num_trees,)
|
|
1452
|
-
Whether, to reflect the acceptance status of the move, the state
|
|
1453
|
-
should be updated by pruning the leaves involved in the move.
|
|
1882
|
+
moves : Moves
|
|
1883
|
+
The accepted/rejected moves, with `acc` and `to_prune` set.
|
|
1454
1884
|
"""
|
|
1455
|
-
bart = bart.copy()
|
|
1456
|
-
moves = moves.copy()
|
|
1457
1885
|
|
|
1458
|
-
def loop(resid,
|
|
1886
|
+
def loop(resid, pt):
|
|
1459
1887
|
resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
|
|
1460
|
-
bart['X'],
|
|
1461
|
-
len(bart['leaf_trees']),
|
|
1462
|
-
bart['opt']['resid_batch_size'],
|
|
1463
1888
|
resid,
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1889
|
+
SeqStageInAllTrees(
|
|
1890
|
+
pso.bart.X,
|
|
1891
|
+
pso.bart.forest.resid_batch_size,
|
|
1892
|
+
pso.bart.prec_scale,
|
|
1893
|
+
pso.bart.forest.min_points_per_leaf,
|
|
1894
|
+
pso.bart.forest.log_likelihood is not None,
|
|
1895
|
+
pso.prelk,
|
|
1896
|
+
),
|
|
1897
|
+
pt,
|
|
1469
1898
|
)
|
|
1470
1899
|
return resid, (leaf_tree, acc, to_prune, ratios)
|
|
1471
1900
|
|
|
1472
|
-
|
|
1473
|
-
bart
|
|
1474
|
-
prec_trees,
|
|
1475
|
-
moves,
|
|
1476
|
-
move_counts,
|
|
1477
|
-
move_precs,
|
|
1478
|
-
bart
|
|
1479
|
-
prelkv,
|
|
1480
|
-
prelf,
|
|
1901
|
+
pts = SeqStageInPerTree(
|
|
1902
|
+
pso.bart.forest.leaf_trees,
|
|
1903
|
+
pso.prec_trees,
|
|
1904
|
+
pso.moves,
|
|
1905
|
+
pso.move_counts,
|
|
1906
|
+
pso.move_precs,
|
|
1907
|
+
pso.bart.forest.leaf_indices,
|
|
1908
|
+
pso.prelkv,
|
|
1909
|
+
pso.prelf,
|
|
1481
1910
|
)
|
|
1482
|
-
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
bart
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1911
|
+
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
|
|
1912
|
+
|
|
1913
|
+
save_ratios = pso.bart.forest.log_likelihood is not None
|
|
1914
|
+
bart = replace(
|
|
1915
|
+
pso.bart,
|
|
1916
|
+
resid=resid,
|
|
1917
|
+
forest=replace(
|
|
1918
|
+
pso.bart.forest,
|
|
1919
|
+
leaf_trees=leaf_trees,
|
|
1920
|
+
log_likelihood=ratios['log_likelihood'] if save_ratios else None,
|
|
1921
|
+
log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
|
|
1922
|
+
),
|
|
1923
|
+
)
|
|
1924
|
+
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
1489
1925
|
|
|
1490
1926
|
return bart, moves
|
|
1491
1927
|
|
|
1492
1928
|
|
|
1493
|
-
|
|
1494
|
-
X,
|
|
1495
|
-
ntree,
|
|
1496
|
-
resid_batch_size,
|
|
1497
|
-
resid,
|
|
1498
|
-
prec_scale,
|
|
1499
|
-
min_points_per_leaf,
|
|
1500
|
-
save_ratios,
|
|
1501
|
-
prelk,
|
|
1502
|
-
leaf_tree,
|
|
1503
|
-
prec_tree,
|
|
1504
|
-
move,
|
|
1505
|
-
move_counts,
|
|
1506
|
-
move_precs,
|
|
1507
|
-
leaf_indices,
|
|
1508
|
-
prelkv,
|
|
1509
|
-
prelf,
|
|
1510
|
-
):
|
|
1929
|
+
class SeqStageInAllTrees(Module):
|
|
1511
1930
|
"""
|
|
1512
|
-
|
|
1931
|
+
The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
|
|
1513
1932
|
|
|
1514
1933
|
Parameters
|
|
1515
1934
|
----------
|
|
1516
|
-
X
|
|
1935
|
+
X
|
|
1517
1936
|
The predictors.
|
|
1518
|
-
|
|
1519
|
-
The number of trees in the forest.
|
|
1520
|
-
resid_batch_size : int, None
|
|
1937
|
+
resid_batch_size
|
|
1521
1938
|
The batch size for computing the sum of residuals in each leaf.
|
|
1522
|
-
|
|
1523
|
-
The residuals (data minus forest value).
|
|
1524
|
-
prec_scale : float array (n,) or None
|
|
1939
|
+
prec_scale
|
|
1525
1940
|
The scale of the precision of the error on each datapoint. If None, it
|
|
1526
1941
|
is assumed to be 1.
|
|
1527
|
-
min_points_per_leaf
|
|
1942
|
+
min_points_per_leaf
|
|
1528
1943
|
The minimum number of data points in a leaf node.
|
|
1529
|
-
save_ratios
|
|
1944
|
+
save_ratios
|
|
1530
1945
|
Whether to save the acceptance ratios.
|
|
1531
|
-
prelk
|
|
1946
|
+
prelk
|
|
1532
1947
|
The pre-computed terms of the likelihood ratio which are shared across
|
|
1533
1948
|
trees.
|
|
1534
|
-
|
|
1949
|
+
"""
|
|
1950
|
+
|
|
1951
|
+
X: UInt[Array, 'p n']
|
|
1952
|
+
resid_batch_size: int | None
|
|
1953
|
+
prec_scale: Float32[Array, 'n'] | None
|
|
1954
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
1955
|
+
save_ratios: bool
|
|
1956
|
+
prelk: PreLk
|
|
1957
|
+
|
|
1958
|
+
|
|
1959
|
+
class SeqStageInPerTree(Module):
|
|
1960
|
+
"""
|
|
1961
|
+
The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
|
|
1962
|
+
|
|
1963
|
+
Parameters
|
|
1964
|
+
----------
|
|
1965
|
+
leaf_tree
|
|
1535
1966
|
The leaf values of the tree.
|
|
1536
|
-
prec_tree
|
|
1967
|
+
prec_tree
|
|
1537
1968
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1538
|
-
move
|
|
1539
|
-
The proposed move, see `
|
|
1540
|
-
move_counts
|
|
1969
|
+
move
|
|
1970
|
+
The proposed move, see `propose_moves`.
|
|
1971
|
+
move_counts
|
|
1541
1972
|
The counts of the number of points in the the nodes modified by the
|
|
1542
1973
|
moves.
|
|
1543
|
-
move_precs
|
|
1974
|
+
move_precs
|
|
1544
1975
|
The likelihood precision scale in each node modified by the moves.
|
|
1545
|
-
leaf_indices
|
|
1976
|
+
leaf_indices
|
|
1546
1977
|
The leaf indices for the largest version of the tree compatible with
|
|
1547
1978
|
the move.
|
|
1548
|
-
prelkv
|
|
1979
|
+
prelkv
|
|
1980
|
+
prelf
|
|
1549
1981
|
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
1550
1982
|
are specific to the tree.
|
|
1983
|
+
"""
|
|
1984
|
+
|
|
1985
|
+
leaf_tree: Float32[Array, '2**d']
|
|
1986
|
+
prec_tree: Float32[Array, '2**d']
|
|
1987
|
+
move: Moves
|
|
1988
|
+
move_counts: Counts | None
|
|
1989
|
+
move_precs: Precs | Counts
|
|
1990
|
+
leaf_indices: UInt[Array, 'n']
|
|
1991
|
+
prelkv: PreLkV
|
|
1992
|
+
prelf: PreLf
|
|
1993
|
+
|
|
1994
|
+
|
|
1995
|
+
def accept_move_and_sample_leaves(
|
|
1996
|
+
resid: Float32[Array, 'n'],
|
|
1997
|
+
at: SeqStageInAllTrees,
|
|
1998
|
+
pt: SeqStageInPerTree,
|
|
1999
|
+
) -> tuple[
|
|
2000
|
+
Float32[Array, 'n'],
|
|
2001
|
+
Float32[Array, '2**d'],
|
|
2002
|
+
Bool[Array, ''],
|
|
2003
|
+
Bool[Array, ''],
|
|
2004
|
+
dict[str, Float32[Array, '']],
|
|
2005
|
+
]:
|
|
2006
|
+
"""
|
|
2007
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
2008
|
+
|
|
2009
|
+
Parameters
|
|
2010
|
+
----------
|
|
2011
|
+
resid
|
|
2012
|
+
The residuals (data minus forest value).
|
|
2013
|
+
at
|
|
2014
|
+
The inputs that are the same for all trees.
|
|
2015
|
+
pt
|
|
2016
|
+
The inputs that are separate for each tree.
|
|
1551
2017
|
|
|
1552
2018
|
Returns
|
|
1553
2019
|
-------
|
|
1554
|
-
resid :
|
|
2020
|
+
resid : Float32[Array, 'n']
|
|
1555
2021
|
The updated residuals (data minus forest value).
|
|
1556
|
-
leaf_tree :
|
|
2022
|
+
leaf_tree : Float32[Array, '2**d']
|
|
1557
2023
|
The new leaf values of the tree.
|
|
1558
|
-
acc :
|
|
2024
|
+
acc : Bool[Array, '']
|
|
1559
2025
|
Whether the move was accepted.
|
|
1560
|
-
to_prune :
|
|
2026
|
+
to_prune : Bool[Array, '']
|
|
1561
2027
|
Whether, to reflect the acceptance status of the move, the state should
|
|
1562
2028
|
be updated by pruning the leaves involved in the move.
|
|
1563
|
-
ratios : dict
|
|
2029
|
+
ratios : dict[str, Float32[Array, '']]
|
|
1564
2030
|
The acceptance ratios for the moves. Empty if not to be saved.
|
|
1565
2031
|
"""
|
|
1566
|
-
|
|
1567
2032
|
# sum residuals in each leaf, in tree proposed by grow move
|
|
1568
|
-
if prec_scale is None:
|
|
2033
|
+
if at.prec_scale is None:
|
|
1569
2034
|
scaled_resid = resid
|
|
1570
2035
|
else:
|
|
1571
|
-
scaled_resid = resid * prec_scale
|
|
1572
|
-
resid_tree = sum_resid(
|
|
2036
|
+
scaled_resid = resid * at.prec_scale
|
|
2037
|
+
resid_tree = sum_resid(
|
|
2038
|
+
scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
|
|
2039
|
+
)
|
|
1573
2040
|
|
|
1574
2041
|
# subtract starting tree from function
|
|
1575
|
-
resid_tree += prec_tree * leaf_tree
|
|
2042
|
+
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
1576
2043
|
|
|
1577
2044
|
# get indices of move
|
|
1578
|
-
node = move
|
|
2045
|
+
node = pt.move.node
|
|
1579
2046
|
assert node.dtype == jnp.int32
|
|
1580
|
-
left = move
|
|
1581
|
-
right = move
|
|
2047
|
+
left = pt.move.left
|
|
2048
|
+
right = pt.move.right
|
|
1582
2049
|
|
|
1583
2050
|
# sum residuals in parent node modified by move
|
|
1584
2051
|
resid_left = resid_tree[left]
|
|
@@ -1588,30 +2055,33 @@ def accept_move_and_sample_leaves(
|
|
|
1588
2055
|
|
|
1589
2056
|
# compute acceptance ratio
|
|
1590
2057
|
log_lk_ratio = compute_likelihood_ratio(
|
|
1591
|
-
resid_total, resid_left, resid_right, prelkv, prelk
|
|
2058
|
+
resid_total, resid_left, resid_right, pt.prelkv, at.prelk
|
|
1592
2059
|
)
|
|
1593
|
-
log_ratio = move
|
|
1594
|
-
log_ratio = jnp.where(move
|
|
2060
|
+
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
2061
|
+
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
1595
2062
|
ratios = {}
|
|
1596
|
-
if save_ratios:
|
|
2063
|
+
if at.save_ratios:
|
|
1597
2064
|
ratios.update(
|
|
1598
|
-
log_trans_prior=move
|
|
2065
|
+
log_trans_prior=pt.move.log_trans_prior_ratio,
|
|
2066
|
+
# TODO save log_trans_prior_ratio as a vector outside of this loop,
|
|
2067
|
+
# then change the option everywhere to `save_likelihood_ratio`.
|
|
1599
2068
|
log_likelihood=log_lk_ratio,
|
|
1600
2069
|
)
|
|
1601
2070
|
|
|
1602
2071
|
# determine whether to accept the move
|
|
1603
|
-
acc = move
|
|
1604
|
-
if min_points_per_leaf is not None:
|
|
1605
|
-
|
|
1606
|
-
acc &= move_counts
|
|
2072
|
+
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
2073
|
+
if at.min_points_per_leaf is not None:
|
|
2074
|
+
assert pt.move_counts is not None
|
|
2075
|
+
acc &= pt.move_counts.left >= at.min_points_per_leaf
|
|
2076
|
+
acc &= pt.move_counts.right >= at.min_points_per_leaf
|
|
1607
2077
|
|
|
1608
2078
|
# compute leaves posterior and sample leaves
|
|
1609
|
-
initial_leaf_tree = leaf_tree
|
|
1610
|
-
mean_post = resid_tree * prelf
|
|
1611
|
-
leaf_tree = mean_post + prelf
|
|
2079
|
+
initial_leaf_tree = pt.leaf_tree
|
|
2080
|
+
mean_post = resid_tree * pt.prelf.mean_factor
|
|
2081
|
+
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
1612
2082
|
|
|
1613
2083
|
# copy leaves around such that the leaf indices point to the correct leaf
|
|
1614
|
-
to_prune = acc ^ move
|
|
2084
|
+
to_prune = acc ^ pt.move.grow
|
|
1615
2085
|
leaf_tree = (
|
|
1616
2086
|
leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1617
2087
|
.set(leaf_tree[node])
|
|
@@ -1620,43 +2090,51 @@ def accept_move_and_sample_leaves(
|
|
|
1620
2090
|
)
|
|
1621
2091
|
|
|
1622
2092
|
# replace old tree with new tree in function values
|
|
1623
|
-
resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
|
|
2093
|
+
resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
|
|
1624
2094
|
|
|
1625
2095
|
return resid, leaf_tree, acc, to_prune, ratios
|
|
1626
2096
|
|
|
1627
2097
|
|
|
1628
|
-
def sum_resid(
|
|
2098
|
+
def sum_resid(
|
|
2099
|
+
scaled_resid: Float32[Array, 'n'],
|
|
2100
|
+
leaf_indices: UInt[Array, 'n'],
|
|
2101
|
+
tree_size: int,
|
|
2102
|
+
batch_size: int | None,
|
|
2103
|
+
) -> Float32[Array, '{tree_size}']:
|
|
1629
2104
|
"""
|
|
1630
2105
|
Sum the residuals in each leaf.
|
|
1631
2106
|
|
|
1632
2107
|
Parameters
|
|
1633
2108
|
----------
|
|
1634
|
-
scaled_resid
|
|
2109
|
+
scaled_resid
|
|
1635
2110
|
The residuals (data minus forest value) multiplied by the error
|
|
1636
2111
|
precision scale.
|
|
1637
|
-
leaf_indices
|
|
2112
|
+
leaf_indices
|
|
1638
2113
|
The leaf indices of the tree (in which leaf each data point falls into).
|
|
1639
|
-
tree_size
|
|
2114
|
+
tree_size
|
|
1640
2115
|
The size of the tree array (2 ** d).
|
|
1641
|
-
batch_size
|
|
2116
|
+
batch_size
|
|
1642
2117
|
The data batch size for the aggregation. Batching increases numerical
|
|
1643
2118
|
accuracy and parallelism.
|
|
1644
2119
|
|
|
1645
2120
|
Returns
|
|
1646
2121
|
-------
|
|
1647
|
-
|
|
1648
|
-
The sum of the residuals at data points in each leaf.
|
|
2122
|
+
The sum of the residuals at data points in each leaf.
|
|
1649
2123
|
"""
|
|
1650
2124
|
if batch_size is None:
|
|
1651
2125
|
aggr_func = _aggregate_scatter
|
|
1652
2126
|
else:
|
|
1653
|
-
aggr_func =
|
|
1654
|
-
return aggr_func(
|
|
1655
|
-
scaled_resid, leaf_indices, tree_size, jnp.float32
|
|
1656
|
-
) # TODO: use large_float
|
|
2127
|
+
aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
2128
|
+
return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
|
|
1657
2129
|
|
|
1658
2130
|
|
|
1659
|
-
def _aggregate_batched_onetree(
|
|
2131
|
+
def _aggregate_batched_onetree(
|
|
2132
|
+
values: Shaped[Array, '*'],
|
|
2133
|
+
indices: Integer[Array, '*'],
|
|
2134
|
+
size: int,
|
|
2135
|
+
dtype: jnp.dtype,
|
|
2136
|
+
batch_size: int,
|
|
2137
|
+
) -> Float32[Array, '{size}']:
|
|
1660
2138
|
(n,) = indices.shape
|
|
1661
2139
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1662
2140
|
batch_indices = jnp.arange(n) % nbatches
|
|
@@ -1668,118 +2146,133 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
|
1668
2146
|
)
|
|
1669
2147
|
|
|
1670
2148
|
|
|
1671
|
-
def compute_likelihood_ratio(
|
|
2149
|
+
def compute_likelihood_ratio(
|
|
2150
|
+
total_resid: Float32[Array, ''],
|
|
2151
|
+
left_resid: Float32[Array, ''],
|
|
2152
|
+
right_resid: Float32[Array, ''],
|
|
2153
|
+
prelkv: PreLkV,
|
|
2154
|
+
prelk: PreLk,
|
|
2155
|
+
) -> Float32[Array, '']:
|
|
1672
2156
|
"""
|
|
1673
2157
|
Compute the likelihood ratio of a grow move.
|
|
1674
2158
|
|
|
1675
2159
|
Parameters
|
|
1676
2160
|
----------
|
|
1677
|
-
total_resid
|
|
2161
|
+
total_resid
|
|
2162
|
+
left_resid
|
|
2163
|
+
right_resid
|
|
1678
2164
|
The sum of the residuals (scaled by error precision scale) of the
|
|
1679
2165
|
datapoints falling in the nodes involved in the moves.
|
|
1680
|
-
prelkv
|
|
2166
|
+
prelkv
|
|
2167
|
+
prelk
|
|
1681
2168
|
The pre-computed terms of the likelihood ratio, see
|
|
1682
2169
|
`precompute_likelihood_terms`.
|
|
1683
2170
|
|
|
1684
2171
|
Returns
|
|
1685
2172
|
-------
|
|
1686
|
-
ratio
|
|
1687
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
2173
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1688
2174
|
"""
|
|
1689
|
-
exp_term = prelk
|
|
1690
|
-
left_resid * left_resid / prelkv
|
|
1691
|
-
+ right_resid * right_resid / prelkv
|
|
1692
|
-
- total_resid * total_resid / prelkv
|
|
2175
|
+
exp_term = prelk.exp_factor * (
|
|
2176
|
+
left_resid * left_resid / prelkv.sigma2_left
|
|
2177
|
+
+ right_resid * right_resid / prelkv.sigma2_right
|
|
2178
|
+
- total_resid * total_resid / prelkv.sigma2_total
|
|
1693
2179
|
)
|
|
1694
|
-
return prelkv
|
|
2180
|
+
return prelkv.sqrt_term + exp_term
|
|
1695
2181
|
|
|
1696
2182
|
|
|
1697
|
-
def accept_moves_final_stage(bart, moves):
|
|
2183
|
+
def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
1698
2184
|
"""
|
|
1699
|
-
|
|
2185
|
+
Post-process the mcmc state after accepting/rejecting the moves.
|
|
2186
|
+
|
|
2187
|
+
This function is separate from `accept_moves_sequential_stage` to signal it
|
|
2188
|
+
can work in parallel across trees.
|
|
1700
2189
|
|
|
1701
2190
|
Parameters
|
|
1702
2191
|
----------
|
|
1703
|
-
bart
|
|
2192
|
+
bart
|
|
1704
2193
|
A partially updated BART mcmc state.
|
|
1705
|
-
|
|
1706
|
-
The
|
|
1707
|
-
moves : dict
|
|
1708
|
-
The proposed moves (see `sample_moves`) as updated by
|
|
2194
|
+
moves
|
|
2195
|
+
The proposed moves (see `propose_moves`) as updated by
|
|
1709
2196
|
`accept_moves_sequential_stage`.
|
|
1710
2197
|
|
|
1711
2198
|
Returns
|
|
1712
2199
|
-------
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
2200
|
+
The fully updated BART mcmc state.
|
|
2201
|
+
"""
|
|
2202
|
+
return replace(
|
|
2203
|
+
bart,
|
|
2204
|
+
forest=replace(
|
|
2205
|
+
bart.forest,
|
|
2206
|
+
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
2207
|
+
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
2208
|
+
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
2209
|
+
split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
|
|
2210
|
+
),
|
|
2211
|
+
)
|
|
1722
2212
|
|
|
1723
2213
|
|
|
1724
|
-
@
|
|
1725
|
-
def apply_moves_to_leaf_indices(
|
|
2214
|
+
@vmap_nodoc
|
|
2215
|
+
def apply_moves_to_leaf_indices(
|
|
2216
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
|
|
2217
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1726
2218
|
"""
|
|
1727
2219
|
Update the leaf indices to match the accepted move.
|
|
1728
2220
|
|
|
1729
2221
|
Parameters
|
|
1730
2222
|
----------
|
|
1731
|
-
leaf_indices
|
|
2223
|
+
leaf_indices
|
|
1732
2224
|
The index of the leaf each datapoint falls into, if the grow move was
|
|
1733
2225
|
accepted.
|
|
1734
|
-
moves
|
|
1735
|
-
The proposed moves (see `
|
|
2226
|
+
moves
|
|
2227
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1736
2228
|
`accept_moves_sequential_stage`.
|
|
1737
2229
|
|
|
1738
2230
|
Returns
|
|
1739
2231
|
-------
|
|
1740
|
-
|
|
1741
|
-
The updated leaf indices.
|
|
2232
|
+
The updated leaf indices.
|
|
1742
2233
|
"""
|
|
1743
2234
|
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1744
|
-
is_child = (leaf_indices & mask) == moves
|
|
2235
|
+
is_child = (leaf_indices & mask) == moves.left
|
|
1745
2236
|
return jnp.where(
|
|
1746
|
-
is_child & moves
|
|
1747
|
-
moves
|
|
2237
|
+
is_child & moves.to_prune,
|
|
2238
|
+
moves.node.astype(leaf_indices.dtype),
|
|
1748
2239
|
leaf_indices,
|
|
1749
2240
|
)
|
|
1750
2241
|
|
|
1751
2242
|
|
|
1752
|
-
@
|
|
1753
|
-
def apply_moves_to_split_trees(
|
|
2243
|
+
@vmap_nodoc
|
|
2244
|
+
def apply_moves_to_split_trees(
|
|
2245
|
+
split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
2246
|
+
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
1754
2247
|
"""
|
|
1755
2248
|
Update the split trees to match the accepted move.
|
|
1756
2249
|
|
|
1757
2250
|
Parameters
|
|
1758
2251
|
----------
|
|
1759
|
-
split_trees
|
|
2252
|
+
split_trees
|
|
1760
2253
|
The cutpoints of the decision nodes in the initial trees.
|
|
1761
|
-
moves
|
|
1762
|
-
The proposed moves (see `
|
|
2254
|
+
moves
|
|
2255
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1763
2256
|
`accept_moves_sequential_stage`.
|
|
1764
2257
|
|
|
1765
2258
|
Returns
|
|
1766
2259
|
-------
|
|
1767
|
-
|
|
1768
|
-
The updated split trees.
|
|
2260
|
+
The updated split trees.
|
|
1769
2261
|
"""
|
|
2262
|
+
assert moves.to_prune is not None
|
|
1770
2263
|
return (
|
|
1771
2264
|
split_trees.at[
|
|
1772
2265
|
jnp.where(
|
|
1773
|
-
moves
|
|
1774
|
-
moves
|
|
2266
|
+
moves.grow,
|
|
2267
|
+
moves.node,
|
|
1775
2268
|
split_trees.size,
|
|
1776
2269
|
)
|
|
1777
2270
|
]
|
|
1778
|
-
.set(moves
|
|
2271
|
+
.set(moves.grow_split.astype(split_trees.dtype))
|
|
1779
2272
|
.at[
|
|
1780
2273
|
jnp.where(
|
|
1781
|
-
moves
|
|
1782
|
-
moves
|
|
2274
|
+
moves.to_prune,
|
|
2275
|
+
moves.node,
|
|
1783
2276
|
split_trees.size,
|
|
1784
2277
|
)
|
|
1785
2278
|
]
|
|
@@ -1787,34 +2280,56 @@ def apply_moves_to_split_trees(split_trees, moves):
|
|
|
1787
2280
|
)
|
|
1788
2281
|
|
|
1789
2282
|
|
|
1790
|
-
def
|
|
2283
|
+
def step_sigma(key: Key[Array, ''], bart: State) -> State:
|
|
1791
2284
|
"""
|
|
1792
|
-
|
|
2285
|
+
MCMC-update the error variance (factor).
|
|
1793
2286
|
|
|
1794
2287
|
Parameters
|
|
1795
2288
|
----------
|
|
1796
|
-
key
|
|
2289
|
+
key
|
|
1797
2290
|
A jax random key.
|
|
1798
|
-
bart
|
|
1799
|
-
A BART mcmc state
|
|
2291
|
+
bart
|
|
2292
|
+
A BART mcmc state.
|
|
1800
2293
|
|
|
1801
2294
|
Returns
|
|
1802
2295
|
-------
|
|
1803
|
-
|
|
1804
|
-
The new BART mcmc state.
|
|
2296
|
+
The new BART mcmc state, with an updated `sigma2`.
|
|
1805
2297
|
"""
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1810
|
-
if bart['prec_scale'] is None:
|
|
2298
|
+
resid = bart.resid
|
|
2299
|
+
alpha = bart.sigma2_alpha + resid.size / 2
|
|
2300
|
+
if bart.prec_scale is None:
|
|
1811
2301
|
scaled_resid = resid
|
|
1812
2302
|
else:
|
|
1813
|
-
scaled_resid = resid * bart
|
|
2303
|
+
scaled_resid = resid * bart.prec_scale
|
|
1814
2304
|
norm2 = resid @ scaled_resid
|
|
1815
|
-
beta = bart
|
|
2305
|
+
beta = bart.sigma2_beta + norm2 / 2
|
|
1816
2306
|
|
|
1817
2307
|
sample = random.gamma(key, alpha)
|
|
1818
|
-
bart
|
|
2308
|
+
return replace(bart, sigma2=beta / sample)
|
|
1819
2309
|
|
|
1820
|
-
|
|
2310
|
+
|
|
2311
|
+
def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
2312
|
+
"""
|
|
2313
|
+
MCMC-update the latent variable for binary regression.
|
|
2314
|
+
|
|
2315
|
+
Parameters
|
|
2316
|
+
----------
|
|
2317
|
+
key
|
|
2318
|
+
A jax random key.
|
|
2319
|
+
bart
|
|
2320
|
+
A BART MCMC state.
|
|
2321
|
+
|
|
2322
|
+
Returns
|
|
2323
|
+
-------
|
|
2324
|
+
The updated BART MCMC state.
|
|
2325
|
+
"""
|
|
2326
|
+
trees_plus_offset = bart.z - bart.resid
|
|
2327
|
+
lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
|
|
2328
|
+
upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
|
|
2329
|
+
resid = random.truncated_normal(key, lower, upper)
|
|
2330
|
+
# TODO jax's implementation of truncated_normal is not good, it just does
|
|
2331
|
+
# cdf inversion with erf and erf_inv. I can do better, at least avoiding to
|
|
2332
|
+
# compute one of the boundaries, and maybe also flipping and using ndtr
|
|
2333
|
+
# instead of erf for numerical stability (open an issue in jax?)
|
|
2334
|
+
z = trees_plus_offset + resid
|
|
2335
|
+
return replace(bart, z=z, resid=resid)
|