bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,904 @@
1
+ # bartz/src/bartz/mcmcstep/_moves.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implement `propose_moves` and associated dataclasses."""
26
+
27
+ from functools import partial
28
+
29
+ import jax
30
+ from equinox import Module
31
+ from jax import numpy as jnp
32
+ from jax import random
33
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
34
+
35
+ from bartz import grove
36
+ from bartz._profiler import jit_and_block_if_profiling
37
+ from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc
38
+ from bartz.mcmcstep._state import Forest, field, vmap_chains
39
+
40
+
41
+ class Moves(Module):
42
+ """Moves proposed to modify each tree."""
43
+
44
+ allowed: Bool[Array, '*chains num_trees'] = field(chains=True)
45
+ """Whether there is a possible move. If `False`, the other values may not
46
+ make sense. The only case in which a move is marked as allowed but is
47
+ then vetoed is if it does not satisfy `min_points_per_leaf`, which for
48
+ efficiency is implemented post-hoc without changing the rest of the
49
+ MCMC logic."""
50
+
51
+ grow: Bool[Array, '*chains num_trees'] = field(chains=True)
52
+ """Whether the move is a grow move or a prune move."""
53
+
54
+ num_growable: UInt[Array, '*chains num_trees'] = field(chains=True)
55
+ """The number of growable leaves in the original tree."""
56
+
57
+ node: UInt[Array, '*chains num_trees'] = field(chains=True)
58
+ """The index of the leaf to grow or node to prune."""
59
+
60
+ left: UInt[Array, '*chains num_trees'] = field(chains=True)
61
+ """The index of the left child of 'node'."""
62
+
63
+ right: UInt[Array, '*chains num_trees'] = field(chains=True)
64
+ """The index of the right child of 'node'."""
65
+
66
+ partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True)
67
+ """A factor of the Metropolis-Hastings ratio of the move. It lacks the
68
+ likelihood ratio, the probability of proposing the prune move, and the
69
+ probability that the children of the modified node are terminal. If the
70
+ move is PRUNE, the ratio is inverted. `None` once
71
+ `log_trans_prior_ratio` has been computed."""
72
+
73
+ log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field(
74
+ chains=True
75
+ )
76
+ """The logarithm of the product of the transition and prior terms of the
77
+ Metropolis-Hastings ratio for the acceptance of the proposed move.
78
+ `None` if not yet computed. If PRUNE, the log-ratio is negated."""
79
+
80
+ grow_var: UInt[Array, '*chains num_trees'] = field(chains=True)
81
+ """The decision axes of the new rules."""
82
+
83
+ grow_split: UInt[Array, '*chains num_trees'] = field(chains=True)
84
+ """The decision boundaries of the new rules."""
85
+
86
+ var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
87
+ """The updated decision axes of the trees, valid whatever move."""
88
+
89
+ affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
90
+ """A partially updated `affluence_tree`, marking non-leaf nodes that would
91
+ become leaves if the move was accepted. This mark initially (out of
92
+ `propose_moves`) takes into account if there would be available decision
93
+ rules to grow the leaf, and whether there are enough datapoints in the
94
+ node is instead checked later in `accept_moves_parallel_stage`."""
95
+
96
+ logu: Float32[Array, '*chains num_trees'] = field(chains=True)
97
+ """The logarithm of a uniform (0, 1] random variable to be used to
98
+ accept the move. It's in (-oo, 0]."""
99
+
100
+ acc: None | Bool[Array, '*chains num_trees'] = field(chains=True)
101
+ """Whether the move was accepted. `None` if not yet computed."""
102
+
103
+ to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True)
104
+ """Whether the final operation to apply the move is pruning. This indicates
105
+ an accepted prune move or a rejected grow move. `None` if not yet
106
+ computed."""
107
+
108
+
109
+ @jit_and_block_if_profiling
110
+ @vmap_chains
111
+ def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
112
+ """
113
+ Propose moves for all the trees.
114
+
115
+ There are two types of moves: GROW (convert a leaf to a decision node and
116
+ add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
117
+ leaf, deleting its children).
118
+
119
+ Parameters
120
+ ----------
121
+ key
122
+ A jax random key.
123
+ forest
124
+ The `forest` field of a BART MCMC state.
125
+
126
+ Returns
127
+ -------
128
+ The proposed move for each tree.
129
+ """
130
+ num_trees = forest.leaf_tree.shape[0]
131
+ keys = split(key, 2)
132
+ grow_keys, prune_keys = keys.pop((2, num_trees))
133
+
134
+ # compute moves
135
+ grow_moves = propose_grow_moves(
136
+ grow_keys,
137
+ forest.var_tree,
138
+ forest.split_tree,
139
+ forest.affluence_tree,
140
+ forest.max_split,
141
+ forest.blocked_vars,
142
+ forest.p_nonterminal,
143
+ forest.p_propose_grow,
144
+ forest.log_s,
145
+ )
146
+ prune_moves = propose_prune_moves(
147
+ prune_keys,
148
+ forest.split_tree,
149
+ grow_moves.affluence_tree,
150
+ forest.p_nonterminal,
151
+ forest.p_propose_grow,
152
+ )
153
+
154
+ u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees))
155
+
156
+ # choose between grow or prune
157
+ p_grow = jnp.where(
158
+ grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
159
+ )
160
+ grow = u < p_grow # use < instead of <= because u is in [0, 1)
161
+
162
+ # compute children indices
163
+ node = jnp.where(grow, grow_moves.node, prune_moves.node)
164
+ left, right = (node << 1) | jnp.arange(2)[:, None]
165
+
166
+ return Moves(
167
+ allowed=grow_moves.allowed | prune_moves.allowed,
168
+ grow=grow,
169
+ num_growable=grow_moves.num_growable,
170
+ node=node,
171
+ left=left,
172
+ right=right,
173
+ partial_ratio=jnp.where(
174
+ grow, grow_moves.partial_ratio, prune_moves.partial_ratio
175
+ ),
176
+ log_trans_prior_ratio=None, # will be set in complete_ratio
177
+ grow_var=grow_moves.var,
178
+ grow_split=grow_moves.split,
179
+ # var_tree does not need to be updated if prune
180
+ var_tree=grow_moves.var_tree,
181
+ # affluence_tree is updated for both moves unconditionally, prune last
182
+ affluence_tree=prune_moves.affluence_tree,
183
+ logu=jnp.log1p(-exp1mlogu),
184
+ acc=None, # will be set in accept_moves_sequential_stage
185
+ to_prune=None, # will be set in accept_moves_sequential_stage
186
+ )
187
+
188
+
189
+ class GrowMoves(Module):
190
+ """Represent a proposed grow move for each tree."""
191
+
192
+ allowed: Bool[Array, ' num_trees']
193
+ """Whether the move is allowed for proposal."""
194
+
195
+ num_growable: UInt[Array, ' num_trees']
196
+ """The number of leaves that can be proposed for grow."""
197
+
198
+ node: UInt[Array, ' num_trees']
199
+ """The index of the leaf to grow. ``2 ** d`` if there are no growable
200
+ leaves."""
201
+
202
+ var: UInt[Array, ' num_trees']
203
+ """The decision axis of the new rule."""
204
+
205
+ split: UInt[Array, ' num_trees']
206
+ """The decision boundary of the new rule."""
207
+
208
+ partial_ratio: Float32[Array, ' num_trees']
209
+ """A factor of the Metropolis-Hastings ratio of the move. It lacks
210
+ the likelihood ratio and the probability of proposing the prune
211
+ move."""
212
+
213
+ var_tree: UInt[Array, 'num_trees 2**(d-1)']
214
+ """The updated decision axes of the tree."""
215
+
216
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
217
+ """A partially updated `affluence_tree` that marks each new leaf that
218
+ would be produced as `True` if it would have available decision rules."""
219
+
220
+
221
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
222
+ def propose_grow_moves(
223
+ key: Key[Array, ' num_trees'],
224
+ var_tree: UInt[Array, 'num_trees 2**(d-1)'],
225
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'],
226
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
227
+ max_split: UInt[Array, ' p'],
228
+ blocked_vars: Int32[Array, ' k'] | None,
229
+ p_nonterminal: Float32[Array, ' 2**d'],
230
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
231
+ log_s: Float32[Array, ' p'] | None,
232
+ ) -> GrowMoves:
233
+ """
234
+ Propose a GROW move for each tree.
235
+
236
+ A GROW move picks a leaf node and converts it to a non-terminal node with
237
+ two leaf children.
238
+
239
+ Parameters
240
+ ----------
241
+ key
242
+ A jax random key.
243
+ var_tree
244
+ The splitting axes of the tree.
245
+ split_tree
246
+ The splitting points of the tree.
247
+ affluence_tree
248
+ Whether each leaf has enough points to be grown.
249
+ max_split
250
+ The maximum split index for each variable.
251
+ blocked_vars
252
+ The indices of the variables that have no available cutpoints.
253
+ p_nonterminal
254
+ The a priori probability of a node to be nonterminal conditional on the
255
+ ancestors, including at the maximum depth where it should be zero.
256
+ p_propose_grow
257
+ The unnormalized probability of choosing a leaf to grow.
258
+ log_s
259
+ Unnormalized log-probability used to choose a variable to split on
260
+ amongst the available ones.
261
+
262
+ Returns
263
+ -------
264
+ An object representing the proposed move.
265
+
266
+ Notes
267
+ -----
268
+ The move is not proposed if each leaf is already at maximum depth, or has
269
+ less datapoints than the requested threshold `min_points_per_decision_node`,
270
+ or it does not have any available decision rules given its ancestors. This
271
+ is marked by setting `allowed` to `False` and `num_growable` to 0.
272
+ """
273
+ keys = split(key, 3)
274
+
275
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
276
+ keys.pop(), split_tree, affluence_tree, p_propose_grow
277
+ )
278
+
279
+ # sample a decision rule
280
+ var, num_available_var = choose_variable(
281
+ keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
282
+ )
283
+ split_idx, l, r = choose_split(
284
+ keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
285
+ )
286
+
287
+ # determine if the new leaves would have available decision rules; if the
288
+ # move is blocked, these values may not make sense
289
+ leftright_growable = (num_available_var > 1) | jnp.stack(
290
+ [l < split_idx, split_idx + 1 < r]
291
+ )
292
+ leftright = (leaf_to_grow << 1) | jnp.arange(2)
293
+ affluence_tree = affluence_tree.at[leftright].set(leftright_growable)
294
+
295
+ ratio = compute_partial_ratio(
296
+ prob_choose, num_prunable, p_nonterminal, leaf_to_grow
297
+ )
298
+
299
+ return GrowMoves(
300
+ allowed=num_growable > 0,
301
+ num_growable=num_growable,
302
+ node=leaf_to_grow,
303
+ var=var,
304
+ split=split_idx,
305
+ partial_ratio=ratio,
306
+ var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
307
+ affluence_tree=affluence_tree,
308
+ )
309
+
310
+
311
+ def choose_leaf(
312
+ key: Key[Array, ''],
313
+ split_tree: UInt[Array, ' 2**(d-1)'],
314
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
315
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
316
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
317
+ """
318
+ Choose a leaf node to grow in a tree.
319
+
320
+ Parameters
321
+ ----------
322
+ key
323
+ A jax random key.
324
+ split_tree
325
+ The splitting points of the tree.
326
+ affluence_tree
327
+ Whether a leaf has enough points that it could be split into two leaves
328
+ satisfying the `min_points_per_decision_node` requirement.
329
+ p_propose_grow
330
+ The unnormalized probability of choosing a leaf to grow.
331
+
332
+ Returns
333
+ -------
334
+ leaf_to_grow : Int32[Array, '']
335
+ The index of the leaf to grow. If ``num_growable == 0``, return
336
+ ``2 ** d``.
337
+ num_growable : Int32[Array, '']
338
+ The number of leaf nodes that can be grown, i.e., are nonterminal
339
+ and have at least twice `min_points_per_decision_node`.
340
+ prob_choose : Float32[Array, '']
341
+ The (normalized) probability that this function had to choose that
342
+ specific leaf, given the arguments.
343
+ num_prunable : Int32[Array, '']
344
+ The number of leaf parents that could be pruned, after converting the
345
+ selected leaf to a non-terminal node.
346
+ """
347
+ is_growable = growable_leaves(split_tree, affluence_tree)
348
+ num_growable = jnp.count_nonzero(is_growable)
349
+ distr = jnp.where(is_growable, p_propose_grow, 0)
350
+ leaf_to_grow, distr_norm = categorical(key, distr)
351
+ leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
352
+ prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1)
353
+ is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
354
+ num_prunable = jnp.count_nonzero(is_parent)
355
+ return leaf_to_grow, num_growable, prob_choose, num_prunable
356
+
357
+
358
+ def growable_leaves(
359
+ split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
360
+ ) -> Bool[Array, ' 2**(d-1)']:
361
+ """
362
+ Return a mask indicating the leaf nodes that can be proposed for growth.
363
+
364
+ The condition is that a leaf is not at the bottom level, has available
365
+ decision rules given its ancestors, and has at least
366
+ `min_points_per_decision_node` points.
367
+
368
+ Parameters
369
+ ----------
370
+ split_tree
371
+ The splitting points of the tree.
372
+ affluence_tree
373
+ Marks leaves that can be grown.
374
+
375
+ Returns
376
+ -------
377
+ The mask indicating the leaf nodes that can be proposed to grow.
378
+
379
+ Notes
380
+ -----
381
+ This function needs `split_tree` and not just `affluence_tree` because
382
+ `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
383
+ """
384
+ return grove.is_actual_leaf(split_tree) & affluence_tree
385
+
386
+
387
+ def categorical(
388
+ key: Key[Array, ''], distr: Float32[Array, ' n']
389
+ ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
390
+ """
391
+ Return a random integer from an arbitrary distribution.
392
+
393
+ Parameters
394
+ ----------
395
+ key
396
+ A jax random key.
397
+ distr
398
+ An unnormalized probability distribution.
399
+
400
+ Returns
401
+ -------
402
+ u : Int32[Array, '']
403
+ A random integer in the range ``[0, n)``. If all probabilities are zero,
404
+ return ``n``.
405
+ norm : Float32[Array, '']
406
+ The sum of `distr`.
407
+
408
+ Notes
409
+ -----
410
+ This function uses a cumsum instead of the Gumbel trick, so it's ok only
411
+ for small ranges with probabilities well greater than 0.
412
+ """
413
+ ecdf = jnp.cumsum(distr)
414
+ u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
415
+ return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1]
416
+
417
+
418
+ def choose_variable(
419
+ key: Key[Array, ''],
420
+ var_tree: UInt[Array, ' 2**(d-1)'],
421
+ split_tree: UInt[Array, ' 2**(d-1)'],
422
+ max_split: UInt[Array, ' p'],
423
+ leaf_index: Int32[Array, ''],
424
+ blocked_vars: Int32[Array, ' k'] | None,
425
+ log_s: Float32[Array, ' p'] | None,
426
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
427
+ """
428
+ Choose a variable to split on for a new non-terminal node.
429
+
430
+ Parameters
431
+ ----------
432
+ key
433
+ A jax random key.
434
+ var_tree
435
+ The variable indices of the tree.
436
+ split_tree
437
+ The splitting points of the tree.
438
+ max_split
439
+ The maximum split index for each variable.
440
+ leaf_index
441
+ The index of the leaf to grow.
442
+ blocked_vars
443
+ The indices of the variables that have no available cutpoints. If
444
+ `None`, all variables are assumed unblocked.
445
+ log_s
446
+ The logarithm of the prior probability for choosing a variable. If
447
+ `None`, use a uniform distribution.
448
+
449
+ Returns
450
+ -------
451
+ var : Int32[Array, '']
452
+ The index of the variable to split on.
453
+ num_available_var : Int32[Array, '']
454
+ The number of variables with available decision rules `var` was chosen
455
+ from.
456
+ """
457
+ var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
458
+ if blocked_vars is not None:
459
+ var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars])
460
+
461
+ if log_s is None:
462
+ return randint_exclude(key, max_split.size, var_to_ignore)
463
+ else:
464
+ return categorical_exclude(key, log_s, var_to_ignore)
465
+
466
+
467
+ def fully_used_variables(
468
+ var_tree: UInt[Array, ' 2**(d-1)'],
469
+ split_tree: UInt[Array, ' 2**(d-1)'],
470
+ max_split: UInt[Array, ' p'],
471
+ leaf_index: Int32[Array, ''],
472
+ ) -> UInt[Array, ' d-2']:
473
+ """
474
+ Find variables in the ancestors of a node that have an empty split range.
475
+
476
+ Parameters
477
+ ----------
478
+ var_tree
479
+ The variable indices of the tree.
480
+ split_tree
481
+ The splitting points of the tree.
482
+ max_split
483
+ The maximum split index for each variable.
484
+ leaf_index
485
+ The index of the node, assumed to be valid for `var_tree`.
486
+
487
+ Returns
488
+ -------
489
+ The indices of the variables that have an empty split range.
490
+
491
+ Notes
492
+ -----
493
+ The number of unused variables is not known in advance. Unused values in the
494
+ array are filled with `p`. The fill values are not guaranteed to be placed
495
+ in any particular order, and variables may appear more than once.
496
+ """
497
+ var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
498
+ split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
499
+ l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
500
+ num_split = r - l
501
+ return jnp.where(num_split == 0, var_to_ignore, max_split.size)
502
+ # the type of var_to_ignore is already sufficient to hold max_split.size,
503
+ # see ancestor_variables()
504
+
505
+
506
+ def ancestor_variables(
507
+ var_tree: UInt[Array, ' 2**(d-1)'],
508
+ max_split: UInt[Array, ' p'],
509
+ node_index: Int32[Array, ''],
510
+ ) -> UInt[Array, ' d-2']:
511
+ """
512
+ Return the list of variables in the ancestors of a node.
513
+
514
+ Parameters
515
+ ----------
516
+ var_tree
517
+ The variable indices of the tree.
518
+ max_split
519
+ The maximum split index for each variable. Used only to get `p`.
520
+ node_index
521
+ The index of the node, assumed to be valid for `var_tree`.
522
+
523
+ Returns
524
+ -------
525
+ The variable indices of the ancestors of the node.
526
+
527
+ Notes
528
+ -----
529
+ The ancestors are the nodes going from the root to the parent of the node.
530
+ The number of ancestors is not known at tracing time; unused spots in the
531
+ output array are filled with `p`.
532
+ """
533
+ max_num_ancestors = grove.tree_depth(var_tree) - 1
534
+ index = node_index >> jnp.arange(max_num_ancestors, 0, -1)
535
+ var = var_tree[index]
536
+ var_type = minimal_unsigned_dtype(max_split.size)
537
+ p = jnp.array(max_split.size, var_type)
538
+ return jnp.where(index, var, p)
539
+
540
+
541
+ def split_range(
542
+ var_tree: UInt[Array, ' 2**(d-1)'],
543
+ split_tree: UInt[Array, ' 2**(d-1)'],
544
+ max_split: UInt[Array, ' p'],
545
+ node_index: Int32[Array, ''],
546
+ ref_var: Int32[Array, ''],
547
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
548
+ """
549
+ Return the range of allowed splits for a variable at a given node.
550
+
551
+ Parameters
552
+ ----------
553
+ var_tree
554
+ The variable indices of the tree.
555
+ split_tree
556
+ The splitting points of the tree.
557
+ max_split
558
+ The maximum split index for each variable.
559
+ node_index
560
+ The index of the node, assumed to be valid for `var_tree`.
561
+ ref_var
562
+ The variable for which to measure the split range.
563
+
564
+ Returns
565
+ -------
566
+ The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
567
+ """
568
+ max_num_ancestors = grove.tree_depth(var_tree) - 1
569
+ index = node_index >> jnp.arange(max_num_ancestors)
570
+ right_child = (index & 1).astype(bool)
571
+ index >>= 1
572
+ split = split_tree[index].astype(jnp.int32)
573
+ cond = (var_tree[index] == ref_var) & index.astype(bool)
574
+ l = jnp.max(split, initial=0, where=cond & right_child)
575
+ initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
576
+ jnp.int32
577
+ )
578
+ r = jnp.min(split, initial=initial_r, where=cond & ~right_child)
579
+
580
+ return l + 1, r
581
+
582
+
583
+ def randint_exclude(
584
+ key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
585
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
586
+ """
587
+ Return a random integer in a range, excluding some values.
588
+
589
+ Parameters
590
+ ----------
591
+ key
592
+ A jax random key.
593
+ sup
594
+ The exclusive upper bound of the range, must be >= 1.
595
+ exclude
596
+ The values to exclude from the range. Values greater than or equal to
597
+ `sup` are ignored. Values can appear more than once.
598
+
599
+ Returns
600
+ -------
601
+ u : Int32[Array, '']
602
+ A random integer `u` in the range ``[0, sup)`` such that ``u not in
603
+ exclude``.
604
+ num_allowed : Int32[Array, '']
605
+ The number of integers in the range that were not excluded.
606
+
607
+ Notes
608
+ -----
609
+ If all values in the range are excluded, return `sup`.
610
+ """
611
+ exclude, num_allowed = _process_exclude(sup, exclude)
612
+ u = random.randint(key, (), 0, num_allowed)
613
+ u_shifted = u + jnp.arange(exclude.size)
614
+ u_shifted = jnp.minimum(u_shifted, sup - 1)
615
+ u += jnp.sum(u_shifted >= exclude)
616
+ return u, num_allowed
617
+
618
+
619
+ def _process_exclude(sup, exclude):
620
+ exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
621
+ num_allowed = sup - jnp.sum(exclude < sup)
622
+ return exclude, num_allowed
623
+
624
+
625
+ def categorical_exclude(
626
+ key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
627
+ ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
628
+ """
629
+ Draw from a categorical distribution, excluding a set of values.
630
+
631
+ Parameters
632
+ ----------
633
+ key
634
+ A jax random key.
635
+ logits
636
+ The unnormalized log-probabilities of each category.
637
+ exclude
638
+ The values to exclude from the range [0, k). Values greater than or
639
+ equal to `logits.size` are ignored. Values can appear more than once.
640
+
641
+ Returns
642
+ -------
643
+ u : Int32[Array, '']
644
+ A random integer in the range ``[0, k)`` such that ``u not in exclude``.
645
+ num_allowed : Int32[Array, '']
646
+ The number of integers in the range that were not excluded.
647
+
648
+ Notes
649
+ -----
650
+ If all values in the range are excluded, the result is unspecified.
651
+ """
652
+ exclude, num_allowed = _process_exclude(logits.size, exclude)
653
+ kinda_neg_inf = jnp.finfo(logits.dtype).min
654
+ logits = logits.at[exclude].set(kinda_neg_inf)
655
+ u = random.categorical(key, logits)
656
+ return u, num_allowed
657
+
658
+
659
+ def choose_split(
660
+ key: Key[Array, ''],
661
+ var: Int32[Array, ''],
662
+ var_tree: UInt[Array, ' 2**(d-1)'],
663
+ split_tree: UInt[Array, ' 2**(d-1)'],
664
+ max_split: UInt[Array, ' p'],
665
+ leaf_index: Int32[Array, ''],
666
+ ) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
667
+ """
668
+ Choose a split point for a new non-terminal node.
669
+
670
+ Parameters
671
+ ----------
672
+ key
673
+ A jax random key.
674
+ var
675
+ The variable to split on.
676
+ var_tree
677
+ The splitting axes of the tree. Does not need to already contain `var`
678
+ at `leaf_index`.
679
+ split_tree
680
+ The splitting points of the tree.
681
+ max_split
682
+ The maximum split index for each variable.
683
+ leaf_index
684
+ The index of the leaf to grow.
685
+
686
+ Returns
687
+ -------
688
+ split : Int32[Array, '']
689
+ The cutpoint.
690
+ l : Int32[Array, '']
691
+ r : Int32[Array, '']
692
+ The integer range `split` was drawn from is [l, r).
693
+
694
+ Notes
695
+ -----
696
+ If `var` is out of bounds, or if the available split range on that variable
697
+ is empty, return 0.
698
+ """
699
+ l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
700
+ return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r
701
+
702
+
703
+ def compute_partial_ratio(
704
+ prob_choose: Float32[Array, ''],
705
+ num_prunable: Int32[Array, ''],
706
+ p_nonterminal: Float32[Array, ' 2**d'],
707
+ leaf_to_grow: Int32[Array, ''],
708
+ ) -> Float32[Array, '']:
709
+ """
710
+ Compute the product of the transition and prior ratios of a grow move.
711
+
712
+ Parameters
713
+ ----------
714
+ prob_choose
715
+ The probability that the leaf had to be chosen amongst the growable
716
+ leaves.
717
+ num_prunable
718
+ The number of leaf parents that could be pruned, after converting the
719
+ leaf to be grown to a non-terminal node.
720
+ p_nonterminal
721
+ The a priori probability of each node being nonterminal conditional on
722
+ its ancestors.
723
+ leaf_to_grow
724
+ The index of the leaf to grow.
725
+
726
+ Returns
727
+ -------
728
+ The partial transition ratio times the prior ratio.
729
+
730
+ Notes
731
+ -----
732
+ The transition ratio is P(new tree => old tree) / P(old tree => new tree).
733
+ The "partial" transition ratio returned is missing the factor P(propose
734
+ prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
735
+ "partial" prior ratio is missing the factor P(children are leaves).
736
+ """
737
+ # the two ratios also contain factors num_available_split *
738
+ # num_available_var * s[var], but they cancel out
739
+
740
+ # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
741
+ # computed here because they need the count trees, which are computed in the
742
+ # acceptance phase
743
+
744
+ prune_allowed = leaf_to_grow != 1
745
+ # prune allowed <---> the initial tree is not a root
746
+ # leaf to grow is root --> the tree can only be a root
747
+ # tree is a root --> the only leaf I can grow is root
748
+ p_grow = jnp.where(prune_allowed, 0.5, 1)
749
+ inv_trans_ratio = p_grow * prob_choose * num_prunable
750
+
751
+ # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
752
+ # would produce a 0 and then an inf when `complete_ratio` takes the log
753
+ pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5)
754
+ tree_ratio = pnt / (1 - pnt)
755
+
756
+ return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1)
757
+
758
+
759
+ class PruneMoves(Module):
760
+ """Represent a proposed prune move for each tree."""
761
+
762
+ allowed: Bool[Array, ' num_trees']
763
+ """Whether the move is possible."""
764
+
765
+ node: UInt[Array, ' num_trees']
766
+ """The index of the node to prune. ``2 ** d`` if no node can be pruned."""
767
+
768
+ partial_ratio: Float32[Array, ' num_trees']
769
+ """A factor of the Metropolis-Hastings ratio of the move. It lacks the
770
+ likelihood ratio, the probability of proposing the prune move, and the
771
+ prior probability that the children of the node to prune are leaves.
772
+ This ratio is inverted, and is meant to be inverted back in
773
+ `accept_move_and_sample_leaves`."""
774
+
775
+ affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
776
+ """A partially updated `affluence_tree`, marking the node to prune as
777
+ growable."""
778
+
779
+
780
+ @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
781
+ def propose_prune_moves(
782
+ key: Key[Array, ''],
783
+ split_tree: UInt[Array, ' 2**(d-1)'],
784
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
785
+ p_nonterminal: Float32[Array, ' 2**d'],
786
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
787
+ ) -> PruneMoves:
788
+ """
789
+ Tree structure prune move proposal of BART MCMC.
790
+
791
+ Parameters
792
+ ----------
793
+ key
794
+ A jax random key.
795
+ split_tree
796
+ The splitting points of the tree.
797
+ affluence_tree
798
+ Whether each leaf can be grown.
799
+ p_nonterminal
800
+ The a priori probability of a node to be nonterminal conditional on
801
+ the ancestors, including at the maximum depth where it should be zero.
802
+ p_propose_grow
803
+ The unnormalized probability of choosing a leaf to grow.
804
+
805
+ Returns
806
+ -------
807
+ An object representing the proposed moves.
808
+ """
809
+ node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent(
810
+ key, split_tree, affluence_tree, p_propose_grow
811
+ )
812
+
813
+ ratio = compute_partial_ratio(
814
+ prob_choose, num_prunable, p_nonterminal, node_to_prune
815
+ )
816
+
817
+ return PruneMoves(
818
+ allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
819
+ node=node_to_prune,
820
+ partial_ratio=ratio,
821
+ affluence_tree=affluence_tree,
822
+ )
823
+
824
+
825
+ def choose_leaf_parent(
826
+ key: Key[Array, ''],
827
+ split_tree: UInt[Array, ' 2**(d-1)'],
828
+ affluence_tree: Bool[Array, ' 2**(d-1)'],
829
+ p_propose_grow: Float32[Array, ' 2**(d-1)'],
830
+ ) -> tuple[
831
+ Int32[Array, ''],
832
+ Int32[Array, ''],
833
+ Float32[Array, ''],
834
+ Bool[Array, 'num_trees 2**(d-1)'],
835
+ ]:
836
+ """
837
+ Pick a non-terminal node with leaf children to prune in a tree.
838
+
839
+ Parameters
840
+ ----------
841
+ key
842
+ A jax random key.
843
+ split_tree
844
+ The splitting points of the tree.
845
+ affluence_tree
846
+ Whether a leaf has enough points to be grown.
847
+ p_propose_grow
848
+ The unnormalized probability of choosing a leaf to grow.
849
+
850
+ Returns
851
+ -------
852
+ node_to_prune : Int32[Array, '']
853
+ The index of the node to prune. If ``num_prunable == 0``, return
854
+ ``2 ** d``.
855
+ num_prunable : Int32[Array, '']
856
+ The number of leaf parents that could be pruned.
857
+ prob_choose : Float32[Array, '']
858
+ The (normalized) probability that `choose_leaf` would chose
859
+ `node_to_prune` as leaf to grow, if passed the tree where
860
+ `node_to_prune` had been pruned.
861
+ affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
862
+ A partially updated `affluence_tree`, marking the node to prune as
863
+ growable.
864
+ """
865
+ # sample a node to prune
866
+ is_prunable = grove.is_leaves_parent(split_tree)
867
+ num_prunable = jnp.count_nonzero(is_prunable)
868
+ node_to_prune = randint_masked(key, is_prunable)
869
+ node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
870
+
871
+ # compute stuff for reverse move
872
+ split_tree = split_tree.at[node_to_prune].set(0)
873
+ affluence_tree = affluence_tree.at[node_to_prune].set(True)
874
+ is_growable_leaf = growable_leaves(split_tree, affluence_tree)
875
+ distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf)
876
+ prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0)
877
+ prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1)
878
+
879
+ return node_to_prune, num_prunable, prob_choose, affluence_tree
880
+
881
+
882
+ def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
883
+ """
884
+ Return a random integer in a range, including only some values.
885
+
886
+ Parameters
887
+ ----------
888
+ key
889
+ A jax random key.
890
+ mask
891
+ The mask indicating the allowed values.
892
+
893
+ Returns
894
+ -------
895
+ A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
896
+
897
+ Notes
898
+ -----
899
+ If all values in the mask are `False`, return `n`. This function is
900
+ optimized for small `n`.
901
+ """
902
+ ecdf = jnp.cumsum(mask)
903
+ u = random.randint(key, (), 0, ecdf[-1])
904
+ return jnp.searchsorted(ecdf, u, 'right', method='compare_all')