bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcstep/_step.py
ADDED
|
@@ -0,0 +1,1603 @@
|
|
|
1
|
+
# bartz/src/bartz/mcmcstep/_step.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Implement `step`, `step_trees`, and the accept-reject logic."""
|
|
26
|
+
|
|
27
|
+
from dataclasses import replace
|
|
28
|
+
from functools import partial
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
# available since jax v0.6.1
|
|
32
|
+
from jax import shard_map
|
|
33
|
+
except ImportError:
|
|
34
|
+
# deprecated in jax v0.8.0
|
|
35
|
+
from jax.experimental.shard_map import shard_map
|
|
36
|
+
|
|
37
|
+
import jax
|
|
38
|
+
from equinox import Module, tree_at
|
|
39
|
+
from jax import lax, random, vmap
|
|
40
|
+
from jax import numpy as jnp
|
|
41
|
+
from jax.lax import cond
|
|
42
|
+
from jax.scipy.linalg import solve_triangular
|
|
43
|
+
from jax.scipy.special import gammaln, logsumexp
|
|
44
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
45
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32
|
|
46
|
+
|
|
47
|
+
from bartz._profiler import (
|
|
48
|
+
get_profile_mode,
|
|
49
|
+
jit_and_block_if_profiling,
|
|
50
|
+
jit_if_not_profiling,
|
|
51
|
+
jit_if_profiling,
|
|
52
|
+
vmap_chains_if_not_profiling,
|
|
53
|
+
vmap_chains_if_profiling,
|
|
54
|
+
)
|
|
55
|
+
from bartz.grove import var_histogram
|
|
56
|
+
from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc
|
|
57
|
+
from bartz.mcmcstep._moves import Moves, propose_moves
|
|
58
|
+
from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@partial(jit_if_not_profiling, donate_argnums=(1,))
|
|
62
|
+
@partial(vmap_chains_if_not_profiling, auto_split_keys=True)
|
|
63
|
+
def step(key: Key[Array, ''], bart: State) -> State:
|
|
64
|
+
"""
|
|
65
|
+
Do one MCMC step.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
key
|
|
70
|
+
A jax random key.
|
|
71
|
+
bart
|
|
72
|
+
A BART mcmc state, as created by `init`.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
The new BART mcmc state.
|
|
77
|
+
|
|
78
|
+
Notes
|
|
79
|
+
-----
|
|
80
|
+
The memory of the input state is re-used for the output state, so the input
|
|
81
|
+
state can not be used any more after calling `step`. All this applies
|
|
82
|
+
outside of `jax.jit`.
|
|
83
|
+
"""
|
|
84
|
+
# handle the interactions between chains and profile mode
|
|
85
|
+
num_chains = bart.forest.num_chains()
|
|
86
|
+
chain_shape = () if num_chains is None else (num_chains,)
|
|
87
|
+
if get_profile_mode() and num_chains is not None and key.ndim == 0:
|
|
88
|
+
key = random.split(key, num_chains)
|
|
89
|
+
assert key.shape == chain_shape
|
|
90
|
+
|
|
91
|
+
keys = split(key, 3)
|
|
92
|
+
|
|
93
|
+
if bart.y.dtype == bool:
|
|
94
|
+
bart = replace(bart, error_cov_inv=jnp.ones(chain_shape))
|
|
95
|
+
bart = step_trees(keys.pop(), bart)
|
|
96
|
+
bart = replace(bart, error_cov_inv=None)
|
|
97
|
+
bart = step_z(keys.pop(), bart)
|
|
98
|
+
|
|
99
|
+
else: # continuous or multivariate regression
|
|
100
|
+
bart = step_trees(keys.pop(), bart)
|
|
101
|
+
bart = step_error_cov_inv(keys.pop(), bart)
|
|
102
|
+
|
|
103
|
+
bart = step_sparse(keys.pop(), bart)
|
|
104
|
+
return step_config(bart)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def step_trees(key: Key[Array, ''], bart: State) -> State:
|
|
108
|
+
"""
|
|
109
|
+
Forest sampling step of BART MCMC.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
key
|
|
114
|
+
A jax random key.
|
|
115
|
+
bart
|
|
116
|
+
A BART mcmc state, as created by `init`.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
The new BART mcmc state.
|
|
121
|
+
|
|
122
|
+
Notes
|
|
123
|
+
-----
|
|
124
|
+
This function zeroes the proposal counters.
|
|
125
|
+
"""
|
|
126
|
+
keys = split(key)
|
|
127
|
+
moves = propose_moves(keys.pop(), bart.forest)
|
|
128
|
+
return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def accept_moves_and_sample_leaves(
|
|
132
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
133
|
+
) -> State:
|
|
134
|
+
"""
|
|
135
|
+
Accept or reject the proposed moves and sample the new leaf values.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
key
|
|
140
|
+
A jax random key.
|
|
141
|
+
bart
|
|
142
|
+
A valid BART mcmc state.
|
|
143
|
+
moves
|
|
144
|
+
The proposed moves, see `propose_moves`.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
A new (valid) BART mcmc state.
|
|
149
|
+
"""
|
|
150
|
+
pso = accept_moves_parallel_stage(key, bart, moves)
|
|
151
|
+
bart, moves = accept_moves_sequential_stage(pso)
|
|
152
|
+
return accept_moves_final_stage(bart, moves)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Counts(Module):
|
|
156
|
+
"""Number of datapoints in the nodes involved in proposed moves for each tree."""
|
|
157
|
+
|
|
158
|
+
left: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
159
|
+
"""Number of datapoints in the left child."""
|
|
160
|
+
|
|
161
|
+
right: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
162
|
+
"""Number of datapoints in the right child."""
|
|
163
|
+
|
|
164
|
+
total: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
165
|
+
"""Number of datapoints in the parent (``= left + right``)."""
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class Precs(Module):
|
|
169
|
+
"""Likelihood precision scale in the nodes involved in proposed moves for each tree.
|
|
170
|
+
|
|
171
|
+
The "likelihood precision scale" of a tree node is the sum of the inverse
|
|
172
|
+
squared error scales of the datapoints selected by the node.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
left: Float32[Array, '*chains num_trees'] = field(chains=True)
|
|
176
|
+
"""Likelihood precision scale in the left child."""
|
|
177
|
+
|
|
178
|
+
right: Float32[Array, '*chains num_trees'] = field(chains=True)
|
|
179
|
+
"""Likelihood precision scale in the right child."""
|
|
180
|
+
|
|
181
|
+
total: Float32[Array, '*chains num_trees'] = field(chains=True)
|
|
182
|
+
"""Likelihood precision scale in the parent (``= left + right``)."""
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class PreLkV(Module):
|
|
186
|
+
"""Non-sequential terms of the likelihood ratio for each tree.
|
|
187
|
+
|
|
188
|
+
These terms can be computed in parallel across trees.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
left: (
|
|
192
|
+
Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
|
|
193
|
+
) = field(chains=True)
|
|
194
|
+
"""In the univariate case, this is the scalar term
|
|
195
|
+
|
|
196
|
+
``1 / error_cov_inv + n_left / leaf_prior_cov_inv``.
|
|
197
|
+
|
|
198
|
+
In the multivariate case, this is the matrix term
|
|
199
|
+
|
|
200
|
+
``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``.
|
|
201
|
+
|
|
202
|
+
``n_left`` is the number of datapoints in the left child, or the
|
|
203
|
+
likelihood precision scale in the heteroskedastic case."""
|
|
204
|
+
|
|
205
|
+
right: (
|
|
206
|
+
Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
|
|
207
|
+
) = field(chains=True)
|
|
208
|
+
"""In the univariate case, this is the scalar term
|
|
209
|
+
|
|
210
|
+
``1 / error_cov_inv + n_right / leaf_prior_cov_inv``.
|
|
211
|
+
|
|
212
|
+
In the multivariate case, this is the matrix term
|
|
213
|
+
|
|
214
|
+
``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``.
|
|
215
|
+
|
|
216
|
+
``n_right`` is the number of datapoints in the right child, or the
|
|
217
|
+
likelihood precision scale in the heteroskedastic case."""
|
|
218
|
+
|
|
219
|
+
total: (
|
|
220
|
+
Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
|
|
221
|
+
) = field(chains=True)
|
|
222
|
+
"""In the univariate case, this is the scalar term
|
|
223
|
+
|
|
224
|
+
``1 / error_cov_inv + n_total / leaf_prior_cov_inv``.
|
|
225
|
+
|
|
226
|
+
In the multivariate case, this is the matrix term
|
|
227
|
+
|
|
228
|
+
``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``.
|
|
229
|
+
|
|
230
|
+
``n_total`` is the number of datapoints in the parent node, or the
|
|
231
|
+
likelihood precision scale in the heteroskedastic case."""
|
|
232
|
+
|
|
233
|
+
log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True)
|
|
234
|
+
"""The logarithm of the square root term of the likelihood ratio."""
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class PreLk(Module):
|
|
238
|
+
"""Non-sequential terms of the likelihood ratio shared by all trees."""
|
|
239
|
+
|
|
240
|
+
exp_factor: Float32[Array, '*chains'] = field(chains=True)
|
|
241
|
+
"""The factor to multiply the likelihood ratio by, shared by all trees."""
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class PreLf(Module):
|
|
245
|
+
"""Pre-computed terms used to sample leaves from their posterior.
|
|
246
|
+
|
|
247
|
+
These terms can be computed in parallel across trees.
|
|
248
|
+
|
|
249
|
+
For each tree and leaf, the terms are scalars in the univariate case, and
|
|
250
|
+
matrices/vectors in the multivariate case.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
mean_factor: (
|
|
254
|
+
Float32[Array, '*chains num_trees 2**d']
|
|
255
|
+
| Float32[Array, '*chains num_trees k k 2**d']
|
|
256
|
+
) = field(chains=True)
|
|
257
|
+
"""The factor to be right-multiplied by the sum of the scaled residuals to
|
|
258
|
+
obtain the posterior mean."""
|
|
259
|
+
|
|
260
|
+
centered_leaves: (
|
|
261
|
+
Float32[Array, '*chains num_trees 2**d']
|
|
262
|
+
| Float32[Array, '*chains num_trees k 2**d']
|
|
263
|
+
) = field(chains=True)
|
|
264
|
+
"""The mean-zero normal values to be added to the posterior mean to
|
|
265
|
+
obtain the posterior leaf samples."""
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class ParallelStageOut(Module):
|
|
269
|
+
"""The output of `accept_moves_parallel_stage`."""
|
|
270
|
+
|
|
271
|
+
bart: State
|
|
272
|
+
"""A partially updated BART mcmc state."""
|
|
273
|
+
|
|
274
|
+
moves: Moves
|
|
275
|
+
"""The proposed moves, with `partial_ratio` set to `None` and
|
|
276
|
+
`log_trans_prior_ratio` set to its final value."""
|
|
277
|
+
|
|
278
|
+
prec_trees: (
|
|
279
|
+
Float32[Array, '*chains num_trees 2**d']
|
|
280
|
+
| Int32[Array, '*chains num_trees 2**d']
|
|
281
|
+
) = field(chains=True)
|
|
282
|
+
"""The likelihood precision scale in each potential or actual leaf node. If
|
|
283
|
+
there is no precision scale, this is the number of points in each leaf."""
|
|
284
|
+
|
|
285
|
+
move_precs: Precs | Counts
|
|
286
|
+
"""The likelihood precision scale in each node modified by the moves. If
|
|
287
|
+
`bart.prec_scale` is not set, this is set to `move_counts`."""
|
|
288
|
+
|
|
289
|
+
prelkv: PreLkV
|
|
290
|
+
"""Object with pre-computed terms of the likelihood ratios."""
|
|
291
|
+
|
|
292
|
+
prelk: PreLk | None
|
|
293
|
+
"""Object with pre-computed terms of the likelihood ratios."""
|
|
294
|
+
|
|
295
|
+
prelf: PreLf
|
|
296
|
+
"""Object with pre-computed terms of the leaf samples."""
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(1, 2))
|
|
300
|
+
@vmap_chains_if_profiling
|
|
301
|
+
def accept_moves_parallel_stage(
|
|
302
|
+
key: Key[Array, ''], bart: State, moves: Moves
|
|
303
|
+
) -> ParallelStageOut:
|
|
304
|
+
"""
|
|
305
|
+
Pre-compute quantities used to accept moves, in parallel across trees.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
key
|
|
310
|
+
A jax random key.
|
|
311
|
+
bart
|
|
312
|
+
A BART mcmc state.
|
|
313
|
+
moves
|
|
314
|
+
The proposed moves, see `propose_moves`.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
An object with all that could be done in parallel.
|
|
319
|
+
"""
|
|
320
|
+
# where the move is grow, modify the state like the move was accepted
|
|
321
|
+
bart = replace(
|
|
322
|
+
bart,
|
|
323
|
+
forest=replace(
|
|
324
|
+
bart.forest,
|
|
325
|
+
var_tree=moves.var_tree,
|
|
326
|
+
leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
|
|
327
|
+
leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
|
|
328
|
+
),
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# count number of datapoints per leaf
|
|
332
|
+
if (
|
|
333
|
+
bart.forest.min_points_per_decision_node is not None
|
|
334
|
+
or bart.forest.min_points_per_leaf is not None
|
|
335
|
+
or bart.prec_scale is None
|
|
336
|
+
):
|
|
337
|
+
count_trees, move_counts = compute_count_trees(
|
|
338
|
+
bart.forest.leaf_indices, moves, bart.config
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# mark which leaves & potential leaves have enough points to be grown
|
|
342
|
+
if bart.forest.min_points_per_decision_node is not None:
|
|
343
|
+
count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
|
|
344
|
+
moves = replace(
|
|
345
|
+
moves,
|
|
346
|
+
affluence_tree=moves.affluence_tree
|
|
347
|
+
& (count_half_trees >= bart.forest.min_points_per_decision_node),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# copy updated affluence_tree to state
|
|
351
|
+
bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
|
|
352
|
+
|
|
353
|
+
# veto grove move if new leaves don't have enough datapoints
|
|
354
|
+
if bart.forest.min_points_per_leaf is not None:
|
|
355
|
+
moves = replace(
|
|
356
|
+
moves,
|
|
357
|
+
allowed=moves.allowed
|
|
358
|
+
& (move_counts.left >= bart.forest.min_points_per_leaf)
|
|
359
|
+
& (move_counts.right >= bart.forest.min_points_per_leaf),
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# count number of datapoints per leaf, weighted by error precision scale
|
|
363
|
+
if bart.prec_scale is None:
|
|
364
|
+
prec_trees = count_trees
|
|
365
|
+
move_precs = move_counts
|
|
366
|
+
else:
|
|
367
|
+
prec_trees, move_precs = compute_prec_trees(
|
|
368
|
+
bart.prec_scale, bart.forest.leaf_indices, moves, bart.config
|
|
369
|
+
)
|
|
370
|
+
assert move_precs is not None
|
|
371
|
+
|
|
372
|
+
# compute some missing information about moves
|
|
373
|
+
moves = complete_ratio(moves, bart.forest.p_nonterminal)
|
|
374
|
+
save_ratios = bart.forest.log_likelihood is not None
|
|
375
|
+
bart = replace(
|
|
376
|
+
bart,
|
|
377
|
+
forest=replace(
|
|
378
|
+
bart.forest,
|
|
379
|
+
grow_prop_count=jnp.sum(moves.grow),
|
|
380
|
+
prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
|
|
381
|
+
log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
|
|
382
|
+
),
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
assert bart.error_cov_inv is not None
|
|
386
|
+
prelkv, prelk = precompute_likelihood_terms(
|
|
387
|
+
bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs
|
|
388
|
+
)
|
|
389
|
+
prelf = precompute_leaf_terms(
|
|
390
|
+
key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
return ParallelStageOut(
|
|
394
|
+
bart=bart,
|
|
395
|
+
moves=moves,
|
|
396
|
+
prec_trees=prec_trees,
|
|
397
|
+
move_precs=move_precs,
|
|
398
|
+
prelkv=prelkv,
|
|
399
|
+
prelk=prelk,
|
|
400
|
+
prelf=prelf,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@partial(vmap_nodoc, in_axes=(0, 0, None))
|
|
405
|
+
def apply_grow_to_indices(
|
|
406
|
+
moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
|
|
407
|
+
) -> UInt[Array, 'num_trees n']:
|
|
408
|
+
"""
|
|
409
|
+
Update the leaf indices to apply a grow move.
|
|
410
|
+
|
|
411
|
+
Parameters
|
|
412
|
+
----------
|
|
413
|
+
moves
|
|
414
|
+
The proposed moves, see `propose_moves`.
|
|
415
|
+
leaf_indices
|
|
416
|
+
The index of the leaf each datapoint falls into.
|
|
417
|
+
X
|
|
418
|
+
The predictors matrix.
|
|
419
|
+
|
|
420
|
+
Returns
|
|
421
|
+
-------
|
|
422
|
+
The updated leaf indices.
|
|
423
|
+
"""
|
|
424
|
+
left_child = moves.node.astype(leaf_indices.dtype) << 1
|
|
425
|
+
x: UInt[Array, ' n'] = X[moves.grow_var, :]
|
|
426
|
+
go_right = x >= moves.grow_split
|
|
427
|
+
tree_size = jnp.array(2 * moves.var_tree.size)
|
|
428
|
+
node_to_update = jnp.where(moves.grow, moves.node, tree_size)
|
|
429
|
+
return jnp.where(
|
|
430
|
+
leaf_indices == node_to_update, left_child + go_right, leaf_indices
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def _compute_count_or_prec_trees(
|
|
435
|
+
prec_scale: Float32[Array, ' n'] | None,
|
|
436
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
437
|
+
moves: Moves,
|
|
438
|
+
config: StepConfig,
|
|
439
|
+
) -> (
|
|
440
|
+
tuple[UInt32[Array, 'num_trees 2**d'], Counts]
|
|
441
|
+
| tuple[Float32[Array, 'num_trees 2**d'], Precs]
|
|
442
|
+
):
|
|
443
|
+
"""Implement `compute_count_trees` and `compute_prec_trees`."""
|
|
444
|
+
if config.prec_count_num_trees is None:
|
|
445
|
+
compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None))
|
|
446
|
+
return compute(prec_scale, leaf_indices, moves, config)
|
|
447
|
+
|
|
448
|
+
def compute(args):
|
|
449
|
+
leaf_indices, moves = args
|
|
450
|
+
return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config)
|
|
451
|
+
|
|
452
|
+
return lax.map(
|
|
453
|
+
compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _compute_count_or_prec_tree(
|
|
458
|
+
prec_scale: Float32[Array, ' n'] | None,
|
|
459
|
+
leaf_indices: UInt[Array, ' n'],
|
|
460
|
+
moves: Moves,
|
|
461
|
+
config: StepConfig,
|
|
462
|
+
) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
|
|
463
|
+
"""Compute count or precision tree for a single tree."""
|
|
464
|
+
(tree_size,) = moves.var_tree.shape
|
|
465
|
+
tree_size *= 2
|
|
466
|
+
|
|
467
|
+
if prec_scale is None:
|
|
468
|
+
value = 1
|
|
469
|
+
cls = Counts
|
|
470
|
+
dtype = jnp.uint32
|
|
471
|
+
num_batches = config.count_num_batches
|
|
472
|
+
else:
|
|
473
|
+
value = prec_scale
|
|
474
|
+
cls = Precs
|
|
475
|
+
dtype = jnp.float32
|
|
476
|
+
num_batches = config.prec_num_batches
|
|
477
|
+
|
|
478
|
+
trees = _scatter_add(
|
|
479
|
+
value, leaf_indices, tree_size, dtype, num_batches, config.mesh
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# count datapoints in nodes modified by move
|
|
483
|
+
left = trees[moves.left]
|
|
484
|
+
right = trees[moves.right]
|
|
485
|
+
counts = cls(left=left, right=right, total=left + right)
|
|
486
|
+
|
|
487
|
+
# write count into non-leaf node
|
|
488
|
+
trees = trees.at[moves.node].set(counts.total)
|
|
489
|
+
|
|
490
|
+
return trees, counts
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def compute_count_trees(
|
|
494
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig
|
|
495
|
+
) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]:
|
|
496
|
+
"""
|
|
497
|
+
Count the number of datapoints in each leaf.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
leaf_indices
|
|
502
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
503
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
504
|
+
moves
|
|
505
|
+
The proposed moves, see `propose_moves`.
|
|
506
|
+
config
|
|
507
|
+
The MCMC configuration.
|
|
508
|
+
|
|
509
|
+
Returns
|
|
510
|
+
-------
|
|
511
|
+
count_trees : Int32[Array, 'num_trees 2**d']
|
|
512
|
+
The number of points in each potential or actual leaf node.
|
|
513
|
+
counts : Counts
|
|
514
|
+
The counts of the number of points in the leaves grown or pruned by the
|
|
515
|
+
moves.
|
|
516
|
+
"""
|
|
517
|
+
return _compute_count_or_prec_trees(None, leaf_indices, moves, config)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def compute_prec_trees(
|
|
521
|
+
prec_scale: Float32[Array, ' n'],
|
|
522
|
+
leaf_indices: UInt[Array, 'num_trees n'],
|
|
523
|
+
moves: Moves,
|
|
524
|
+
config: StepConfig,
|
|
525
|
+
) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
|
|
526
|
+
"""
|
|
527
|
+
Compute the likelihood precision scale in each leaf.
|
|
528
|
+
|
|
529
|
+
Parameters
|
|
530
|
+
----------
|
|
531
|
+
prec_scale
|
|
532
|
+
The scale of the precision of the error on each datapoint.
|
|
533
|
+
leaf_indices
|
|
534
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
535
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
536
|
+
moves
|
|
537
|
+
The proposed moves, see `propose_moves`.
|
|
538
|
+
config
|
|
539
|
+
The MCMC configuration.
|
|
540
|
+
|
|
541
|
+
Returns
|
|
542
|
+
-------
|
|
543
|
+
prec_trees : Float32[Array, 'num_trees 2**d']
|
|
544
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
545
|
+
precs : Precs
|
|
546
|
+
The likelihood precision scale in the nodes involved in the moves.
|
|
547
|
+
"""
|
|
548
|
+
return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
@partial(vmap_nodoc, in_axes=(0, None))
|
|
552
|
+
def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
|
|
553
|
+
"""
|
|
554
|
+
Complete non-likelihood MH ratio calculation.
|
|
555
|
+
|
|
556
|
+
This function adds the probability of choosing a prune move over the grow
|
|
557
|
+
move in the inverse transition, and the a priori probability that the
|
|
558
|
+
children nodes are leaves.
|
|
559
|
+
|
|
560
|
+
Parameters
|
|
561
|
+
----------
|
|
562
|
+
moves
|
|
563
|
+
The proposed moves. Must have already been updated to keep into account
|
|
564
|
+
the thresholds on the number of datapoints per node, this happens in
|
|
565
|
+
`accept_moves_parallel_stage`.
|
|
566
|
+
p_nonterminal
|
|
567
|
+
The a priori probability of each node being nonterminal conditional on
|
|
568
|
+
its ancestors, including at the maximum depth where it should be zero.
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
|
|
573
|
+
"""
|
|
574
|
+
# can the leaves be grown?
|
|
575
|
+
left_growable = moves.affluence_tree.at[moves.left].get(
|
|
576
|
+
mode='fill', fill_value=False
|
|
577
|
+
)
|
|
578
|
+
right_growable = moves.affluence_tree.at[moves.right].get(
|
|
579
|
+
mode='fill', fill_value=False
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# p_prune if grow
|
|
583
|
+
other_growable_leaves = moves.num_growable >= 2
|
|
584
|
+
grow_again_allowed = other_growable_leaves | left_growable | right_growable
|
|
585
|
+
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0)
|
|
586
|
+
|
|
587
|
+
# p_prune if prune
|
|
588
|
+
prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
|
|
589
|
+
|
|
590
|
+
# select p_prune
|
|
591
|
+
p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
|
|
592
|
+
|
|
593
|
+
# prior probability of both children being terminal
|
|
594
|
+
pt_left = 1 - p_nonterminal[moves.left] * left_growable
|
|
595
|
+
pt_right = 1 - p_nonterminal[moves.right] * right_growable
|
|
596
|
+
pt_children = pt_left * pt_right
|
|
597
|
+
|
|
598
|
+
assert moves.partial_ratio is not None
|
|
599
|
+
return replace(
|
|
600
|
+
moves,
|
|
601
|
+
log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
|
|
602
|
+
partial_ratio=None,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@vmap_nodoc
|
|
607
|
+
def adapt_leaf_trees_to_grow_indices(
|
|
608
|
+
leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
|
|
609
|
+
) -> Float32[Array, 'num_trees 2**d']:
|
|
610
|
+
"""
|
|
611
|
+
Modify leaves such that post-grow indices work on the original tree.
|
|
612
|
+
|
|
613
|
+
The value of the leaf to grow is copied to what would be its children if the
|
|
614
|
+
grow move was accepted.
|
|
615
|
+
|
|
616
|
+
Parameters
|
|
617
|
+
----------
|
|
618
|
+
leaf_trees
|
|
619
|
+
The leaf values.
|
|
620
|
+
moves
|
|
621
|
+
The proposed moves, see `propose_moves`.
|
|
622
|
+
|
|
623
|
+
Returns
|
|
624
|
+
-------
|
|
625
|
+
The modified leaf values.
|
|
626
|
+
"""
|
|
627
|
+
values_at_node = leaf_trees[..., moves.node]
|
|
628
|
+
return (
|
|
629
|
+
leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)]
|
|
630
|
+
.set(values_at_node)
|
|
631
|
+
.at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)]
|
|
632
|
+
.set(values_at_node)
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
|
|
637
|
+
"""Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
|
|
638
|
+
diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1)
|
|
639
|
+
return 2.0 * jnp.sum(jnp.log(diags), axis=-1)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def _precompute_likelihood_terms_uv(
|
|
643
|
+
error_cov_inv: Float32[Array, ''],
|
|
644
|
+
leaf_prior_cov_inv: Float32[Array, ''],
|
|
645
|
+
move_precs: Precs | Counts,
|
|
646
|
+
) -> tuple[PreLkV, PreLk]:
|
|
647
|
+
sigma2 = lax.reciprocal(error_cov_inv)
|
|
648
|
+
sigma_mu2 = lax.reciprocal(leaf_prior_cov_inv)
|
|
649
|
+
left = sigma2 + move_precs.left * sigma_mu2
|
|
650
|
+
right = sigma2 + move_precs.right * sigma_mu2
|
|
651
|
+
total = sigma2 + move_precs.total * sigma_mu2
|
|
652
|
+
prelkv = PreLkV(
|
|
653
|
+
left=left,
|
|
654
|
+
right=right,
|
|
655
|
+
total=total,
|
|
656
|
+
log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2,
|
|
657
|
+
)
|
|
658
|
+
return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2)
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def _precompute_likelihood_terms_mv(
|
|
662
|
+
error_cov_inv: Float32[Array, 'k k'],
|
|
663
|
+
leaf_prior_cov_inv: Float32[Array, 'k k'],
|
|
664
|
+
move_precs: Counts,
|
|
665
|
+
) -> tuple[PreLkV, None]:
|
|
666
|
+
nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None]
|
|
667
|
+
nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None]
|
|
668
|
+
nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None]
|
|
669
|
+
|
|
670
|
+
L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh(
|
|
671
|
+
error_cov_inv * nL + leaf_prior_cov_inv
|
|
672
|
+
)
|
|
673
|
+
L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh(
|
|
674
|
+
error_cov_inv * nR + leaf_prior_cov_inv
|
|
675
|
+
)
|
|
676
|
+
L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh(
|
|
677
|
+
error_cov_inv * nT + leaf_prior_cov_inv
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * (
|
|
681
|
+
_logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
|
|
682
|
+
+ _logdet_from_chol(L_total)
|
|
683
|
+
- _logdet_from_chol(L_left)
|
|
684
|
+
- _logdet_from_chol(L_right)
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
def _term_from_chol(
|
|
688
|
+
L: Float32[Array, 'num_trees k k'],
|
|
689
|
+
) -> Float32[Array, 'num_trees k k']:
|
|
690
|
+
rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape)
|
|
691
|
+
Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True)
|
|
692
|
+
return Y.mT @ Y
|
|
693
|
+
|
|
694
|
+
prelkv = PreLkV(
|
|
695
|
+
left=_term_from_chol(L_left),
|
|
696
|
+
right=_term_from_chol(L_right),
|
|
697
|
+
total=_term_from_chol(L_total),
|
|
698
|
+
log_sqrt_term=log_sqrt_term,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
return prelkv, None
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def precompute_likelihood_terms(
|
|
705
|
+
error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
|
|
706
|
+
leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
|
|
707
|
+
move_precs: Precs | Counts,
|
|
708
|
+
) -> tuple[PreLkV, PreLk | None]:
|
|
709
|
+
"""
|
|
710
|
+
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
711
|
+
|
|
712
|
+
Handles both univariate and multivariate cases based on the shape of the
|
|
713
|
+
input arrays. The multivariate implementation assumes a homoskedastic error
|
|
714
|
+
model (i.e., the residual covariance is the same for all observations).
|
|
715
|
+
|
|
716
|
+
Parameters
|
|
717
|
+
----------
|
|
718
|
+
error_cov_inv
|
|
719
|
+
The inverse error variance (univariate) or the inverse of the error
|
|
720
|
+
covariance matrix (multivariate). For univariate case, this is the
|
|
721
|
+
inverse global error variance factor if `prec_scale` is set.
|
|
722
|
+
leaf_prior_cov_inv
|
|
723
|
+
The inverse prior variance of each leaf (univariate) or the inverse of
|
|
724
|
+
prior covariance matrix of each leaf (multivariate).
|
|
725
|
+
move_precs
|
|
726
|
+
The likelihood precision scale in the leaves grown or pruned by the
|
|
727
|
+
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
728
|
+
|
|
729
|
+
Returns
|
|
730
|
+
-------
|
|
731
|
+
prelkv : PreLkV
|
|
732
|
+
Pre-computed terms of the likelihood ratio, one per tree.
|
|
733
|
+
prelk : PreLk | None
|
|
734
|
+
Pre-computed terms of the likelihood ratio, shared by all trees.
|
|
735
|
+
"""
|
|
736
|
+
if error_cov_inv.ndim == 2:
|
|
737
|
+
assert isinstance(move_precs, Counts)
|
|
738
|
+
return _precompute_likelihood_terms_mv(
|
|
739
|
+
error_cov_inv, leaf_prior_cov_inv, move_precs
|
|
740
|
+
)
|
|
741
|
+
else:
|
|
742
|
+
return _precompute_likelihood_terms_uv(
|
|
743
|
+
error_cov_inv, leaf_prior_cov_inv, move_precs
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _precompute_leaf_terms_uv(
|
|
748
|
+
key: Key[Array, ''],
|
|
749
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
750
|
+
error_cov_inv: Float32[Array, ''],
|
|
751
|
+
leaf_prior_cov_inv: Float32[Array, ''],
|
|
752
|
+
z: Float32[Array, 'num_trees 2**d'] | None = None,
|
|
753
|
+
) -> PreLf:
|
|
754
|
+
prec_lk = prec_trees * error_cov_inv
|
|
755
|
+
var_post = lax.reciprocal(prec_lk + leaf_prior_cov_inv)
|
|
756
|
+
if z is None:
|
|
757
|
+
z = random.normal(key, prec_trees.shape, error_cov_inv.dtype)
|
|
758
|
+
return PreLf(
|
|
759
|
+
mean_factor=var_post * error_cov_inv,
|
|
760
|
+
# | mean = mean_lk * prec_lk * var_post
|
|
761
|
+
# | resid_tree = mean_lk * prec_tree -->
|
|
762
|
+
# | --> mean_lk = resid_tree / prec_tree (kind of)
|
|
763
|
+
# | mean_factor =
|
|
764
|
+
# | = mean / resid_tree =
|
|
765
|
+
# | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
|
|
766
|
+
# | = 1 / prec_tree * prec_tree / sigma2 * var_post =
|
|
767
|
+
# | = var_post / sigma2
|
|
768
|
+
centered_leaves=z * jnp.sqrt(var_post),
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
def _precompute_leaf_terms_mv(
|
|
773
|
+
key: Key[Array, ''],
|
|
774
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
775
|
+
error_cov_inv: Float32[Array, 'k k'],
|
|
776
|
+
leaf_prior_cov_inv: Float32[Array, 'k k'],
|
|
777
|
+
z: Float32[Array, 'num_trees 2**d k'] | None = None,
|
|
778
|
+
) -> PreLf:
|
|
779
|
+
num_trees, tree_size = prec_trees.shape
|
|
780
|
+
k = error_cov_inv.shape[0]
|
|
781
|
+
n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None]
|
|
782
|
+
|
|
783
|
+
# Only broadcast the inverse of error covariance matrix to satisfy JAX's
|
|
784
|
+
# batching rules for `lax.linalg.solve_triangular`, which does not support
|
|
785
|
+
# implicit broadcasting.
|
|
786
|
+
error_cov_inv_batched = jnp.broadcast_to(
|
|
787
|
+
error_cov_inv, (num_trees, tree_size, k, k)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
posterior_precision: Float32[Array, 'num_trees tree_size k k'] = (
|
|
791
|
+
leaf_prior_cov_inv + n_k * error_cov_inv_batched
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh(
|
|
795
|
+
posterior_precision
|
|
796
|
+
)
|
|
797
|
+
Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular(
|
|
798
|
+
L_prec, error_cov_inv_batched, lower=True
|
|
799
|
+
)
|
|
800
|
+
mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular(
|
|
801
|
+
L_prec, Y, trans='T', lower=True
|
|
802
|
+
)
|
|
803
|
+
mean_factor = mean_factor.mT
|
|
804
|
+
mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis(
|
|
805
|
+
mean_factor, 1, -1
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
if z is None:
|
|
809
|
+
z = random.normal(key, (num_trees, tree_size, k))
|
|
810
|
+
centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular(
|
|
811
|
+
L_prec, z, trans='T'
|
|
812
|
+
)
|
|
813
|
+
centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes(
|
|
814
|
+
centered_leaves, -1, -2
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out)
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def precompute_leaf_terms(
|
|
821
|
+
key: Key[Array, ''],
|
|
822
|
+
prec_trees: Float32[Array, 'num_trees 2**d'],
|
|
823
|
+
error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
|
|
824
|
+
leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
|
|
825
|
+
z: Float32[Array, 'num_trees 2**d']
|
|
826
|
+
| Float32[Array, 'num_trees 2**d k']
|
|
827
|
+
| None = None,
|
|
828
|
+
) -> PreLf:
|
|
829
|
+
"""
|
|
830
|
+
Pre-compute terms used to sample leaves from their posterior.
|
|
831
|
+
|
|
832
|
+
Handles both univariate and multivariate cases based on the shape of the
|
|
833
|
+
input arrays.
|
|
834
|
+
|
|
835
|
+
Parameters
|
|
836
|
+
----------
|
|
837
|
+
key
|
|
838
|
+
A jax random key.
|
|
839
|
+
prec_trees
|
|
840
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
841
|
+
error_cov_inv
|
|
842
|
+
The inverse error variance (univariate) or the inverse of error
|
|
843
|
+
covariance matrix (multivariate). For univariate case, this is the
|
|
844
|
+
inverse global error variance factor if `prec_scale` is set.
|
|
845
|
+
leaf_prior_cov_inv
|
|
846
|
+
The inverse prior variance of each leaf (univariate) or the inverse of
|
|
847
|
+
prior covariance matrix of each leaf (multivariate).
|
|
848
|
+
z
|
|
849
|
+
Optional standard normal noise to use for sampling the centered leaves.
|
|
850
|
+
This is intended for testing purposes only.
|
|
851
|
+
|
|
852
|
+
Returns
|
|
853
|
+
-------
|
|
854
|
+
Pre-computed terms for leaf sampling.
|
|
855
|
+
"""
|
|
856
|
+
if error_cov_inv.ndim == 2:
|
|
857
|
+
return _precompute_leaf_terms_mv(
|
|
858
|
+
key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
|
|
859
|
+
)
|
|
860
|
+
else:
|
|
861
|
+
return _precompute_leaf_terms_uv(
|
|
862
|
+
key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(0,))
|
|
867
|
+
@vmap_chains_if_profiling
|
|
868
|
+
def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
|
|
869
|
+
"""
|
|
870
|
+
Accept/reject the moves one tree at a time.
|
|
871
|
+
|
|
872
|
+
This is the most performance-sensitive function because it contains all and
|
|
873
|
+
only the parts of the algorithm that can not be parallelized across trees.
|
|
874
|
+
|
|
875
|
+
Parameters
|
|
876
|
+
----------
|
|
877
|
+
pso
|
|
878
|
+
The output of `accept_moves_parallel_stage`.
|
|
879
|
+
|
|
880
|
+
Returns
|
|
881
|
+
-------
|
|
882
|
+
bart : State
|
|
883
|
+
A partially updated BART mcmc state.
|
|
884
|
+
moves : Moves
|
|
885
|
+
The accepted/rejected moves, with `acc` and `to_prune` set.
|
|
886
|
+
"""
|
|
887
|
+
|
|
888
|
+
def loop(resid, pt):
|
|
889
|
+
resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
|
|
890
|
+
resid,
|
|
891
|
+
SeqStageInAllTrees(
|
|
892
|
+
pso.bart.X,
|
|
893
|
+
pso.bart.config.resid_num_batches,
|
|
894
|
+
pso.bart.config.mesh,
|
|
895
|
+
pso.bart.prec_scale,
|
|
896
|
+
pso.bart.forest.log_likelihood is not None,
|
|
897
|
+
pso.prelk,
|
|
898
|
+
),
|
|
899
|
+
pt,
|
|
900
|
+
)
|
|
901
|
+
return resid, (leaf_tree, acc, to_prune, lkratio)
|
|
902
|
+
|
|
903
|
+
pts = SeqStageInPerTree(
|
|
904
|
+
pso.bart.forest.leaf_tree,
|
|
905
|
+
pso.prec_trees,
|
|
906
|
+
pso.moves,
|
|
907
|
+
pso.move_precs,
|
|
908
|
+
pso.bart.forest.leaf_indices,
|
|
909
|
+
pso.prelkv,
|
|
910
|
+
pso.prelf,
|
|
911
|
+
)
|
|
912
|
+
resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
|
|
913
|
+
|
|
914
|
+
bart = replace(
|
|
915
|
+
pso.bart,
|
|
916
|
+
resid=resid,
|
|
917
|
+
forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
|
|
918
|
+
)
|
|
919
|
+
moves = replace(pso.moves, acc=acc, to_prune=to_prune)
|
|
920
|
+
|
|
921
|
+
return bart, moves
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
class SeqStageInAllTrees(Module):
|
|
925
|
+
"""The inputs to `accept_move_and_sample_leaves` that are shared by all trees."""
|
|
926
|
+
|
|
927
|
+
X: UInt[Array, 'p n']
|
|
928
|
+
"""The predictors."""
|
|
929
|
+
|
|
930
|
+
resid_num_batches: int | None = field(static=True)
|
|
931
|
+
"""The number of batches for computing the sum of residuals in each leaf."""
|
|
932
|
+
|
|
933
|
+
mesh: Mesh | None = field(static=True)
|
|
934
|
+
"""The mesh of devices to use."""
|
|
935
|
+
|
|
936
|
+
prec_scale: Float32[Array, ' n'] | None
|
|
937
|
+
"""The scale of the precision of the error on each datapoint. If None, it
|
|
938
|
+
is assumed to be 1."""
|
|
939
|
+
|
|
940
|
+
save_ratios: bool = field(static=True)
|
|
941
|
+
"""Whether to save the acceptance ratios."""
|
|
942
|
+
|
|
943
|
+
prelk: PreLk | None
|
|
944
|
+
"""The pre-computed terms of the likelihood ratio which are shared across
|
|
945
|
+
trees."""
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
class SeqStageInPerTree(Module):
|
|
949
|
+
"""The inputs to `accept_move_and_sample_leaves` that are separate for each tree."""
|
|
950
|
+
|
|
951
|
+
leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d']
|
|
952
|
+
"""The leaf values of the tree."""
|
|
953
|
+
|
|
954
|
+
prec_tree: Float32[Array, ' 2**d']
|
|
955
|
+
"""The likelihood precision scale in each potential or actual leaf node."""
|
|
956
|
+
|
|
957
|
+
move: Moves
|
|
958
|
+
"""The proposed move, see `propose_moves`."""
|
|
959
|
+
|
|
960
|
+
move_precs: Precs | Counts
|
|
961
|
+
"""The likelihood precision scale in each node modified by the moves."""
|
|
962
|
+
|
|
963
|
+
leaf_indices: UInt[Array, ' n']
|
|
964
|
+
"""The leaf indices for the largest version of the tree compatible with
|
|
965
|
+
the move."""
|
|
966
|
+
|
|
967
|
+
prelkv: PreLkV
|
|
968
|
+
"""The pre-computed terms of the likelihood ratio which are specific to the tree."""
|
|
969
|
+
|
|
970
|
+
prelf: PreLf
|
|
971
|
+
"""The pre-computed terms of the leaf sampling which are specific to the tree."""
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def accept_move_and_sample_leaves(
|
|
975
|
+
resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
|
|
976
|
+
at: SeqStageInAllTrees,
|
|
977
|
+
pt: SeqStageInPerTree,
|
|
978
|
+
) -> tuple[
|
|
979
|
+
Float32[Array, ' n'] | Float32[Array, ' k n'],
|
|
980
|
+
Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
|
|
981
|
+
Bool[Array, ''],
|
|
982
|
+
Bool[Array, ''],
|
|
983
|
+
Float32[Array, ''] | None,
|
|
984
|
+
]:
|
|
985
|
+
"""
|
|
986
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
987
|
+
|
|
988
|
+
Parameters
|
|
989
|
+
----------
|
|
990
|
+
resid
|
|
991
|
+
The residuals (data minus forest value).
|
|
992
|
+
at
|
|
993
|
+
The inputs that are the same for all trees.
|
|
994
|
+
pt
|
|
995
|
+
The inputs that are separate for each tree.
|
|
996
|
+
|
|
997
|
+
Returns
|
|
998
|
+
-------
|
|
999
|
+
resid : Float32[Array, 'n'] | Float32[Array, ' k n']
|
|
1000
|
+
The updated residuals (data minus forest value).
|
|
1001
|
+
leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d']
|
|
1002
|
+
The new leaf values of the tree.
|
|
1003
|
+
acc : Bool[Array, '']
|
|
1004
|
+
Whether the move was accepted.
|
|
1005
|
+
to_prune : Bool[Array, '']
|
|
1006
|
+
Whether, to reflect the acceptance status of the move, the state should
|
|
1007
|
+
be updated by pruning the leaves involved in the move.
|
|
1008
|
+
log_lk_ratio : Float32[Array, ''] | None
|
|
1009
|
+
The logarithm of the likelihood ratio for the move. `None` if not to be
|
|
1010
|
+
saved.
|
|
1011
|
+
"""
|
|
1012
|
+
# sum residuals in each leaf, in tree proposed by grow move
|
|
1013
|
+
if at.prec_scale is None:
|
|
1014
|
+
scaled_resid = resid
|
|
1015
|
+
else:
|
|
1016
|
+
scaled_resid = resid * at.prec_scale
|
|
1017
|
+
|
|
1018
|
+
tree_size = pt.leaf_tree.shape[-1] # 2**d
|
|
1019
|
+
|
|
1020
|
+
resid_tree = sum_resid(
|
|
1021
|
+
scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
# subtract starting tree from function
|
|
1025
|
+
resid_tree += pt.prec_tree * pt.leaf_tree
|
|
1026
|
+
|
|
1027
|
+
# sum residuals in parent node modified by move and compute likelihood
|
|
1028
|
+
resid_left = resid_tree[..., pt.move.left]
|
|
1029
|
+
resid_right = resid_tree[..., pt.move.right]
|
|
1030
|
+
resid_total = resid_left + resid_right
|
|
1031
|
+
assert pt.move.node.dtype == jnp.int32
|
|
1032
|
+
resid_tree = resid_tree.at[..., pt.move.node].set(resid_total)
|
|
1033
|
+
|
|
1034
|
+
log_lk_ratio = compute_likelihood_ratio(
|
|
1035
|
+
resid_total, resid_left, resid_right, pt.prelkv, at.prelk
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
# calculate accept/reject ratio
|
|
1039
|
+
log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
|
|
1040
|
+
log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
|
|
1041
|
+
if not at.save_ratios:
|
|
1042
|
+
log_lk_ratio = None
|
|
1043
|
+
|
|
1044
|
+
# determine whether to accept the move
|
|
1045
|
+
acc = pt.move.allowed & (pt.move.logu <= log_ratio)
|
|
1046
|
+
|
|
1047
|
+
# compute leaves posterior and sample leaves
|
|
1048
|
+
if resid.ndim > 1:
|
|
1049
|
+
mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree)
|
|
1050
|
+
else:
|
|
1051
|
+
mean_post = resid_tree * pt.prelf.mean_factor
|
|
1052
|
+
leaf_tree = mean_post + pt.prelf.centered_leaves
|
|
1053
|
+
|
|
1054
|
+
# copy leaves around such that the leaf indices point to the correct leaf
|
|
1055
|
+
to_prune = acc ^ pt.move.grow
|
|
1056
|
+
leaf_tree = (
|
|
1057
|
+
leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)]
|
|
1058
|
+
.set(leaf_tree[..., pt.move.node])
|
|
1059
|
+
.at[..., jnp.where(to_prune, pt.move.right, tree_size)]
|
|
1060
|
+
.set(leaf_tree[..., pt.move.node])
|
|
1061
|
+
)
|
|
1062
|
+
# replace old tree with new tree in function values
|
|
1063
|
+
resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices]
|
|
1064
|
+
|
|
1065
|
+
return resid, leaf_tree, acc, to_prune, log_lk_ratio
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)')
|
|
1069
|
+
def sum_resid(
|
|
1070
|
+
scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'],
|
|
1071
|
+
leaf_indices: UInt[Array, ' n'],
|
|
1072
|
+
tree_size: int,
|
|
1073
|
+
resid_num_batches: int | None,
|
|
1074
|
+
mesh: Mesh | None,
|
|
1075
|
+
) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']:
|
|
1076
|
+
"""
|
|
1077
|
+
Sum the residuals in each leaf.
|
|
1078
|
+
|
|
1079
|
+
Handles both univariate and multivariate cases based on the shape of the
|
|
1080
|
+
input arrays.
|
|
1081
|
+
|
|
1082
|
+
Parameters
|
|
1083
|
+
----------
|
|
1084
|
+
scaled_resid
|
|
1085
|
+
The residuals (data minus forest value) multiplied by the error
|
|
1086
|
+
precision scale. For multivariate case, shape is ``(k, n)`` where ``k``
|
|
1087
|
+
is the number of outcome columns.
|
|
1088
|
+
leaf_indices
|
|
1089
|
+
The leaf indices of the tree (in which leaf each data point falls into).
|
|
1090
|
+
tree_size
|
|
1091
|
+
The size of the tree array (2 ** d).
|
|
1092
|
+
resid_num_batches
|
|
1093
|
+
The number of batches for computing the sum of residuals in each leaf.
|
|
1094
|
+
mesh
|
|
1095
|
+
The mesh of devices to use.
|
|
1096
|
+
|
|
1097
|
+
Returns
|
|
1098
|
+
-------
|
|
1099
|
+
The sum of the residuals at data points in each leaf. For multivariate
|
|
1100
|
+
case, returns per-leaf sums of residual vectors.
|
|
1101
|
+
"""
|
|
1102
|
+
return _scatter_add(
|
|
1103
|
+
scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def _scatter_add(
|
|
1108
|
+
values: Float32[Array, ' n'] | int,
|
|
1109
|
+
indices: Integer[Array, ' n'],
|
|
1110
|
+
size: int,
|
|
1111
|
+
dtype: jnp.dtype,
|
|
1112
|
+
batch_size: int | None,
|
|
1113
|
+
mesh: Mesh | None,
|
|
1114
|
+
) -> Shaped[Array, ' {size}']:
|
|
1115
|
+
"""Indexed reduce with optional batching."""
|
|
1116
|
+
# check `values`
|
|
1117
|
+
values = jnp.asarray(values)
|
|
1118
|
+
assert values.ndim == 0 or values.shape == indices.shape
|
|
1119
|
+
|
|
1120
|
+
# set configuration
|
|
1121
|
+
_scatter_add = partial(
|
|
1122
|
+
_scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
# single-device invocation
|
|
1126
|
+
if mesh is None or 'data' not in mesh.axis_names:
|
|
1127
|
+
return _scatter_add(values, indices)
|
|
1128
|
+
|
|
1129
|
+
# multi-device invocation
|
|
1130
|
+
if values.shape:
|
|
1131
|
+
in_specs = PartitionSpec('data'), PartitionSpec('data')
|
|
1132
|
+
else:
|
|
1133
|
+
in_specs = PartitionSpec(), PartitionSpec('data')
|
|
1134
|
+
_scatter_add = partial(_scatter_add, final_psum=True)
|
|
1135
|
+
_scatter_add = shard_map(
|
|
1136
|
+
_scatter_add,
|
|
1137
|
+
in_specs=in_specs,
|
|
1138
|
+
out_specs=PartitionSpec(),
|
|
1139
|
+
mesh=mesh,
|
|
1140
|
+
**_get_shard_map_patch_kwargs(),
|
|
1141
|
+
)
|
|
1142
|
+
return _scatter_add(values, indices)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
def _get_shard_map_patch_kwargs():
|
|
1146
|
+
# see jax/issues/#34249, problem with vmap(shard_map(psum))
|
|
1147
|
+
# we tried the config jax_disable_vmap_shmap_error but it didn't work
|
|
1148
|
+
if jax.__version__ in ('0.8.1', '0.8.2'):
|
|
1149
|
+
return {'check_vma': False}
|
|
1150
|
+
else:
|
|
1151
|
+
return {}
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
def _scatter_add_impl(
|
|
1155
|
+
values: Float32[Array, ' n'] | Int32[Array, ''],
|
|
1156
|
+
indices: Integer[Array, ' n'],
|
|
1157
|
+
/,
|
|
1158
|
+
*,
|
|
1159
|
+
size: int,
|
|
1160
|
+
dtype: jnp.dtype,
|
|
1161
|
+
num_batches: int | None,
|
|
1162
|
+
final_psum: bool = False,
|
|
1163
|
+
) -> Shaped[Array, ' {size}']:
|
|
1164
|
+
if num_batches is None:
|
|
1165
|
+
out = jnp.zeros(size, dtype).at[indices].add(values)
|
|
1166
|
+
|
|
1167
|
+
else:
|
|
1168
|
+
# in the sharded case, n is the size of the local shard, not the full size
|
|
1169
|
+
(n,) = indices.shape
|
|
1170
|
+
batch_indices = jnp.arange(n) % num_batches
|
|
1171
|
+
out = (
|
|
1172
|
+
jnp.zeros((size, num_batches), dtype)
|
|
1173
|
+
.at[indices, batch_indices]
|
|
1174
|
+
.add(values)
|
|
1175
|
+
.sum(axis=1)
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
if final_psum:
|
|
1179
|
+
out = lax.psum(out, 'data')
|
|
1180
|
+
return out
|
|
1181
|
+
|
|
1182
|
+
|
|
1183
|
+
def _compute_likelihood_ratio_uv(
|
|
1184
|
+
total_resid: Float32[Array, ''],
|
|
1185
|
+
left_resid: Float32[Array, ''],
|
|
1186
|
+
right_resid: Float32[Array, ''],
|
|
1187
|
+
prelkv: PreLkV,
|
|
1188
|
+
prelk: PreLk,
|
|
1189
|
+
) -> Float32[Array, '']:
|
|
1190
|
+
exp_term = prelk.exp_factor * (
|
|
1191
|
+
left_resid * left_resid / prelkv.left
|
|
1192
|
+
+ right_resid * right_resid / prelkv.right
|
|
1193
|
+
- total_resid * total_resid / prelkv.total
|
|
1194
|
+
)
|
|
1195
|
+
return prelkv.log_sqrt_term + exp_term
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
def _compute_likelihood_ratio_mv(
|
|
1199
|
+
total_resid: Float32[Array, ' k'],
|
|
1200
|
+
left_resid: Float32[Array, ' k'],
|
|
1201
|
+
right_resid: Float32[Array, ' k'],
|
|
1202
|
+
prelkv: PreLkV,
|
|
1203
|
+
) -> Float32[Array, '']:
|
|
1204
|
+
def _quadratic_form(r, mat):
|
|
1205
|
+
return r @ mat @ r
|
|
1206
|
+
|
|
1207
|
+
qf_left = _quadratic_form(left_resid, prelkv.left)
|
|
1208
|
+
qf_right = _quadratic_form(right_resid, prelkv.right)
|
|
1209
|
+
qf_total = _quadratic_form(total_resid, prelkv.total)
|
|
1210
|
+
exp_term = 0.5 * (qf_left + qf_right - qf_total)
|
|
1211
|
+
return prelkv.log_sqrt_term + exp_term
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def compute_likelihood_ratio(
|
|
1215
|
+
total_resid: Float32[Array, ''] | Float32[Array, ' k'],
|
|
1216
|
+
left_resid: Float32[Array, ''] | Float32[Array, ' k'],
|
|
1217
|
+
right_resid: Float32[Array, ''] | Float32[Array, ' k'],
|
|
1218
|
+
prelkv: PreLkV,
|
|
1219
|
+
prelk: PreLk | None,
|
|
1220
|
+
) -> Float32[Array, '']:
|
|
1221
|
+
"""
|
|
1222
|
+
Compute the likelihood ratio of a grow move.
|
|
1223
|
+
|
|
1224
|
+
Handles both univariate and multivariate cases based on the shape of the
|
|
1225
|
+
residual arrays.
|
|
1226
|
+
|
|
1227
|
+
Parameters
|
|
1228
|
+
----------
|
|
1229
|
+
total_resid
|
|
1230
|
+
left_resid
|
|
1231
|
+
right_resid
|
|
1232
|
+
The sum of the residuals (scaled by error precision scale) of the
|
|
1233
|
+
datapoints falling in the nodes involved in the moves.
|
|
1234
|
+
prelkv
|
|
1235
|
+
prelk
|
|
1236
|
+
The pre-computed terms of the likelihood ratio, see
|
|
1237
|
+
`precompute_likelihood_terms`.
|
|
1238
|
+
|
|
1239
|
+
Returns
|
|
1240
|
+
-------
|
|
1241
|
+
The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
|
|
1242
|
+
"""
|
|
1243
|
+
if total_resid.ndim > 0:
|
|
1244
|
+
return _compute_likelihood_ratio_mv(
|
|
1245
|
+
total_resid, left_resid, right_resid, prelkv
|
|
1246
|
+
)
|
|
1247
|
+
else:
|
|
1248
|
+
assert prelk is not None
|
|
1249
|
+
return _compute_likelihood_ratio_uv(
|
|
1250
|
+
total_resid, left_resid, right_resid, prelkv, prelk
|
|
1251
|
+
)
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(0, 1))
|
|
1255
|
+
@vmap_chains_if_profiling
|
|
1256
|
+
def accept_moves_final_stage(bart: State, moves: Moves) -> State:
|
|
1257
|
+
"""
|
|
1258
|
+
Post-process the mcmc state after accepting/rejecting the moves.
|
|
1259
|
+
|
|
1260
|
+
This function is separate from `accept_moves_sequential_stage` to signal it
|
|
1261
|
+
can work in parallel across trees.
|
|
1262
|
+
|
|
1263
|
+
Parameters
|
|
1264
|
+
----------
|
|
1265
|
+
bart
|
|
1266
|
+
A partially updated BART mcmc state.
|
|
1267
|
+
moves
|
|
1268
|
+
The proposed moves (see `propose_moves`) as updated by
|
|
1269
|
+
`accept_moves_sequential_stage`.
|
|
1270
|
+
|
|
1271
|
+
Returns
|
|
1272
|
+
-------
|
|
1273
|
+
The fully updated BART mcmc state.
|
|
1274
|
+
"""
|
|
1275
|
+
return replace(
|
|
1276
|
+
bart,
|
|
1277
|
+
forest=replace(
|
|
1278
|
+
bart.forest,
|
|
1279
|
+
grow_acc_count=jnp.sum(moves.acc & moves.grow),
|
|
1280
|
+
prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
|
|
1281
|
+
leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
|
|
1282
|
+
split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
|
|
1283
|
+
),
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
@vmap_nodoc
|
|
1288
|
+
def apply_moves_to_leaf_indices(
|
|
1289
|
+
leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
|
|
1290
|
+
) -> UInt[Array, 'num_trees n']:
|
|
1291
|
+
"""
|
|
1292
|
+
Update the leaf indices to match the accepted move.
|
|
1293
|
+
|
|
1294
|
+
Parameters
|
|
1295
|
+
----------
|
|
1296
|
+
leaf_indices
|
|
1297
|
+
The index of the leaf each datapoint falls into, if the grow move was
|
|
1298
|
+
accepted.
|
|
1299
|
+
moves
|
|
1300
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1301
|
+
`accept_moves_sequential_stage`.
|
|
1302
|
+
|
|
1303
|
+
Returns
|
|
1304
|
+
-------
|
|
1305
|
+
The updated leaf indices.
|
|
1306
|
+
"""
|
|
1307
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1308
|
+
is_child = (leaf_indices & mask) == moves.left
|
|
1309
|
+
assert moves.to_prune is not None
|
|
1310
|
+
return jnp.where(
|
|
1311
|
+
is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
@vmap_nodoc
|
|
1316
|
+
def apply_moves_to_split_trees(
|
|
1317
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
|
|
1318
|
+
) -> UInt[Array, 'num_trees 2**(d-1)']:
|
|
1319
|
+
"""
|
|
1320
|
+
Update the split trees to match the accepted move.
|
|
1321
|
+
|
|
1322
|
+
Parameters
|
|
1323
|
+
----------
|
|
1324
|
+
split_tree
|
|
1325
|
+
The cutpoints of the decision nodes in the initial trees.
|
|
1326
|
+
moves
|
|
1327
|
+
The proposed moves (see `propose_moves`), as updated by
|
|
1328
|
+
`accept_moves_sequential_stage`.
|
|
1329
|
+
|
|
1330
|
+
Returns
|
|
1331
|
+
-------
|
|
1332
|
+
The updated split trees.
|
|
1333
|
+
"""
|
|
1334
|
+
assert moves.to_prune is not None
|
|
1335
|
+
return (
|
|
1336
|
+
split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
|
|
1337
|
+
.set(moves.grow_split.astype(split_tree.dtype))
|
|
1338
|
+
.at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
|
|
1339
|
+
.set(0)
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
@jax.jit
|
|
1344
|
+
def _sample_wishart_bartlett(
|
|
1345
|
+
key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k']
|
|
1346
|
+
) -> Float32[Array, 'k k']:
|
|
1347
|
+
"""
|
|
1348
|
+
Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
|
|
1349
|
+
|
|
1350
|
+
Parameters
|
|
1351
|
+
----------
|
|
1352
|
+
key
|
|
1353
|
+
A JAX random key
|
|
1354
|
+
df
|
|
1355
|
+
Degrees of freedom
|
|
1356
|
+
scale_inv
|
|
1357
|
+
Scale matrix of the corresponding Inverse Wishart distribution
|
|
1358
|
+
|
|
1359
|
+
Returns
|
|
1360
|
+
-------
|
|
1361
|
+
A sample from Wishart(df, scale)
|
|
1362
|
+
"""
|
|
1363
|
+
keys = split(key)
|
|
1364
|
+
|
|
1365
|
+
# Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
|
|
1366
|
+
# chi^2(k) = Gamma(k/2, scale=2)
|
|
1367
|
+
k, _ = scale_inv.shape
|
|
1368
|
+
df_vector = df - jnp.arange(k)
|
|
1369
|
+
chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0
|
|
1370
|
+
diag_A = jnp.sqrt(chi2_samples)
|
|
1371
|
+
|
|
1372
|
+
off_diag_A = random.normal(keys.pop(), (k, k))
|
|
1373
|
+
A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A)
|
|
1374
|
+
L = chol_with_gersh(scale_inv, absolute_eps=True)
|
|
1375
|
+
T = solve_triangular(L, A, lower=True, trans='T')
|
|
1376
|
+
|
|
1377
|
+
return T @ T.T
|
|
1378
|
+
|
|
1379
|
+
|
|
1380
|
+
def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State:
|
|
1381
|
+
resid = bart.resid
|
|
1382
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
1383
|
+
alpha = bart.error_cov_df / 2 + resid.size / 2
|
|
1384
|
+
if bart.prec_scale is None:
|
|
1385
|
+
scaled_resid = resid
|
|
1386
|
+
else:
|
|
1387
|
+
scaled_resid = resid * bart.prec_scale
|
|
1388
|
+
norm2 = resid @ scaled_resid
|
|
1389
|
+
beta = bart.error_cov_scale / 2 + norm2 / 2
|
|
1390
|
+
|
|
1391
|
+
sample = random.gamma(key, alpha)
|
|
1392
|
+
# random.gamma seems to be slow at compiling, maybe cdf inversion would
|
|
1393
|
+
# be better, but it's not implemented in jax
|
|
1394
|
+
return replace(bart, error_cov_inv=sample / beta)
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State:
|
|
1398
|
+
n = bart.resid.shape[-1]
|
|
1399
|
+
df_post = bart.error_cov_df + n
|
|
1400
|
+
scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T
|
|
1401
|
+
|
|
1402
|
+
prec = _sample_wishart_bartlett(key, df_post, scale_post)
|
|
1403
|
+
return replace(bart, error_cov_inv=prec)
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(1,))
|
|
1407
|
+
@vmap_chains_if_profiling
|
|
1408
|
+
def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State:
|
|
1409
|
+
"""
|
|
1410
|
+
MCMC-update the inverse error covariance.
|
|
1411
|
+
|
|
1412
|
+
Handles both univariate and multivariate cases based on the BART state's
|
|
1413
|
+
`kind` attribute.
|
|
1414
|
+
|
|
1415
|
+
Parameters
|
|
1416
|
+
----------
|
|
1417
|
+
key
|
|
1418
|
+
A jax random key.
|
|
1419
|
+
bart
|
|
1420
|
+
A BART mcmc state.
|
|
1421
|
+
|
|
1422
|
+
Returns
|
|
1423
|
+
-------
|
|
1424
|
+
The new BART mcmc state, with an updated `error_cov_inv`.
|
|
1425
|
+
"""
|
|
1426
|
+
assert bart.error_cov_inv is not None
|
|
1427
|
+
if bart.error_cov_inv.ndim == 2:
|
|
1428
|
+
return _step_error_cov_inv_mv(key, bart)
|
|
1429
|
+
else:
|
|
1430
|
+
return _step_error_cov_inv_uv(key, bart)
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(1,))
|
|
1434
|
+
@vmap_chains_if_profiling
|
|
1435
|
+
def step_z(key: Key[Array, ''], bart: State) -> State:
|
|
1436
|
+
"""
|
|
1437
|
+
MCMC-update the latent variable for binary regression.
|
|
1438
|
+
|
|
1439
|
+
Parameters
|
|
1440
|
+
----------
|
|
1441
|
+
key
|
|
1442
|
+
A jax random key.
|
|
1443
|
+
bart
|
|
1444
|
+
A BART MCMC state.
|
|
1445
|
+
|
|
1446
|
+
Returns
|
|
1447
|
+
-------
|
|
1448
|
+
The updated BART MCMC state.
|
|
1449
|
+
"""
|
|
1450
|
+
trees_plus_offset = bart.z - bart.resid
|
|
1451
|
+
assert bart.y.dtype == bool
|
|
1452
|
+
resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
|
|
1453
|
+
z = trees_plus_offset + resid
|
|
1454
|
+
return replace(bart, z=z, resid=resid)
|
|
1455
|
+
|
|
1456
|
+
|
|
1457
|
+
def step_s(key: Key[Array, ''], bart: State) -> State:
|
|
1458
|
+
"""
|
|
1459
|
+
Update `log_s` using Dirichlet sampling.
|
|
1460
|
+
|
|
1461
|
+
The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
|
|
1462
|
+
is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
|
|
1463
|
+
varcount is the count of how many times each variable is used in the
|
|
1464
|
+
current forest.
|
|
1465
|
+
|
|
1466
|
+
Parameters
|
|
1467
|
+
----------
|
|
1468
|
+
key
|
|
1469
|
+
Random key for sampling.
|
|
1470
|
+
bart
|
|
1471
|
+
The current BART state.
|
|
1472
|
+
|
|
1473
|
+
Returns
|
|
1474
|
+
-------
|
|
1475
|
+
Updated BART state with re-sampled `log_s`.
|
|
1476
|
+
|
|
1477
|
+
Notes
|
|
1478
|
+
-----
|
|
1479
|
+
This full conditional is approximated, because it does not take into account
|
|
1480
|
+
that there are forbidden decision rules.
|
|
1481
|
+
"""
|
|
1482
|
+
assert bart.forest.theta is not None
|
|
1483
|
+
|
|
1484
|
+
# histogram current variable usage
|
|
1485
|
+
p = bart.forest.max_split.size
|
|
1486
|
+
varcount = var_histogram(
|
|
1487
|
+
p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
# sample from Dirichlet posterior
|
|
1491
|
+
alpha = bart.forest.theta / p + varcount
|
|
1492
|
+
log_s = random.loggamma(key, alpha)
|
|
1493
|
+
|
|
1494
|
+
# update forest with new s
|
|
1495
|
+
return replace(bart, forest=replace(bart.forest, log_s=log_s))
|
|
1496
|
+
|
|
1497
|
+
|
|
1498
|
+
def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
|
|
1499
|
+
"""
|
|
1500
|
+
Update `theta`.
|
|
1501
|
+
|
|
1502
|
+
The prior is theta / (theta + rho) ~ Beta(a, b).
|
|
1503
|
+
|
|
1504
|
+
Parameters
|
|
1505
|
+
----------
|
|
1506
|
+
key
|
|
1507
|
+
Random key for sampling.
|
|
1508
|
+
bart
|
|
1509
|
+
The current BART state.
|
|
1510
|
+
num_grid
|
|
1511
|
+
The number of points in the evenly-spaced grid used to sample
|
|
1512
|
+
theta / (theta + rho).
|
|
1513
|
+
|
|
1514
|
+
Returns
|
|
1515
|
+
-------
|
|
1516
|
+
Updated BART state with re-sampled `theta`.
|
|
1517
|
+
"""
|
|
1518
|
+
assert bart.forest.log_s is not None
|
|
1519
|
+
assert bart.forest.rho is not None
|
|
1520
|
+
assert bart.forest.a is not None
|
|
1521
|
+
assert bart.forest.b is not None
|
|
1522
|
+
|
|
1523
|
+
# the grid points are the midpoints of num_grid bins in (0, 1)
|
|
1524
|
+
padding = 1 / (2 * num_grid)
|
|
1525
|
+
lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
|
|
1526
|
+
|
|
1527
|
+
# normalize s
|
|
1528
|
+
log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
|
|
1529
|
+
|
|
1530
|
+
# sample lambda
|
|
1531
|
+
logp, theta_grid = _log_p_lamda(
|
|
1532
|
+
lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
|
|
1533
|
+
)
|
|
1534
|
+
i = random.categorical(key, logp)
|
|
1535
|
+
theta = theta_grid[i]
|
|
1536
|
+
|
|
1537
|
+
return replace(bart, forest=replace(bart.forest, theta=theta))
|
|
1538
|
+
|
|
1539
|
+
|
|
1540
|
+
def _log_p_lamda(
|
|
1541
|
+
lamda: Float32[Array, ' num_grid'],
|
|
1542
|
+
log_s: Float32[Array, ' p'],
|
|
1543
|
+
rho: Float32[Array, ''],
|
|
1544
|
+
a: Float32[Array, ''],
|
|
1545
|
+
b: Float32[Array, ''],
|
|
1546
|
+
) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
|
|
1547
|
+
# in the following I use lamda[::-1] == 1 - lamda
|
|
1548
|
+
theta = rho * lamda / lamda[::-1]
|
|
1549
|
+
p = log_s.size
|
|
1550
|
+
return (
|
|
1551
|
+
(a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
|
|
1552
|
+
+ (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
|
|
1553
|
+
+ gammaln(theta)
|
|
1554
|
+
- p * gammaln(theta / p)
|
|
1555
|
+
+ theta / p * jnp.sum(log_s)
|
|
1556
|
+
), theta
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
@partial(jit_and_block_if_profiling, donate_argnums=(1,))
|
|
1560
|
+
@vmap_chains_if_profiling
|
|
1561
|
+
def step_sparse(key: Key[Array, ''], bart: State) -> State:
|
|
1562
|
+
"""
|
|
1563
|
+
Update the sparsity parameters.
|
|
1564
|
+
|
|
1565
|
+
This invokes `step_s`, and then `step_theta` only if the parameters of
|
|
1566
|
+
the theta prior are defined.
|
|
1567
|
+
|
|
1568
|
+
Parameters
|
|
1569
|
+
----------
|
|
1570
|
+
key
|
|
1571
|
+
Random key for sampling.
|
|
1572
|
+
bart
|
|
1573
|
+
The current BART state.
|
|
1574
|
+
|
|
1575
|
+
Returns
|
|
1576
|
+
-------
|
|
1577
|
+
Updated BART state with re-sampled `log_s` and `theta`.
|
|
1578
|
+
"""
|
|
1579
|
+
if bart.config.sparse_on_at is not None:
|
|
1580
|
+
bart = cond(
|
|
1581
|
+
bart.config.steps_done < bart.config.sparse_on_at,
|
|
1582
|
+
lambda _key, bart: bart,
|
|
1583
|
+
_step_sparse,
|
|
1584
|
+
key,
|
|
1585
|
+
bart,
|
|
1586
|
+
)
|
|
1587
|
+
return bart
|
|
1588
|
+
|
|
1589
|
+
|
|
1590
|
+
def _step_sparse(key, bart):
|
|
1591
|
+
keys = split(key)
|
|
1592
|
+
bart = step_s(keys.pop(), bart)
|
|
1593
|
+
if bart.forest.rho is not None:
|
|
1594
|
+
bart = step_theta(keys.pop(), bart)
|
|
1595
|
+
return bart
|
|
1596
|
+
|
|
1597
|
+
|
|
1598
|
+
@jit_if_profiling
|
|
1599
|
+
# jit to avoid the overhead of replace(_: Module)
|
|
1600
|
+
def step_config(bart):
|
|
1601
|
+
config = bart.config
|
|
1602
|
+
config = replace(config, steps_done=config.steps_done + 1)
|
|
1603
|
+
return replace(bart, config=config)
|