bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/jaxext.py DELETED
@@ -1,423 +0,0 @@
1
- # bartz/src/bartz/jaxext.py
2
- #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
4
- #
5
- # This file is part of bartz.
6
- #
7
- # Permission is hereby granted, free of charge, to any person obtaining a copy
8
- # of this software and associated documentation files (the "Software"), to deal
9
- # in the Software without restriction, including without limitation the rights
10
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- # copies of the Software, and to permit persons to whom the Software is
12
- # furnished to do so, subject to the following conditions:
13
- #
14
- # The above copyright notice and this permission notice shall be included in all
15
- # copies or substantial portions of the Software.
16
- #
17
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- # SOFTWARE.
24
-
25
- """Additions to jax."""
26
-
27
- import functools
28
- import math
29
- import warnings
30
-
31
- import jax
32
- from jax import lax, random, tree_util
33
- from jax import numpy as jnp
34
- from scipy import special
35
-
36
-
37
- def float_type(*args):
38
- """Determine the jax floating point result type given operands/types."""
39
- t = jnp.result_type(*args)
40
- return jnp.sin(jnp.empty(0, t)).dtype
41
-
42
-
43
- def _castto(func, type):
44
- @functools.wraps(func)
45
- def newfunc(*args, **kw):
46
- return func(*args, **kw).astype(type)
47
-
48
- return newfunc
49
-
50
-
51
- class scipy:
52
- """Mockup of the :external:py:mod:`scipy` module."""
53
-
54
- class special:
55
- """Mockup of the :external:py:mod:`scipy.special` module."""
56
-
57
- @staticmethod
58
- def gammainccinv(a, y):
59
- """Survival function inverse of the Gamma(a, 1) distribution."""
60
- a = jnp.asarray(a)
61
- y = jnp.asarray(y)
62
- shape = jnp.broadcast_shapes(a.shape, y.shape)
63
- dtype = float_type(a.dtype, y.dtype)
64
- dummy = jax.ShapeDtypeStruct(shape, dtype)
65
- ufunc = _castto(special.gammainccinv, dtype)
66
- return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
67
-
68
- class stats:
69
- """Mockup of the :external:py:mod:`scipy.stats` module."""
70
-
71
- class invgamma:
72
- """Class that represents the distribution InvGamma(a, 1)."""
73
-
74
- @staticmethod
75
- def ppf(q, a):
76
- """Percentile point function."""
77
- return 1 / scipy.special.gammainccinv(a, q)
78
-
79
-
80
- def vmap_nodoc(fun, *args, **kw):
81
- """
82
- Acts like `jax.vmap` but preserves the docstring of the function unchanged.
83
-
84
- This is useful if the docstring already takes into account that the
85
- arguments have additional axes due to vmap.
86
- """
87
- doc = fun.__doc__
88
- fun = jax.vmap(fun, *args, **kw)
89
- fun.__doc__ = doc
90
- return fun
91
-
92
-
93
- def huge_value(x):
94
- """
95
- Return the maximum value that can be stored in `x`.
96
-
97
- Parameters
98
- ----------
99
- x : array
100
- A numerical numpy or jax array.
101
-
102
- Returns
103
- -------
104
- maxval : scalar
105
- The maximum value allowed by `x`'s type (+inf for floats).
106
- """
107
- if jnp.issubdtype(x.dtype, jnp.integer):
108
- return jnp.iinfo(x.dtype).max
109
- else:
110
- return jnp.inf
111
-
112
-
113
- def minimal_unsigned_dtype(value):
114
- """Return the smallest unsigned integer dtype that can represent `value`."""
115
- if value < 2**8:
116
- return jnp.uint8
117
- if value < 2**16:
118
- return jnp.uint16
119
- if value < 2**32:
120
- return jnp.uint32
121
- return jnp.uint64
122
-
123
-
124
- def signed_to_unsigned(int_dtype):
125
- """
126
- Map a signed integer type to its unsigned counterpart.
127
-
128
- Unsigned types are passed through.
129
- """
130
- assert jnp.issubdtype(int_dtype, jnp.integer)
131
- if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
132
- return int_dtype
133
- if int_dtype == jnp.int8:
134
- return jnp.uint8
135
- if int_dtype == jnp.int16:
136
- return jnp.uint16
137
- if int_dtype == jnp.int32:
138
- return jnp.uint32
139
- if int_dtype == jnp.int64:
140
- return jnp.uint64
141
-
142
-
143
- def ensure_unsigned(x):
144
- """If x has signed integer type, cast it to the unsigned dtype of the same size."""
145
- return x.astype(signed_to_unsigned(x.dtype))
146
-
147
-
148
- @functools.partial(jax.jit, static_argnums=(1,))
149
- def unique(x, size, fill_value):
150
- """
151
- Restricted version of `jax.numpy.unique` that uses less memory.
152
-
153
- Parameters
154
- ----------
155
- x : 1d array
156
- The input array.
157
- size : int
158
- The length of the output.
159
- fill_value : scalar
160
- The value to fill the output with if `size` is greater than the number
161
- of unique values in `x`.
162
-
163
- Returns
164
- -------
165
- out : array (size,)
166
- The unique values in `x`, sorted, and right-padded with `fill_value`.
167
- actual_length : int
168
- The number of used values in `out`.
169
- """
170
- if x.size == 0:
171
- return jnp.full(size, fill_value, x.dtype), 0
172
- if size == 0:
173
- return jnp.empty(0, x.dtype), 0
174
- x = jnp.sort(x)
175
-
176
- def loop(carry, x):
177
- i_out, i_in, last, out = carry
178
- i_out = jnp.where(x == last, i_out, i_out + 1)
179
- out = out.at[i_out].set(x)
180
- return (i_out, i_in + 1, x, out), None
181
-
182
- carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
183
- (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
184
- return out, actual_length + 1
185
-
186
-
187
- def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
188
- """
189
- Batch a function such that each batch is smaller than a threshold.
190
-
191
- Parameters
192
- ----------
193
- func : callable
194
- A jittable function with positional arguments only, with inputs and
195
- outputs pytrees of arrays.
196
- max_io_nbytes : int
197
- The maximum number of input + output bytes in each batch (excluding
198
- unbatched arguments.)
199
- in_axes : pytree of int or None, default 0
200
- A tree matching the structure of the function input, indicating along
201
- which axes each array should be batched. If a single integer, it is
202
- used for all arrays. A `None` axis indicates to not batch an argument.
203
- out_axes : pytree of ints, default 0
204
- The same for outputs (but non-batching is not allowed).
205
- return_nbatches : bool, default False
206
- If True, the number of batches is returned as a second output.
207
-
208
- Returns
209
- -------
210
- batched_func : callable
211
- A function with the same signature as `func`, but that processes the
212
- input and output in batches in a loop.
213
- """
214
-
215
- def expand_axes(axes, tree):
216
- if isinstance(axes, int):
217
- return tree_util.tree_map(lambda _: axes, tree)
218
- return tree_util.tree_map(lambda _, axis: axis, tree, axes)
219
-
220
- def check_no_nones(axes, tree):
221
- def check_not_none(_, axis):
222
- assert axis is not None
223
-
224
- tree_util.tree_map(check_not_none, tree, axes)
225
-
226
- def extract_size(axes, tree):
227
- def get_size(x, axis):
228
- if axis is None:
229
- return None
230
- else:
231
- return x.shape[axis]
232
-
233
- sizes = tree_util.tree_map(get_size, tree, axes)
234
- sizes, _ = tree_util.tree_flatten(sizes)
235
- assert all(s == sizes[0] for s in sizes)
236
- return sizes[0]
237
-
238
- def sum_nbytes(tree):
239
- def nbytes(x):
240
- return math.prod(x.shape) * x.dtype.itemsize
241
-
242
- return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
243
-
244
- def next_divisor_small(dividend, min_divisor):
245
- for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
246
- if dividend % divisor == 0:
247
- return divisor
248
- return dividend
249
-
250
- def next_divisor_large(dividend, min_divisor):
251
- max_inv_divisor = dividend // min_divisor
252
- for inv_divisor in range(max_inv_divisor, 0, -1):
253
- if dividend % inv_divisor == 0:
254
- return dividend // inv_divisor
255
- return dividend
256
-
257
- def next_divisor(dividend, min_divisor):
258
- if dividend == 0:
259
- return min_divisor
260
- if min_divisor * min_divisor <= dividend:
261
- return next_divisor_small(dividend, min_divisor)
262
- return next_divisor_large(dividend, min_divisor)
263
-
264
- def pull_nonbatched(axes, tree):
265
- def pull_nonbatched(x, axis):
266
- if axis is None:
267
- return None
268
- else:
269
- return x
270
-
271
- return tree_util.tree_map(pull_nonbatched, tree, axes), tree
272
-
273
- def push_nonbatched(axes, tree, original_tree):
274
- def push_nonbatched(original_x, x, axis):
275
- if axis is None:
276
- return original_x
277
- else:
278
- return x
279
-
280
- return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
281
-
282
- def move_axes_out(axes, tree):
283
- def move_axis_out(x, axis):
284
- return jnp.moveaxis(x, axis, 0)
285
-
286
- return tree_util.tree_map(move_axis_out, tree, axes)
287
-
288
- def move_axes_in(axes, tree):
289
- def move_axis_in(x, axis):
290
- return jnp.moveaxis(x, 0, axis)
291
-
292
- return tree_util.tree_map(move_axis_in, tree, axes)
293
-
294
- def batch(tree, nbatches):
295
- def batch(x):
296
- return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
297
-
298
- return tree_util.tree_map(batch, tree)
299
-
300
- def unbatch(tree):
301
- def unbatch(x):
302
- return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
303
-
304
- return tree_util.tree_map(unbatch, tree)
305
-
306
- def check_same(tree1, tree2):
307
- def check_same(x1, x2):
308
- assert x1.shape == x2.shape
309
- assert x1.dtype == x2.dtype
310
-
311
- tree_util.tree_map(check_same, tree1, tree2)
312
-
313
- initial_in_axes = in_axes
314
- initial_out_axes = out_axes
315
-
316
- @jax.jit
317
- @functools.wraps(func)
318
- def batched_func(*args):
319
- example_result = jax.eval_shape(func, *args)
320
-
321
- in_axes = expand_axes(initial_in_axes, args)
322
- out_axes = expand_axes(initial_out_axes, example_result)
323
- check_no_nones(out_axes, example_result)
324
-
325
- size = extract_size((in_axes, out_axes), (args, example_result))
326
-
327
- args, nonbatched_args = pull_nonbatched(in_axes, args)
328
-
329
- total_nbytes = sum_nbytes((args, example_result))
330
- min_nbatches = total_nbytes // max_io_nbytes + bool(
331
- total_nbytes % max_io_nbytes
332
- )
333
- min_nbatches = max(1, min_nbatches)
334
- nbatches = next_divisor(size, min_nbatches)
335
- assert 1 <= nbatches <= max(1, size)
336
- assert size % nbatches == 0
337
- assert total_nbytes % nbatches == 0
338
-
339
- batch_nbytes = total_nbytes // nbatches
340
- if batch_nbytes > max_io_nbytes:
341
- assert size == nbatches
342
- warnings.warn(
343
- f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
344
- )
345
-
346
- def loop(_, args):
347
- args = move_axes_in(in_axes, args)
348
- args = push_nonbatched(in_axes, args, nonbatched_args)
349
- result = func(*args)
350
- result = move_axes_out(out_axes, result)
351
- return None, result
352
-
353
- args = move_axes_out(in_axes, args)
354
- args = batch(args, nbatches)
355
- _, result = lax.scan(loop, None, args)
356
- result = unbatch(result)
357
- result = move_axes_in(out_axes, result)
358
-
359
- check_same(example_result, result)
360
-
361
- if return_nbatches:
362
- return result, nbatches
363
- return result
364
-
365
- return batched_func
366
-
367
-
368
- class split:
369
- """
370
- Split a key into `num` keys.
371
-
372
- Parameters
373
- ----------
374
- key : jax.dtypes.prng_key array
375
- The key to split.
376
- num : int
377
- The number of keys to split into.
378
- """
379
-
380
- def __init__(self, key, num=2):
381
- self._keys = random.split(key, num)
382
-
383
- def __len__(self):
384
- return self._keys.size
385
-
386
- def pop(self, shape=None):
387
- """
388
- Pop one or more keys from the list.
389
-
390
- Parameters
391
- ----------
392
- shape : int or tuple of int, optional
393
- The shape of the keys to pop. If `None`, a single key is popped.
394
- If an integer, that many keys are popped. If a tuple, the keys are
395
- reshaped to that shape.
396
-
397
- Returns
398
- -------
399
- keys : jax.dtypes.prng_key array
400
- The popped keys.
401
-
402
- Raises
403
- ------
404
- IndexError
405
- If `shape` is larger than the number of keys left in the list.
406
-
407
- Notes
408
- -----
409
- The keys are popped from the beginning of the list, so for example
410
- ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
411
- """
412
- if shape is None:
413
- shape = ()
414
- elif not isinstance(shape, tuple):
415
- shape = (shape,)
416
- size_to_pop = math.prod(shape)
417
- if size_to_pop > self._keys.size:
418
- raise IndexError(
419
- f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
420
- )
421
- popped_keys = self._keys[:size_to_pop]
422
- self._keys = self._keys[size_to_pop:]
423
- return popped_keys.reshape(shape)