bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/grove.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/grove.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,93 +22,122 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """
25
+ """Functions to create and manipulate binary decision trees."""
26
26
 
27
- Functions to create and manipulate binary trees.
27
+ import math
28
+ from functools import partial
29
+ from typing import Protocol
28
30
 
29
- A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
31
+ from jax import jit, lax, vmap
32
+ from jax import numpy as jnp
33
+ from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt
30
34
 
31
- A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
35
+ try:
36
+ from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
37
+ except ImportError:
38
+ from numpy.core.numeric import normalize_axis_tuple # numpy 1
32
39
 
33
- The 'leaf' array contains the values in the leaves.
40
+ from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
34
41
 
35
- The 'var' array contains the axes along which the decision nodes operate.
36
42
 
37
- The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. Unused nodes also have split set to 0.
43
+ class TreeHeaps(Protocol):
44
+ """A protocol for dataclasses that represent trees.
38
45
 
39
- Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
46
+ A tree is represented with arrays as a heap. The root node is at index 1.
47
+ The children nodes of a node at index :math:`i` are at indices :math:`2i`
48
+ (left child) and :math:`2i + 1` (right child). The array element at index 0
49
+ is unused.
40
50
 
41
- """
51
+ Since the nodes at the bottom can only be leaves and not decision nodes,
52
+ `var_tree` and `split_tree` are half as long as `leaf_tree`.
42
53
 
43
- import functools
44
- import math
54
+ Arrays may have additional initial axes to represent multiple trees.
55
+ """
45
56
 
46
- import jax
47
- from jax import lax
48
- from jax import numpy as jnp
57
+ leaf_tree: (
58
+ Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
59
+ )
60
+ """The values in the leaves of the trees. This array can be dirty, i.e.,
61
+ unused nodes can have whatever value. It may have an additional axis
62
+ for multivariate leaves."""
49
63
 
50
- from . import jaxext
64
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)']
65
+ """The axes along which the decision nodes operate. This array can be
66
+ dirty but for the always unused node at index 0 which must be set to 0."""
51
67
 
68
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)']
69
+ """The decision boundaries of the trees. The boundaries are open on the
70
+ right, i.e., a point belongs to the left child iff x < split. Whether a
71
+ node is a leaf is indicated by the corresponding 'split' element being
72
+ 0. Unused nodes also have split set to 0. This array can't be dirty."""
52
73
 
53
- def make_tree(depth, dtype):
74
+
75
+ def make_tree(
76
+ depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = ()
77
+ ) -> Shaped[Array, '*batch_shape 2**{depth}']:
54
78
  """
55
79
  Make an array to represent a binary tree.
56
80
 
57
81
  Parameters
58
82
  ----------
59
- depth : int
83
+ depth
60
84
  The maximum depth of the tree. Depth 1 means that there is only a root
61
85
  node.
62
- dtype : dtype
86
+ dtype
63
87
  The dtype of the array.
88
+ batch_shape
89
+ The leading shape of the array, to represent multiple trees and/or
90
+ multivariate trees.
64
91
 
65
92
  Returns
66
93
  -------
67
- tree : array
68
- An array of zeroes with shape (2 ** depth,).
94
+ An array of zeroes with the appropriate shape.
69
95
  """
70
- return jnp.zeros(2**depth, dtype)
96
+ shape = (*batch_shape, 2**depth)
97
+ return jnp.zeros(shape, dtype)
71
98
 
72
99
 
73
- def tree_depth(tree):
100
+ def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
74
101
  """
75
102
  Return the maximum depth of a tree.
76
103
 
77
104
  Parameters
78
105
  ----------
79
- tree : array
106
+ tree
80
107
  A tree created by `make_tree`. If the array is ND, the tree structure is
81
108
  assumed to be along the last axis.
82
109
 
83
110
  Returns
84
111
  -------
85
- depth : int
86
- The maximum depth of the tree.
112
+ The maximum depth of the tree.
87
113
  """
88
- return int(round(math.log2(tree.shape[-1])))
114
+ return round(math.log2(tree.shape[-1]))
89
115
 
90
116
 
91
- def traverse_tree(x, var_tree, split_tree):
117
+ def traverse_tree(
118
+ x: UInt[Array, ' p'],
119
+ var_tree: UInt[Array, ' 2**(d-1)'],
120
+ split_tree: UInt[Array, ' 2**(d-1)'],
121
+ ) -> UInt[Array, '']:
92
122
  """
93
123
  Find the leaf where a point falls into.
94
124
 
95
125
  Parameters
96
126
  ----------
97
- x : array (p,)
127
+ x
98
128
  The coordinates to evaluate the tree at.
99
- var_tree : array (2 ** (d - 1),)
129
+ var_tree
100
130
  The decision axes of the tree.
101
- split_tree : array (2 ** (d - 1),)
131
+ split_tree
102
132
  The decision boundaries of the tree.
103
133
 
104
134
  Returns
105
135
  -------
106
- index : int
107
- The index of the leaf.
136
+ The index of the leaf.
108
137
  """
109
138
  carry = (
110
139
  jnp.zeros((), bool),
111
- jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
140
+ jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
112
141
  )
113
142
 
114
143
  def loop(carry, _):
@@ -128,111 +157,132 @@ def traverse_tree(x, var_tree, split_tree):
128
157
  return index
129
158
 
130
159
 
131
- @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
132
- @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
133
- def traverse_forest(X, var_trees, split_trees):
160
+ @jit
161
+ @partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
162
+ @partial(vmap_nodoc, in_axes=(1, None, None))
163
+ def traverse_forest(
164
+ X: UInt[Array, 'p n'],
165
+ var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
166
+ split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
167
+ ) -> UInt[Array, '*forest_shape n']:
134
168
  """
135
- Find the leaves where points fall into.
169
+ Find the leaves where points falls into for each tree in a set.
136
170
 
137
171
  Parameters
138
172
  ----------
139
- X : array (p, n)
173
+ X
140
174
  The coordinates to evaluate the trees at.
141
- var_trees : array (m, 2 ** (d - 1))
175
+ var_trees
142
176
  The decision axes of the trees.
143
- split_trees : array (m, 2 ** (d - 1))
177
+ split_trees
144
178
  The decision boundaries of the trees.
145
179
 
146
180
  Returns
147
181
  -------
148
- indices : array (m, n)
149
- The indices of the leaves.
182
+ The indices of the leaves.
150
183
  """
151
184
  return traverse_tree(X, var_trees, split_trees)
152
185
 
153
186
 
154
- def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
187
+ @partial(jit, static_argnames=('sum_batch_axis',))
188
+ def evaluate_forest(
189
+ X: UInt[Array, 'p n'],
190
+ trees: TreeHeaps,
191
+ *,
192
+ sum_batch_axis: int | tuple[int, ...] = (),
193
+ ) -> (
194
+ Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
195
+ ):
155
196
  """
156
- Evaluate a ensemble of trees at an array of points.
197
+ Evaluate an ensemble of trees at an array of points.
157
198
 
158
199
  Parameters
159
200
  ----------
160
- X : array (p, n)
201
+ X
161
202
  The coordinates to evaluate the trees at.
162
- leaf_trees : array (m, 2 ** d)
163
- The leaf values of the tree or forest. If the input is a forest, the
164
- first axis is the tree index, and the values are summed.
165
- var_trees : array (m, 2 ** (d - 1))
166
- The decision axes of the trees.
167
- split_trees : array (m, 2 ** (d - 1))
168
- The decision boundaries of the trees.
169
- dtype : dtype, optional
170
- The dtype of the output. Ignored if `sum_trees` is `False`.
171
- sum_trees : bool, default True
172
- Whether to sum the values across trees.
203
+ trees
204
+ The trees.
205
+ sum_batch_axis
206
+ The batch axes to sum over. By default, no summation is performed.
207
+ Note that negative indices count from the end of the batch dimensions,
208
+ the core dimensions n and k can't be summed over by this function.
173
209
 
174
210
  Returns
175
211
  -------
176
- out : array (n,) or (m, n)
177
- The (sum of) the values of the trees at the points in `X`.
212
+ The (sum of) the values of the trees at the points in `X`.
178
213
  """
179
- indices = traverse_forest(X, var_trees, split_trees)
180
- ntree, _ = leaf_trees.shape
181
- tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))
182
- leaves = leaf_trees[tree_index[:, None], indices]
183
- if sum_trees:
184
- return jnp.sum(leaves, axis=0, dtype=dtype)
185
- # this sum suggests to swap the vmaps, but I think it's better for X
186
- # copying to keep it that way
187
- else:
188
- return leaves
189
-
190
-
191
- def is_actual_leaf(split_tree, *, add_bottom_level=False):
214
+ indices: UInt[Array, '*forest_shape n']
215
+ indices = traverse_forest(X, trees.var_tree, trees.split_tree)
216
+
217
+ is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim
218
+
219
+ bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
220
+ bc_indices = indices[..., None, :, None] if is_mv else indices[..., None]
221
+
222
+ bc_leaf_tree: (
223
+ Float32[Array, '*forest_shape 1 tree_size']
224
+ | Float32[Array, '*forest_shape k 1 tree_size']
225
+ )
226
+ bc_leaf_tree = (
227
+ trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
228
+ )
229
+
230
+ bc_leaves: (
231
+ Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
232
+ )
233
+ bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1)
234
+
235
+ leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
236
+ leaves = jnp.squeeze(bc_leaves, -1)
237
+
238
+ axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1)
239
+ return jnp.sum(leaves, axis=axis)
240
+
241
+
242
+ def is_actual_leaf(
243
+ split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
244
+ ) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
192
245
  """
193
246
  Return a mask indicating the leaf nodes in a tree.
194
247
 
195
248
  Parameters
196
249
  ----------
197
- split_tree : int array (2 ** (d - 1),)
250
+ split_tree
198
251
  The splitting points of the tree.
199
- add_bottom_level : bool, default False
252
+ add_bottom_level
200
253
  If True, the bottom level of the tree is also considered.
201
254
 
202
255
  Returns
203
256
  -------
204
- is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,)
205
- The mask indicating the leaf nodes. The length is doubled if
206
- `add_bottom_level` is True.
257
+ The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
207
258
  """
208
259
  size = split_tree.size
209
260
  is_leaf = split_tree == 0
210
261
  if add_bottom_level:
211
262
  size *= 2
212
263
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
213
- index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
264
+ index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
214
265
  parent_index = index >> 1
215
266
  parent_nonleaf = split_tree[parent_index].astype(bool)
216
267
  parent_nonleaf = parent_nonleaf.at[1].set(True)
217
268
  return is_leaf & parent_nonleaf
218
269
 
219
270
 
220
- def is_leaves_parent(split_tree):
271
+ def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
221
272
  """
222
273
  Return a mask indicating the nodes with leaf (and only leaf) children.
223
274
 
224
275
  Parameters
225
276
  ----------
226
- split_tree : int array (2 ** (d - 1),)
277
+ split_tree
227
278
  The decision boundaries of the tree.
228
279
 
229
280
  Returns
230
281
  -------
231
- is_leaves_parent : bool array (2 ** (d - 1),)
232
- The mask indicating which nodes have leaf children.
282
+ The mask indicating which nodes have leaf children.
233
283
  """
234
284
  index = jnp.arange(
235
- split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
285
+ split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
236
286
  )
237
287
  left_index = index << 1 # left child
238
288
  right_index = left_index + 1 # right child
@@ -243,45 +293,50 @@ def is_leaves_parent(split_tree):
243
293
  # the 0-th item has split == 0, so it's not counted
244
294
 
245
295
 
246
- def tree_depths(tree_length):
296
+ def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
247
297
  """
248
298
  Return the depth of each node in a binary tree.
249
299
 
250
300
  Parameters
251
301
  ----------
252
- tree_length : int
302
+ tree_size
253
303
  The length of the tree array, i.e., 2 ** d.
254
304
 
255
305
  Returns
256
306
  -------
257
- depth : array (tree_length,)
258
- The depth of each node. The root node (index 1) has depth 0. The depth
259
- is the position of the most significant non-zero bit in the index. The
260
- first element (the unused node) is marked as depth 0.
307
+ The depth of each node.
308
+
309
+ Notes
310
+ -----
311
+ The root node (index 1) has depth 0. The depth is the position of the most
312
+ significant non-zero bit in the index. The first element (the unused node)
313
+ is marked as depth 0.
261
314
  """
262
315
  depths = []
263
316
  depth = 0
264
- for i in range(tree_length):
317
+ for i in range(tree_size):
265
318
  if i == 2**depth:
266
319
  depth += 1
267
320
  depths.append(depth - 1)
268
321
  depths[0] = 0
269
- return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
322
+ return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
270
323
 
271
324
 
272
- def is_used(split_tree):
325
+ @partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
326
+ def is_used(
327
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
328
+ ) -> Bool[Array, '*batch_shape 2**d']:
273
329
  """
274
330
  Return a mask indicating the used nodes in a tree.
275
331
 
276
332
  Parameters
277
333
  ----------
278
- split_tree : int array (2 ** (d - 1),)
334
+ split_tree
279
335
  The decision boundaries of the tree.
280
336
 
281
337
  Returns
282
338
  -------
283
- is_used : bool array (2 ** d,)
284
- A mask indicating which nodes are actually used.
339
+ A mask indicating which nodes are actually used.
285
340
  """
286
341
  internal_node = split_tree.astype(bool)
287
342
  internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
@@ -289,22 +344,69 @@ def is_used(split_tree):
289
344
  return internal_node | actual_leaf
290
345
 
291
346
 
292
- def forest_fill(split_trees):
347
+ @jit
348
+ def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
293
349
  """
294
350
  Return the fraction of used nodes in a set of trees.
295
351
 
296
352
  Parameters
297
353
  ----------
298
- split_trees : array (m, 2 ** (d - 1),)
354
+ split_tree
299
355
  The decision boundaries of the trees.
300
356
 
301
357
  Returns
302
358
  -------
303
- fill : float
304
- The number of tree nodes in the forest over the maximum number that
305
- could be stored in the arrays.
359
+ Number of tree nodes over the maximum number that could be stored.
306
360
  """
307
- m, _ = split_trees.shape
308
- used = jax.vmap(is_used)(split_trees)
361
+ used = is_used(split_tree)
309
362
  count = jnp.count_nonzero(used)
310
- return count / (used.size - m)
363
+ batch_size = split_tree.size // split_tree.shape[-1]
364
+ return count / (used.size - batch_size)
365
+
366
+
367
+ @partial(jit, static_argnames=('p', 'sum_batch_axis'))
368
+ def var_histogram(
369
+ p: int,
370
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
371
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
372
+ *,
373
+ sum_batch_axis: int | tuple[int, ...] = (),
374
+ ) -> Int32[Array, '*reduced_batch_shape {p}']:
375
+ """
376
+ Count how many times each variable appears in a tree.
377
+
378
+ Parameters
379
+ ----------
380
+ p
381
+ The number of variables (the maximum value that can occur in `var_tree`
382
+ is ``p - 1``).
383
+ var_tree
384
+ The decision axes of the tree.
385
+ split_tree
386
+ The decision boundaries of the tree.
387
+ sum_batch_axis
388
+ The batch axes to sum over. By default, no summation is performed. Note
389
+ that negative indices count from the end of the batch dimensions, the
390
+ core dimension p can't be summed over by this function.
391
+
392
+ Returns
393
+ -------
394
+ The histogram(s) of the variables used in the tree.
395
+ """
396
+ is_internal = split_tree.astype(bool)
397
+
398
+ def scatter_add(
399
+ var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
400
+ is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
401
+ ) -> Int32[Array, ' p']:
402
+ return jnp.zeros(p, int).at[var_tree].add(is_internal)
403
+
404
+ # vmap scatter_add over non-batched dims
405
+ batch_ndim = var_tree.ndim - 1
406
+ axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
407
+ for i in reversed(range(batch_ndim)):
408
+ neg_i = i - var_tree.ndim
409
+ if i not in axes:
410
+ scatter_add = vmap(scatter_add, in_axes=neg_i)
411
+
412
+ return scatter_add(var_tree, is_internal)