bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,287 @@
1
+ # bartz/src/bartz/jaxext/__init__.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Additions to jax."""
26
+
27
+ import math
28
+ from collections.abc import Sequence
29
+ from contextlib import nullcontext
30
+ from functools import partial
31
+
32
+ import jax
33
+ from jax import (
34
+ Device,
35
+ debug_key_reuse,
36
+ device_count,
37
+ ensure_compile_time_eval,
38
+ jit,
39
+ random,
40
+ vmap,
41
+ )
42
+ from jax import numpy as jnp
43
+ from jax.dtypes import prng_key
44
+ from jax.lax import scan
45
+ from jax.scipy.special import ndtr
46
+ from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
47
+
48
+ from bartz.jaxext._autobatch import autobatch # noqa: F401
49
+ from bartz.jaxext.scipy.special import ndtri
50
+
51
+
52
+ def vmap_nodoc(fun, *args, **kw):
53
+ """
54
+ Acts like `jax.vmap` but preserves the docstring of the function unchanged.
55
+
56
+ This is useful if the docstring already takes into account that the
57
+ arguments have additional axes due to vmap.
58
+ """
59
+ doc = fun.__doc__
60
+ fun = jax.vmap(fun, *args, **kw)
61
+ fun.__doc__ = doc
62
+ return fun
63
+
64
+
65
+ def minimal_unsigned_dtype(value):
66
+ """Return the smallest unsigned integer dtype that can represent `value`."""
67
+ if value < 2**8:
68
+ return jnp.uint8
69
+ if value < 2**16:
70
+ return jnp.uint16
71
+ if value < 2**32:
72
+ return jnp.uint32
73
+ return jnp.uint64
74
+
75
+
76
+ @partial(jax.jit, static_argnums=(1,))
77
+ def unique(
78
+ x: Shaped[Array, ' _'], size: int, fill_value: Scalar
79
+ ) -> tuple[Shaped[Array, ' {size}'], int]:
80
+ """
81
+ Restricted version of `jax.numpy.unique` that uses less memory.
82
+
83
+ Parameters
84
+ ----------
85
+ x
86
+ The input array.
87
+ size
88
+ The length of the output.
89
+ fill_value
90
+ The value to fill the output with if `size` is greater than the number
91
+ of unique values in `x`.
92
+
93
+ Returns
94
+ -------
95
+ out : Shaped[Array, '{size}']
96
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
97
+ actual_length : int
98
+ The number of used values in `out`.
99
+ """
100
+ if x.size == 0:
101
+ return jnp.full(size, fill_value, x.dtype), 0
102
+ if size == 0:
103
+ return jnp.empty(0, x.dtype), 0
104
+ x = jnp.sort(x)
105
+
106
+ def loop(carry, x):
107
+ i_out, last, out = carry
108
+ i_out = jnp.where(x == last, i_out, i_out + 1)
109
+ out = out.at[i_out].set(x)
110
+ return (i_out, x, out), None
111
+
112
+ carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
113
+ (actual_length, _, out), _ = scan(loop, carry, x[:size])
114
+ return out, actual_length + 1
115
+
116
+
117
+ class split:
118
+ """
119
+ Split a key into `num` keys.
120
+
121
+ Parameters
122
+ ----------
123
+ key
124
+ The key to split.
125
+ num
126
+ The number of keys to split into.
127
+
128
+ Notes
129
+ -----
130
+ Unlike `jax.random.split`, this class supports a vector of keys as input. In
131
+ this case, it behaves as if everything had been vmapped over, so `keys.pop`
132
+ has an additional initial output dimension equal to the number of input
133
+ keys, and the deterministic dependency respects this axis.
134
+ """
135
+
136
+ _keys: tuple[Key[Array, '*batch'], ...]
137
+ _num_used: int
138
+
139
+ def __init__(self, key: Key[Array, '*batch'], num: int = 2):
140
+ if key.ndim:
141
+ context = debug_key_reuse(False)
142
+ else:
143
+ context = nullcontext()
144
+ with context:
145
+ # jitted-vmapped key split seems to be triggering a false positive
146
+ # with key reuse checks
147
+ self._keys = _split_unpack(key, num)
148
+ self._num_used = 0
149
+
150
+ def __len__(self):
151
+ return len(self._keys) - self._num_used
152
+
153
+ def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']:
154
+ """
155
+ Pop one or more keys from the list.
156
+
157
+ Parameters
158
+ ----------
159
+ shape
160
+ The shape of the keys to pop. If empty (default), a single key is
161
+ popped and returned. If not empty, the popped key is split and
162
+ reshaped to the target shape.
163
+
164
+ Returns
165
+ -------
166
+ The popped keys as a jax array with the requested shape.
167
+
168
+ Raises
169
+ ------
170
+ IndexError
171
+ If the list is empty.
172
+ """
173
+ if len(self) == 0:
174
+ msg = 'No keys left to pop'
175
+ raise IndexError(msg)
176
+ if not isinstance(shape, tuple):
177
+ shape = (shape,)
178
+ key = self._keys[self._num_used]
179
+ self._num_used += 1
180
+ if shape:
181
+ key = _split_shaped(key, shape)
182
+ return key
183
+
184
+
185
+ @partial(jit, static_argnums=(1,))
186
+ def _split_unpack(
187
+ key: Key[Array, '*batch'], num: int
188
+ ) -> tuple[Key[Array, '*batch'], ...]:
189
+ if key.ndim == 0:
190
+ keys = random.split(key, num)
191
+ elif key.ndim == 1:
192
+ keys = vmap(random.split, in_axes=(0, None), out_axes=1)(key, num)
193
+ return tuple(keys)
194
+
195
+
196
+ @partial(jit, static_argnums=(1,))
197
+ def _split_shaped(
198
+ key: Key[Array, '*batch'], shape: tuple[int, ...]
199
+ ) -> Key[Array, '*batch {shape}']:
200
+ num = math.prod(shape)
201
+ if key.ndim == 0:
202
+ keys = random.split(key, num)
203
+ elif key.ndim == 1:
204
+ keys = vmap(random.split, in_axes=(0, None))(key, num)
205
+ return keys.reshape(*key.shape, *shape)
206
+
207
+
208
+ def truncated_normal_onesided(
209
+ key: Key[Array, ''],
210
+ shape: Sequence[int],
211
+ upper: Bool[Array, '*'],
212
+ bound: Float32[Array, '*'],
213
+ *,
214
+ clip: bool = True,
215
+ ) -> Float32[Array, '*']:
216
+ """
217
+ Sample from a one-sided truncated standard normal distribution.
218
+
219
+ Parameters
220
+ ----------
221
+ key
222
+ JAX random key.
223
+ shape
224
+ Shape of output array, broadcasted with other inputs.
225
+ upper
226
+ True for (-∞, bound], False for [bound, ∞).
227
+ bound
228
+ The truncation boundary.
229
+ clip
230
+ Whether to clip the truncated uniform samples to (0, 1) before
231
+ transforming them to truncated normal. Intended for debugging purposes.
232
+
233
+ Returns
234
+ -------
235
+ Array of samples from the truncated normal distribution.
236
+ """
237
+ # Pseudocode:
238
+ # | if upper:
239
+ # | if bound < 0:
240
+ # | ndtri(uniform(0, ndtr(bound))) =
241
+ # | ndtri(ndtr(bound) * u)
242
+ # | if bound > 0:
243
+ # | -ndtri(uniform(ndtr(-bound), 1)) =
244
+ # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
245
+ # | if not upper:
246
+ # | if bound < 0:
247
+ # | ndtri(uniform(ndtr(bound), 1)) =
248
+ # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
249
+ # | if bound > 0:
250
+ # | -ndtri(uniform(0, ndtr(-bound))) =
251
+ # | -ndtri(ndtr(-bound) * u)
252
+ shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
253
+ bound_pos = bound > 0
254
+ ndtr_bound = ndtr(bound)
255
+ ndtr_neg_bound = ndtr(-bound)
256
+ scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
257
+ shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
258
+ u = random.uniform(key, shape)
259
+ left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
260
+ right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
261
+ truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
262
+ if clip:
263
+ # on gpu the accuracy is lower and sometimes u can reach the boundaries
264
+ zero = jnp.zeros((), truncated_u.dtype)
265
+ one = jnp.ones((), truncated_u.dtype)
266
+ truncated_u = jnp.clip(
267
+ truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
268
+ )
269
+ truncated_norm = ndtri(truncated_u)
270
+ return jnp.where(bound_pos, -truncated_norm, truncated_norm)
271
+
272
+
273
+ def get_default_device() -> Device:
274
+ """Get the current default JAX device."""
275
+ with ensure_compile_time_eval():
276
+ return jnp.zeros(()).device
277
+
278
+
279
+ def get_device_count() -> int:
280
+ """Get the number of available devices on the default platform."""
281
+ device = get_default_device()
282
+ return device_count(device.platform)
283
+
284
+
285
+ def is_key(x: object) -> bool:
286
+ """Determine if `x` is a jax random key."""
287
+ return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key)