bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext/_autobatch.py
2
2
  #
3
- # Copyright (c) 2025, Giacomo Petrillo
3
+ # Copyright (c) 2025-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -26,16 +26,23 @@
26
26
 
27
27
  import math
28
28
  from collections.abc import Callable
29
- from functools import wraps
29
+ from functools import partial, wraps
30
30
  from warnings import warn
31
31
 
32
- from jax import eval_shape, jit
32
+ from jax.typing import DTypeLike
33
+
34
+ try:
35
+ from numpy.lib.array_utils import normalize_axis_index # numpy 2
36
+ except ImportError:
37
+ from numpy.core.numeric import normalize_axis_index # numpy 1
38
+
39
+ from jax import ShapeDtypeStruct, eval_shape, jit
33
40
  from jax import numpy as jnp
34
41
  from jax.lax import scan
35
42
  from jax.tree import flatten as tree_flatten
36
43
  from jax.tree import map as tree_map
37
44
  from jax.tree import reduce as tree_reduce
38
- from jaxtyping import PyTree
45
+ from jaxtyping import Array, PyTree, Shaped
39
46
 
40
47
 
41
48
  def expand_axes(axes, tree):
@@ -47,14 +54,43 @@ def expand_axes(axes, tree):
47
54
  return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None)
48
55
 
49
56
 
57
+ def normalize_axes(
58
+ axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T']
59
+ ) -> PyTree[int | None, ' T']:
60
+ """Normalize axes to be non-negative and valid for the corresponding arrays in the tree."""
61
+
62
+ def normalize_axis(axis: int | None, x: Array) -> int | None:
63
+ if axis is None:
64
+ return None
65
+ else:
66
+ return normalize_axis_index(axis, len(x.shape))
67
+
68
+ return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None)
69
+
70
+
50
71
  def check_no_nones(axes, tree):
51
72
  def check_not_none(_, axis):
52
73
  assert axis is not None
53
74
 
54
- tree_map(check_not_none, tree, axes)
75
+ tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None)
76
+
77
+
78
+ def remove_axis(
79
+ x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
80
+ ) -> PyTree[ShapeDtypeStruct, ' T']:
81
+ """Remove an axis from dummy arrays and change the type to reduction type."""
82
+
83
+ def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct:
84
+ new_shape = x.shape[:axis] + x.shape[axis + 1 :]
85
+ new_dtype = reduction_dtype(ufunc, x.dtype)
86
+ return ShapeDtypeStruct(new_shape, new_dtype)
87
+
88
+ return tree_map(remove_axis, x, axis)
55
89
 
56
90
 
57
91
  def extract_size(axes, tree):
92
+ """Get the size of each array in tree at the axis in axes, check they are equal and return it."""
93
+
58
94
  def get_size(x, axis):
59
95
  if axis is None:
60
96
  return None
@@ -90,6 +126,7 @@ def next_divisor_large(dividend, min_divisor):
90
126
 
91
127
 
92
128
  def next_divisor(dividend, min_divisor):
129
+ """Return divisor >= min_divisor such that divided % divisor == 0."""
93
130
  if dividend == 0:
94
131
  return min_divisor
95
132
  if min_divisor * min_divisor <= dividend:
@@ -131,20 +168,73 @@ def move_axes_in(axes, tree):
131
168
  return tree_map(move_axis_in, tree, axes)
132
169
 
133
170
 
134
- def batch(tree, nbatches):
171
+ def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
172
+ """Split the first axis into two axes, the first of size `nbatches`."""
173
+
135
174
  def batch(x):
136
- return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
175
+ return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
137
176
 
138
177
  return tree_map(batch, tree)
139
178
 
140
179
 
141
- def unbatch(tree):
180
+ def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
181
+ """Merge the first two axes into a single axis."""
182
+
142
183
  def unbatch(x):
143
- return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
184
+ return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
144
185
 
145
186
  return tree_map(unbatch, tree)
146
187
 
147
188
 
189
+ def reduce(
190
+ ufunc: jnp.ufunc,
191
+ x: PyTree[Array, ' T'],
192
+ axes: PyTree[int, ' T'],
193
+ initial: PyTree[Array, ' T'] | None,
194
+ ) -> PyTree[Array, ' T']:
195
+ """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
196
+ if initial is None:
197
+
198
+ def reduce(x: Array, axis: int) -> Array:
199
+ return ufunc.reduce(x, axis=axis)
200
+
201
+ return tree_map(reduce, x, axes)
202
+
203
+ else:
204
+
205
+ def reduce(x: Array, initial: Array, axis: int) -> Array:
206
+ reduced = ufunc.reduce(x, axis=axis)
207
+ return ufunc(initial, reduced)
208
+
209
+ return tree_map(reduce, x, initial, axes)
210
+
211
+
212
+ def identity(
213
+ ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
214
+ ) -> PyTree[Array, ' T']:
215
+ """Get the identity element for `ufunc` and each array in `x`."""
216
+
217
+ def identity(x: ShapeDtypeStruct) -> Array:
218
+ identity = identity_for(ufunc, x.dtype)
219
+ return jnp.broadcast_to(identity, x.shape)
220
+
221
+ return tree_map(identity, x)
222
+
223
+
224
+ def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
225
+ """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
226
+ return ufunc.reduce(jnp.empty(1, input_dtype)).dtype
227
+
228
+
229
+ def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
230
+ """Return the identity for ufunc as an array scalar with the right dtype."""
231
+ # get output type from input type, e.g., int8 is accumulated to int32
232
+ dtype = reduction_dtype(ufunc, input_dtype)
233
+
234
+ # return as explicitly typed array
235
+ return jnp.array(ufunc.identity, dtype)
236
+
237
+
148
238
  def check_same(tree1, tree2):
149
239
  def check_same(x1, x2):
150
240
  assert x1.shape == x2.shape
@@ -153,12 +243,20 @@ def check_same(tree1, tree2):
153
243
  tree_map(check_same, tree1, tree2)
154
244
 
155
245
 
246
+ class NotDefined:
247
+ pass
248
+
249
+
156
250
  def autobatch(
157
251
  func: Callable,
158
252
  max_io_nbytes: int,
159
253
  in_axes: PyTree[int | None] = 0,
160
254
  out_axes: PyTree[int] = 0,
255
+ *,
161
256
  return_nbatches: bool = False,
257
+ reduce_ufunc: jnp.ufunc | None = None,
258
+ warn_on_overflow: bool = True,
259
+ result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
162
260
  ) -> Callable:
163
261
  """
164
262
  Batch a function such that each batch is smaller than a threshold.
@@ -179,60 +277,168 @@ def autobatch(
179
277
  The same for outputs (but non-batching is not allowed).
180
278
  return_nbatches
181
279
  If True, the number of batches is returned as a second output.
280
+ reduce_ufunc
281
+ Function used to reduce the output along the batched axis (e.g.,
282
+ `jax.numpy.add`).
283
+ warn_on_overflow
284
+ If True, a warning is raised if the memory limit could not be
285
+ respected.
286
+ result_shape_dtype
287
+ A pytree of dummy arrays matching the expected output. If not provided,
288
+ the function is traced an additional time to determine the output
289
+ structure.
182
290
 
183
291
  Returns
184
292
  -------
185
293
  A function with the same signature as `func`, save for the return value if `return_nbatches`.
294
+
295
+ Notes
296
+ -----
297
+ Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
298
+ arguments is idempotent. Furthermore, `autobatch` can be applied multiple
299
+ times over multiple axes with the same `max_io_nbytes` limit to work on
300
+ multiple axes; in this case it won't unnecessarily loop over additional axes
301
+ if one or more outer `autobatch` are already sufficient.
302
+
303
+ To handle memory used in intermediate values: assuming all intermediate
304
+ values have size that scales linearly with the axis batched over, say the
305
+ batched input/output total size is ``batched_size * core_io_size``, and the
306
+ intermediate values have size ``batched_size * core_int_size``, then to take
307
+ them into account divide `max_io_nbytes` by ``(1 + core_int_size /
308
+ core_io_size)``.
186
309
  """
187
- initial_in_axes = in_axes
188
- initial_out_axes = out_axes
189
310
 
190
311
  @jit
191
312
  @wraps(func)
192
- def batched_func(*args):
193
- example_result = eval_shape(func, *args)
194
-
195
- in_axes = expand_axes(initial_in_axes, args)
196
- out_axes = expand_axes(initial_out_axes, example_result)
197
- check_no_nones(out_axes, example_result)
198
-
199
- size = extract_size((in_axes, out_axes), (args, example_result))
313
+ def autobatch_wrapper(*args):
314
+ return batched_func(
315
+ func,
316
+ max_io_nbytes,
317
+ in_axes,
318
+ out_axes,
319
+ return_nbatches,
320
+ reduce_ufunc,
321
+ warn_on_overflow,
322
+ result_shape_dtype,
323
+ args,
324
+ )
200
325
 
201
- args, nonbatched_args = pull_nonbatched(in_axes, args)
326
+ return autobatch_wrapper
202
327
 
203
- total_nbytes = sum_nbytes((args, example_result))
204
- min_nbatches = total_nbytes // max_io_nbytes + bool(
205
- total_nbytes % max_io_nbytes
206
- )
207
- min_nbatches = max(1, min_nbatches)
208
- nbatches = next_divisor(size, min_nbatches)
209
- assert 1 <= nbatches <= max(1, size)
210
- assert size % nbatches == 0
211
- assert total_nbytes % nbatches == 0
212
-
213
- batch_nbytes = total_nbytes // nbatches
214
- if batch_nbytes > max_io_nbytes:
215
- assert size == nbatches
216
- msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
217
- warn(msg)
218
-
219
- def loop(_, args):
220
- args = move_axes_in(in_axes, args)
221
- args = push_nonbatched(in_axes, args, nonbatched_args)
222
- result = func(*args)
223
- result = move_axes_out(out_axes, result)
224
- return None, result
225
328
 
329
+ def batched_func(
330
+ func: Callable,
331
+ max_io_nbytes: int,
332
+ in_axes: PyTree[int | None],
333
+ out_axes: PyTree[int],
334
+ return_nbatches: bool,
335
+ reduce_ufunc: jnp.ufunc | None,
336
+ warn_on_overflow: bool,
337
+ result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
338
+ args: tuple[PyTree[Array], ...],
339
+ ) -> PyTree[Array]:
340
+ """Implement the wrapper used in `autobatch`."""
341
+ # determine the output structure of the function
342
+ if result_shape_dtype is NotDefined:
343
+ example_result = eval_shape(func, *args)
344
+ else:
345
+ example_result = result_shape_dtype
346
+
347
+ # expand the axes pytrees if they are prefixes
348
+ in_axes = expand_axes(in_axes, args)
349
+ out_axes = expand_axes(out_axes, example_result)
350
+ check_no_nones(out_axes, example_result)
351
+
352
+ # check the axes are valid
353
+ in_axes = normalize_axes(in_axes, args)
354
+ out_axes = normalize_axes(out_axes, example_result)
355
+
356
+ # get the size of the batched axis
357
+ size = extract_size((in_axes, out_axes), (args, example_result))
358
+
359
+ # split arguments in batched and not batched
360
+ original_args = args
361
+ args, nonbatched_args = pull_nonbatched(in_axes, args)
362
+
363
+ # determine the number of batches to respect the memory limit
364
+ total_nbytes = sum_nbytes((args, example_result))
365
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
366
+ min_nbatches = max(1, min_nbatches)
367
+ nbatches = next_divisor(size, min_nbatches)
368
+ assert 1 <= nbatches <= max(1, size)
369
+ assert size % nbatches == 0
370
+ assert total_nbytes % nbatches == 0
371
+
372
+ # warn if the memory limit could not be respected
373
+ batch_nbytes = total_nbytes // nbatches
374
+ if batch_nbytes > max_io_nbytes and warn_on_overflow:
375
+ assert size == nbatches
376
+ msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}'
377
+ warn(msg)
378
+
379
+ # squeeze out the output dims that will be reduced
380
+ if reduce_ufunc is not None:
381
+ example_result = remove_axis(example_result, out_axes, reduce_ufunc)
382
+
383
+ if nbatches > 1:
384
+ # prepare arguments for looping
226
385
  args = move_axes_out(in_axes, args)
227
386
  args = batch(args, nbatches)
228
- _, result = scan(loop, None, args)
229
- result = unbatch(result)
230
- result = move_axes_in(out_axes, result)
231
-
232
- check_same(example_result, result)
233
387
 
234
- if return_nbatches:
235
- return result, nbatches
236
- return result
388
+ # prepare carry for reduction
389
+ if reduce_ufunc is None:
390
+ initial = None
391
+ else:
392
+ initial = identity(reduce_ufunc, example_result)
393
+
394
+ # loop and invoke the function in batches
395
+ loop = partial(
396
+ batching_loop,
397
+ func=func,
398
+ nonbatched_args=nonbatched_args,
399
+ in_axes=in_axes,
400
+ out_axes=out_axes,
401
+ reduce_ufunc=reduce_ufunc,
402
+ )
403
+ reduced_result, result = scan(loop, initial, args)
237
404
 
238
- return batched_func
405
+ # remove auxiliary batching axis and reverse transposition
406
+ if reduce_ufunc is None:
407
+ assert reduced_result is None
408
+ result = unbatch(result)
409
+ result = move_axes_in(out_axes, result)
410
+ else:
411
+ assert result is None
412
+ result = reduced_result
413
+
414
+ # trivial case: no batching needed
415
+ else:
416
+ result = func(*original_args)
417
+ if reduce_ufunc is not None:
418
+ result = reduce(reduce_ufunc, result, out_axes, None)
419
+
420
+ check_same(example_result, result)
421
+
422
+ if return_nbatches:
423
+ return result, nbatches
424
+ return result
425
+
426
+
427
+ def batching_loop(
428
+ initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc
429
+ ):
430
+ """Implement the batching loop in `autobatch`."""
431
+ # evaluate the function
432
+ args = move_axes_in(in_axes, args)
433
+ args = push_nonbatched(in_axes, args, nonbatched_args)
434
+ result = func(*args)
435
+
436
+ # unreduced case: transpose for concatenation and return
437
+ if reduce_ufunc is None:
438
+ result = move_axes_out(out_axes, result)
439
+ return None, result
440
+
441
+ # reduced case: reduce starting from initial
442
+ else:
443
+ reduced_result = reduce(reduce_ufunc, result, out_axes, initial)
444
+ return reduced_result, None
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext/scipy/__init__.py
2
2
  #
3
- # Copyright (c) 2025, Giacomo Petrillo
3
+ # Copyright (c) 2025, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext/scipy/special.py
2
2
  #
3
- # Copyright (c) 2025, Giacomo Petrillo
3
+ # Copyright (c) 2025, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -26,7 +26,7 @@
26
26
 
27
27
  from functools import wraps
28
28
 
29
- from jax import ShapeDtypeStruct, pure_callback
29
+ from jax import ShapeDtypeStruct, jit, pure_callback
30
30
  from jax import numpy as jnp
31
31
  from scipy.special import gammainccinv as scipy_gammainccinv
32
32
 
@@ -45,10 +45,9 @@ def _castto(func, dtype):
45
45
  return newfunc
46
46
 
47
47
 
48
+ @jit
48
49
  def gammainccinv(a, y):
49
50
  """Survival function inverse of the Gamma(a, 1) distribution."""
50
- a = jnp.asarray(a)
51
- y = jnp.asarray(y)
52
51
  shape = jnp.broadcast_shapes(a.shape, y.shape)
53
52
  dtype = _float_type(a.dtype, y.dtype)
54
53
  dummy = ShapeDtypeStruct(shape, dtype)
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/jaxext/scipy/stats.py
2
2
  #
3
- # Copyright (c) 2025, Giacomo Petrillo
3
+ # Copyright (c) 2025, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #