bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,444 @@
1
+ # bartz/src/bartz/jaxext/_autobatch.py
2
+ #
3
+ # Copyright (c) 2025-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implementation of `autobatch`."""
26
+
27
+ import math
28
+ from collections.abc import Callable
29
+ from functools import partial, wraps
30
+ from warnings import warn
31
+
32
+ from jax.typing import DTypeLike
33
+
34
+ try:
35
+ from numpy.lib.array_utils import normalize_axis_index # numpy 2
36
+ except ImportError:
37
+ from numpy.core.numeric import normalize_axis_index # numpy 1
38
+
39
+ from jax import ShapeDtypeStruct, eval_shape, jit
40
+ from jax import numpy as jnp
41
+ from jax.lax import scan
42
+ from jax.tree import flatten as tree_flatten
43
+ from jax.tree import map as tree_map
44
+ from jax.tree import reduce as tree_reduce
45
+ from jaxtyping import Array, PyTree, Shaped
46
+
47
+
48
+ def expand_axes(axes, tree):
49
+ """Expand `axes` such that they match the pytreedef of `tree`."""
50
+
51
+ def expand_axis(axis, subtree):
52
+ return tree_map(lambda _: axis, subtree)
53
+
54
+ return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None)
55
+
56
+
57
+ def normalize_axes(
58
+ axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T']
59
+ ) -> PyTree[int | None, ' T']:
60
+ """Normalize axes to be non-negative and valid for the corresponding arrays in the tree."""
61
+
62
+ def normalize_axis(axis: int | None, x: Array) -> int | None:
63
+ if axis is None:
64
+ return None
65
+ else:
66
+ return normalize_axis_index(axis, len(x.shape))
67
+
68
+ return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None)
69
+
70
+
71
+ def check_no_nones(axes, tree):
72
+ def check_not_none(_, axis):
73
+ assert axis is not None
74
+
75
+ tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None)
76
+
77
+
78
+ def remove_axis(
79
+ x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
80
+ ) -> PyTree[ShapeDtypeStruct, ' T']:
81
+ """Remove an axis from dummy arrays and change the type to reduction type."""
82
+
83
+ def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct:
84
+ new_shape = x.shape[:axis] + x.shape[axis + 1 :]
85
+ new_dtype = reduction_dtype(ufunc, x.dtype)
86
+ return ShapeDtypeStruct(new_shape, new_dtype)
87
+
88
+ return tree_map(remove_axis, x, axis)
89
+
90
+
91
+ def extract_size(axes, tree):
92
+ """Get the size of each array in tree at the axis in axes, check they are equal and return it."""
93
+
94
+ def get_size(x, axis):
95
+ if axis is None:
96
+ return None
97
+ else:
98
+ return x.shape[axis]
99
+
100
+ sizes = tree_map(get_size, tree, axes)
101
+ sizes, _ = tree_flatten(sizes)
102
+ assert all(s == sizes[0] for s in sizes)
103
+ return sizes[0]
104
+
105
+
106
+ def sum_nbytes(tree):
107
+ def nbytes(x):
108
+ return math.prod(x.shape) * x.dtype.itemsize
109
+
110
+ return tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
111
+
112
+
113
+ def next_divisor_small(dividend, min_divisor):
114
+ for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
115
+ if dividend % divisor == 0:
116
+ return divisor
117
+ return dividend
118
+
119
+
120
+ def next_divisor_large(dividend, min_divisor):
121
+ max_inv_divisor = dividend // min_divisor
122
+ for inv_divisor in range(max_inv_divisor, 0, -1):
123
+ if dividend % inv_divisor == 0:
124
+ return dividend // inv_divisor
125
+ return dividend
126
+
127
+
128
+ def next_divisor(dividend, min_divisor):
129
+ """Return divisor >= min_divisor such that divided % divisor == 0."""
130
+ if dividend == 0:
131
+ return min_divisor
132
+ if min_divisor * min_divisor <= dividend:
133
+ return next_divisor_small(dividend, min_divisor)
134
+ return next_divisor_large(dividend, min_divisor)
135
+
136
+
137
+ def pull_nonbatched(axes, tree):
138
+ def pull_nonbatched(x, axis):
139
+ if axis is None:
140
+ return None
141
+ else:
142
+ return x
143
+
144
+ return tree_map(pull_nonbatched, tree, axes), tree
145
+
146
+
147
+ def push_nonbatched(axes, tree, original_tree):
148
+ def push_nonbatched(original_x, x, axis):
149
+ if axis is None:
150
+ return original_x
151
+ else:
152
+ return x
153
+
154
+ return tree_map(push_nonbatched, original_tree, tree, axes)
155
+
156
+
157
+ def move_axes_out(axes, tree):
158
+ def move_axis_out(x, axis):
159
+ return jnp.moveaxis(x, axis, 0)
160
+
161
+ return tree_map(move_axis_out, tree, axes)
162
+
163
+
164
+ def move_axes_in(axes, tree):
165
+ def move_axis_in(x, axis):
166
+ return jnp.moveaxis(x, 0, axis)
167
+
168
+ return tree_map(move_axis_in, tree, axes)
169
+
170
+
171
+ def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
172
+ """Split the first axis into two axes, the first of size `nbatches`."""
173
+
174
+ def batch(x):
175
+ return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
176
+
177
+ return tree_map(batch, tree)
178
+
179
+
180
+ def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
181
+ """Merge the first two axes into a single axis."""
182
+
183
+ def unbatch(x):
184
+ return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
185
+
186
+ return tree_map(unbatch, tree)
187
+
188
+
189
+ def reduce(
190
+ ufunc: jnp.ufunc,
191
+ x: PyTree[Array, ' T'],
192
+ axes: PyTree[int, ' T'],
193
+ initial: PyTree[Array, ' T'] | None,
194
+ ) -> PyTree[Array, ' T']:
195
+ """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
196
+ if initial is None:
197
+
198
+ def reduce(x: Array, axis: int) -> Array:
199
+ return ufunc.reduce(x, axis=axis)
200
+
201
+ return tree_map(reduce, x, axes)
202
+
203
+ else:
204
+
205
+ def reduce(x: Array, initial: Array, axis: int) -> Array:
206
+ reduced = ufunc.reduce(x, axis=axis)
207
+ return ufunc(initial, reduced)
208
+
209
+ return tree_map(reduce, x, initial, axes)
210
+
211
+
212
+ def identity(
213
+ ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
214
+ ) -> PyTree[Array, ' T']:
215
+ """Get the identity element for `ufunc` and each array in `x`."""
216
+
217
+ def identity(x: ShapeDtypeStruct) -> Array:
218
+ identity = identity_for(ufunc, x.dtype)
219
+ return jnp.broadcast_to(identity, x.shape)
220
+
221
+ return tree_map(identity, x)
222
+
223
+
224
+ def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
225
+ """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
226
+ return ufunc.reduce(jnp.empty(1, input_dtype)).dtype
227
+
228
+
229
+ def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
230
+ """Return the identity for ufunc as an array scalar with the right dtype."""
231
+ # get output type from input type, e.g., int8 is accumulated to int32
232
+ dtype = reduction_dtype(ufunc, input_dtype)
233
+
234
+ # return as explicitly typed array
235
+ return jnp.array(ufunc.identity, dtype)
236
+
237
+
238
+ def check_same(tree1, tree2):
239
+ def check_same(x1, x2):
240
+ assert x1.shape == x2.shape
241
+ assert x1.dtype == x2.dtype
242
+
243
+ tree_map(check_same, tree1, tree2)
244
+
245
+
246
+ class NotDefined:
247
+ pass
248
+
249
+
250
+ def autobatch(
251
+ func: Callable,
252
+ max_io_nbytes: int,
253
+ in_axes: PyTree[int | None] = 0,
254
+ out_axes: PyTree[int] = 0,
255
+ *,
256
+ return_nbatches: bool = False,
257
+ reduce_ufunc: jnp.ufunc | None = None,
258
+ warn_on_overflow: bool = True,
259
+ result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
260
+ ) -> Callable:
261
+ """
262
+ Batch a function such that each batch is smaller than a threshold.
263
+
264
+ Parameters
265
+ ----------
266
+ func
267
+ A jittable function with positional arguments only, with inputs and
268
+ outputs pytrees of arrays.
269
+ max_io_nbytes
270
+ The maximum number of input + output bytes in each batch (excluding
271
+ unbatched arguments.)
272
+ in_axes
273
+ A tree matching (a prefix of) the structure of the function input,
274
+ indicating along which axes each array should be batched. A `None` axis
275
+ indicates to not batch an argument.
276
+ out_axes
277
+ The same for outputs (but non-batching is not allowed).
278
+ return_nbatches
279
+ If True, the number of batches is returned as a second output.
280
+ reduce_ufunc
281
+ Function used to reduce the output along the batched axis (e.g.,
282
+ `jax.numpy.add`).
283
+ warn_on_overflow
284
+ If True, a warning is raised if the memory limit could not be
285
+ respected.
286
+ result_shape_dtype
287
+ A pytree of dummy arrays matching the expected output. If not provided,
288
+ the function is traced an additional time to determine the output
289
+ structure.
290
+
291
+ Returns
292
+ -------
293
+ A function with the same signature as `func`, save for the return value if `return_nbatches`.
294
+
295
+ Notes
296
+ -----
297
+ Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
298
+ arguments is idempotent. Furthermore, `autobatch` can be applied multiple
299
+ times over multiple axes with the same `max_io_nbytes` limit to work on
300
+ multiple axes; in this case it won't unnecessarily loop over additional axes
301
+ if one or more outer `autobatch` are already sufficient.
302
+
303
+ To handle memory used in intermediate values: assuming all intermediate
304
+ values have size that scales linearly with the axis batched over, say the
305
+ batched input/output total size is ``batched_size * core_io_size``, and the
306
+ intermediate values have size ``batched_size * core_int_size``, then to take
307
+ them into account divide `max_io_nbytes` by ``(1 + core_int_size /
308
+ core_io_size)``.
309
+ """
310
+
311
+ @jit
312
+ @wraps(func)
313
+ def autobatch_wrapper(*args):
314
+ return batched_func(
315
+ func,
316
+ max_io_nbytes,
317
+ in_axes,
318
+ out_axes,
319
+ return_nbatches,
320
+ reduce_ufunc,
321
+ warn_on_overflow,
322
+ result_shape_dtype,
323
+ args,
324
+ )
325
+
326
+ return autobatch_wrapper
327
+
328
+
329
+ def batched_func(
330
+ func: Callable,
331
+ max_io_nbytes: int,
332
+ in_axes: PyTree[int | None],
333
+ out_axes: PyTree[int],
334
+ return_nbatches: bool,
335
+ reduce_ufunc: jnp.ufunc | None,
336
+ warn_on_overflow: bool,
337
+ result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
338
+ args: tuple[PyTree[Array], ...],
339
+ ) -> PyTree[Array]:
340
+ """Implement the wrapper used in `autobatch`."""
341
+ # determine the output structure of the function
342
+ if result_shape_dtype is NotDefined:
343
+ example_result = eval_shape(func, *args)
344
+ else:
345
+ example_result = result_shape_dtype
346
+
347
+ # expand the axes pytrees if they are prefixes
348
+ in_axes = expand_axes(in_axes, args)
349
+ out_axes = expand_axes(out_axes, example_result)
350
+ check_no_nones(out_axes, example_result)
351
+
352
+ # check the axes are valid
353
+ in_axes = normalize_axes(in_axes, args)
354
+ out_axes = normalize_axes(out_axes, example_result)
355
+
356
+ # get the size of the batched axis
357
+ size = extract_size((in_axes, out_axes), (args, example_result))
358
+
359
+ # split arguments in batched and not batched
360
+ original_args = args
361
+ args, nonbatched_args = pull_nonbatched(in_axes, args)
362
+
363
+ # determine the number of batches to respect the memory limit
364
+ total_nbytes = sum_nbytes((args, example_result))
365
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
366
+ min_nbatches = max(1, min_nbatches)
367
+ nbatches = next_divisor(size, min_nbatches)
368
+ assert 1 <= nbatches <= max(1, size)
369
+ assert size % nbatches == 0
370
+ assert total_nbytes % nbatches == 0
371
+
372
+ # warn if the memory limit could not be respected
373
+ batch_nbytes = total_nbytes // nbatches
374
+ if batch_nbytes > max_io_nbytes and warn_on_overflow:
375
+ assert size == nbatches
376
+ msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}'
377
+ warn(msg)
378
+
379
+ # squeeze out the output dims that will be reduced
380
+ if reduce_ufunc is not None:
381
+ example_result = remove_axis(example_result, out_axes, reduce_ufunc)
382
+
383
+ if nbatches > 1:
384
+ # prepare arguments for looping
385
+ args = move_axes_out(in_axes, args)
386
+ args = batch(args, nbatches)
387
+
388
+ # prepare carry for reduction
389
+ if reduce_ufunc is None:
390
+ initial = None
391
+ else:
392
+ initial = identity(reduce_ufunc, example_result)
393
+
394
+ # loop and invoke the function in batches
395
+ loop = partial(
396
+ batching_loop,
397
+ func=func,
398
+ nonbatched_args=nonbatched_args,
399
+ in_axes=in_axes,
400
+ out_axes=out_axes,
401
+ reduce_ufunc=reduce_ufunc,
402
+ )
403
+ reduced_result, result = scan(loop, initial, args)
404
+
405
+ # remove auxiliary batching axis and reverse transposition
406
+ if reduce_ufunc is None:
407
+ assert reduced_result is None
408
+ result = unbatch(result)
409
+ result = move_axes_in(out_axes, result)
410
+ else:
411
+ assert result is None
412
+ result = reduced_result
413
+
414
+ # trivial case: no batching needed
415
+ else:
416
+ result = func(*original_args)
417
+ if reduce_ufunc is not None:
418
+ result = reduce(reduce_ufunc, result, out_axes, None)
419
+
420
+ check_same(example_result, result)
421
+
422
+ if return_nbatches:
423
+ return result, nbatches
424
+ return result
425
+
426
+
427
+ def batching_loop(
428
+ initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc
429
+ ):
430
+ """Implement the batching loop in `autobatch`."""
431
+ # evaluate the function
432
+ args = move_axes_in(in_axes, args)
433
+ args = push_nonbatched(in_axes, args, nonbatched_args)
434
+ result = func(*args)
435
+
436
+ # unreduced case: transpose for concatenation and return
437
+ if reduce_ufunc is None:
438
+ result = move_axes_out(out_axes, result)
439
+ return None, result
440
+
441
+ # reduced case: reduce starting from initial
442
+ else:
443
+ reduced_result = reduce(reduce_ufunc, result, out_axes, initial)
444
+ return reduced_result, None
@@ -0,0 +1,25 @@
1
+ # bartz/src/bartz/jaxext/scipy/__init__.py
2
+ #
3
+ # Copyright (c) 2025, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Mockup of the :external:py:mod:`scipy` module."""
@@ -0,0 +1,239 @@
1
+ # bartz/src/bartz/jaxext/scipy/special.py
2
+ #
3
+ # Copyright (c) 2025, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Mockup of the :external:py:mod:`scipy.special` module."""
26
+
27
+ from functools import wraps
28
+
29
+ from jax import ShapeDtypeStruct, jit, pure_callback
30
+ from jax import numpy as jnp
31
+ from scipy.special import gammainccinv as scipy_gammainccinv
32
+
33
+
34
+ def _float_type(*args):
35
+ """Determine the jax floating point result type given operands/types."""
36
+ t = jnp.result_type(*args)
37
+ return jnp.sin(jnp.empty(0, t)).dtype
38
+
39
+
40
+ def _castto(func, dtype):
41
+ @wraps(func)
42
+ def newfunc(*args, **kw):
43
+ return func(*args, **kw).astype(dtype)
44
+
45
+ return newfunc
46
+
47
+
48
+ @jit
49
+ def gammainccinv(a, y):
50
+ """Survival function inverse of the Gamma(a, 1) distribution."""
51
+ shape = jnp.broadcast_shapes(a.shape, y.shape)
52
+ dtype = _float_type(a.dtype, y.dtype)
53
+ dummy = ShapeDtypeStruct(shape, dtype)
54
+ ufunc = _castto(scipy_gammainccinv, dtype)
55
+ return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
56
+
57
+
58
+ ################# COPIED AND ADAPTED FROM JAX ##################
59
+ # Copyright 2018 The JAX Authors.
60
+ #
61
+ # Licensed under the Apache License, Version 2.0 (the "License");
62
+ # you may not use this file except in compliance with the License.
63
+ # You may obtain a copy of the License at
64
+ #
65
+ # https://www.apache.org/licenses/LICENSE-2.0
66
+ #
67
+ # Unless required by applicable law or agreed to in writing, software
68
+ # distributed under the License is distributed on an "AS IS" BASIS,
69
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70
+ # See the License for the specific language governing permissions and
71
+ # limitations under the License.
72
+
73
+ import numpy as np
74
+ from jax import debug_infs, lax
75
+
76
+
77
+ def ndtri(p):
78
+ """Compute the inverse of the CDF of the Normal distribution function.
79
+
80
+ This is a patch of `jax.scipy.special.ndtri`.
81
+ """
82
+ dtype = lax.dtype(p)
83
+ if dtype not in (jnp.float32, jnp.float64):
84
+ msg = f'x.dtype={dtype} is not supported, see docstring for supported types.'
85
+ raise TypeError(msg)
86
+ return _ndtri(p)
87
+
88
+
89
+ def _ndtri(p):
90
+ # Constants used in piece-wise rational approximations. Taken from the cephes
91
+ # library:
92
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
93
+ p0 = list(
94
+ reversed(
95
+ [
96
+ -5.99633501014107895267e1,
97
+ 9.80010754185999661536e1,
98
+ -5.66762857469070293439e1,
99
+ 1.39312609387279679503e1,
100
+ -1.23916583867381258016e0,
101
+ ]
102
+ )
103
+ )
104
+ q0 = list(
105
+ reversed(
106
+ [
107
+ 1.0,
108
+ 1.95448858338141759834e0,
109
+ 4.67627912898881538453e0,
110
+ 8.63602421390890590575e1,
111
+ -2.25462687854119370527e2,
112
+ 2.00260212380060660359e2,
113
+ -8.20372256168333339912e1,
114
+ 1.59056225126211695515e1,
115
+ -1.18331621121330003142e0,
116
+ ]
117
+ )
118
+ )
119
+ p1 = list(
120
+ reversed(
121
+ [
122
+ 4.05544892305962419923e0,
123
+ 3.15251094599893866154e1,
124
+ 5.71628192246421288162e1,
125
+ 4.40805073893200834700e1,
126
+ 1.46849561928858024014e1,
127
+ 2.18663306850790267539e0,
128
+ -1.40256079171354495875e-1,
129
+ -3.50424626827848203418e-2,
130
+ -8.57456785154685413611e-4,
131
+ ]
132
+ )
133
+ )
134
+ q1 = list(
135
+ reversed(
136
+ [
137
+ 1.0,
138
+ 1.57799883256466749731e1,
139
+ 4.53907635128879210584e1,
140
+ 4.13172038254672030440e1,
141
+ 1.50425385692907503408e1,
142
+ 2.50464946208309415979e0,
143
+ -1.42182922854787788574e-1,
144
+ -3.80806407691578277194e-2,
145
+ -9.33259480895457427372e-4,
146
+ ]
147
+ )
148
+ )
149
+ p2 = list(
150
+ reversed(
151
+ [
152
+ 3.23774891776946035970e0,
153
+ 6.91522889068984211695e0,
154
+ 3.93881025292474443415e0,
155
+ 1.33303460815807542389e0,
156
+ 2.01485389549179081538e-1,
157
+ 1.23716634817820021358e-2,
158
+ 3.01581553508235416007e-4,
159
+ 2.65806974686737550832e-6,
160
+ 6.23974539184983293730e-9,
161
+ ]
162
+ )
163
+ )
164
+ q2 = list(
165
+ reversed(
166
+ [
167
+ 1.0,
168
+ 6.02427039364742014255e0,
169
+ 3.67983563856160859403e0,
170
+ 1.37702099489081330271e0,
171
+ 2.16236993594496635890e-1,
172
+ 1.34204006088543189037e-2,
173
+ 3.28014464682127739104e-4,
174
+ 2.89247864745380683936e-6,
175
+ 6.79019408009981274425e-9,
176
+ ]
177
+ )
178
+ )
179
+
180
+ dtype = lax.dtype(p).type
181
+ shape = jnp.shape(p)
182
+
183
+ def _create_polynomial(var, coeffs):
184
+ """Compute n_th order polynomial via Horner's method."""
185
+ coeffs = np.array(coeffs, dtype)
186
+ if not coeffs.size:
187
+ return jnp.zeros_like(var)
188
+ return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
189
+
190
+ maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p)
191
+ # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
192
+ # later on. The result from the computation when p == 0 is not used so any
193
+ # number that doesn't result in NaNs is fine.
194
+ sanitized_mcp = jnp.where(
195
+ maybe_complement_p == dtype(0.0),
196
+ jnp.full(shape, dtype(0.5)),
197
+ maybe_complement_p,
198
+ )
199
+
200
+ # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
201
+ w = sanitized_mcp - dtype(0.5)
202
+ ww = lax.square(w)
203
+ x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0))
204
+ x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi))
205
+
206
+ # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
207
+ # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
208
+ # arrays based on whether p < exp(-32).
209
+ z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp))
210
+ first_term = z - lax.log(z) / z
211
+ second_term_small_p = (
212
+ _create_polynomial(dtype(1.0) / z, p2)
213
+ / _create_polynomial(dtype(1.0) / z, q2)
214
+ / z
215
+ )
216
+ second_term_otherwise = (
217
+ _create_polynomial(dtype(1.0) / z, p1)
218
+ / _create_polynomial(dtype(1.0) / z, q1)
219
+ / z
220
+ )
221
+ x_for_small_p = first_term - second_term_small_p
222
+ x_otherwise = first_term - second_term_otherwise
223
+
224
+ x = jnp.where(
225
+ sanitized_mcp > dtype(np.exp(-2.0)),
226
+ x_for_big_p,
227
+ jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise),
228
+ )
229
+
230
+ x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x)
231
+ with debug_infs(False):
232
+ infinity = jnp.full(shape, dtype(np.inf))
233
+ neg_infinity = -infinity
234
+ return jnp.where(
235
+ p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x)
236
+ )
237
+
238
+
239
+ ################################################################