bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/mcmcstep.py
DELETED
|
@@ -1,2616 +0,0 @@
|
|
|
1
|
-
# bartz/src/bartz/mcmcstep.py
|
|
2
|
-
#
|
|
3
|
-
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
|
-
#
|
|
5
|
-
# This file is part of bartz.
|
|
6
|
-
#
|
|
7
|
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
-
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
-
# in the Software without restriction, including without limitation the rights
|
|
10
|
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
-
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
-
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
14
|
-
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
-
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
17
|
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
-
# SOFTWARE.
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
Functions that implement the BART posterior MCMC initialization and update step.
|
|
27
|
-
|
|
28
|
-
Functions that do MCMC steps operate by taking as input a bart state, and
|
|
29
|
-
outputting a new state. The inputs are not modified.
|
|
30
|
-
|
|
31
|
-
The entry points are:
|
|
32
|
-
|
|
33
|
-
- `State`: The dataclass that represents a BART MCMC state.
|
|
34
|
-
- `init`: Creates an initial `State` from data and configurations.
|
|
35
|
-
- `step`: Performs one full MCMC step on a `State`, returning a new `State`.
|
|
36
|
-
- `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
import math
|
|
40
|
-
from dataclasses import replace
|
|
41
|
-
from functools import cache, partial
|
|
42
|
-
from typing import Any, Literal
|
|
43
|
-
|
|
44
|
-
import jax
|
|
45
|
-
from equinox import Module, field, tree_at
|
|
46
|
-
from jax import lax, random
|
|
47
|
-
from jax import numpy as jnp
|
|
48
|
-
from jax.scipy.special import gammaln, logsumexp
|
|
49
|
-
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
|
|
50
|
-
|
|
51
|
-
from bartz import grove
|
|
52
|
-
from bartz.jaxext import (
|
|
53
|
-
minimal_unsigned_dtype,
|
|
54
|
-
split,
|
|
55
|
-
truncated_normal_onesided,
|
|
56
|
-
vmap_nodoc,
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class Forest(Module):
|
|
61
|
-
"""
|
|
62
|
-
Represents the MCMC state of a sum of trees.
|
|
63
|
-
|
|
64
|
-
Parameters
|
|
65
|
-
----------
|
|
66
|
-
leaf_tree
|
|
67
|
-
The leaf values.
|
|
68
|
-
var_tree
|
|
69
|
-
The decision axes.
|
|
70
|
-
split_tree
|
|
71
|
-
The decision boundaries.
|
|
72
|
-
affluence_tree
|
|
73
|
-
Marks leaves that can be grown.
|
|
74
|
-
max_split
|
|
75
|
-
The maximum split index for each predictor.
|
|
76
|
-
blocked_vars
|
|
77
|
-
Indices of variables that are not used. This shall include at least
|
|
78
|
-
the `i` such that ``max_split[i] == 0``, otherwise behavior is
|
|
79
|
-
undefined.
|
|
80
|
-
p_nonterminal
|
|
81
|
-
The prior probability of each node being nonterminal, conditional on
|
|
82
|
-
its ancestors. Includes the nodes at maximum depth which should be set
|
|
83
|
-
to 0.
|
|
84
|
-
p_propose_grow
|
|
85
|
-
The unnormalized probability of picking a leaf for a grow proposal.
|
|
86
|
-
leaf_indices
|
|
87
|
-
The index of the leaf each datapoints falls into, for each tree.
|
|
88
|
-
min_points_per_decision_node
|
|
89
|
-
The minimum number of data points in a decision node.
|
|
90
|
-
min_points_per_leaf
|
|
91
|
-
The minimum number of data points in a leaf node.
|
|
92
|
-
resid_batch_size
|
|
93
|
-
count_batch_size
|
|
94
|
-
The data batch sizes for computing the sufficient statistics. If `None`,
|
|
95
|
-
they are computed with no batching.
|
|
96
|
-
log_trans_prior
|
|
97
|
-
The log transition and prior Metropolis-Hastings ratio for the
|
|
98
|
-
proposed move on each tree.
|
|
99
|
-
log_likelihood
|
|
100
|
-
The log likelihood ratio.
|
|
101
|
-
grow_prop_count
|
|
102
|
-
prune_prop_count
|
|
103
|
-
The number of grow/prune proposals made during one full MCMC cycle.
|
|
104
|
-
grow_acc_count
|
|
105
|
-
prune_acc_count
|
|
106
|
-
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
-
sigma_mu2
|
|
108
|
-
The prior variance of a leaf, conditional on the tree structure.
|
|
109
|
-
log_s
|
|
110
|
-
The logarithm of the prior probability for choosing a variable to split
|
|
111
|
-
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
112
|
-
If `None`, use a uniform distribution.
|
|
113
|
-
theta
|
|
114
|
-
The concentration parameter for the Dirichlet prior on the variable
|
|
115
|
-
distribution `s`. Required only to update `s`.
|
|
116
|
-
a
|
|
117
|
-
b
|
|
118
|
-
rho
|
|
119
|
-
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
120
|
-
See `step_theta`.
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
leaf_tree: Float32[Array, 'num_trees 2**d']
|
|
124
|
-
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
125
|
-
split_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
126
|
-
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
127
|
-
max_split: UInt[Array, ' p']
|
|
128
|
-
blocked_vars: UInt[Array, ' k'] | None
|
|
129
|
-
p_nonterminal: Float32[Array, ' 2**d']
|
|
130
|
-
p_propose_grow: Float32[Array, ' 2**(d-1)']
|
|
131
|
-
leaf_indices: UInt[Array, 'num_trees n']
|
|
132
|
-
min_points_per_decision_node: Int32[Array, ''] | None
|
|
133
|
-
min_points_per_leaf: Int32[Array, ''] | None
|
|
134
|
-
resid_batch_size: int | None = field(static=True)
|
|
135
|
-
count_batch_size: int | None = field(static=True)
|
|
136
|
-
log_trans_prior: Float32[Array, ' num_trees'] | None
|
|
137
|
-
log_likelihood: Float32[Array, ' num_trees'] | None
|
|
138
|
-
grow_prop_count: Int32[Array, '']
|
|
139
|
-
prune_prop_count: Int32[Array, '']
|
|
140
|
-
grow_acc_count: Int32[Array, '']
|
|
141
|
-
prune_acc_count: Int32[Array, '']
|
|
142
|
-
sigma_mu2: Float32[Array, '']
|
|
143
|
-
log_s: Float32[Array, ' p'] | None
|
|
144
|
-
theta: Float32[Array, ''] | None
|
|
145
|
-
a: Float32[Array, ''] | None
|
|
146
|
-
b: Float32[Array, ''] | None
|
|
147
|
-
rho: Float32[Array, ''] | None
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class State(Module):
|
|
151
|
-
"""
|
|
152
|
-
Represents the MCMC state of BART.
|
|
153
|
-
|
|
154
|
-
Parameters
|
|
155
|
-
----------
|
|
156
|
-
X
|
|
157
|
-
The predictors.
|
|
158
|
-
y
|
|
159
|
-
The response. If the data type is `bool`, the model is binary regression.
|
|
160
|
-
resid
|
|
161
|
-
The residuals (`y` or `z` minus sum of trees).
|
|
162
|
-
z
|
|
163
|
-
The latent variable for binary regression. `None` in continuous
|
|
164
|
-
regression.
|
|
165
|
-
offset
|
|
166
|
-
Constant shift added to the sum of trees.
|
|
167
|
-
sigma2
|
|
168
|
-
The error variance. `None` in binary regression.
|
|
169
|
-
prec_scale
|
|
170
|
-
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
171
|
-
`None` in binary regression.
|
|
172
|
-
sigma2_alpha
|
|
173
|
-
sigma2_beta
|
|
174
|
-
The shape and scale parameters of the inverse gamma prior on the noise
|
|
175
|
-
variance. `None` in binary regression.
|
|
176
|
-
forest
|
|
177
|
-
The sum of trees model.
|
|
178
|
-
"""
|
|
179
|
-
|
|
180
|
-
X: UInt[Array, 'p n']
|
|
181
|
-
y: Float32[Array, ' n'] | Bool[Array, ' n']
|
|
182
|
-
z: None | Float32[Array, ' n']
|
|
183
|
-
offset: Float32[Array, '']
|
|
184
|
-
resid: Float32[Array, ' n']
|
|
185
|
-
sigma2: Float32[Array, ''] | None
|
|
186
|
-
prec_scale: Float32[Array, ' n'] | None
|
|
187
|
-
sigma2_alpha: Float32[Array, ''] | None
|
|
188
|
-
sigma2_beta: Float32[Array, ''] | None
|
|
189
|
-
forest: Forest
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def init(
|
|
193
|
-
*,
|
|
194
|
-
X: UInt[Any, 'p n'],
|
|
195
|
-
y: Float32[Any, ' n'] | Bool[Any, ' n'],
|
|
196
|
-
offset: float | Float32[Any, ''] = 0.0,
|
|
197
|
-
max_split: UInt[Any, ' p'],
|
|
198
|
-
num_trees: int,
|
|
199
|
-
p_nonterminal: Float32[Any, ' d-1'],
|
|
200
|
-
sigma_mu2: float | Float32[Any, ''],
|
|
201
|
-
sigma2_alpha: float | Float32[Any, ''] | None = None,
|
|
202
|
-
sigma2_beta: float | Float32[Any, ''] | None = None,
|
|
203
|
-
error_scale: Float32[Any, ' n'] | None = None,
|
|
204
|
-
min_points_per_decision_node: int | Integer[Any, ''] | None = None,
|
|
205
|
-
resid_batch_size: int | None | Literal['auto'] = 'auto',
|
|
206
|
-
count_batch_size: int | None | Literal['auto'] = 'auto',
|
|
207
|
-
save_ratios: bool = False,
|
|
208
|
-
filter_splitless_vars: bool = True,
|
|
209
|
-
min_points_per_leaf: int | Integer[Any, ''] | None = None,
|
|
210
|
-
log_s: Float32[Any, ' p'] | None = None,
|
|
211
|
-
theta: float | Float32[Any, ''] | None = None,
|
|
212
|
-
a: float | Float32[Any, ''] | None = None,
|
|
213
|
-
b: float | Float32[Any, ''] | None = None,
|
|
214
|
-
rho: float | Float32[Any, ''] | None = None,
|
|
215
|
-
) -> State:
|
|
216
|
-
"""
|
|
217
|
-
Make a BART posterior sampling MCMC initial state.
|
|
218
|
-
|
|
219
|
-
Parameters
|
|
220
|
-
----------
|
|
221
|
-
X
|
|
222
|
-
The predictors. Note this is trasposed compared to the usual convention.
|
|
223
|
-
y
|
|
224
|
-
The response. If the data type is `bool`, the regression model is binary
|
|
225
|
-
regression with probit.
|
|
226
|
-
offset
|
|
227
|
-
Constant shift added to the sum of trees. 0 if not specified.
|
|
228
|
-
max_split
|
|
229
|
-
The maximum split index for each variable. All split ranges start at 1.
|
|
230
|
-
num_trees
|
|
231
|
-
The number of trees in the forest.
|
|
232
|
-
p_nonterminal
|
|
233
|
-
The probability of a nonterminal node at each depth. The maximum depth
|
|
234
|
-
of trees is fixed by the length of this array.
|
|
235
|
-
sigma_mu2
|
|
236
|
-
The prior variance of a leaf, conditional on the tree structure. The
|
|
237
|
-
prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
|
|
238
|
-
prior mean of leaves is always zero.
|
|
239
|
-
sigma2_alpha
|
|
240
|
-
sigma2_beta
|
|
241
|
-
The shape and scale parameters of the inverse gamma prior on the error
|
|
242
|
-
variance. Leave unspecified for binary regression.
|
|
243
|
-
error_scale
|
|
244
|
-
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
245
|
-
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
246
|
-
Not supported for binary regression. If not specified, defaults to 1 for
|
|
247
|
-
all points, but potentially skipping calculations.
|
|
248
|
-
min_points_per_decision_node
|
|
249
|
-
The minimum number of data points in a decision node. 0 if not
|
|
250
|
-
specified.
|
|
251
|
-
resid_batch_size
|
|
252
|
-
count_batch_size
|
|
253
|
-
The batch sizes, along datapoints, for summing the residuals and
|
|
254
|
-
counting the number of datapoints in each leaf. `None` for no batching.
|
|
255
|
-
If 'auto', pick a value based on the device of `y`, or the default
|
|
256
|
-
device.
|
|
257
|
-
save_ratios
|
|
258
|
-
Whether to save the Metropolis-Hastings ratios.
|
|
259
|
-
filter_splitless_vars
|
|
260
|
-
Whether to check `max_split` for variables without available cutpoints.
|
|
261
|
-
If any are found, they are put into a list of variables to exclude from
|
|
262
|
-
the MCMC. If `False`, no check is performed, but the results may be
|
|
263
|
-
wrong if any variable is blocked. The function is jax-traceable only
|
|
264
|
-
if this is set to `False`.
|
|
265
|
-
min_points_per_leaf
|
|
266
|
-
The minimum number of datapoints in a leaf node. 0 if not specified.
|
|
267
|
-
Unlike `min_points_per_decision_node`, this constraint is not taken into
|
|
268
|
-
account in the Metropolis-Hastings ratio because it would be expensive
|
|
269
|
-
to compute. Grow moves that would violate this constraint are vetoed.
|
|
270
|
-
This parameter is independent of `min_points_per_decision_node` and
|
|
271
|
-
there is no check that they are coherent. It makes sense to set
|
|
272
|
-
``min_points_per_decision_node >= 2 * min_points_per_leaf``.
|
|
273
|
-
log_s
|
|
274
|
-
The logarithm of the prior probability for choosing a variable to split
|
|
275
|
-
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
276
|
-
If not specified, use a uniform distribution. If not specified and
|
|
277
|
-
`theta` or `rho`, `a`, `b` are, it's initialized automatically.
|
|
278
|
-
theta
|
|
279
|
-
The concentration parameter for the Dirichlet prior on `s`. Required
|
|
280
|
-
only to update `log_s`. If not specified, and `rho`, `a`, `b` are
|
|
281
|
-
specified, it's initialized automatically.
|
|
282
|
-
a
|
|
283
|
-
b
|
|
284
|
-
rho
|
|
285
|
-
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
286
|
-
|
|
287
|
-
Returns
|
|
288
|
-
-------
|
|
289
|
-
An initialized BART MCMC state.
|
|
290
|
-
|
|
291
|
-
Raises
|
|
292
|
-
------
|
|
293
|
-
ValueError
|
|
294
|
-
If `y` is boolean and arguments unused in binary regression are set.
|
|
295
|
-
|
|
296
|
-
Notes
|
|
297
|
-
-----
|
|
298
|
-
In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
|
|
299
|
-
of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
|
|
300
|
-
child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
|
|
301
|
-
integers in the range ``[0, 1, ..., max_split[i]]``.
|
|
302
|
-
"""
|
|
303
|
-
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
304
|
-
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
305
|
-
max_depth = p_nonterminal.size
|
|
306
|
-
|
|
307
|
-
@partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
308
|
-
def make_forest(max_depth, dtype):
|
|
309
|
-
return grove.make_tree(max_depth, dtype)
|
|
310
|
-
|
|
311
|
-
y = jnp.asarray(y)
|
|
312
|
-
offset = jnp.asarray(offset)
|
|
313
|
-
|
|
314
|
-
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
315
|
-
resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
is_binary = y.dtype == bool
|
|
319
|
-
if is_binary:
|
|
320
|
-
if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
|
|
321
|
-
msg = (
|
|
322
|
-
'error_scale, sigma2_alpha, and sigma2_beta must be set '
|
|
323
|
-
' to `None` for binary regression.'
|
|
324
|
-
)
|
|
325
|
-
raise ValueError(msg)
|
|
326
|
-
sigma2 = None
|
|
327
|
-
else:
|
|
328
|
-
sigma2_alpha = jnp.asarray(sigma2_alpha)
|
|
329
|
-
sigma2_beta = jnp.asarray(sigma2_beta)
|
|
330
|
-
sigma2 = sigma2_beta / sigma2_alpha
|
|
331
|
-
|
|
332
|
-
max_split = jnp.asarray(max_split)
|
|
333
|
-
|
|
334
|
-
if filter_splitless_vars:
|
|
335
|
-
(blocked_vars,) = jnp.nonzero(max_split == 0)
|
|
336
|
-
blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
|
|
337
|
-
# see `fully_used_variables` for the type cast
|
|
338
|
-
else:
|
|
339
|
-
blocked_vars = None
|
|
340
|
-
|
|
341
|
-
# check and initialize sparsity parameters
|
|
342
|
-
if not _all_none_or_not_none(rho, a, b):
|
|
343
|
-
msg = 'rho, a, b are not either all `None` or all set'
|
|
344
|
-
raise ValueError(msg)
|
|
345
|
-
if theta is None and rho is not None:
|
|
346
|
-
theta = rho
|
|
347
|
-
if log_s is None and theta is not None:
|
|
348
|
-
log_s = jnp.zeros(max_split.size)
|
|
349
|
-
|
|
350
|
-
return State(
|
|
351
|
-
X=jnp.asarray(X),
|
|
352
|
-
y=y,
|
|
353
|
-
z=jnp.full(y.shape, offset) if is_binary else None,
|
|
354
|
-
offset=offset,
|
|
355
|
-
resid=jnp.zeros(y.shape) if is_binary else y - offset,
|
|
356
|
-
sigma2=sigma2,
|
|
357
|
-
prec_scale=(
|
|
358
|
-
None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
|
|
359
|
-
),
|
|
360
|
-
sigma2_alpha=sigma2_alpha,
|
|
361
|
-
sigma2_beta=sigma2_beta,
|
|
362
|
-
forest=Forest(
|
|
363
|
-
leaf_tree=make_forest(max_depth, jnp.float32),
|
|
364
|
-
var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
365
|
-
split_tree=make_forest(max_depth - 1, max_split.dtype),
|
|
366
|
-
affluence_tree=(
|
|
367
|
-
make_forest(max_depth - 1, bool)
|
|
368
|
-
.at[:, 1]
|
|
369
|
-
.set(
|
|
370
|
-
True
|
|
371
|
-
if min_points_per_decision_node is None
|
|
372
|
-
else y.size >= min_points_per_decision_node
|
|
373
|
-
)
|
|
374
|
-
),
|
|
375
|
-
blocked_vars=blocked_vars,
|
|
376
|
-
max_split=max_split,
|
|
377
|
-
grow_prop_count=jnp.zeros((), int),
|
|
378
|
-
grow_acc_count=jnp.zeros((), int),
|
|
379
|
-
prune_prop_count=jnp.zeros((), int),
|
|
380
|
-
prune_acc_count=jnp.zeros((), int),
|
|
381
|
-
p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
|
|
382
|
-
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
383
|
-
leaf_indices=jnp.ones(
|
|
384
|
-
(num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
|
|
385
|
-
),
|
|
386
|
-
min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
|
|
387
|
-
min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
|
|
388
|
-
resid_batch_size=resid_batch_size,
|
|
389
|
-
count_batch_size=count_batch_size,
|
|
390
|
-
log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
|
|
391
|
-
log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
|
|
392
|
-
sigma_mu2=jnp.asarray(sigma_mu2),
|
|
393
|
-
log_s=_asarray_or_none(log_s),
|
|
394
|
-
theta=_asarray_or_none(theta),
|
|
395
|
-
rho=_asarray_or_none(rho),
|
|
396
|
-
a=_asarray_or_none(a),
|
|
397
|
-
b=_asarray_or_none(b),
|
|
398
|
-
),
|
|
399
|
-
)
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
def _all_none_or_not_none(*args):
|
|
403
|
-
is_none = [x is None for x in args]
|
|
404
|
-
return all(is_none) or not any(is_none)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
def _asarray_or_none(x):
|
|
408
|
-
if x is None:
|
|
409
|
-
return None
|
|
410
|
-
return jnp.asarray(x)
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
def _choose_suffstat_batch_size(
|
|
414
|
-
resid_batch_size, count_batch_size, y, forest_size
|
|
415
|
-
) -> tuple[int | None, ...]:
|
|
416
|
-
@cache
|
|
417
|
-
def get_platform():
|
|
418
|
-
try:
|
|
419
|
-
device = y.devices().pop()
|
|
420
|
-
except jax.errors.ConcretizationTypeError:
|
|
421
|
-
device = jax.devices()[0]
|
|
422
|
-
platform = device.platform
|
|
423
|
-
if platform not in ('cpu', 'gpu'):
|
|
424
|
-
msg = f'Unknown platform: {platform}'
|
|
425
|
-
raise KeyError(msg)
|
|
426
|
-
return platform
|
|
427
|
-
|
|
428
|
-
if resid_batch_size == 'auto':
|
|
429
|
-
platform = get_platform()
|
|
430
|
-
n = max(1, y.size)
|
|
431
|
-
if platform == 'cpu':
|
|
432
|
-
resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
|
|
433
|
-
elif platform == 'gpu':
|
|
434
|
-
resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
|
|
435
|
-
resid_batch_size = max(1, resid_batch_size)
|
|
436
|
-
|
|
437
|
-
if count_batch_size == 'auto':
|
|
438
|
-
platform = get_platform()
|
|
439
|
-
if platform == 'cpu':
|
|
440
|
-
count_batch_size = None
|
|
441
|
-
elif platform == 'gpu':
|
|
442
|
-
n = max(1, y.size)
|
|
443
|
-
count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
|
|
444
|
-
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
445
|
-
max_memory = 2**29
|
|
446
|
-
itemsize = 4
|
|
447
|
-
min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
|
|
448
|
-
count_batch_size = max(count_batch_size, min_batch_size)
|
|
449
|
-
count_batch_size = max(1, count_batch_size)
|
|
450
|
-
|
|
451
|
-
return resid_batch_size, count_batch_size
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
@jax.jit
|
|
455
|
-
def step(key: Key[Array, ''], bart: State) -> State:
|
|
456
|
-
"""
|
|
457
|
-
Do one MCMC step.
|
|
458
|
-
|
|
459
|
-
Parameters
|
|
460
|
-
----------
|
|
461
|
-
key
|
|
462
|
-
A jax random key.
|
|
463
|
-
bart
|
|
464
|
-
A BART mcmc state, as created by `init`.
|
|
465
|
-
|
|
466
|
-
Returns
|
|
467
|
-
-------
|
|
468
|
-
The new BART mcmc state.
|
|
469
|
-
"""
|
|
470
|
-
keys = split(key)
|
|
471
|
-
|
|
472
|
-
if bart.y.dtype == bool: # binary regression
|
|
473
|
-
bart = replace(bart, sigma2=jnp.float32(1))
|
|
474
|
-
bart = step_trees(keys.pop(), bart)
|
|
475
|
-
bart = replace(bart, sigma2=None)
|
|
476
|
-
return step_z(keys.pop(), bart)
|
|
477
|
-
|
|
478
|
-
else: # continuous regression
|
|
479
|
-
bart = step_trees(keys.pop(), bart)
|
|
480
|
-
return step_sigma(keys.pop(), bart)
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
484
|
-
"""
|
|
485
|
-
Forest sampling step of BART MCMC.
|
|
486
|
-
|
|
487
|
-
Parameters
|
|
488
|
-
----------
|
|
489
|
-
key
|
|
490
|
-
A jax random key.
|
|
491
|
-
bart
|
|
492
|
-
A BART mcmc state, as created by `init`.
|
|
493
|
-
|
|
494
|
-
Returns
|
|
495
|
-
-------
|
|
496
|
-
The new BART mcmc state.
|
|
497
|
-
|
|
498
|
-
Notes
|
|
499
|
-
-----
|
|
500
|
-
This function zeroes the proposal counters.
|
|
501
|
-
"""
|
|
502
|
-
keys = split(key)
|
|
503
|
-
moves = propose_moves(keys.pop(), bart.forest)
|
|
504
|
-
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class Moves(Module):
|
|
508
|
-
"""
|
|
509
|
-
Moves proposed to modify each tree.
|
|
510
|
-
|
|
511
|
-
Parameters
|
|
512
|
-
----------
|
|
513
|
-
allowed
|
|
514
|
-
Whether there is a possible move. If `False`, the other values may not
|
|
515
|
-
make sense. The only case in which a move is marked as allowed but is
|
|
516
|
-
then vetoed is if it does not satisfy `min_points_per_leaf`, which for
|
|
517
|
-
efficiency is implemented post-hoc without changing the rest of the
|
|
518
|
-
MCMC logic.
|
|
519
|
-
grow
|
|
520
|
-
Whether the move is a grow move or a prune move.
|
|
521
|
-
num_growable
|
|
522
|
-
The number of growable leaves in the original tree.
|
|
523
|
-
node
|
|
524
|
-
The index of the leaf to grow or node to prune.
|
|
525
|
-
left
|
|
526
|
-
right
|
|
527
|
-
The indices of the children of 'node'.
|
|
528
|
-
partial_ratio
|
|
529
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
530
|
-
likelihood ratio, the probability of proposing the prune move, and the
|
|
531
|
-
probability that the children of the modified node are terminal. If the
|
|
532
|
-
move is PRUNE, the ratio is inverted. `None` once
|
|
533
|
-
`log_trans_prior_ratio` has been computed.
|
|
534
|
-
log_trans_prior_ratio
|
|
535
|
-
The logarithm of the product of the transition and prior terms of the
|
|
536
|
-
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
537
|
-
`None` if not yet computed. If PRUNE, the log-ratio is negated.
|
|
538
|
-
grow_var
|
|
539
|
-
The decision axes of the new rules.
|
|
540
|
-
grow_split
|
|
541
|
-
The decision boundaries of the new rules.
|
|
542
|
-
var_tree
|
|
543
|
-
The updated decision axes of the trees, valid whatever move.
|
|
544
|
-
affluence_tree
|
|
545
|
-
A partially updated `affluence_tree`, marking non-leaf nodes that would
|
|
546
|
-
become leaves if the move was accepted. This mark initially (out of
|
|
547
|
-
`propose_moves`) takes into account if there would be available decision
|
|
548
|
-
rules to grow the leaf, and whether there are enough datapoints in the
|
|
549
|
-
node is marked in `accept_moves_parallel_stage`.
|
|
550
|
-
logu
|
|
551
|
-
The logarithm of a uniform (0, 1] random variable to be used to
|
|
552
|
-
accept the move. It's in (-oo, 0].
|
|
553
|
-
acc
|
|
554
|
-
Whether the move was accepted. `None` if not yet computed.
|
|
555
|
-
to_prune
|
|
556
|
-
Whether the final operation to apply the move is pruning. This indicates
|
|
557
|
-
an accepted prune move or a rejected grow move. `None` if not yet
|
|
558
|
-
computed.
|
|
559
|
-
"""
|
|
560
|
-
|
|
561
|
-
allowed: Bool[Array, ' num_trees']
|
|
562
|
-
grow: Bool[Array, ' num_trees']
|
|
563
|
-
num_growable: UInt[Array, ' num_trees']
|
|
564
|
-
node: UInt[Array, ' num_trees']
|
|
565
|
-
left: UInt[Array, ' num_trees']
|
|
566
|
-
right: UInt[Array, ' num_trees']
|
|
567
|
-
partial_ratio: Float32[Array, ' num_trees'] | None
|
|
568
|
-
log_trans_prior_ratio: None | Float32[Array, ' num_trees']
|
|
569
|
-
grow_var: UInt[Array, ' num_trees']
|
|
570
|
-
grow_split: UInt[Array, ' num_trees']
|
|
571
|
-
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
572
|
-
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
573
|
-
logu: Float32[Array, ' num_trees']
|
|
574
|
-
acc: None | Bool[Array, ' num_trees']
|
|
575
|
-
to_prune: None | Bool[Array, ' num_trees']
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
|
|
579
|
-
"""
|
|
580
|
-
Propose moves for all the trees.
|
|
581
|
-
|
|
582
|
-
There are two types of moves: GROW (convert a leaf to a decision node and
|
|
583
|
-
add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
|
|
584
|
-
leaf, deleting its children).
|
|
585
|
-
|
|
586
|
-
Parameters
|
|
587
|
-
----------
|
|
588
|
-
key
|
|
589
|
-
A jax random key.
|
|
590
|
-
forest
|
|
591
|
-
The `forest` field of a BART MCMC state.
|
|
592
|
-
|
|
593
|
-
Returns
|
|
594
|
-
-------
|
|
595
|
-
The proposed move for each tree.
|
|
596
|
-
"""
|
|
597
|
-
num_trees, _ = forest.leaf_tree.shape
|
|
598
|
-
keys = split(key, 1 + 2 * num_trees)
|
|
599
|
-
|
|
600
|
-
# compute moves
|
|
601
|
-
grow_moves = propose_grow_moves(
|
|
602
|
-
keys.pop(num_trees),
|
|
603
|
-
forest.var_tree,
|
|
604
|
-
forest.split_tree,
|
|
605
|
-
forest.affluence_tree,
|
|
606
|
-
forest.max_split,
|
|
607
|
-
forest.blocked_vars,
|
|
608
|
-
forest.p_nonterminal,
|
|
609
|
-
forest.p_propose_grow,
|
|
610
|
-
forest.log_s,
|
|
611
|
-
)
|
|
612
|
-
prune_moves = propose_prune_moves(
|
|
613
|
-
keys.pop(num_trees),
|
|
614
|
-
forest.split_tree,
|
|
615
|
-
grow_moves.affluence_tree,
|
|
616
|
-
forest.p_nonterminal,
|
|
617
|
-
forest.p_propose_grow,
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
|
|
621
|
-
|
|
622
|
-
# choose between grow or prune
|
|
623
|
-
p_grow = jnp.where(
|
|
624
|
-
grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
|
|
625
|
-
)
|
|
626
|
-
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
627
|
-
|
|
628
|
-
# compute children indices
|
|
629
|
-
node = jnp.where(grow, grow_moves.node, prune_moves.node)
|
|
630
|
-
left = node << 1
|
|
631
|
-
right = left + 1
|
|
632
|
-
|
|
633
|
-
return Moves(
|
|
634
|
-
allowed=grow_moves.allowed | prune_moves.allowed,
|
|
635
|
-
grow=grow,
|
|
636
|
-
num_growable=grow_moves.num_growable,
|
|
637
|
-
node=node,
|
|
638
|
-
left=left,
|
|
639
|
-
right=right,
|
|
640
|
-
partial_ratio=jnp.where(
|
|
641
|
-
grow, grow_moves.partial_ratio, prune_moves.partial_ratio
|
|
642
|
-
),
|
|
643
|
-
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
644
|
-
grow_var=grow_moves.var,
|
|
645
|
-
grow_split=grow_moves.split,
|
|
646
|
-
# var_tree does not need to be updated if prune
|
|
647
|
-
var_tree=grow_moves.var_tree,
|
|
648
|
-
# affluence_tree is updated for both moves unconditionally, prune last
|
|
649
|
-
affluence_tree=prune_moves.affluence_tree,
|
|
650
|
-
logu=jnp.log1p(-exp1mlogu),
|
|
651
|
-
acc=None, # will be set in accept_moves_sequential_stage
|
|
652
|
-
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
653
|
-
)
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
class GrowMoves(Module):
|
|
657
|
-
"""
|
|
658
|
-
Represent a proposed grow move for each tree.
|
|
659
|
-
|
|
660
|
-
Parameters
|
|
661
|
-
----------
|
|
662
|
-
allowed
|
|
663
|
-
Whether the move is allowed for proposal.
|
|
664
|
-
num_growable
|
|
665
|
-
The number of leaves that can be proposed for grow.
|
|
666
|
-
node
|
|
667
|
-
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
668
|
-
leaves.
|
|
669
|
-
var
|
|
670
|
-
split
|
|
671
|
-
The decision axis and boundary of the new rule.
|
|
672
|
-
partial_ratio
|
|
673
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
674
|
-
the likelihood ratio and the probability of proposing the prune
|
|
675
|
-
move.
|
|
676
|
-
var_tree
|
|
677
|
-
The updated decision axes of the tree.
|
|
678
|
-
affluence_tree
|
|
679
|
-
A partially updated `affluence_tree` that marks each new leaf that
|
|
680
|
-
would be produced as `True` if it would have available decision rules.
|
|
681
|
-
"""
|
|
682
|
-
|
|
683
|
-
allowed: Bool[Array, ' num_trees']
|
|
684
|
-
num_growable: UInt[Array, ' num_trees']
|
|
685
|
-
node: UInt[Array, ' num_trees']
|
|
686
|
-
var: UInt[Array, ' num_trees']
|
|
687
|
-
split: UInt[Array, ' num_trees']
|
|
688
|
-
partial_ratio: Float32[Array, ' num_trees']
|
|
689
|
-
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
690
|
-
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
|
|
694
|
-
def propose_grow_moves(
|
|
695
|
-
key: Key[Array, ' num_trees'],
|
|
696
|
-
var_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
697
|
-
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
698
|
-
affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
|
|
699
|
-
max_split: UInt[Array, ' p'],
|
|
700
|
-
blocked_vars: Int32[Array, ' k'] | None,
|
|
701
|
-
p_nonterminal: Float32[Array, ' 2**d'],
|
|
702
|
-
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
703
|
-
log_s: Float32[Array, ' p'] | None,
|
|
704
|
-
) -> GrowMoves:
|
|
705
|
-
"""
|
|
706
|
-
Propose a GROW move for each tree.
|
|
707
|
-
|
|
708
|
-
A GROW move picks a leaf node and converts it to a non-terminal node with
|
|
709
|
-
two leaf children.
|
|
710
|
-
|
|
711
|
-
Parameters
|
|
712
|
-
----------
|
|
713
|
-
key
|
|
714
|
-
A jax random key.
|
|
715
|
-
var_tree
|
|
716
|
-
The splitting axes of the tree.
|
|
717
|
-
split_tree
|
|
718
|
-
The splitting points of the tree.
|
|
719
|
-
affluence_tree
|
|
720
|
-
Whether each leaf has enough points to be grown.
|
|
721
|
-
max_split
|
|
722
|
-
The maximum split index for each variable.
|
|
723
|
-
blocked_vars
|
|
724
|
-
The indices of the variables that have no available cutpoints.
|
|
725
|
-
p_nonterminal
|
|
726
|
-
The a priori probability of a node to be nonterminal conditional on the
|
|
727
|
-
ancestors, including at the maximum depth where it should be zero.
|
|
728
|
-
p_propose_grow
|
|
729
|
-
The unnormalized probability of choosing a leaf to grow.
|
|
730
|
-
log_s
|
|
731
|
-
Unnormalized log-probability used to choose a variable to split on
|
|
732
|
-
amongst the available ones.
|
|
733
|
-
|
|
734
|
-
Returns
|
|
735
|
-
-------
|
|
736
|
-
An object representing the proposed move.
|
|
737
|
-
|
|
738
|
-
Notes
|
|
739
|
-
-----
|
|
740
|
-
The move is not proposed if each leaf is already at maximum depth, or has
|
|
741
|
-
less datapoints than the requested threshold `min_points_per_decision_node`,
|
|
742
|
-
or it does not have any available decision rules given its ancestors. This
|
|
743
|
-
is marked by setting `allowed` to `False` and `num_growable` to 0.
|
|
744
|
-
"""
|
|
745
|
-
keys = split(key, 3)
|
|
746
|
-
|
|
747
|
-
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
748
|
-
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
749
|
-
)
|
|
750
|
-
|
|
751
|
-
# sample a decision rule
|
|
752
|
-
var, num_available_var = choose_variable(
|
|
753
|
-
keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
|
|
754
|
-
)
|
|
755
|
-
split_idx, l, r = choose_split(
|
|
756
|
-
keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
# determine if the new leaves would have available decision rules; if the
|
|
760
|
-
# move is blocked, these values may not make sense
|
|
761
|
-
left_growable = right_growable = num_available_var > 1
|
|
762
|
-
left_growable |= l < split_idx
|
|
763
|
-
right_growable |= split_idx + 1 < r
|
|
764
|
-
left = leaf_to_grow << 1
|
|
765
|
-
right = left + 1
|
|
766
|
-
affluence_tree = affluence_tree.at[left].set(left_growable)
|
|
767
|
-
affluence_tree = affluence_tree.at[right].set(right_growable)
|
|
768
|
-
|
|
769
|
-
ratio = compute_partial_ratio(
|
|
770
|
-
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
return GrowMoves(
|
|
774
|
-
allowed=num_growable > 0,
|
|
775
|
-
num_growable=num_growable,
|
|
776
|
-
node=leaf_to_grow,
|
|
777
|
-
var=var,
|
|
778
|
-
split=split_idx,
|
|
779
|
-
partial_ratio=ratio,
|
|
780
|
-
var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
|
|
781
|
-
affluence_tree=affluence_tree,
|
|
782
|
-
)
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
def choose_leaf(
|
|
786
|
-
key: Key[Array, ''],
|
|
787
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
788
|
-
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
789
|
-
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
790
|
-
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
791
|
-
"""
|
|
792
|
-
Choose a leaf node to grow in a tree.
|
|
793
|
-
|
|
794
|
-
Parameters
|
|
795
|
-
----------
|
|
796
|
-
key
|
|
797
|
-
A jax random key.
|
|
798
|
-
split_tree
|
|
799
|
-
The splitting points of the tree.
|
|
800
|
-
affluence_tree
|
|
801
|
-
Whether a leaf has enough points that it could be split into two leaves
|
|
802
|
-
satisfying the `min_points_per_leaf` requirement.
|
|
803
|
-
p_propose_grow
|
|
804
|
-
The unnormalized probability of choosing a leaf to grow.
|
|
805
|
-
|
|
806
|
-
Returns
|
|
807
|
-
-------
|
|
808
|
-
leaf_to_grow : Int32[Array, '']
|
|
809
|
-
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
810
|
-
``2 ** d``.
|
|
811
|
-
num_growable : Int32[Array, '']
|
|
812
|
-
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
813
|
-
and have at least twice `min_points_per_leaf`.
|
|
814
|
-
prob_choose : Float32[Array, '']
|
|
815
|
-
The (normalized) probability that this function had to choose that
|
|
816
|
-
specific leaf, given the arguments.
|
|
817
|
-
num_prunable : Int32[Array, '']
|
|
818
|
-
The number of leaf parents that could be pruned, after converting the
|
|
819
|
-
selected leaf to a non-terminal node.
|
|
820
|
-
"""
|
|
821
|
-
is_growable = growable_leaves(split_tree, affluence_tree)
|
|
822
|
-
num_growable = jnp.count_nonzero(is_growable)
|
|
823
|
-
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
824
|
-
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
825
|
-
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
826
|
-
prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
|
|
827
|
-
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
828
|
-
num_prunable = jnp.count_nonzero(is_parent)
|
|
829
|
-
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
def growable_leaves(
|
|
833
|
-
split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
|
|
834
|
-
) -> Bool[Array, ' 2**(d-1)']:
|
|
835
|
-
"""
|
|
836
|
-
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
837
|
-
|
|
838
|
-
The condition is that a leaf is not at the bottom level, has available
|
|
839
|
-
decision rules given its ancestors, and has at least
|
|
840
|
-
`min_points_per_decision_node` points.
|
|
841
|
-
|
|
842
|
-
Parameters
|
|
843
|
-
----------
|
|
844
|
-
split_tree
|
|
845
|
-
The splitting points of the tree.
|
|
846
|
-
affluence_tree
|
|
847
|
-
Marks leaves that can be grown.
|
|
848
|
-
|
|
849
|
-
Returns
|
|
850
|
-
-------
|
|
851
|
-
The mask indicating the leaf nodes that can be proposed to grow.
|
|
852
|
-
|
|
853
|
-
Notes
|
|
854
|
-
-----
|
|
855
|
-
This function needs `split_tree` and not just `affluence_tree` because
|
|
856
|
-
`affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
|
|
857
|
-
"""
|
|
858
|
-
return grove.is_actual_leaf(split_tree) & affluence_tree
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
def categorical(
|
|
862
|
-
key: Key[Array, ''], distr: Float32[Array, ' n']
|
|
863
|
-
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
864
|
-
"""
|
|
865
|
-
Return a random integer from an arbitrary distribution.
|
|
866
|
-
|
|
867
|
-
Parameters
|
|
868
|
-
----------
|
|
869
|
-
key
|
|
870
|
-
A jax random key.
|
|
871
|
-
distr
|
|
872
|
-
An unnormalized probability distribution.
|
|
873
|
-
|
|
874
|
-
Returns
|
|
875
|
-
-------
|
|
876
|
-
u : Int32[Array, '']
|
|
877
|
-
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
878
|
-
return ``n``.
|
|
879
|
-
norm : Float32[Array, '']
|
|
880
|
-
The sum of `distr`.
|
|
881
|
-
|
|
882
|
-
Notes
|
|
883
|
-
-----
|
|
884
|
-
This function uses a cumsum instead of the Gumbel trick, so it's ok only
|
|
885
|
-
for small ranges with probabilities well greater than 0.
|
|
886
|
-
"""
|
|
887
|
-
ecdf = jnp.cumsum(distr)
|
|
888
|
-
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
889
|
-
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
def choose_variable(
|
|
893
|
-
key: Key[Array, ''],
|
|
894
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
895
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
896
|
-
max_split: UInt[Array, ' p'],
|
|
897
|
-
leaf_index: Int32[Array, ''],
|
|
898
|
-
blocked_vars: Int32[Array, ' k'] | None,
|
|
899
|
-
log_s: Float32[Array, ' p'] | None,
|
|
900
|
-
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
901
|
-
"""
|
|
902
|
-
Choose a variable to split on for a new non-terminal node.
|
|
903
|
-
|
|
904
|
-
Parameters
|
|
905
|
-
----------
|
|
906
|
-
key
|
|
907
|
-
A jax random key.
|
|
908
|
-
var_tree
|
|
909
|
-
The variable indices of the tree.
|
|
910
|
-
split_tree
|
|
911
|
-
The splitting points of the tree.
|
|
912
|
-
max_split
|
|
913
|
-
The maximum split index for each variable.
|
|
914
|
-
leaf_index
|
|
915
|
-
The index of the leaf to grow.
|
|
916
|
-
blocked_vars
|
|
917
|
-
The indices of the variables that have no available cutpoints. If
|
|
918
|
-
`None`, all variables are assumed unblocked.
|
|
919
|
-
log_s
|
|
920
|
-
The logarithm of the prior probability for choosing a variable. If
|
|
921
|
-
`None`, use a uniform distribution.
|
|
922
|
-
|
|
923
|
-
Returns
|
|
924
|
-
-------
|
|
925
|
-
var : Int32[Array, '']
|
|
926
|
-
The index of the variable to split on.
|
|
927
|
-
num_available_var : Int32[Array, '']
|
|
928
|
-
The number of variables with available decision rules `var` was chosen
|
|
929
|
-
from.
|
|
930
|
-
"""
|
|
931
|
-
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
932
|
-
if blocked_vars is not None:
|
|
933
|
-
var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
|
|
934
|
-
|
|
935
|
-
if log_s is None:
|
|
936
|
-
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
937
|
-
else:
|
|
938
|
-
return categorical_exclude(key, log_s, var_to_ignore)
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
def fully_used_variables(
|
|
942
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
943
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
944
|
-
max_split: UInt[Array, ' p'],
|
|
945
|
-
leaf_index: Int32[Array, ''],
|
|
946
|
-
) -> UInt[Array, ' d-2']:
|
|
947
|
-
"""
|
|
948
|
-
Find variables in the ancestors of a node that have an empty split range.
|
|
949
|
-
|
|
950
|
-
Parameters
|
|
951
|
-
----------
|
|
952
|
-
var_tree
|
|
953
|
-
The variable indices of the tree.
|
|
954
|
-
split_tree
|
|
955
|
-
The splitting points of the tree.
|
|
956
|
-
max_split
|
|
957
|
-
The maximum split index for each variable.
|
|
958
|
-
leaf_index
|
|
959
|
-
The index of the node, assumed to be valid for `var_tree`.
|
|
960
|
-
|
|
961
|
-
Returns
|
|
962
|
-
-------
|
|
963
|
-
The indices of the variables that have an empty split range.
|
|
964
|
-
|
|
965
|
-
Notes
|
|
966
|
-
-----
|
|
967
|
-
The number of unused variables is not known in advance. Unused values in the
|
|
968
|
-
array are filled with `p`. The fill values are not guaranteed to be placed
|
|
969
|
-
in any particular order, and variables may appear more than once.
|
|
970
|
-
"""
|
|
971
|
-
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
972
|
-
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
973
|
-
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
974
|
-
num_split = r - l
|
|
975
|
-
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
976
|
-
# the type of var_to_ignore is already sufficient to hold max_split.size,
|
|
977
|
-
# see ancestor_variables()
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
def ancestor_variables(
|
|
981
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
982
|
-
max_split: UInt[Array, ' p'],
|
|
983
|
-
node_index: Int32[Array, ''],
|
|
984
|
-
) -> UInt[Array, ' d-2']:
|
|
985
|
-
"""
|
|
986
|
-
Return the list of variables in the ancestors of a node.
|
|
987
|
-
|
|
988
|
-
Parameters
|
|
989
|
-
----------
|
|
990
|
-
var_tree
|
|
991
|
-
The variable indices of the tree.
|
|
992
|
-
max_split
|
|
993
|
-
The maximum split index for each variable. Used only to get `p`.
|
|
994
|
-
node_index
|
|
995
|
-
The index of the node, assumed to be valid for `var_tree`.
|
|
996
|
-
|
|
997
|
-
Returns
|
|
998
|
-
-------
|
|
999
|
-
The variable indices of the ancestors of the node.
|
|
1000
|
-
|
|
1001
|
-
Notes
|
|
1002
|
-
-----
|
|
1003
|
-
The ancestors are the nodes going from the root to the parent of the node.
|
|
1004
|
-
The number of ancestors is not known at tracing time; unused spots in the
|
|
1005
|
-
output array are filled with `p`.
|
|
1006
|
-
"""
|
|
1007
|
-
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
1008
|
-
ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
|
|
1009
|
-
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
1010
|
-
|
|
1011
|
-
def loop(carry, _):
|
|
1012
|
-
i, index, ancestor_vars = carry
|
|
1013
|
-
index >>= 1
|
|
1014
|
-
var = var_tree[index]
|
|
1015
|
-
var = jnp.where(index, var, max_split.size)
|
|
1016
|
-
ancestor_vars = ancestor_vars.at[i].set(var)
|
|
1017
|
-
return (i - 1, index, ancestor_vars), None
|
|
1018
|
-
|
|
1019
|
-
(_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
|
|
1020
|
-
return ancestor_vars
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
def split_range(
|
|
1024
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1025
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1026
|
-
max_split: UInt[Array, ' p'],
|
|
1027
|
-
node_index: Int32[Array, ''],
|
|
1028
|
-
ref_var: Int32[Array, ''],
|
|
1029
|
-
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
1030
|
-
"""
|
|
1031
|
-
Return the range of allowed splits for a variable at a given node.
|
|
1032
|
-
|
|
1033
|
-
Parameters
|
|
1034
|
-
----------
|
|
1035
|
-
var_tree
|
|
1036
|
-
The variable indices of the tree.
|
|
1037
|
-
split_tree
|
|
1038
|
-
The splitting points of the tree.
|
|
1039
|
-
max_split
|
|
1040
|
-
The maximum split index for each variable.
|
|
1041
|
-
node_index
|
|
1042
|
-
The index of the node, assumed to be valid for `var_tree`.
|
|
1043
|
-
ref_var
|
|
1044
|
-
The variable for which to measure the split range.
|
|
1045
|
-
|
|
1046
|
-
Returns
|
|
1047
|
-
-------
|
|
1048
|
-
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
|
|
1049
|
-
"""
|
|
1050
|
-
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
1051
|
-
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
1052
|
-
jnp.int32
|
|
1053
|
-
)
|
|
1054
|
-
carry = jnp.int32(0), initial_r, node_index
|
|
1055
|
-
|
|
1056
|
-
def loop(carry, _):
|
|
1057
|
-
l, r, index = carry
|
|
1058
|
-
right_child = (index & 1).astype(bool)
|
|
1059
|
-
index >>= 1
|
|
1060
|
-
split = split_tree[index]
|
|
1061
|
-
cond = (var_tree[index] == ref_var) & index.astype(bool)
|
|
1062
|
-
l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
|
|
1063
|
-
r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
|
|
1064
|
-
return (l, r, index), None
|
|
1065
|
-
|
|
1066
|
-
(l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
|
|
1067
|
-
return l + 1, r
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
def randint_exclude(
|
|
1071
|
-
key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
|
|
1072
|
-
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
1073
|
-
"""
|
|
1074
|
-
Return a random integer in a range, excluding some values.
|
|
1075
|
-
|
|
1076
|
-
Parameters
|
|
1077
|
-
----------
|
|
1078
|
-
key
|
|
1079
|
-
A jax random key.
|
|
1080
|
-
sup
|
|
1081
|
-
The exclusive upper bound of the range.
|
|
1082
|
-
exclude
|
|
1083
|
-
The values to exclude from the range. Values greater than or equal to
|
|
1084
|
-
`sup` are ignored. Values can appear more than once.
|
|
1085
|
-
|
|
1086
|
-
Returns
|
|
1087
|
-
-------
|
|
1088
|
-
u : Int32[Array, '']
|
|
1089
|
-
A random integer `u` in the range ``[0, sup)`` such that ``u not in
|
|
1090
|
-
exclude``.
|
|
1091
|
-
num_allowed : Int32[Array, '']
|
|
1092
|
-
The number of integers in the range that were not excluded.
|
|
1093
|
-
|
|
1094
|
-
Notes
|
|
1095
|
-
-----
|
|
1096
|
-
If all values in the range are excluded, return `sup`.
|
|
1097
|
-
"""
|
|
1098
|
-
exclude, num_allowed = _process_exclude(sup, exclude)
|
|
1099
|
-
u = random.randint(key, (), 0, num_allowed)
|
|
1100
|
-
|
|
1101
|
-
def loop(u, i_excluded):
|
|
1102
|
-
return jnp.where(i_excluded <= u, u + 1, u), None
|
|
1103
|
-
|
|
1104
|
-
u, _ = lax.scan(loop, u, exclude)
|
|
1105
|
-
return u, num_allowed
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
def _process_exclude(sup, exclude):
|
|
1109
|
-
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
1110
|
-
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
1111
|
-
return exclude, num_allowed
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
def categorical_exclude(
|
|
1115
|
-
key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
|
|
1116
|
-
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
1117
|
-
"""
|
|
1118
|
-
Draw from a categorical distribution, excluding a set of values.
|
|
1119
|
-
|
|
1120
|
-
Parameters
|
|
1121
|
-
----------
|
|
1122
|
-
key
|
|
1123
|
-
A jax random key.
|
|
1124
|
-
logits
|
|
1125
|
-
The unnormalized log-probabilities of each category.
|
|
1126
|
-
exclude
|
|
1127
|
-
The values to exclude from the range [0, k). Values greater than or
|
|
1128
|
-
equal to `logits.size` are ignored. Values can appear more than once.
|
|
1129
|
-
|
|
1130
|
-
Returns
|
|
1131
|
-
-------
|
|
1132
|
-
u : Int32[Array, '']
|
|
1133
|
-
A random integer in the range ``[0, k)`` such that ``u not in exclude``.
|
|
1134
|
-
num_allowed : Int32[Array, '']
|
|
1135
|
-
The number of integers in the range that were not excluded.
|
|
1136
|
-
|
|
1137
|
-
Notes
|
|
1138
|
-
-----
|
|
1139
|
-
If all values in the range are excluded, the result is unspecified.
|
|
1140
|
-
"""
|
|
1141
|
-
exclude, num_allowed = _process_exclude(logits.size, exclude)
|
|
1142
|
-
kinda_neg_inf = jnp.finfo(logits.dtype).min
|
|
1143
|
-
logits = logits.at[exclude].set(kinda_neg_inf)
|
|
1144
|
-
u = random.categorical(key, logits)
|
|
1145
|
-
return u, num_allowed
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
def choose_split(
|
|
1149
|
-
key: Key[Array, ''],
|
|
1150
|
-
var: Int32[Array, ''],
|
|
1151
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1152
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1153
|
-
max_split: UInt[Array, ' p'],
|
|
1154
|
-
leaf_index: Int32[Array, ''],
|
|
1155
|
-
) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
|
|
1156
|
-
"""
|
|
1157
|
-
Choose a split point for a new non-terminal node.
|
|
1158
|
-
|
|
1159
|
-
Parameters
|
|
1160
|
-
----------
|
|
1161
|
-
key
|
|
1162
|
-
A jax random key.
|
|
1163
|
-
var
|
|
1164
|
-
The variable to split on.
|
|
1165
|
-
var_tree
|
|
1166
|
-
The splitting axes of the tree. Does not need to already contain `var`
|
|
1167
|
-
at `leaf_index`.
|
|
1168
|
-
split_tree
|
|
1169
|
-
The splitting points of the tree.
|
|
1170
|
-
max_split
|
|
1171
|
-
The maximum split index for each variable.
|
|
1172
|
-
leaf_index
|
|
1173
|
-
The index of the leaf to grow.
|
|
1174
|
-
|
|
1175
|
-
Returns
|
|
1176
|
-
-------
|
|
1177
|
-
split : Int32[Array, '']
|
|
1178
|
-
The cutpoint.
|
|
1179
|
-
l : Int32[Array, '']
|
|
1180
|
-
r : Int32[Array, '']
|
|
1181
|
-
The integer range `split` was drawn from is [l, r).
|
|
1182
|
-
|
|
1183
|
-
Notes
|
|
1184
|
-
-----
|
|
1185
|
-
If `var` is out of bounds, or if the available split range on that variable
|
|
1186
|
-
is empty, return 0.
|
|
1187
|
-
"""
|
|
1188
|
-
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
1189
|
-
return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
def compute_partial_ratio(
|
|
1193
|
-
prob_choose: Float32[Array, ''],
|
|
1194
|
-
num_prunable: Int32[Array, ''],
|
|
1195
|
-
p_nonterminal: Float32[Array, ' 2**d'],
|
|
1196
|
-
leaf_to_grow: Int32[Array, ''],
|
|
1197
|
-
) -> Float32[Array, '']:
|
|
1198
|
-
"""
|
|
1199
|
-
Compute the product of the transition and prior ratios of a grow move.
|
|
1200
|
-
|
|
1201
|
-
Parameters
|
|
1202
|
-
----------
|
|
1203
|
-
prob_choose
|
|
1204
|
-
The probability that the leaf had to be chosen amongst the growable
|
|
1205
|
-
leaves.
|
|
1206
|
-
num_prunable
|
|
1207
|
-
The number of leaf parents that could be pruned, after converting the
|
|
1208
|
-
leaf to be grown to a non-terminal node.
|
|
1209
|
-
p_nonterminal
|
|
1210
|
-
The a priori probability of each node being nonterminal conditional on
|
|
1211
|
-
its ancestors.
|
|
1212
|
-
leaf_to_grow
|
|
1213
|
-
The index of the leaf to grow.
|
|
1214
|
-
|
|
1215
|
-
Returns
|
|
1216
|
-
-------
|
|
1217
|
-
The partial transition ratio times the prior ratio.
|
|
1218
|
-
|
|
1219
|
-
Notes
|
|
1220
|
-
-----
|
|
1221
|
-
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
1222
|
-
The "partial" transition ratio returned is missing the factor P(propose
|
|
1223
|
-
prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
|
|
1224
|
-
"partial" prior ratio is missing the factor P(children are leaves).
|
|
1225
|
-
"""
|
|
1226
|
-
# the two ratios also contain factors num_available_split *
|
|
1227
|
-
# num_available_var * s[var], but they cancel out
|
|
1228
|
-
|
|
1229
|
-
# p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
|
|
1230
|
-
# computed here because they need the count trees, which are computed in the
|
|
1231
|
-
# acceptance phase
|
|
1232
|
-
|
|
1233
|
-
prune_allowed = leaf_to_grow != 1
|
|
1234
|
-
# prune allowed <---> the initial tree is not a root
|
|
1235
|
-
# leaf to grow is root --> the tree can only be a root
|
|
1236
|
-
# tree is a root --> the only leaf I can grow is root
|
|
1237
|
-
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
1238
|
-
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
1239
|
-
|
|
1240
|
-
# .at.get because if leaf_to_grow is out of bounds (move not allowed), this
|
|
1241
|
-
# would produce a 0 and then an inf when `complete_ratio` takes the log
|
|
1242
|
-
pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
|
|
1243
|
-
tree_ratio = pnt / (1 - pnt)
|
|
1244
|
-
|
|
1245
|
-
return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
class PruneMoves(Module):
|
|
1249
|
-
"""
|
|
1250
|
-
Represent a proposed prune move for each tree.
|
|
1251
|
-
|
|
1252
|
-
Parameters
|
|
1253
|
-
----------
|
|
1254
|
-
allowed
|
|
1255
|
-
Whether the move is possible.
|
|
1256
|
-
node
|
|
1257
|
-
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
1258
|
-
partial_ratio
|
|
1259
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
1260
|
-
likelihood ratio, the probability of proposing the prune move, and the
|
|
1261
|
-
prior probability that the children of the node to prune are leaves.
|
|
1262
|
-
This ratio is inverted, and is meant to be inverted back in
|
|
1263
|
-
`accept_move_and_sample_leaves`.
|
|
1264
|
-
"""
|
|
1265
|
-
|
|
1266
|
-
allowed: Bool[Array, ' num_trees']
|
|
1267
|
-
node: UInt[Array, ' num_trees']
|
|
1268
|
-
partial_ratio: Float32[Array, ' num_trees']
|
|
1269
|
-
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
1273
|
-
def propose_prune_moves(
|
|
1274
|
-
key: Key[Array, ''],
|
|
1275
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1276
|
-
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1277
|
-
p_nonterminal: Float32[Array, ' 2**d'],
|
|
1278
|
-
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1279
|
-
) -> PruneMoves:
|
|
1280
|
-
"""
|
|
1281
|
-
Tree structure prune move proposal of BART MCMC.
|
|
1282
|
-
|
|
1283
|
-
Parameters
|
|
1284
|
-
----------
|
|
1285
|
-
key
|
|
1286
|
-
A jax random key.
|
|
1287
|
-
split_tree
|
|
1288
|
-
The splitting points of the tree.
|
|
1289
|
-
affluence_tree
|
|
1290
|
-
Whether each leaf can be grown.
|
|
1291
|
-
p_nonterminal
|
|
1292
|
-
The a priori probability of a node to be nonterminal conditional on
|
|
1293
|
-
the ancestors, including at the maximum depth where it should be zero.
|
|
1294
|
-
p_propose_grow
|
|
1295
|
-
The unnormalized probability of choosing a leaf to grow.
|
|
1296
|
-
|
|
1297
|
-
Returns
|
|
1298
|
-
-------
|
|
1299
|
-
An object representing the proposed moves.
|
|
1300
|
-
"""
|
|
1301
|
-
node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
|
|
1302
|
-
key, split_tree, affluence_tree, p_propose_grow
|
|
1303
|
-
)
|
|
1304
|
-
|
|
1305
|
-
ratio = compute_partial_ratio(
|
|
1306
|
-
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
1307
|
-
)
|
|
1308
|
-
|
|
1309
|
-
return PruneMoves(
|
|
1310
|
-
allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
|
|
1311
|
-
node=node_to_prune,
|
|
1312
|
-
partial_ratio=ratio,
|
|
1313
|
-
affluence_tree=affluence_tree,
|
|
1314
|
-
)
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
def choose_leaf_parent(
|
|
1318
|
-
key: Key[Array, ''],
|
|
1319
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1320
|
-
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1321
|
-
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1322
|
-
) -> tuple[
|
|
1323
|
-
Int32[Array, ''],
|
|
1324
|
-
Int32[Array, ''],
|
|
1325
|
-
Float32[Array, ''],
|
|
1326
|
-
Bool[Array, 'num_trees 2**(d-1)'],
|
|
1327
|
-
]:
|
|
1328
|
-
"""
|
|
1329
|
-
Pick a non-terminal node with leaf children to prune in a tree.
|
|
1330
|
-
|
|
1331
|
-
Parameters
|
|
1332
|
-
----------
|
|
1333
|
-
key
|
|
1334
|
-
A jax random key.
|
|
1335
|
-
split_tree
|
|
1336
|
-
The splitting points of the tree.
|
|
1337
|
-
affluence_tree
|
|
1338
|
-
Whether a leaf has enough points to be grown.
|
|
1339
|
-
p_propose_grow
|
|
1340
|
-
The unnormalized probability of choosing a leaf to grow.
|
|
1341
|
-
|
|
1342
|
-
Returns
|
|
1343
|
-
-------
|
|
1344
|
-
node_to_prune : Int32[Array, '']
|
|
1345
|
-
The index of the node to prune. If ``num_prunable == 0``, return
|
|
1346
|
-
``2 ** d``.
|
|
1347
|
-
num_prunable : Int32[Array, '']
|
|
1348
|
-
The number of leaf parents that could be pruned.
|
|
1349
|
-
prob_choose : Float32[Array, '']
|
|
1350
|
-
The (normalized) probability that `choose_leaf` would chose
|
|
1351
|
-
`node_to_prune` as leaf to grow, if passed the tree where
|
|
1352
|
-
`node_to_prune` had been pruned.
|
|
1353
|
-
affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
|
|
1354
|
-
A partially updated `affluence_tree`, marking the node to prune as
|
|
1355
|
-
growable.
|
|
1356
|
-
"""
|
|
1357
|
-
# sample a node to prune
|
|
1358
|
-
is_prunable = grove.is_leaves_parent(split_tree)
|
|
1359
|
-
num_prunable = jnp.count_nonzero(is_prunable)
|
|
1360
|
-
node_to_prune = randint_masked(key, is_prunable)
|
|
1361
|
-
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
1362
|
-
|
|
1363
|
-
# compute stuff for reverse move
|
|
1364
|
-
split_tree = split_tree.at[node_to_prune].set(0)
|
|
1365
|
-
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
1366
|
-
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
1367
|
-
distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
1368
|
-
prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
|
|
1369
|
-
prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
|
|
1370
|
-
|
|
1371
|
-
return node_to_prune, num_prunable, prob_choose, affluence_tree
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
|
|
1375
|
-
"""
|
|
1376
|
-
Return a random integer in a range, including only some values.
|
|
1377
|
-
|
|
1378
|
-
Parameters
|
|
1379
|
-
----------
|
|
1380
|
-
key
|
|
1381
|
-
A jax random key.
|
|
1382
|
-
mask
|
|
1383
|
-
The mask indicating the allowed values.
|
|
1384
|
-
|
|
1385
|
-
Returns
|
|
1386
|
-
-------
|
|
1387
|
-
A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
|
|
1388
|
-
|
|
1389
|
-
Notes
|
|
1390
|
-
-----
|
|
1391
|
-
If all values in the mask are `False`, return `n`.
|
|
1392
|
-
"""
|
|
1393
|
-
ecdf = jnp.cumsum(mask)
|
|
1394
|
-
u = random.randint(key, (), 0, ecdf[-1])
|
|
1395
|
-
return jnp.searchsorted(ecdf, u, 'right')
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
def accept_moves_and_sample_leaves(
|
|
1399
|
-
key: Key[Array, ''], bart: State, moves: Moves
|
|
1400
|
-
) -> State:
|
|
1401
|
-
"""
|
|
1402
|
-
Accept or reject the proposed moves and sample the new leaf values.
|
|
1403
|
-
|
|
1404
|
-
Parameters
|
|
1405
|
-
----------
|
|
1406
|
-
key
|
|
1407
|
-
A jax random key.
|
|
1408
|
-
bart
|
|
1409
|
-
A valid BART mcmc state.
|
|
1410
|
-
moves
|
|
1411
|
-
The proposed moves, see `propose_moves`.
|
|
1412
|
-
|
|
1413
|
-
Returns
|
|
1414
|
-
-------
|
|
1415
|
-
A new (valid) BART mcmc state.
|
|
1416
|
-
"""
|
|
1417
|
-
pso = accept_moves_parallel_stage(key, bart, moves)
|
|
1418
|
-
bart, moves = accept_moves_sequential_stage(pso)
|
|
1419
|
-
return accept_moves_final_stage(bart, moves)
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
class Counts(Module):
|
|
1423
|
-
"""
|
|
1424
|
-
Number of datapoints in the nodes involved in proposed moves for each tree.
|
|
1425
|
-
|
|
1426
|
-
Parameters
|
|
1427
|
-
----------
|
|
1428
|
-
left
|
|
1429
|
-
Number of datapoints in the left child.
|
|
1430
|
-
right
|
|
1431
|
-
Number of datapoints in the right child.
|
|
1432
|
-
total
|
|
1433
|
-
Number of datapoints in the parent (``= left + right``).
|
|
1434
|
-
"""
|
|
1435
|
-
|
|
1436
|
-
left: UInt[Array, ' num_trees']
|
|
1437
|
-
right: UInt[Array, ' num_trees']
|
|
1438
|
-
total: UInt[Array, ' num_trees']
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
class Precs(Module):
|
|
1442
|
-
"""
|
|
1443
|
-
Likelihood precision scale in the nodes involved in proposed moves for each tree.
|
|
1444
|
-
|
|
1445
|
-
The "likelihood precision scale" of a tree node is the sum of the inverse
|
|
1446
|
-
squared error scales of the datapoints selected by the node.
|
|
1447
|
-
|
|
1448
|
-
Parameters
|
|
1449
|
-
----------
|
|
1450
|
-
left
|
|
1451
|
-
Likelihood precision scale in the left child.
|
|
1452
|
-
right
|
|
1453
|
-
Likelihood precision scale in the right child.
|
|
1454
|
-
total
|
|
1455
|
-
Likelihood precision scale in the parent (``= left + right``).
|
|
1456
|
-
"""
|
|
1457
|
-
|
|
1458
|
-
left: Float32[Array, ' num_trees']
|
|
1459
|
-
right: Float32[Array, ' num_trees']
|
|
1460
|
-
total: Float32[Array, ' num_trees']
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
class PreLkV(Module):
|
|
1464
|
-
"""
|
|
1465
|
-
Non-sequential terms of the likelihood ratio for each tree.
|
|
1466
|
-
|
|
1467
|
-
These terms can be computed in parallel across trees.
|
|
1468
|
-
|
|
1469
|
-
Parameters
|
|
1470
|
-
----------
|
|
1471
|
-
sigma2_left
|
|
1472
|
-
The noise variance in the left child of the leaves grown or pruned by
|
|
1473
|
-
the moves.
|
|
1474
|
-
sigma2_right
|
|
1475
|
-
The noise variance in the right child of the leaves grown or pruned by
|
|
1476
|
-
the moves.
|
|
1477
|
-
sigma2_total
|
|
1478
|
-
The noise variance in the total of the leaves grown or pruned by the
|
|
1479
|
-
moves.
|
|
1480
|
-
sqrt_term
|
|
1481
|
-
The **logarithm** of the square root term of the likelihood ratio.
|
|
1482
|
-
"""
|
|
1483
|
-
|
|
1484
|
-
sigma2_left: Float32[Array, ' num_trees']
|
|
1485
|
-
sigma2_right: Float32[Array, ' num_trees']
|
|
1486
|
-
sigma2_total: Float32[Array, ' num_trees']
|
|
1487
|
-
sqrt_term: Float32[Array, ' num_trees']
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
class PreLk(Module):
|
|
1491
|
-
"""
|
|
1492
|
-
Non-sequential terms of the likelihood ratio shared by all trees.
|
|
1493
|
-
|
|
1494
|
-
Parameters
|
|
1495
|
-
----------
|
|
1496
|
-
exp_factor
|
|
1497
|
-
The factor to multiply the likelihood ratio by, shared by all trees.
|
|
1498
|
-
"""
|
|
1499
|
-
|
|
1500
|
-
exp_factor: Float32[Array, '']
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
class PreLf(Module):
|
|
1504
|
-
"""
|
|
1505
|
-
Pre-computed terms used to sample leaves from their posterior.
|
|
1506
|
-
|
|
1507
|
-
These terms can be computed in parallel across trees.
|
|
1508
|
-
|
|
1509
|
-
Parameters
|
|
1510
|
-
----------
|
|
1511
|
-
mean_factor
|
|
1512
|
-
The factor to be multiplied by the sum of the scaled residuals to
|
|
1513
|
-
obtain the posterior mean.
|
|
1514
|
-
centered_leaves
|
|
1515
|
-
The mean-zero normal values to be added to the posterior mean to
|
|
1516
|
-
obtain the posterior leaf samples.
|
|
1517
|
-
"""
|
|
1518
|
-
|
|
1519
|
-
mean_factor: Float32[Array, 'num_trees 2**d']
|
|
1520
|
-
centered_leaves: Float32[Array, 'num_trees 2**d']
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
class ParallelStageOut(Module):
|
|
1524
|
-
"""
|
|
1525
|
-
The output of `accept_moves_parallel_stage`.
|
|
1526
|
-
|
|
1527
|
-
Parameters
|
|
1528
|
-
----------
|
|
1529
|
-
bart
|
|
1530
|
-
A partially updated BART mcmc state.
|
|
1531
|
-
moves
|
|
1532
|
-
The proposed moves, with `partial_ratio` set to `None` and
|
|
1533
|
-
`log_trans_prior_ratio` set to its final value.
|
|
1534
|
-
prec_trees
|
|
1535
|
-
The likelihood precision scale in each potential or actual leaf node. If
|
|
1536
|
-
there is no precision scale, this is the number of points in each leaf.
|
|
1537
|
-
move_counts
|
|
1538
|
-
The counts of the number of points in the the nodes modified by the
|
|
1539
|
-
moves. If `bart.min_points_per_leaf` is not set and
|
|
1540
|
-
`bart.prec_scale` is set, they are not computed.
|
|
1541
|
-
move_precs
|
|
1542
|
-
The likelihood precision scale in each node modified by the moves. If
|
|
1543
|
-
`bart.prec_scale` is not set, this is set to `move_counts`.
|
|
1544
|
-
prelkv
|
|
1545
|
-
prelk
|
|
1546
|
-
prelf
|
|
1547
|
-
Objects with pre-computed terms of the likelihood ratios and leaf
|
|
1548
|
-
samples.
|
|
1549
|
-
"""
|
|
1550
|
-
|
|
1551
|
-
bart: State
|
|
1552
|
-
moves: Moves
|
|
1553
|
-
prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
|
|
1554
|
-
move_precs: Precs | Counts
|
|
1555
|
-
prelkv: PreLkV
|
|
1556
|
-
prelk: PreLk
|
|
1557
|
-
prelf: PreLf
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
def accept_moves_parallel_stage(
|
|
1561
|
-
key: Key[Array, ''], bart: State, moves: Moves
|
|
1562
|
-
) -> ParallelStageOut:
|
|
1563
|
-
"""
|
|
1564
|
-
Pre-compute quantities used to accept moves, in parallel across trees.
|
|
1565
|
-
|
|
1566
|
-
Parameters
|
|
1567
|
-
----------
|
|
1568
|
-
key : jax.dtypes.prng_key array
|
|
1569
|
-
A jax random key.
|
|
1570
|
-
bart : dict
|
|
1571
|
-
A BART mcmc state.
|
|
1572
|
-
moves : dict
|
|
1573
|
-
The proposed moves, see `propose_moves`.
|
|
1574
|
-
|
|
1575
|
-
Returns
|
|
1576
|
-
-------
|
|
1577
|
-
An object with all that could be done in parallel.
|
|
1578
|
-
"""
|
|
1579
|
-
# where the move is grow, modify the state like the move was accepted
|
|
1580
|
-
bart = replace(
|
|
1581
|
-
bart,
|
|
1582
|
-
forest=replace(
|
|
1583
|
-
bart.forest,
|
|
1584
|
-
var_tree=moves.var_tree,
|
|
1585
|
-
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
1586
|
-
leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
|
|
1587
|
-
),
|
|
1588
|
-
)
|
|
1589
|
-
|
|
1590
|
-
# count number of datapoints per leaf
|
|
1591
|
-
if (
|
|
1592
|
-
bart.forest.min_points_per_decision_node is not None
|
|
1593
|
-
or bart.forest.min_points_per_leaf is not None
|
|
1594
|
-
or bart.prec_scale is None
|
|
1595
|
-
):
|
|
1596
|
-
count_trees, move_counts = compute_count_trees(
|
|
1597
|
-
bart.forest.leaf_indices, moves, bart.forest.count_batch_size
|
|
1598
|
-
)
|
|
1599
|
-
|
|
1600
|
-
# mark which leaves & potential leaves have enough points to be grown
|
|
1601
|
-
if bart.forest.min_points_per_decision_node is not None:
|
|
1602
|
-
count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
|
|
1603
|
-
moves = replace(
|
|
1604
|
-
moves,
|
|
1605
|
-
affluence_tree=moves.affluence_tree
|
|
1606
|
-
& (count_half_trees >= bart.forest.min_points_per_decision_node),
|
|
1607
|
-
)
|
|
1608
|
-
|
|
1609
|
-
# copy updated affluence_tree to state
|
|
1610
|
-
bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
|
|
1611
|
-
|
|
1612
|
-
# veto grove move if new leaves don't have enough datapoints
|
|
1613
|
-
if bart.forest.min_points_per_leaf is not None:
|
|
1614
|
-
moves = replace(
|
|
1615
|
-
moves,
|
|
1616
|
-
allowed=moves.allowed
|
|
1617
|
-
& (move_counts.left >= bart.forest.min_points_per_leaf)
|
|
1618
|
-
& (move_counts.right >= bart.forest.min_points_per_leaf),
|
|
1619
|
-
)
|
|
1620
|
-
|
|
1621
|
-
# count number of datapoints per leaf, weighted by error precision scale
|
|
1622
|
-
if bart.prec_scale is None:
|
|
1623
|
-
prec_trees = count_trees
|
|
1624
|
-
move_precs = move_counts
|
|
1625
|
-
else:
|
|
1626
|
-
prec_trees, move_precs = compute_prec_trees(
|
|
1627
|
-
bart.prec_scale,
|
|
1628
|
-
bart.forest.leaf_indices,
|
|
1629
|
-
moves,
|
|
1630
|
-
bart.forest.count_batch_size,
|
|
1631
|
-
)
|
|
1632
|
-
assert move_precs is not None
|
|
1633
|
-
|
|
1634
|
-
# compute some missing information about moves
|
|
1635
|
-
moves = complete_ratio(moves, bart.forest.p_nonterminal)
|
|
1636
|
-
save_ratios = bart.forest.log_likelihood is not None
|
|
1637
|
-
bart = replace(
|
|
1638
|
-
bart,
|
|
1639
|
-
forest=replace(
|
|
1640
|
-
bart.forest,
|
|
1641
|
-
grow_prop_count=jnp.sum(moves.grow),
|
|
1642
|
-
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
1643
|
-
log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
|
|
1644
|
-
),
|
|
1645
|
-
)
|
|
1646
|
-
|
|
1647
|
-
# pre-compute some likelihood ratio & posterior terms
|
|
1648
|
-
assert bart.sigma2 is not None # `step` shall temporarily set it to 1
|
|
1649
|
-
prelkv, prelk = precompute_likelihood_terms(
|
|
1650
|
-
bart.sigma2, bart.forest.sigma_mu2, move_precs
|
|
1651
|
-
)
|
|
1652
|
-
prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
|
|
1653
|
-
|
|
1654
|
-
return ParallelStageOut(
|
|
1655
|
-
bart=bart,
|
|
1656
|
-
moves=moves,
|
|
1657
|
-
prec_trees=prec_trees,
|
|
1658
|
-
move_precs=move_precs,
|
|
1659
|
-
prelkv=prelkv,
|
|
1660
|
-
prelk=prelk,
|
|
1661
|
-
prelf=prelf,
|
|
1662
|
-
)
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
@partial(vmap_nodoc, in_axes=(0, 0, None))
|
|
1666
|
-
def apply_grow_to_indices(
|
|
1667
|
-
moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
|
|
1668
|
-
) -> UInt[Array, 'num_trees n']:
|
|
1669
|
-
"""
|
|
1670
|
-
Update the leaf indices to apply a grow move.
|
|
1671
|
-
|
|
1672
|
-
Parameters
|
|
1673
|
-
----------
|
|
1674
|
-
moves
|
|
1675
|
-
The proposed moves, see `propose_moves`.
|
|
1676
|
-
leaf_indices
|
|
1677
|
-
The index of the leaf each datapoint falls into.
|
|
1678
|
-
X
|
|
1679
|
-
The predictors matrix.
|
|
1680
|
-
|
|
1681
|
-
Returns
|
|
1682
|
-
-------
|
|
1683
|
-
The updated leaf indices.
|
|
1684
|
-
"""
|
|
1685
|
-
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
1686
|
-
go_right = X[moves.grow_var, :] >= moves.grow_split
|
|
1687
|
-
tree_size = jnp.array(2 * moves.var_tree.size)
|
|
1688
|
-
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
1689
|
-
return jnp.where(
|
|
1690
|
-
leaf_indices == node_to_update, left_child + go_right, leaf_indices
|
|
1691
|
-
)
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
def compute_count_trees(
|
|
1695
|
-
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
|
|
1696
|
-
) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
|
|
1697
|
-
"""
|
|
1698
|
-
Count the number of datapoints in each leaf.
|
|
1699
|
-
|
|
1700
|
-
Parameters
|
|
1701
|
-
----------
|
|
1702
|
-
leaf_indices
|
|
1703
|
-
The index of the leaf each datapoint falls into, with the deeper version
|
|
1704
|
-
of the tree (post-GROW, pre-PRUNE).
|
|
1705
|
-
moves
|
|
1706
|
-
The proposed moves, see `propose_moves`.
|
|
1707
|
-
batch_size
|
|
1708
|
-
The data batch size to use for the summation.
|
|
1709
|
-
|
|
1710
|
-
Returns
|
|
1711
|
-
-------
|
|
1712
|
-
count_trees : Int32[Array, 'num_trees 2**d']
|
|
1713
|
-
The number of points in each potential or actual leaf node.
|
|
1714
|
-
counts : Counts
|
|
1715
|
-
The counts of the number of points in the leaves grown or pruned by the
|
|
1716
|
-
moves.
|
|
1717
|
-
"""
|
|
1718
|
-
num_trees, tree_size = moves.var_tree.shape
|
|
1719
|
-
tree_size *= 2
|
|
1720
|
-
tree_indices = jnp.arange(num_trees)
|
|
1721
|
-
|
|
1722
|
-
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
1723
|
-
|
|
1724
|
-
# count datapoints in nodes modified by move
|
|
1725
|
-
left = count_trees[tree_indices, moves.left]
|
|
1726
|
-
right = count_trees[tree_indices, moves.right]
|
|
1727
|
-
counts = Counts(left=left, right=right, total=left + right)
|
|
1728
|
-
|
|
1729
|
-
# write count into non-leaf node
|
|
1730
|
-
count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
|
|
1731
|
-
|
|
1732
|
-
return count_trees, counts
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
def count_datapoints_per_leaf(
|
|
1736
|
-
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
|
|
1737
|
-
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1738
|
-
"""
|
|
1739
|
-
Count the number of datapoints in each leaf.
|
|
1740
|
-
|
|
1741
|
-
Parameters
|
|
1742
|
-
----------
|
|
1743
|
-
leaf_indices
|
|
1744
|
-
The index of the leaf each datapoint falls into.
|
|
1745
|
-
tree_size
|
|
1746
|
-
The size of the leaf tree array (2 ** d).
|
|
1747
|
-
batch_size
|
|
1748
|
-
The data batch size to use for the summation.
|
|
1749
|
-
|
|
1750
|
-
Returns
|
|
1751
|
-
-------
|
|
1752
|
-
The number of points in each leaf node.
|
|
1753
|
-
"""
|
|
1754
|
-
if batch_size is None:
|
|
1755
|
-
return _count_scan(leaf_indices, tree_size)
|
|
1756
|
-
else:
|
|
1757
|
-
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
def _count_scan(
|
|
1761
|
-
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
|
|
1762
|
-
) -> Int32[Array, 'num_trees {tree_size}']:
|
|
1763
|
-
def loop(_, leaf_indices):
|
|
1764
|
-
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1765
|
-
|
|
1766
|
-
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1767
|
-
return count_trees
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
def _aggregate_scatter(
|
|
1771
|
-
values: Shaped[Array, '*'],
|
|
1772
|
-
indices: Integer[Array, '*'],
|
|
1773
|
-
size: int,
|
|
1774
|
-
dtype: jnp.dtype,
|
|
1775
|
-
) -> Shaped[Array, ' {size}']:
|
|
1776
|
-
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
def _count_vec(
|
|
1780
|
-
leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
|
|
1781
|
-
) -> Int32[Array, 'num_trees 2**(d-1)']:
|
|
1782
|
-
return _aggregate_batched_alltrees(
|
|
1783
|
-
1, leaf_indices, tree_size, jnp.uint32, batch_size
|
|
1784
|
-
)
|
|
1785
|
-
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
def _aggregate_batched_alltrees(
|
|
1789
|
-
values: Shaped[Array, '*'],
|
|
1790
|
-
indices: UInt[Array, 'num_trees n'],
|
|
1791
|
-
size: int,
|
|
1792
|
-
dtype: jnp.dtype,
|
|
1793
|
-
batch_size: int,
|
|
1794
|
-
) -> Shaped[Array, 'num_trees {size}']:
|
|
1795
|
-
num_trees, n = indices.shape
|
|
1796
|
-
tree_indices = jnp.arange(num_trees)
|
|
1797
|
-
nbatches = n // batch_size + bool(n % batch_size)
|
|
1798
|
-
batch_indices = jnp.arange(n) % nbatches
|
|
1799
|
-
return (
|
|
1800
|
-
jnp.zeros((num_trees, size, nbatches), dtype)
|
|
1801
|
-
.at[tree_indices[:, None], indices, batch_indices]
|
|
1802
|
-
.add(values)
|
|
1803
|
-
.sum(axis=2)
|
|
1804
|
-
)
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
def compute_prec_trees(
|
|
1808
|
-
prec_scale: Float32[Array, ' n'],
|
|
1809
|
-
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1810
|
-
moves: Moves,
|
|
1811
|
-
batch_size: int | None,
|
|
1812
|
-
) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
|
|
1813
|
-
"""
|
|
1814
|
-
Compute the likelihood precision scale in each leaf.
|
|
1815
|
-
|
|
1816
|
-
Parameters
|
|
1817
|
-
----------
|
|
1818
|
-
prec_scale
|
|
1819
|
-
The scale of the precision of the error on each datapoint.
|
|
1820
|
-
leaf_indices
|
|
1821
|
-
The index of the leaf each datapoint falls into, with the deeper version
|
|
1822
|
-
of the tree (post-GROW, pre-PRUNE).
|
|
1823
|
-
moves
|
|
1824
|
-
The proposed moves, see `propose_moves`.
|
|
1825
|
-
batch_size
|
|
1826
|
-
The data batch size to use for the summation.
|
|
1827
|
-
|
|
1828
|
-
Returns
|
|
1829
|
-
-------
|
|
1830
|
-
prec_trees : Float32[Array, 'num_trees 2**d']
|
|
1831
|
-
The likelihood precision scale in each potential or actual leaf node.
|
|
1832
|
-
precs : Precs
|
|
1833
|
-
The likelihood precision scale in the nodes involved in the moves.
|
|
1834
|
-
"""
|
|
1835
|
-
num_trees, tree_size = moves.var_tree.shape
|
|
1836
|
-
tree_size *= 2
|
|
1837
|
-
tree_indices = jnp.arange(num_trees)
|
|
1838
|
-
|
|
1839
|
-
prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1840
|
-
|
|
1841
|
-
# prec datapoints in nodes modified by move
|
|
1842
|
-
left = prec_trees[tree_indices, moves.left]
|
|
1843
|
-
right = prec_trees[tree_indices, moves.right]
|
|
1844
|
-
precs = Precs(left=left, right=right, total=left + right)
|
|
1845
|
-
|
|
1846
|
-
# write prec into non-leaf node
|
|
1847
|
-
prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
|
|
1848
|
-
|
|
1849
|
-
return prec_trees, precs
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
def prec_per_leaf(
|
|
1853
|
-
prec_scale: Float32[Array, ' n'],
|
|
1854
|
-
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1855
|
-
tree_size: int,
|
|
1856
|
-
batch_size: int | None,
|
|
1857
|
-
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1858
|
-
"""
|
|
1859
|
-
Compute the likelihood precision scale in each leaf.
|
|
1860
|
-
|
|
1861
|
-
Parameters
|
|
1862
|
-
----------
|
|
1863
|
-
prec_scale
|
|
1864
|
-
The scale of the precision of the error on each datapoint.
|
|
1865
|
-
leaf_indices
|
|
1866
|
-
The index of the leaf each datapoint falls into.
|
|
1867
|
-
tree_size
|
|
1868
|
-
The size of the leaf tree array (2 ** d).
|
|
1869
|
-
batch_size
|
|
1870
|
-
The data batch size to use for the summation.
|
|
1871
|
-
|
|
1872
|
-
Returns
|
|
1873
|
-
-------
|
|
1874
|
-
The likelihood precision scale in each leaf node.
|
|
1875
|
-
"""
|
|
1876
|
-
if batch_size is None:
|
|
1877
|
-
return _prec_scan(prec_scale, leaf_indices, tree_size)
|
|
1878
|
-
else:
|
|
1879
|
-
return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
def _prec_scan(
|
|
1883
|
-
prec_scale: Float32[Array, ' n'],
|
|
1884
|
-
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1885
|
-
tree_size: int,
|
|
1886
|
-
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1887
|
-
def loop(_, leaf_indices):
|
|
1888
|
-
return None, _aggregate_scatter(
|
|
1889
|
-
prec_scale, leaf_indices, tree_size, jnp.float32
|
|
1890
|
-
)
|
|
1891
|
-
|
|
1892
|
-
_, prec_trees = lax.scan(loop, None, leaf_indices)
|
|
1893
|
-
return prec_trees
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
def _prec_vec(
|
|
1897
|
-
prec_scale: Float32[Array, ' n'],
|
|
1898
|
-
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1899
|
-
tree_size: int,
|
|
1900
|
-
batch_size: int,
|
|
1901
|
-
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
1902
|
-
return _aggregate_batched_alltrees(
|
|
1903
|
-
prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
|
|
1904
|
-
)
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
|
|
1908
|
-
"""
|
|
1909
|
-
Complete non-likelihood MH ratio calculation.
|
|
1910
|
-
|
|
1911
|
-
This function adds the probability of choosing a prune move over the grow
|
|
1912
|
-
move in the inverse transition, and the a priori probability that the
|
|
1913
|
-
children nodes are leaves.
|
|
1914
|
-
|
|
1915
|
-
Parameters
|
|
1916
|
-
----------
|
|
1917
|
-
moves
|
|
1918
|
-
The proposed moves. Must have already been updated to keep into account
|
|
1919
|
-
the thresholds on the number of datapoints per node, this happens in
|
|
1920
|
-
`accept_moves_parallel_stage`.
|
|
1921
|
-
p_nonterminal
|
|
1922
|
-
The a priori probability of each node being nonterminal conditional on
|
|
1923
|
-
its ancestors, including at the maximum depth where it should be zero.
|
|
1924
|
-
|
|
1925
|
-
Returns
|
|
1926
|
-
-------
|
|
1927
|
-
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
1928
|
-
"""
|
|
1929
|
-
# can the leaves can be grown?
|
|
1930
|
-
num_trees, _ = moves.affluence_tree.shape
|
|
1931
|
-
tree_indices = jnp.arange(num_trees)
|
|
1932
|
-
left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
|
|
1933
|
-
mode='fill', fill_value=False
|
|
1934
|
-
)
|
|
1935
|
-
right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
|
|
1936
|
-
mode='fill', fill_value=False
|
|
1937
|
-
)
|
|
1938
|
-
|
|
1939
|
-
# p_prune if grow
|
|
1940
|
-
other_growable_leaves = moves.num_growable >= 2
|
|
1941
|
-
grow_again_allowed = other_growable_leaves | left_growable | right_growable
|
|
1942
|
-
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1943
|
-
|
|
1944
|
-
# p_prune if prune
|
|
1945
|
-
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
1946
|
-
|
|
1947
|
-
# select p_prune
|
|
1948
|
-
p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
1949
|
-
|
|
1950
|
-
# prior probability of both children being terminal
|
|
1951
|
-
pt_left = 1 - p_nonterminal[moves.left] * left_growable
|
|
1952
|
-
pt_right = 1 - p_nonterminal[moves.right] * right_growable
|
|
1953
|
-
pt_children = pt_left * pt_right
|
|
1954
|
-
|
|
1955
|
-
return replace(
|
|
1956
|
-
moves,
|
|
1957
|
-
log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
|
|
1958
|
-
partial_ratio=None,
|
|
1959
|
-
)
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
@vmap_nodoc
|
|
1963
|
-
def adapt_leaf_trees_to_grow_indices(
|
|
1964
|
-
leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
|
|
1965
|
-
) -> Float32[Array, 'num_trees 2**d']:
|
|
1966
|
-
"""
|
|
1967
|
-
Modify leaves such that post-grow indices work on the original tree.
|
|
1968
|
-
|
|
1969
|
-
The value of the leaf to grow is copied to what would be its children if the
|
|
1970
|
-
grow move was accepted.
|
|
1971
|
-
|
|
1972
|
-
Parameters
|
|
1973
|
-
----------
|
|
1974
|
-
leaf_trees
|
|
1975
|
-
The leaf values.
|
|
1976
|
-
moves
|
|
1977
|
-
The proposed moves, see `propose_moves`.
|
|
1978
|
-
|
|
1979
|
-
Returns
|
|
1980
|
-
-------
|
|
1981
|
-
The modified leaf values.
|
|
1982
|
-
"""
|
|
1983
|
-
values_at_node = leaf_trees[moves.node]
|
|
1984
|
-
return (
|
|
1985
|
-
leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
|
|
1986
|
-
.set(values_at_node)
|
|
1987
|
-
.at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
|
|
1988
|
-
.set(values_at_node)
|
|
1989
|
-
)
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
def precompute_likelihood_terms(
|
|
1993
|
-
sigma2: Float32[Array, ''],
|
|
1994
|
-
sigma_mu2: Float32[Array, ''],
|
|
1995
|
-
move_precs: Precs | Counts,
|
|
1996
|
-
) -> tuple[PreLkV, PreLk]:
|
|
1997
|
-
"""
|
|
1998
|
-
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1999
|
-
|
|
2000
|
-
Parameters
|
|
2001
|
-
----------
|
|
2002
|
-
sigma2
|
|
2003
|
-
The error variance, or the global error variance factor is `prec_scale`
|
|
2004
|
-
is set.
|
|
2005
|
-
sigma_mu2
|
|
2006
|
-
The prior variance of each leaf.
|
|
2007
|
-
move_precs
|
|
2008
|
-
The likelihood precision scale in the leaves grown or pruned by the
|
|
2009
|
-
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
2010
|
-
|
|
2011
|
-
Returns
|
|
2012
|
-
-------
|
|
2013
|
-
prelkv : PreLkV
|
|
2014
|
-
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
2015
|
-
tree.
|
|
2016
|
-
prelk : PreLk
|
|
2017
|
-
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
2018
|
-
all trees.
|
|
2019
|
-
"""
|
|
2020
|
-
sigma2_left = sigma2 + move_precs.left * sigma_mu2
|
|
2021
|
-
sigma2_right = sigma2 + move_precs.right * sigma_mu2
|
|
2022
|
-
sigma2_total = sigma2 + move_precs.total * sigma_mu2
|
|
2023
|
-
prelkv = PreLkV(
|
|
2024
|
-
sigma2_left=sigma2_left,
|
|
2025
|
-
sigma2_right=sigma2_right,
|
|
2026
|
-
sigma2_total=sigma2_total,
|
|
2027
|
-
sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
|
|
2028
|
-
)
|
|
2029
|
-
return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
def precompute_leaf_terms(
|
|
2033
|
-
key: Key[Array, ''],
|
|
2034
|
-
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
2035
|
-
sigma2: Float32[Array, ''],
|
|
2036
|
-
sigma_mu2: Float32[Array, ''],
|
|
2037
|
-
) -> PreLf:
|
|
2038
|
-
"""
|
|
2039
|
-
Pre-compute terms used to sample leaves from their posterior.
|
|
2040
|
-
|
|
2041
|
-
Parameters
|
|
2042
|
-
----------
|
|
2043
|
-
key
|
|
2044
|
-
A jax random key.
|
|
2045
|
-
prec_trees
|
|
2046
|
-
The likelihood precision scale in each potential or actual leaf node.
|
|
2047
|
-
sigma2
|
|
2048
|
-
The error variance, or the global error variance factor if `prec_scale`
|
|
2049
|
-
is set.
|
|
2050
|
-
sigma_mu2
|
|
2051
|
-
The prior variance of each leaf.
|
|
2052
|
-
|
|
2053
|
-
Returns
|
|
2054
|
-
-------
|
|
2055
|
-
Pre-computed terms for leaf sampling.
|
|
2056
|
-
"""
|
|
2057
|
-
prec_lk = prec_trees / sigma2
|
|
2058
|
-
prec_prior = lax.reciprocal(sigma_mu2)
|
|
2059
|
-
var_post = lax.reciprocal(prec_lk + prec_prior)
|
|
2060
|
-
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
2061
|
-
return PreLf(
|
|
2062
|
-
mean_factor=var_post / sigma2,
|
|
2063
|
-
# | mean = mean_lk * prec_lk * var_post
|
|
2064
|
-
# | resid_tree = mean_lk * prec_tree -->
|
|
2065
|
-
# | --> mean_lk = resid_tree / prec_tree (kind of)
|
|
2066
|
-
# | mean_factor =
|
|
2067
|
-
# | = mean / resid_tree =
|
|
2068
|
-
# | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
2069
|
-
# | = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
2070
|
-
# | = var_post / sigma2
|
|
2071
|
-
centered_leaves=z * jnp.sqrt(var_post),
|
|
2072
|
-
)
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
2076
|
-
"""
|
|
2077
|
-
Accept/reject the moves one tree at a time.
|
|
2078
|
-
|
|
2079
|
-
This is the most performance-sensitive function because it contains all and
|
|
2080
|
-
only the parts of the algorithm that can not be parallelized across trees.
|
|
2081
|
-
|
|
2082
|
-
Parameters
|
|
2083
|
-
----------
|
|
2084
|
-
pso
|
|
2085
|
-
The output of `accept_moves_parallel_stage`.
|
|
2086
|
-
|
|
2087
|
-
Returns
|
|
2088
|
-
-------
|
|
2089
|
-
bart : State
|
|
2090
|
-
A partially updated BART mcmc state.
|
|
2091
|
-
moves : Moves
|
|
2092
|
-
The accepted/rejected moves, with `acc` and `to_prune` set.
|
|
2093
|
-
"""
|
|
2094
|
-
|
|
2095
|
-
def loop(resid, pt):
|
|
2096
|
-
resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
|
|
2097
|
-
resid,
|
|
2098
|
-
SeqStageInAllTrees(
|
|
2099
|
-
pso.bart.X,
|
|
2100
|
-
pso.bart.forest.resid_batch_size,
|
|
2101
|
-
pso.bart.prec_scale,
|
|
2102
|
-
pso.bart.forest.log_likelihood is not None,
|
|
2103
|
-
pso.prelk,
|
|
2104
|
-
),
|
|
2105
|
-
pt,
|
|
2106
|
-
)
|
|
2107
|
-
return resid, (leaf_tree, acc, to_prune, lkratio)
|
|
2108
|
-
|
|
2109
|
-
pts = SeqStageInPerTree(
|
|
2110
|
-
pso.bart.forest.leaf_tree,
|
|
2111
|
-
pso.prec_trees,
|
|
2112
|
-
pso.moves,
|
|
2113
|
-
pso.move_precs,
|
|
2114
|
-
pso.bart.forest.leaf_indices,
|
|
2115
|
-
pso.prelkv,
|
|
2116
|
-
pso.prelf,
|
|
2117
|
-
)
|
|
2118
|
-
resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
|
|
2119
|
-
|
|
2120
|
-
bart = replace(
|
|
2121
|
-
pso.bart,
|
|
2122
|
-
resid=resid,
|
|
2123
|
-
forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
|
|
2124
|
-
)
|
|
2125
|
-
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
2126
|
-
|
|
2127
|
-
return bart, moves
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
class SeqStageInAllTrees(Module):
|
|
2131
|
-
"""
|
|
2132
|
-
The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
|
|
2133
|
-
|
|
2134
|
-
Parameters
|
|
2135
|
-
----------
|
|
2136
|
-
X
|
|
2137
|
-
The predictors.
|
|
2138
|
-
resid_batch_size
|
|
2139
|
-
The batch size for computing the sum of residuals in each leaf.
|
|
2140
|
-
prec_scale
|
|
2141
|
-
The scale of the precision of the error on each datapoint. If None, it
|
|
2142
|
-
is assumed to be 1.
|
|
2143
|
-
save_ratios
|
|
2144
|
-
Whether to save the acceptance ratios.
|
|
2145
|
-
prelk
|
|
2146
|
-
The pre-computed terms of the likelihood ratio which are shared across
|
|
2147
|
-
trees.
|
|
2148
|
-
"""
|
|
2149
|
-
|
|
2150
|
-
X: UInt[Array, 'p n']
|
|
2151
|
-
resid_batch_size: int | None = field(static=True)
|
|
2152
|
-
prec_scale: Float32[Array, ' n'] | None
|
|
2153
|
-
save_ratios: bool = field(static=True)
|
|
2154
|
-
prelk: PreLk
|
|
2155
|
-
|
|
2156
|
-
|
|
2157
|
-
class SeqStageInPerTree(Module):
|
|
2158
|
-
"""
|
|
2159
|
-
The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
|
|
2160
|
-
|
|
2161
|
-
Parameters
|
|
2162
|
-
----------
|
|
2163
|
-
leaf_tree
|
|
2164
|
-
The leaf values of the tree.
|
|
2165
|
-
prec_tree
|
|
2166
|
-
The likelihood precision scale in each potential or actual leaf node.
|
|
2167
|
-
move
|
|
2168
|
-
The proposed move, see `propose_moves`.
|
|
2169
|
-
move_precs
|
|
2170
|
-
The likelihood precision scale in each node modified by the moves.
|
|
2171
|
-
leaf_indices
|
|
2172
|
-
The leaf indices for the largest version of the tree compatible with
|
|
2173
|
-
the move.
|
|
2174
|
-
prelkv
|
|
2175
|
-
prelf
|
|
2176
|
-
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
2177
|
-
are specific to the tree.
|
|
2178
|
-
"""
|
|
2179
|
-
|
|
2180
|
-
leaf_tree: Float32[Array, ' 2**d']
|
|
2181
|
-
prec_tree: Float32[Array, ' 2**d']
|
|
2182
|
-
move: Moves
|
|
2183
|
-
move_precs: Precs | Counts
|
|
2184
|
-
leaf_indices: UInt[Array, ' n']
|
|
2185
|
-
prelkv: PreLkV
|
|
2186
|
-
prelf: PreLf
|
|
2187
|
-
|
|
2188
|
-
|
|
2189
|
-
def accept_move_and_sample_leaves(
|
|
2190
|
-
resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
|
|
2191
|
-
) -> tuple[
|
|
2192
|
-
Float32[Array, ' n'],
|
|
2193
|
-
Float32[Array, ' 2**d'],
|
|
2194
|
-
Bool[Array, ''],
|
|
2195
|
-
Bool[Array, ''],
|
|
2196
|
-
Float32[Array, ''] | None,
|
|
2197
|
-
]:
|
|
2198
|
-
"""
|
|
2199
|
-
Accept or reject a proposed move and sample the new leaf values.
|
|
2200
|
-
|
|
2201
|
-
Parameters
|
|
2202
|
-
----------
|
|
2203
|
-
resid
|
|
2204
|
-
The residuals (data minus forest value).
|
|
2205
|
-
at
|
|
2206
|
-
The inputs that are the same for all trees.
|
|
2207
|
-
pt
|
|
2208
|
-
The inputs that are separate for each tree.
|
|
2209
|
-
|
|
2210
|
-
Returns
|
|
2211
|
-
-------
|
|
2212
|
-
resid : Float32[Array, 'n']
|
|
2213
|
-
The updated residuals (data minus forest value).
|
|
2214
|
-
leaf_tree : Float32[Array, '2**d']
|
|
2215
|
-
The new leaf values of the tree.
|
|
2216
|
-
acc : Bool[Array, '']
|
|
2217
|
-
Whether the move was accepted.
|
|
2218
|
-
to_prune : Bool[Array, '']
|
|
2219
|
-
Whether, to reflect the acceptance status of the move, the state should
|
|
2220
|
-
be updated by pruning the leaves involved in the move.
|
|
2221
|
-
log_lk_ratio : Float32[Array, ''] | None
|
|
2222
|
-
The logarithm of the likelihood ratio for the move. `None` if not to be
|
|
2223
|
-
saved.
|
|
2224
|
-
"""
|
|
2225
|
-
# sum residuals in each leaf, in tree proposed by grow move
|
|
2226
|
-
if at.prec_scale is None:
|
|
2227
|
-
scaled_resid = resid
|
|
2228
|
-
else:
|
|
2229
|
-
scaled_resid = resid * at.prec_scale
|
|
2230
|
-
resid_tree = sum_resid(
|
|
2231
|
-
scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
|
|
2232
|
-
)
|
|
2233
|
-
|
|
2234
|
-
# subtract starting tree from function
|
|
2235
|
-
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
2236
|
-
|
|
2237
|
-
# sum residuals in parent node modified by move
|
|
2238
|
-
resid_left = resid_tree[pt.move.left]
|
|
2239
|
-
resid_right = resid_tree[pt.move.right]
|
|
2240
|
-
resid_total = resid_left + resid_right
|
|
2241
|
-
assert pt.move.node.dtype == jnp.int32
|
|
2242
|
-
resid_tree = resid_tree.at[pt.move.node].set(resid_total)
|
|
2243
|
-
|
|
2244
|
-
# compute acceptance ratio
|
|
2245
|
-
log_lk_ratio = compute_likelihood_ratio(
|
|
2246
|
-
resid_total, resid_left, resid_right, pt.prelkv, at.prelk
|
|
2247
|
-
)
|
|
2248
|
-
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
2249
|
-
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
2250
|
-
if not at.save_ratios:
|
|
2251
|
-
log_lk_ratio = None
|
|
2252
|
-
|
|
2253
|
-
# determine whether to accept the move
|
|
2254
|
-
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
2255
|
-
|
|
2256
|
-
# compute leaves posterior and sample leaves
|
|
2257
|
-
mean_post = resid_tree * pt.prelf.mean_factor
|
|
2258
|
-
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
2259
|
-
|
|
2260
|
-
# copy leaves around such that the leaf indices point to the correct leaf
|
|
2261
|
-
to_prune = acc ^ pt.move.grow
|
|
2262
|
-
leaf_tree = (
|
|
2263
|
-
leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
|
|
2264
|
-
.set(leaf_tree[pt.move.node])
|
|
2265
|
-
.at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
|
|
2266
|
-
.set(leaf_tree[pt.move.node])
|
|
2267
|
-
)
|
|
2268
|
-
|
|
2269
|
-
# replace old tree with new tree in function values
|
|
2270
|
-
resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
|
|
2271
|
-
|
|
2272
|
-
return resid, leaf_tree, acc, to_prune, log_lk_ratio
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
def sum_resid(
|
|
2276
|
-
scaled_resid: Float32[Array, ' n'],
|
|
2277
|
-
leaf_indices: UInt[Array, ' n'],
|
|
2278
|
-
tree_size: int,
|
|
2279
|
-
batch_size: int | None,
|
|
2280
|
-
) -> Float32[Array, ' {tree_size}']:
|
|
2281
|
-
"""
|
|
2282
|
-
Sum the residuals in each leaf.
|
|
2283
|
-
|
|
2284
|
-
Parameters
|
|
2285
|
-
----------
|
|
2286
|
-
scaled_resid
|
|
2287
|
-
The residuals (data minus forest value) multiplied by the error
|
|
2288
|
-
precision scale.
|
|
2289
|
-
leaf_indices
|
|
2290
|
-
The leaf indices of the tree (in which leaf each data point falls into).
|
|
2291
|
-
tree_size
|
|
2292
|
-
The size of the tree array (2 ** d).
|
|
2293
|
-
batch_size
|
|
2294
|
-
The data batch size for the aggregation. Batching increases numerical
|
|
2295
|
-
accuracy and parallelism.
|
|
2296
|
-
|
|
2297
|
-
Returns
|
|
2298
|
-
-------
|
|
2299
|
-
The sum of the residuals at data points in each leaf.
|
|
2300
|
-
"""
|
|
2301
|
-
if batch_size is None:
|
|
2302
|
-
aggr_func = _aggregate_scatter
|
|
2303
|
-
else:
|
|
2304
|
-
aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
2305
|
-
return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
def _aggregate_batched_onetree(
|
|
2309
|
-
values: Shaped[Array, '*'],
|
|
2310
|
-
indices: Integer[Array, '*'],
|
|
2311
|
-
size: int,
|
|
2312
|
-
dtype: jnp.dtype,
|
|
2313
|
-
batch_size: int,
|
|
2314
|
-
) -> Float32[Array, ' {size}']:
|
|
2315
|
-
(n,) = indices.shape
|
|
2316
|
-
nbatches = n // batch_size + bool(n % batch_size)
|
|
2317
|
-
batch_indices = jnp.arange(n) % nbatches
|
|
2318
|
-
return (
|
|
2319
|
-
jnp.zeros((size, nbatches), dtype)
|
|
2320
|
-
.at[indices, batch_indices]
|
|
2321
|
-
.add(values)
|
|
2322
|
-
.sum(axis=1)
|
|
2323
|
-
)
|
|
2324
|
-
|
|
2325
|
-
|
|
2326
|
-
def compute_likelihood_ratio(
|
|
2327
|
-
total_resid: Float32[Array, ''],
|
|
2328
|
-
left_resid: Float32[Array, ''],
|
|
2329
|
-
right_resid: Float32[Array, ''],
|
|
2330
|
-
prelkv: PreLkV,
|
|
2331
|
-
prelk: PreLk,
|
|
2332
|
-
) -> Float32[Array, '']:
|
|
2333
|
-
"""
|
|
2334
|
-
Compute the likelihood ratio of a grow move.
|
|
2335
|
-
|
|
2336
|
-
Parameters
|
|
2337
|
-
----------
|
|
2338
|
-
total_resid
|
|
2339
|
-
left_resid
|
|
2340
|
-
right_resid
|
|
2341
|
-
The sum of the residuals (scaled by error precision scale) of the
|
|
2342
|
-
datapoints falling in the nodes involved in the moves.
|
|
2343
|
-
prelkv
|
|
2344
|
-
prelk
|
|
2345
|
-
The pre-computed terms of the likelihood ratio, see
|
|
2346
|
-
`precompute_likelihood_terms`.
|
|
2347
|
-
|
|
2348
|
-
Returns
|
|
2349
|
-
-------
|
|
2350
|
-
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
2351
|
-
"""
|
|
2352
|
-
exp_term = prelk.exp_factor * (
|
|
2353
|
-
left_resid * left_resid / prelkv.sigma2_left
|
|
2354
|
-
+ right_resid * right_resid / prelkv.sigma2_right
|
|
2355
|
-
- total_resid * total_resid / prelkv.sigma2_total
|
|
2356
|
-
)
|
|
2357
|
-
return prelkv.sqrt_term + exp_term
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
2361
|
-
"""
|
|
2362
|
-
Post-process the mcmc state after accepting/rejecting the moves.
|
|
2363
|
-
|
|
2364
|
-
This function is separate from `accept_moves_sequential_stage` to signal it
|
|
2365
|
-
can work in parallel across trees.
|
|
2366
|
-
|
|
2367
|
-
Parameters
|
|
2368
|
-
----------
|
|
2369
|
-
bart
|
|
2370
|
-
A partially updated BART mcmc state.
|
|
2371
|
-
moves
|
|
2372
|
-
The proposed moves (see `propose_moves`) as updated by
|
|
2373
|
-
`accept_moves_sequential_stage`.
|
|
2374
|
-
|
|
2375
|
-
Returns
|
|
2376
|
-
-------
|
|
2377
|
-
The fully updated BART mcmc state.
|
|
2378
|
-
"""
|
|
2379
|
-
return replace(
|
|
2380
|
-
bart,
|
|
2381
|
-
forest=replace(
|
|
2382
|
-
bart.forest,
|
|
2383
|
-
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
2384
|
-
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
2385
|
-
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
2386
|
-
split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
|
|
2387
|
-
),
|
|
2388
|
-
)
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
@vmap_nodoc
|
|
2392
|
-
def apply_moves_to_leaf_indices(
|
|
2393
|
-
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
|
|
2394
|
-
) -> UInt[Array, 'num_trees n']:
|
|
2395
|
-
"""
|
|
2396
|
-
Update the leaf indices to match the accepted move.
|
|
2397
|
-
|
|
2398
|
-
Parameters
|
|
2399
|
-
----------
|
|
2400
|
-
leaf_indices
|
|
2401
|
-
The index of the leaf each datapoint falls into, if the grow move was
|
|
2402
|
-
accepted.
|
|
2403
|
-
moves
|
|
2404
|
-
The proposed moves (see `propose_moves`), as updated by
|
|
2405
|
-
`accept_moves_sequential_stage`.
|
|
2406
|
-
|
|
2407
|
-
Returns
|
|
2408
|
-
-------
|
|
2409
|
-
The updated leaf indices.
|
|
2410
|
-
"""
|
|
2411
|
-
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
2412
|
-
is_child = (leaf_indices & mask) == moves.left
|
|
2413
|
-
return jnp.where(
|
|
2414
|
-
is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
|
|
2415
|
-
)
|
|
2416
|
-
|
|
2417
|
-
|
|
2418
|
-
@vmap_nodoc
|
|
2419
|
-
def apply_moves_to_split_trees(
|
|
2420
|
-
split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
2421
|
-
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
2422
|
-
"""
|
|
2423
|
-
Update the split trees to match the accepted move.
|
|
2424
|
-
|
|
2425
|
-
Parameters
|
|
2426
|
-
----------
|
|
2427
|
-
split_tree
|
|
2428
|
-
The cutpoints of the decision nodes in the initial trees.
|
|
2429
|
-
moves
|
|
2430
|
-
The proposed moves (see `propose_moves`), as updated by
|
|
2431
|
-
`accept_moves_sequential_stage`.
|
|
2432
|
-
|
|
2433
|
-
Returns
|
|
2434
|
-
-------
|
|
2435
|
-
The updated split trees.
|
|
2436
|
-
"""
|
|
2437
|
-
assert moves.to_prune is not None
|
|
2438
|
-
return (
|
|
2439
|
-
split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
|
|
2440
|
-
.set(moves.grow_split.astype(split_tree.dtype))
|
|
2441
|
-
.at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
|
|
2442
|
-
.set(0)
|
|
2443
|
-
)
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
def step_sigma(key: Key[Array, ''], bart: State) -> State:
|
|
2447
|
-
"""
|
|
2448
|
-
MCMC-update the error variance (factor).
|
|
2449
|
-
|
|
2450
|
-
Parameters
|
|
2451
|
-
----------
|
|
2452
|
-
key
|
|
2453
|
-
A jax random key.
|
|
2454
|
-
bart
|
|
2455
|
-
A BART mcmc state.
|
|
2456
|
-
|
|
2457
|
-
Returns
|
|
2458
|
-
-------
|
|
2459
|
-
The new BART mcmc state, with an updated `sigma2`.
|
|
2460
|
-
"""
|
|
2461
|
-
resid = bart.resid
|
|
2462
|
-
alpha = bart.sigma2_alpha + resid.size / 2
|
|
2463
|
-
if bart.prec_scale is None:
|
|
2464
|
-
scaled_resid = resid
|
|
2465
|
-
else:
|
|
2466
|
-
scaled_resid = resid * bart.prec_scale
|
|
2467
|
-
norm2 = resid @ scaled_resid
|
|
2468
|
-
beta = bart.sigma2_beta + norm2 / 2
|
|
2469
|
-
|
|
2470
|
-
sample = random.gamma(key, alpha)
|
|
2471
|
-
# random.gamma seems to be slow at compiling, maybe cdf inversion would
|
|
2472
|
-
# be better, but it's not implemented in jax
|
|
2473
|
-
return replace(bart, sigma2=beta / sample)
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
2477
|
-
"""
|
|
2478
|
-
MCMC-update the latent variable for binary regression.
|
|
2479
|
-
|
|
2480
|
-
Parameters
|
|
2481
|
-
----------
|
|
2482
|
-
key
|
|
2483
|
-
A jax random key.
|
|
2484
|
-
bart
|
|
2485
|
-
A BART MCMC state.
|
|
2486
|
-
|
|
2487
|
-
Returns
|
|
2488
|
-
-------
|
|
2489
|
-
The updated BART MCMC state.
|
|
2490
|
-
"""
|
|
2491
|
-
trees_plus_offset = bart.z - bart.resid
|
|
2492
|
-
assert bart.y.dtype == bool
|
|
2493
|
-
resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
|
|
2494
|
-
z = trees_plus_offset + resid
|
|
2495
|
-
return replace(bart, z=z, resid=resid)
|
|
2496
|
-
|
|
2497
|
-
|
|
2498
|
-
def step_s(key: Key[Array, ''], bart: State) -> State:
|
|
2499
|
-
"""
|
|
2500
|
-
Update `log_s` using Dirichlet sampling.
|
|
2501
|
-
|
|
2502
|
-
The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
|
|
2503
|
-
is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
|
|
2504
|
-
varcount is the count of how many times each variable is used in the
|
|
2505
|
-
current forest.
|
|
2506
|
-
|
|
2507
|
-
Parameters
|
|
2508
|
-
----------
|
|
2509
|
-
key
|
|
2510
|
-
Random key for sampling.
|
|
2511
|
-
bart
|
|
2512
|
-
The current BART state.
|
|
2513
|
-
|
|
2514
|
-
Returns
|
|
2515
|
-
-------
|
|
2516
|
-
Updated BART state with re-sampled `log_s`.
|
|
2517
|
-
|
|
2518
|
-
"""
|
|
2519
|
-
assert bart.forest.theta is not None
|
|
2520
|
-
|
|
2521
|
-
# histogram current variable usage
|
|
2522
|
-
p = bart.forest.max_split.size
|
|
2523
|
-
varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
|
|
2524
|
-
|
|
2525
|
-
# sample from Dirichlet posterior
|
|
2526
|
-
alpha = bart.forest.theta / p + varcount
|
|
2527
|
-
log_s = random.loggamma(key, alpha)
|
|
2528
|
-
|
|
2529
|
-
# update forest with new s
|
|
2530
|
-
return replace(bart, forest=replace(bart.forest, log_s=log_s))
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
|
|
2534
|
-
"""
|
|
2535
|
-
Update `theta`.
|
|
2536
|
-
|
|
2537
|
-
The prior is theta / (theta + rho) ~ Beta(a, b).
|
|
2538
|
-
|
|
2539
|
-
Parameters
|
|
2540
|
-
----------
|
|
2541
|
-
key
|
|
2542
|
-
Random key for sampling.
|
|
2543
|
-
bart
|
|
2544
|
-
The current BART state.
|
|
2545
|
-
num_grid
|
|
2546
|
-
The number of points in the evenly-spaced grid used to sample
|
|
2547
|
-
theta / (theta + rho).
|
|
2548
|
-
|
|
2549
|
-
Returns
|
|
2550
|
-
-------
|
|
2551
|
-
Updated BART state with re-sampled `theta`.
|
|
2552
|
-
"""
|
|
2553
|
-
assert bart.forest.log_s is not None
|
|
2554
|
-
assert bart.forest.rho is not None
|
|
2555
|
-
assert bart.forest.a is not None
|
|
2556
|
-
assert bart.forest.b is not None
|
|
2557
|
-
|
|
2558
|
-
# the grid points are the midpoints of num_grid bins in (0, 1)
|
|
2559
|
-
padding = 1 / (2 * num_grid)
|
|
2560
|
-
lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
|
|
2561
|
-
|
|
2562
|
-
# normalize s
|
|
2563
|
-
log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
|
|
2564
|
-
|
|
2565
|
-
# sample lambda
|
|
2566
|
-
logp, theta_grid = _log_p_lamda(
|
|
2567
|
-
lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
|
|
2568
|
-
)
|
|
2569
|
-
i = random.categorical(key, logp)
|
|
2570
|
-
theta = theta_grid[i]
|
|
2571
|
-
|
|
2572
|
-
return replace(bart, forest=replace(bart.forest, theta=theta))
|
|
2573
|
-
|
|
2574
|
-
|
|
2575
|
-
def _log_p_lamda(
|
|
2576
|
-
lamda: Float32[Array, ' num_grid'],
|
|
2577
|
-
log_s: Float32[Array, ' p'],
|
|
2578
|
-
rho: Float32[Array, ''],
|
|
2579
|
-
a: Float32[Array, ''],
|
|
2580
|
-
b: Float32[Array, ''],
|
|
2581
|
-
) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
|
|
2582
|
-
# in the following I use lamda[::-1] == 1 - lamda
|
|
2583
|
-
theta = rho * lamda / lamda[::-1]
|
|
2584
|
-
p = log_s.size
|
|
2585
|
-
return (
|
|
2586
|
-
(a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
|
|
2587
|
-
+ (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
|
|
2588
|
-
+ gammaln(theta)
|
|
2589
|
-
- p * gammaln(theta / p)
|
|
2590
|
-
+ theta / p * jnp.sum(log_s)
|
|
2591
|
-
), theta
|
|
2592
|
-
|
|
2593
|
-
|
|
2594
|
-
def step_sparse(key: Key[Array, ''], bart: State) -> State:
|
|
2595
|
-
"""
|
|
2596
|
-
Update the sparsity parameters.
|
|
2597
|
-
|
|
2598
|
-
This invokes `step_s`, and then `step_theta` only if the parameters of
|
|
2599
|
-
the theta prior are defined.
|
|
2600
|
-
|
|
2601
|
-
Parameters
|
|
2602
|
-
----------
|
|
2603
|
-
key
|
|
2604
|
-
Random key for sampling.
|
|
2605
|
-
bart
|
|
2606
|
-
The current BART state.
|
|
2607
|
-
|
|
2608
|
-
Returns
|
|
2609
|
-
-------
|
|
2610
|
-
Updated BART state with re-sampled `log_s` and `theta`.
|
|
2611
|
-
"""
|
|
2612
|
-
keys = split(key)
|
|
2613
|
-
bart = step_s(keys.pop(), bart)
|
|
2614
|
-
if bart.forest.rho is not None:
|
|
2615
|
-
bart = step_theta(keys.pop(), bart)
|
|
2616
|
-
return bart
|