bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/grove.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/grove.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,93 +22,122 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
25
|
+
"""Functions to create and manipulate binary decision trees."""
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
import math
|
|
28
|
+
from functools import partial
|
|
29
|
+
from typing import Protocol
|
|
28
30
|
|
|
29
|
-
|
|
31
|
+
from jax import jit, lax, vmap
|
|
32
|
+
from jax import numpy as jnp
|
|
33
|
+
from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt
|
|
30
34
|
|
|
31
|
-
|
|
35
|
+
try:
|
|
36
|
+
from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
|
|
37
|
+
except ImportError:
|
|
38
|
+
from numpy.core.numeric import normalize_axis_tuple # numpy 1
|
|
32
39
|
|
|
33
|
-
|
|
40
|
+
from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
|
|
34
41
|
|
|
35
|
-
The 'var' array contains the axes along which the decision nodes operate.
|
|
36
42
|
|
|
37
|
-
|
|
43
|
+
class TreeHeaps(Protocol):
|
|
44
|
+
"""A protocol for dataclasses that represent trees.
|
|
38
45
|
|
|
39
|
-
|
|
46
|
+
A tree is represented with arrays as a heap. The root node is at index 1.
|
|
47
|
+
The children nodes of a node at index :math:`i` are at indices :math:`2i`
|
|
48
|
+
(left child) and :math:`2i + 1` (right child). The array element at index 0
|
|
49
|
+
is unused.
|
|
40
50
|
|
|
41
|
-
|
|
51
|
+
Since the nodes at the bottom can only be leaves and not decision nodes,
|
|
52
|
+
`var_tree` and `split_tree` are half as long as `leaf_tree`.
|
|
42
53
|
|
|
43
|
-
|
|
44
|
-
|
|
54
|
+
Arrays may have additional initial axes to represent multiple trees.
|
|
55
|
+
"""
|
|
45
56
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
57
|
+
leaf_tree: (
|
|
58
|
+
Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
|
|
59
|
+
)
|
|
60
|
+
"""The values in the leaves of the trees. This array can be dirty, i.e.,
|
|
61
|
+
unused nodes can have whatever value. It may have an additional axis
|
|
62
|
+
for multivariate leaves."""
|
|
49
63
|
|
|
50
|
-
|
|
64
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)']
|
|
65
|
+
"""The axes along which the decision nodes operate. This array can be
|
|
66
|
+
dirty but for the always unused node at index 0 which must be set to 0."""
|
|
51
67
|
|
|
68
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)']
|
|
69
|
+
"""The decision boundaries of the trees. The boundaries are open on the
|
|
70
|
+
right, i.e., a point belongs to the left child iff x < split. Whether a
|
|
71
|
+
node is a leaf is indicated by the corresponding 'split' element being
|
|
72
|
+
0. Unused nodes also have split set to 0. This array can't be dirty."""
|
|
52
73
|
|
|
53
|
-
|
|
74
|
+
|
|
75
|
+
def make_tree(
|
|
76
|
+
depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = ()
|
|
77
|
+
) -> Shaped[Array, '*batch_shape 2**{depth}']:
|
|
54
78
|
"""
|
|
55
79
|
Make an array to represent a binary tree.
|
|
56
80
|
|
|
57
81
|
Parameters
|
|
58
82
|
----------
|
|
59
|
-
depth
|
|
83
|
+
depth
|
|
60
84
|
The maximum depth of the tree. Depth 1 means that there is only a root
|
|
61
85
|
node.
|
|
62
|
-
dtype
|
|
86
|
+
dtype
|
|
63
87
|
The dtype of the array.
|
|
88
|
+
batch_shape
|
|
89
|
+
The leading shape of the array, to represent multiple trees and/or
|
|
90
|
+
multivariate trees.
|
|
64
91
|
|
|
65
92
|
Returns
|
|
66
93
|
-------
|
|
67
|
-
|
|
68
|
-
An array of zeroes with shape (2 ** depth,).
|
|
94
|
+
An array of zeroes with the appropriate shape.
|
|
69
95
|
"""
|
|
70
|
-
|
|
96
|
+
shape = (*batch_shape, 2**depth)
|
|
97
|
+
return jnp.zeros(shape, dtype)
|
|
71
98
|
|
|
72
99
|
|
|
73
|
-
def tree_depth(tree):
|
|
100
|
+
def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
|
|
74
101
|
"""
|
|
75
102
|
Return the maximum depth of a tree.
|
|
76
103
|
|
|
77
104
|
Parameters
|
|
78
105
|
----------
|
|
79
|
-
tree
|
|
106
|
+
tree
|
|
80
107
|
A tree created by `make_tree`. If the array is ND, the tree structure is
|
|
81
108
|
assumed to be along the last axis.
|
|
82
109
|
|
|
83
110
|
Returns
|
|
84
111
|
-------
|
|
85
|
-
depth
|
|
86
|
-
The maximum depth of the tree.
|
|
112
|
+
The maximum depth of the tree.
|
|
87
113
|
"""
|
|
88
|
-
return
|
|
114
|
+
return round(math.log2(tree.shape[-1]))
|
|
89
115
|
|
|
90
116
|
|
|
91
|
-
def traverse_tree(
|
|
117
|
+
def traverse_tree(
|
|
118
|
+
x: UInt[Array, ' p'],
|
|
119
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
120
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
121
|
+
) -> UInt[Array, '']:
|
|
92
122
|
"""
|
|
93
123
|
Find the leaf where a point falls into.
|
|
94
124
|
|
|
95
125
|
Parameters
|
|
96
126
|
----------
|
|
97
|
-
x
|
|
127
|
+
x
|
|
98
128
|
The coordinates to evaluate the tree at.
|
|
99
|
-
var_tree
|
|
129
|
+
var_tree
|
|
100
130
|
The decision axes of the tree.
|
|
101
|
-
split_tree
|
|
131
|
+
split_tree
|
|
102
132
|
The decision boundaries of the tree.
|
|
103
133
|
|
|
104
134
|
Returns
|
|
105
135
|
-------
|
|
106
|
-
index
|
|
107
|
-
The index of the leaf.
|
|
136
|
+
The index of the leaf.
|
|
108
137
|
"""
|
|
109
138
|
carry = (
|
|
110
139
|
jnp.zeros((), bool),
|
|
111
|
-
jnp.ones((),
|
|
140
|
+
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
112
141
|
)
|
|
113
142
|
|
|
114
143
|
def loop(carry, _):
|
|
@@ -128,111 +157,132 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
128
157
|
return index
|
|
129
158
|
|
|
130
159
|
|
|
131
|
-
@
|
|
132
|
-
@
|
|
133
|
-
|
|
160
|
+
@jit
|
|
161
|
+
@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
|
|
162
|
+
@partial(vmap_nodoc, in_axes=(1, None, None))
|
|
163
|
+
def traverse_forest(
|
|
164
|
+
X: UInt[Array, 'p n'],
|
|
165
|
+
var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
|
|
166
|
+
split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
|
|
167
|
+
) -> UInt[Array, '*forest_shape n']:
|
|
134
168
|
"""
|
|
135
|
-
Find the leaves where points
|
|
169
|
+
Find the leaves where points falls into for each tree in a set.
|
|
136
170
|
|
|
137
171
|
Parameters
|
|
138
172
|
----------
|
|
139
|
-
X
|
|
173
|
+
X
|
|
140
174
|
The coordinates to evaluate the trees at.
|
|
141
|
-
var_trees
|
|
175
|
+
var_trees
|
|
142
176
|
The decision axes of the trees.
|
|
143
|
-
split_trees
|
|
177
|
+
split_trees
|
|
144
178
|
The decision boundaries of the trees.
|
|
145
179
|
|
|
146
180
|
Returns
|
|
147
181
|
-------
|
|
148
|
-
indices
|
|
149
|
-
The indices of the leaves.
|
|
182
|
+
The indices of the leaves.
|
|
150
183
|
"""
|
|
151
184
|
return traverse_tree(X, var_trees, split_trees)
|
|
152
185
|
|
|
153
186
|
|
|
154
|
-
|
|
187
|
+
@partial(jit, static_argnames=('sum_batch_axis',))
|
|
188
|
+
def evaluate_forest(
|
|
189
|
+
X: UInt[Array, 'p n'],
|
|
190
|
+
trees: TreeHeaps,
|
|
191
|
+
*,
|
|
192
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
193
|
+
) -> (
|
|
194
|
+
Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
|
|
195
|
+
):
|
|
155
196
|
"""
|
|
156
|
-
Evaluate
|
|
197
|
+
Evaluate an ensemble of trees at an array of points.
|
|
157
198
|
|
|
158
199
|
Parameters
|
|
159
200
|
----------
|
|
160
|
-
X
|
|
201
|
+
X
|
|
161
202
|
The coordinates to evaluate the trees at.
|
|
162
|
-
|
|
163
|
-
The
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
The decision boundaries of the trees.
|
|
169
|
-
dtype : dtype, optional
|
|
170
|
-
The dtype of the output. Ignored if `sum_trees` is `False`.
|
|
171
|
-
sum_trees : bool, default True
|
|
172
|
-
Whether to sum the values across trees.
|
|
203
|
+
trees
|
|
204
|
+
The trees.
|
|
205
|
+
sum_batch_axis
|
|
206
|
+
The batch axes to sum over. By default, no summation is performed.
|
|
207
|
+
Note that negative indices count from the end of the batch dimensions,
|
|
208
|
+
the core dimensions n and k can't be summed over by this function.
|
|
173
209
|
|
|
174
210
|
Returns
|
|
175
211
|
-------
|
|
176
|
-
|
|
177
|
-
The (sum of) the values of the trees at the points in `X`.
|
|
212
|
+
The (sum of) the values of the trees at the points in `X`.
|
|
178
213
|
"""
|
|
179
|
-
indices
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
214
|
+
indices: UInt[Array, '*forest_shape n']
|
|
215
|
+
indices = traverse_forest(X, trees.var_tree, trees.split_tree)
|
|
216
|
+
|
|
217
|
+
is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim
|
|
218
|
+
|
|
219
|
+
bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
|
|
220
|
+
bc_indices = indices[..., None, :, None] if is_mv else indices[..., None]
|
|
221
|
+
|
|
222
|
+
bc_leaf_tree: (
|
|
223
|
+
Float32[Array, '*forest_shape 1 tree_size']
|
|
224
|
+
| Float32[Array, '*forest_shape k 1 tree_size']
|
|
225
|
+
)
|
|
226
|
+
bc_leaf_tree = (
|
|
227
|
+
trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
bc_leaves: (
|
|
231
|
+
Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
|
|
232
|
+
)
|
|
233
|
+
bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1)
|
|
234
|
+
|
|
235
|
+
leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
|
|
236
|
+
leaves = jnp.squeeze(bc_leaves, -1)
|
|
237
|
+
|
|
238
|
+
axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1)
|
|
239
|
+
return jnp.sum(leaves, axis=axis)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def is_actual_leaf(
|
|
243
|
+
split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
|
|
244
|
+
) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
|
|
192
245
|
"""
|
|
193
246
|
Return a mask indicating the leaf nodes in a tree.
|
|
194
247
|
|
|
195
248
|
Parameters
|
|
196
249
|
----------
|
|
197
|
-
split_tree
|
|
250
|
+
split_tree
|
|
198
251
|
The splitting points of the tree.
|
|
199
|
-
add_bottom_level
|
|
252
|
+
add_bottom_level
|
|
200
253
|
If True, the bottom level of the tree is also considered.
|
|
201
254
|
|
|
202
255
|
Returns
|
|
203
256
|
-------
|
|
204
|
-
|
|
205
|
-
The mask indicating the leaf nodes. The length is doubled if
|
|
206
|
-
`add_bottom_level` is True.
|
|
257
|
+
The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
|
|
207
258
|
"""
|
|
208
259
|
size = split_tree.size
|
|
209
260
|
is_leaf = split_tree == 0
|
|
210
261
|
if add_bottom_level:
|
|
211
262
|
size *= 2
|
|
212
263
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
213
|
-
index = jnp.arange(size, dtype=
|
|
264
|
+
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
214
265
|
parent_index = index >> 1
|
|
215
266
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
216
267
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
217
268
|
return is_leaf & parent_nonleaf
|
|
218
269
|
|
|
219
270
|
|
|
220
|
-
def is_leaves_parent(split_tree):
|
|
271
|
+
def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
|
|
221
272
|
"""
|
|
222
273
|
Return a mask indicating the nodes with leaf (and only leaf) children.
|
|
223
274
|
|
|
224
275
|
Parameters
|
|
225
276
|
----------
|
|
226
|
-
split_tree
|
|
277
|
+
split_tree
|
|
227
278
|
The decision boundaries of the tree.
|
|
228
279
|
|
|
229
280
|
Returns
|
|
230
281
|
-------
|
|
231
|
-
|
|
232
|
-
The mask indicating which nodes have leaf children.
|
|
282
|
+
The mask indicating which nodes have leaf children.
|
|
233
283
|
"""
|
|
234
284
|
index = jnp.arange(
|
|
235
|
-
split_tree.size, dtype=
|
|
285
|
+
split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
|
|
236
286
|
)
|
|
237
287
|
left_index = index << 1 # left child
|
|
238
288
|
right_index = left_index + 1 # right child
|
|
@@ -243,45 +293,50 @@ def is_leaves_parent(split_tree):
|
|
|
243
293
|
# the 0-th item has split == 0, so it's not counted
|
|
244
294
|
|
|
245
295
|
|
|
246
|
-
def tree_depths(
|
|
296
|
+
def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
|
|
247
297
|
"""
|
|
248
298
|
Return the depth of each node in a binary tree.
|
|
249
299
|
|
|
250
300
|
Parameters
|
|
251
301
|
----------
|
|
252
|
-
|
|
302
|
+
tree_size
|
|
253
303
|
The length of the tree array, i.e., 2 ** d.
|
|
254
304
|
|
|
255
305
|
Returns
|
|
256
306
|
-------
|
|
257
|
-
depth
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
307
|
+
The depth of each node.
|
|
308
|
+
|
|
309
|
+
Notes
|
|
310
|
+
-----
|
|
311
|
+
The root node (index 1) has depth 0. The depth is the position of the most
|
|
312
|
+
significant non-zero bit in the index. The first element (the unused node)
|
|
313
|
+
is marked as depth 0.
|
|
261
314
|
"""
|
|
262
315
|
depths = []
|
|
263
316
|
depth = 0
|
|
264
|
-
for i in range(
|
|
317
|
+
for i in range(tree_size):
|
|
265
318
|
if i == 2**depth:
|
|
266
319
|
depth += 1
|
|
267
320
|
depths.append(depth - 1)
|
|
268
321
|
depths[0] = 0
|
|
269
|
-
return jnp.array(depths,
|
|
322
|
+
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
270
323
|
|
|
271
324
|
|
|
272
|
-
|
|
325
|
+
@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
|
|
326
|
+
def is_used(
|
|
327
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
328
|
+
) -> Bool[Array, '*batch_shape 2**d']:
|
|
273
329
|
"""
|
|
274
330
|
Return a mask indicating the used nodes in a tree.
|
|
275
331
|
|
|
276
332
|
Parameters
|
|
277
333
|
----------
|
|
278
|
-
split_tree
|
|
334
|
+
split_tree
|
|
279
335
|
The decision boundaries of the tree.
|
|
280
336
|
|
|
281
337
|
Returns
|
|
282
338
|
-------
|
|
283
|
-
|
|
284
|
-
A mask indicating which nodes are actually used.
|
|
339
|
+
A mask indicating which nodes are actually used.
|
|
285
340
|
"""
|
|
286
341
|
internal_node = split_tree.astype(bool)
|
|
287
342
|
internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
|
|
@@ -289,22 +344,69 @@ def is_used(split_tree):
|
|
|
289
344
|
return internal_node | actual_leaf
|
|
290
345
|
|
|
291
346
|
|
|
292
|
-
|
|
347
|
+
@jit
|
|
348
|
+
def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
|
|
293
349
|
"""
|
|
294
350
|
Return the fraction of used nodes in a set of trees.
|
|
295
351
|
|
|
296
352
|
Parameters
|
|
297
353
|
----------
|
|
298
|
-
|
|
354
|
+
split_tree
|
|
299
355
|
The decision boundaries of the trees.
|
|
300
356
|
|
|
301
357
|
Returns
|
|
302
358
|
-------
|
|
303
|
-
|
|
304
|
-
The number of tree nodes in the forest over the maximum number that
|
|
305
|
-
could be stored in the arrays.
|
|
359
|
+
Number of tree nodes over the maximum number that could be stored.
|
|
306
360
|
"""
|
|
307
|
-
|
|
308
|
-
used = jax.vmap(is_used)(split_trees)
|
|
361
|
+
used = is_used(split_tree)
|
|
309
362
|
count = jnp.count_nonzero(used)
|
|
310
|
-
|
|
363
|
+
batch_size = split_tree.size // split_tree.shape[-1]
|
|
364
|
+
return count / (used.size - batch_size)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@partial(jit, static_argnames=('p', 'sum_batch_axis'))
|
|
368
|
+
def var_histogram(
|
|
369
|
+
p: int,
|
|
370
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
371
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
372
|
+
*,
|
|
373
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
374
|
+
) -> Int32[Array, '*reduced_batch_shape {p}']:
|
|
375
|
+
"""
|
|
376
|
+
Count how many times each variable appears in a tree.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
p
|
|
381
|
+
The number of variables (the maximum value that can occur in `var_tree`
|
|
382
|
+
is ``p - 1``).
|
|
383
|
+
var_tree
|
|
384
|
+
The decision axes of the tree.
|
|
385
|
+
split_tree
|
|
386
|
+
The decision boundaries of the tree.
|
|
387
|
+
sum_batch_axis
|
|
388
|
+
The batch axes to sum over. By default, no summation is performed. Note
|
|
389
|
+
that negative indices count from the end of the batch dimensions, the
|
|
390
|
+
core dimension p can't be summed over by this function.
|
|
391
|
+
|
|
392
|
+
Returns
|
|
393
|
+
-------
|
|
394
|
+
The histogram(s) of the variables used in the tree.
|
|
395
|
+
"""
|
|
396
|
+
is_internal = split_tree.astype(bool)
|
|
397
|
+
|
|
398
|
+
def scatter_add(
|
|
399
|
+
var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
|
|
400
|
+
is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
|
|
401
|
+
) -> Int32[Array, ' p']:
|
|
402
|
+
return jnp.zeros(p, int).at[var_tree].add(is_internal)
|
|
403
|
+
|
|
404
|
+
# vmap scatter_add over non-batched dims
|
|
405
|
+
batch_ndim = var_tree.ndim - 1
|
|
406
|
+
axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
|
|
407
|
+
for i in reversed(range(batch_ndim)):
|
|
408
|
+
neg_i = i - var_tree.ndim
|
|
409
|
+
if i not in axes:
|
|
410
|
+
scatter_add = vmap(scatter_add, in_axes=neg_i)
|
|
411
|
+
|
|
412
|
+
return scatter_add(var_tree, is_internal)
|