bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcstep/_moves.py
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
1
|
+
# bartz/src/bartz/mcmcstep/_moves.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Implement `propose_moves` and associated dataclasses."""
|
|
26
|
+
|
|
27
|
+
from functools import partial
|
|
28
|
+
|
|
29
|
+
import jax
|
|
30
|
+
from equinox import Module
|
|
31
|
+
from jax import numpy as jnp
|
|
32
|
+
from jax import random
|
|
33
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
|
|
34
|
+
|
|
35
|
+
from bartz import grove
|
|
36
|
+
from bartz._profiler import jit_and_block_if_profiling
|
|
37
|
+
from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc
|
|
38
|
+
from bartz.mcmcstep._state import Forest, field, vmap_chains
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Moves(Module):
|
|
42
|
+
"""Moves proposed to modify each tree."""
|
|
43
|
+
|
|
44
|
+
allowed: Bool[Array, '*chains num_trees'] = field(chains=True)
|
|
45
|
+
"""Whether there is a possible move. If `False`, the other values may not
|
|
46
|
+
make sense. The only case in which a move is marked as allowed but is
|
|
47
|
+
then vetoed is if it does not satisfy `min_points_per_leaf`, which for
|
|
48
|
+
efficiency is implemented post-hoc without changing the rest of the
|
|
49
|
+
MCMC logic."""
|
|
50
|
+
|
|
51
|
+
grow: Bool[Array, '*chains num_trees'] = field(chains=True)
|
|
52
|
+
"""Whether the move is a grow move or a prune move."""
|
|
53
|
+
|
|
54
|
+
num_growable: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
55
|
+
"""The number of growable leaves in the original tree."""
|
|
56
|
+
|
|
57
|
+
node: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
58
|
+
"""The index of the leaf to grow or node to prune."""
|
|
59
|
+
|
|
60
|
+
left: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
61
|
+
"""The index of the left child of 'node'."""
|
|
62
|
+
|
|
63
|
+
right: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
64
|
+
"""The index of the right child of 'node'."""
|
|
65
|
+
|
|
66
|
+
partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True)
|
|
67
|
+
"""A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
68
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
69
|
+
probability that the children of the modified node are terminal. If the
|
|
70
|
+
move is PRUNE, the ratio is inverted. `None` once
|
|
71
|
+
`log_trans_prior_ratio` has been computed."""
|
|
72
|
+
|
|
73
|
+
log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field(
|
|
74
|
+
chains=True
|
|
75
|
+
)
|
|
76
|
+
"""The logarithm of the product of the transition and prior terms of the
|
|
77
|
+
Metropolis-Hastings ratio for the acceptance of the proposed move.
|
|
78
|
+
`None` if not yet computed. If PRUNE, the log-ratio is negated."""
|
|
79
|
+
|
|
80
|
+
grow_var: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
81
|
+
"""The decision axes of the new rules."""
|
|
82
|
+
|
|
83
|
+
grow_split: UInt[Array, '*chains num_trees'] = field(chains=True)
|
|
84
|
+
"""The decision boundaries of the new rules."""
|
|
85
|
+
|
|
86
|
+
var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
|
|
87
|
+
"""The updated decision axes of the trees, valid whatever move."""
|
|
88
|
+
|
|
89
|
+
affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
|
|
90
|
+
"""A partially updated `affluence_tree`, marking non-leaf nodes that would
|
|
91
|
+
become leaves if the move was accepted. This mark initially (out of
|
|
92
|
+
`propose_moves`) takes into account if there would be available decision
|
|
93
|
+
rules to grow the leaf, and whether there are enough datapoints in the
|
|
94
|
+
node is instead checked later in `accept_moves_parallel_stage`."""
|
|
95
|
+
|
|
96
|
+
logu: Float32[Array, '*chains num_trees'] = field(chains=True)
|
|
97
|
+
"""The logarithm of a uniform (0, 1] random variable to be used to
|
|
98
|
+
accept the move. It's in (-oo, 0]."""
|
|
99
|
+
|
|
100
|
+
acc: None | Bool[Array, '*chains num_trees'] = field(chains=True)
|
|
101
|
+
"""Whether the move was accepted. `None` if not yet computed."""
|
|
102
|
+
|
|
103
|
+
to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True)
|
|
104
|
+
"""Whether the final operation to apply the move is pruning. This indicates
|
|
105
|
+
an accepted prune move or a rejected grow move. `None` if not yet
|
|
106
|
+
computed."""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@jit_and_block_if_profiling
|
|
110
|
+
@vmap_chains
|
|
111
|
+
def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
|
|
112
|
+
"""
|
|
113
|
+
Propose moves for all the trees.
|
|
114
|
+
|
|
115
|
+
There are two types of moves: GROW (convert a leaf to a decision node and
|
|
116
|
+
add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
|
|
117
|
+
leaf, deleting its children).
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
key
|
|
122
|
+
A jax random key.
|
|
123
|
+
forest
|
|
124
|
+
The `forest` field of a BART MCMC state.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
The proposed move for each tree.
|
|
129
|
+
"""
|
|
130
|
+
num_trees = forest.leaf_tree.shape[0]
|
|
131
|
+
keys = split(key, 2)
|
|
132
|
+
grow_keys, prune_keys = keys.pop((2, num_trees))
|
|
133
|
+
|
|
134
|
+
# compute moves
|
|
135
|
+
grow_moves = propose_grow_moves(
|
|
136
|
+
grow_keys,
|
|
137
|
+
forest.var_tree,
|
|
138
|
+
forest.split_tree,
|
|
139
|
+
forest.affluence_tree,
|
|
140
|
+
forest.max_split,
|
|
141
|
+
forest.blocked_vars,
|
|
142
|
+
forest.p_nonterminal,
|
|
143
|
+
forest.p_propose_grow,
|
|
144
|
+
forest.log_s,
|
|
145
|
+
)
|
|
146
|
+
prune_moves = propose_prune_moves(
|
|
147
|
+
prune_keys,
|
|
148
|
+
forest.split_tree,
|
|
149
|
+
grow_moves.affluence_tree,
|
|
150
|
+
forest.p_nonterminal,
|
|
151
|
+
forest.p_propose_grow,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
|
|
155
|
+
|
|
156
|
+
# choose between grow or prune
|
|
157
|
+
p_grow = jnp.where(
|
|
158
|
+
grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
|
|
159
|
+
)
|
|
160
|
+
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
161
|
+
|
|
162
|
+
# compute children indices
|
|
163
|
+
node = jnp.where(grow, grow_moves.node, prune_moves.node)
|
|
164
|
+
left, right = (node << 1) | jnp.arange(2)[:, None]
|
|
165
|
+
|
|
166
|
+
return Moves(
|
|
167
|
+
allowed=grow_moves.allowed | prune_moves.allowed,
|
|
168
|
+
grow=grow,
|
|
169
|
+
num_growable=grow_moves.num_growable,
|
|
170
|
+
node=node,
|
|
171
|
+
left=left,
|
|
172
|
+
right=right,
|
|
173
|
+
partial_ratio=jnp.where(
|
|
174
|
+
grow, grow_moves.partial_ratio, prune_moves.partial_ratio
|
|
175
|
+
),
|
|
176
|
+
log_trans_prior_ratio=None, # will be set in complete_ratio
|
|
177
|
+
grow_var=grow_moves.var,
|
|
178
|
+
grow_split=grow_moves.split,
|
|
179
|
+
# var_tree does not need to be updated if prune
|
|
180
|
+
var_tree=grow_moves.var_tree,
|
|
181
|
+
# affluence_tree is updated for both moves unconditionally, prune last
|
|
182
|
+
affluence_tree=prune_moves.affluence_tree,
|
|
183
|
+
logu=jnp.log1p(-exp1mlogu),
|
|
184
|
+
acc=None, # will be set in accept_moves_sequential_stage
|
|
185
|
+
to_prune=None, # will be set in accept_moves_sequential_stage
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class GrowMoves(Module):
|
|
190
|
+
"""Represent a proposed grow move for each tree."""
|
|
191
|
+
|
|
192
|
+
allowed: Bool[Array, ' num_trees']
|
|
193
|
+
"""Whether the move is allowed for proposal."""
|
|
194
|
+
|
|
195
|
+
num_growable: UInt[Array, ' num_trees']
|
|
196
|
+
"""The number of leaves that can be proposed for grow."""
|
|
197
|
+
|
|
198
|
+
node: UInt[Array, ' num_trees']
|
|
199
|
+
"""The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
200
|
+
leaves."""
|
|
201
|
+
|
|
202
|
+
var: UInt[Array, ' num_trees']
|
|
203
|
+
"""The decision axis of the new rule."""
|
|
204
|
+
|
|
205
|
+
split: UInt[Array, ' num_trees']
|
|
206
|
+
"""The decision boundary of the new rule."""
|
|
207
|
+
|
|
208
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
209
|
+
"""A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
210
|
+
the likelihood ratio and the probability of proposing the prune
|
|
211
|
+
move."""
|
|
212
|
+
|
|
213
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)']
|
|
214
|
+
"""The updated decision axes of the tree."""
|
|
215
|
+
|
|
216
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
217
|
+
"""A partially updated `affluence_tree` that marks each new leaf that
|
|
218
|
+
would be produced as `True` if it would have available decision rules."""
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
|
|
222
|
+
def propose_grow_moves(
|
|
223
|
+
key: Key[Array, ' num_trees'],
|
|
224
|
+
var_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
225
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
226
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
|
|
227
|
+
max_split: UInt[Array, ' p'],
|
|
228
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
229
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
230
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
231
|
+
log_s: Float32[Array, ' p'] | None,
|
|
232
|
+
) -> GrowMoves:
|
|
233
|
+
"""
|
|
234
|
+
Propose a GROW move for each tree.
|
|
235
|
+
|
|
236
|
+
A GROW move picks a leaf node and converts it to a non-terminal node with
|
|
237
|
+
two leaf children.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
key
|
|
242
|
+
A jax random key.
|
|
243
|
+
var_tree
|
|
244
|
+
The splitting axes of the tree.
|
|
245
|
+
split_tree
|
|
246
|
+
The splitting points of the tree.
|
|
247
|
+
affluence_tree
|
|
248
|
+
Whether each leaf has enough points to be grown.
|
|
249
|
+
max_split
|
|
250
|
+
The maximum split index for each variable.
|
|
251
|
+
blocked_vars
|
|
252
|
+
The indices of the variables that have no available cutpoints.
|
|
253
|
+
p_nonterminal
|
|
254
|
+
The a priori probability of a node to be nonterminal conditional on the
|
|
255
|
+
ancestors, including at the maximum depth where it should be zero.
|
|
256
|
+
p_propose_grow
|
|
257
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
258
|
+
log_s
|
|
259
|
+
Unnormalized log-probability used to choose a variable to split on
|
|
260
|
+
amongst the available ones.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
An object representing the proposed move.
|
|
265
|
+
|
|
266
|
+
Notes
|
|
267
|
+
-----
|
|
268
|
+
The move is not proposed if each leaf is already at maximum depth, or has
|
|
269
|
+
less datapoints than the requested threshold `min_points_per_decision_node`,
|
|
270
|
+
or it does not have any available decision rules given its ancestors. This
|
|
271
|
+
is marked by setting `allowed` to `False` and `num_growable` to 0.
|
|
272
|
+
"""
|
|
273
|
+
keys = split(key, 3)
|
|
274
|
+
|
|
275
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
276
|
+
keys.pop(), split_tree, affluence_tree, p_propose_grow
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# sample a decision rule
|
|
280
|
+
var, num_available_var = choose_variable(
|
|
281
|
+
keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
|
|
282
|
+
)
|
|
283
|
+
split_idx, l, r = choose_split(
|
|
284
|
+
keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# determine if the new leaves would have available decision rules; if the
|
|
288
|
+
# move is blocked, these values may not make sense
|
|
289
|
+
leftright_growable = (num_available_var > 1) | jnp.stack(
|
|
290
|
+
[l < split_idx, split_idx + 1 < r]
|
|
291
|
+
)
|
|
292
|
+
leftright = (leaf_to_grow << 1) | jnp.arange(2)
|
|
293
|
+
affluence_tree = affluence_tree.at[leftright].set(leftright_growable)
|
|
294
|
+
|
|
295
|
+
ratio = compute_partial_ratio(
|
|
296
|
+
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return GrowMoves(
|
|
300
|
+
allowed=num_growable > 0,
|
|
301
|
+
num_growable=num_growable,
|
|
302
|
+
node=leaf_to_grow,
|
|
303
|
+
var=var,
|
|
304
|
+
split=split_idx,
|
|
305
|
+
partial_ratio=ratio,
|
|
306
|
+
var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
|
|
307
|
+
affluence_tree=affluence_tree,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def choose_leaf(
|
|
312
|
+
key: Key[Array, ''],
|
|
313
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
314
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
315
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
316
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
|
|
317
|
+
"""
|
|
318
|
+
Choose a leaf node to grow in a tree.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
key
|
|
323
|
+
A jax random key.
|
|
324
|
+
split_tree
|
|
325
|
+
The splitting points of the tree.
|
|
326
|
+
affluence_tree
|
|
327
|
+
Whether a leaf has enough points that it could be split into two leaves
|
|
328
|
+
satisfying the `min_points_per_decision_node` requirement.
|
|
329
|
+
p_propose_grow
|
|
330
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
leaf_to_grow : Int32[Array, '']
|
|
335
|
+
The index of the leaf to grow. If ``num_growable == 0``, return
|
|
336
|
+
``2 ** d``.
|
|
337
|
+
num_growable : Int32[Array, '']
|
|
338
|
+
The number of leaf nodes that can be grown, i.e., are nonterminal
|
|
339
|
+
and have at least twice `min_points_per_decision_node`.
|
|
340
|
+
prob_choose : Float32[Array, '']
|
|
341
|
+
The (normalized) probability that this function had to choose that
|
|
342
|
+
specific leaf, given the arguments.
|
|
343
|
+
num_prunable : Int32[Array, '']
|
|
344
|
+
The number of leaf parents that could be pruned, after converting the
|
|
345
|
+
selected leaf to a non-terminal node.
|
|
346
|
+
"""
|
|
347
|
+
is_growable = growable_leaves(split_tree, affluence_tree)
|
|
348
|
+
num_growable = jnp.count_nonzero(is_growable)
|
|
349
|
+
distr = jnp.where(is_growable, p_propose_grow, 0)
|
|
350
|
+
leaf_to_grow, distr_norm = categorical(key, distr)
|
|
351
|
+
leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
|
|
352
|
+
prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
|
|
353
|
+
is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
|
|
354
|
+
num_prunable = jnp.count_nonzero(is_parent)
|
|
355
|
+
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def growable_leaves(
|
|
359
|
+
split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
|
|
360
|
+
) -> Bool[Array, ' 2**(d-1)']:
|
|
361
|
+
"""
|
|
362
|
+
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
363
|
+
|
|
364
|
+
The condition is that a leaf is not at the bottom level, has available
|
|
365
|
+
decision rules given its ancestors, and has at least
|
|
366
|
+
`min_points_per_decision_node` points.
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
split_tree
|
|
371
|
+
The splitting points of the tree.
|
|
372
|
+
affluence_tree
|
|
373
|
+
Marks leaves that can be grown.
|
|
374
|
+
|
|
375
|
+
Returns
|
|
376
|
+
-------
|
|
377
|
+
The mask indicating the leaf nodes that can be proposed to grow.
|
|
378
|
+
|
|
379
|
+
Notes
|
|
380
|
+
-----
|
|
381
|
+
This function needs `split_tree` and not just `affluence_tree` because
|
|
382
|
+
`affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
|
|
383
|
+
"""
|
|
384
|
+
return grove.is_actual_leaf(split_tree) & affluence_tree
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def categorical(
|
|
388
|
+
key: Key[Array, ''], distr: Float32[Array, ' n']
|
|
389
|
+
) -> tuple[Int32[Array, ''], Float32[Array, '']]:
|
|
390
|
+
"""
|
|
391
|
+
Return a random integer from an arbitrary distribution.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
key
|
|
396
|
+
A jax random key.
|
|
397
|
+
distr
|
|
398
|
+
An unnormalized probability distribution.
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
u : Int32[Array, '']
|
|
403
|
+
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
404
|
+
return ``n``.
|
|
405
|
+
norm : Float32[Array, '']
|
|
406
|
+
The sum of `distr`.
|
|
407
|
+
|
|
408
|
+
Notes
|
|
409
|
+
-----
|
|
410
|
+
This function uses a cumsum instead of the Gumbel trick, so it's ok only
|
|
411
|
+
for small ranges with probabilities well greater than 0.
|
|
412
|
+
"""
|
|
413
|
+
ecdf = jnp.cumsum(distr)
|
|
414
|
+
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
415
|
+
return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1]
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def choose_variable(
|
|
419
|
+
key: Key[Array, ''],
|
|
420
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
421
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
422
|
+
max_split: UInt[Array, ' p'],
|
|
423
|
+
leaf_index: Int32[Array, ''],
|
|
424
|
+
blocked_vars: Int32[Array, ' k'] | None,
|
|
425
|
+
log_s: Float32[Array, ' p'] | None,
|
|
426
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
427
|
+
"""
|
|
428
|
+
Choose a variable to split on for a new non-terminal node.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
key
|
|
433
|
+
A jax random key.
|
|
434
|
+
var_tree
|
|
435
|
+
The variable indices of the tree.
|
|
436
|
+
split_tree
|
|
437
|
+
The splitting points of the tree.
|
|
438
|
+
max_split
|
|
439
|
+
The maximum split index for each variable.
|
|
440
|
+
leaf_index
|
|
441
|
+
The index of the leaf to grow.
|
|
442
|
+
blocked_vars
|
|
443
|
+
The indices of the variables that have no available cutpoints. If
|
|
444
|
+
`None`, all variables are assumed unblocked.
|
|
445
|
+
log_s
|
|
446
|
+
The logarithm of the prior probability for choosing a variable. If
|
|
447
|
+
`None`, use a uniform distribution.
|
|
448
|
+
|
|
449
|
+
Returns
|
|
450
|
+
-------
|
|
451
|
+
var : Int32[Array, '']
|
|
452
|
+
The index of the variable to split on.
|
|
453
|
+
num_available_var : Int32[Array, '']
|
|
454
|
+
The number of variables with available decision rules `var` was chosen
|
|
455
|
+
from.
|
|
456
|
+
"""
|
|
457
|
+
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
458
|
+
if blocked_vars is not None:
|
|
459
|
+
var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
|
|
460
|
+
|
|
461
|
+
if log_s is None:
|
|
462
|
+
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
463
|
+
else:
|
|
464
|
+
return categorical_exclude(key, log_s, var_to_ignore)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def fully_used_variables(
|
|
468
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
469
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
470
|
+
max_split: UInt[Array, ' p'],
|
|
471
|
+
leaf_index: Int32[Array, ''],
|
|
472
|
+
) -> UInt[Array, ' d-2']:
|
|
473
|
+
"""
|
|
474
|
+
Find variables in the ancestors of a node that have an empty split range.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
var_tree
|
|
479
|
+
The variable indices of the tree.
|
|
480
|
+
split_tree
|
|
481
|
+
The splitting points of the tree.
|
|
482
|
+
max_split
|
|
483
|
+
The maximum split index for each variable.
|
|
484
|
+
leaf_index
|
|
485
|
+
The index of the node, assumed to be valid for `var_tree`.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
The indices of the variables that have an empty split range.
|
|
490
|
+
|
|
491
|
+
Notes
|
|
492
|
+
-----
|
|
493
|
+
The number of unused variables is not known in advance. Unused values in the
|
|
494
|
+
array are filled with `p`. The fill values are not guaranteed to be placed
|
|
495
|
+
in any particular order, and variables may appear more than once.
|
|
496
|
+
"""
|
|
497
|
+
var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
|
|
498
|
+
split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
|
|
499
|
+
l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
|
|
500
|
+
num_split = r - l
|
|
501
|
+
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
502
|
+
# the type of var_to_ignore is already sufficient to hold max_split.size,
|
|
503
|
+
# see ancestor_variables()
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def ancestor_variables(
|
|
507
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
508
|
+
max_split: UInt[Array, ' p'],
|
|
509
|
+
node_index: Int32[Array, ''],
|
|
510
|
+
) -> UInt[Array, ' d-2']:
|
|
511
|
+
"""
|
|
512
|
+
Return the list of variables in the ancestors of a node.
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
var_tree
|
|
517
|
+
The variable indices of the tree.
|
|
518
|
+
max_split
|
|
519
|
+
The maximum split index for each variable. Used only to get `p`.
|
|
520
|
+
node_index
|
|
521
|
+
The index of the node, assumed to be valid for `var_tree`.
|
|
522
|
+
|
|
523
|
+
Returns
|
|
524
|
+
-------
|
|
525
|
+
The variable indices of the ancestors of the node.
|
|
526
|
+
|
|
527
|
+
Notes
|
|
528
|
+
-----
|
|
529
|
+
The ancestors are the nodes going from the root to the parent of the node.
|
|
530
|
+
The number of ancestors is not known at tracing time; unused spots in the
|
|
531
|
+
output array are filled with `p`.
|
|
532
|
+
"""
|
|
533
|
+
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
534
|
+
index = node_index >> jnp.arange(max_num_ancestors, 0, -1)
|
|
535
|
+
var = var_tree[index]
|
|
536
|
+
var_type = minimal_unsigned_dtype(max_split.size)
|
|
537
|
+
p = jnp.array(max_split.size, var_type)
|
|
538
|
+
return jnp.where(index, var, p)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def split_range(
|
|
542
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
543
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
544
|
+
max_split: UInt[Array, ' p'],
|
|
545
|
+
node_index: Int32[Array, ''],
|
|
546
|
+
ref_var: Int32[Array, ''],
|
|
547
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
548
|
+
"""
|
|
549
|
+
Return the range of allowed splits for a variable at a given node.
|
|
550
|
+
|
|
551
|
+
Parameters
|
|
552
|
+
----------
|
|
553
|
+
var_tree
|
|
554
|
+
The variable indices of the tree.
|
|
555
|
+
split_tree
|
|
556
|
+
The splitting points of the tree.
|
|
557
|
+
max_split
|
|
558
|
+
The maximum split index for each variable.
|
|
559
|
+
node_index
|
|
560
|
+
The index of the node, assumed to be valid for `var_tree`.
|
|
561
|
+
ref_var
|
|
562
|
+
The variable for which to measure the split range.
|
|
563
|
+
|
|
564
|
+
Returns
|
|
565
|
+
-------
|
|
566
|
+
The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
|
|
567
|
+
"""
|
|
568
|
+
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
569
|
+
index = node_index >> jnp.arange(max_num_ancestors)
|
|
570
|
+
right_child = (index & 1).astype(bool)
|
|
571
|
+
index >>= 1
|
|
572
|
+
split = split_tree[index].astype(jnp.int32)
|
|
573
|
+
cond = (var_tree[index] == ref_var) & index.astype(bool)
|
|
574
|
+
l = jnp.max(split, initial=0, where=cond & right_child)
|
|
575
|
+
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
576
|
+
jnp.int32
|
|
577
|
+
)
|
|
578
|
+
r = jnp.min(split, initial=initial_r, where=cond & ~right_child)
|
|
579
|
+
|
|
580
|
+
return l + 1, r
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def randint_exclude(
|
|
584
|
+
key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
|
|
585
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
586
|
+
"""
|
|
587
|
+
Return a random integer in a range, excluding some values.
|
|
588
|
+
|
|
589
|
+
Parameters
|
|
590
|
+
----------
|
|
591
|
+
key
|
|
592
|
+
A jax random key.
|
|
593
|
+
sup
|
|
594
|
+
The exclusive upper bound of the range, must be >= 1.
|
|
595
|
+
exclude
|
|
596
|
+
The values to exclude from the range. Values greater than or equal to
|
|
597
|
+
`sup` are ignored. Values can appear more than once.
|
|
598
|
+
|
|
599
|
+
Returns
|
|
600
|
+
-------
|
|
601
|
+
u : Int32[Array, '']
|
|
602
|
+
A random integer `u` in the range ``[0, sup)`` such that ``u not in
|
|
603
|
+
exclude``.
|
|
604
|
+
num_allowed : Int32[Array, '']
|
|
605
|
+
The number of integers in the range that were not excluded.
|
|
606
|
+
|
|
607
|
+
Notes
|
|
608
|
+
-----
|
|
609
|
+
If all values in the range are excluded, return `sup`.
|
|
610
|
+
"""
|
|
611
|
+
exclude, num_allowed = _process_exclude(sup, exclude)
|
|
612
|
+
u = random.randint(key, (), 0, num_allowed)
|
|
613
|
+
u_shifted = u + jnp.arange(exclude.size)
|
|
614
|
+
u_shifted = jnp.minimum(u_shifted, sup - 1)
|
|
615
|
+
u += jnp.sum(u_shifted >= exclude)
|
|
616
|
+
return u, num_allowed
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _process_exclude(sup, exclude):
|
|
620
|
+
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
621
|
+
num_allowed = sup - jnp.sum(exclude < sup)
|
|
622
|
+
return exclude, num_allowed
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def categorical_exclude(
|
|
626
|
+
key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
|
|
627
|
+
) -> tuple[Int32[Array, ''], Int32[Array, '']]:
|
|
628
|
+
"""
|
|
629
|
+
Draw from a categorical distribution, excluding a set of values.
|
|
630
|
+
|
|
631
|
+
Parameters
|
|
632
|
+
----------
|
|
633
|
+
key
|
|
634
|
+
A jax random key.
|
|
635
|
+
logits
|
|
636
|
+
The unnormalized log-probabilities of each category.
|
|
637
|
+
exclude
|
|
638
|
+
The values to exclude from the range [0, k). Values greater than or
|
|
639
|
+
equal to `logits.size` are ignored. Values can appear more than once.
|
|
640
|
+
|
|
641
|
+
Returns
|
|
642
|
+
-------
|
|
643
|
+
u : Int32[Array, '']
|
|
644
|
+
A random integer in the range ``[0, k)`` such that ``u not in exclude``.
|
|
645
|
+
num_allowed : Int32[Array, '']
|
|
646
|
+
The number of integers in the range that were not excluded.
|
|
647
|
+
|
|
648
|
+
Notes
|
|
649
|
+
-----
|
|
650
|
+
If all values in the range are excluded, the result is unspecified.
|
|
651
|
+
"""
|
|
652
|
+
exclude, num_allowed = _process_exclude(logits.size, exclude)
|
|
653
|
+
kinda_neg_inf = jnp.finfo(logits.dtype).min
|
|
654
|
+
logits = logits.at[exclude].set(kinda_neg_inf)
|
|
655
|
+
u = random.categorical(key, logits)
|
|
656
|
+
return u, num_allowed
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def choose_split(
|
|
660
|
+
key: Key[Array, ''],
|
|
661
|
+
var: Int32[Array, ''],
|
|
662
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
663
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
664
|
+
max_split: UInt[Array, ' p'],
|
|
665
|
+
leaf_index: Int32[Array, ''],
|
|
666
|
+
) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
|
|
667
|
+
"""
|
|
668
|
+
Choose a split point for a new non-terminal node.
|
|
669
|
+
|
|
670
|
+
Parameters
|
|
671
|
+
----------
|
|
672
|
+
key
|
|
673
|
+
A jax random key.
|
|
674
|
+
var
|
|
675
|
+
The variable to split on.
|
|
676
|
+
var_tree
|
|
677
|
+
The splitting axes of the tree. Does not need to already contain `var`
|
|
678
|
+
at `leaf_index`.
|
|
679
|
+
split_tree
|
|
680
|
+
The splitting points of the tree.
|
|
681
|
+
max_split
|
|
682
|
+
The maximum split index for each variable.
|
|
683
|
+
leaf_index
|
|
684
|
+
The index of the leaf to grow.
|
|
685
|
+
|
|
686
|
+
Returns
|
|
687
|
+
-------
|
|
688
|
+
split : Int32[Array, '']
|
|
689
|
+
The cutpoint.
|
|
690
|
+
l : Int32[Array, '']
|
|
691
|
+
r : Int32[Array, '']
|
|
692
|
+
The integer range `split` was drawn from is [l, r).
|
|
693
|
+
|
|
694
|
+
Notes
|
|
695
|
+
-----
|
|
696
|
+
If `var` is out of bounds, or if the available split range on that variable
|
|
697
|
+
is empty, return 0.
|
|
698
|
+
"""
|
|
699
|
+
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
700
|
+
return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def compute_partial_ratio(
|
|
704
|
+
prob_choose: Float32[Array, ''],
|
|
705
|
+
num_prunable: Int32[Array, ''],
|
|
706
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
707
|
+
leaf_to_grow: Int32[Array, ''],
|
|
708
|
+
) -> Float32[Array, '']:
|
|
709
|
+
"""
|
|
710
|
+
Compute the product of the transition and prior ratios of a grow move.
|
|
711
|
+
|
|
712
|
+
Parameters
|
|
713
|
+
----------
|
|
714
|
+
prob_choose
|
|
715
|
+
The probability that the leaf had to be chosen amongst the growable
|
|
716
|
+
leaves.
|
|
717
|
+
num_prunable
|
|
718
|
+
The number of leaf parents that could be pruned, after converting the
|
|
719
|
+
leaf to be grown to a non-terminal node.
|
|
720
|
+
p_nonterminal
|
|
721
|
+
The a priori probability of each node being nonterminal conditional on
|
|
722
|
+
its ancestors.
|
|
723
|
+
leaf_to_grow
|
|
724
|
+
The index of the leaf to grow.
|
|
725
|
+
|
|
726
|
+
Returns
|
|
727
|
+
-------
|
|
728
|
+
The partial transition ratio times the prior ratio.
|
|
729
|
+
|
|
730
|
+
Notes
|
|
731
|
+
-----
|
|
732
|
+
The transition ratio is P(new tree => old tree) / P(old tree => new tree).
|
|
733
|
+
The "partial" transition ratio returned is missing the factor P(propose
|
|
734
|
+
prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
|
|
735
|
+
"partial" prior ratio is missing the factor P(children are leaves).
|
|
736
|
+
"""
|
|
737
|
+
# the two ratios also contain factors num_available_split *
|
|
738
|
+
# num_available_var * s[var], but they cancel out
|
|
739
|
+
|
|
740
|
+
# p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
|
|
741
|
+
# computed here because they need the count trees, which are computed in the
|
|
742
|
+
# acceptance phase
|
|
743
|
+
|
|
744
|
+
prune_allowed = leaf_to_grow != 1
|
|
745
|
+
# prune allowed <---> the initial tree is not a root
|
|
746
|
+
# leaf to grow is root --> the tree can only be a root
|
|
747
|
+
# tree is a root --> the only leaf I can grow is root
|
|
748
|
+
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
749
|
+
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
750
|
+
|
|
751
|
+
# .at.get because if leaf_to_grow is out of bounds (move not allowed), this
|
|
752
|
+
# would produce a 0 and then an inf when `complete_ratio` takes the log
|
|
753
|
+
pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
|
|
754
|
+
tree_ratio = pnt / (1 - pnt)
|
|
755
|
+
|
|
756
|
+
return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
class PruneMoves(Module):
|
|
760
|
+
"""Represent a proposed prune move for each tree."""
|
|
761
|
+
|
|
762
|
+
allowed: Bool[Array, ' num_trees']
|
|
763
|
+
"""Whether the move is possible."""
|
|
764
|
+
|
|
765
|
+
node: UInt[Array, ' num_trees']
|
|
766
|
+
"""The index of the node to prune. ``2 ** d`` if no node can be pruned."""
|
|
767
|
+
|
|
768
|
+
partial_ratio: Float32[Array, ' num_trees']
|
|
769
|
+
"""A factor of the Metropolis-Hastings ratio of the move. It lacks the
|
|
770
|
+
likelihood ratio, the probability of proposing the prune move, and the
|
|
771
|
+
prior probability that the children of the node to prune are leaves.
|
|
772
|
+
This ratio is inverted, and is meant to be inverted back in
|
|
773
|
+
`accept_move_and_sample_leaves`."""
|
|
774
|
+
|
|
775
|
+
affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
|
|
776
|
+
"""A partially updated `affluence_tree`, marking the node to prune as
|
|
777
|
+
growable."""
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
|
|
781
|
+
def propose_prune_moves(
|
|
782
|
+
key: Key[Array, ''],
|
|
783
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
784
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
785
|
+
p_nonterminal: Float32[Array, ' 2**d'],
|
|
786
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
787
|
+
) -> PruneMoves:
|
|
788
|
+
"""
|
|
789
|
+
Tree structure prune move proposal of BART MCMC.
|
|
790
|
+
|
|
791
|
+
Parameters
|
|
792
|
+
----------
|
|
793
|
+
key
|
|
794
|
+
A jax random key.
|
|
795
|
+
split_tree
|
|
796
|
+
The splitting points of the tree.
|
|
797
|
+
affluence_tree
|
|
798
|
+
Whether each leaf can be grown.
|
|
799
|
+
p_nonterminal
|
|
800
|
+
The a priori probability of a node to be nonterminal conditional on
|
|
801
|
+
the ancestors, including at the maximum depth where it should be zero.
|
|
802
|
+
p_propose_grow
|
|
803
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
804
|
+
|
|
805
|
+
Returns
|
|
806
|
+
-------
|
|
807
|
+
An object representing the proposed moves.
|
|
808
|
+
"""
|
|
809
|
+
node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
|
|
810
|
+
key, split_tree, affluence_tree, p_propose_grow
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
ratio = compute_partial_ratio(
|
|
814
|
+
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
return PruneMoves(
|
|
818
|
+
allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
|
|
819
|
+
node=node_to_prune,
|
|
820
|
+
partial_ratio=ratio,
|
|
821
|
+
affluence_tree=affluence_tree,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def choose_leaf_parent(
|
|
826
|
+
key: Key[Array, ''],
|
|
827
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
828
|
+
affluence_tree: Bool[Array, ' 2**(d-1)'],
|
|
829
|
+
p_propose_grow: Float32[Array, ' 2**(d-1)'],
|
|
830
|
+
) -> tuple[
|
|
831
|
+
Int32[Array, ''],
|
|
832
|
+
Int32[Array, ''],
|
|
833
|
+
Float32[Array, ''],
|
|
834
|
+
Bool[Array, 'num_trees 2**(d-1)'],
|
|
835
|
+
]:
|
|
836
|
+
"""
|
|
837
|
+
Pick a non-terminal node with leaf children to prune in a tree.
|
|
838
|
+
|
|
839
|
+
Parameters
|
|
840
|
+
----------
|
|
841
|
+
key
|
|
842
|
+
A jax random key.
|
|
843
|
+
split_tree
|
|
844
|
+
The splitting points of the tree.
|
|
845
|
+
affluence_tree
|
|
846
|
+
Whether a leaf has enough points to be grown.
|
|
847
|
+
p_propose_grow
|
|
848
|
+
The unnormalized probability of choosing a leaf to grow.
|
|
849
|
+
|
|
850
|
+
Returns
|
|
851
|
+
-------
|
|
852
|
+
node_to_prune : Int32[Array, '']
|
|
853
|
+
The index of the node to prune. If ``num_prunable == 0``, return
|
|
854
|
+
``2 ** d``.
|
|
855
|
+
num_prunable : Int32[Array, '']
|
|
856
|
+
The number of leaf parents that could be pruned.
|
|
857
|
+
prob_choose : Float32[Array, '']
|
|
858
|
+
The (normalized) probability that `choose_leaf` would chose
|
|
859
|
+
`node_to_prune` as leaf to grow, if passed the tree where
|
|
860
|
+
`node_to_prune` had been pruned.
|
|
861
|
+
affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
|
|
862
|
+
A partially updated `affluence_tree`, marking the node to prune as
|
|
863
|
+
growable.
|
|
864
|
+
"""
|
|
865
|
+
# sample a node to prune
|
|
866
|
+
is_prunable = grove.is_leaves_parent(split_tree)
|
|
867
|
+
num_prunable = jnp.count_nonzero(is_prunable)
|
|
868
|
+
node_to_prune = randint_masked(key, is_prunable)
|
|
869
|
+
node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
|
|
870
|
+
|
|
871
|
+
# compute stuff for reverse move
|
|
872
|
+
split_tree = split_tree.at[node_to_prune].set(0)
|
|
873
|
+
affluence_tree = affluence_tree.at[node_to_prune].set(True)
|
|
874
|
+
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
875
|
+
distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
|
|
876
|
+
prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
|
|
877
|
+
prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
|
|
878
|
+
|
|
879
|
+
return node_to_prune, num_prunable, prob_choose, affluence_tree
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
|
|
883
|
+
"""
|
|
884
|
+
Return a random integer in a range, including only some values.
|
|
885
|
+
|
|
886
|
+
Parameters
|
|
887
|
+
----------
|
|
888
|
+
key
|
|
889
|
+
A jax random key.
|
|
890
|
+
mask
|
|
891
|
+
The mask indicating the allowed values.
|
|
892
|
+
|
|
893
|
+
Returns
|
|
894
|
+
-------
|
|
895
|
+
A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
|
|
896
|
+
|
|
897
|
+
Notes
|
|
898
|
+
-----
|
|
899
|
+
If all values in the mask are `False`, return `n`. This function is
|
|
900
|
+
optimized for small `n`.
|
|
901
|
+
"""
|
|
902
|
+
ecdf = jnp.cumsum(mask)
|
|
903
|
+
u = random.randint(key, (), 0, ecdf[-1])
|
|
904
|
+
return jnp.searchsorted(ecdf, u, 'right', method='compare_all')
|