bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/grove.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/grove.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -28,10 +28,14 @@ import math
28
28
  from functools import partial
29
29
  from typing import Protocol
30
30
 
31
- import jax
32
- from jax import jit, lax
31
+ from jax import jit, lax, vmap
33
32
  from jax import numpy as jnp
34
- from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Real, Shaped, UInt
33
+ from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt
34
+
35
+ try:
36
+ from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
37
+ except ImportError:
38
+ from numpy.core.numeric import normalize_axis_tuple # numpy 1
35
39
 
36
40
  from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
37
41
 
@@ -44,32 +48,33 @@ class TreeHeaps(Protocol):
44
48
  (left child) and :math:`2i + 1` (right child). The array element at index 0
45
49
  is unused.
46
50
 
47
- Parameters
48
- ----------
49
- leaf_tree
50
- The values in the leaves of the trees. This array can be dirty, i.e.,
51
- unused nodes can have whatever value.
52
- var_tree
53
- The axes along which the decision nodes operate. This array can be
54
- dirty but for the always unused node at index 0 which must be set to 0.
55
- split_tree
56
- The decision boundaries of the trees. The boundaries are open on the
57
- right, i.e., a point belongs to the left child iff x < split. Whether a
58
- node is a leaf is indicated by the corresponding 'split' element being
59
- 0. Unused nodes also have split set to 0. This array can't be dirty.
60
-
61
- Notes
62
- -----
63
51
  Since the nodes at the bottom can only be leaves and not decision nodes,
64
52
  `var_tree` and `split_tree` are half as long as `leaf_tree`.
53
+
54
+ Arrays may have additional initial axes to represent multiple trees.
65
55
  """
66
56
 
67
- leaf_tree: Float32[Array, '* 2**d']
68
- var_tree: UInt[Array, '* 2**(d-1)']
69
- split_tree: UInt[Array, '* 2**(d-1)']
57
+ leaf_tree: (
58
+ Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
59
+ )
60
+ """The values in the leaves of the trees. This array can be dirty, i.e.,
61
+ unused nodes can have whatever value. It may have an additional axis
62
+ for multivariate leaves."""
63
+
64
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)']
65
+ """The axes along which the decision nodes operate. This array can be
66
+ dirty but for the always unused node at index 0 which must be set to 0."""
67
+
68
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)']
69
+ """The decision boundaries of the trees. The boundaries are open on the
70
+ right, i.e., a point belongs to the left child iff x < split. Whether a
71
+ node is a leaf is indicated by the corresponding 'split' element being
72
+ 0. Unused nodes also have split set to 0. This array can't be dirty."""
70
73
 
71
74
 
72
- def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']:
75
+ def make_tree(
76
+ depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = ()
77
+ ) -> Shaped[Array, '*batch_shape 2**{depth}']:
73
78
  """
74
79
  Make an array to represent a binary tree.
75
80
 
@@ -80,15 +85,19 @@ def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']:
80
85
  node.
81
86
  dtype
82
87
  The dtype of the array.
88
+ batch_shape
89
+ The leading shape of the array, to represent multiple trees and/or
90
+ multivariate trees.
83
91
 
84
92
  Returns
85
93
  -------
86
94
  An array of zeroes with the appropriate shape.
87
95
  """
88
- return jnp.zeros(2**depth, dtype)
96
+ shape = (*batch_shape, 2**depth)
97
+ return jnp.zeros(shape, dtype)
89
98
 
90
99
 
91
- def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
100
+ def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
92
101
  """
93
102
  Return the maximum depth of a tree.
94
103
 
@@ -106,10 +115,10 @@ def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
106
115
 
107
116
 
108
117
  def traverse_tree(
109
- x: Real[Array, ' p'],
118
+ x: UInt[Array, ' p'],
110
119
  var_tree: UInt[Array, ' 2**(d-1)'],
111
120
  split_tree: UInt[Array, ' 2**(d-1)'],
112
- ) -> Int32[Array, '']:
121
+ ) -> UInt[Array, '']:
113
122
  """
114
123
  Find the leaf where a point falls into.
115
124
 
@@ -148,15 +157,16 @@ def traverse_tree(
148
157
  return index
149
158
 
150
159
 
151
- @partial(vmap_nodoc, in_axes=(None, 0, 0))
160
+ @jit
161
+ @partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
152
162
  @partial(vmap_nodoc, in_axes=(1, None, None))
153
163
  def traverse_forest(
154
- X: Real[Array, 'p n'],
155
- var_trees: UInt[Array, 'm 2**(d-1)'],
156
- split_trees: UInt[Array, 'm 2**(d-1)'],
157
- ) -> Int32[Array, 'm n']:
164
+ X: UInt[Array, 'p n'],
165
+ var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
166
+ split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
167
+ ) -> UInt[Array, '*forest_shape n']:
158
168
  """
159
- Find the leaves where points fall into.
169
+ Find the leaves where points falls into for each tree in a set.
160
170
 
161
171
  Parameters
162
172
  ----------
@@ -174,35 +184,59 @@ def traverse_forest(
174
184
  return traverse_tree(X, var_trees, split_trees)
175
185
 
176
186
 
187
+ @partial(jit, static_argnames=('sum_batch_axis',))
177
188
  def evaluate_forest(
178
- X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True
179
- ) -> Float32[Array, ' n'] | Float32[Array, 'm n']:
189
+ X: UInt[Array, 'p n'],
190
+ trees: TreeHeaps,
191
+ *,
192
+ sum_batch_axis: int | tuple[int, ...] = (),
193
+ ) -> (
194
+ Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
195
+ ):
180
196
  """
181
- Evaluate a ensemble of trees at an array of points.
197
+ Evaluate an ensemble of trees at an array of points.
182
198
 
183
199
  Parameters
184
200
  ----------
185
201
  X
186
202
  The coordinates to evaluate the trees at.
187
203
  trees
188
- The tree heaps, with batch shape (m,).
189
- sum_trees
190
- Whether to sum the values across trees.
204
+ The trees.
205
+ sum_batch_axis
206
+ The batch axes to sum over. By default, no summation is performed.
207
+ Note that negative indices count from the end of the batch dimensions,
208
+ the core dimensions n and k can't be summed over by this function.
191
209
 
192
210
  Returns
193
211
  -------
194
212
  The (sum of) the values of the trees at the points in `X`.
195
213
  """
214
+ indices: UInt[Array, '*forest_shape n']
196
215
  indices = traverse_forest(X, trees.var_tree, trees.split_tree)
197
- ntree, _ = trees.leaf_tree.shape
198
- tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))
199
- leaves = trees.leaf_tree[tree_index[:, None], indices]
200
- if sum_trees:
201
- return jnp.sum(leaves, axis=0, dtype=jnp.float32)
202
- # this sum suggests to swap the vmaps, but I think it's better for X
203
- # copying to keep it that way
204
- else:
205
- return leaves
216
+
217
+ is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim
218
+
219
+ bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
220
+ bc_indices = indices[..., None, :, None] if is_mv else indices[..., None]
221
+
222
+ bc_leaf_tree: (
223
+ Float32[Array, '*forest_shape 1 tree_size']
224
+ | Float32[Array, '*forest_shape k 1 tree_size']
225
+ )
226
+ bc_leaf_tree = (
227
+ trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
228
+ )
229
+
230
+ bc_leaves: (
231
+ Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
232
+ )
233
+ bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1)
234
+
235
+ leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
236
+ leaves = jnp.squeeze(bc_leaves, -1)
237
+
238
+ axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1)
239
+ return jnp.sum(leaves, axis=axis)
206
240
 
207
241
 
208
242
  def is_actual_leaf(
@@ -259,13 +293,13 @@ def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(
259
293
  # the 0-th item has split == 0, so it's not counted
260
294
 
261
295
 
262
- def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
296
+ def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
263
297
  """
264
298
  Return the depth of each node in a binary tree.
265
299
 
266
300
  Parameters
267
301
  ----------
268
- tree_length
302
+ tree_size
269
303
  The length of the tree array, i.e., 2 ** d.
270
304
 
271
305
  Returns
@@ -280,7 +314,7 @@ def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
280
314
  """
281
315
  depths = []
282
316
  depth = 0
283
- for i in range(tree_length):
317
+ for i in range(tree_size):
284
318
  if i == 2**depth:
285
319
  depth += 1
286
320
  depths.append(depth - 1)
@@ -288,7 +322,10 @@ def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
288
322
  return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
289
323
 
290
324
 
291
- def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']:
325
+ @partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
326
+ def is_used(
327
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
328
+ ) -> Bool[Array, '*batch_shape 2**d']:
292
329
  """
293
330
  Return a mask indicating the used nodes in a tree.
294
331
 
@@ -308,7 +345,7 @@ def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']:
308
345
 
309
346
 
310
347
  @jit
311
- def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array, '']:
348
+ def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
312
349
  """
313
350
  Return the fraction of used nodes in a set of trees.
314
351
 
@@ -321,36 +358,55 @@ def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array,
321
358
  -------
322
359
  Number of tree nodes over the maximum number that could be stored.
323
360
  """
324
- num_trees, _ = split_tree.shape
325
- used = jax.vmap(is_used)(split_tree)
361
+ used = is_used(split_tree)
326
362
  count = jnp.count_nonzero(used)
327
- return count / (used.size - num_trees)
363
+ batch_size = split_tree.size // split_tree.shape[-1]
364
+ return count / (used.size - batch_size)
328
365
 
329
366
 
367
+ @partial(jit, static_argnames=('p', 'sum_batch_axis'))
330
368
  def var_histogram(
331
- p: int, var_tree: UInt[Array, '* 2**(d-1)'], split_tree: UInt[Array, '* 2**(d-1)']
332
- ) -> Int32[Array, ' {p}']:
369
+ p: int,
370
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
371
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
372
+ *,
373
+ sum_batch_axis: int | tuple[int, ...] = (),
374
+ ) -> Int32[Array, '*reduced_batch_shape {p}']:
333
375
  """
334
376
  Count how many times each variable appears in a tree.
335
377
 
336
378
  Parameters
337
379
  ----------
338
380
  p
339
- The number of variables (the maximum value that can occur in
340
- `var_tree` is ``p - 1``).
381
+ The number of variables (the maximum value that can occur in `var_tree`
382
+ is ``p - 1``).
341
383
  var_tree
342
384
  The decision axes of the tree.
343
385
  split_tree
344
386
  The decision boundaries of the tree.
387
+ sum_batch_axis
388
+ The batch axes to sum over. By default, no summation is performed. Note
389
+ that negative indices count from the end of the batch dimensions, the
390
+ core dimension p can't be summed over by this function.
345
391
 
346
392
  Returns
347
393
  -------
348
- The histogram of the variables used in the tree.
349
-
350
- Notes
351
- -----
352
- If there are leading axes in the tree arrays (i.e., multiple trees), the
353
- returned counts are cumulative over trees.
394
+ The histogram(s) of the variables used in the tree.
354
395
  """
355
396
  is_internal = split_tree.astype(bool)
356
- return jnp.zeros(p, int).at[var_tree].add(is_internal)
397
+
398
+ def scatter_add(
399
+ var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
400
+ is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
401
+ ) -> Int32[Array, ' p']:
402
+ return jnp.zeros(p, int).at[var_tree].add(is_internal)
403
+
404
+ # vmap scatter_add over non-batched dims
405
+ batch_ndim = var_tree.ndim - 1
406
+ axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
407
+ for i in reversed(range(batch_ndim)):
408
+ neg_i = i - var_tree.ndim
409
+ if i not in axes:
410
+ scatter_add = vmap(scatter_add, in_axes=neg_i)
411
+
412
+ return scatter_add(var_tree, is_internal)
bartz/jaxext/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext/__init__.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -24,13 +24,23 @@
24
24
 
25
25
  """Additions to jax."""
26
26
 
27
- import functools
28
27
  import math
29
28
  from collections.abc import Sequence
29
+ from contextlib import nullcontext
30
+ from functools import partial
30
31
 
31
32
  import jax
33
+ from jax import (
34
+ Device,
35
+ debug_key_reuse,
36
+ device_count,
37
+ ensure_compile_time_eval,
38
+ jit,
39
+ random,
40
+ vmap,
41
+ )
32
42
  from jax import numpy as jnp
33
- from jax import random
43
+ from jax.dtypes import prng_key
34
44
  from jax.lax import scan
35
45
  from jax.scipy.special import ndtr
36
46
  from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
@@ -63,7 +73,7 @@ def minimal_unsigned_dtype(value):
63
73
  return jnp.uint64
64
74
 
65
75
 
66
- @functools.partial(jax.jit, static_argnums=(1,))
76
+ @partial(jax.jit, static_argnums=(1,))
67
77
  def unique(
68
78
  x: Shaped[Array, ' _'], size: int, fill_value: Scalar
69
79
  ) -> tuple[Shaped[Array, ' {size}'], int]:
@@ -114,24 +124,42 @@ class split:
114
124
  The key to split.
115
125
  num
116
126
  The number of keys to split into.
127
+
128
+ Notes
129
+ -----
130
+ Unlike `jax.random.split`, this class supports a vector of keys as input. In
131
+ this case, it behaves as if everything had been vmapped over, so `keys.pop`
132
+ has an additional initial output dimension equal to the number of input
133
+ keys, and the deterministic dependency respects this axis.
117
134
  """
118
135
 
119
- def __init__(self, key: Key[Array, ''], num: int = 2):
120
- self._keys = random.split(key, num)
136
+ _keys: tuple[Key[Array, '*batch'], ...]
137
+ _num_used: int
138
+
139
+ def __init__(self, key: Key[Array, '*batch'], num: int = 2):
140
+ if key.ndim:
141
+ context = debug_key_reuse(False)
142
+ else:
143
+ context = nullcontext()
144
+ with context:
145
+ # jitted-vmapped key split seems to be triggering a false positive
146
+ # with key reuse checks
147
+ self._keys = _split_unpack(key, num)
148
+ self._num_used = 0
121
149
 
122
150
  def __len__(self):
123
- return self._keys.size
151
+ return len(self._keys) - self._num_used
124
152
 
125
- def pop(self, shape: int | tuple[int, ...] | None = None) -> Key[Array, '*']:
153
+ def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']:
126
154
  """
127
155
  Pop one or more keys from the list.
128
156
 
129
157
  Parameters
130
158
  ----------
131
159
  shape
132
- The shape of the keys to pop. If `None`, a single key is popped.
133
- If an integer, that many keys are popped. If a tuple, the keys are
134
- reshaped to that shape.
160
+ The shape of the keys to pop. If empty (default), a single key is
161
+ popped and returned. If not empty, the popped key is split and
162
+ reshaped to the target shape.
135
163
 
136
164
  Returns
137
165
  -------
@@ -140,24 +168,41 @@ class split:
140
168
  Raises
141
169
  ------
142
170
  IndexError
143
- If `shape` is larger than the number of keys left in the list.
144
-
145
- Notes
146
- -----
147
- The keys are popped from the beginning of the list, so for example
148
- ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
171
+ If the list is empty.
149
172
  """
150
- if shape is None:
151
- shape = ()
152
- elif not isinstance(shape, tuple):
153
- shape = (shape,)
154
- size_to_pop = math.prod(shape)
155
- if size_to_pop > self._keys.size:
156
- msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
173
+ if len(self) == 0:
174
+ msg = 'No keys left to pop'
157
175
  raise IndexError(msg)
158
- popped_keys = self._keys[:size_to_pop]
159
- self._keys = self._keys[size_to_pop:]
160
- return popped_keys.reshape(shape)
176
+ if not isinstance(shape, tuple):
177
+ shape = (shape,)
178
+ key = self._keys[self._num_used]
179
+ self._num_used += 1
180
+ if shape:
181
+ key = _split_shaped(key, shape)
182
+ return key
183
+
184
+
185
+ @partial(jit, static_argnums=(1,))
186
+ def _split_unpack(
187
+ key: Key[Array, '*batch'], num: int
188
+ ) -> tuple[Key[Array, '*batch'], ...]:
189
+ if key.ndim == 0:
190
+ keys = random.split(key, num)
191
+ elif key.ndim == 1:
192
+ keys = vmap(random.split, in_axes=(0, None), out_axes=1)(key, num)
193
+ return tuple(keys)
194
+
195
+
196
+ @partial(jit, static_argnums=(1,))
197
+ def _split_shaped(
198
+ key: Key[Array, '*batch'], shape: tuple[int, ...]
199
+ ) -> Key[Array, '*batch {shape}']:
200
+ num = math.prod(shape)
201
+ if key.ndim == 0:
202
+ keys = random.split(key, num)
203
+ elif key.ndim == 1:
204
+ keys = vmap(random.split, in_axes=(0, None))(key, num)
205
+ return keys.reshape(*key.shape, *shape)
161
206
 
162
207
 
163
208
  def truncated_normal_onesided(
@@ -165,6 +210,8 @@ def truncated_normal_onesided(
165
210
  shape: Sequence[int],
166
211
  upper: Bool[Array, '*'],
167
212
  bound: Float32[Array, '*'],
213
+ *,
214
+ clip: bool = True,
168
215
  ) -> Float32[Array, '*']:
169
216
  """
170
217
  Sample from a one-sided truncated standard normal distribution.
@@ -179,6 +226,9 @@ def truncated_normal_onesided(
179
226
  True for (-∞, bound], False for [bound, ∞).
180
227
  bound
181
228
  The truncation boundary.
229
+ clip
230
+ Whether to clip the truncated uniform samples to (0, 1) before
231
+ transforming them to truncated normal. Intended for debugging purposes.
182
232
 
183
233
  Returns
184
234
  -------
@@ -209,5 +259,29 @@ def truncated_normal_onesided(
209
259
  left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
210
260
  right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
211
261
  truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
262
+ if clip:
263
+ # on gpu the accuracy is lower and sometimes u can reach the boundaries
264
+ zero = jnp.zeros((), truncated_u.dtype)
265
+ one = jnp.ones((), truncated_u.dtype)
266
+ truncated_u = jnp.clip(
267
+ truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
268
+ )
212
269
  truncated_norm = ndtri(truncated_u)
213
270
  return jnp.where(bound_pos, -truncated_norm, truncated_norm)
271
+
272
+
273
+ def get_default_device() -> Device:
274
+ """Get the current default JAX device."""
275
+ with ensure_compile_time_eval():
276
+ return jnp.zeros(()).device
277
+
278
+
279
+ def get_device_count() -> int:
280
+ """Get the number of available devices on the default platform."""
281
+ device = get_default_device()
282
+ return device_count(device.platform)
283
+
284
+
285
+ def is_key(x: object) -> bool:
286
+ """Determine if `x` is a jax random key."""
287
+ return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key)