bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +582 -279
- bartz/__init__.py +3 -3
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +168 -81
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +568 -158
- bartz/mcmcstep.py +1722 -926
- bartz/prepcovars.py +142 -44
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/METADATA +6 -5
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -374
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/mcmcstep.py
CHANGED
|
@@ -26,220 +26,394 @@
|
|
|
26
26
|
Functions that implement the BART posterior MCMC initialization and update step.
|
|
27
27
|
|
|
28
28
|
Functions that do MCMC steps operate by taking as input a bart state, and
|
|
29
|
-
outputting a new
|
|
30
|
-
modified.
|
|
29
|
+
outputting a new state. The inputs are not modified.
|
|
31
30
|
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
The entry points are:
|
|
32
|
+
|
|
33
|
+
- `State`: The dataclass that represents a BART MCMC state.
|
|
34
|
+
- `init`: Creates an initial `State` from data and configurations.
|
|
35
|
+
- `step`: Performs one full MCMC step on a `State`, returning a new `State`.
|
|
36
|
+
- `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
|
|
34
37
|
"""
|
|
35
38
|
|
|
36
|
-
import functools
|
|
37
39
|
import math
|
|
40
|
+
from dataclasses import replace
|
|
41
|
+
from functools import cache, partial
|
|
42
|
+
from typing import Any, Literal
|
|
38
43
|
|
|
39
44
|
import jax
|
|
45
|
+
from equinox import Module, field, tree_at
|
|
40
46
|
from jax import lax, random
|
|
41
47
|
from jax import numpy as jnp
|
|
48
|
+
from jax.scipy.special import gammaln, logsumexp
|
|
49
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
|
|
50
|
+
|
|
51
|
+
from bartz import grove
|
|
52
|
+
from bartz.jaxext import (
|
|
53
|
+
minimal_unsigned_dtype,
|
|
54
|
+
split,
|
|
55
|
+
truncated_normal_onesided,
|
|
56
|
+
vmap_nodoc,
|
|
57
|
+
)
|
|
42
58
|
|
|
43
|
-
|
|
59
|
+
|
|
60
|
+
class Forest(Module):
|
|
61
|
+
"""
|
|
62
|
+
Represents the MCMC state of a sum of trees.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
leaf_tree
|
|
67
|
+
The leaf values.
|
|
68
|
+
var_tree
|
|
69
|
+
The decision axes.
|
|
70
|
+
split_tree
|
|
71
|
+
The decision boundaries.
|
|
72
|
+
affluence_tree
|
|
73
|
+
Marks leaves that can be grown.
|
|
74
|
+
max_split
|
|
75
|
+
The maximum split index for each predictor.
|
|
76
|
+
blocked_vars
|
|
77
|
+
Indices of variables that are not used. This shall include at least
|
|
78
|
+
the `i` such that ``max_split[i] == 0``, otherwise behavior is
|
|
79
|
+
undefined.
|
|
80
|
+
p_nonterminal
|
|
81
|
+
The prior probability of each node being nonterminal, conditional on
|
|
82
|
+
its ancestors. Includes the nodes at maximum depth which should be set
|
|
83
|
+
to 0.
|
|
84
|
+
p_propose_grow
|
|
85
|
+
The unnormalized probability of picking a leaf for a grow proposal.
|
|
86
|
+
leaf_indices
|
|
87
|
+
The index of the leaf each datapoints falls into, for each tree.
|
|
88
|
+
min_points_per_decision_node
|
|
89
|
+
The minimum number of data points in a decision node.
|
|
90
|
+
min_points_per_leaf
|
|
91
|
+
The minimum number of data points in a leaf node.
|
|
92
|
+
resid_batch_size
|
|
93
|
+
count_batch_size
|
|
94
|
+
The data batch sizes for computing the sufficient statistics. If `None`,
|
|
95
|
+
they are computed with no batching.
|
|
96
|
+
log_trans_prior
|
|
97
|
+
The log transition and prior Metropolis-Hastings ratio for the
|
|
98
|
+
proposed move on each tree.
|
|
99
|
+
log_likelihood
|
|
100
|
+
The log likelihood ratio.
|
|
101
|
+
grow_prop_count
|
|
102
|
+
prune_prop_count
|
|
103
|
+
The number of grow/prune proposals made during one full MCMC cycle.
|
|
104
|
+
grow_acc_count
|
|
105
|
+
prune_acc_count
|
|
106
|
+
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
+
sigma_mu2
|
|
108
|
+
The prior variance of a leaf, conditional on the tree structure.
|
|
109
|
+
log_s
|
|
110
|
+
The logarithm of the prior probability for choosing a variable to split
|
|
111
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
112
|
+
If `None`, use a uniform distribution.
|
|
113
|
+
theta
|
|
114
|
+
The concentration parameter for the Dirichlet prior on the variable
|
|
115
|
+
distribution `s`. Required only to update `s`.
|
|
116
|
+
a
|
|
117
|
+
b
|
|
118
|
+
rho
|
|
119
|
+
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
120
|
+
See `step_theta`.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
leaf_tree: Float32[Array, 'num_trees 2**d']
|
|
124
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
125
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
126
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
127
|
+
max_split: UInt[Array, ' p']
|
|
128
|
+
blocked_vars: UInt[Array, ' k'] | None
|
|
129
|
+
p_nonterminal: Float32[Array, ' 2**d']
|
|
130
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)']
|
|
131
|
+
leaf_indices: UInt[Array, 'num_trees n']
|
|
132
|
+
min_points_per_decision_node: Int32[Array, ''] | None
|
|
133
|
+
min_points_per_leaf: Int32[Array, ''] | None
|
|
134
|
+
resid_batch_size: int | None = field(static=True)
|
|
135
|
+
count_batch_size: int | None = field(static=True)
|
|
136
|
+
log_trans_prior: Float32[Array, ' num_trees'] | None
|
|
137
|
+
log_likelihood: Float32[Array, ' num_trees'] | None
|
|
138
|
+
grow_prop_count: Int32[Array, '']
|
|
139
|
+
prune_prop_count: Int32[Array, '']
|
|
140
|
+
grow_acc_count: Int32[Array, '']
|
|
141
|
+
prune_acc_count: Int32[Array, '']
|
|
142
|
+
sigma_mu2: Float32[Array, '']
|
|
143
|
+
log_s: Float32[Array, ' p'] | None
|
|
144
|
+
theta: Float32[Array, ''] | None
|
|
145
|
+
a: Float32[Array, ''] | None
|
|
146
|
+
b: Float32[Array, ''] | None
|
|
147
|
+
rho: Float32[Array, ''] | None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class State(Module):
|
|
151
|
+
"""
|
|
152
|
+
Represents the MCMC state of BART.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
X
|
|
157
|
+
The predictors.
|
|
158
|
+
y
|
|
159
|
+
The response. If the data type is `bool`, the model is binary regression.
|
|
160
|
+
resid
|
|
161
|
+
The residuals (`y` or `z` minus sum of trees).
|
|
162
|
+
z
|
|
163
|
+
The latent variable for binary regression. `None` in continuous
|
|
164
|
+
regression.
|
|
165
|
+
offset
|
|
166
|
+
Constant shift added to the sum of trees.
|
|
167
|
+
sigma2
|
|
168
|
+
The error variance. `None` in binary regression.
|
|
169
|
+
prec_scale
|
|
170
|
+
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
171
|
+
`None` in binary regression.
|
|
172
|
+
sigma2_alpha
|
|
173
|
+
sigma2_beta
|
|
174
|
+
The shape and scale parameters of the inverse gamma prior on the noise
|
|
175
|
+
variance. `None` in binary regression.
|
|
176
|
+
forest
|
|
177
|
+
The sum of trees model.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
X: UInt[Array, 'p n']
|
|
181
|
+
y: Float32[Array, ' n'] | Bool[Array, ' n']
|
|
182
|
+
z: None | Float32[Array, ' n']
|
|
183
|
+
offset: Float32[Array, '']
|
|
184
|
+
resid: Float32[Array, ' n']
|
|
185
|
+
sigma2: Float32[Array, ''] | None
|
|
186
|
+
prec_scale: Float32[Array, ' n'] | None
|
|
187
|
+
sigma2_alpha: Float32[Array, ''] | None
|
|
188
|
+
sigma2_beta: Float32[Array, ''] | None
|
|
189
|
+
forest: Forest
|
|
44
190
|
|
|
45
191
|
|
|
46
192
|
def init(
|
|
47
193
|
*,
|
|
48
|
-
X,
|
|
49
|
-
y,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
resid_batch_size='auto',
|
|
60
|
-
count_batch_size='auto',
|
|
61
|
-
save_ratios=False,
|
|
62
|
-
|
|
194
|
+
X: UInt[Any, 'p n'],
|
|
195
|
+
y: Float32[Any, ' n'] | Bool[Any, ' n'],
|
|
196
|
+
offset: float | Float32[Any, ''] = 0.0,
|
|
197
|
+
max_split: UInt[Any, ' p'],
|
|
198
|
+
num_trees: int,
|
|
199
|
+
p_nonterminal: Float32[Any, ' d-1'],
|
|
200
|
+
sigma_mu2: float | Float32[Any, ''],
|
|
201
|
+
sigma2_alpha: float | Float32[Any, ''] | None = None,
|
|
202
|
+
sigma2_beta: float | Float32[Any, ''] | None = None,
|
|
203
|
+
error_scale: Float32[Any, ' n'] | None = None,
|
|
204
|
+
min_points_per_decision_node: int | Integer[Any, ''] | None = None,
|
|
205
|
+
resid_batch_size: int | None | Literal['auto'] = 'auto',
|
|
206
|
+
count_batch_size: int | None | Literal['auto'] = 'auto',
|
|
207
|
+
save_ratios: bool = False,
|
|
208
|
+
filter_splitless_vars: bool = True,
|
|
209
|
+
min_points_per_leaf: int | Integer[Any, ''] | None = None,
|
|
210
|
+
log_s: Float32[Any, ' p'] | None = None,
|
|
211
|
+
theta: float | Float32[Any, ''] | None = None,
|
|
212
|
+
a: float | Float32[Any, ''] | None = None,
|
|
213
|
+
b: float | Float32[Any, ''] | None = None,
|
|
214
|
+
rho: float | Float32[Any, ''] | None = None,
|
|
215
|
+
) -> State:
|
|
63
216
|
"""
|
|
64
217
|
Make a BART posterior sampling MCMC initial state.
|
|
65
218
|
|
|
66
219
|
Parameters
|
|
67
220
|
----------
|
|
68
|
-
X
|
|
221
|
+
X
|
|
69
222
|
The predictors. Note this is trasposed compared to the usual convention.
|
|
70
|
-
y
|
|
71
|
-
The response.
|
|
72
|
-
|
|
223
|
+
y
|
|
224
|
+
The response. If the data type is `bool`, the regression model is binary
|
|
225
|
+
regression with probit.
|
|
226
|
+
offset
|
|
227
|
+
Constant shift added to the sum of trees. 0 if not specified.
|
|
228
|
+
max_split
|
|
73
229
|
The maximum split index for each variable. All split ranges start at 1.
|
|
74
|
-
num_trees
|
|
230
|
+
num_trees
|
|
75
231
|
The number of trees in the forest.
|
|
76
|
-
p_nonterminal
|
|
232
|
+
p_nonterminal
|
|
77
233
|
The probability of a nonterminal node at each depth. The maximum depth
|
|
78
234
|
of trees is fixed by the length of this array.
|
|
79
|
-
|
|
80
|
-
The
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
235
|
+
sigma_mu2
|
|
236
|
+
The prior variance of a leaf, conditional on the tree structure. The
|
|
237
|
+
prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
|
|
238
|
+
prior mean of leaves is always zero.
|
|
239
|
+
sigma2_alpha
|
|
240
|
+
sigma2_beta
|
|
241
|
+
The shape and scale parameters of the inverse gamma prior on the error
|
|
242
|
+
variance. Leave unspecified for binary regression.
|
|
243
|
+
error_scale
|
|
84
244
|
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
85
245
|
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
The
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
246
|
+
Not supported for binary regression. If not specified, defaults to 1 for
|
|
247
|
+
all points, but potentially skipping calculations.
|
|
248
|
+
min_points_per_decision_node
|
|
249
|
+
The minimum number of data points in a decision node. 0 if not
|
|
250
|
+
specified.
|
|
251
|
+
resid_batch_size
|
|
252
|
+
count_batch_size
|
|
93
253
|
The batch sizes, along datapoints, for summing the residuals and
|
|
94
254
|
counting the number of datapoints in each leaf. `None` for no batching.
|
|
95
255
|
If 'auto', pick a value based on the device of `y`, or the default
|
|
96
256
|
device.
|
|
97
|
-
save_ratios
|
|
257
|
+
save_ratios
|
|
98
258
|
Whether to save the Metropolis-Hastings ratios.
|
|
259
|
+
filter_splitless_vars
|
|
260
|
+
Whether to check `max_split` for variables without available cutpoints.
|
|
261
|
+
If any are found, they are put into a list of variables to exclude from
|
|
262
|
+
the MCMC. If `False`, no check is performed, but the results may be
|
|
263
|
+
wrong if any variable is blocked. The function is jax-traceable only
|
|
264
|
+
if this is set to `False`.
|
|
265
|
+
min_points_per_leaf
|
|
266
|
+
The minimum number of datapoints in a leaf node. 0 if not specified.
|
|
267
|
+
Unlike `min_points_per_decision_node`, this constraint is not taken into
|
|
268
|
+
account in the Metropolis-Hastings ratio because it would be expensive
|
|
269
|
+
to compute. Grow moves that would violate this constraint are vetoed.
|
|
270
|
+
This parameter is independent of `min_points_per_decision_node` and
|
|
271
|
+
there is no check that they are coherent. It makes sense to set
|
|
272
|
+
``min_points_per_decision_node >= 2 * min_points_per_leaf``.
|
|
273
|
+
log_s
|
|
274
|
+
The logarithm of the prior probability for choosing a variable to split
|
|
275
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
276
|
+
If not specified, use a uniform distribution. If not specified and
|
|
277
|
+
`theta` or `rho`, `a`, `b` are, it's initialized automatically.
|
|
278
|
+
theta
|
|
279
|
+
The concentration parameter for the Dirichlet prior on `s`. Required
|
|
280
|
+
only to update `log_s`. If not specified, and `rho`, `a`, `b` are
|
|
281
|
+
specified, it's initialized automatically.
|
|
282
|
+
a
|
|
283
|
+
b
|
|
284
|
+
rho
|
|
285
|
+
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
99
286
|
|
|
100
287
|
Returns
|
|
101
288
|
-------
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
'prec_scale' : large_float array (n,) or None
|
|
118
|
-
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
119
|
-
'grow_prop_count', 'prune_prop_count' : int
|
|
120
|
-
The number of grow/prune proposals made during one full MCMC cycle.
|
|
121
|
-
'grow_acc_count', 'prune_acc_count' : int
|
|
122
|
-
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
123
|
-
'p_nonterminal' : large_float array (d,)
|
|
124
|
-
The probability of a nonterminal node at each depth, padded with a
|
|
125
|
-
zero.
|
|
126
|
-
'p_propose_grow' : large_float array (2 ** (d - 1),)
|
|
127
|
-
The unnormalized probability of picking a leaf for a grow proposal.
|
|
128
|
-
'sigma2_alpha' : large_float
|
|
129
|
-
The shape parameter of the inverse gamma prior on the noise variance.
|
|
130
|
-
'sigma2_beta' : large_float
|
|
131
|
-
The scale parameter of the inverse gamma prior on the noise variance.
|
|
132
|
-
'max_split' : int array (p,)
|
|
133
|
-
The maximum split index for each variable.
|
|
134
|
-
'y' : small_float array (n,)
|
|
135
|
-
The response.
|
|
136
|
-
'X' : int array (p, n)
|
|
137
|
-
The predictors.
|
|
138
|
-
'leaf_indices' : int array (num_trees, n)
|
|
139
|
-
The index of the leaf each datapoints falls into, for each tree.
|
|
140
|
-
'min_points_per_leaf' : int or None
|
|
141
|
-
The minimum number of data points in a leaf node.
|
|
142
|
-
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
143
|
-
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
144
|
-
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
145
|
-
'opt' : LeafDict
|
|
146
|
-
A dictionary with config values:
|
|
147
|
-
|
|
148
|
-
'small_float' : dtype
|
|
149
|
-
The dtype for large arrays used in the algorithm.
|
|
150
|
-
'large_float' : dtype
|
|
151
|
-
The dtype for scalars, small arrays, and arrays which require
|
|
152
|
-
accuracy.
|
|
153
|
-
'require_min_points' : bool
|
|
154
|
-
Whether the `min_points_per_leaf` parameter is specified.
|
|
155
|
-
'resid_batch_size', 'count_batch_size' : int or None
|
|
156
|
-
The data batch sizes for computing the sufficient statistics.
|
|
157
|
-
'ratios' : dict, optional
|
|
158
|
-
If `save_ratios` is True, this field is present. It has the fields:
|
|
159
|
-
|
|
160
|
-
'log_trans_prior' : large_float array (num_trees,)
|
|
161
|
-
The log transition and prior Metropolis-Hastings ratio for the
|
|
162
|
-
proposed move on each tree.
|
|
163
|
-
'log_likelihood' : large_float array (num_trees,)
|
|
164
|
-
The log likelihood ratio.
|
|
165
|
-
"""
|
|
166
|
-
|
|
167
|
-
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
289
|
+
An initialized BART MCMC state.
|
|
290
|
+
|
|
291
|
+
Raises
|
|
292
|
+
------
|
|
293
|
+
ValueError
|
|
294
|
+
If `y` is boolean and arguments unused in binary regression are set.
|
|
295
|
+
|
|
296
|
+
Notes
|
|
297
|
+
-----
|
|
298
|
+
In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
|
|
299
|
+
of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
|
|
300
|
+
child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
|
|
301
|
+
integers in the range ``[0, 1, ..., max_split[i]]``.
|
|
302
|
+
"""
|
|
303
|
+
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
168
304
|
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
169
305
|
max_depth = p_nonterminal.size
|
|
170
306
|
|
|
171
|
-
@
|
|
307
|
+
@partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
172
308
|
def make_forest(max_depth, dtype):
|
|
173
309
|
return grove.make_tree(max_depth, dtype)
|
|
174
310
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
311
|
+
y = jnp.asarray(y)
|
|
312
|
+
offset = jnp.asarray(offset)
|
|
313
|
+
|
|
178
314
|
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
179
315
|
resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
|
|
180
316
|
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
317
|
+
|
|
318
|
+
is_binary = y.dtype == bool
|
|
319
|
+
if is_binary:
|
|
320
|
+
if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
|
|
321
|
+
msg = (
|
|
322
|
+
'error_scale, sigma2_alpha, and sigma2_beta must be set '
|
|
323
|
+
' to `None` for binary regression.'
|
|
324
|
+
)
|
|
325
|
+
raise ValueError(msg)
|
|
326
|
+
sigma2 = None
|
|
327
|
+
else:
|
|
328
|
+
sigma2_alpha = jnp.asarray(sigma2_alpha)
|
|
329
|
+
sigma2_beta = jnp.asarray(sigma2_beta)
|
|
330
|
+
sigma2 = sigma2_beta / sigma2_alpha
|
|
331
|
+
|
|
332
|
+
max_split = jnp.asarray(max_split)
|
|
333
|
+
|
|
334
|
+
if filter_splitless_vars:
|
|
335
|
+
(blocked_vars,) = jnp.nonzero(max_split == 0)
|
|
336
|
+
blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
|
|
337
|
+
# see `fully_used_variables` for the type cast
|
|
338
|
+
else:
|
|
339
|
+
blocked_vars = None
|
|
340
|
+
|
|
341
|
+
# check and initialize sparsity parameters
|
|
342
|
+
if not _all_none_or_not_none(rho, a, b):
|
|
343
|
+
msg = 'rho, a, b are not either all `None` or all set'
|
|
344
|
+
raise ValueError(msg)
|
|
345
|
+
if theta is None and rho is not None:
|
|
346
|
+
theta = rho
|
|
347
|
+
if log_s is None and theta is not None:
|
|
348
|
+
log_s = jnp.zeros(max_split.size)
|
|
349
|
+
|
|
350
|
+
return State(
|
|
351
|
+
X=jnp.asarray(X),
|
|
352
|
+
y=y,
|
|
353
|
+
z=jnp.full(y.shape, offset) if is_binary else None,
|
|
354
|
+
offset=offset,
|
|
355
|
+
resid=jnp.zeros(y.shape) if is_binary else y - offset,
|
|
193
356
|
sigma2=sigma2,
|
|
194
357
|
prec_scale=(
|
|
195
|
-
None
|
|
196
|
-
if error_scale is None
|
|
197
|
-
else lax.reciprocal(jnp.square(jnp.asarray(error_scale, large_float)))
|
|
198
|
-
),
|
|
199
|
-
grow_prop_count=jnp.zeros((), int),
|
|
200
|
-
grow_acc_count=jnp.zeros((), int),
|
|
201
|
-
prune_prop_count=jnp.zeros((), int),
|
|
202
|
-
prune_acc_count=jnp.zeros((), int),
|
|
203
|
-
p_nonterminal=p_nonterminal,
|
|
204
|
-
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
205
|
-
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
206
|
-
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
207
|
-
max_split=jnp.asarray(max_split),
|
|
208
|
-
y=y,
|
|
209
|
-
X=jnp.asarray(X),
|
|
210
|
-
leaf_indices=jnp.ones(
|
|
211
|
-
(num_trees, y.size), jaxext.minimal_unsigned_dtype(2**max_depth - 1)
|
|
212
|
-
),
|
|
213
|
-
min_points_per_leaf=(
|
|
214
|
-
None if min_points_per_leaf is None else jnp.asarray(min_points_per_leaf)
|
|
358
|
+
None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
|
|
215
359
|
),
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
.
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
360
|
+
sigma2_alpha=sigma2_alpha,
|
|
361
|
+
sigma2_beta=sigma2_beta,
|
|
362
|
+
forest=Forest(
|
|
363
|
+
leaf_tree=make_forest(max_depth, jnp.float32),
|
|
364
|
+
var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
365
|
+
split_tree=make_forest(max_depth - 1, max_split.dtype),
|
|
366
|
+
affluence_tree=(
|
|
367
|
+
make_forest(max_depth - 1, bool)
|
|
368
|
+
.at[:, 1]
|
|
369
|
+
.set(
|
|
370
|
+
True
|
|
371
|
+
if min_points_per_decision_node is None
|
|
372
|
+
else y.size >= min_points_per_decision_node
|
|
373
|
+
)
|
|
374
|
+
),
|
|
375
|
+
blocked_vars=blocked_vars,
|
|
376
|
+
max_split=max_split,
|
|
377
|
+
grow_prop_count=jnp.zeros((), int),
|
|
378
|
+
grow_acc_count=jnp.zeros((), int),
|
|
379
|
+
prune_prop_count=jnp.zeros((), int),
|
|
380
|
+
prune_acc_count=jnp.zeros((), int),
|
|
381
|
+
p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
|
|
382
|
+
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
383
|
+
leaf_indices=jnp.ones(
|
|
384
|
+
(num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
|
|
385
|
+
),
|
|
386
|
+
min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
|
|
387
|
+
min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
|
|
227
388
|
resid_batch_size=resid_batch_size,
|
|
228
389
|
count_batch_size=count_batch_size,
|
|
390
|
+
log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
|
|
391
|
+
log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
|
|
392
|
+
sigma_mu2=jnp.asarray(sigma_mu2),
|
|
393
|
+
log_s=_asarray_or_none(log_s),
|
|
394
|
+
theta=_asarray_or_none(theta),
|
|
395
|
+
rho=_asarray_or_none(rho),
|
|
396
|
+
a=_asarray_or_none(a),
|
|
397
|
+
b=_asarray_or_none(b),
|
|
229
398
|
),
|
|
230
399
|
)
|
|
231
400
|
|
|
232
|
-
if save_ratios:
|
|
233
|
-
bart['ratios'] = dict(
|
|
234
|
-
log_trans_prior=jnp.full(num_trees, jnp.nan),
|
|
235
|
-
log_likelihood=jnp.full(num_trees, jnp.nan),
|
|
236
|
-
)
|
|
237
401
|
|
|
238
|
-
|
|
402
|
+
def _all_none_or_not_none(*args):
|
|
403
|
+
is_none = [x is None for x in args]
|
|
404
|
+
return all(is_none) or not any(is_none)
|
|
239
405
|
|
|
240
406
|
|
|
241
|
-
def
|
|
242
|
-
|
|
407
|
+
def _asarray_or_none(x):
|
|
408
|
+
if x is None:
|
|
409
|
+
return None
|
|
410
|
+
return jnp.asarray(x)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def _choose_suffstat_batch_size(
|
|
414
|
+
resid_batch_size, count_batch_size, y, forest_size
|
|
415
|
+
) -> tuple[int | None, ...]:
|
|
416
|
+
@cache
|
|
243
417
|
def get_platform():
|
|
244
418
|
try:
|
|
245
419
|
device = y.devices().pop()
|
|
@@ -247,16 +421,17 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
247
421
|
device = jax.devices()[0]
|
|
248
422
|
platform = device.platform
|
|
249
423
|
if platform not in ('cpu', 'gpu'):
|
|
250
|
-
|
|
424
|
+
msg = f'Unknown platform: {platform}'
|
|
425
|
+
raise KeyError(msg)
|
|
251
426
|
return platform
|
|
252
427
|
|
|
253
428
|
if resid_batch_size == 'auto':
|
|
254
429
|
platform = get_platform()
|
|
255
430
|
n = max(1, y.size)
|
|
256
431
|
if platform == 'cpu':
|
|
257
|
-
resid_batch_size = 2 **
|
|
432
|
+
resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
|
|
258
433
|
elif platform == 'gpu':
|
|
259
|
-
resid_batch_size = 2 **
|
|
434
|
+
resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
|
|
260
435
|
resid_batch_size = max(1, resid_batch_size)
|
|
261
436
|
|
|
262
437
|
if count_batch_size == 'auto':
|
|
@@ -265,253 +440,381 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
265
440
|
count_batch_size = None
|
|
266
441
|
elif platform == 'gpu':
|
|
267
442
|
n = max(1, y.size)
|
|
268
|
-
count_batch_size = 2 **
|
|
443
|
+
count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
|
|
269
444
|
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
270
445
|
max_memory = 2**29
|
|
271
446
|
itemsize = 4
|
|
272
|
-
min_batch_size =
|
|
447
|
+
min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
|
|
273
448
|
count_batch_size = max(count_batch_size, min_batch_size)
|
|
274
449
|
count_batch_size = max(1, count_batch_size)
|
|
275
450
|
|
|
276
451
|
return resid_batch_size, count_batch_size
|
|
277
452
|
|
|
278
453
|
|
|
279
|
-
|
|
454
|
+
@jax.jit
|
|
455
|
+
def step(key: Key[Array, ''], bart: State) -> State:
|
|
280
456
|
"""
|
|
281
|
-
|
|
457
|
+
Do one MCMC step.
|
|
282
458
|
|
|
283
459
|
Parameters
|
|
284
460
|
----------
|
|
285
|
-
key
|
|
461
|
+
key
|
|
286
462
|
A jax random key.
|
|
287
|
-
bart
|
|
463
|
+
bart
|
|
288
464
|
A BART mcmc state, as created by `init`.
|
|
289
465
|
|
|
290
466
|
Returns
|
|
291
467
|
-------
|
|
292
|
-
|
|
293
|
-
The new BART mcmc state.
|
|
468
|
+
The new BART mcmc state.
|
|
294
469
|
"""
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
470
|
+
keys = split(key)
|
|
471
|
+
|
|
472
|
+
if bart.y.dtype == bool: # binary regression
|
|
473
|
+
bart = replace(bart, sigma2=jnp.float32(1))
|
|
474
|
+
bart = step_trees(keys.pop(), bart)
|
|
475
|
+
bart = replace(bart, sigma2=None)
|
|
476
|
+
return step_z(keys.pop(), bart)
|
|
298
477
|
|
|
478
|
+
else: # continuous regression
|
|
479
|
+
bart = step_trees(keys.pop(), bart)
|
|
480
|
+
return step_sigma(keys.pop(), bart)
|
|
299
481
|
|
|
300
|
-
|
|
482
|
+
|
|
483
|
+
def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
301
484
|
"""
|
|
302
485
|
Forest sampling step of BART MCMC.
|
|
303
486
|
|
|
304
487
|
Parameters
|
|
305
488
|
----------
|
|
306
|
-
key
|
|
489
|
+
key
|
|
307
490
|
A jax random key.
|
|
308
|
-
bart
|
|
491
|
+
bart
|
|
309
492
|
A BART mcmc state, as created by `init`.
|
|
310
493
|
|
|
311
494
|
Returns
|
|
312
495
|
-------
|
|
313
|
-
|
|
314
|
-
The new BART mcmc state.
|
|
496
|
+
The new BART mcmc state.
|
|
315
497
|
|
|
316
498
|
Notes
|
|
317
499
|
-----
|
|
318
500
|
This function zeroes the proposal counters.
|
|
319
501
|
"""
|
|
320
|
-
|
|
321
|
-
moves =
|
|
322
|
-
return accept_moves_and_sample_leaves(
|
|
502
|
+
keys = split(key)
|
|
503
|
+
moves = propose_moves(keys.pop(), bart.forest)
|
|
504
|
+
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
505
|
+
|
|
323
506
|
|
|
507
|
+
class Moves(Module):
|
|
508
|
+
"""
|
|
509
|
+
Moves proposed to modify each tree.
|
|
324
510
|
|
|
325
|
-
|
|
511
|
+
Parameters
|
|
512
|
+
----------
|
|
513
|
+
allowed
|
|
514
|
+
Whether there is a possible move. If `False`, the other values may not
|
|
515
|
+
make sense. The only case in which a move is marked as allowed but is
|
|
516
|
+
then vetoed is if it does not satisfy `min_points_per_leaf`, which for
|
|
517
|
+
efficiency is implemented post-hoc without changing the rest of the
|
|
518
|
+
MCMC logic.
|
|
519
|
+
grow
|
|
520
|
+
Whether the move is a grow move or a prune move.
|
|
521
|
+
num_growable
|
|
522
|
+
The number of growable leaves in the original tree.
|
|
523
|
+
node
|
|
524
|
+
The index of the leaf to grow or node to prune.
|
|
525
|
+
left
|
|
526
|
+
right
|
|
527
|
+
The indices of the children of 'node'.
|
|
528
|
+
partial_ratio
|
|
529
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
530
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
531
|
+
probability that the children of the modified node are terminal. If the
|
|
532
|
+
move is PRUNE, the ratio is inverted. `None` once
|
|
533
|
+
`log_trans_prior_ratio` has been computed.
|
|
534
|
+
log_trans_prior_ratio
|
|
535
|
+
The logarithm of the product of the transition and prior terms of the
|
|
536
|
+
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
537
|
+
`None` if not yet computed. If PRUNE, the log-ratio is negated.
|
|
538
|
+
grow_var
|
|
539
|
+
The decision axes of the new rules.
|
|
540
|
+
grow_split
|
|
541
|
+
The decision boundaries of the new rules.
|
|
542
|
+
var_tree
|
|
543
|
+
The updated decision axes of the trees, valid whatever move.
|
|
544
|
+
affluence_tree
|
|
545
|
+
A partially updated `affluence_tree`, marking non-leaf nodes that would
|
|
546
|
+
become leaves if the move was accepted. This mark initially (out of
|
|
547
|
+
`propose_moves`) takes into account if there would be available decision
|
|
548
|
+
rules to grow the leaf, and whether there are enough datapoints in the
|
|
549
|
+
node is marked in `accept_moves_parallel_stage`.
|
|
550
|
+
logu
|
|
551
|
+
The logarithm of a uniform (0, 1] random variable to be used to
|
|
552
|
+
accept the move. It's in (-oo, 0].
|
|
553
|
+
acc
|
|
554
|
+
Whether the move was accepted. `None` if not yet computed.
|
|
555
|
+
to_prune
|
|
556
|
+
Whether the final operation to apply the move is pruning. This indicates
|
|
557
|
+
an accepted prune move or a rejected grow move. `None` if not yet
|
|
558
|
+
computed.
|
|
559
|
+
"""
|
|
560
|
+
|
|
561
|
+
allowed: Bool[Array, ' num_trees']
|
|
562
|
+
grow: Bool[Array, ' num_trees']
|
|
563
|
+
num_growable: UInt[Array, ' num_trees']
|
|
564
|
+
node: UInt[Array, ' num_trees']
|
|
565
|
+
left: UInt[Array, ' num_trees']
|
|
566
|
+
right: UInt[Array, ' num_trees']
|
|
567
|
+
partial_ratio: Float32[Array, ' num_trees'] | None
|
|
568
|
+
log_trans_prior_ratio: None | Float32[Array, ' num_trees']
|
|
569
|
+
grow_var: UInt[Array, ' num_trees']
|
|
570
|
+
grow_split: UInt[Array, ' num_trees']
|
|
571
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
572
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
573
|
+
logu: Float32[Array, ' num_trees']
|
|
574
|
+
acc: None | Bool[Array, ' num_trees']
|
|
575
|
+
to_prune: None | Bool[Array, ' num_trees']
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
|
|
326
579
|
"""
|
|
327
580
|
Propose moves for all the trees.
|
|
328
581
|
|
|
582
|
+
There are two types of moves: GROW (convert a leaf to a decision node and
|
|
583
|
+
add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
|
|
584
|
+
leaf, deleting its children).
|
|
585
|
+
|
|
329
586
|
Parameters
|
|
330
587
|
----------
|
|
331
|
-
key
|
|
588
|
+
key
|
|
332
589
|
A jax random key.
|
|
333
|
-
|
|
334
|
-
BART
|
|
590
|
+
forest
|
|
591
|
+
The `forest` field of a BART MCMC state.
|
|
335
592
|
|
|
336
593
|
Returns
|
|
337
594
|
-------
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
Whether the move is possible.
|
|
343
|
-
'grow' : bool array (num_trees,)
|
|
344
|
-
Whether the move is a grow move or a prune move.
|
|
345
|
-
'num_growable' : int array (num_trees,)
|
|
346
|
-
The number of growable leaves in the original tree.
|
|
347
|
-
'node' : int array (num_trees,)
|
|
348
|
-
The index of the leaf to grow or node to prune.
|
|
349
|
-
'left', 'right' : int array (num_trees,)
|
|
350
|
-
The indices of the children of 'node'.
|
|
351
|
-
'partial_ratio' : float array (num_trees,)
|
|
352
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
353
|
-
the likelihood ratio and the probability of proposing the prune
|
|
354
|
-
move. If the move is Prune, the ratio is inverted.
|
|
355
|
-
'grow_var' : int array (num_trees,)
|
|
356
|
-
The decision axes of the new rules.
|
|
357
|
-
'grow_split' : int array (num_trees,)
|
|
358
|
-
The decision boundaries of the new rules.
|
|
359
|
-
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
360
|
-
The updated decision axes of the trees, valid whatever move.
|
|
361
|
-
'logu' : float array (num_trees,)
|
|
362
|
-
The logarithm of a uniform (0, 1] random variable to be used to
|
|
363
|
-
accept the move. It's in (-oo, 0].
|
|
364
|
-
"""
|
|
365
|
-
ntree = bart['leaf_trees'].shape[0]
|
|
366
|
-
key = random.split(key, 1 + ntree)
|
|
367
|
-
key, subkey = key[0], key[1:]
|
|
595
|
+
The proposed move for each tree.
|
|
596
|
+
"""
|
|
597
|
+
num_trees, _ = forest.leaf_tree.shape
|
|
598
|
+
keys = split(key, 1 + 2 * num_trees)
|
|
368
599
|
|
|
369
600
|
# compute moves
|
|
370
|
-
grow_moves
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
601
|
+
grow_moves = propose_grow_moves(
|
|
602
|
+
keys.pop(num_trees),
|
|
603
|
+
forest.var_tree,
|
|
604
|
+
forest.split_tree,
|
|
605
|
+
forest.affluence_tree,
|
|
606
|
+
forest.max_split,
|
|
607
|
+
forest.blocked_vars,
|
|
608
|
+
forest.p_nonterminal,
|
|
609
|
+
forest.p_propose_grow,
|
|
610
|
+
forest.log_s,
|
|
611
|
+
)
|
|
612
|
+
prune_moves = propose_prune_moves(
|
|
613
|
+
keys.pop(num_trees),
|
|
614
|
+
forest.split_tree,
|
|
615
|
+
grow_moves.affluence_tree,
|
|
616
|
+
forest.p_nonterminal,
|
|
617
|
+
forest.p_propose_grow,
|
|
378
618
|
)
|
|
379
619
|
|
|
380
|
-
u,
|
|
620
|
+
u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
|
|
381
621
|
|
|
382
622
|
# choose between grow or prune
|
|
383
|
-
|
|
384
|
-
|
|
623
|
+
p_grow = jnp.where(
|
|
624
|
+
grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
|
|
625
|
+
)
|
|
385
626
|
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
386
627
|
|
|
387
628
|
# compute children indices
|
|
388
|
-
node = jnp.where(grow, grow_moves
|
|
629
|
+
node = jnp.where(grow, grow_moves.node, prune_moves.node)
|
|
389
630
|
left = node << 1
|
|
390
631
|
right = left + 1
|
|
391
632
|
|
|
392
|
-
return
|
|
393
|
-
allowed=
|
|
633
|
+
return Moves(
|
|
634
|
+
allowed=grow_moves.allowed | prune_moves.allowed,
|
|
394
635
|
grow=grow,
|
|
395
|
-
num_growable=grow_moves
|
|
636
|
+
num_growable=grow_moves.num_growable,
|
|
396
637
|
node=node,
|
|
397
638
|
left=left,
|
|
398
639
|
right=right,
|
|
399
640
|
partial_ratio=jnp.where(
|
|
400
|
-
grow, grow_moves
|
|
641
|
+
grow, grow_moves.partial_ratio, prune_moves.partial_ratio
|
|
401
642
|
),
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
643
|
+
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
644
|
+
grow_var=grow_moves.var,
|
|
645
|
+
grow_split=grow_moves.split,
|
|
646
|
+
# var_tree does not need to be updated if prune
|
|
647
|
+
var_tree=grow_moves.var_tree,
|
|
648
|
+
# affluence_tree is updated for both moves unconditionally, prune last
|
|
649
|
+
affluence_tree=prune_moves.affluence_tree,
|
|
650
|
+
logu=jnp.log1p(-exp1mlogu),
|
|
651
|
+
acc=None, # will be set in accept_moves_sequential_stage
|
|
652
|
+
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
406
653
|
)
|
|
407
654
|
|
|
408
655
|
|
|
409
|
-
|
|
410
|
-
def _sample_moves_vmap_trees(*args):
|
|
411
|
-
key, args = args[0], args[1:]
|
|
412
|
-
key, key1 = random.split(key)
|
|
413
|
-
grow = grow_move(key, *args)
|
|
414
|
-
prune = prune_move(key1, *args)
|
|
415
|
-
return grow, prune
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
def grow_move(
|
|
419
|
-
key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
|
|
420
|
-
):
|
|
656
|
+
class GrowMoves(Module):
|
|
421
657
|
"""
|
|
422
|
-
|
|
658
|
+
Represent a proposed grow move for each tree.
|
|
423
659
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
660
|
+
Parameters
|
|
661
|
+
----------
|
|
662
|
+
allowed
|
|
663
|
+
Whether the move is allowed for proposal.
|
|
664
|
+
num_growable
|
|
665
|
+
The number of leaves that can be proposed for grow.
|
|
666
|
+
node
|
|
667
|
+
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
668
|
+
leaves.
|
|
669
|
+
var
|
|
670
|
+
split
|
|
671
|
+
The decision axis and boundary of the new rule.
|
|
672
|
+
partial_ratio
|
|
673
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
674
|
+
the likelihood ratio and the probability of proposing the prune
|
|
675
|
+
move.
|
|
676
|
+
var_tree
|
|
677
|
+
The updated decision axes of the tree.
|
|
678
|
+
affluence_tree
|
|
679
|
+
A partially updated `affluence_tree` that marks each new leaf that
|
|
680
|
+
would be produced as `True` if it would have available decision rules.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
allowed: Bool[Array, ' num_trees']
|
|
684
|
+
num_growable: UInt[Array, ' num_trees']
|
|
685
|
+
node: UInt[Array, ' num_trees']
|
|
686
|
+
var: UInt[Array, ' num_trees']
|
|
687
|
+
split: UInt[Array, ' num_trees']
|
|
688
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
689
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
690
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
|
|
694
|
+
def propose_grow_moves(
|
|
695
|
+
key: Key[Array, ' num_trees'],
|
|
696
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
697
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
698
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
|
|
699
|
+
max_split: UInt[Array, ' p'],
|
|
700
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
701
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
702
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
703
|
+
log_s: Float32[Array, ' p'] | None,
|
|
704
|
+
) -> GrowMoves:
|
|
705
|
+
"""
|
|
706
|
+
Propose a GROW move for each tree.
|
|
707
|
+
|
|
708
|
+
A GROW move picks a leaf node and converts it to a non-terminal node with
|
|
709
|
+
two leaf children.
|
|
427
710
|
|
|
428
711
|
Parameters
|
|
429
712
|
----------
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
713
|
+
key
|
|
714
|
+
A jax random key.
|
|
715
|
+
var_tree
|
|
716
|
+
The splitting axes of the tree.
|
|
717
|
+
split_tree
|
|
433
718
|
The splitting points of the tree.
|
|
434
|
-
affluence_tree
|
|
435
|
-
Whether
|
|
436
|
-
max_split
|
|
719
|
+
affluence_tree
|
|
720
|
+
Whether each leaf has enough points to be grown.
|
|
721
|
+
max_split
|
|
437
722
|
The maximum split index for each variable.
|
|
438
|
-
|
|
439
|
-
The
|
|
440
|
-
|
|
723
|
+
blocked_vars
|
|
724
|
+
The indices of the variables that have no available cutpoints.
|
|
725
|
+
p_nonterminal
|
|
726
|
+
The a priori probability of a node to be nonterminal conditional on the
|
|
727
|
+
ancestors, including at the maximum depth where it should be zero.
|
|
728
|
+
p_propose_grow
|
|
441
729
|
The unnormalized probability of choosing a leaf to grow.
|
|
442
|
-
|
|
443
|
-
|
|
730
|
+
log_s
|
|
731
|
+
Unnormalized log-probability used to choose a variable to split on
|
|
732
|
+
amongst the available ones.
|
|
444
733
|
|
|
445
734
|
Returns
|
|
446
735
|
-------
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
'partial_ratio' : float
|
|
458
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
459
|
-
the likelihood ratio and the probability of proposing the prune
|
|
460
|
-
move.
|
|
461
|
-
'var_tree' : array (2 ** (d - 1),)
|
|
462
|
-
The updated decision axes of the tree.
|
|
463
|
-
"""
|
|
464
|
-
|
|
465
|
-
key, key1, key2 = random.split(key, 3)
|
|
736
|
+
An object representing the proposed move.
|
|
737
|
+
|
|
738
|
+
Notes
|
|
739
|
+
-----
|
|
740
|
+
The move is not proposed if each leaf is already at maximum depth, or has
|
|
741
|
+
less datapoints than the requested threshold `min_points_per_decision_node`,
|
|
742
|
+
or it does not have any available decision rules given its ancestors. This
|
|
743
|
+
is marked by setting `allowed` to `False` and `num_growable` to 0.
|
|
744
|
+
"""
|
|
745
|
+
keys = split(key, 3)
|
|
466
746
|
|
|
467
747
|
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
468
|
-
|
|
748
|
+
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
469
749
|
)
|
|
470
750
|
|
|
471
|
-
|
|
472
|
-
|
|
751
|
+
# sample a decision rule
|
|
752
|
+
var, num_available_var = choose_variable(
|
|
753
|
+
keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
|
|
754
|
+
)
|
|
755
|
+
split_idx, l, r = choose_split(
|
|
756
|
+
keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
|
|
757
|
+
)
|
|
473
758
|
|
|
474
|
-
|
|
759
|
+
# determine if the new leaves would have available decision rules; if the
|
|
760
|
+
# move is blocked, these values may not make sense
|
|
761
|
+
left_growable = right_growable = num_available_var > 1
|
|
762
|
+
left_growable |= l < split_idx
|
|
763
|
+
right_growable |= split_idx + 1 < r
|
|
764
|
+
left = leaf_to_grow << 1
|
|
765
|
+
right = left + 1
|
|
766
|
+
affluence_tree = affluence_tree.at[left].set(left_growable)
|
|
767
|
+
affluence_tree = affluence_tree.at[right].set(right_growable)
|
|
475
768
|
|
|
476
769
|
ratio = compute_partial_ratio(
|
|
477
770
|
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
478
771
|
)
|
|
479
772
|
|
|
480
|
-
return
|
|
773
|
+
return GrowMoves(
|
|
774
|
+
allowed=num_growable > 0,
|
|
481
775
|
num_growable=num_growable,
|
|
482
776
|
node=leaf_to_grow,
|
|
483
777
|
var=var,
|
|
484
|
-
split=
|
|
778
|
+
split=split_idx,
|
|
485
779
|
partial_ratio=ratio,
|
|
486
|
-
var_tree=var_tree,
|
|
780
|
+
var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
|
|
781
|
+
affluence_tree=affluence_tree,
|
|
487
782
|
)
|
|
488
783
|
|
|
489
784
|
|
|
490
|
-
def choose_leaf(
|
|
785
|
+
def choose_leaf(
|
|
786
|
+
key: Key[Array, ''],
|
|
787
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
788
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
789
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
790
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
491
791
|
"""
|
|
492
792
|
Choose a leaf node to grow in a tree.
|
|
493
793
|
|
|
494
794
|
Parameters
|
|
495
795
|
----------
|
|
496
|
-
|
|
796
|
+
key
|
|
797
|
+
A jax random key.
|
|
798
|
+
split_tree
|
|
497
799
|
The splitting points of the tree.
|
|
498
|
-
affluence_tree
|
|
499
|
-
Whether a leaf has enough points
|
|
500
|
-
|
|
800
|
+
affluence_tree
|
|
801
|
+
Whether a leaf has enough points that it could be split into two leaves
|
|
802
|
+
satisfying the `min_points_per_leaf` requirement.
|
|
803
|
+
p_propose_grow
|
|
501
804
|
The unnormalized probability of choosing a leaf to grow.
|
|
502
|
-
key : jax.dtypes.prng_key array
|
|
503
|
-
A jax random key.
|
|
504
805
|
|
|
505
806
|
Returns
|
|
506
807
|
-------
|
|
507
|
-
leaf_to_grow :
|
|
808
|
+
leaf_to_grow : Int32[Array, '']
|
|
508
809
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
509
810
|
``2 ** d``.
|
|
510
|
-
num_growable :
|
|
511
|
-
The number of leaf nodes that can be grown.
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
811
|
+
num_growable : Int32[Array, '']
|
|
812
|
+
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
813
|
+
and have at least twice `min_points_per_leaf`.
|
|
814
|
+
prob_choose : Float32[Array, '']
|
|
815
|
+
The (normalized) probability that this function had to choose that
|
|
816
|
+
specific leaf, given the arguments.
|
|
817
|
+
num_prunable : Int32[Array, '']
|
|
515
818
|
The number of leaf parents that could be pruned, after converting the
|
|
516
819
|
selected leaf to a non-terminal node.
|
|
517
820
|
"""
|
|
@@ -520,145 +823,189 @@ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
|
|
|
520
823
|
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
521
824
|
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
522
825
|
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
523
|
-
prob_choose = distr[leaf_to_grow] / distr_norm
|
|
826
|
+
prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
|
|
524
827
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
525
828
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
526
829
|
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
527
830
|
|
|
528
831
|
|
|
529
|
-
def growable_leaves(
|
|
832
|
+
def growable_leaves(
|
|
833
|
+
split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
|
|
834
|
+
) -> Bool[Array, ' 2**(d-1)']:
|
|
530
835
|
"""
|
|
531
836
|
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
532
837
|
|
|
838
|
+
The condition is that a leaf is not at the bottom level, has available
|
|
839
|
+
decision rules given its ancestors, and has at least
|
|
840
|
+
`min_points_per_decision_node` points.
|
|
841
|
+
|
|
533
842
|
Parameters
|
|
534
843
|
----------
|
|
535
|
-
split_tree
|
|
844
|
+
split_tree
|
|
536
845
|
The splitting points of the tree.
|
|
537
|
-
affluence_tree
|
|
538
|
-
|
|
846
|
+
affluence_tree
|
|
847
|
+
Marks leaves that can be grown.
|
|
539
848
|
|
|
540
849
|
Returns
|
|
541
850
|
-------
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
851
|
+
The mask indicating the leaf nodes that can be proposed to grow.
|
|
852
|
+
|
|
853
|
+
Notes
|
|
854
|
+
-----
|
|
855
|
+
This function needs `split_tree` and not just `affluence_tree` because
|
|
856
|
+
`affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
|
|
546
857
|
"""
|
|
547
|
-
|
|
548
|
-
if affluence_tree is not None:
|
|
549
|
-
is_growable &= affluence_tree
|
|
550
|
-
return is_growable
|
|
858
|
+
return grove.is_actual_leaf(split_tree) & affluence_tree
|
|
551
859
|
|
|
552
860
|
|
|
553
|
-
def categorical(
|
|
861
|
+
def categorical(
|
|
862
|
+
key: Key[Array, ''], distr: Float32[Array, ' n']
|
|
863
|
+
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
554
864
|
"""
|
|
555
865
|
Return a random integer from an arbitrary distribution.
|
|
556
866
|
|
|
557
867
|
Parameters
|
|
558
868
|
----------
|
|
559
|
-
key
|
|
869
|
+
key
|
|
560
870
|
A jax random key.
|
|
561
|
-
distr
|
|
871
|
+
distr
|
|
562
872
|
An unnormalized probability distribution.
|
|
563
873
|
|
|
564
874
|
Returns
|
|
565
875
|
-------
|
|
566
|
-
u :
|
|
876
|
+
u : Int32[Array, '']
|
|
567
877
|
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
568
878
|
return ``n``.
|
|
569
|
-
norm :
|
|
879
|
+
norm : Float32[Array, '']
|
|
570
880
|
The sum of `distr`.
|
|
881
|
+
|
|
882
|
+
Notes
|
|
883
|
+
-----
|
|
884
|
+
This function uses a cumsum instead of the Gumbel trick, so it's ok only
|
|
885
|
+
for small ranges with probabilities well greater than 0.
|
|
571
886
|
"""
|
|
572
887
|
ecdf = jnp.cumsum(distr)
|
|
573
888
|
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
574
889
|
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
575
890
|
|
|
576
891
|
|
|
577
|
-
def choose_variable(
|
|
892
|
+
def choose_variable(
|
|
893
|
+
key: Key[Array, ''],
|
|
894
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
895
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
896
|
+
max_split: UInt[Array, ' p'],
|
|
897
|
+
leaf_index: Int32[Array, ''],
|
|
898
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
899
|
+
log_s: Float32[Array, ' p'] | None,
|
|
900
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
578
901
|
"""
|
|
579
902
|
Choose a variable to split on for a new non-terminal node.
|
|
580
903
|
|
|
581
904
|
Parameters
|
|
582
905
|
----------
|
|
583
|
-
|
|
906
|
+
key
|
|
907
|
+
A jax random key.
|
|
908
|
+
var_tree
|
|
584
909
|
The variable indices of the tree.
|
|
585
|
-
split_tree
|
|
910
|
+
split_tree
|
|
586
911
|
The splitting points of the tree.
|
|
587
|
-
max_split
|
|
912
|
+
max_split
|
|
588
913
|
The maximum split index for each variable.
|
|
589
|
-
leaf_index
|
|
914
|
+
leaf_index
|
|
590
915
|
The index of the leaf to grow.
|
|
591
|
-
|
|
592
|
-
|
|
916
|
+
blocked_vars
|
|
917
|
+
The indices of the variables that have no available cutpoints. If
|
|
918
|
+
`None`, all variables are assumed unblocked.
|
|
919
|
+
log_s
|
|
920
|
+
The logarithm of the prior probability for choosing a variable. If
|
|
921
|
+
`None`, use a uniform distribution.
|
|
593
922
|
|
|
594
923
|
Returns
|
|
595
924
|
-------
|
|
596
|
-
var :
|
|
925
|
+
var : Int32[Array, '']
|
|
597
926
|
The index of the variable to split on.
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
The variable is chosen among the variables that have a non-empty range of
|
|
602
|
-
allowed splits. If no variable has a non-empty range, return `p`.
|
|
927
|
+
num_available_var : Int32[Array, '']
|
|
928
|
+
The number of variables with available decision rules `var` was chosen
|
|
929
|
+
from.
|
|
603
930
|
"""
|
|
604
931
|
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
605
|
-
|
|
932
|
+
if blocked_vars is not None:
|
|
933
|
+
var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
|
|
934
|
+
|
|
935
|
+
if log_s is None:
|
|
936
|
+
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
937
|
+
else:
|
|
938
|
+
return categorical_exclude(key, log_s, var_to_ignore)
|
|
606
939
|
|
|
607
940
|
|
|
608
|
-
def fully_used_variables(
|
|
941
|
+
def fully_used_variables(
|
|
942
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
943
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
944
|
+
max_split: UInt[Array, ' p'],
|
|
945
|
+
leaf_index: Int32[Array, ''],
|
|
946
|
+
) -> UInt[Array, ' d-2']:
|
|
609
947
|
"""
|
|
610
|
-
|
|
948
|
+
Find variables in the ancestors of a node that have an empty split range.
|
|
611
949
|
|
|
612
950
|
Parameters
|
|
613
951
|
----------
|
|
614
|
-
var_tree
|
|
952
|
+
var_tree
|
|
615
953
|
The variable indices of the tree.
|
|
616
|
-
split_tree
|
|
954
|
+
split_tree
|
|
617
955
|
The splitting points of the tree.
|
|
618
|
-
max_split
|
|
956
|
+
max_split
|
|
619
957
|
The maximum split index for each variable.
|
|
620
|
-
leaf_index
|
|
958
|
+
leaf_index
|
|
621
959
|
The index of the node, assumed to be valid for `var_tree`.
|
|
622
960
|
|
|
623
961
|
Returns
|
|
624
962
|
-------
|
|
625
|
-
|
|
626
|
-
The indices of the variables that have an empty split range. Since the
|
|
627
|
-
number of such variables is not fixed, unused values in the array are
|
|
628
|
-
filled with `p`. The fill values are not guaranteed to be placed in any
|
|
629
|
-
particular order. Variables may appear more than once.
|
|
630
|
-
"""
|
|
963
|
+
The indices of the variables that have an empty split range.
|
|
631
964
|
|
|
965
|
+
Notes
|
|
966
|
+
-----
|
|
967
|
+
The number of unused variables is not known in advance. Unused values in the
|
|
968
|
+
array are filled with `p`. The fill values are not guaranteed to be placed
|
|
969
|
+
in any particular order, and variables may appear more than once.
|
|
970
|
+
"""
|
|
632
971
|
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
633
972
|
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
634
973
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
635
974
|
num_split = r - l
|
|
636
975
|
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
976
|
+
# the type of var_to_ignore is already sufficient to hold max_split.size,
|
|
977
|
+
# see ancestor_variables()
|
|
637
978
|
|
|
638
979
|
|
|
639
|
-
def ancestor_variables(
|
|
980
|
+
def ancestor_variables(
|
|
981
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
982
|
+
max_split: UInt[Array, ' p'],
|
|
983
|
+
node_index: Int32[Array, ''],
|
|
984
|
+
) -> UInt[Array, ' d-2']:
|
|
640
985
|
"""
|
|
641
986
|
Return the list of variables in the ancestors of a node.
|
|
642
987
|
|
|
643
988
|
Parameters
|
|
644
989
|
----------
|
|
645
|
-
var_tree
|
|
990
|
+
var_tree
|
|
646
991
|
The variable indices of the tree.
|
|
647
|
-
max_split
|
|
992
|
+
max_split
|
|
648
993
|
The maximum split index for each variable. Used only to get `p`.
|
|
649
|
-
node_index
|
|
994
|
+
node_index
|
|
650
995
|
The index of the node, assumed to be valid for `var_tree`.
|
|
651
996
|
|
|
652
997
|
Returns
|
|
653
998
|
-------
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
999
|
+
The variable indices of the ancestors of the node.
|
|
1000
|
+
|
|
1001
|
+
Notes
|
|
1002
|
+
-----
|
|
1003
|
+
The ancestors are the nodes going from the root to the parent of the node.
|
|
1004
|
+
The number of ancestors is not known at tracing time; unused spots in the
|
|
1005
|
+
output array are filled with `p`.
|
|
657
1006
|
"""
|
|
658
1007
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
659
|
-
ancestor_vars = jnp.zeros(
|
|
660
|
-
max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
|
|
661
|
-
)
|
|
1008
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
|
|
662
1009
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
663
1010
|
|
|
664
1011
|
def loop(carry, _):
|
|
@@ -673,33 +1020,38 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
673
1020
|
return ancestor_vars
|
|
674
1021
|
|
|
675
1022
|
|
|
676
|
-
def split_range(
|
|
1023
|
+
def split_range(
|
|
1024
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1025
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1026
|
+
max_split: UInt[Array, ' p'],
|
|
1027
|
+
node_index: Int32[Array, ''],
|
|
1028
|
+
ref_var: Int32[Array, ''],
|
|
1029
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
677
1030
|
"""
|
|
678
1031
|
Return the range of allowed splits for a variable at a given node.
|
|
679
1032
|
|
|
680
1033
|
Parameters
|
|
681
1034
|
----------
|
|
682
|
-
var_tree
|
|
1035
|
+
var_tree
|
|
683
1036
|
The variable indices of the tree.
|
|
684
|
-
split_tree
|
|
1037
|
+
split_tree
|
|
685
1038
|
The splitting points of the tree.
|
|
686
|
-
max_split
|
|
1039
|
+
max_split
|
|
687
1040
|
The maximum split index for each variable.
|
|
688
|
-
node_index
|
|
1041
|
+
node_index
|
|
689
1042
|
The index of the node, assumed to be valid for `var_tree`.
|
|
690
|
-
ref_var
|
|
1043
|
+
ref_var
|
|
691
1044
|
The variable for which to measure the split range.
|
|
692
1045
|
|
|
693
1046
|
Returns
|
|
694
1047
|
-------
|
|
695
|
-
l, r
|
|
696
|
-
The range of allowed splits is [l, r).
|
|
1048
|
+
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
|
|
697
1049
|
"""
|
|
698
1050
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
699
1051
|
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
700
1052
|
jnp.int32
|
|
701
1053
|
)
|
|
702
|
-
carry = 0, initial_r, node_index
|
|
1054
|
+
carry = jnp.int32(0), initial_r, node_index
|
|
703
1055
|
|
|
704
1056
|
def loop(carry, _):
|
|
705
1057
|
l, r, index = carry
|
|
@@ -715,259 +1067,501 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
|
715
1067
|
return l + 1, r
|
|
716
1068
|
|
|
717
1069
|
|
|
718
|
-
def randint_exclude(
|
|
1070
|
+
def randint_exclude(
|
|
1071
|
+
key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
|
|
1072
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
719
1073
|
"""
|
|
720
1074
|
Return a random integer in a range, excluding some values.
|
|
721
1075
|
|
|
722
1076
|
Parameters
|
|
723
1077
|
----------
|
|
724
|
-
key
|
|
1078
|
+
key
|
|
725
1079
|
A jax random key.
|
|
726
|
-
sup
|
|
1080
|
+
sup
|
|
727
1081
|
The exclusive upper bound of the range.
|
|
728
|
-
exclude
|
|
1082
|
+
exclude
|
|
729
1083
|
The values to exclude from the range. Values greater than or equal to
|
|
730
1084
|
`sup` are ignored. Values can appear more than once.
|
|
731
1085
|
|
|
732
1086
|
Returns
|
|
733
1087
|
-------
|
|
734
|
-
u :
|
|
735
|
-
A random integer in the range ``[0, sup)
|
|
736
|
-
|
|
737
|
-
|
|
1088
|
+
u : Int32[Array, '']
|
|
1089
|
+
A random integer `u` in the range ``[0, sup)`` such that ``u not in
|
|
1090
|
+
exclude``.
|
|
1091
|
+
num_allowed : Int32[Array, '']
|
|
1092
|
+
The number of integers in the range that were not excluded.
|
|
1093
|
+
|
|
1094
|
+
Notes
|
|
1095
|
+
-----
|
|
1096
|
+
If all values in the range are excluded, return `sup`.
|
|
738
1097
|
"""
|
|
739
|
-
exclude =
|
|
740
|
-
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
1098
|
+
exclude, num_allowed = _process_exclude(sup, exclude)
|
|
741
1099
|
u = random.randint(key, (), 0, num_allowed)
|
|
742
1100
|
|
|
743
|
-
def loop(u,
|
|
744
|
-
return jnp.where(
|
|
1101
|
+
def loop(u, i_excluded):
|
|
1102
|
+
return jnp.where(i_excluded <= u, u + 1, u), None
|
|
745
1103
|
|
|
746
1104
|
u, _ = lax.scan(loop, u, exclude)
|
|
747
|
-
return u
|
|
1105
|
+
return u, num_allowed
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _process_exclude(sup, exclude):
|
|
1109
|
+
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
1110
|
+
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
1111
|
+
return exclude, num_allowed
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def categorical_exclude(
|
|
1115
|
+
key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
|
|
1116
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
1117
|
+
"""
|
|
1118
|
+
Draw from a categorical distribution, excluding a set of values.
|
|
1119
|
+
|
|
1120
|
+
Parameters
|
|
1121
|
+
----------
|
|
1122
|
+
key
|
|
1123
|
+
A jax random key.
|
|
1124
|
+
logits
|
|
1125
|
+
The unnormalized log-probabilities of each category.
|
|
1126
|
+
exclude
|
|
1127
|
+
The values to exclude from the range [0, k). Values greater than or
|
|
1128
|
+
equal to `logits.size` are ignored. Values can appear more than once.
|
|
1129
|
+
|
|
1130
|
+
Returns
|
|
1131
|
+
-------
|
|
1132
|
+
u : Int32[Array, '']
|
|
1133
|
+
A random integer in the range ``[0, k)`` such that ``u not in exclude``.
|
|
1134
|
+
num_allowed : Int32[Array, '']
|
|
1135
|
+
The number of integers in the range that were not excluded.
|
|
1136
|
+
|
|
1137
|
+
Notes
|
|
1138
|
+
-----
|
|
1139
|
+
If all values in the range are excluded, the result is unspecified.
|
|
1140
|
+
"""
|
|
1141
|
+
exclude, num_allowed = _process_exclude(logits.size, exclude)
|
|
1142
|
+
kinda_neg_inf = jnp.finfo(logits.dtype).min
|
|
1143
|
+
logits = logits.at[exclude].set(kinda_neg_inf)
|
|
1144
|
+
u = random.categorical(key, logits)
|
|
1145
|
+
return u, num_allowed
|
|
748
1146
|
|
|
749
1147
|
|
|
750
|
-
def choose_split(
|
|
1148
|
+
def choose_split(
|
|
1149
|
+
key: Key[Array, ''],
|
|
1150
|
+
var: Int32[Array, ''],
|
|
1151
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1152
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1153
|
+
max_split: UInt[Array, ' p'],
|
|
1154
|
+
leaf_index: Int32[Array, ''],
|
|
1155
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
|
|
751
1156
|
"""
|
|
752
1157
|
Choose a split point for a new non-terminal node.
|
|
753
1158
|
|
|
754
1159
|
Parameters
|
|
755
1160
|
----------
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
1161
|
+
key
|
|
1162
|
+
A jax random key.
|
|
1163
|
+
var
|
|
1164
|
+
The variable to split on.
|
|
1165
|
+
var_tree
|
|
1166
|
+
The splitting axes of the tree. Does not need to already contain `var`
|
|
1167
|
+
at `leaf_index`.
|
|
1168
|
+
split_tree
|
|
759
1169
|
The splitting points of the tree.
|
|
760
|
-
max_split
|
|
1170
|
+
max_split
|
|
761
1171
|
The maximum split index for each variable.
|
|
762
|
-
leaf_index
|
|
763
|
-
The index of the leaf to grow.
|
|
764
|
-
contains the target variable at this index.
|
|
765
|
-
key : jax.dtypes.prng_key array
|
|
766
|
-
A jax random key.
|
|
1172
|
+
leaf_index
|
|
1173
|
+
The index of the leaf to grow.
|
|
767
1174
|
|
|
768
1175
|
Returns
|
|
769
1176
|
-------
|
|
770
|
-
split :
|
|
771
|
-
The
|
|
1177
|
+
split : Int32[Array, '']
|
|
1178
|
+
The cutpoint.
|
|
1179
|
+
l : Int32[Array, '']
|
|
1180
|
+
r : Int32[Array, '']
|
|
1181
|
+
The integer range `split` was drawn from is [l, r).
|
|
1182
|
+
|
|
1183
|
+
Notes
|
|
1184
|
+
-----
|
|
1185
|
+
If `var` is out of bounds, or if the available split range on that variable
|
|
1186
|
+
is empty, return 0.
|
|
772
1187
|
"""
|
|
773
|
-
var = var_tree[leaf_index]
|
|
774
1188
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
775
|
-
return random.randint(key, (), l, r)
|
|
1189
|
+
return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
|
|
776
1190
|
|
|
777
1191
|
|
|
778
|
-
def compute_partial_ratio(
|
|
1192
|
+
def compute_partial_ratio(
|
|
1193
|
+
prob_choose: Float32[Array, ''],
|
|
1194
|
+
num_prunable: Int32[Array, ''],
|
|
1195
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
1196
|
+
leaf_to_grow: Int32[Array, ''],
|
|
1197
|
+
) -> Float32[Array, '']:
|
|
779
1198
|
"""
|
|
780
1199
|
Compute the product of the transition and prior ratios of a grow move.
|
|
781
1200
|
|
|
782
1201
|
Parameters
|
|
783
1202
|
----------
|
|
784
|
-
|
|
785
|
-
The
|
|
786
|
-
|
|
1203
|
+
prob_choose
|
|
1204
|
+
The probability that the leaf had to be chosen amongst the growable
|
|
1205
|
+
leaves.
|
|
1206
|
+
num_prunable
|
|
787
1207
|
The number of leaf parents that could be pruned, after converting the
|
|
788
1208
|
leaf to be grown to a non-terminal node.
|
|
789
|
-
p_nonterminal
|
|
790
|
-
The probability of
|
|
791
|
-
|
|
1209
|
+
p_nonterminal
|
|
1210
|
+
The a priori probability of each node being nonterminal conditional on
|
|
1211
|
+
its ancestors.
|
|
1212
|
+
leaf_to_grow
|
|
792
1213
|
The index of the leaf to grow.
|
|
793
1214
|
|
|
794
1215
|
Returns
|
|
795
1216
|
-------
|
|
796
|
-
ratio
|
|
797
|
-
The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
|
|
798
|
-
times the prior ratio P(new tree) / P(old tree), but the transition
|
|
799
|
-
ratio is missing the factor P(propose prune) in the numerator.
|
|
800
|
-
"""
|
|
1217
|
+
The partial transition ratio times the prior ratio.
|
|
801
1218
|
|
|
1219
|
+
Notes
|
|
1220
|
+
-----
|
|
1221
|
+
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
1222
|
+
The "partial" transition ratio returned is missing the factor P(propose
|
|
1223
|
+
prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
|
|
1224
|
+
"partial" prior ratio is missing the factor P(children are leaves).
|
|
1225
|
+
"""
|
|
802
1226
|
# the two ratios also contain factors num_available_split *
|
|
803
|
-
# num_available_var, but they cancel out
|
|
1227
|
+
# num_available_var * s[var], but they cancel out
|
|
804
1228
|
|
|
805
|
-
# p_prune
|
|
806
|
-
# computed
|
|
1229
|
+
# p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
|
|
1230
|
+
# computed here because they need the count trees, which are computed in the
|
|
1231
|
+
# acceptance phase
|
|
807
1232
|
|
|
808
1233
|
prune_allowed = leaf_to_grow != 1
|
|
809
1234
|
# prune allowed <---> the initial tree is not a root
|
|
810
1235
|
# leaf to grow is root --> the tree can only be a root
|
|
811
1236
|
# tree is a root --> the only leaf I can grow is root
|
|
812
|
-
|
|
813
1237
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
814
|
-
|
|
815
1238
|
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
816
1239
|
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
tree_ratio =
|
|
1240
|
+
# .at.get because if leaf_to_grow is out of bounds (move not allowed), this
|
|
1241
|
+
# would produce a 0 and then an inf when `complete_ratio` takes the log
|
|
1242
|
+
pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
|
|
1243
|
+
tree_ratio = pnt / (1 - pnt)
|
|
1244
|
+
|
|
1245
|
+
return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
|
|
821
1246
|
|
|
822
|
-
return tree_ratio / inv_trans_ratio
|
|
823
1247
|
|
|
1248
|
+
class PruneMoves(Module):
|
|
1249
|
+
"""
|
|
1250
|
+
Represent a proposed prune move for each tree.
|
|
824
1251
|
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
1252
|
+
Parameters
|
|
1253
|
+
----------
|
|
1254
|
+
allowed
|
|
1255
|
+
Whether the move is possible.
|
|
1256
|
+
node
|
|
1257
|
+
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
1258
|
+
partial_ratio
|
|
1259
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
1260
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
1261
|
+
prior probability that the children of the node to prune are leaves.
|
|
1262
|
+
This ratio is inverted, and is meant to be inverted back in
|
|
1263
|
+
`accept_move_and_sample_leaves`.
|
|
1264
|
+
"""
|
|
1265
|
+
|
|
1266
|
+
allowed: Bool[Array, ' num_trees']
|
|
1267
|
+
node: UInt[Array, ' num_trees']
|
|
1268
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
1269
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
1270
|
+
|
|
1271
|
+
|
|
1272
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
1273
|
+
def propose_prune_moves(
|
|
1274
|
+
key: Key[Array, ''],
|
|
1275
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1276
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1277
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
1278
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1279
|
+
) -> PruneMoves:
|
|
828
1280
|
"""
|
|
829
1281
|
Tree structure prune move proposal of BART MCMC.
|
|
830
1282
|
|
|
831
1283
|
Parameters
|
|
832
1284
|
----------
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
split_tree
|
|
1285
|
+
key
|
|
1286
|
+
A jax random key.
|
|
1287
|
+
split_tree
|
|
836
1288
|
The splitting points of the tree.
|
|
837
|
-
affluence_tree
|
|
838
|
-
Whether
|
|
839
|
-
|
|
840
|
-
The
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
p_propose_grow : float array (2 ** (d - 1),)
|
|
1289
|
+
affluence_tree
|
|
1290
|
+
Whether each leaf can be grown.
|
|
1291
|
+
p_nonterminal
|
|
1292
|
+
The a priori probability of a node to be nonterminal conditional on
|
|
1293
|
+
the ancestors, including at the maximum depth where it should be zero.
|
|
1294
|
+
p_propose_grow
|
|
844
1295
|
The unnormalized probability of choosing a leaf to grow.
|
|
845
|
-
key : jax.dtypes.prng_key array
|
|
846
|
-
A jax random key.
|
|
847
1296
|
|
|
848
1297
|
Returns
|
|
849
1298
|
-------
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
'allowed' : bool
|
|
854
|
-
Whether the move is possible.
|
|
855
|
-
'node' : int
|
|
856
|
-
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
857
|
-
'partial_ratio' : float
|
|
858
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
859
|
-
the likelihood ratio and the probability of proposing the prune
|
|
860
|
-
move. This ratio is inverted.
|
|
861
|
-
"""
|
|
862
|
-
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
1299
|
+
An object representing the proposed moves.
|
|
1300
|
+
"""
|
|
1301
|
+
node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
|
|
863
1302
|
key, split_tree, affluence_tree, p_propose_grow
|
|
864
1303
|
)
|
|
865
|
-
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
866
1304
|
|
|
867
1305
|
ratio = compute_partial_ratio(
|
|
868
1306
|
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
869
1307
|
)
|
|
870
1308
|
|
|
871
|
-
return
|
|
872
|
-
allowed=allowed
|
|
1309
|
+
return PruneMoves(
|
|
1310
|
+
allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
|
|
873
1311
|
node=node_to_prune,
|
|
874
|
-
partial_ratio=ratio,
|
|
1312
|
+
partial_ratio=ratio,
|
|
1313
|
+
affluence_tree=affluence_tree,
|
|
875
1314
|
)
|
|
876
1315
|
|
|
877
1316
|
|
|
878
|
-
def choose_leaf_parent(
|
|
1317
|
+
def choose_leaf_parent(
|
|
1318
|
+
key: Key[Array, ''],
|
|
1319
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1320
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1321
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1322
|
+
) -> tuple[
|
|
1323
|
+
Int32[Array, ''],
|
|
1324
|
+
Int32[Array, ''],
|
|
1325
|
+
Float32[Array, ''],
|
|
1326
|
+
Bool[Array, 'num_trees 2**(d-1)'],
|
|
1327
|
+
]:
|
|
879
1328
|
"""
|
|
880
1329
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
881
1330
|
|
|
882
1331
|
Parameters
|
|
883
1332
|
----------
|
|
884
|
-
|
|
1333
|
+
key
|
|
1334
|
+
A jax random key.
|
|
1335
|
+
split_tree
|
|
885
1336
|
The splitting points of the tree.
|
|
886
|
-
affluence_tree
|
|
1337
|
+
affluence_tree
|
|
887
1338
|
Whether a leaf has enough points to be grown.
|
|
888
|
-
p_propose_grow
|
|
1339
|
+
p_propose_grow
|
|
889
1340
|
The unnormalized probability of choosing a leaf to grow.
|
|
890
|
-
key : jax.dtypes.prng_key array
|
|
891
|
-
A jax random key.
|
|
892
1341
|
|
|
893
1342
|
Returns
|
|
894
1343
|
-------
|
|
895
|
-
node_to_prune :
|
|
1344
|
+
node_to_prune : Int32[Array, '']
|
|
896
1345
|
The index of the node to prune. If ``num_prunable == 0``, return
|
|
897
1346
|
``2 ** d``.
|
|
898
|
-
num_prunable :
|
|
1347
|
+
num_prunable : Int32[Array, '']
|
|
899
1348
|
The number of leaf parents that could be pruned.
|
|
900
|
-
prob_choose :
|
|
901
|
-
The normalized probability
|
|
902
|
-
|
|
1349
|
+
prob_choose : Float32[Array, '']
|
|
1350
|
+
The (normalized) probability that `choose_leaf` would chose
|
|
1351
|
+
`node_to_prune` as leaf to grow, if passed the tree where
|
|
1352
|
+
`node_to_prune` had been pruned.
|
|
1353
|
+
affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
|
|
1354
|
+
A partially updated `affluence_tree`, marking the node to prune as
|
|
1355
|
+
growable.
|
|
1356
|
+
"""
|
|
1357
|
+
# sample a node to prune
|
|
903
1358
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
904
1359
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
905
1360
|
node_to_prune = randint_masked(key, is_prunable)
|
|
906
1361
|
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
907
1362
|
|
|
1363
|
+
# compute stuff for reverse move
|
|
908
1364
|
split_tree = split_tree.at[node_to_prune].set(0)
|
|
909
|
-
affluence_tree = (
|
|
910
|
-
None if affluence_tree is None else affluence_tree.at[node_to_prune].set(True)
|
|
911
|
-
)
|
|
1365
|
+
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
912
1366
|
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
913
|
-
|
|
914
|
-
prob_choose
|
|
1367
|
+
distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
1368
|
+
prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
|
|
1369
|
+
prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
|
|
915
1370
|
|
|
916
|
-
return node_to_prune, num_prunable, prob_choose
|
|
1371
|
+
return node_to_prune, num_prunable, prob_choose, affluence_tree
|
|
917
1372
|
|
|
918
1373
|
|
|
919
|
-
def randint_masked(key, mask):
|
|
1374
|
+
def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
|
|
920
1375
|
"""
|
|
921
1376
|
Return a random integer in a range, including only some values.
|
|
922
1377
|
|
|
923
1378
|
Parameters
|
|
924
1379
|
----------
|
|
925
|
-
key
|
|
1380
|
+
key
|
|
926
1381
|
A jax random key.
|
|
927
|
-
mask
|
|
1382
|
+
mask
|
|
928
1383
|
The mask indicating the allowed values.
|
|
929
1384
|
|
|
930
1385
|
Returns
|
|
931
1386
|
-------
|
|
932
|
-
u
|
|
933
|
-
|
|
934
|
-
|
|
1387
|
+
A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
|
|
1388
|
+
|
|
1389
|
+
Notes
|
|
1390
|
+
-----
|
|
1391
|
+
If all values in the mask are `False`, return `n`.
|
|
935
1392
|
"""
|
|
936
1393
|
ecdf = jnp.cumsum(mask)
|
|
937
1394
|
u = random.randint(key, (), 0, ecdf[-1])
|
|
938
1395
|
return jnp.searchsorted(ecdf, u, 'right')
|
|
939
1396
|
|
|
940
1397
|
|
|
941
|
-
def accept_moves_and_sample_leaves(
|
|
1398
|
+
def accept_moves_and_sample_leaves(
|
|
1399
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1400
|
+
) -> State:
|
|
942
1401
|
"""
|
|
943
1402
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
944
1403
|
|
|
945
1404
|
Parameters
|
|
946
1405
|
----------
|
|
947
|
-
key
|
|
1406
|
+
key
|
|
948
1407
|
A jax random key.
|
|
949
|
-
bart
|
|
950
|
-
A BART mcmc state.
|
|
951
|
-
moves
|
|
952
|
-
The proposed moves, see `
|
|
1408
|
+
bart
|
|
1409
|
+
A valid BART mcmc state.
|
|
1410
|
+
moves
|
|
1411
|
+
The proposed moves, see `propose_moves`.
|
|
953
1412
|
|
|
954
1413
|
Returns
|
|
955
1414
|
-------
|
|
956
|
-
|
|
957
|
-
The new BART mcmc state.
|
|
1415
|
+
A new (valid) BART mcmc state.
|
|
958
1416
|
"""
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
)
|
|
962
|
-
bart, moves = accept_moves_sequential_stage(
|
|
963
|
-
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
964
|
-
)
|
|
1417
|
+
pso = accept_moves_parallel_stage(key, bart, moves)
|
|
1418
|
+
bart, moves = accept_moves_sequential_stage(pso)
|
|
965
1419
|
return accept_moves_final_stage(bart, moves)
|
|
966
1420
|
|
|
967
1421
|
|
|
968
|
-
|
|
1422
|
+
class Counts(Module):
|
|
1423
|
+
"""
|
|
1424
|
+
Number of datapoints in the nodes involved in proposed moves for each tree.
|
|
1425
|
+
|
|
1426
|
+
Parameters
|
|
1427
|
+
----------
|
|
1428
|
+
left
|
|
1429
|
+
Number of datapoints in the left child.
|
|
1430
|
+
right
|
|
1431
|
+
Number of datapoints in the right child.
|
|
1432
|
+
total
|
|
1433
|
+
Number of datapoints in the parent (``= left + right``).
|
|
1434
|
+
"""
|
|
1435
|
+
|
|
1436
|
+
left: UInt[Array, ' num_trees']
|
|
1437
|
+
right: UInt[Array, ' num_trees']
|
|
1438
|
+
total: UInt[Array, ' num_trees']
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
class Precs(Module):
|
|
969
1442
|
"""
|
|
970
|
-
|
|
1443
|
+
Likelihood precision scale in the nodes involved in proposed moves for each tree.
|
|
1444
|
+
|
|
1445
|
+
The "likelihood precision scale" of a tree node is the sum of the inverse
|
|
1446
|
+
squared error scales of the datapoints selected by the node.
|
|
1447
|
+
|
|
1448
|
+
Parameters
|
|
1449
|
+
----------
|
|
1450
|
+
left
|
|
1451
|
+
Likelihood precision scale in the left child.
|
|
1452
|
+
right
|
|
1453
|
+
Likelihood precision scale in the right child.
|
|
1454
|
+
total
|
|
1455
|
+
Likelihood precision scale in the parent (``= left + right``).
|
|
1456
|
+
"""
|
|
1457
|
+
|
|
1458
|
+
left: Float32[Array, ' num_trees']
|
|
1459
|
+
right: Float32[Array, ' num_trees']
|
|
1460
|
+
total: Float32[Array, ' num_trees']
|
|
1461
|
+
|
|
1462
|
+
|
|
1463
|
+
class PreLkV(Module):
|
|
1464
|
+
"""
|
|
1465
|
+
Non-sequential terms of the likelihood ratio for each tree.
|
|
1466
|
+
|
|
1467
|
+
These terms can be computed in parallel across trees.
|
|
1468
|
+
|
|
1469
|
+
Parameters
|
|
1470
|
+
----------
|
|
1471
|
+
sigma2_left
|
|
1472
|
+
The noise variance in the left child of the leaves grown or pruned by
|
|
1473
|
+
the moves.
|
|
1474
|
+
sigma2_right
|
|
1475
|
+
The noise variance in the right child of the leaves grown or pruned by
|
|
1476
|
+
the moves.
|
|
1477
|
+
sigma2_total
|
|
1478
|
+
The noise variance in the total of the leaves grown or pruned by the
|
|
1479
|
+
moves.
|
|
1480
|
+
sqrt_term
|
|
1481
|
+
The **logarithm** of the square root term of the likelihood ratio.
|
|
1482
|
+
"""
|
|
1483
|
+
|
|
1484
|
+
sigma2_left: Float32[Array, ' num_trees']
|
|
1485
|
+
sigma2_right: Float32[Array, ' num_trees']
|
|
1486
|
+
sigma2_total: Float32[Array, ' num_trees']
|
|
1487
|
+
sqrt_term: Float32[Array, ' num_trees']
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
class PreLk(Module):
|
|
1491
|
+
"""
|
|
1492
|
+
Non-sequential terms of the likelihood ratio shared by all trees.
|
|
1493
|
+
|
|
1494
|
+
Parameters
|
|
1495
|
+
----------
|
|
1496
|
+
exp_factor
|
|
1497
|
+
The factor to multiply the likelihood ratio by, shared by all trees.
|
|
1498
|
+
"""
|
|
1499
|
+
|
|
1500
|
+
exp_factor: Float32[Array, '']
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
class PreLf(Module):
|
|
1504
|
+
"""
|
|
1505
|
+
Pre-computed terms used to sample leaves from their posterior.
|
|
1506
|
+
|
|
1507
|
+
These terms can be computed in parallel across trees.
|
|
1508
|
+
|
|
1509
|
+
Parameters
|
|
1510
|
+
----------
|
|
1511
|
+
mean_factor
|
|
1512
|
+
The factor to be multiplied by the sum of the scaled residuals to
|
|
1513
|
+
obtain the posterior mean.
|
|
1514
|
+
centered_leaves
|
|
1515
|
+
The mean-zero normal values to be added to the posterior mean to
|
|
1516
|
+
obtain the posterior leaf samples.
|
|
1517
|
+
"""
|
|
1518
|
+
|
|
1519
|
+
mean_factor: Float32[Array, 'num_trees 2**d']
|
|
1520
|
+
centered_leaves: Float32[Array, 'num_trees 2**d']
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
class ParallelStageOut(Module):
|
|
1524
|
+
"""
|
|
1525
|
+
The output of `accept_moves_parallel_stage`.
|
|
1526
|
+
|
|
1527
|
+
Parameters
|
|
1528
|
+
----------
|
|
1529
|
+
bart
|
|
1530
|
+
A partially updated BART mcmc state.
|
|
1531
|
+
moves
|
|
1532
|
+
The proposed moves, with `partial_ratio` set to `None` and
|
|
1533
|
+
`log_trans_prior_ratio` set to its final value.
|
|
1534
|
+
prec_trees
|
|
1535
|
+
The likelihood precision scale in each potential or actual leaf node. If
|
|
1536
|
+
there is no precision scale, this is the number of points in each leaf.
|
|
1537
|
+
move_counts
|
|
1538
|
+
The counts of the number of points in the the nodes modified by the
|
|
1539
|
+
moves. If `bart.min_points_per_leaf` is not set and
|
|
1540
|
+
`bart.prec_scale` is set, they are not computed.
|
|
1541
|
+
move_precs
|
|
1542
|
+
The likelihood precision scale in each node modified by the moves. If
|
|
1543
|
+
`bart.prec_scale` is not set, this is set to `move_counts`.
|
|
1544
|
+
prelkv
|
|
1545
|
+
prelk
|
|
1546
|
+
prelf
|
|
1547
|
+
Objects with pre-computed terms of the likelihood ratios and leaf
|
|
1548
|
+
samples.
|
|
1549
|
+
"""
|
|
1550
|
+
|
|
1551
|
+
bart: State
|
|
1552
|
+
moves: Moves
|
|
1553
|
+
prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
|
|
1554
|
+
move_precs: Precs | Counts
|
|
1555
|
+
prelkv: PreLkV
|
|
1556
|
+
prelk: PreLk
|
|
1557
|
+
prelf: PreLf
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
def accept_moves_parallel_stage(
|
|
1561
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
1562
|
+
) -> ParallelStageOut:
|
|
1563
|
+
"""
|
|
1564
|
+
Pre-compute quantities used to accept moves, in parallel across trees.
|
|
971
1565
|
|
|
972
1566
|
Parameters
|
|
973
1567
|
----------
|
|
@@ -976,153 +1570,186 @@ def accept_moves_parallel_stage(key, bart, moves):
|
|
|
976
1570
|
bart : dict
|
|
977
1571
|
A BART mcmc state.
|
|
978
1572
|
moves : dict
|
|
979
|
-
The proposed moves, see `
|
|
1573
|
+
The proposed moves, see `propose_moves`.
|
|
980
1574
|
|
|
981
1575
|
Returns
|
|
982
1576
|
-------
|
|
983
|
-
|
|
984
|
-
A partially updated BART mcmc state.
|
|
985
|
-
moves : dict
|
|
986
|
-
The proposed moves, with the field 'partial_ratio' replaced
|
|
987
|
-
by 'log_trans_prior_ratio'.
|
|
988
|
-
prec_trees : float array (num_trees, 2 ** d)
|
|
989
|
-
The likelihood precision scale in each potential or actual leaf node. If
|
|
990
|
-
there is no precision scale, this is the number of points in each leaf.
|
|
991
|
-
move_counts : dict
|
|
992
|
-
The counts of the number of points in the the nodes modified by the
|
|
993
|
-
moves.
|
|
994
|
-
move_precs : dict
|
|
995
|
-
The likelihood precision scale in each node modified by the moves.
|
|
996
|
-
prelkv, prelk, prelf : dict
|
|
997
|
-
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
998
|
-
samples.
|
|
1577
|
+
An object with all that could be done in parallel.
|
|
999
1578
|
"""
|
|
1000
|
-
bart = bart.copy()
|
|
1001
|
-
|
|
1002
1579
|
# where the move is grow, modify the state like the move was accepted
|
|
1003
|
-
bart
|
|
1004
|
-
|
|
1005
|
-
|
|
1580
|
+
bart = replace(
|
|
1581
|
+
bart,
|
|
1582
|
+
forest=replace(
|
|
1583
|
+
bart.forest,
|
|
1584
|
+
var_tree=moves.var_tree,
|
|
1585
|
+
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
1586
|
+
leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
|
|
1587
|
+
),
|
|
1588
|
+
)
|
|
1006
1589
|
|
|
1007
1590
|
# count number of datapoints per leaf
|
|
1008
|
-
|
|
1009
|
-
bart
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1591
|
+
if (
|
|
1592
|
+
bart.forest.min_points_per_decision_node is not None
|
|
1593
|
+
or bart.forest.min_points_per_leaf is not None
|
|
1594
|
+
or bart.prec_scale is None
|
|
1595
|
+
):
|
|
1596
|
+
count_trees, move_counts = compute_count_trees(
|
|
1597
|
+
bart.forest.leaf_indices, moves, bart.forest.count_batch_size
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# mark which leaves & potential leaves have enough points to be grown
|
|
1601
|
+
if bart.forest.min_points_per_decision_node is not None:
|
|
1602
|
+
count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
|
|
1603
|
+
moves = replace(
|
|
1604
|
+
moves,
|
|
1605
|
+
affluence_tree=moves.affluence_tree
|
|
1606
|
+
& (count_half_trees >= bart.forest.min_points_per_decision_node),
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
# copy updated affluence_tree to state
|
|
1610
|
+
bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
|
|
1611
|
+
|
|
1612
|
+
# veto grove move if new leaves don't have enough datapoints
|
|
1613
|
+
if bart.forest.min_points_per_leaf is not None:
|
|
1614
|
+
moves = replace(
|
|
1615
|
+
moves,
|
|
1616
|
+
allowed=moves.allowed
|
|
1617
|
+
& (move_counts.left >= bart.forest.min_points_per_leaf)
|
|
1618
|
+
& (move_counts.right >= bart.forest.min_points_per_leaf),
|
|
1619
|
+
)
|
|
1014
1620
|
|
|
1015
1621
|
# count number of datapoints per leaf, weighted by error precision scale
|
|
1016
|
-
if bart
|
|
1622
|
+
if bart.prec_scale is None:
|
|
1017
1623
|
prec_trees = count_trees
|
|
1018
1624
|
move_precs = move_counts
|
|
1019
1625
|
else:
|
|
1020
1626
|
prec_trees, move_precs = compute_prec_trees(
|
|
1021
|
-
bart
|
|
1022
|
-
bart
|
|
1627
|
+
bart.prec_scale,
|
|
1628
|
+
bart.forest.leaf_indices,
|
|
1023
1629
|
moves,
|
|
1024
|
-
bart
|
|
1630
|
+
bart.forest.count_batch_size,
|
|
1025
1631
|
)
|
|
1632
|
+
assert move_precs is not None
|
|
1026
1633
|
|
|
1027
1634
|
# compute some missing information about moves
|
|
1028
|
-
moves = complete_ratio(moves,
|
|
1029
|
-
|
|
1030
|
-
bart
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1635
|
+
moves = complete_ratio(moves, bart.forest.p_nonterminal)
|
|
1636
|
+
save_ratios = bart.forest.log_likelihood is not None
|
|
1637
|
+
bart = replace(
|
|
1638
|
+
bart,
|
|
1639
|
+
forest=replace(
|
|
1640
|
+
bart.forest,
|
|
1641
|
+
grow_prop_count=jnp.sum(moves.grow),
|
|
1642
|
+
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
1643
|
+
log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
|
|
1644
|
+
),
|
|
1645
|
+
)
|
|
1034
1646
|
|
|
1035
|
-
|
|
1647
|
+
# pre-compute some likelihood ratio & posterior terms
|
|
1648
|
+
assert bart.sigma2 is not None # `step` shall temporarily set it to 1
|
|
1649
|
+
prelkv, prelk = precompute_likelihood_terms(
|
|
1650
|
+
bart.sigma2, bart.forest.sigma_mu2, move_precs
|
|
1651
|
+
)
|
|
1652
|
+
prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
|
|
1653
|
+
|
|
1654
|
+
return ParallelStageOut(
|
|
1655
|
+
bart=bart,
|
|
1656
|
+
moves=moves,
|
|
1657
|
+
prec_trees=prec_trees,
|
|
1658
|
+
move_precs=move_precs,
|
|
1659
|
+
prelkv=prelkv,
|
|
1660
|
+
prelk=prelk,
|
|
1661
|
+
prelf=prelf,
|
|
1662
|
+
)
|
|
1036
1663
|
|
|
1037
1664
|
|
|
1038
|
-
@
|
|
1039
|
-
def apply_grow_to_indices(
|
|
1665
|
+
@partial(vmap_nodoc, in_axes=(0, 0, None))
|
|
1666
|
+
def apply_grow_to_indices(
|
|
1667
|
+
moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
|
|
1668
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1040
1669
|
"""
|
|
1041
1670
|
Update the leaf indices to apply a grow move.
|
|
1042
1671
|
|
|
1043
1672
|
Parameters
|
|
1044
1673
|
----------
|
|
1045
|
-
moves
|
|
1046
|
-
The proposed moves, see `
|
|
1047
|
-
leaf_indices
|
|
1674
|
+
moves
|
|
1675
|
+
The proposed moves, see `propose_moves`.
|
|
1676
|
+
leaf_indices
|
|
1048
1677
|
The index of the leaf each datapoint falls into.
|
|
1049
|
-
X
|
|
1678
|
+
X
|
|
1050
1679
|
The predictors matrix.
|
|
1051
1680
|
|
|
1052
1681
|
Returns
|
|
1053
1682
|
-------
|
|
1054
|
-
|
|
1055
|
-
The updated leaf indices.
|
|
1683
|
+
The updated leaf indices.
|
|
1056
1684
|
"""
|
|
1057
|
-
left_child = moves
|
|
1058
|
-
go_right = X[moves
|
|
1059
|
-
tree_size = jnp.array(2 * moves
|
|
1060
|
-
node_to_update = jnp.where(moves
|
|
1685
|
+
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
1686
|
+
go_right = X[moves.grow_var, :] >= moves.grow_split
|
|
1687
|
+
tree_size = jnp.array(2 * moves.var_tree.size)
|
|
1688
|
+
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
1061
1689
|
return jnp.where(
|
|
1062
|
-
leaf_indices == node_to_update,
|
|
1063
|
-
left_child + go_right,
|
|
1064
|
-
leaf_indices,
|
|
1690
|
+
leaf_indices == node_to_update, left_child + go_right, leaf_indices
|
|
1065
1691
|
)
|
|
1066
1692
|
|
|
1067
1693
|
|
|
1068
|
-
def compute_count_trees(
|
|
1694
|
+
def compute_count_trees(
|
|
1695
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
|
|
1696
|
+
) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
|
|
1069
1697
|
"""
|
|
1070
1698
|
Count the number of datapoints in each leaf.
|
|
1071
1699
|
|
|
1072
1700
|
Parameters
|
|
1073
1701
|
----------
|
|
1074
|
-
leaf_indices
|
|
1702
|
+
leaf_indices
|
|
1075
1703
|
The index of the leaf each datapoint falls into, with the deeper version
|
|
1076
1704
|
of the tree (post-GROW, pre-PRUNE).
|
|
1077
|
-
moves
|
|
1078
|
-
The proposed moves, see `
|
|
1079
|
-
batch_size
|
|
1705
|
+
moves
|
|
1706
|
+
The proposed moves, see `propose_moves`.
|
|
1707
|
+
batch_size
|
|
1080
1708
|
The data batch size to use for the summation.
|
|
1081
1709
|
|
|
1082
1710
|
Returns
|
|
1083
1711
|
-------
|
|
1084
|
-
count_trees :
|
|
1712
|
+
count_trees : Int32[Array, 'num_trees 2**d']
|
|
1085
1713
|
The number of points in each potential or actual leaf node.
|
|
1086
|
-
counts :
|
|
1714
|
+
counts : Counts
|
|
1087
1715
|
The counts of the number of points in the leaves grown or pruned by the
|
|
1088
|
-
moves
|
|
1716
|
+
moves.
|
|
1089
1717
|
"""
|
|
1090
|
-
|
|
1091
|
-
ntree, tree_size = moves['var_trees'].shape
|
|
1718
|
+
num_trees, tree_size = moves.var_tree.shape
|
|
1092
1719
|
tree_size *= 2
|
|
1093
|
-
tree_indices = jnp.arange(
|
|
1720
|
+
tree_indices = jnp.arange(num_trees)
|
|
1094
1721
|
|
|
1095
1722
|
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
1096
1723
|
|
|
1097
1724
|
# count datapoints in nodes modified by move
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
counts
|
|
1101
|
-
counts['total'] = counts['left'] + counts['right']
|
|
1725
|
+
left = count_trees[tree_indices, moves.left]
|
|
1726
|
+
right = count_trees[tree_indices, moves.right]
|
|
1727
|
+
counts = Counts(left=left, right=right, total=left + right)
|
|
1102
1728
|
|
|
1103
1729
|
# write count into non-leaf node
|
|
1104
|
-
count_trees = count_trees.at[tree_indices, moves
|
|
1730
|
+
count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
|
|
1105
1731
|
|
|
1106
1732
|
return count_trees, counts
|
|
1107
1733
|
|
|
1108
1734
|
|
|
1109
|
-
def count_datapoints_per_leaf(
|
|
1735
|
+
def count_datapoints_per_leaf(
|
|
1736
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
|
|
1737
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1110
1738
|
"""
|
|
1111
1739
|
Count the number of datapoints in each leaf.
|
|
1112
1740
|
|
|
1113
1741
|
Parameters
|
|
1114
1742
|
----------
|
|
1115
|
-
leaf_indices
|
|
1743
|
+
leaf_indices
|
|
1116
1744
|
The index of the leaf each datapoint falls into.
|
|
1117
|
-
tree_size
|
|
1745
|
+
tree_size
|
|
1118
1746
|
The size of the leaf tree array (2 ** d).
|
|
1119
|
-
batch_size
|
|
1747
|
+
batch_size
|
|
1120
1748
|
The data batch size to use for the summation.
|
|
1121
1749
|
|
|
1122
1750
|
Returns
|
|
1123
1751
|
-------
|
|
1124
|
-
|
|
1125
|
-
The number of points in each leaf node.
|
|
1752
|
+
The number of points in each leaf node.
|
|
1126
1753
|
"""
|
|
1127
1754
|
if batch_size is None:
|
|
1128
1755
|
return _count_scan(leaf_indices, tree_size)
|
|
@@ -1130,7 +1757,9 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
|
1130
1757
|
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1131
1758
|
|
|
1132
1759
|
|
|
1133
|
-
def _count_scan(
|
|
1760
|
+
def _count_scan(
|
|
1761
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
|
|
1762
|
+
) -> Int32[Array, 'num_trees {tree_size}']:
|
|
1134
1763
|
def loop(_, leaf_indices):
|
|
1135
1764
|
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1136
1765
|
|
|
@@ -1138,92 +1767,111 @@ def _count_scan(leaf_indices, tree_size):
|
|
|
1138
1767
|
return count_trees
|
|
1139
1768
|
|
|
1140
1769
|
|
|
1141
|
-
def _aggregate_scatter(
|
|
1770
|
+
def _aggregate_scatter(
|
|
1771
|
+
values: Shaped[Array, '*'],
|
|
1772
|
+
indices: Integer[Array, '*'],
|
|
1773
|
+
size: int,
|
|
1774
|
+
dtype: jnp.dtype,
|
|
1775
|
+
) -> Shaped[Array, ' {size}']:
|
|
1142
1776
|
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1143
1777
|
|
|
1144
1778
|
|
|
1145
|
-
def _count_vec(
|
|
1779
|
+
def _count_vec(
|
|
1780
|
+
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
|
|
1781
|
+
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1146
1782
|
return _aggregate_batched_alltrees(
|
|
1147
1783
|
1, leaf_indices, tree_size, jnp.uint32, batch_size
|
|
1148
1784
|
)
|
|
1149
1785
|
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1150
1786
|
|
|
1151
1787
|
|
|
1152
|
-
def _aggregate_batched_alltrees(
|
|
1153
|
-
|
|
1154
|
-
|
|
1788
|
+
def _aggregate_batched_alltrees(
|
|
1789
|
+
values: Shaped[Array, '*'],
|
|
1790
|
+
indices: UInt[Array, 'num_trees n'],
|
|
1791
|
+
size: int,
|
|
1792
|
+
dtype: jnp.dtype,
|
|
1793
|
+
batch_size: int,
|
|
1794
|
+
) -> Shaped[Array, 'num_trees {size}']:
|
|
1795
|
+
num_trees, n = indices.shape
|
|
1796
|
+
tree_indices = jnp.arange(num_trees)
|
|
1155
1797
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1156
1798
|
batch_indices = jnp.arange(n) % nbatches
|
|
1157
1799
|
return (
|
|
1158
|
-
jnp.zeros((
|
|
1800
|
+
jnp.zeros((num_trees, size, nbatches), dtype)
|
|
1159
1801
|
.at[tree_indices[:, None], indices, batch_indices]
|
|
1160
1802
|
.add(values)
|
|
1161
1803
|
.sum(axis=2)
|
|
1162
1804
|
)
|
|
1163
1805
|
|
|
1164
1806
|
|
|
1165
|
-
def compute_prec_trees(
|
|
1807
|
+
def compute_prec_trees(
|
|
1808
|
+
prec_scale: Float32[Array, ' n'],
|
|
1809
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1810
|
+
moves: Moves,
|
|
1811
|
+
batch_size: int | None,
|
|
1812
|
+
) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
|
|
1166
1813
|
"""
|
|
1167
1814
|
Compute the likelihood precision scale in each leaf.
|
|
1168
1815
|
|
|
1169
1816
|
Parameters
|
|
1170
1817
|
----------
|
|
1171
|
-
prec_scale
|
|
1818
|
+
prec_scale
|
|
1172
1819
|
The scale of the precision of the error on each datapoint.
|
|
1173
|
-
leaf_indices
|
|
1820
|
+
leaf_indices
|
|
1174
1821
|
The index of the leaf each datapoint falls into, with the deeper version
|
|
1175
1822
|
of the tree (post-GROW, pre-PRUNE).
|
|
1176
|
-
moves
|
|
1177
|
-
The proposed moves, see `
|
|
1178
|
-
batch_size
|
|
1823
|
+
moves
|
|
1824
|
+
The proposed moves, see `propose_moves`.
|
|
1825
|
+
batch_size
|
|
1179
1826
|
The data batch size to use for the summation.
|
|
1180
1827
|
|
|
1181
1828
|
Returns
|
|
1182
1829
|
-------
|
|
1183
|
-
prec_trees :
|
|
1830
|
+
prec_trees : Float32[Array, 'num_trees 2**d']
|
|
1184
1831
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1185
|
-
|
|
1186
|
-
The likelihood precision scale in the
|
|
1187
|
-
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1832
|
+
precs : Precs
|
|
1833
|
+
The likelihood precision scale in the nodes involved in the moves.
|
|
1188
1834
|
"""
|
|
1189
|
-
|
|
1190
|
-
ntree, tree_size = moves['var_trees'].shape
|
|
1835
|
+
num_trees, tree_size = moves.var_tree.shape
|
|
1191
1836
|
tree_size *= 2
|
|
1192
|
-
tree_indices = jnp.arange(
|
|
1837
|
+
tree_indices = jnp.arange(num_trees)
|
|
1193
1838
|
|
|
1194
1839
|
prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1195
1840
|
|
|
1196
1841
|
# prec datapoints in nodes modified by move
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
precs
|
|
1200
|
-
precs['total'] = precs['left'] + precs['right']
|
|
1842
|
+
left = prec_trees[tree_indices, moves.left]
|
|
1843
|
+
right = prec_trees[tree_indices, moves.right]
|
|
1844
|
+
precs = Precs(left=left, right=right, total=left + right)
|
|
1201
1845
|
|
|
1202
1846
|
# write prec into non-leaf node
|
|
1203
|
-
prec_trees = prec_trees.at[tree_indices, moves
|
|
1847
|
+
prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
|
|
1204
1848
|
|
|
1205
1849
|
return prec_trees, precs
|
|
1206
1850
|
|
|
1207
1851
|
|
|
1208
|
-
def prec_per_leaf(
|
|
1852
|
+
def prec_per_leaf(
|
|
1853
|
+
prec_scale: Float32[Array, ' n'],
|
|
1854
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1855
|
+
tree_size: int,
|
|
1856
|
+
batch_size: int | None,
|
|
1857
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1209
1858
|
"""
|
|
1210
1859
|
Compute the likelihood precision scale in each leaf.
|
|
1211
1860
|
|
|
1212
1861
|
Parameters
|
|
1213
1862
|
----------
|
|
1214
|
-
prec_scale
|
|
1863
|
+
prec_scale
|
|
1215
1864
|
The scale of the precision of the error on each datapoint.
|
|
1216
|
-
leaf_indices
|
|
1865
|
+
leaf_indices
|
|
1217
1866
|
The index of the leaf each datapoint falls into.
|
|
1218
|
-
tree_size
|
|
1867
|
+
tree_size
|
|
1219
1868
|
The size of the leaf tree array (2 ** d).
|
|
1220
|
-
batch_size
|
|
1869
|
+
batch_size
|
|
1221
1870
|
The data batch size to use for the summation.
|
|
1222
1871
|
|
|
1223
1872
|
Returns
|
|
1224
1873
|
-------
|
|
1225
|
-
|
|
1226
|
-
The likelihood precision scale in each leaf node.
|
|
1874
|
+
The likelihood precision scale in each leaf node.
|
|
1227
1875
|
"""
|
|
1228
1876
|
if batch_size is None:
|
|
1229
1877
|
return _prec_scan(prec_scale, leaf_indices, tree_size)
|
|
@@ -1231,432 +1879,439 @@ def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
|
|
|
1231
1879
|
return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1232
1880
|
|
|
1233
1881
|
|
|
1234
|
-
def _prec_scan(
|
|
1882
|
+
def _prec_scan(
|
|
1883
|
+
prec_scale: Float32[Array, ' n'],
|
|
1884
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1885
|
+
tree_size: int,
|
|
1886
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1235
1887
|
def loop(_, leaf_indices):
|
|
1236
1888
|
return None, _aggregate_scatter(
|
|
1237
1889
|
prec_scale, leaf_indices, tree_size, jnp.float32
|
|
1238
|
-
)
|
|
1890
|
+
)
|
|
1239
1891
|
|
|
1240
1892
|
_, prec_trees = lax.scan(loop, None, leaf_indices)
|
|
1241
1893
|
return prec_trees
|
|
1242
1894
|
|
|
1243
1895
|
|
|
1244
|
-
def _prec_vec(
|
|
1896
|
+
def _prec_vec(
|
|
1897
|
+
prec_scale: Float32[Array, ' n'],
|
|
1898
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1899
|
+
tree_size: int,
|
|
1900
|
+
batch_size: int,
|
|
1901
|
+
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1245
1902
|
return _aggregate_batched_alltrees(
|
|
1246
1903
|
prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
|
|
1247
|
-
)
|
|
1904
|
+
)
|
|
1248
1905
|
|
|
1249
1906
|
|
|
1250
|
-
def complete_ratio(moves,
|
|
1907
|
+
def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
|
|
1251
1908
|
"""
|
|
1252
1909
|
Complete non-likelihood MH ratio calculation.
|
|
1253
1910
|
|
|
1254
|
-
This function adds the probability of choosing
|
|
1911
|
+
This function adds the probability of choosing a prune move over the grow
|
|
1912
|
+
move in the inverse transition, and the a priori probability that the
|
|
1913
|
+
children nodes are leaves.
|
|
1255
1914
|
|
|
1256
1915
|
Parameters
|
|
1257
1916
|
----------
|
|
1258
|
-
moves
|
|
1259
|
-
The proposed moves
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1917
|
+
moves
|
|
1918
|
+
The proposed moves. Must have already been updated to keep into account
|
|
1919
|
+
the thresholds on the number of datapoints per node, this happens in
|
|
1920
|
+
`accept_moves_parallel_stage`.
|
|
1921
|
+
p_nonterminal
|
|
1922
|
+
The a priori probability of each node being nonterminal conditional on
|
|
1923
|
+
its ancestors, including at the maximum depth where it should be zero.
|
|
1265
1924
|
|
|
1266
1925
|
Returns
|
|
1267
1926
|
-------
|
|
1268
|
-
moves
|
|
1269
|
-
The updated moves, with the field 'partial_ratio' replaced by
|
|
1270
|
-
'log_trans_prior_ratio'.
|
|
1927
|
+
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
1271
1928
|
"""
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1929
|
+
# can the leaves can be grown?
|
|
1930
|
+
num_trees, _ = moves.affluence_tree.shape
|
|
1931
|
+
tree_indices = jnp.arange(num_trees)
|
|
1932
|
+
left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
|
|
1933
|
+
mode='fill', fill_value=False
|
|
1934
|
+
)
|
|
1935
|
+
right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
|
|
1936
|
+
mode='fill', fill_value=False
|
|
1275
1937
|
)
|
|
1276
|
-
moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
|
|
1277
|
-
return moves
|
|
1278
|
-
|
|
1279
1938
|
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1939
|
+
# p_prune if grow
|
|
1940
|
+
other_growable_leaves = moves.num_growable >= 2
|
|
1941
|
+
grow_again_allowed = other_growable_leaves | left_growable | right_growable
|
|
1942
|
+
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1283
1943
|
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
moves : dict
|
|
1287
|
-
The proposed moves, see `sample_moves`.
|
|
1288
|
-
left_count, right_count : int
|
|
1289
|
-
The number of datapoints in the proposed children of the leaf to grow.
|
|
1290
|
-
min_points_per_leaf : int or None
|
|
1291
|
-
The minimum number of data points in a leaf node.
|
|
1944
|
+
# p_prune if prune
|
|
1945
|
+
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
1292
1946
|
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
p_prune : float
|
|
1296
|
-
The probability of proposing a prune move. If grow: after accepting the
|
|
1297
|
-
grow move, if prune: right away.
|
|
1298
|
-
"""
|
|
1299
|
-
|
|
1300
|
-
# calculation in case the move is grow
|
|
1301
|
-
other_growable_leaves = moves['num_growable'] >= 2
|
|
1302
|
-
new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
|
|
1303
|
-
if min_points_per_leaf is not None:
|
|
1304
|
-
any_above_threshold = left_count >= 2 * min_points_per_leaf
|
|
1305
|
-
any_above_threshold |= right_count >= 2 * min_points_per_leaf
|
|
1306
|
-
new_leaves_growable &= any_above_threshold
|
|
1307
|
-
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1308
|
-
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1947
|
+
# select p_prune
|
|
1948
|
+
p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
1309
1949
|
|
|
1310
|
-
#
|
|
1311
|
-
|
|
1950
|
+
# prior probability of both children being terminal
|
|
1951
|
+
pt_left = 1 - p_nonterminal[moves.left] * left_growable
|
|
1952
|
+
pt_right = 1 - p_nonterminal[moves.right] * right_growable
|
|
1953
|
+
pt_children = pt_left * pt_right
|
|
1312
1954
|
|
|
1313
|
-
return
|
|
1955
|
+
return replace(
|
|
1956
|
+
moves,
|
|
1957
|
+
log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
|
|
1958
|
+
partial_ratio=None,
|
|
1959
|
+
)
|
|
1314
1960
|
|
|
1315
1961
|
|
|
1316
|
-
@
|
|
1317
|
-
def adapt_leaf_trees_to_grow_indices(
|
|
1962
|
+
@vmap_nodoc
|
|
1963
|
+
def adapt_leaf_trees_to_grow_indices(
|
|
1964
|
+
leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
|
|
1965
|
+
) -> Float32[Array, 'num_trees 2**d']:
|
|
1318
1966
|
"""
|
|
1319
|
-
Modify
|
|
1320
|
-
|
|
1967
|
+
Modify leaves such that post-grow indices work on the original tree.
|
|
1968
|
+
|
|
1969
|
+
The value of the leaf to grow is copied to what would be its children if the
|
|
1970
|
+
grow move was accepted.
|
|
1321
1971
|
|
|
1322
1972
|
Parameters
|
|
1323
1973
|
----------
|
|
1324
|
-
leaf_trees
|
|
1974
|
+
leaf_trees
|
|
1325
1975
|
The leaf values.
|
|
1326
|
-
moves
|
|
1327
|
-
The proposed moves, see `
|
|
1976
|
+
moves
|
|
1977
|
+
The proposed moves, see `propose_moves`.
|
|
1328
1978
|
|
|
1329
1979
|
Returns
|
|
1330
1980
|
-------
|
|
1331
|
-
|
|
1332
|
-
The modified leaf values. The value of the leaf to grow is copied to
|
|
1333
|
-
what would be its children if the grow move was accepted.
|
|
1981
|
+
The modified leaf values.
|
|
1334
1982
|
"""
|
|
1335
|
-
values_at_node = leaf_trees[moves
|
|
1983
|
+
values_at_node = leaf_trees[moves.node]
|
|
1336
1984
|
return (
|
|
1337
|
-
leaf_trees.at[jnp.where(moves
|
|
1985
|
+
leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
|
|
1338
1986
|
.set(values_at_node)
|
|
1339
|
-
.at[jnp.where(moves
|
|
1987
|
+
.at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
|
|
1340
1988
|
.set(values_at_node)
|
|
1341
1989
|
)
|
|
1342
1990
|
|
|
1343
1991
|
|
|
1344
|
-
def precompute_likelihood_terms(
|
|
1992
|
+
def precompute_likelihood_terms(
|
|
1993
|
+
sigma2: Float32[Array, ''],
|
|
1994
|
+
sigma_mu2: Float32[Array, ''],
|
|
1995
|
+
move_precs: Precs | Counts,
|
|
1996
|
+
) -> tuple[PreLkV, PreLk]:
|
|
1345
1997
|
"""
|
|
1346
1998
|
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1347
1999
|
|
|
1348
2000
|
Parameters
|
|
1349
2001
|
----------
|
|
1350
|
-
sigma2
|
|
1351
|
-
The
|
|
1352
|
-
|
|
2002
|
+
sigma2
|
|
2003
|
+
The error variance, or the global error variance factor is `prec_scale`
|
|
2004
|
+
is set.
|
|
2005
|
+
sigma_mu2
|
|
2006
|
+
The prior variance of each leaf.
|
|
2007
|
+
move_precs
|
|
1353
2008
|
The likelihood precision scale in the leaves grown or pruned by the
|
|
1354
2009
|
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1355
2010
|
|
|
1356
2011
|
Returns
|
|
1357
2012
|
-------
|
|
1358
|
-
prelkv :
|
|
2013
|
+
prelkv : PreLkV
|
|
1359
2014
|
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
1360
2015
|
tree.
|
|
1361
|
-
prelk :
|
|
2016
|
+
prelk : PreLk
|
|
1362
2017
|
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1363
2018
|
all trees.
|
|
1364
2019
|
"""
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
prelkv
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
jnp.log(
|
|
1373
|
-
sigma2
|
|
1374
|
-
* prelkv['sigma2_total']
|
|
1375
|
-
/ (prelkv['sigma2_left'] * prelkv['sigma2_right'])
|
|
1376
|
-
)
|
|
1377
|
-
/ 2
|
|
1378
|
-
)
|
|
1379
|
-
return prelkv, dict(
|
|
1380
|
-
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
2020
|
+
sigma2_left = sigma2 + move_precs.left * sigma_mu2
|
|
2021
|
+
sigma2_right = sigma2 + move_precs.right * sigma_mu2
|
|
2022
|
+
sigma2_total = sigma2 + move_precs.total * sigma_mu2
|
|
2023
|
+
prelkv = PreLkV(
|
|
2024
|
+
sigma2_left=sigma2_left,
|
|
2025
|
+
sigma2_right=sigma2_right,
|
|
2026
|
+
sigma2_total=sigma2_total,
|
|
2027
|
+
sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
|
|
1381
2028
|
)
|
|
2029
|
+
return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
|
|
1382
2030
|
|
|
1383
2031
|
|
|
1384
|
-
def precompute_leaf_terms(
|
|
2032
|
+
def precompute_leaf_terms(
|
|
2033
|
+
key: Key[Array, ''],
|
|
2034
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
2035
|
+
sigma2: Float32[Array, ''],
|
|
2036
|
+
sigma_mu2: Float32[Array, ''],
|
|
2037
|
+
) -> PreLf:
|
|
1385
2038
|
"""
|
|
1386
2039
|
Pre-compute terms used to sample leaves from their posterior.
|
|
1387
2040
|
|
|
1388
2041
|
Parameters
|
|
1389
2042
|
----------
|
|
1390
|
-
key
|
|
2043
|
+
key
|
|
1391
2044
|
A jax random key.
|
|
1392
|
-
prec_trees
|
|
2045
|
+
prec_trees
|
|
1393
2046
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1394
|
-
sigma2
|
|
1395
|
-
The
|
|
2047
|
+
sigma2
|
|
2048
|
+
The error variance, or the global error variance factor if `prec_scale`
|
|
2049
|
+
is set.
|
|
2050
|
+
sigma_mu2
|
|
2051
|
+
The prior variance of each leaf.
|
|
1396
2052
|
|
|
1397
2053
|
Returns
|
|
1398
2054
|
-------
|
|
1399
|
-
|
|
1400
|
-
Dictionary with pre-computed terms of the leaf sampling, with fields:
|
|
1401
|
-
|
|
1402
|
-
'mean_factor' : float array (num_trees, 2 ** d)
|
|
1403
|
-
The factor to be multiplied by the sum of the scaled residuals to
|
|
1404
|
-
obtain the posterior mean.
|
|
1405
|
-
'centered_leaves' : float array (num_trees, 2 ** d)
|
|
1406
|
-
The mean-zero normal values to be added to the posterior mean to
|
|
1407
|
-
obtain the posterior leaf samples.
|
|
2055
|
+
Pre-computed terms for leaf sampling.
|
|
1408
2056
|
"""
|
|
1409
|
-
ntree = len(prec_trees)
|
|
1410
2057
|
prec_lk = prec_trees / sigma2
|
|
1411
|
-
|
|
2058
|
+
prec_prior = lax.reciprocal(sigma_mu2)
|
|
2059
|
+
var_post = lax.reciprocal(prec_lk + prec_prior)
|
|
1412
2060
|
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
1413
|
-
return
|
|
1414
|
-
mean_factor=var_post / sigma2,
|
|
2061
|
+
return PreLf(
|
|
2062
|
+
mean_factor=var_post / sigma2,
|
|
2063
|
+
# | mean = mean_lk * prec_lk * var_post
|
|
2064
|
+
# | resid_tree = mean_lk * prec_tree -->
|
|
2065
|
+
# | --> mean_lk = resid_tree / prec_tree (kind of)
|
|
2066
|
+
# | mean_factor =
|
|
2067
|
+
# | = mean / resid_tree =
|
|
2068
|
+
# | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
2069
|
+
# | = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
2070
|
+
# | = var_post / sigma2
|
|
1415
2071
|
centered_leaves=z * jnp.sqrt(var_post),
|
|
1416
2072
|
)
|
|
1417
2073
|
|
|
1418
2074
|
|
|
1419
|
-
def accept_moves_sequential_stage(
|
|
1420
|
-
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
1421
|
-
):
|
|
2075
|
+
def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
1422
2076
|
"""
|
|
1423
|
-
|
|
2077
|
+
Accept/reject the moves one tree at a time.
|
|
2078
|
+
|
|
2079
|
+
This is the most performance-sensitive function because it contains all and
|
|
2080
|
+
only the parts of the algorithm that can not be parallelized across trees.
|
|
1424
2081
|
|
|
1425
2082
|
Parameters
|
|
1426
2083
|
----------
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
prec_trees : float array (num_trees, 2 ** d)
|
|
1430
|
-
The likelihood precision scale in each potential or actual leaf node.
|
|
1431
|
-
moves : dict
|
|
1432
|
-
The proposed moves, see `sample_moves`.
|
|
1433
|
-
move_counts : dict
|
|
1434
|
-
The counts of the number of points in the the nodes modified by the
|
|
1435
|
-
moves.
|
|
1436
|
-
move_precs : dict
|
|
1437
|
-
The likelihood precision scale in each node modified by the moves.
|
|
1438
|
-
prelkv, prelk, prelf : dict
|
|
1439
|
-
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1440
|
-
samples.
|
|
2084
|
+
pso
|
|
2085
|
+
The output of `accept_moves_parallel_stage`.
|
|
1441
2086
|
|
|
1442
2087
|
Returns
|
|
1443
2088
|
-------
|
|
1444
|
-
bart :
|
|
2089
|
+
bart : State
|
|
1445
2090
|
A partially updated BART mcmc state.
|
|
1446
|
-
moves :
|
|
1447
|
-
The
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
Whether, to reflect the acceptance status of the move, the state
|
|
1453
|
-
should be updated by pruning the leaves involved in the move.
|
|
1454
|
-
"""
|
|
1455
|
-
bart = bart.copy()
|
|
1456
|
-
moves = moves.copy()
|
|
1457
|
-
|
|
1458
|
-
def loop(resid, item):
|
|
1459
|
-
resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
|
|
1460
|
-
bart['X'],
|
|
1461
|
-
len(bart['leaf_trees']),
|
|
1462
|
-
bart['opt']['resid_batch_size'],
|
|
2091
|
+
moves : Moves
|
|
2092
|
+
The accepted/rejected moves, with `acc` and `to_prune` set.
|
|
2093
|
+
"""
|
|
2094
|
+
|
|
2095
|
+
def loop(resid, pt):
|
|
2096
|
+
resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
|
|
1463
2097
|
resid,
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
2098
|
+
SeqStageInAllTrees(
|
|
2099
|
+
pso.bart.X,
|
|
2100
|
+
pso.bart.forest.resid_batch_size,
|
|
2101
|
+
pso.bart.prec_scale,
|
|
2102
|
+
pso.bart.forest.log_likelihood is not None,
|
|
2103
|
+
pso.prelk,
|
|
2104
|
+
),
|
|
2105
|
+
pt,
|
|
1469
2106
|
)
|
|
1470
|
-
return resid, (leaf_tree, acc, to_prune,
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
bart
|
|
1474
|
-
prec_trees,
|
|
1475
|
-
moves,
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
prelf,
|
|
2107
|
+
return resid, (leaf_tree, acc, to_prune, lkratio)
|
|
2108
|
+
|
|
2109
|
+
pts = SeqStageInPerTree(
|
|
2110
|
+
pso.bart.forest.leaf_tree,
|
|
2111
|
+
pso.prec_trees,
|
|
2112
|
+
pso.moves,
|
|
2113
|
+
pso.move_precs,
|
|
2114
|
+
pso.bart.forest.leaf_indices,
|
|
2115
|
+
pso.prelkv,
|
|
2116
|
+
pso.prelf,
|
|
1481
2117
|
)
|
|
1482
|
-
resid, (leaf_trees, acc, to_prune,
|
|
2118
|
+
resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
|
|
1483
2119
|
|
|
1484
|
-
bart
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
2120
|
+
bart = replace(
|
|
2121
|
+
pso.bart,
|
|
2122
|
+
resid=resid,
|
|
2123
|
+
forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
|
|
2124
|
+
)
|
|
2125
|
+
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
1489
2126
|
|
|
1490
2127
|
return bart, moves
|
|
1491
2128
|
|
|
1492
2129
|
|
|
1493
|
-
|
|
1494
|
-
X,
|
|
1495
|
-
ntree,
|
|
1496
|
-
resid_batch_size,
|
|
1497
|
-
resid,
|
|
1498
|
-
prec_scale,
|
|
1499
|
-
min_points_per_leaf,
|
|
1500
|
-
save_ratios,
|
|
1501
|
-
prelk,
|
|
1502
|
-
leaf_tree,
|
|
1503
|
-
prec_tree,
|
|
1504
|
-
move,
|
|
1505
|
-
move_counts,
|
|
1506
|
-
move_precs,
|
|
1507
|
-
leaf_indices,
|
|
1508
|
-
prelkv,
|
|
1509
|
-
prelf,
|
|
1510
|
-
):
|
|
2130
|
+
class SeqStageInAllTrees(Module):
|
|
1511
2131
|
"""
|
|
1512
|
-
|
|
2132
|
+
The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
|
|
1513
2133
|
|
|
1514
2134
|
Parameters
|
|
1515
2135
|
----------
|
|
1516
|
-
X
|
|
2136
|
+
X
|
|
1517
2137
|
The predictors.
|
|
1518
|
-
|
|
1519
|
-
The number of trees in the forest.
|
|
1520
|
-
resid_batch_size : int, None
|
|
2138
|
+
resid_batch_size
|
|
1521
2139
|
The batch size for computing the sum of residuals in each leaf.
|
|
1522
|
-
|
|
1523
|
-
The residuals (data minus forest value).
|
|
1524
|
-
prec_scale : float array (n,) or None
|
|
2140
|
+
prec_scale
|
|
1525
2141
|
The scale of the precision of the error on each datapoint. If None, it
|
|
1526
2142
|
is assumed to be 1.
|
|
1527
|
-
|
|
1528
|
-
The minimum number of data points in a leaf node.
|
|
1529
|
-
save_ratios : bool
|
|
2143
|
+
save_ratios
|
|
1530
2144
|
Whether to save the acceptance ratios.
|
|
1531
|
-
prelk
|
|
2145
|
+
prelk
|
|
1532
2146
|
The pre-computed terms of the likelihood ratio which are shared across
|
|
1533
2147
|
trees.
|
|
1534
|
-
|
|
2148
|
+
"""
|
|
2149
|
+
|
|
2150
|
+
X: UInt[Array, 'p n']
|
|
2151
|
+
resid_batch_size: int | None = field(static=True)
|
|
2152
|
+
prec_scale: Float32[Array, ' n'] | None
|
|
2153
|
+
save_ratios: bool = field(static=True)
|
|
2154
|
+
prelk: PreLk
|
|
2155
|
+
|
|
2156
|
+
|
|
2157
|
+
class SeqStageInPerTree(Module):
|
|
2158
|
+
"""
|
|
2159
|
+
The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
|
|
2160
|
+
|
|
2161
|
+
Parameters
|
|
2162
|
+
----------
|
|
2163
|
+
leaf_tree
|
|
1535
2164
|
The leaf values of the tree.
|
|
1536
|
-
prec_tree
|
|
2165
|
+
prec_tree
|
|
1537
2166
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1538
|
-
move
|
|
1539
|
-
The proposed move, see `
|
|
1540
|
-
|
|
1541
|
-
The counts of the number of points in the the nodes modified by the
|
|
1542
|
-
moves.
|
|
1543
|
-
move_precs : dict
|
|
2167
|
+
move
|
|
2168
|
+
The proposed move, see `propose_moves`.
|
|
2169
|
+
move_precs
|
|
1544
2170
|
The likelihood precision scale in each node modified by the moves.
|
|
1545
|
-
leaf_indices
|
|
2171
|
+
leaf_indices
|
|
1546
2172
|
The leaf indices for the largest version of the tree compatible with
|
|
1547
2173
|
the move.
|
|
1548
|
-
prelkv
|
|
2174
|
+
prelkv
|
|
2175
|
+
prelf
|
|
1549
2176
|
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
1550
2177
|
are specific to the tree.
|
|
2178
|
+
"""
|
|
2179
|
+
|
|
2180
|
+
leaf_tree: Float32[Array, ' 2**d']
|
|
2181
|
+
prec_tree: Float32[Array, ' 2**d']
|
|
2182
|
+
move: Moves
|
|
2183
|
+
move_precs: Precs | Counts
|
|
2184
|
+
leaf_indices: UInt[Array, ' n']
|
|
2185
|
+
prelkv: PreLkV
|
|
2186
|
+
prelf: PreLf
|
|
2187
|
+
|
|
2188
|
+
|
|
2189
|
+
def accept_move_and_sample_leaves(
|
|
2190
|
+
resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
|
|
2191
|
+
) -> tuple[
|
|
2192
|
+
Float32[Array, ' n'],
|
|
2193
|
+
Float32[Array, ' 2**d'],
|
|
2194
|
+
Bool[Array, ''],
|
|
2195
|
+
Bool[Array, ''],
|
|
2196
|
+
Float32[Array, ''] | None,
|
|
2197
|
+
]:
|
|
2198
|
+
"""
|
|
2199
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
2200
|
+
|
|
2201
|
+
Parameters
|
|
2202
|
+
----------
|
|
2203
|
+
resid
|
|
2204
|
+
The residuals (data minus forest value).
|
|
2205
|
+
at
|
|
2206
|
+
The inputs that are the same for all trees.
|
|
2207
|
+
pt
|
|
2208
|
+
The inputs that are separate for each tree.
|
|
1551
2209
|
|
|
1552
2210
|
Returns
|
|
1553
2211
|
-------
|
|
1554
|
-
resid :
|
|
2212
|
+
resid : Float32[Array, 'n']
|
|
1555
2213
|
The updated residuals (data minus forest value).
|
|
1556
|
-
leaf_tree :
|
|
2214
|
+
leaf_tree : Float32[Array, '2**d']
|
|
1557
2215
|
The new leaf values of the tree.
|
|
1558
|
-
acc :
|
|
2216
|
+
acc : Bool[Array, '']
|
|
1559
2217
|
Whether the move was accepted.
|
|
1560
|
-
to_prune :
|
|
2218
|
+
to_prune : Bool[Array, '']
|
|
1561
2219
|
Whether, to reflect the acceptance status of the move, the state should
|
|
1562
2220
|
be updated by pruning the leaves involved in the move.
|
|
1563
|
-
|
|
1564
|
-
The
|
|
2221
|
+
log_lk_ratio : Float32[Array, ''] | None
|
|
2222
|
+
The logarithm of the likelihood ratio for the move. `None` if not to be
|
|
2223
|
+
saved.
|
|
1565
2224
|
"""
|
|
1566
|
-
|
|
1567
2225
|
# sum residuals in each leaf, in tree proposed by grow move
|
|
1568
|
-
if prec_scale is None:
|
|
2226
|
+
if at.prec_scale is None:
|
|
1569
2227
|
scaled_resid = resid
|
|
1570
2228
|
else:
|
|
1571
|
-
scaled_resid = resid * prec_scale
|
|
1572
|
-
resid_tree = sum_resid(
|
|
2229
|
+
scaled_resid = resid * at.prec_scale
|
|
2230
|
+
resid_tree = sum_resid(
|
|
2231
|
+
scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
|
|
2232
|
+
)
|
|
1573
2233
|
|
|
1574
2234
|
# subtract starting tree from function
|
|
1575
|
-
resid_tree += prec_tree * leaf_tree
|
|
1576
|
-
|
|
1577
|
-
# get indices of move
|
|
1578
|
-
node = move['node']
|
|
1579
|
-
assert node.dtype == jnp.int32
|
|
1580
|
-
left = move['left']
|
|
1581
|
-
right = move['right']
|
|
2235
|
+
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
1582
2236
|
|
|
1583
2237
|
# sum residuals in parent node modified by move
|
|
1584
|
-
resid_left = resid_tree[left]
|
|
1585
|
-
resid_right = resid_tree[right]
|
|
2238
|
+
resid_left = resid_tree[pt.move.left]
|
|
2239
|
+
resid_right = resid_tree[pt.move.right]
|
|
1586
2240
|
resid_total = resid_left + resid_right
|
|
1587
|
-
|
|
2241
|
+
assert pt.move.node.dtype == jnp.int32
|
|
2242
|
+
resid_tree = resid_tree.at[pt.move.node].set(resid_total)
|
|
1588
2243
|
|
|
1589
2244
|
# compute acceptance ratio
|
|
1590
2245
|
log_lk_ratio = compute_likelihood_ratio(
|
|
1591
|
-
resid_total, resid_left, resid_right, prelkv, prelk
|
|
2246
|
+
resid_total, resid_left, resid_right, pt.prelkv, at.prelk
|
|
1592
2247
|
)
|
|
1593
|
-
log_ratio = move
|
|
1594
|
-
log_ratio = jnp.where(move
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
ratios.update(
|
|
1598
|
-
log_trans_prior=move['log_trans_prior_ratio'],
|
|
1599
|
-
log_likelihood=log_lk_ratio,
|
|
1600
|
-
)
|
|
2248
|
+
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
2249
|
+
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
2250
|
+
if not at.save_ratios:
|
|
2251
|
+
log_lk_ratio = None
|
|
1601
2252
|
|
|
1602
2253
|
# determine whether to accept the move
|
|
1603
|
-
acc = move
|
|
1604
|
-
if min_points_per_leaf is not None:
|
|
1605
|
-
acc &= move_counts['left'] >= min_points_per_leaf
|
|
1606
|
-
acc &= move_counts['right'] >= min_points_per_leaf
|
|
2254
|
+
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
1607
2255
|
|
|
1608
2256
|
# compute leaves posterior and sample leaves
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
leaf_tree = mean_post + prelf['centered_leaves']
|
|
2257
|
+
mean_post = resid_tree * pt.prelf.mean_factor
|
|
2258
|
+
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
1612
2259
|
|
|
1613
2260
|
# copy leaves around such that the leaf indices point to the correct leaf
|
|
1614
|
-
to_prune = acc ^ move
|
|
2261
|
+
to_prune = acc ^ pt.move.grow
|
|
1615
2262
|
leaf_tree = (
|
|
1616
|
-
leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1617
|
-
.set(leaf_tree[node])
|
|
1618
|
-
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
1619
|
-
.set(leaf_tree[node])
|
|
2263
|
+
leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
|
|
2264
|
+
.set(leaf_tree[pt.move.node])
|
|
2265
|
+
.at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
|
|
2266
|
+
.set(leaf_tree[pt.move.node])
|
|
1620
2267
|
)
|
|
1621
2268
|
|
|
1622
2269
|
# replace old tree with new tree in function values
|
|
1623
|
-
resid += (
|
|
2270
|
+
resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
|
|
1624
2271
|
|
|
1625
|
-
return resid, leaf_tree, acc, to_prune,
|
|
2272
|
+
return resid, leaf_tree, acc, to_prune, log_lk_ratio
|
|
1626
2273
|
|
|
1627
2274
|
|
|
1628
|
-
def sum_resid(
|
|
2275
|
+
def sum_resid(
|
|
2276
|
+
scaled_resid: Float32[Array, ' n'],
|
|
2277
|
+
leaf_indices: UInt[Array, ' n'],
|
|
2278
|
+
tree_size: int,
|
|
2279
|
+
batch_size: int | None,
|
|
2280
|
+
) -> Float32[Array, ' {tree_size}']:
|
|
1629
2281
|
"""
|
|
1630
2282
|
Sum the residuals in each leaf.
|
|
1631
2283
|
|
|
1632
2284
|
Parameters
|
|
1633
2285
|
----------
|
|
1634
|
-
scaled_resid
|
|
2286
|
+
scaled_resid
|
|
1635
2287
|
The residuals (data minus forest value) multiplied by the error
|
|
1636
2288
|
precision scale.
|
|
1637
|
-
leaf_indices
|
|
2289
|
+
leaf_indices
|
|
1638
2290
|
The leaf indices of the tree (in which leaf each data point falls into).
|
|
1639
|
-
tree_size
|
|
2291
|
+
tree_size
|
|
1640
2292
|
The size of the tree array (2 ** d).
|
|
1641
|
-
batch_size
|
|
2293
|
+
batch_size
|
|
1642
2294
|
The data batch size for the aggregation. Batching increases numerical
|
|
1643
2295
|
accuracy and parallelism.
|
|
1644
2296
|
|
|
1645
2297
|
Returns
|
|
1646
2298
|
-------
|
|
1647
|
-
|
|
1648
|
-
The sum of the residuals at data points in each leaf.
|
|
2299
|
+
The sum of the residuals at data points in each leaf.
|
|
1649
2300
|
"""
|
|
1650
2301
|
if batch_size is None:
|
|
1651
2302
|
aggr_func = _aggregate_scatter
|
|
1652
2303
|
else:
|
|
1653
|
-
aggr_func =
|
|
1654
|
-
return aggr_func(
|
|
1655
|
-
scaled_resid, leaf_indices, tree_size, jnp.float32
|
|
1656
|
-
) # TODO: use large_float
|
|
2304
|
+
aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
2305
|
+
return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
|
|
1657
2306
|
|
|
1658
2307
|
|
|
1659
|
-
def _aggregate_batched_onetree(
|
|
2308
|
+
def _aggregate_batched_onetree(
|
|
2309
|
+
values: Shaped[Array, '*'],
|
|
2310
|
+
indices: Integer[Array, '*'],
|
|
2311
|
+
size: int,
|
|
2312
|
+
dtype: jnp.dtype,
|
|
2313
|
+
batch_size: int,
|
|
2314
|
+
) -> Float32[Array, ' {size}']:
|
|
1660
2315
|
(n,) = indices.shape
|
|
1661
2316
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1662
2317
|
batch_indices = jnp.arange(n) % nbatches
|
|
@@ -1668,153 +2323,294 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
|
1668
2323
|
)
|
|
1669
2324
|
|
|
1670
2325
|
|
|
1671
|
-
def compute_likelihood_ratio(
|
|
2326
|
+
def compute_likelihood_ratio(
|
|
2327
|
+
total_resid: Float32[Array, ''],
|
|
2328
|
+
left_resid: Float32[Array, ''],
|
|
2329
|
+
right_resid: Float32[Array, ''],
|
|
2330
|
+
prelkv: PreLkV,
|
|
2331
|
+
prelk: PreLk,
|
|
2332
|
+
) -> Float32[Array, '']:
|
|
1672
2333
|
"""
|
|
1673
2334
|
Compute the likelihood ratio of a grow move.
|
|
1674
2335
|
|
|
1675
2336
|
Parameters
|
|
1676
2337
|
----------
|
|
1677
|
-
total_resid
|
|
2338
|
+
total_resid
|
|
2339
|
+
left_resid
|
|
2340
|
+
right_resid
|
|
1678
2341
|
The sum of the residuals (scaled by error precision scale) of the
|
|
1679
2342
|
datapoints falling in the nodes involved in the moves.
|
|
1680
|
-
prelkv
|
|
2343
|
+
prelkv
|
|
2344
|
+
prelk
|
|
1681
2345
|
The pre-computed terms of the likelihood ratio, see
|
|
1682
2346
|
`precompute_likelihood_terms`.
|
|
1683
2347
|
|
|
1684
2348
|
Returns
|
|
1685
2349
|
-------
|
|
1686
|
-
ratio
|
|
1687
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
2350
|
+
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1688
2351
|
"""
|
|
1689
|
-
exp_term = prelk
|
|
1690
|
-
left_resid * left_resid / prelkv
|
|
1691
|
-
+ right_resid * right_resid / prelkv
|
|
1692
|
-
- total_resid * total_resid / prelkv
|
|
2352
|
+
exp_term = prelk.exp_factor * (
|
|
2353
|
+
left_resid * left_resid / prelkv.sigma2_left
|
|
2354
|
+
+ right_resid * right_resid / prelkv.sigma2_right
|
|
2355
|
+
- total_resid * total_resid / prelkv.sigma2_total
|
|
1693
2356
|
)
|
|
1694
|
-
return prelkv
|
|
2357
|
+
return prelkv.sqrt_term + exp_term
|
|
1695
2358
|
|
|
1696
2359
|
|
|
1697
|
-
def accept_moves_final_stage(bart, moves):
|
|
2360
|
+
def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
1698
2361
|
"""
|
|
1699
|
-
|
|
2362
|
+
Post-process the mcmc state after accepting/rejecting the moves.
|
|
2363
|
+
|
|
2364
|
+
This function is separate from `accept_moves_sequential_stage` to signal it
|
|
2365
|
+
can work in parallel across trees.
|
|
1700
2366
|
|
|
1701
2367
|
Parameters
|
|
1702
2368
|
----------
|
|
1703
|
-
bart
|
|
2369
|
+
bart
|
|
1704
2370
|
A partially updated BART mcmc state.
|
|
1705
|
-
|
|
1706
|
-
The
|
|
1707
|
-
moves : dict
|
|
1708
|
-
The proposed moves (see `sample_moves`) as updated by
|
|
2371
|
+
moves
|
|
2372
|
+
The proposed moves (see `propose_moves`) as updated by
|
|
1709
2373
|
`accept_moves_sequential_stage`.
|
|
1710
2374
|
|
|
1711
2375
|
Returns
|
|
1712
2376
|
-------
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
2377
|
+
The fully updated BART mcmc state.
|
|
2378
|
+
"""
|
|
2379
|
+
return replace(
|
|
2380
|
+
bart,
|
|
2381
|
+
forest=replace(
|
|
2382
|
+
bart.forest,
|
|
2383
|
+
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
2384
|
+
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
2385
|
+
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
2386
|
+
split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
|
|
2387
|
+
),
|
|
2388
|
+
)
|
|
1722
2389
|
|
|
1723
2390
|
|
|
1724
|
-
@
|
|
1725
|
-
def apply_moves_to_leaf_indices(
|
|
2391
|
+
@vmap_nodoc
|
|
2392
|
+
def apply_moves_to_leaf_indices(
|
|
2393
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
|
|
2394
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1726
2395
|
"""
|
|
1727
2396
|
Update the leaf indices to match the accepted move.
|
|
1728
2397
|
|
|
1729
2398
|
Parameters
|
|
1730
2399
|
----------
|
|
1731
|
-
leaf_indices
|
|
2400
|
+
leaf_indices
|
|
1732
2401
|
The index of the leaf each datapoint falls into, if the grow move was
|
|
1733
2402
|
accepted.
|
|
1734
|
-
moves
|
|
1735
|
-
The proposed moves (see `
|
|
2403
|
+
moves
|
|
2404
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1736
2405
|
`accept_moves_sequential_stage`.
|
|
1737
2406
|
|
|
1738
2407
|
Returns
|
|
1739
2408
|
-------
|
|
1740
|
-
|
|
1741
|
-
The updated leaf indices.
|
|
2409
|
+
The updated leaf indices.
|
|
1742
2410
|
"""
|
|
1743
2411
|
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1744
|
-
is_child = (leaf_indices & mask) == moves
|
|
2412
|
+
is_child = (leaf_indices & mask) == moves.left
|
|
1745
2413
|
return jnp.where(
|
|
1746
|
-
is_child & moves
|
|
1747
|
-
moves['node'].astype(leaf_indices.dtype),
|
|
1748
|
-
leaf_indices,
|
|
2414
|
+
is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
|
|
1749
2415
|
)
|
|
1750
2416
|
|
|
1751
2417
|
|
|
1752
|
-
@
|
|
1753
|
-
def apply_moves_to_split_trees(
|
|
2418
|
+
@vmap_nodoc
|
|
2419
|
+
def apply_moves_to_split_trees(
|
|
2420
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
2421
|
+
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
1754
2422
|
"""
|
|
1755
2423
|
Update the split trees to match the accepted move.
|
|
1756
2424
|
|
|
1757
2425
|
Parameters
|
|
1758
2426
|
----------
|
|
1759
|
-
|
|
2427
|
+
split_tree
|
|
1760
2428
|
The cutpoints of the decision nodes in the initial trees.
|
|
1761
|
-
moves
|
|
1762
|
-
The proposed moves (see `
|
|
2429
|
+
moves
|
|
2430
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1763
2431
|
`accept_moves_sequential_stage`.
|
|
1764
2432
|
|
|
1765
2433
|
Returns
|
|
1766
2434
|
-------
|
|
1767
|
-
|
|
1768
|
-
The updated split trees.
|
|
2435
|
+
The updated split trees.
|
|
1769
2436
|
"""
|
|
2437
|
+
assert moves.to_prune is not None
|
|
1770
2438
|
return (
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
moves['node'],
|
|
1775
|
-
split_trees.size,
|
|
1776
|
-
)
|
|
1777
|
-
]
|
|
1778
|
-
.set(moves['grow_split'].astype(split_trees.dtype))
|
|
1779
|
-
.at[
|
|
1780
|
-
jnp.where(
|
|
1781
|
-
moves['to_prune'],
|
|
1782
|
-
moves['node'],
|
|
1783
|
-
split_trees.size,
|
|
1784
|
-
)
|
|
1785
|
-
]
|
|
2439
|
+
split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
|
|
2440
|
+
.set(moves.grow_split.astype(split_tree.dtype))
|
|
2441
|
+
.at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
|
|
1786
2442
|
.set(0)
|
|
1787
2443
|
)
|
|
1788
2444
|
|
|
1789
2445
|
|
|
1790
|
-
def
|
|
2446
|
+
def step_sigma(key: Key[Array, ''], bart: State) -> State:
|
|
1791
2447
|
"""
|
|
1792
|
-
|
|
2448
|
+
MCMC-update the error variance (factor).
|
|
1793
2449
|
|
|
1794
2450
|
Parameters
|
|
1795
2451
|
----------
|
|
1796
|
-
key
|
|
2452
|
+
key
|
|
1797
2453
|
A jax random key.
|
|
1798
|
-
bart
|
|
1799
|
-
A BART mcmc state
|
|
2454
|
+
bart
|
|
2455
|
+
A BART mcmc state.
|
|
1800
2456
|
|
|
1801
2457
|
Returns
|
|
1802
2458
|
-------
|
|
1803
|
-
|
|
1804
|
-
The new BART mcmc state.
|
|
2459
|
+
The new BART mcmc state, with an updated `sigma2`.
|
|
1805
2460
|
"""
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1810
|
-
if bart['prec_scale'] is None:
|
|
2461
|
+
resid = bart.resid
|
|
2462
|
+
alpha = bart.sigma2_alpha + resid.size / 2
|
|
2463
|
+
if bart.prec_scale is None:
|
|
1811
2464
|
scaled_resid = resid
|
|
1812
2465
|
else:
|
|
1813
|
-
scaled_resid = resid * bart
|
|
2466
|
+
scaled_resid = resid * bart.prec_scale
|
|
1814
2467
|
norm2 = resid @ scaled_resid
|
|
1815
|
-
beta = bart
|
|
2468
|
+
beta = bart.sigma2_beta + norm2 / 2
|
|
1816
2469
|
|
|
1817
2470
|
sample = random.gamma(key, alpha)
|
|
1818
|
-
|
|
2471
|
+
# random.gamma seems to be slow at compiling, maybe cdf inversion would
|
|
2472
|
+
# be better, but it's not implemented in jax
|
|
2473
|
+
return replace(bart, sigma2=beta / sample)
|
|
2474
|
+
|
|
2475
|
+
|
|
2476
|
+
def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
2477
|
+
"""
|
|
2478
|
+
MCMC-update the latent variable for binary regression.
|
|
2479
|
+
|
|
2480
|
+
Parameters
|
|
2481
|
+
----------
|
|
2482
|
+
key
|
|
2483
|
+
A jax random key.
|
|
2484
|
+
bart
|
|
2485
|
+
A BART MCMC state.
|
|
2486
|
+
|
|
2487
|
+
Returns
|
|
2488
|
+
-------
|
|
2489
|
+
The updated BART MCMC state.
|
|
2490
|
+
"""
|
|
2491
|
+
trees_plus_offset = bart.z - bart.resid
|
|
2492
|
+
assert bart.y.dtype == bool
|
|
2493
|
+
resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
|
|
2494
|
+
z = trees_plus_offset + resid
|
|
2495
|
+
return replace(bart, z=z, resid=resid)
|
|
2496
|
+
|
|
2497
|
+
|
|
2498
|
+
def step_s(key: Key[Array, ''], bart: State) -> State:
|
|
2499
|
+
"""
|
|
2500
|
+
Update `log_s` using Dirichlet sampling.
|
|
2501
|
+
|
|
2502
|
+
The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
|
|
2503
|
+
is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
|
|
2504
|
+
varcount is the count of how many times each variable is used in the
|
|
2505
|
+
current forest.
|
|
2506
|
+
|
|
2507
|
+
Parameters
|
|
2508
|
+
----------
|
|
2509
|
+
key
|
|
2510
|
+
Random key for sampling.
|
|
2511
|
+
bart
|
|
2512
|
+
The current BART state.
|
|
2513
|
+
|
|
2514
|
+
Returns
|
|
2515
|
+
-------
|
|
2516
|
+
Updated BART state with re-sampled `log_s`.
|
|
2517
|
+
|
|
2518
|
+
"""
|
|
2519
|
+
assert bart.forest.theta is not None
|
|
2520
|
+
|
|
2521
|
+
# histogram current variable usage
|
|
2522
|
+
p = bart.forest.max_split.size
|
|
2523
|
+
varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
|
|
2524
|
+
|
|
2525
|
+
# sample from Dirichlet posterior
|
|
2526
|
+
alpha = bart.forest.theta / p + varcount
|
|
2527
|
+
log_s = random.loggamma(key, alpha)
|
|
2528
|
+
|
|
2529
|
+
# update forest with new s
|
|
2530
|
+
return replace(bart, forest=replace(bart.forest, log_s=log_s))
|
|
2531
|
+
|
|
1819
2532
|
|
|
2533
|
+
def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
|
|
2534
|
+
"""
|
|
2535
|
+
Update `theta`.
|
|
2536
|
+
|
|
2537
|
+
The prior is theta / (theta + rho) ~ Beta(a, b).
|
|
2538
|
+
|
|
2539
|
+
Parameters
|
|
2540
|
+
----------
|
|
2541
|
+
key
|
|
2542
|
+
Random key for sampling.
|
|
2543
|
+
bart
|
|
2544
|
+
The current BART state.
|
|
2545
|
+
num_grid
|
|
2546
|
+
The number of points in the evenly-spaced grid used to sample
|
|
2547
|
+
theta / (theta + rho).
|
|
2548
|
+
|
|
2549
|
+
Returns
|
|
2550
|
+
-------
|
|
2551
|
+
Updated BART state with re-sampled `theta`.
|
|
2552
|
+
"""
|
|
2553
|
+
assert bart.forest.log_s is not None
|
|
2554
|
+
assert bart.forest.rho is not None
|
|
2555
|
+
assert bart.forest.a is not None
|
|
2556
|
+
assert bart.forest.b is not None
|
|
2557
|
+
|
|
2558
|
+
# the grid points are the midpoints of num_grid bins in (0, 1)
|
|
2559
|
+
padding = 1 / (2 * num_grid)
|
|
2560
|
+
lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
|
|
2561
|
+
|
|
2562
|
+
# normalize s
|
|
2563
|
+
log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
|
|
2564
|
+
|
|
2565
|
+
# sample lambda
|
|
2566
|
+
logp, theta_grid = _log_p_lamda(
|
|
2567
|
+
lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
|
|
2568
|
+
)
|
|
2569
|
+
i = random.categorical(key, logp)
|
|
2570
|
+
theta = theta_grid[i]
|
|
2571
|
+
|
|
2572
|
+
return replace(bart, forest=replace(bart.forest, theta=theta))
|
|
2573
|
+
|
|
2574
|
+
|
|
2575
|
+
def _log_p_lamda(
|
|
2576
|
+
lamda: Float32[Array, ' num_grid'],
|
|
2577
|
+
log_s: Float32[Array, ' p'],
|
|
2578
|
+
rho: Float32[Array, ''],
|
|
2579
|
+
a: Float32[Array, ''],
|
|
2580
|
+
b: Float32[Array, ''],
|
|
2581
|
+
) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
|
|
2582
|
+
# in the following I use lamda[::-1] == 1 - lamda
|
|
2583
|
+
theta = rho * lamda / lamda[::-1]
|
|
2584
|
+
p = log_s.size
|
|
2585
|
+
return (
|
|
2586
|
+
(a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
|
|
2587
|
+
+ (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
|
|
2588
|
+
+ gammaln(theta)
|
|
2589
|
+
- p * gammaln(theta / p)
|
|
2590
|
+
+ theta / p * jnp.sum(log_s)
|
|
2591
|
+
), theta
|
|
2592
|
+
|
|
2593
|
+
|
|
2594
|
+
def step_sparse(key: Key[Array, ''], bart: State) -> State:
|
|
2595
|
+
"""
|
|
2596
|
+
Update the sparsity parameters.
|
|
2597
|
+
|
|
2598
|
+
This invokes `step_s`, and then `step_theta` only if the parameters of
|
|
2599
|
+
the theta prior are defined.
|
|
2600
|
+
|
|
2601
|
+
Parameters
|
|
2602
|
+
----------
|
|
2603
|
+
key
|
|
2604
|
+
Random key for sampling.
|
|
2605
|
+
bart
|
|
2606
|
+
The current BART state.
|
|
2607
|
+
|
|
2608
|
+
Returns
|
|
2609
|
+
-------
|
|
2610
|
+
Updated BART state with re-sampled `log_s` and `theta`.
|
|
2611
|
+
"""
|
|
2612
|
+
keys = split(key)
|
|
2613
|
+
bart = step_s(keys.pop(), bart)
|
|
2614
|
+
if bart.forest.rho is not None:
|
|
2615
|
+
bart = step_theta(keys.pop(), bart)
|
|
1820
2616
|
return bart
|