bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +464 -254
- bartz/__init__.py +2 -2
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +139 -93
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +468 -311
- bartz/mcmcstep.py +734 -453
- bartz/prepcovars.py +139 -43
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/METADATA +2 -3
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -423
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcstep.py
CHANGED
|
@@ -28,26 +28,33 @@ Functions that implement the BART posterior MCMC initialization and update step.
|
|
|
28
28
|
Functions that do MCMC steps operate by taking as input a bart state, and
|
|
29
29
|
outputting a new state. The inputs are not modified.
|
|
30
30
|
|
|
31
|
-
The
|
|
31
|
+
The entry points are:
|
|
32
32
|
|
|
33
33
|
- `State`: The dataclass that represents a BART MCMC state.
|
|
34
34
|
- `init`: Creates an initial `State` from data and configurations.
|
|
35
35
|
- `step`: Performs one full MCMC step on a `State`, returning a new `State`.
|
|
36
|
+
- `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
|
|
36
37
|
"""
|
|
37
38
|
|
|
38
39
|
import math
|
|
39
40
|
from dataclasses import replace
|
|
40
41
|
from functools import cache, partial
|
|
41
|
-
from typing import Any
|
|
42
|
+
from typing import Any, Literal
|
|
42
43
|
|
|
43
44
|
import jax
|
|
44
|
-
from equinox import Module, field
|
|
45
|
+
from equinox import Module, field, tree_at
|
|
45
46
|
from jax import lax, random
|
|
46
47
|
from jax import numpy as jnp
|
|
48
|
+
from jax.scipy.special import gammaln, logsumexp
|
|
47
49
|
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
|
|
48
50
|
|
|
49
|
-
from
|
|
50
|
-
from .jaxext import
|
|
51
|
+
from bartz import grove
|
|
52
|
+
from bartz.jaxext import (
|
|
53
|
+
minimal_unsigned_dtype,
|
|
54
|
+
split,
|
|
55
|
+
truncated_normal_onesided,
|
|
56
|
+
vmap_nodoc,
|
|
57
|
+
)
|
|
51
58
|
|
|
52
59
|
|
|
53
60
|
class Forest(Module):
|
|
@@ -56,24 +63,32 @@ class Forest(Module):
|
|
|
56
63
|
|
|
57
64
|
Parameters
|
|
58
65
|
----------
|
|
59
|
-
|
|
66
|
+
leaf_tree
|
|
60
67
|
The leaf values.
|
|
61
|
-
|
|
68
|
+
var_tree
|
|
62
69
|
The decision axes.
|
|
63
|
-
|
|
70
|
+
split_tree
|
|
64
71
|
The decision boundaries.
|
|
72
|
+
affluence_tree
|
|
73
|
+
Marks leaves that can be grown.
|
|
74
|
+
max_split
|
|
75
|
+
The maximum split index for each predictor.
|
|
76
|
+
blocked_vars
|
|
77
|
+
Indices of variables that are not used. This shall include at least
|
|
78
|
+
the `i` such that ``max_split[i] == 0``, otherwise behavior is
|
|
79
|
+
undefined.
|
|
65
80
|
p_nonterminal
|
|
66
|
-
The probability of
|
|
67
|
-
|
|
81
|
+
The prior probability of each node being nonterminal, conditional on
|
|
82
|
+
its ancestors. Includes the nodes at maximum depth which should be set
|
|
83
|
+
to 0.
|
|
68
84
|
p_propose_grow
|
|
69
85
|
The unnormalized probability of picking a leaf for a grow proposal.
|
|
70
86
|
leaf_indices
|
|
71
87
|
The index of the leaf each datapoints falls into, for each tree.
|
|
88
|
+
min_points_per_decision_node
|
|
89
|
+
The minimum number of data points in a decision node.
|
|
72
90
|
min_points_per_leaf
|
|
73
91
|
The minimum number of data points in a leaf node.
|
|
74
|
-
affluence_trees
|
|
75
|
-
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
76
|
-
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
77
92
|
resid_batch_size
|
|
78
93
|
count_batch_size
|
|
79
94
|
The data batch sizes for computing the sufficient statistics. If `None`,
|
|
@@ -91,25 +106,45 @@ class Forest(Module):
|
|
|
91
106
|
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
92
107
|
sigma_mu2
|
|
93
108
|
The prior variance of a leaf, conditional on the tree structure.
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
109
|
+
log_s
|
|
110
|
+
The logarithm of the prior probability for choosing a variable to split
|
|
111
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
112
|
+
If `None`, use a uniform distribution.
|
|
113
|
+
theta
|
|
114
|
+
The concentration parameter for the Dirichlet prior on the variable
|
|
115
|
+
distribution `s`. Required only to update `s`.
|
|
116
|
+
a
|
|
117
|
+
b
|
|
118
|
+
rho
|
|
119
|
+
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
120
|
+
See `step_theta`.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
leaf_tree: Float32[Array, 'num_trees 2**d']
|
|
124
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
125
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
126
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
127
|
+
max_split: UInt[Array, ' p']
|
|
128
|
+
blocked_vars: UInt[Array, ' k'] | None
|
|
129
|
+
p_nonterminal: Float32[Array, ' 2**d']
|
|
130
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)']
|
|
101
131
|
leaf_indices: UInt[Array, 'num_trees n']
|
|
132
|
+
min_points_per_decision_node: Int32[Array, ''] | None
|
|
102
133
|
min_points_per_leaf: Int32[Array, ''] | None
|
|
103
|
-
affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
|
|
104
134
|
resid_batch_size: int | None = field(static=True)
|
|
105
135
|
count_batch_size: int | None = field(static=True)
|
|
106
|
-
log_trans_prior: Float32[Array, 'num_trees'] | None
|
|
107
|
-
log_likelihood: Float32[Array, 'num_trees'] | None
|
|
136
|
+
log_trans_prior: Float32[Array, ' num_trees'] | None
|
|
137
|
+
log_likelihood: Float32[Array, ' num_trees'] | None
|
|
108
138
|
grow_prop_count: Int32[Array, '']
|
|
109
139
|
prune_prop_count: Int32[Array, '']
|
|
110
140
|
grow_acc_count: Int32[Array, '']
|
|
111
141
|
prune_acc_count: Int32[Array, '']
|
|
112
142
|
sigma_mu2: Float32[Array, '']
|
|
143
|
+
log_s: Float32[Array, ' p'] | None
|
|
144
|
+
theta: Float32[Array, ''] | None
|
|
145
|
+
a: Float32[Array, ''] | None
|
|
146
|
+
b: Float32[Array, ''] | None
|
|
147
|
+
rho: Float32[Array, ''] | None
|
|
113
148
|
|
|
114
149
|
|
|
115
150
|
class State(Module):
|
|
@@ -120,8 +155,6 @@ class State(Module):
|
|
|
120
155
|
----------
|
|
121
156
|
X
|
|
122
157
|
The predictors.
|
|
123
|
-
max_split
|
|
124
|
-
The maximum split index for each predictor.
|
|
125
158
|
y
|
|
126
159
|
The response. If the data type is `bool`, the model is binary regression.
|
|
127
160
|
resid
|
|
@@ -145,13 +178,12 @@ class State(Module):
|
|
|
145
178
|
"""
|
|
146
179
|
|
|
147
180
|
X: UInt[Array, 'p n']
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
z: None | Float32[Array, 'n']
|
|
181
|
+
y: Float32[Array, ' n'] | Bool[Array, ' n']
|
|
182
|
+
z: None | Float32[Array, ' n']
|
|
151
183
|
offset: Float32[Array, '']
|
|
152
|
-
resid: Float32[Array, 'n']
|
|
184
|
+
resid: Float32[Array, ' n']
|
|
153
185
|
sigma2: Float32[Array, ''] | None
|
|
154
|
-
prec_scale: Float32[Array, 'n'] | None
|
|
186
|
+
prec_scale: Float32[Array, ' n'] | None
|
|
155
187
|
sigma2_alpha: Float32[Array, ''] | None
|
|
156
188
|
sigma2_beta: Float32[Array, ''] | None
|
|
157
189
|
forest: Forest
|
|
@@ -160,19 +192,26 @@ class State(Module):
|
|
|
160
192
|
def init(
|
|
161
193
|
*,
|
|
162
194
|
X: UInt[Any, 'p n'],
|
|
163
|
-
y: Float32[Any, 'n'] | Bool[Any, 'n'],
|
|
195
|
+
y: Float32[Any, ' n'] | Bool[Any, ' n'],
|
|
164
196
|
offset: float | Float32[Any, ''] = 0.0,
|
|
165
|
-
max_split: UInt[Any, 'p'],
|
|
197
|
+
max_split: UInt[Any, ' p'],
|
|
166
198
|
num_trees: int,
|
|
167
|
-
p_nonterminal: Float32[Any, 'd-1'],
|
|
199
|
+
p_nonterminal: Float32[Any, ' d-1'],
|
|
168
200
|
sigma_mu2: float | Float32[Any, ''],
|
|
169
201
|
sigma2_alpha: float | Float32[Any, ''] | None = None,
|
|
170
202
|
sigma2_beta: float | Float32[Any, ''] | None = None,
|
|
171
|
-
error_scale: Float32[Any, 'n'] | None = None,
|
|
172
|
-
|
|
173
|
-
resid_batch_size: int | None |
|
|
174
|
-
count_batch_size: int | None |
|
|
203
|
+
error_scale: Float32[Any, ' n'] | None = None,
|
|
204
|
+
min_points_per_decision_node: int | Integer[Any, ''] | None = None,
|
|
205
|
+
resid_batch_size: int | None | Literal['auto'] = 'auto',
|
|
206
|
+
count_batch_size: int | None | Literal['auto'] = 'auto',
|
|
175
207
|
save_ratios: bool = False,
|
|
208
|
+
filter_splitless_vars: bool = True,
|
|
209
|
+
min_points_per_leaf: int | Integer[Any, ''] | None = None,
|
|
210
|
+
log_s: Float32[Any, ' p'] | None = None,
|
|
211
|
+
theta: float | Float32[Any, ''] | None = None,
|
|
212
|
+
a: float | Float32[Any, ''] | None = None,
|
|
213
|
+
b: float | Float32[Any, ''] | None = None,
|
|
214
|
+
rho: float | Float32[Any, ''] | None = None,
|
|
176
215
|
) -> State:
|
|
177
216
|
"""
|
|
178
217
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -206,8 +245,9 @@ def init(
|
|
|
206
245
|
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
207
246
|
Not supported for binary regression. If not specified, defaults to 1 for
|
|
208
247
|
all points, but potentially skipping calculations.
|
|
209
|
-
|
|
210
|
-
The minimum number of data points in a
|
|
248
|
+
min_points_per_decision_node
|
|
249
|
+
The minimum number of data points in a decision node. 0 if not
|
|
250
|
+
specified.
|
|
211
251
|
resid_batch_size
|
|
212
252
|
count_batch_size
|
|
213
253
|
The batch sizes, along datapoints, for summing the residuals and
|
|
@@ -216,6 +256,33 @@ def init(
|
|
|
216
256
|
device.
|
|
217
257
|
save_ratios
|
|
218
258
|
Whether to save the Metropolis-Hastings ratios.
|
|
259
|
+
filter_splitless_vars
|
|
260
|
+
Whether to check `max_split` for variables without available cutpoints.
|
|
261
|
+
If any are found, they are put into a list of variables to exclude from
|
|
262
|
+
the MCMC. If `False`, no check is performed, but the results may be
|
|
263
|
+
wrong if any variable is blocked. The function is jax-traceable only
|
|
264
|
+
if this is set to `False`.
|
|
265
|
+
min_points_per_leaf
|
|
266
|
+
The minimum number of datapoints in a leaf node. 0 if not specified.
|
|
267
|
+
Unlike `min_points_per_decision_node`, this constraint is not taken into
|
|
268
|
+
account in the Metropolis-Hastings ratio because it would be expensive
|
|
269
|
+
to compute. Grow moves that would violate this constraint are vetoed.
|
|
270
|
+
This parameter is independent of `min_points_per_decision_node` and
|
|
271
|
+
there is no check that they are coherent. It makes sense to set
|
|
272
|
+
``min_points_per_decision_node >= 2 * min_points_per_leaf``.
|
|
273
|
+
log_s
|
|
274
|
+
The logarithm of the prior probability for choosing a variable to split
|
|
275
|
+
along in a decision rule, conditional on the ancestors. Not normalized.
|
|
276
|
+
If not specified, use a uniform distribution. If not specified and
|
|
277
|
+
`theta` or `rho`, `a`, `b` are, it's initialized automatically.
|
|
278
|
+
theta
|
|
279
|
+
The concentration parameter for the Dirichlet prior on `s`. Required
|
|
280
|
+
only to update `log_s`. If not specified, and `rho`, `a`, `b` are
|
|
281
|
+
specified, it's initialized automatically.
|
|
282
|
+
a
|
|
283
|
+
b
|
|
284
|
+
rho
|
|
285
|
+
Parameters of the prior on `theta`. Required only to sample `theta`.
|
|
219
286
|
|
|
220
287
|
Returns
|
|
221
288
|
-------
|
|
@@ -225,6 +292,13 @@ def init(
|
|
|
225
292
|
------
|
|
226
293
|
ValueError
|
|
227
294
|
If `y` is boolean and arguments unused in binary regression are set.
|
|
295
|
+
|
|
296
|
+
Notes
|
|
297
|
+
-----
|
|
298
|
+
In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
|
|
299
|
+
of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
|
|
300
|
+
child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
|
|
301
|
+
integers in the range ``[0, 1, ..., max_split[i]]``.
|
|
228
302
|
"""
|
|
229
303
|
p_nonterminal = jnp.asarray(p_nonterminal)
|
|
230
304
|
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
@@ -244,22 +318,37 @@ def init(
|
|
|
244
318
|
is_binary = y.dtype == bool
|
|
245
319
|
if is_binary:
|
|
246
320
|
if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
|
|
247
|
-
|
|
321
|
+
msg = (
|
|
248
322
|
'error_scale, sigma2_alpha, and sigma2_beta must be set '
|
|
249
323
|
' to `None` for binary regression.'
|
|
250
324
|
)
|
|
325
|
+
raise ValueError(msg)
|
|
251
326
|
sigma2 = None
|
|
252
327
|
else:
|
|
253
328
|
sigma2_alpha = jnp.asarray(sigma2_alpha)
|
|
254
329
|
sigma2_beta = jnp.asarray(sigma2_beta)
|
|
255
330
|
sigma2 = sigma2_beta / sigma2_alpha
|
|
256
|
-
# sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
257
|
-
# TODO: I don't like this isfinite check, these functions should be
|
|
258
|
-
# low-level and just do the thing. Why was it here?
|
|
259
331
|
|
|
260
|
-
|
|
332
|
+
max_split = jnp.asarray(max_split)
|
|
333
|
+
|
|
334
|
+
if filter_splitless_vars:
|
|
335
|
+
(blocked_vars,) = jnp.nonzero(max_split == 0)
|
|
336
|
+
blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size))
|
|
337
|
+
# see `fully_used_variables` for the type cast
|
|
338
|
+
else:
|
|
339
|
+
blocked_vars = None
|
|
340
|
+
|
|
341
|
+
# check and initialize sparsity parameters
|
|
342
|
+
if not _all_none_or_not_none(rho, a, b):
|
|
343
|
+
msg = 'rho, a, b are not either all `None` or all set'
|
|
344
|
+
raise ValueError(msg)
|
|
345
|
+
if theta is None and rho is not None:
|
|
346
|
+
theta = rho
|
|
347
|
+
if log_s is None and theta is not None:
|
|
348
|
+
log_s = jnp.zeros(max_split.size)
|
|
349
|
+
|
|
350
|
+
return State(
|
|
261
351
|
X=jnp.asarray(X),
|
|
262
|
-
max_split=jnp.asarray(max_split),
|
|
263
352
|
y=y,
|
|
264
353
|
z=jnp.full(y.shape, offset) if is_binary else None,
|
|
265
354
|
offset=offset,
|
|
@@ -271,41 +360,54 @@ def init(
|
|
|
271
360
|
sigma2_alpha=sigma2_alpha,
|
|
272
361
|
sigma2_beta=sigma2_beta,
|
|
273
362
|
forest=Forest(
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
363
|
+
leaf_tree=make_forest(max_depth, jnp.float32),
|
|
364
|
+
var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
365
|
+
split_tree=make_forest(max_depth - 1, max_split.dtype),
|
|
366
|
+
affluence_tree=(
|
|
367
|
+
make_forest(max_depth - 1, bool)
|
|
368
|
+
.at[:, 1]
|
|
369
|
+
.set(
|
|
370
|
+
True
|
|
371
|
+
if min_points_per_decision_node is None
|
|
372
|
+
else y.size >= min_points_per_decision_node
|
|
373
|
+
)
|
|
277
374
|
),
|
|
278
|
-
|
|
375
|
+
blocked_vars=blocked_vars,
|
|
376
|
+
max_split=max_split,
|
|
279
377
|
grow_prop_count=jnp.zeros((), int),
|
|
280
378
|
grow_acc_count=jnp.zeros((), int),
|
|
281
379
|
prune_prop_count=jnp.zeros((), int),
|
|
282
380
|
prune_acc_count=jnp.zeros((), int),
|
|
283
|
-
p_nonterminal=p_nonterminal,
|
|
381
|
+
p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
|
|
284
382
|
p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
|
|
285
383
|
leaf_indices=jnp.ones(
|
|
286
384
|
(num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
|
|
287
385
|
),
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
if min_points_per_leaf is None
|
|
291
|
-
else jnp.asarray(min_points_per_leaf)
|
|
292
|
-
),
|
|
293
|
-
affluence_trees=(
|
|
294
|
-
None
|
|
295
|
-
if min_points_per_leaf is None
|
|
296
|
-
else make_forest(max_depth - 1, bool)
|
|
297
|
-
.at[:, 1]
|
|
298
|
-
.set(y.size >= 2 * min_points_per_leaf)
|
|
299
|
-
),
|
|
386
|
+
min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
|
|
387
|
+
min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
|
|
300
388
|
resid_batch_size=resid_batch_size,
|
|
301
389
|
count_batch_size=count_batch_size,
|
|
302
|
-
log_trans_prior=jnp.
|
|
303
|
-
log_likelihood=jnp.
|
|
390
|
+
log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
|
|
391
|
+
log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
|
|
304
392
|
sigma_mu2=jnp.asarray(sigma_mu2),
|
|
393
|
+
log_s=_asarray_or_none(log_s),
|
|
394
|
+
theta=_asarray_or_none(theta),
|
|
395
|
+
rho=_asarray_or_none(rho),
|
|
396
|
+
a=_asarray_or_none(a),
|
|
397
|
+
b=_asarray_or_none(b),
|
|
305
398
|
),
|
|
306
399
|
)
|
|
307
400
|
|
|
308
|
-
|
|
401
|
+
|
|
402
|
+
def _all_none_or_not_none(*args):
|
|
403
|
+
is_none = [x is None for x in args]
|
|
404
|
+
return all(is_none) or not any(is_none)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _asarray_or_none(x):
|
|
408
|
+
if x is None:
|
|
409
|
+
return None
|
|
410
|
+
return jnp.asarray(x)
|
|
309
411
|
|
|
310
412
|
|
|
311
413
|
def _choose_suffstat_batch_size(
|
|
@@ -319,16 +421,17 @@ def _choose_suffstat_batch_size(
|
|
|
319
421
|
device = jax.devices()[0]
|
|
320
422
|
platform = device.platform
|
|
321
423
|
if platform not in ('cpu', 'gpu'):
|
|
322
|
-
|
|
424
|
+
msg = f'Unknown platform: {platform}'
|
|
425
|
+
raise KeyError(msg)
|
|
323
426
|
return platform
|
|
324
427
|
|
|
325
428
|
if resid_batch_size == 'auto':
|
|
326
429
|
platform = get_platform()
|
|
327
430
|
n = max(1, y.size)
|
|
328
431
|
if platform == 'cpu':
|
|
329
|
-
resid_batch_size = 2 **
|
|
432
|
+
resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6
|
|
330
433
|
elif platform == 'gpu':
|
|
331
|
-
resid_batch_size = 2 **
|
|
434
|
+
resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
|
|
332
435
|
resid_batch_size = max(1, resid_batch_size)
|
|
333
436
|
|
|
334
437
|
if count_batch_size == 'auto':
|
|
@@ -337,11 +440,11 @@ def _choose_suffstat_batch_size(
|
|
|
337
440
|
count_batch_size = None
|
|
338
441
|
elif platform == 'gpu':
|
|
339
442
|
n = max(1, y.size)
|
|
340
|
-
count_batch_size = 2 **
|
|
443
|
+
count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
|
|
341
444
|
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
342
445
|
max_memory = 2**29
|
|
343
446
|
itemsize = 4
|
|
344
|
-
min_batch_size =
|
|
447
|
+
min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
|
|
345
448
|
count_batch_size = max(count_batch_size, min_batch_size)
|
|
346
449
|
count_batch_size = max(1, count_batch_size)
|
|
347
450
|
|
|
@@ -397,7 +500,7 @@ def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
|
397
500
|
This function zeroes the proposal counters.
|
|
398
501
|
"""
|
|
399
502
|
keys = split(key)
|
|
400
|
-
moves = propose_moves(keys.pop(), bart.forest
|
|
503
|
+
moves = propose_moves(keys.pop(), bart.forest)
|
|
401
504
|
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
402
505
|
|
|
403
506
|
|
|
@@ -408,9 +511,11 @@ class Moves(Module):
|
|
|
408
511
|
Parameters
|
|
409
512
|
----------
|
|
410
513
|
allowed
|
|
411
|
-
Whether
|
|
412
|
-
|
|
413
|
-
|
|
514
|
+
Whether there is a possible move. If `False`, the other values may not
|
|
515
|
+
make sense. The only case in which a move is marked as allowed but is
|
|
516
|
+
then vetoed is if it does not satisfy `min_points_per_leaf`, which for
|
|
517
|
+
efficiency is implemented post-hoc without changing the rest of the
|
|
518
|
+
MCMC logic.
|
|
414
519
|
grow
|
|
415
520
|
Whether the move is a grow move or a prune move.
|
|
416
521
|
num_growable
|
|
@@ -421,20 +526,27 @@ class Moves(Module):
|
|
|
421
526
|
right
|
|
422
527
|
The indices of the children of 'node'.
|
|
423
528
|
partial_ratio
|
|
424
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
425
|
-
|
|
426
|
-
|
|
529
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
530
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
531
|
+
probability that the children of the modified node are terminal. If the
|
|
532
|
+
move is PRUNE, the ratio is inverted. `None` once
|
|
427
533
|
`log_trans_prior_ratio` has been computed.
|
|
428
534
|
log_trans_prior_ratio
|
|
429
535
|
The logarithm of the product of the transition and prior terms of the
|
|
430
536
|
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
431
|
-
`None` if not yet computed.
|
|
537
|
+
`None` if not yet computed. If PRUNE, the log-ratio is negated.
|
|
432
538
|
grow_var
|
|
433
539
|
The decision axes of the new rules.
|
|
434
540
|
grow_split
|
|
435
541
|
The decision boundaries of the new rules.
|
|
436
|
-
|
|
542
|
+
var_tree
|
|
437
543
|
The updated decision axes of the trees, valid whatever move.
|
|
544
|
+
affluence_tree
|
|
545
|
+
A partially updated `affluence_tree`, marking non-leaf nodes that would
|
|
546
|
+
become leaves if the move was accepted. This mark initially (out of
|
|
547
|
+
`propose_moves`) takes into account if there would be available decision
|
|
548
|
+
rules to grow the leaf, and whether there are enough datapoints in the
|
|
549
|
+
node is marked in `accept_moves_parallel_stage`.
|
|
438
550
|
logu
|
|
439
551
|
The logarithm of a uniform (0, 1] random variable to be used to
|
|
440
552
|
accept the move. It's in (-oo, 0].
|
|
@@ -446,25 +558,24 @@ class Moves(Module):
|
|
|
446
558
|
computed.
|
|
447
559
|
"""
|
|
448
560
|
|
|
449
|
-
allowed: Bool[Array, 'num_trees']
|
|
450
|
-
grow: Bool[Array, 'num_trees']
|
|
451
|
-
num_growable: UInt[Array, 'num_trees']
|
|
452
|
-
node: UInt[Array, 'num_trees']
|
|
453
|
-
left: UInt[Array, 'num_trees']
|
|
454
|
-
right: UInt[Array, 'num_trees']
|
|
455
|
-
partial_ratio: Float32[Array, 'num_trees'] | None
|
|
456
|
-
log_trans_prior_ratio: None | Float32[Array, 'num_trees']
|
|
457
|
-
grow_var: UInt[Array, 'num_trees']
|
|
458
|
-
grow_split: UInt[Array, 'num_trees']
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
561
|
+
allowed: Bool[Array, ' num_trees']
|
|
562
|
+
grow: Bool[Array, ' num_trees']
|
|
563
|
+
num_growable: UInt[Array, ' num_trees']
|
|
564
|
+
node: UInt[Array, ' num_trees']
|
|
565
|
+
left: UInt[Array, ' num_trees']
|
|
566
|
+
right: UInt[Array, ' num_trees']
|
|
567
|
+
partial_ratio: Float32[Array, ' num_trees'] | None
|
|
568
|
+
log_trans_prior_ratio: None | Float32[Array, ' num_trees']
|
|
569
|
+
grow_var: UInt[Array, ' num_trees']
|
|
570
|
+
grow_split: UInt[Array, ' num_trees']
|
|
571
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
572
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
573
|
+
logu: Float32[Array, ' num_trees']
|
|
574
|
+
acc: None | Bool[Array, ' num_trees']
|
|
575
|
+
to_prune: None | Bool[Array, ' num_trees']
|
|
463
576
|
|
|
464
577
|
|
|
465
|
-
def propose_moves(
|
|
466
|
-
key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
|
|
467
|
-
) -> Moves:
|
|
578
|
+
def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
|
|
468
579
|
"""
|
|
469
580
|
Propose moves for all the trees.
|
|
470
581
|
|
|
@@ -478,39 +589,40 @@ def propose_moves(
|
|
|
478
589
|
A jax random key.
|
|
479
590
|
forest
|
|
480
591
|
The `forest` field of a BART MCMC state.
|
|
481
|
-
max_split
|
|
482
|
-
The maximum split index for each variable, found in `State`.
|
|
483
592
|
|
|
484
593
|
Returns
|
|
485
594
|
-------
|
|
486
595
|
The proposed move for each tree.
|
|
487
596
|
"""
|
|
488
|
-
num_trees, _ = forest.
|
|
597
|
+
num_trees, _ = forest.leaf_tree.shape
|
|
489
598
|
keys = split(key, 1 + 2 * num_trees)
|
|
490
599
|
|
|
491
600
|
# compute moves
|
|
492
601
|
grow_moves = propose_grow_moves(
|
|
493
602
|
keys.pop(num_trees),
|
|
494
|
-
forest.
|
|
495
|
-
forest.
|
|
496
|
-
forest.
|
|
497
|
-
max_split,
|
|
603
|
+
forest.var_tree,
|
|
604
|
+
forest.split_tree,
|
|
605
|
+
forest.affluence_tree,
|
|
606
|
+
forest.max_split,
|
|
607
|
+
forest.blocked_vars,
|
|
498
608
|
forest.p_nonterminal,
|
|
499
609
|
forest.p_propose_grow,
|
|
610
|
+
forest.log_s,
|
|
500
611
|
)
|
|
501
612
|
prune_moves = propose_prune_moves(
|
|
502
613
|
keys.pop(num_trees),
|
|
503
|
-
forest.
|
|
504
|
-
|
|
614
|
+
forest.split_tree,
|
|
615
|
+
grow_moves.affluence_tree,
|
|
505
616
|
forest.p_nonterminal,
|
|
506
617
|
forest.p_propose_grow,
|
|
507
618
|
)
|
|
508
619
|
|
|
509
|
-
u,
|
|
620
|
+
u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
|
|
510
621
|
|
|
511
622
|
# choose between grow or prune
|
|
512
|
-
|
|
513
|
-
|
|
623
|
+
p_grow = jnp.where(
|
|
624
|
+
grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
|
|
625
|
+
)
|
|
514
626
|
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
515
627
|
|
|
516
628
|
# compute children indices
|
|
@@ -519,7 +631,7 @@ def propose_moves(
|
|
|
519
631
|
right = left + 1
|
|
520
632
|
|
|
521
633
|
return Moves(
|
|
522
|
-
allowed=
|
|
634
|
+
allowed=grow_moves.allowed | prune_moves.allowed,
|
|
523
635
|
grow=grow,
|
|
524
636
|
num_growable=grow_moves.num_growable,
|
|
525
637
|
node=node,
|
|
@@ -531,8 +643,11 @@ def propose_moves(
|
|
|
531
643
|
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
532
644
|
grow_var=grow_moves.var,
|
|
533
645
|
grow_split=grow_moves.split,
|
|
534
|
-
|
|
535
|
-
|
|
646
|
+
# var_tree does not need to be updated if prune
|
|
647
|
+
var_tree=grow_moves.var_tree,
|
|
648
|
+
# affluence_tree is updated for both moves unconditionally, prune last
|
|
649
|
+
affluence_tree=prune_moves.affluence_tree,
|
|
650
|
+
logu=jnp.log1p(-exp1mlogu),
|
|
536
651
|
acc=None, # will be set in accept_moves_sequential_stage
|
|
537
652
|
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
538
653
|
)
|
|
@@ -544,8 +659,10 @@ class GrowMoves(Module):
|
|
|
544
659
|
|
|
545
660
|
Parameters
|
|
546
661
|
----------
|
|
662
|
+
allowed
|
|
663
|
+
Whether the move is allowed for proposal.
|
|
547
664
|
num_growable
|
|
548
|
-
The number of
|
|
665
|
+
The number of leaves that can be proposed for grow.
|
|
549
666
|
node
|
|
550
667
|
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
551
668
|
leaves.
|
|
@@ -558,25 +675,32 @@ class GrowMoves(Module):
|
|
|
558
675
|
move.
|
|
559
676
|
var_tree
|
|
560
677
|
The updated decision axes of the tree.
|
|
678
|
+
affluence_tree
|
|
679
|
+
A partially updated `affluence_tree` that marks each new leaf that
|
|
680
|
+
would be produced as `True` if it would have available decision rules.
|
|
561
681
|
"""
|
|
562
682
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
683
|
+
allowed: Bool[Array, ' num_trees']
|
|
684
|
+
num_growable: UInt[Array, ' num_trees']
|
|
685
|
+
node: UInt[Array, ' num_trees']
|
|
686
|
+
var: UInt[Array, ' num_trees']
|
|
687
|
+
split: UInt[Array, ' num_trees']
|
|
688
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
568
689
|
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
690
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
569
691
|
|
|
570
692
|
|
|
571
|
-
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
|
|
693
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
|
|
572
694
|
def propose_grow_moves(
|
|
573
|
-
key: Key[Array, ''],
|
|
574
|
-
var_tree: UInt[Array, '2**(d-1)'],
|
|
575
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
576
|
-
affluence_tree: Bool[Array, '2**(d-1)']
|
|
577
|
-
max_split: UInt[Array, 'p'],
|
|
578
|
-
|
|
579
|
-
|
|
695
|
+
key: Key[Array, ' num_trees'],
|
|
696
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
697
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
698
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
|
|
699
|
+
max_split: UInt[Array, ' p'],
|
|
700
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
701
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
702
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
703
|
+
log_s: Float32[Array, ' p'] | None,
|
|
580
704
|
) -> GrowMoves:
|
|
581
705
|
"""
|
|
582
706
|
Propose a GROW move for each tree.
|
|
@@ -593,13 +717,19 @@ def propose_grow_moves(
|
|
|
593
717
|
split_tree
|
|
594
718
|
The splitting points of the tree.
|
|
595
719
|
affluence_tree
|
|
596
|
-
Whether
|
|
720
|
+
Whether each leaf has enough points to be grown.
|
|
597
721
|
max_split
|
|
598
722
|
The maximum split index for each variable.
|
|
723
|
+
blocked_vars
|
|
724
|
+
The indices of the variables that have no available cutpoints.
|
|
599
725
|
p_nonterminal
|
|
600
|
-
The probability of a nonterminal
|
|
726
|
+
The a priori probability of a node to be nonterminal conditional on the
|
|
727
|
+
ancestors, including at the maximum depth where it should be zero.
|
|
601
728
|
p_propose_grow
|
|
602
729
|
The unnormalized probability of choosing a leaf to grow.
|
|
730
|
+
log_s
|
|
731
|
+
Unnormalized log-probability used to choose a variable to split on
|
|
732
|
+
amongst the available ones.
|
|
603
733
|
|
|
604
734
|
Returns
|
|
605
735
|
-------
|
|
@@ -607,16 +737,10 @@ def propose_grow_moves(
|
|
|
607
737
|
|
|
608
738
|
Notes
|
|
609
739
|
-----
|
|
610
|
-
The move is not proposed if
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
The move is also not be possible if the ancestors of a leaf have
|
|
615
|
-
exhausted the possible decision rules that lead to a non-empty selection.
|
|
616
|
-
This is marked by returning `var` set to `p` and `split` set to 0. But this
|
|
617
|
-
does not block the move from counting as "proposed", even though it is
|
|
618
|
-
predictably going to be rejected. This simplifies the MCMC and should not
|
|
619
|
-
reduce efficiency if not in unrealistic corner cases.
|
|
740
|
+
The move is not proposed if each leaf is already at maximum depth, or has
|
|
741
|
+
less datapoints than the requested threshold `min_points_per_decision_node`,
|
|
742
|
+
or it does not have any available decision rules given its ancestors. This
|
|
743
|
+
is marked by setting `allowed` to `False` and `num_growable` to 0.
|
|
620
744
|
"""
|
|
621
745
|
keys = split(key, 3)
|
|
622
746
|
|
|
@@ -624,36 +748,45 @@ def propose_grow_moves(
|
|
|
624
748
|
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
625
749
|
)
|
|
626
750
|
|
|
627
|
-
|
|
628
|
-
|
|
751
|
+
# sample a decision rule
|
|
752
|
+
var, num_available_var = choose_variable(
|
|
753
|
+
keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
|
|
754
|
+
)
|
|
755
|
+
split_idx, l, r = choose_split(
|
|
756
|
+
keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
|
|
757
|
+
)
|
|
629
758
|
|
|
630
|
-
|
|
759
|
+
# determine if the new leaves would have available decision rules; if the
|
|
760
|
+
# move is blocked, these values may not make sense
|
|
761
|
+
left_growable = right_growable = num_available_var > 1
|
|
762
|
+
left_growable |= l < split_idx
|
|
763
|
+
right_growable |= split_idx + 1 < r
|
|
764
|
+
left = leaf_to_grow << 1
|
|
765
|
+
right = left + 1
|
|
766
|
+
affluence_tree = affluence_tree.at[left].set(left_growable)
|
|
767
|
+
affluence_tree = affluence_tree.at[right].set(right_growable)
|
|
631
768
|
|
|
632
769
|
ratio = compute_partial_ratio(
|
|
633
770
|
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
634
771
|
)
|
|
635
772
|
|
|
636
773
|
return GrowMoves(
|
|
774
|
+
allowed=num_growable > 0,
|
|
637
775
|
num_growable=num_growable,
|
|
638
776
|
node=leaf_to_grow,
|
|
639
777
|
var=var,
|
|
640
778
|
split=split_idx,
|
|
641
779
|
partial_ratio=ratio,
|
|
642
|
-
var_tree=var_tree,
|
|
780
|
+
var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
|
|
781
|
+
affluence_tree=affluence_tree,
|
|
643
782
|
)
|
|
644
783
|
|
|
645
|
-
# TODO it is not clear to me how var=p and split=0 when the move is not
|
|
646
|
-
# possible lead to corrent behavior downstream. Like, the move is proposed,
|
|
647
|
-
# but then it's a noop? And since it's a noop, it makes no difference if
|
|
648
|
-
# it's "accepted" or "rejected", it's like it's always rejected, so who
|
|
649
|
-
# cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
|
|
650
|
-
|
|
651
784
|
|
|
652
785
|
def choose_leaf(
|
|
653
786
|
key: Key[Array, ''],
|
|
654
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
655
|
-
affluence_tree: Bool[Array, '2**(d-1)']
|
|
656
|
-
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
787
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
788
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
789
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
657
790
|
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
658
791
|
"""
|
|
659
792
|
Choose a leaf node to grow in a tree.
|
|
@@ -672,16 +805,16 @@ def choose_leaf(
|
|
|
672
805
|
|
|
673
806
|
Returns
|
|
674
807
|
-------
|
|
675
|
-
leaf_to_grow :
|
|
808
|
+
leaf_to_grow : Int32[Array, '']
|
|
676
809
|
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
677
810
|
``2 ** d``.
|
|
678
|
-
num_growable :
|
|
811
|
+
num_growable : Int32[Array, '']
|
|
679
812
|
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
680
|
-
and have at least twice `min_points_per_leaf
|
|
681
|
-
prob_choose :
|
|
813
|
+
and have at least twice `min_points_per_leaf`.
|
|
814
|
+
prob_choose : Float32[Array, '']
|
|
682
815
|
The (normalized) probability that this function had to choose that
|
|
683
816
|
specific leaf, given the arguments.
|
|
684
|
-
num_prunable :
|
|
817
|
+
num_prunable : Int32[Array, '']
|
|
685
818
|
The number of leaf parents that could be pruned, after converting the
|
|
686
819
|
selected leaf to a non-terminal node.
|
|
687
820
|
"""
|
|
@@ -690,41 +823,43 @@ def choose_leaf(
|
|
|
690
823
|
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
691
824
|
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
692
825
|
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
693
|
-
prob_choose = distr[leaf_to_grow] / distr_norm
|
|
826
|
+
prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
|
|
694
827
|
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
695
828
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
696
829
|
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
697
830
|
|
|
698
831
|
|
|
699
832
|
def growable_leaves(
|
|
700
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
701
|
-
|
|
702
|
-
) -> Bool[Array, '2**(d-1)']:
|
|
833
|
+
split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
|
|
834
|
+
) -> Bool[Array, ' 2**(d-1)']:
|
|
703
835
|
"""
|
|
704
836
|
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
705
837
|
|
|
706
|
-
The condition is that a leaf is not at the bottom level
|
|
707
|
-
|
|
838
|
+
The condition is that a leaf is not at the bottom level, has available
|
|
839
|
+
decision rules given its ancestors, and has at least
|
|
840
|
+
`min_points_per_decision_node` points.
|
|
708
841
|
|
|
709
842
|
Parameters
|
|
710
843
|
----------
|
|
711
844
|
split_tree
|
|
712
845
|
The splitting points of the tree.
|
|
713
846
|
affluence_tree
|
|
714
|
-
|
|
847
|
+
Marks leaves that can be grown.
|
|
715
848
|
|
|
716
849
|
Returns
|
|
717
850
|
-------
|
|
718
851
|
The mask indicating the leaf nodes that can be proposed to grow.
|
|
852
|
+
|
|
853
|
+
Notes
|
|
854
|
+
-----
|
|
855
|
+
This function needs `split_tree` and not just `affluence_tree` because
|
|
856
|
+
`affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
|
|
719
857
|
"""
|
|
720
|
-
|
|
721
|
-
if affluence_tree is not None:
|
|
722
|
-
is_growable &= affluence_tree
|
|
723
|
-
return is_growable
|
|
858
|
+
return grove.is_actual_leaf(split_tree) & affluence_tree
|
|
724
859
|
|
|
725
860
|
|
|
726
861
|
def categorical(
|
|
727
|
-
key: Key[Array, ''], distr: Float32[Array, 'n']
|
|
862
|
+
key: Key[Array, ''], distr: Float32[Array, ' n']
|
|
728
863
|
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
729
864
|
"""
|
|
730
865
|
Return a random integer from an arbitrary distribution.
|
|
@@ -743,6 +878,11 @@ def categorical(
|
|
|
743
878
|
return ``n``.
|
|
744
879
|
norm : Float32[Array, '']
|
|
745
880
|
The sum of `distr`.
|
|
881
|
+
|
|
882
|
+
Notes
|
|
883
|
+
-----
|
|
884
|
+
This function uses a cumsum instead of the Gumbel trick, so it's ok only
|
|
885
|
+
for small ranges with probabilities well greater than 0.
|
|
746
886
|
"""
|
|
747
887
|
ecdf = jnp.cumsum(distr)
|
|
748
888
|
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
@@ -751,11 +891,13 @@ def categorical(
|
|
|
751
891
|
|
|
752
892
|
def choose_variable(
|
|
753
893
|
key: Key[Array, ''],
|
|
754
|
-
var_tree: UInt[Array, '2**(d-1)'],
|
|
755
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
756
|
-
max_split: UInt[Array, 'p'],
|
|
894
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
895
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
896
|
+
max_split: UInt[Array, ' p'],
|
|
757
897
|
leaf_index: Int32[Array, ''],
|
|
758
|
-
|
|
898
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
899
|
+
log_s: Float32[Array, ' p'] | None,
|
|
900
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
759
901
|
"""
|
|
760
902
|
Choose a variable to split on for a new non-terminal node.
|
|
761
903
|
|
|
@@ -771,28 +913,39 @@ def choose_variable(
|
|
|
771
913
|
The maximum split index for each variable.
|
|
772
914
|
leaf_index
|
|
773
915
|
The index of the leaf to grow.
|
|
916
|
+
blocked_vars
|
|
917
|
+
The indices of the variables that have no available cutpoints. If
|
|
918
|
+
`None`, all variables are assumed unblocked.
|
|
919
|
+
log_s
|
|
920
|
+
The logarithm of the prior probability for choosing a variable. If
|
|
921
|
+
`None`, use a uniform distribution.
|
|
774
922
|
|
|
775
923
|
Returns
|
|
776
924
|
-------
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
allowed splits. If no variable has a non-empty range, return `p`.
|
|
925
|
+
var : Int32[Array, '']
|
|
926
|
+
The index of the variable to split on.
|
|
927
|
+
num_available_var : Int32[Array, '']
|
|
928
|
+
The number of variables with available decision rules `var` was chosen
|
|
929
|
+
from.
|
|
783
930
|
"""
|
|
784
931
|
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
785
|
-
|
|
932
|
+
if blocked_vars is not None:
|
|
933
|
+
var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
|
|
934
|
+
|
|
935
|
+
if log_s is None:
|
|
936
|
+
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
937
|
+
else:
|
|
938
|
+
return categorical_exclude(key, log_s, var_to_ignore)
|
|
786
939
|
|
|
787
940
|
|
|
788
941
|
def fully_used_variables(
|
|
789
|
-
var_tree: UInt[Array, '2**(d-1)'],
|
|
790
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
791
|
-
max_split: UInt[Array, 'p'],
|
|
942
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
943
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
944
|
+
max_split: UInt[Array, ' p'],
|
|
792
945
|
leaf_index: Int32[Array, ''],
|
|
793
|
-
) -> UInt[Array, 'd-2']:
|
|
946
|
+
) -> UInt[Array, ' d-2']:
|
|
794
947
|
"""
|
|
795
|
-
|
|
948
|
+
Find variables in the ancestors of a node that have an empty split range.
|
|
796
949
|
|
|
797
950
|
Parameters
|
|
798
951
|
----------
|
|
@@ -820,23 +973,25 @@ def fully_used_variables(
|
|
|
820
973
|
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
821
974
|
num_split = r - l
|
|
822
975
|
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
976
|
+
# the type of var_to_ignore is already sufficient to hold max_split.size,
|
|
977
|
+
# see ancestor_variables()
|
|
823
978
|
|
|
824
979
|
|
|
825
980
|
def ancestor_variables(
|
|
826
|
-
var_tree: UInt[Array, '2**(d-1)'],
|
|
827
|
-
max_split: UInt[Array, 'p'],
|
|
981
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
982
|
+
max_split: UInt[Array, ' p'],
|
|
828
983
|
node_index: Int32[Array, ''],
|
|
829
|
-
) -> UInt[Array, 'd-2']:
|
|
984
|
+
) -> UInt[Array, ' d-2']:
|
|
830
985
|
"""
|
|
831
986
|
Return the list of variables in the ancestors of a node.
|
|
832
987
|
|
|
833
988
|
Parameters
|
|
834
989
|
----------
|
|
835
|
-
var_tree
|
|
990
|
+
var_tree
|
|
836
991
|
The variable indices of the tree.
|
|
837
|
-
max_split
|
|
992
|
+
max_split
|
|
838
993
|
The maximum split index for each variable. Used only to get `p`.
|
|
839
|
-
node_index
|
|
994
|
+
node_index
|
|
840
995
|
The index of the node, assumed to be valid for `var_tree`.
|
|
841
996
|
|
|
842
997
|
Returns
|
|
@@ -866,9 +1021,9 @@ def ancestor_variables(
|
|
|
866
1021
|
|
|
867
1022
|
|
|
868
1023
|
def split_range(
|
|
869
|
-
var_tree: UInt[Array, '2**(d-1)'],
|
|
870
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
871
|
-
max_split: UInt[Array, 'p'],
|
|
1024
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1025
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1026
|
+
max_split: UInt[Array, ' p'],
|
|
872
1027
|
node_index: Int32[Array, ''],
|
|
873
1028
|
ref_var: Int32[Array, ''],
|
|
874
1029
|
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
@@ -890,13 +1045,13 @@ def split_range(
|
|
|
890
1045
|
|
|
891
1046
|
Returns
|
|
892
1047
|
-------
|
|
893
|
-
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=
|
|
1048
|
+
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
|
|
894
1049
|
"""
|
|
895
1050
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
896
1051
|
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
897
1052
|
jnp.int32
|
|
898
1053
|
)
|
|
899
|
-
carry = 0, initial_r, node_index
|
|
1054
|
+
carry = jnp.int32(0), initial_r, node_index
|
|
900
1055
|
|
|
901
1056
|
def loop(carry, _):
|
|
902
1057
|
l, r, index = carry
|
|
@@ -913,8 +1068,8 @@ def split_range(
|
|
|
913
1068
|
|
|
914
1069
|
|
|
915
1070
|
def randint_exclude(
|
|
916
|
-
key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
|
|
917
|
-
) -> Int32[Array, '']:
|
|
1071
|
+
key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
|
|
1072
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
918
1073
|
"""
|
|
919
1074
|
Return a random integer in a range, excluding some values.
|
|
920
1075
|
|
|
@@ -930,30 +1085,74 @@ def randint_exclude(
|
|
|
930
1085
|
|
|
931
1086
|
Returns
|
|
932
1087
|
-------
|
|
933
|
-
|
|
1088
|
+
u : Int32[Array, '']
|
|
1089
|
+
A random integer `u` in the range ``[0, sup)`` such that ``u not in
|
|
1090
|
+
exclude``.
|
|
1091
|
+
num_allowed : Int32[Array, '']
|
|
1092
|
+
The number of integers in the range that were not excluded.
|
|
934
1093
|
|
|
935
1094
|
Notes
|
|
936
1095
|
-----
|
|
937
1096
|
If all values in the range are excluded, return `sup`.
|
|
938
1097
|
"""
|
|
939
|
-
exclude =
|
|
940
|
-
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
1098
|
+
exclude, num_allowed = _process_exclude(sup, exclude)
|
|
941
1099
|
u = random.randint(key, (), 0, num_allowed)
|
|
942
1100
|
|
|
943
|
-
def loop(u,
|
|
944
|
-
return jnp.where(
|
|
1101
|
+
def loop(u, i_excluded):
|
|
1102
|
+
return jnp.where(i_excluded <= u, u + 1, u), None
|
|
945
1103
|
|
|
946
1104
|
u, _ = lax.scan(loop, u, exclude)
|
|
947
|
-
return u
|
|
1105
|
+
return u, num_allowed
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def _process_exclude(sup, exclude):
|
|
1109
|
+
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
1110
|
+
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
1111
|
+
return exclude, num_allowed
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def categorical_exclude(
|
|
1115
|
+
key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
|
|
1116
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
1117
|
+
"""
|
|
1118
|
+
Draw from a categorical distribution, excluding a set of values.
|
|
1119
|
+
|
|
1120
|
+
Parameters
|
|
1121
|
+
----------
|
|
1122
|
+
key
|
|
1123
|
+
A jax random key.
|
|
1124
|
+
logits
|
|
1125
|
+
The unnormalized log-probabilities of each category.
|
|
1126
|
+
exclude
|
|
1127
|
+
The values to exclude from the range [0, k). Values greater than or
|
|
1128
|
+
equal to `logits.size` are ignored. Values can appear more than once.
|
|
1129
|
+
|
|
1130
|
+
Returns
|
|
1131
|
+
-------
|
|
1132
|
+
u : Int32[Array, '']
|
|
1133
|
+
A random integer in the range ``[0, k)`` such that ``u not in exclude``.
|
|
1134
|
+
num_allowed : Int32[Array, '']
|
|
1135
|
+
The number of integers in the range that were not excluded.
|
|
1136
|
+
|
|
1137
|
+
Notes
|
|
1138
|
+
-----
|
|
1139
|
+
If all values in the range are excluded, the result is unspecified.
|
|
1140
|
+
"""
|
|
1141
|
+
exclude, num_allowed = _process_exclude(logits.size, exclude)
|
|
1142
|
+
kinda_neg_inf = jnp.finfo(logits.dtype).min
|
|
1143
|
+
logits = logits.at[exclude].set(kinda_neg_inf)
|
|
1144
|
+
u = random.categorical(key, logits)
|
|
1145
|
+
return u, num_allowed
|
|
948
1146
|
|
|
949
1147
|
|
|
950
1148
|
def choose_split(
|
|
951
1149
|
key: Key[Array, ''],
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
1150
|
+
var: Int32[Array, ''],
|
|
1151
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
1152
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1153
|
+
max_split: UInt[Array, ' p'],
|
|
955
1154
|
leaf_index: Int32[Array, ''],
|
|
956
|
-
) -> Int32[Array, '']:
|
|
1155
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
|
|
957
1156
|
"""
|
|
958
1157
|
Choose a split point for a new non-terminal node.
|
|
959
1158
|
|
|
@@ -961,32 +1160,39 @@ def choose_split(
|
|
|
961
1160
|
----------
|
|
962
1161
|
key
|
|
963
1162
|
A jax random key.
|
|
1163
|
+
var
|
|
1164
|
+
The variable to split on.
|
|
964
1165
|
var_tree
|
|
965
|
-
The splitting axes of the tree.
|
|
1166
|
+
The splitting axes of the tree. Does not need to already contain `var`
|
|
1167
|
+
at `leaf_index`.
|
|
966
1168
|
split_tree
|
|
967
1169
|
The splitting points of the tree.
|
|
968
1170
|
max_split
|
|
969
1171
|
The maximum split index for each variable.
|
|
970
1172
|
leaf_index
|
|
971
|
-
The index of the leaf to grow.
|
|
972
|
-
contains the target variable at this index.
|
|
1173
|
+
The index of the leaf to grow.
|
|
973
1174
|
|
|
974
1175
|
Returns
|
|
975
1176
|
-------
|
|
976
|
-
|
|
1177
|
+
split : Int32[Array, '']
|
|
1178
|
+
The cutpoint.
|
|
1179
|
+
l : Int32[Array, '']
|
|
1180
|
+
r : Int32[Array, '']
|
|
1181
|
+
The integer range `split` was drawn from is [l, r).
|
|
1182
|
+
|
|
1183
|
+
Notes
|
|
1184
|
+
-----
|
|
1185
|
+
If `var` is out of bounds, or if the available split range on that variable
|
|
1186
|
+
is empty, return 0.
|
|
977
1187
|
"""
|
|
978
|
-
var = var_tree[leaf_index]
|
|
979
1188
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
980
|
-
return random.randint(key, (), l, r)
|
|
981
|
-
|
|
982
|
-
# TODO what happens if leaf_index is out of bounds? And is the value used
|
|
983
|
-
# in that case?
|
|
1189
|
+
return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
|
|
984
1190
|
|
|
985
1191
|
|
|
986
1192
|
def compute_partial_ratio(
|
|
987
1193
|
prob_choose: Float32[Array, ''],
|
|
988
1194
|
num_prunable: Int32[Array, ''],
|
|
989
|
-
p_nonterminal: Float32[Array, 'd'],
|
|
1195
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
990
1196
|
leaf_to_grow: Int32[Array, ''],
|
|
991
1197
|
) -> Float32[Array, '']:
|
|
992
1198
|
"""
|
|
@@ -1001,7 +1207,8 @@ def compute_partial_ratio(
|
|
|
1001
1207
|
The number of leaf parents that could be pruned, after converting the
|
|
1002
1208
|
leaf to be grown to a non-terminal node.
|
|
1003
1209
|
p_nonterminal
|
|
1004
|
-
The probability of
|
|
1210
|
+
The a priori probability of each node being nonterminal conditional on
|
|
1211
|
+
its ancestors.
|
|
1005
1212
|
leaf_to_grow
|
|
1006
1213
|
The index of the leaf to grow.
|
|
1007
1214
|
|
|
@@ -1013,29 +1220,29 @@ def compute_partial_ratio(
|
|
|
1013
1220
|
-----
|
|
1014
1221
|
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
1015
1222
|
The "partial" transition ratio returned is missing the factor P(propose
|
|
1016
|
-
prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
|
|
1223
|
+
prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
|
|
1224
|
+
"partial" prior ratio is missing the factor P(children are leaves).
|
|
1017
1225
|
"""
|
|
1018
1226
|
# the two ratios also contain factors num_available_split *
|
|
1019
|
-
# num_available_var, but they cancel out
|
|
1227
|
+
# num_available_var * s[var], but they cancel out
|
|
1020
1228
|
|
|
1021
|
-
# p_prune
|
|
1022
|
-
# computed
|
|
1229
|
+
# p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
|
|
1230
|
+
# computed here because they need the count trees, which are computed in the
|
|
1231
|
+
# acceptance phase
|
|
1023
1232
|
|
|
1024
1233
|
prune_allowed = leaf_to_grow != 1
|
|
1025
1234
|
# prune allowed <---> the initial tree is not a root
|
|
1026
1235
|
# leaf to grow is root --> the tree can only be a root
|
|
1027
1236
|
# tree is a root --> the only leaf I can grow is root
|
|
1028
|
-
|
|
1029
1237
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
1030
|
-
|
|
1031
1238
|
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
1032
1239
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
tree_ratio =
|
|
1240
|
+
# .at.get because if leaf_to_grow is out of bounds (move not allowed), this
|
|
1241
|
+
# would produce a 0 and then an inf when `complete_ratio` takes the log
|
|
1242
|
+
pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
|
|
1243
|
+
tree_ratio = pnt / (1 - pnt)
|
|
1037
1244
|
|
|
1038
|
-
return tree_ratio / inv_trans_ratio
|
|
1245
|
+
return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
|
|
1039
1246
|
|
|
1040
1247
|
|
|
1041
1248
|
class PruneMoves(Module):
|
|
@@ -1049,24 +1256,26 @@ class PruneMoves(Module):
|
|
|
1049
1256
|
node
|
|
1050
1257
|
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
1051
1258
|
partial_ratio
|
|
1052
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
1053
|
-
|
|
1054
|
-
|
|
1259
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
1260
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
1261
|
+
prior probability that the children of the node to prune are leaves.
|
|
1262
|
+
This ratio is inverted, and is meant to be inverted back in
|
|
1055
1263
|
`accept_move_and_sample_leaves`.
|
|
1056
1264
|
"""
|
|
1057
1265
|
|
|
1058
|
-
allowed: Bool[Array, 'num_trees']
|
|
1059
|
-
node: UInt[Array, 'num_trees']
|
|
1060
|
-
partial_ratio: Float32[Array, 'num_trees']
|
|
1266
|
+
allowed: Bool[Array, ' num_trees']
|
|
1267
|
+
node: UInt[Array, ' num_trees']
|
|
1268
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
1269
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
1061
1270
|
|
|
1062
1271
|
|
|
1063
1272
|
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
1064
1273
|
def propose_prune_moves(
|
|
1065
1274
|
key: Key[Array, ''],
|
|
1066
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
1067
|
-
affluence_tree: Bool[Array, '2**(d-1)']
|
|
1068
|
-
p_nonterminal: Float32[Array, 'd'],
|
|
1069
|
-
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1275
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1276
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1277
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
1278
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1070
1279
|
) -> PruneMoves:
|
|
1071
1280
|
"""
|
|
1072
1281
|
Tree structure prune move proposal of BART MCMC.
|
|
@@ -1078,9 +1287,10 @@ def propose_prune_moves(
|
|
|
1078
1287
|
split_tree
|
|
1079
1288
|
The splitting points of the tree.
|
|
1080
1289
|
affluence_tree
|
|
1081
|
-
Whether
|
|
1290
|
+
Whether each leaf can be grown.
|
|
1082
1291
|
p_nonterminal
|
|
1083
|
-
The probability of a
|
|
1292
|
+
The a priori probability of a node to be nonterminal conditional on
|
|
1293
|
+
the ancestors, including at the maximum depth where it should be zero.
|
|
1084
1294
|
p_propose_grow
|
|
1085
1295
|
The unnormalized probability of choosing a leaf to grow.
|
|
1086
1296
|
|
|
@@ -1088,28 +1298,33 @@ def propose_prune_moves(
|
|
|
1088
1298
|
-------
|
|
1089
1299
|
An object representing the proposed moves.
|
|
1090
1300
|
"""
|
|
1091
|
-
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
1301
|
+
node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
|
|
1092
1302
|
key, split_tree, affluence_tree, p_propose_grow
|
|
1093
1303
|
)
|
|
1094
|
-
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
1095
1304
|
|
|
1096
1305
|
ratio = compute_partial_ratio(
|
|
1097
1306
|
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
1098
1307
|
)
|
|
1099
1308
|
|
|
1100
1309
|
return PruneMoves(
|
|
1101
|
-
allowed=allowed
|
|
1310
|
+
allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
|
|
1102
1311
|
node=node_to_prune,
|
|
1103
1312
|
partial_ratio=ratio,
|
|
1313
|
+
affluence_tree=affluence_tree,
|
|
1104
1314
|
)
|
|
1105
1315
|
|
|
1106
1316
|
|
|
1107
1317
|
def choose_leaf_parent(
|
|
1108
1318
|
key: Key[Array, ''],
|
|
1109
|
-
split_tree: UInt[Array, '2**(d-1)'],
|
|
1110
|
-
affluence_tree: Bool[Array, '2**(d-1)']
|
|
1111
|
-
p_propose_grow: Float32[Array, '2**(d-1)'],
|
|
1112
|
-
) -> tuple[
|
|
1319
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
1320
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
1321
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
1322
|
+
) -> tuple[
|
|
1323
|
+
Int32[Array, ''],
|
|
1324
|
+
Int32[Array, ''],
|
|
1325
|
+
Float32[Array, ''],
|
|
1326
|
+
Bool[Array, 'num_trees 2**(d-1)'],
|
|
1327
|
+
]:
|
|
1113
1328
|
"""
|
|
1114
1329
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
1115
1330
|
|
|
@@ -1135,23 +1350,28 @@ def choose_leaf_parent(
|
|
|
1135
1350
|
The (normalized) probability that `choose_leaf` would chose
|
|
1136
1351
|
`node_to_prune` as leaf to grow, if passed the tree where
|
|
1137
1352
|
`node_to_prune` had been pruned.
|
|
1353
|
+
affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
|
|
1354
|
+
A partially updated `affluence_tree`, marking the node to prune as
|
|
1355
|
+
growable.
|
|
1138
1356
|
"""
|
|
1357
|
+
# sample a node to prune
|
|
1139
1358
|
is_prunable = grove.is_leaves_parent(split_tree)
|
|
1140
1359
|
num_prunable = jnp.count_nonzero(is_prunable)
|
|
1141
1360
|
node_to_prune = randint_masked(key, is_prunable)
|
|
1142
1361
|
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
1143
1362
|
|
|
1363
|
+
# compute stuff for reverse move
|
|
1144
1364
|
split_tree = split_tree.at[node_to_prune].set(0)
|
|
1145
|
-
|
|
1146
|
-
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
1365
|
+
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
1147
1366
|
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
1148
|
-
|
|
1149
|
-
prob_choose
|
|
1367
|
+
distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
1368
|
+
prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
|
|
1369
|
+
prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
|
|
1150
1370
|
|
|
1151
|
-
return node_to_prune, num_prunable, prob_choose
|
|
1371
|
+
return node_to_prune, num_prunable, prob_choose, affluence_tree
|
|
1152
1372
|
|
|
1153
1373
|
|
|
1154
|
-
def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
|
|
1374
|
+
def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
|
|
1155
1375
|
"""
|
|
1156
1376
|
Return a random integer in a range, including only some values.
|
|
1157
1377
|
|
|
@@ -1213,9 +1433,9 @@ class Counts(Module):
|
|
|
1213
1433
|
Number of datapoints in the parent (``= left + right``).
|
|
1214
1434
|
"""
|
|
1215
1435
|
|
|
1216
|
-
left: UInt[Array, 'num_trees']
|
|
1217
|
-
right: UInt[Array, 'num_trees']
|
|
1218
|
-
total: UInt[Array, 'num_trees']
|
|
1436
|
+
left: UInt[Array, ' num_trees']
|
|
1437
|
+
right: UInt[Array, ' num_trees']
|
|
1438
|
+
total: UInt[Array, ' num_trees']
|
|
1219
1439
|
|
|
1220
1440
|
|
|
1221
1441
|
class Precs(Module):
|
|
@@ -1235,9 +1455,9 @@ class Precs(Module):
|
|
|
1235
1455
|
Likelihood precision scale in the parent (``= left + right``).
|
|
1236
1456
|
"""
|
|
1237
1457
|
|
|
1238
|
-
left: Float32[Array, 'num_trees']
|
|
1239
|
-
right: Float32[Array, 'num_trees']
|
|
1240
|
-
total: Float32[Array, 'num_trees']
|
|
1458
|
+
left: Float32[Array, ' num_trees']
|
|
1459
|
+
right: Float32[Array, ' num_trees']
|
|
1460
|
+
total: Float32[Array, ' num_trees']
|
|
1241
1461
|
|
|
1242
1462
|
|
|
1243
1463
|
class PreLkV(Module):
|
|
@@ -1261,10 +1481,10 @@ class PreLkV(Module):
|
|
|
1261
1481
|
The **logarithm** of the square root term of the likelihood ratio.
|
|
1262
1482
|
"""
|
|
1263
1483
|
|
|
1264
|
-
sigma2_left: Float32[Array, 'num_trees']
|
|
1265
|
-
sigma2_right: Float32[Array, 'num_trees']
|
|
1266
|
-
sigma2_total: Float32[Array, 'num_trees']
|
|
1267
|
-
sqrt_term: Float32[Array, 'num_trees']
|
|
1484
|
+
sigma2_left: Float32[Array, ' num_trees']
|
|
1485
|
+
sigma2_right: Float32[Array, ' num_trees']
|
|
1486
|
+
sigma2_total: Float32[Array, ' num_trees']
|
|
1487
|
+
sqrt_term: Float32[Array, ' num_trees']
|
|
1268
1488
|
|
|
1269
1489
|
|
|
1270
1490
|
class PreLk(Module):
|
|
@@ -1331,7 +1551,6 @@ class ParallelStageOut(Module):
|
|
|
1331
1551
|
bart: State
|
|
1332
1552
|
moves: Moves
|
|
1333
1553
|
prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
|
|
1334
|
-
move_counts: Counts | None
|
|
1335
1554
|
move_precs: Precs | Counts
|
|
1336
1555
|
prelkv: PreLkV
|
|
1337
1556
|
prelk: PreLk
|
|
@@ -1342,7 +1561,7 @@ def accept_moves_parallel_stage(
|
|
|
1342
1561
|
key: Key[Array, ''], bart: State, moves: Moves
|
|
1343
1562
|
) -> ParallelStageOut:
|
|
1344
1563
|
"""
|
|
1345
|
-
Pre-
|
|
1564
|
+
Pre-compute quantities used to accept moves, in parallel across trees.
|
|
1346
1565
|
|
|
1347
1566
|
Parameters
|
|
1348
1567
|
----------
|
|
@@ -1362,33 +1581,41 @@ def accept_moves_parallel_stage(
|
|
|
1362
1581
|
bart,
|
|
1363
1582
|
forest=replace(
|
|
1364
1583
|
bart.forest,
|
|
1365
|
-
|
|
1584
|
+
var_tree=moves.var_tree,
|
|
1366
1585
|
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
1367
|
-
|
|
1586
|
+
leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
|
|
1368
1587
|
),
|
|
1369
1588
|
)
|
|
1370
1589
|
|
|
1371
1590
|
# count number of datapoints per leaf
|
|
1372
|
-
if
|
|
1591
|
+
if (
|
|
1592
|
+
bart.forest.min_points_per_decision_node is not None
|
|
1593
|
+
or bart.forest.min_points_per_leaf is not None
|
|
1594
|
+
or bart.prec_scale is None
|
|
1595
|
+
):
|
|
1373
1596
|
count_trees, move_counts = compute_count_trees(
|
|
1374
1597
|
bart.forest.leaf_indices, moves, bart.forest.count_batch_size
|
|
1375
1598
|
)
|
|
1376
|
-
else:
|
|
1377
|
-
# move_counts is passed later to a function, but then is unused under
|
|
1378
|
-
# this condition
|
|
1379
|
-
move_counts = None
|
|
1380
1599
|
|
|
1381
|
-
#
|
|
1382
|
-
|
|
1383
|
-
|
|
1600
|
+
# mark which leaves & potential leaves have enough points to be grown
|
|
1601
|
+
if bart.forest.min_points_per_decision_node is not None:
|
|
1602
|
+
count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
|
|
1603
|
+
moves = replace(
|
|
1604
|
+
moves,
|
|
1605
|
+
affluence_tree=moves.affluence_tree
|
|
1606
|
+
& (count_half_trees >= bart.forest.min_points_per_decision_node),
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
# copy updated affluence_tree to state
|
|
1610
|
+
bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
|
|
1611
|
+
|
|
1612
|
+
# veto grove move if new leaves don't have enough datapoints
|
|
1384
1613
|
if bart.forest.min_points_per_leaf is not None:
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
forest
|
|
1389
|
-
|
|
1390
|
-
affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
|
|
1391
|
-
),
|
|
1614
|
+
moves = replace(
|
|
1615
|
+
moves,
|
|
1616
|
+
allowed=moves.allowed
|
|
1617
|
+
& (move_counts.left >= bart.forest.min_points_per_leaf)
|
|
1618
|
+
& (move_counts.right >= bart.forest.min_points_per_leaf),
|
|
1392
1619
|
)
|
|
1393
1620
|
|
|
1394
1621
|
# count number of datapoints per leaf, weighted by error precision scale
|
|
@@ -1402,18 +1629,23 @@ def accept_moves_parallel_stage(
|
|
|
1402
1629
|
moves,
|
|
1403
1630
|
bart.forest.count_batch_size,
|
|
1404
1631
|
)
|
|
1632
|
+
assert move_precs is not None
|
|
1405
1633
|
|
|
1406
1634
|
# compute some missing information about moves
|
|
1407
|
-
moves = complete_ratio(moves,
|
|
1635
|
+
moves = complete_ratio(moves, bart.forest.p_nonterminal)
|
|
1636
|
+
save_ratios = bart.forest.log_likelihood is not None
|
|
1408
1637
|
bart = replace(
|
|
1409
1638
|
bart,
|
|
1410
1639
|
forest=replace(
|
|
1411
1640
|
bart.forest,
|
|
1412
1641
|
grow_prop_count=jnp.sum(moves.grow),
|
|
1413
1642
|
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
1643
|
+
log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
|
|
1414
1644
|
),
|
|
1415
1645
|
)
|
|
1416
1646
|
|
|
1647
|
+
# pre-compute some likelihood ratio & posterior terms
|
|
1648
|
+
assert bart.sigma2 is not None # `step` shall temporarily set it to 1
|
|
1417
1649
|
prelkv, prelk = precompute_likelihood_terms(
|
|
1418
1650
|
bart.sigma2, bart.forest.sigma_mu2, move_precs
|
|
1419
1651
|
)
|
|
@@ -1423,7 +1655,6 @@ def accept_moves_parallel_stage(
|
|
|
1423
1655
|
bart=bart,
|
|
1424
1656
|
moves=moves,
|
|
1425
1657
|
prec_trees=prec_trees,
|
|
1426
|
-
move_counts=move_counts,
|
|
1427
1658
|
move_precs=move_precs,
|
|
1428
1659
|
prelkv=prelkv,
|
|
1429
1660
|
prelk=prelk,
|
|
@@ -1453,12 +1684,10 @@ def apply_grow_to_indices(
|
|
|
1453
1684
|
"""
|
|
1454
1685
|
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
1455
1686
|
go_right = X[moves.grow_var, :] >= moves.grow_split
|
|
1456
|
-
tree_size = jnp.array(2 * moves.
|
|
1687
|
+
tree_size = jnp.array(2 * moves.var_tree.size)
|
|
1457
1688
|
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
1458
1689
|
return jnp.where(
|
|
1459
|
-
leaf_indices == node_to_update,
|
|
1460
|
-
left_child + go_right,
|
|
1461
|
-
leaf_indices,
|
|
1690
|
+
leaf_indices == node_to_update, left_child + go_right, leaf_indices
|
|
1462
1691
|
)
|
|
1463
1692
|
|
|
1464
1693
|
|
|
@@ -1486,7 +1715,7 @@ def compute_count_trees(
|
|
|
1486
1715
|
The counts of the number of points in the leaves grown or pruned by the
|
|
1487
1716
|
moves.
|
|
1488
1717
|
"""
|
|
1489
|
-
num_trees, tree_size = moves.
|
|
1718
|
+
num_trees, tree_size = moves.var_tree.shape
|
|
1490
1719
|
tree_size *= 2
|
|
1491
1720
|
tree_indices = jnp.arange(num_trees)
|
|
1492
1721
|
|
|
@@ -1543,7 +1772,7 @@ def _aggregate_scatter(
|
|
|
1543
1772
|
indices: Integer[Array, '*'],
|
|
1544
1773
|
size: int,
|
|
1545
1774
|
dtype: jnp.dtype,
|
|
1546
|
-
) -> Shaped[Array, '{size}']:
|
|
1775
|
+
) -> Shaped[Array, ' {size}']:
|
|
1547
1776
|
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1548
1777
|
|
|
1549
1778
|
|
|
@@ -1576,7 +1805,7 @@ def _aggregate_batched_alltrees(
|
|
|
1576
1805
|
|
|
1577
1806
|
|
|
1578
1807
|
def compute_prec_trees(
|
|
1579
|
-
prec_scale: Float32[Array, 'n'],
|
|
1808
|
+
prec_scale: Float32[Array, ' n'],
|
|
1580
1809
|
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1581
1810
|
moves: Moves,
|
|
1582
1811
|
batch_size: int | None,
|
|
@@ -1603,7 +1832,7 @@ def compute_prec_trees(
|
|
|
1603
1832
|
precs : Precs
|
|
1604
1833
|
The likelihood precision scale in the nodes involved in the moves.
|
|
1605
1834
|
"""
|
|
1606
|
-
num_trees, tree_size = moves.
|
|
1835
|
+
num_trees, tree_size = moves.var_tree.shape
|
|
1607
1836
|
tree_size *= 2
|
|
1608
1837
|
tree_indices = jnp.arange(num_trees)
|
|
1609
1838
|
|
|
@@ -1621,7 +1850,7 @@ def compute_prec_trees(
|
|
|
1621
1850
|
|
|
1622
1851
|
|
|
1623
1852
|
def prec_per_leaf(
|
|
1624
|
-
prec_scale: Float32[Array, 'n'],
|
|
1853
|
+
prec_scale: Float32[Array, ' n'],
|
|
1625
1854
|
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1626
1855
|
tree_size: int,
|
|
1627
1856
|
batch_size: int | None,
|
|
@@ -1651,7 +1880,7 @@ def prec_per_leaf(
|
|
|
1651
1880
|
|
|
1652
1881
|
|
|
1653
1882
|
def _prec_scan(
|
|
1654
|
-
prec_scale: Float32[Array, 'n'],
|
|
1883
|
+
prec_scale: Float32[Array, ' n'],
|
|
1655
1884
|
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1656
1885
|
tree_size: int,
|
|
1657
1886
|
) -> Float32[Array, 'num_trees {tree_size}']:
|
|
@@ -1665,7 +1894,7 @@ def _prec_scan(
|
|
|
1665
1894
|
|
|
1666
1895
|
|
|
1667
1896
|
def _prec_vec(
|
|
1668
|
-
prec_scale: Float32[Array, 'n'],
|
|
1897
|
+
prec_scale: Float32[Array, ' n'],
|
|
1669
1898
|
leaf_indices: UInt[Array, 'num_trees n'],
|
|
1670
1899
|
tree_size: int,
|
|
1671
1900
|
batch_size: int,
|
|
@@ -1675,77 +1904,59 @@ def _prec_vec(
|
|
|
1675
1904
|
)
|
|
1676
1905
|
|
|
1677
1906
|
|
|
1678
|
-
def complete_ratio(
|
|
1679
|
-
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1680
|
-
) -> Moves:
|
|
1907
|
+
def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
|
|
1681
1908
|
"""
|
|
1682
1909
|
Complete non-likelihood MH ratio calculation.
|
|
1683
1910
|
|
|
1684
|
-
This function adds the probability of choosing
|
|
1911
|
+
This function adds the probability of choosing a prune move over the grow
|
|
1912
|
+
move in the inverse transition, and the a priori probability that the
|
|
1913
|
+
children nodes are leaves.
|
|
1685
1914
|
|
|
1686
1915
|
Parameters
|
|
1687
1916
|
----------
|
|
1688
1917
|
moves
|
|
1689
|
-
The proposed moves
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1918
|
+
The proposed moves. Must have already been updated to keep into account
|
|
1919
|
+
the thresholds on the number of datapoints per node, this happens in
|
|
1920
|
+
`accept_moves_parallel_stage`.
|
|
1921
|
+
p_nonterminal
|
|
1922
|
+
The a priori probability of each node being nonterminal conditional on
|
|
1923
|
+
its ancestors, including at the maximum depth where it should be zero.
|
|
1695
1924
|
|
|
1696
1925
|
Returns
|
|
1697
1926
|
-------
|
|
1698
1927
|
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
1699
1928
|
"""
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1929
|
+
# can the leaves can be grown?
|
|
1930
|
+
num_trees, _ = moves.affluence_tree.shape
|
|
1931
|
+
tree_indices = jnp.arange(num_trees)
|
|
1932
|
+
left_growable = moves.affluence_tree.at[tree_indices, moves.left].get(
|
|
1933
|
+
mode='fill', fill_value=False
|
|
1934
|
+
)
|
|
1935
|
+
right_growable = moves.affluence_tree.at[tree_indices, moves.right].get(
|
|
1936
|
+
mode='fill', fill_value=False
|
|
1705
1937
|
)
|
|
1706
1938
|
|
|
1707
|
-
|
|
1708
|
-
def compute_p_prune(
|
|
1709
|
-
moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
|
|
1710
|
-
) -> Float32[Array, 'num_trees']:
|
|
1711
|
-
"""
|
|
1712
|
-
Compute the probability of proposing a prune move for each tree.
|
|
1713
|
-
|
|
1714
|
-
Parameters
|
|
1715
|
-
----------
|
|
1716
|
-
moves
|
|
1717
|
-
The proposed moves, see `propose_moves`.
|
|
1718
|
-
move_counts
|
|
1719
|
-
The number of datapoints in the proposed children of the leaf to grow.
|
|
1720
|
-
Not used if `min_points_per_leaf` is `None`.
|
|
1721
|
-
min_points_per_leaf
|
|
1722
|
-
The minimum number of data points in a leaf node.
|
|
1723
|
-
|
|
1724
|
-
Returns
|
|
1725
|
-
-------
|
|
1726
|
-
The probability of proposing a prune move.
|
|
1727
|
-
|
|
1728
|
-
Notes
|
|
1729
|
-
-----
|
|
1730
|
-
This probability is computed for going from the state with the deeper tree
|
|
1731
|
-
to the one with the shallower one. This means, if grow: after accepting the
|
|
1732
|
-
grow move, if prune: right away.
|
|
1733
|
-
"""
|
|
1734
|
-
# calculation in case the move is grow
|
|
1939
|
+
# p_prune if grow
|
|
1735
1940
|
other_growable_leaves = moves.num_growable >= 2
|
|
1736
|
-
|
|
1737
|
-
if min_points_per_leaf is not None:
|
|
1738
|
-
assert move_counts is not None
|
|
1739
|
-
any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
|
|
1740
|
-
any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
|
|
1741
|
-
new_leaves_growable &= any_above_threshold
|
|
1742
|
-
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1941
|
+
grow_again_allowed = other_growable_leaves | left_growable | right_growable
|
|
1743
1942
|
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1744
1943
|
|
|
1745
|
-
#
|
|
1944
|
+
# p_prune if prune
|
|
1746
1945
|
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
1747
1946
|
|
|
1748
|
-
|
|
1947
|
+
# select p_prune
|
|
1948
|
+
p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
1949
|
+
|
|
1950
|
+
# prior probability of both children being terminal
|
|
1951
|
+
pt_left = 1 - p_nonterminal[moves.left] * left_growable
|
|
1952
|
+
pt_right = 1 - p_nonterminal[moves.right] * right_growable
|
|
1953
|
+
pt_children = pt_left * pt_right
|
|
1954
|
+
|
|
1955
|
+
return replace(
|
|
1956
|
+
moves,
|
|
1957
|
+
log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
|
|
1958
|
+
partial_ratio=None,
|
|
1959
|
+
)
|
|
1749
1960
|
|
|
1750
1961
|
|
|
1751
1962
|
@vmap_nodoc
|
|
@@ -1815,9 +2026,7 @@ def precompute_likelihood_terms(
|
|
|
1815
2026
|
sigma2_total=sigma2_total,
|
|
1816
2027
|
sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
|
|
1817
2028
|
)
|
|
1818
|
-
return prelkv, PreLk(
|
|
1819
|
-
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1820
|
-
)
|
|
2029
|
+
return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2))
|
|
1821
2030
|
|
|
1822
2031
|
|
|
1823
2032
|
def precompute_leaf_terms(
|
|
@@ -1851,14 +2060,14 @@ def precompute_leaf_terms(
|
|
|
1851
2060
|
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
1852
2061
|
return PreLf(
|
|
1853
2062
|
mean_factor=var_post / sigma2,
|
|
1854
|
-
# mean = mean_lk * prec_lk * var_post
|
|
1855
|
-
# resid_tree = mean_lk * prec_tree -->
|
|
1856
|
-
# --> mean_lk = resid_tree / prec_tree (kind of)
|
|
1857
|
-
# mean_factor =
|
|
1858
|
-
# = mean / resid_tree =
|
|
1859
|
-
# = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
1860
|
-
# = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
1861
|
-
# = var_post / sigma2
|
|
2063
|
+
# | mean = mean_lk * prec_lk * var_post
|
|
2064
|
+
# | resid_tree = mean_lk * prec_tree -->
|
|
2065
|
+
# | --> mean_lk = resid_tree / prec_tree (kind of)
|
|
2066
|
+
# | mean_factor =
|
|
2067
|
+
# | = mean / resid_tree =
|
|
2068
|
+
# | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
2069
|
+
# | = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
2070
|
+
# | = var_post / sigma2
|
|
1862
2071
|
centered_leaves=z * jnp.sqrt(var_post),
|
|
1863
2072
|
)
|
|
1864
2073
|
|
|
@@ -1884,42 +2093,34 @@ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
|
1884
2093
|
"""
|
|
1885
2094
|
|
|
1886
2095
|
def loop(resid, pt):
|
|
1887
|
-
resid, leaf_tree, acc, to_prune,
|
|
2096
|
+
resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
|
|
1888
2097
|
resid,
|
|
1889
2098
|
SeqStageInAllTrees(
|
|
1890
2099
|
pso.bart.X,
|
|
1891
2100
|
pso.bart.forest.resid_batch_size,
|
|
1892
2101
|
pso.bart.prec_scale,
|
|
1893
|
-
pso.bart.forest.min_points_per_leaf,
|
|
1894
2102
|
pso.bart.forest.log_likelihood is not None,
|
|
1895
2103
|
pso.prelk,
|
|
1896
2104
|
),
|
|
1897
2105
|
pt,
|
|
1898
2106
|
)
|
|
1899
|
-
return resid, (leaf_tree, acc, to_prune,
|
|
2107
|
+
return resid, (leaf_tree, acc, to_prune, lkratio)
|
|
1900
2108
|
|
|
1901
2109
|
pts = SeqStageInPerTree(
|
|
1902
|
-
pso.bart.forest.
|
|
2110
|
+
pso.bart.forest.leaf_tree,
|
|
1903
2111
|
pso.prec_trees,
|
|
1904
2112
|
pso.moves,
|
|
1905
|
-
pso.move_counts,
|
|
1906
2113
|
pso.move_precs,
|
|
1907
2114
|
pso.bart.forest.leaf_indices,
|
|
1908
2115
|
pso.prelkv,
|
|
1909
2116
|
pso.prelf,
|
|
1910
2117
|
)
|
|
1911
|
-
resid, (leaf_trees, acc, to_prune,
|
|
2118
|
+
resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
|
|
1912
2119
|
|
|
1913
|
-
save_ratios = pso.bart.forest.log_likelihood is not None
|
|
1914
2120
|
bart = replace(
|
|
1915
2121
|
pso.bart,
|
|
1916
2122
|
resid=resid,
|
|
1917
|
-
forest=replace(
|
|
1918
|
-
pso.bart.forest,
|
|
1919
|
-
leaf_trees=leaf_trees,
|
|
1920
|
-
log_likelihood=ratios['log_likelihood'] if save_ratios else None,
|
|
1921
|
-
log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
|
|
1922
|
-
),
|
|
2123
|
+
forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
|
|
1923
2124
|
)
|
|
1924
2125
|
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
1925
2126
|
|
|
@@ -1928,7 +2129,7 @@ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
|
1928
2129
|
|
|
1929
2130
|
class SeqStageInAllTrees(Module):
|
|
1930
2131
|
"""
|
|
1931
|
-
The inputs to `accept_move_and_sample_leaves` that are
|
|
2132
|
+
The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
|
|
1932
2133
|
|
|
1933
2134
|
Parameters
|
|
1934
2135
|
----------
|
|
@@ -1939,8 +2140,6 @@ class SeqStageInAllTrees(Module):
|
|
|
1939
2140
|
prec_scale
|
|
1940
2141
|
The scale of the precision of the error on each datapoint. If None, it
|
|
1941
2142
|
is assumed to be 1.
|
|
1942
|
-
min_points_per_leaf
|
|
1943
|
-
The minimum number of data points in a leaf node.
|
|
1944
2143
|
save_ratios
|
|
1945
2144
|
Whether to save the acceptance ratios.
|
|
1946
2145
|
prelk
|
|
@@ -1949,10 +2148,9 @@ class SeqStageInAllTrees(Module):
|
|
|
1949
2148
|
"""
|
|
1950
2149
|
|
|
1951
2150
|
X: UInt[Array, 'p n']
|
|
1952
|
-
resid_batch_size: int | None
|
|
1953
|
-
prec_scale: Float32[Array, 'n'] | None
|
|
1954
|
-
|
|
1955
|
-
save_ratios: bool
|
|
2151
|
+
resid_batch_size: int | None = field(static=True)
|
|
2152
|
+
prec_scale: Float32[Array, ' n'] | None
|
|
2153
|
+
save_ratios: bool = field(static=True)
|
|
1956
2154
|
prelk: PreLk
|
|
1957
2155
|
|
|
1958
2156
|
|
|
@@ -1968,9 +2166,6 @@ class SeqStageInPerTree(Module):
|
|
|
1968
2166
|
The likelihood precision scale in each potential or actual leaf node.
|
|
1969
2167
|
move
|
|
1970
2168
|
The proposed move, see `propose_moves`.
|
|
1971
|
-
move_counts
|
|
1972
|
-
The counts of the number of points in the the nodes modified by the
|
|
1973
|
-
moves.
|
|
1974
2169
|
move_precs
|
|
1975
2170
|
The likelihood precision scale in each node modified by the moves.
|
|
1976
2171
|
leaf_indices
|
|
@@ -1982,26 +2177,23 @@ class SeqStageInPerTree(Module):
|
|
|
1982
2177
|
are specific to the tree.
|
|
1983
2178
|
"""
|
|
1984
2179
|
|
|
1985
|
-
leaf_tree: Float32[Array, '2**d']
|
|
1986
|
-
prec_tree: Float32[Array, '2**d']
|
|
2180
|
+
leaf_tree: Float32[Array, ' 2**d']
|
|
2181
|
+
prec_tree: Float32[Array, ' 2**d']
|
|
1987
2182
|
move: Moves
|
|
1988
|
-
move_counts: Counts | None
|
|
1989
2183
|
move_precs: Precs | Counts
|
|
1990
|
-
leaf_indices: UInt[Array, 'n']
|
|
2184
|
+
leaf_indices: UInt[Array, ' n']
|
|
1991
2185
|
prelkv: PreLkV
|
|
1992
2186
|
prelf: PreLf
|
|
1993
2187
|
|
|
1994
2188
|
|
|
1995
2189
|
def accept_move_and_sample_leaves(
|
|
1996
|
-
resid: Float32[Array, 'n'],
|
|
1997
|
-
at: SeqStageInAllTrees,
|
|
1998
|
-
pt: SeqStageInPerTree,
|
|
2190
|
+
resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
|
|
1999
2191
|
) -> tuple[
|
|
2000
|
-
Float32[Array, 'n'],
|
|
2001
|
-
Float32[Array, '2**d'],
|
|
2192
|
+
Float32[Array, ' n'],
|
|
2193
|
+
Float32[Array, ' 2**d'],
|
|
2002
2194
|
Bool[Array, ''],
|
|
2003
2195
|
Bool[Array, ''],
|
|
2004
|
-
|
|
2196
|
+
Float32[Array, ''] | None,
|
|
2005
2197
|
]:
|
|
2006
2198
|
"""
|
|
2007
2199
|
Accept or reject a proposed move and sample the new leaf values.
|
|
@@ -2026,8 +2218,9 @@ def accept_move_and_sample_leaves(
|
|
|
2026
2218
|
to_prune : Bool[Array, '']
|
|
2027
2219
|
Whether, to reflect the acceptance status of the move, the state should
|
|
2028
2220
|
be updated by pruning the leaves involved in the move.
|
|
2029
|
-
|
|
2030
|
-
The
|
|
2221
|
+
log_lk_ratio : Float32[Array, ''] | None
|
|
2222
|
+
The logarithm of the likelihood ratio for the move. `None` if not to be
|
|
2223
|
+
saved.
|
|
2031
2224
|
"""
|
|
2032
2225
|
# sum residuals in each leaf, in tree proposed by grow move
|
|
2033
2226
|
if at.prec_scale is None:
|
|
@@ -2041,17 +2234,12 @@ def accept_move_and_sample_leaves(
|
|
|
2041
2234
|
# subtract starting tree from function
|
|
2042
2235
|
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
2043
2236
|
|
|
2044
|
-
# get indices of move
|
|
2045
|
-
node = pt.move.node
|
|
2046
|
-
assert node.dtype == jnp.int32
|
|
2047
|
-
left = pt.move.left
|
|
2048
|
-
right = pt.move.right
|
|
2049
|
-
|
|
2050
2237
|
# sum residuals in parent node modified by move
|
|
2051
|
-
resid_left = resid_tree[left]
|
|
2052
|
-
resid_right = resid_tree[right]
|
|
2238
|
+
resid_left = resid_tree[pt.move.left]
|
|
2239
|
+
resid_right = resid_tree[pt.move.right]
|
|
2053
2240
|
resid_total = resid_left + resid_right
|
|
2054
|
-
|
|
2241
|
+
assert pt.move.node.dtype == jnp.int32
|
|
2242
|
+
resid_tree = resid_tree.at[pt.move.node].set(resid_total)
|
|
2055
2243
|
|
|
2056
2244
|
# compute acceptance ratio
|
|
2057
2245
|
log_lk_ratio = compute_likelihood_ratio(
|
|
@@ -2059,48 +2247,37 @@ def accept_move_and_sample_leaves(
|
|
|
2059
2247
|
)
|
|
2060
2248
|
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
2061
2249
|
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
ratios.update(
|
|
2065
|
-
log_trans_prior=pt.move.log_trans_prior_ratio,
|
|
2066
|
-
# TODO save log_trans_prior_ratio as a vector outside of this loop,
|
|
2067
|
-
# then change the option everywhere to `save_likelihood_ratio`.
|
|
2068
|
-
log_likelihood=log_lk_ratio,
|
|
2069
|
-
)
|
|
2250
|
+
if not at.save_ratios:
|
|
2251
|
+
log_lk_ratio = None
|
|
2070
2252
|
|
|
2071
2253
|
# determine whether to accept the move
|
|
2072
2254
|
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
2073
|
-
if at.min_points_per_leaf is not None:
|
|
2074
|
-
assert pt.move_counts is not None
|
|
2075
|
-
acc &= pt.move_counts.left >= at.min_points_per_leaf
|
|
2076
|
-
acc &= pt.move_counts.right >= at.min_points_per_leaf
|
|
2077
2255
|
|
|
2078
2256
|
# compute leaves posterior and sample leaves
|
|
2079
|
-
initial_leaf_tree = pt.leaf_tree
|
|
2080
2257
|
mean_post = resid_tree * pt.prelf.mean_factor
|
|
2081
2258
|
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
2082
2259
|
|
|
2083
2260
|
# copy leaves around such that the leaf indices point to the correct leaf
|
|
2084
2261
|
to_prune = acc ^ pt.move.grow
|
|
2085
2262
|
leaf_tree = (
|
|
2086
|
-
leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
2087
|
-
.set(leaf_tree[node])
|
|
2088
|
-
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
2089
|
-
.set(leaf_tree[node])
|
|
2263
|
+
leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
|
|
2264
|
+
.set(leaf_tree[pt.move.node])
|
|
2265
|
+
.at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
|
|
2266
|
+
.set(leaf_tree[pt.move.node])
|
|
2090
2267
|
)
|
|
2091
2268
|
|
|
2092
2269
|
# replace old tree with new tree in function values
|
|
2093
|
-
resid += (
|
|
2270
|
+
resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices]
|
|
2094
2271
|
|
|
2095
|
-
return resid, leaf_tree, acc, to_prune,
|
|
2272
|
+
return resid, leaf_tree, acc, to_prune, log_lk_ratio
|
|
2096
2273
|
|
|
2097
2274
|
|
|
2098
2275
|
def sum_resid(
|
|
2099
|
-
scaled_resid: Float32[Array, 'n'],
|
|
2100
|
-
leaf_indices: UInt[Array, 'n'],
|
|
2276
|
+
scaled_resid: Float32[Array, ' n'],
|
|
2277
|
+
leaf_indices: UInt[Array, ' n'],
|
|
2101
2278
|
tree_size: int,
|
|
2102
2279
|
batch_size: int | None,
|
|
2103
|
-
) -> Float32[Array, '{tree_size}']:
|
|
2280
|
+
) -> Float32[Array, ' {tree_size}']:
|
|
2104
2281
|
"""
|
|
2105
2282
|
Sum the residuals in each leaf.
|
|
2106
2283
|
|
|
@@ -2134,7 +2311,7 @@ def _aggregate_batched_onetree(
|
|
|
2134
2311
|
size: int,
|
|
2135
2312
|
dtype: jnp.dtype,
|
|
2136
2313
|
batch_size: int,
|
|
2137
|
-
) -> Float32[Array, '{size}']:
|
|
2314
|
+
) -> Float32[Array, ' {size}']:
|
|
2138
2315
|
(n,) = indices.shape
|
|
2139
2316
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
2140
2317
|
batch_indices = jnp.arange(n) % nbatches
|
|
@@ -2206,7 +2383,7 @@ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
|
2206
2383
|
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
2207
2384
|
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
2208
2385
|
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
2209
|
-
|
|
2386
|
+
split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
|
|
2210
2387
|
),
|
|
2211
2388
|
)
|
|
2212
2389
|
|
|
@@ -2234,22 +2411,20 @@ def apply_moves_to_leaf_indices(
|
|
|
2234
2411
|
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
2235
2412
|
is_child = (leaf_indices & mask) == moves.left
|
|
2236
2413
|
return jnp.where(
|
|
2237
|
-
is_child & moves.to_prune,
|
|
2238
|
-
moves.node.astype(leaf_indices.dtype),
|
|
2239
|
-
leaf_indices,
|
|
2414
|
+
is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
|
|
2240
2415
|
)
|
|
2241
2416
|
|
|
2242
2417
|
|
|
2243
2418
|
@vmap_nodoc
|
|
2244
2419
|
def apply_moves_to_split_trees(
|
|
2245
|
-
|
|
2420
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
2246
2421
|
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
2247
2422
|
"""
|
|
2248
2423
|
Update the split trees to match the accepted move.
|
|
2249
2424
|
|
|
2250
2425
|
Parameters
|
|
2251
2426
|
----------
|
|
2252
|
-
|
|
2427
|
+
split_tree
|
|
2253
2428
|
The cutpoints of the decision nodes in the initial trees.
|
|
2254
2429
|
moves
|
|
2255
2430
|
The proposed moves (see `propose_moves`), as updated by
|
|
@@ -2261,21 +2436,9 @@ def apply_moves_to_split_trees(
|
|
|
2261
2436
|
"""
|
|
2262
2437
|
assert moves.to_prune is not None
|
|
2263
2438
|
return (
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
moves.node,
|
|
2268
|
-
split_trees.size,
|
|
2269
|
-
)
|
|
2270
|
-
]
|
|
2271
|
-
.set(moves.grow_split.astype(split_trees.dtype))
|
|
2272
|
-
.at[
|
|
2273
|
-
jnp.where(
|
|
2274
|
-
moves.to_prune,
|
|
2275
|
-
moves.node,
|
|
2276
|
-
split_trees.size,
|
|
2277
|
-
)
|
|
2278
|
-
]
|
|
2439
|
+
split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
|
|
2440
|
+
.set(moves.grow_split.astype(split_tree.dtype))
|
|
2441
|
+
.at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
|
|
2279
2442
|
.set(0)
|
|
2280
2443
|
)
|
|
2281
2444
|
|
|
@@ -2305,6 +2468,8 @@ def step_sigma(key: Key[Array, ''], bart: State) -> State:
|
|
|
2305
2468
|
beta = bart.sigma2_beta + norm2 / 2
|
|
2306
2469
|
|
|
2307
2470
|
sample = random.gamma(key, alpha)
|
|
2471
|
+
# random.gamma seems to be slow at compiling, maybe cdf inversion would
|
|
2472
|
+
# be better, but it's not implemented in jax
|
|
2308
2473
|
return replace(bart, sigma2=beta / sample)
|
|
2309
2474
|
|
|
2310
2475
|
|
|
@@ -2324,12 +2489,128 @@ def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
|
2324
2489
|
The updated BART MCMC state.
|
|
2325
2490
|
"""
|
|
2326
2491
|
trees_plus_offset = bart.z - bart.resid
|
|
2327
|
-
|
|
2328
|
-
|
|
2329
|
-
resid = random.truncated_normal(key, lower, upper)
|
|
2330
|
-
# TODO jax's implementation of truncated_normal is not good, it just does
|
|
2331
|
-
# cdf inversion with erf and erf_inv. I can do better, at least avoiding to
|
|
2332
|
-
# compute one of the boundaries, and maybe also flipping and using ndtr
|
|
2333
|
-
# instead of erf for numerical stability (open an issue in jax?)
|
|
2492
|
+
assert bart.y.dtype == bool
|
|
2493
|
+
resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
|
|
2334
2494
|
z = trees_plus_offset + resid
|
|
2335
2495
|
return replace(bart, z=z, resid=resid)
|
|
2496
|
+
|
|
2497
|
+
|
|
2498
|
+
def step_s(key: Key[Array, ''], bart: State) -> State:
|
|
2499
|
+
"""
|
|
2500
|
+
Update `log_s` using Dirichlet sampling.
|
|
2501
|
+
|
|
2502
|
+
The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
|
|
2503
|
+
is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
|
|
2504
|
+
varcount is the count of how many times each variable is used in the
|
|
2505
|
+
current forest.
|
|
2506
|
+
|
|
2507
|
+
Parameters
|
|
2508
|
+
----------
|
|
2509
|
+
key
|
|
2510
|
+
Random key for sampling.
|
|
2511
|
+
bart
|
|
2512
|
+
The current BART state.
|
|
2513
|
+
|
|
2514
|
+
Returns
|
|
2515
|
+
-------
|
|
2516
|
+
Updated BART state with re-sampled `log_s`.
|
|
2517
|
+
|
|
2518
|
+
"""
|
|
2519
|
+
assert bart.forest.theta is not None
|
|
2520
|
+
|
|
2521
|
+
# histogram current variable usage
|
|
2522
|
+
p = bart.forest.max_split.size
|
|
2523
|
+
varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree)
|
|
2524
|
+
|
|
2525
|
+
# sample from Dirichlet posterior
|
|
2526
|
+
alpha = bart.forest.theta / p + varcount
|
|
2527
|
+
log_s = random.loggamma(key, alpha)
|
|
2528
|
+
|
|
2529
|
+
# update forest with new s
|
|
2530
|
+
return replace(bart, forest=replace(bart.forest, log_s=log_s))
|
|
2531
|
+
|
|
2532
|
+
|
|
2533
|
+
def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
|
|
2534
|
+
"""
|
|
2535
|
+
Update `theta`.
|
|
2536
|
+
|
|
2537
|
+
The prior is theta / (theta + rho) ~ Beta(a, b).
|
|
2538
|
+
|
|
2539
|
+
Parameters
|
|
2540
|
+
----------
|
|
2541
|
+
key
|
|
2542
|
+
Random key for sampling.
|
|
2543
|
+
bart
|
|
2544
|
+
The current BART state.
|
|
2545
|
+
num_grid
|
|
2546
|
+
The number of points in the evenly-spaced grid used to sample
|
|
2547
|
+
theta / (theta + rho).
|
|
2548
|
+
|
|
2549
|
+
Returns
|
|
2550
|
+
-------
|
|
2551
|
+
Updated BART state with re-sampled `theta`.
|
|
2552
|
+
"""
|
|
2553
|
+
assert bart.forest.log_s is not None
|
|
2554
|
+
assert bart.forest.rho is not None
|
|
2555
|
+
assert bart.forest.a is not None
|
|
2556
|
+
assert bart.forest.b is not None
|
|
2557
|
+
|
|
2558
|
+
# the grid points are the midpoints of num_grid bins in (0, 1)
|
|
2559
|
+
padding = 1 / (2 * num_grid)
|
|
2560
|
+
lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
|
|
2561
|
+
|
|
2562
|
+
# normalize s
|
|
2563
|
+
log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
|
|
2564
|
+
|
|
2565
|
+
# sample lambda
|
|
2566
|
+
logp, theta_grid = _log_p_lamda(
|
|
2567
|
+
lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
|
|
2568
|
+
)
|
|
2569
|
+
i = random.categorical(key, logp)
|
|
2570
|
+
theta = theta_grid[i]
|
|
2571
|
+
|
|
2572
|
+
return replace(bart, forest=replace(bart.forest, theta=theta))
|
|
2573
|
+
|
|
2574
|
+
|
|
2575
|
+
def _log_p_lamda(
|
|
2576
|
+
lamda: Float32[Array, ' num_grid'],
|
|
2577
|
+
log_s: Float32[Array, ' p'],
|
|
2578
|
+
rho: Float32[Array, ''],
|
|
2579
|
+
a: Float32[Array, ''],
|
|
2580
|
+
b: Float32[Array, ''],
|
|
2581
|
+
) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
|
|
2582
|
+
# in the following I use lamda[::-1] == 1 - lamda
|
|
2583
|
+
theta = rho * lamda / lamda[::-1]
|
|
2584
|
+
p = log_s.size
|
|
2585
|
+
return (
|
|
2586
|
+
(a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
|
|
2587
|
+
+ (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
|
|
2588
|
+
+ gammaln(theta)
|
|
2589
|
+
- p * gammaln(theta / p)
|
|
2590
|
+
+ theta / p * jnp.sum(log_s)
|
|
2591
|
+
), theta
|
|
2592
|
+
|
|
2593
|
+
|
|
2594
|
+
def step_sparse(key: Key[Array, ''], bart: State) -> State:
|
|
2595
|
+
"""
|
|
2596
|
+
Update the sparsity parameters.
|
|
2597
|
+
|
|
2598
|
+
This invokes `step_s`, and then `step_theta` only if the parameters of
|
|
2599
|
+
the theta prior are defined.
|
|
2600
|
+
|
|
2601
|
+
Parameters
|
|
2602
|
+
----------
|
|
2603
|
+
key
|
|
2604
|
+
Random key for sampling.
|
|
2605
|
+
bart
|
|
2606
|
+
The current BART state.
|
|
2607
|
+
|
|
2608
|
+
Returns
|
|
2609
|
+
-------
|
|
2610
|
+
Updated BART state with re-sampled `log_s` and `theta`.
|
|
2611
|
+
"""
|
|
2612
|
+
keys = split(key)
|
|
2613
|
+
bart = step_s(keys.pop(), bart)
|
|
2614
|
+
if bart.forest.rho is not None:
|
|
2615
|
+
bart = step_theta(keys.pop(), bart)
|
|
2616
|
+
return bart
|