bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext/_autobatch.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Implementation of `autobatch`."""
|
|
26
|
+
|
|
27
|
+
import math
|
|
28
|
+
from collections.abc import Callable
|
|
29
|
+
from functools import partial, wraps
|
|
30
|
+
from warnings import warn
|
|
31
|
+
|
|
32
|
+
from jax.typing import DTypeLike
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from numpy.lib.array_utils import normalize_axis_index # numpy 2
|
|
36
|
+
except ImportError:
|
|
37
|
+
from numpy.core.numeric import normalize_axis_index # numpy 1
|
|
38
|
+
|
|
39
|
+
from jax import ShapeDtypeStruct, eval_shape, jit
|
|
40
|
+
from jax import numpy as jnp
|
|
41
|
+
from jax.lax import scan
|
|
42
|
+
from jax.tree import flatten as tree_flatten
|
|
43
|
+
from jax.tree import map as tree_map
|
|
44
|
+
from jax.tree import reduce as tree_reduce
|
|
45
|
+
from jaxtyping import Array, PyTree, Shaped
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def expand_axes(axes, tree):
|
|
49
|
+
"""Expand `axes` such that they match the pytreedef of `tree`."""
|
|
50
|
+
|
|
51
|
+
def expand_axis(axis, subtree):
|
|
52
|
+
return tree_map(lambda _: axis, subtree)
|
|
53
|
+
|
|
54
|
+
return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def normalize_axes(
|
|
58
|
+
axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T']
|
|
59
|
+
) -> PyTree[int | None, ' T']:
|
|
60
|
+
"""Normalize axes to be non-negative and valid for the corresponding arrays in the tree."""
|
|
61
|
+
|
|
62
|
+
def normalize_axis(axis: int | None, x: Array) -> int | None:
|
|
63
|
+
if axis is None:
|
|
64
|
+
return None
|
|
65
|
+
else:
|
|
66
|
+
return normalize_axis_index(axis, len(x.shape))
|
|
67
|
+
|
|
68
|
+
return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def check_no_nones(axes, tree):
|
|
72
|
+
def check_not_none(_, axis):
|
|
73
|
+
assert axis is not None
|
|
74
|
+
|
|
75
|
+
tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def remove_axis(
|
|
79
|
+
x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
|
|
80
|
+
) -> PyTree[ShapeDtypeStruct, ' T']:
|
|
81
|
+
"""Remove an axis from dummy arrays and change the type to reduction type."""
|
|
82
|
+
|
|
83
|
+
def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct:
|
|
84
|
+
new_shape = x.shape[:axis] + x.shape[axis + 1 :]
|
|
85
|
+
new_dtype = reduction_dtype(ufunc, x.dtype)
|
|
86
|
+
return ShapeDtypeStruct(new_shape, new_dtype)
|
|
87
|
+
|
|
88
|
+
return tree_map(remove_axis, x, axis)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def extract_size(axes, tree):
|
|
92
|
+
"""Get the size of each array in tree at the axis in axes, check they are equal and return it."""
|
|
93
|
+
|
|
94
|
+
def get_size(x, axis):
|
|
95
|
+
if axis is None:
|
|
96
|
+
return None
|
|
97
|
+
else:
|
|
98
|
+
return x.shape[axis]
|
|
99
|
+
|
|
100
|
+
sizes = tree_map(get_size, tree, axes)
|
|
101
|
+
sizes, _ = tree_flatten(sizes)
|
|
102
|
+
assert all(s == sizes[0] for s in sizes)
|
|
103
|
+
return sizes[0]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def sum_nbytes(tree):
|
|
107
|
+
def nbytes(x):
|
|
108
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
109
|
+
|
|
110
|
+
return tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def next_divisor_small(dividend, min_divisor):
|
|
114
|
+
for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
|
|
115
|
+
if dividend % divisor == 0:
|
|
116
|
+
return divisor
|
|
117
|
+
return dividend
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def next_divisor_large(dividend, min_divisor):
|
|
121
|
+
max_inv_divisor = dividend // min_divisor
|
|
122
|
+
for inv_divisor in range(max_inv_divisor, 0, -1):
|
|
123
|
+
if dividend % inv_divisor == 0:
|
|
124
|
+
return dividend // inv_divisor
|
|
125
|
+
return dividend
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def next_divisor(dividend, min_divisor):
|
|
129
|
+
"""Return divisor >= min_divisor such that divided % divisor == 0."""
|
|
130
|
+
if dividend == 0:
|
|
131
|
+
return min_divisor
|
|
132
|
+
if min_divisor * min_divisor <= dividend:
|
|
133
|
+
return next_divisor_small(dividend, min_divisor)
|
|
134
|
+
return next_divisor_large(dividend, min_divisor)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def pull_nonbatched(axes, tree):
|
|
138
|
+
def pull_nonbatched(x, axis):
|
|
139
|
+
if axis is None:
|
|
140
|
+
return None
|
|
141
|
+
else:
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
return tree_map(pull_nonbatched, tree, axes), tree
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def push_nonbatched(axes, tree, original_tree):
|
|
148
|
+
def push_nonbatched(original_x, x, axis):
|
|
149
|
+
if axis is None:
|
|
150
|
+
return original_x
|
|
151
|
+
else:
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
return tree_map(push_nonbatched, original_tree, tree, axes)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def move_axes_out(axes, tree):
|
|
158
|
+
def move_axis_out(x, axis):
|
|
159
|
+
return jnp.moveaxis(x, axis, 0)
|
|
160
|
+
|
|
161
|
+
return tree_map(move_axis_out, tree, axes)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def move_axes_in(axes, tree):
|
|
165
|
+
def move_axis_in(x, axis):
|
|
166
|
+
return jnp.moveaxis(x, 0, axis)
|
|
167
|
+
|
|
168
|
+
return tree_map(move_axis_in, tree, axes)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
|
|
172
|
+
"""Split the first axis into two axes, the first of size `nbatches`."""
|
|
173
|
+
|
|
174
|
+
def batch(x):
|
|
175
|
+
return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
|
|
176
|
+
|
|
177
|
+
return tree_map(batch, tree)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
|
|
181
|
+
"""Merge the first two axes into a single axis."""
|
|
182
|
+
|
|
183
|
+
def unbatch(x):
|
|
184
|
+
return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
|
|
185
|
+
|
|
186
|
+
return tree_map(unbatch, tree)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def reduce(
|
|
190
|
+
ufunc: jnp.ufunc,
|
|
191
|
+
x: PyTree[Array, ' T'],
|
|
192
|
+
axes: PyTree[int, ' T'],
|
|
193
|
+
initial: PyTree[Array, ' T'] | None,
|
|
194
|
+
) -> PyTree[Array, ' T']:
|
|
195
|
+
"""Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
|
|
196
|
+
if initial is None:
|
|
197
|
+
|
|
198
|
+
def reduce(x: Array, axis: int) -> Array:
|
|
199
|
+
return ufunc.reduce(x, axis=axis)
|
|
200
|
+
|
|
201
|
+
return tree_map(reduce, x, axes)
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
|
|
205
|
+
def reduce(x: Array, initial: Array, axis: int) -> Array:
|
|
206
|
+
reduced = ufunc.reduce(x, axis=axis)
|
|
207
|
+
return ufunc(initial, reduced)
|
|
208
|
+
|
|
209
|
+
return tree_map(reduce, x, initial, axes)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def identity(
|
|
213
|
+
ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
|
|
214
|
+
) -> PyTree[Array, ' T']:
|
|
215
|
+
"""Get the identity element for `ufunc` and each array in `x`."""
|
|
216
|
+
|
|
217
|
+
def identity(x: ShapeDtypeStruct) -> Array:
|
|
218
|
+
identity = identity_for(ufunc, x.dtype)
|
|
219
|
+
return jnp.broadcast_to(identity, x.shape)
|
|
220
|
+
|
|
221
|
+
return tree_map(identity, x)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
|
|
225
|
+
"""Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
|
|
226
|
+
return ufunc.reduce(jnp.empty(1, input_dtype)).dtype
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
|
|
230
|
+
"""Return the identity for ufunc as an array scalar with the right dtype."""
|
|
231
|
+
# get output type from input type, e.g., int8 is accumulated to int32
|
|
232
|
+
dtype = reduction_dtype(ufunc, input_dtype)
|
|
233
|
+
|
|
234
|
+
# return as explicitly typed array
|
|
235
|
+
return jnp.array(ufunc.identity, dtype)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def check_same(tree1, tree2):
|
|
239
|
+
def check_same(x1, x2):
|
|
240
|
+
assert x1.shape == x2.shape
|
|
241
|
+
assert x1.dtype == x2.dtype
|
|
242
|
+
|
|
243
|
+
tree_map(check_same, tree1, tree2)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class NotDefined:
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def autobatch(
|
|
251
|
+
func: Callable,
|
|
252
|
+
max_io_nbytes: int,
|
|
253
|
+
in_axes: PyTree[int | None] = 0,
|
|
254
|
+
out_axes: PyTree[int] = 0,
|
|
255
|
+
*,
|
|
256
|
+
return_nbatches: bool = False,
|
|
257
|
+
reduce_ufunc: jnp.ufunc | None = None,
|
|
258
|
+
warn_on_overflow: bool = True,
|
|
259
|
+
result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
|
|
260
|
+
) -> Callable:
|
|
261
|
+
"""
|
|
262
|
+
Batch a function such that each batch is smaller than a threshold.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
func
|
|
267
|
+
A jittable function with positional arguments only, with inputs and
|
|
268
|
+
outputs pytrees of arrays.
|
|
269
|
+
max_io_nbytes
|
|
270
|
+
The maximum number of input + output bytes in each batch (excluding
|
|
271
|
+
unbatched arguments.)
|
|
272
|
+
in_axes
|
|
273
|
+
A tree matching (a prefix of) the structure of the function input,
|
|
274
|
+
indicating along which axes each array should be batched. A `None` axis
|
|
275
|
+
indicates to not batch an argument.
|
|
276
|
+
out_axes
|
|
277
|
+
The same for outputs (but non-batching is not allowed).
|
|
278
|
+
return_nbatches
|
|
279
|
+
If True, the number of batches is returned as a second output.
|
|
280
|
+
reduce_ufunc
|
|
281
|
+
Function used to reduce the output along the batched axis (e.g.,
|
|
282
|
+
`jax.numpy.add`).
|
|
283
|
+
warn_on_overflow
|
|
284
|
+
If True, a warning is raised if the memory limit could not be
|
|
285
|
+
respected.
|
|
286
|
+
result_shape_dtype
|
|
287
|
+
A pytree of dummy arrays matching the expected output. If not provided,
|
|
288
|
+
the function is traced an additional time to determine the output
|
|
289
|
+
structure.
|
|
290
|
+
|
|
291
|
+
Returns
|
|
292
|
+
-------
|
|
293
|
+
A function with the same signature as `func`, save for the return value if `return_nbatches`.
|
|
294
|
+
|
|
295
|
+
Notes
|
|
296
|
+
-----
|
|
297
|
+
Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
|
|
298
|
+
arguments is idempotent. Furthermore, `autobatch` can be applied multiple
|
|
299
|
+
times over multiple axes with the same `max_io_nbytes` limit to work on
|
|
300
|
+
multiple axes; in this case it won't unnecessarily loop over additional axes
|
|
301
|
+
if one or more outer `autobatch` are already sufficient.
|
|
302
|
+
|
|
303
|
+
To handle memory used in intermediate values: assuming all intermediate
|
|
304
|
+
values have size that scales linearly with the axis batched over, say the
|
|
305
|
+
batched input/output total size is ``batched_size * core_io_size``, and the
|
|
306
|
+
intermediate values have size ``batched_size * core_int_size``, then to take
|
|
307
|
+
them into account divide `max_io_nbytes` by ``(1 + core_int_size /
|
|
308
|
+
core_io_size)``.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
@jit
|
|
312
|
+
@wraps(func)
|
|
313
|
+
def autobatch_wrapper(*args):
|
|
314
|
+
return batched_func(
|
|
315
|
+
func,
|
|
316
|
+
max_io_nbytes,
|
|
317
|
+
in_axes,
|
|
318
|
+
out_axes,
|
|
319
|
+
return_nbatches,
|
|
320
|
+
reduce_ufunc,
|
|
321
|
+
warn_on_overflow,
|
|
322
|
+
result_shape_dtype,
|
|
323
|
+
args,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return autobatch_wrapper
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def batched_func(
|
|
330
|
+
func: Callable,
|
|
331
|
+
max_io_nbytes: int,
|
|
332
|
+
in_axes: PyTree[int | None],
|
|
333
|
+
out_axes: PyTree[int],
|
|
334
|
+
return_nbatches: bool,
|
|
335
|
+
reduce_ufunc: jnp.ufunc | None,
|
|
336
|
+
warn_on_overflow: bool,
|
|
337
|
+
result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
|
|
338
|
+
args: tuple[PyTree[Array], ...],
|
|
339
|
+
) -> PyTree[Array]:
|
|
340
|
+
"""Implement the wrapper used in `autobatch`."""
|
|
341
|
+
# determine the output structure of the function
|
|
342
|
+
if result_shape_dtype is NotDefined:
|
|
343
|
+
example_result = eval_shape(func, *args)
|
|
344
|
+
else:
|
|
345
|
+
example_result = result_shape_dtype
|
|
346
|
+
|
|
347
|
+
# expand the axes pytrees if they are prefixes
|
|
348
|
+
in_axes = expand_axes(in_axes, args)
|
|
349
|
+
out_axes = expand_axes(out_axes, example_result)
|
|
350
|
+
check_no_nones(out_axes, example_result)
|
|
351
|
+
|
|
352
|
+
# check the axes are valid
|
|
353
|
+
in_axes = normalize_axes(in_axes, args)
|
|
354
|
+
out_axes = normalize_axes(out_axes, example_result)
|
|
355
|
+
|
|
356
|
+
# get the size of the batched axis
|
|
357
|
+
size = extract_size((in_axes, out_axes), (args, example_result))
|
|
358
|
+
|
|
359
|
+
# split arguments in batched and not batched
|
|
360
|
+
original_args = args
|
|
361
|
+
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
362
|
+
|
|
363
|
+
# determine the number of batches to respect the memory limit
|
|
364
|
+
total_nbytes = sum_nbytes((args, example_result))
|
|
365
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
366
|
+
min_nbatches = max(1, min_nbatches)
|
|
367
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
368
|
+
assert 1 <= nbatches <= max(1, size)
|
|
369
|
+
assert size % nbatches == 0
|
|
370
|
+
assert total_nbytes % nbatches == 0
|
|
371
|
+
|
|
372
|
+
# warn if the memory limit could not be respected
|
|
373
|
+
batch_nbytes = total_nbytes // nbatches
|
|
374
|
+
if batch_nbytes > max_io_nbytes and warn_on_overflow:
|
|
375
|
+
assert size == nbatches
|
|
376
|
+
msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}'
|
|
377
|
+
warn(msg)
|
|
378
|
+
|
|
379
|
+
# squeeze out the output dims that will be reduced
|
|
380
|
+
if reduce_ufunc is not None:
|
|
381
|
+
example_result = remove_axis(example_result, out_axes, reduce_ufunc)
|
|
382
|
+
|
|
383
|
+
if nbatches > 1:
|
|
384
|
+
# prepare arguments for looping
|
|
385
|
+
args = move_axes_out(in_axes, args)
|
|
386
|
+
args = batch(args, nbatches)
|
|
387
|
+
|
|
388
|
+
# prepare carry for reduction
|
|
389
|
+
if reduce_ufunc is None:
|
|
390
|
+
initial = None
|
|
391
|
+
else:
|
|
392
|
+
initial = identity(reduce_ufunc, example_result)
|
|
393
|
+
|
|
394
|
+
# loop and invoke the function in batches
|
|
395
|
+
loop = partial(
|
|
396
|
+
batching_loop,
|
|
397
|
+
func=func,
|
|
398
|
+
nonbatched_args=nonbatched_args,
|
|
399
|
+
in_axes=in_axes,
|
|
400
|
+
out_axes=out_axes,
|
|
401
|
+
reduce_ufunc=reduce_ufunc,
|
|
402
|
+
)
|
|
403
|
+
reduced_result, result = scan(loop, initial, args)
|
|
404
|
+
|
|
405
|
+
# remove auxiliary batching axis and reverse transposition
|
|
406
|
+
if reduce_ufunc is None:
|
|
407
|
+
assert reduced_result is None
|
|
408
|
+
result = unbatch(result)
|
|
409
|
+
result = move_axes_in(out_axes, result)
|
|
410
|
+
else:
|
|
411
|
+
assert result is None
|
|
412
|
+
result = reduced_result
|
|
413
|
+
|
|
414
|
+
# trivial case: no batching needed
|
|
415
|
+
else:
|
|
416
|
+
result = func(*original_args)
|
|
417
|
+
if reduce_ufunc is not None:
|
|
418
|
+
result = reduce(reduce_ufunc, result, out_axes, None)
|
|
419
|
+
|
|
420
|
+
check_same(example_result, result)
|
|
421
|
+
|
|
422
|
+
if return_nbatches:
|
|
423
|
+
return result, nbatches
|
|
424
|
+
return result
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def batching_loop(
|
|
428
|
+
initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc
|
|
429
|
+
):
|
|
430
|
+
"""Implement the batching loop in `autobatch`."""
|
|
431
|
+
# evaluate the function
|
|
432
|
+
args = move_axes_in(in_axes, args)
|
|
433
|
+
args = push_nonbatched(in_axes, args, nonbatched_args)
|
|
434
|
+
result = func(*args)
|
|
435
|
+
|
|
436
|
+
# unreduced case: transpose for concatenation and return
|
|
437
|
+
if reduce_ufunc is None:
|
|
438
|
+
result = move_axes_out(out_axes, result)
|
|
439
|
+
return None, result
|
|
440
|
+
|
|
441
|
+
# reduced case: reduce starting from initial
|
|
442
|
+
else:
|
|
443
|
+
reduced_result = reduce(reduce_ufunc, result, out_axes, initial)
|
|
444
|
+
return reduced_result, None
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext/scipy/__init__.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Mockup of the :external:py:mod:`scipy` module."""
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext/scipy/special.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Mockup of the :external:py:mod:`scipy.special` module."""
|
|
26
|
+
|
|
27
|
+
from functools import wraps
|
|
28
|
+
|
|
29
|
+
from jax import ShapeDtypeStruct, jit, pure_callback
|
|
30
|
+
from jax import numpy as jnp
|
|
31
|
+
from scipy.special import gammainccinv as scipy_gammainccinv
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _float_type(*args):
|
|
35
|
+
"""Determine the jax floating point result type given operands/types."""
|
|
36
|
+
t = jnp.result_type(*args)
|
|
37
|
+
return jnp.sin(jnp.empty(0, t)).dtype
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _castto(func, dtype):
|
|
41
|
+
@wraps(func)
|
|
42
|
+
def newfunc(*args, **kw):
|
|
43
|
+
return func(*args, **kw).astype(dtype)
|
|
44
|
+
|
|
45
|
+
return newfunc
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@jit
|
|
49
|
+
def gammainccinv(a, y):
|
|
50
|
+
"""Survival function inverse of the Gamma(a, 1) distribution."""
|
|
51
|
+
shape = jnp.broadcast_shapes(a.shape, y.shape)
|
|
52
|
+
dtype = _float_type(a.dtype, y.dtype)
|
|
53
|
+
dummy = ShapeDtypeStruct(shape, dtype)
|
|
54
|
+
ufunc = _castto(scipy_gammainccinv, dtype)
|
|
55
|
+
return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
################# COPIED AND ADAPTED FROM JAX ##################
|
|
59
|
+
# Copyright 2018 The JAX Authors.
|
|
60
|
+
#
|
|
61
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
62
|
+
# you may not use this file except in compliance with the License.
|
|
63
|
+
# You may obtain a copy of the License at
|
|
64
|
+
#
|
|
65
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
66
|
+
#
|
|
67
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
68
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
69
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
70
|
+
# See the License for the specific language governing permissions and
|
|
71
|
+
# limitations under the License.
|
|
72
|
+
|
|
73
|
+
import numpy as np
|
|
74
|
+
from jax import debug_infs, lax
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def ndtri(p):
|
|
78
|
+
"""Compute the inverse of the CDF of the Normal distribution function.
|
|
79
|
+
|
|
80
|
+
This is a patch of `jax.scipy.special.ndtri`.
|
|
81
|
+
"""
|
|
82
|
+
dtype = lax.dtype(p)
|
|
83
|
+
if dtype not in (jnp.float32, jnp.float64):
|
|
84
|
+
msg = f'x.dtype={dtype} is not supported, see docstring for supported types.'
|
|
85
|
+
raise TypeError(msg)
|
|
86
|
+
return _ndtri(p)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _ndtri(p):
|
|
90
|
+
# Constants used in piece-wise rational approximations. Taken from the cephes
|
|
91
|
+
# library:
|
|
92
|
+
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
|
93
|
+
p0 = list(
|
|
94
|
+
reversed(
|
|
95
|
+
[
|
|
96
|
+
-5.99633501014107895267e1,
|
|
97
|
+
9.80010754185999661536e1,
|
|
98
|
+
-5.66762857469070293439e1,
|
|
99
|
+
1.39312609387279679503e1,
|
|
100
|
+
-1.23916583867381258016e0,
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
q0 = list(
|
|
105
|
+
reversed(
|
|
106
|
+
[
|
|
107
|
+
1.0,
|
|
108
|
+
1.95448858338141759834e0,
|
|
109
|
+
4.67627912898881538453e0,
|
|
110
|
+
8.63602421390890590575e1,
|
|
111
|
+
-2.25462687854119370527e2,
|
|
112
|
+
2.00260212380060660359e2,
|
|
113
|
+
-8.20372256168333339912e1,
|
|
114
|
+
1.59056225126211695515e1,
|
|
115
|
+
-1.18331621121330003142e0,
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
p1 = list(
|
|
120
|
+
reversed(
|
|
121
|
+
[
|
|
122
|
+
4.05544892305962419923e0,
|
|
123
|
+
3.15251094599893866154e1,
|
|
124
|
+
5.71628192246421288162e1,
|
|
125
|
+
4.40805073893200834700e1,
|
|
126
|
+
1.46849561928858024014e1,
|
|
127
|
+
2.18663306850790267539e0,
|
|
128
|
+
-1.40256079171354495875e-1,
|
|
129
|
+
-3.50424626827848203418e-2,
|
|
130
|
+
-8.57456785154685413611e-4,
|
|
131
|
+
]
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
q1 = list(
|
|
135
|
+
reversed(
|
|
136
|
+
[
|
|
137
|
+
1.0,
|
|
138
|
+
1.57799883256466749731e1,
|
|
139
|
+
4.53907635128879210584e1,
|
|
140
|
+
4.13172038254672030440e1,
|
|
141
|
+
1.50425385692907503408e1,
|
|
142
|
+
2.50464946208309415979e0,
|
|
143
|
+
-1.42182922854787788574e-1,
|
|
144
|
+
-3.80806407691578277194e-2,
|
|
145
|
+
-9.33259480895457427372e-4,
|
|
146
|
+
]
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
p2 = list(
|
|
150
|
+
reversed(
|
|
151
|
+
[
|
|
152
|
+
3.23774891776946035970e0,
|
|
153
|
+
6.91522889068984211695e0,
|
|
154
|
+
3.93881025292474443415e0,
|
|
155
|
+
1.33303460815807542389e0,
|
|
156
|
+
2.01485389549179081538e-1,
|
|
157
|
+
1.23716634817820021358e-2,
|
|
158
|
+
3.01581553508235416007e-4,
|
|
159
|
+
2.65806974686737550832e-6,
|
|
160
|
+
6.23974539184983293730e-9,
|
|
161
|
+
]
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
q2 = list(
|
|
165
|
+
reversed(
|
|
166
|
+
[
|
|
167
|
+
1.0,
|
|
168
|
+
6.02427039364742014255e0,
|
|
169
|
+
3.67983563856160859403e0,
|
|
170
|
+
1.37702099489081330271e0,
|
|
171
|
+
2.16236993594496635890e-1,
|
|
172
|
+
1.34204006088543189037e-2,
|
|
173
|
+
3.28014464682127739104e-4,
|
|
174
|
+
2.89247864745380683936e-6,
|
|
175
|
+
6.79019408009981274425e-9,
|
|
176
|
+
]
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
dtype = lax.dtype(p).type
|
|
181
|
+
shape = jnp.shape(p)
|
|
182
|
+
|
|
183
|
+
def _create_polynomial(var, coeffs):
|
|
184
|
+
"""Compute n_th order polynomial via Horner's method."""
|
|
185
|
+
coeffs = np.array(coeffs, dtype)
|
|
186
|
+
if not coeffs.size:
|
|
187
|
+
return jnp.zeros_like(var)
|
|
188
|
+
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
|
|
189
|
+
|
|
190
|
+
maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p)
|
|
191
|
+
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
|
|
192
|
+
# later on. The result from the computation when p == 0 is not used so any
|
|
193
|
+
# number that doesn't result in NaNs is fine.
|
|
194
|
+
sanitized_mcp = jnp.where(
|
|
195
|
+
maybe_complement_p == dtype(0.0),
|
|
196
|
+
jnp.full(shape, dtype(0.5)),
|
|
197
|
+
maybe_complement_p,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
|
|
201
|
+
w = sanitized_mcp - dtype(0.5)
|
|
202
|
+
ww = lax.square(w)
|
|
203
|
+
x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0))
|
|
204
|
+
x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi))
|
|
205
|
+
|
|
206
|
+
# Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
|
|
207
|
+
# where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
|
|
208
|
+
# arrays based on whether p < exp(-32).
|
|
209
|
+
z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp))
|
|
210
|
+
first_term = z - lax.log(z) / z
|
|
211
|
+
second_term_small_p = (
|
|
212
|
+
_create_polynomial(dtype(1.0) / z, p2)
|
|
213
|
+
/ _create_polynomial(dtype(1.0) / z, q2)
|
|
214
|
+
/ z
|
|
215
|
+
)
|
|
216
|
+
second_term_otherwise = (
|
|
217
|
+
_create_polynomial(dtype(1.0) / z, p1)
|
|
218
|
+
/ _create_polynomial(dtype(1.0) / z, q1)
|
|
219
|
+
/ z
|
|
220
|
+
)
|
|
221
|
+
x_for_small_p = first_term - second_term_small_p
|
|
222
|
+
x_otherwise = first_term - second_term_otherwise
|
|
223
|
+
|
|
224
|
+
x = jnp.where(
|
|
225
|
+
sanitized_mcp > dtype(np.exp(-2.0)),
|
|
226
|
+
x_for_big_p,
|
|
227
|
+
jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise),
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x)
|
|
231
|
+
with debug_infs(False):
|
|
232
|
+
infinity = jnp.full(shape, dtype(np.inf))
|
|
233
|
+
neg_infinity = -infinity
|
|
234
|
+
return jnp.where(
|
|
235
|
+
p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
################################################################
|