bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/jaxext/__init__.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext/__init__.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Additions to jax."""
|
|
26
|
+
|
|
27
|
+
import math
|
|
28
|
+
from collections.abc import Sequence
|
|
29
|
+
from contextlib import nullcontext
|
|
30
|
+
from functools import partial
|
|
31
|
+
|
|
32
|
+
import jax
|
|
33
|
+
from jax import (
|
|
34
|
+
Device,
|
|
35
|
+
debug_key_reuse,
|
|
36
|
+
device_count,
|
|
37
|
+
ensure_compile_time_eval,
|
|
38
|
+
jit,
|
|
39
|
+
random,
|
|
40
|
+
vmap,
|
|
41
|
+
)
|
|
42
|
+
from jax import numpy as jnp
|
|
43
|
+
from jax.dtypes import prng_key
|
|
44
|
+
from jax.lax import scan
|
|
45
|
+
from jax.scipy.special import ndtr
|
|
46
|
+
from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
|
|
47
|
+
|
|
48
|
+
from bartz.jaxext._autobatch import autobatch # noqa: F401
|
|
49
|
+
from bartz.jaxext.scipy.special import ndtri
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def vmap_nodoc(fun, *args, **kw):
|
|
53
|
+
"""
|
|
54
|
+
Acts like `jax.vmap` but preserves the docstring of the function unchanged.
|
|
55
|
+
|
|
56
|
+
This is useful if the docstring already takes into account that the
|
|
57
|
+
arguments have additional axes due to vmap.
|
|
58
|
+
"""
|
|
59
|
+
doc = fun.__doc__
|
|
60
|
+
fun = jax.vmap(fun, *args, **kw)
|
|
61
|
+
fun.__doc__ = doc
|
|
62
|
+
return fun
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def minimal_unsigned_dtype(value):
|
|
66
|
+
"""Return the smallest unsigned integer dtype that can represent `value`."""
|
|
67
|
+
if value < 2**8:
|
|
68
|
+
return jnp.uint8
|
|
69
|
+
if value < 2**16:
|
|
70
|
+
return jnp.uint16
|
|
71
|
+
if value < 2**32:
|
|
72
|
+
return jnp.uint32
|
|
73
|
+
return jnp.uint64
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@partial(jax.jit, static_argnums=(1,))
|
|
77
|
+
def unique(
|
|
78
|
+
x: Shaped[Array, ' _'], size: int, fill_value: Scalar
|
|
79
|
+
) -> tuple[Shaped[Array, ' {size}'], int]:
|
|
80
|
+
"""
|
|
81
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
x
|
|
86
|
+
The input array.
|
|
87
|
+
size
|
|
88
|
+
The length of the output.
|
|
89
|
+
fill_value
|
|
90
|
+
The value to fill the output with if `size` is greater than the number
|
|
91
|
+
of unique values in `x`.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
out : Shaped[Array, '{size}']
|
|
96
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
97
|
+
actual_length : int
|
|
98
|
+
The number of used values in `out`.
|
|
99
|
+
"""
|
|
100
|
+
if x.size == 0:
|
|
101
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
102
|
+
if size == 0:
|
|
103
|
+
return jnp.empty(0, x.dtype), 0
|
|
104
|
+
x = jnp.sort(x)
|
|
105
|
+
|
|
106
|
+
def loop(carry, x):
|
|
107
|
+
i_out, last, out = carry
|
|
108
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
109
|
+
out = out.at[i_out].set(x)
|
|
110
|
+
return (i_out, x, out), None
|
|
111
|
+
|
|
112
|
+
carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
113
|
+
(actual_length, _, out), _ = scan(loop, carry, x[:size])
|
|
114
|
+
return out, actual_length + 1
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class split:
|
|
118
|
+
"""
|
|
119
|
+
Split a key into `num` keys.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
key
|
|
124
|
+
The key to split.
|
|
125
|
+
num
|
|
126
|
+
The number of keys to split into.
|
|
127
|
+
|
|
128
|
+
Notes
|
|
129
|
+
-----
|
|
130
|
+
Unlike `jax.random.split`, this class supports a vector of keys as input. In
|
|
131
|
+
this case, it behaves as if everything had been vmapped over, so `keys.pop`
|
|
132
|
+
has an additional initial output dimension equal to the number of input
|
|
133
|
+
keys, and the deterministic dependency respects this axis.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
_keys: tuple[Key[Array, '*batch'], ...]
|
|
137
|
+
_num_used: int
|
|
138
|
+
|
|
139
|
+
def __init__(self, key: Key[Array, '*batch'], num: int = 2):
|
|
140
|
+
if key.ndim:
|
|
141
|
+
context = debug_key_reuse(False)
|
|
142
|
+
else:
|
|
143
|
+
context = nullcontext()
|
|
144
|
+
with context:
|
|
145
|
+
# jitted-vmapped key split seems to be triggering a false positive
|
|
146
|
+
# with key reuse checks
|
|
147
|
+
self._keys = _split_unpack(key, num)
|
|
148
|
+
self._num_used = 0
|
|
149
|
+
|
|
150
|
+
def __len__(self):
|
|
151
|
+
return len(self._keys) - self._num_used
|
|
152
|
+
|
|
153
|
+
def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*batch {shape}']:
|
|
154
|
+
"""
|
|
155
|
+
Pop one or more keys from the list.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
shape
|
|
160
|
+
The shape of the keys to pop. If empty (default), a single key is
|
|
161
|
+
popped and returned. If not empty, the popped key is split and
|
|
162
|
+
reshaped to the target shape.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
The popped keys as a jax array with the requested shape.
|
|
167
|
+
|
|
168
|
+
Raises
|
|
169
|
+
------
|
|
170
|
+
IndexError
|
|
171
|
+
If the list is empty.
|
|
172
|
+
"""
|
|
173
|
+
if len(self) == 0:
|
|
174
|
+
msg = 'No keys left to pop'
|
|
175
|
+
raise IndexError(msg)
|
|
176
|
+
if not isinstance(shape, tuple):
|
|
177
|
+
shape = (shape,)
|
|
178
|
+
key = self._keys[self._num_used]
|
|
179
|
+
self._num_used += 1
|
|
180
|
+
if shape:
|
|
181
|
+
key = _split_shaped(key, shape)
|
|
182
|
+
return key
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@partial(jit, static_argnums=(1,))
|
|
186
|
+
def _split_unpack(
|
|
187
|
+
key: Key[Array, '*batch'], num: int
|
|
188
|
+
) -> tuple[Key[Array, '*batch'], ...]:
|
|
189
|
+
if key.ndim == 0:
|
|
190
|
+
keys = random.split(key, num)
|
|
191
|
+
elif key.ndim == 1:
|
|
192
|
+
keys = vmap(random.split, in_axes=(0, None), out_axes=1)(key, num)
|
|
193
|
+
return tuple(keys)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@partial(jit, static_argnums=(1,))
|
|
197
|
+
def _split_shaped(
|
|
198
|
+
key: Key[Array, '*batch'], shape: tuple[int, ...]
|
|
199
|
+
) -> Key[Array, '*batch {shape}']:
|
|
200
|
+
num = math.prod(shape)
|
|
201
|
+
if key.ndim == 0:
|
|
202
|
+
keys = random.split(key, num)
|
|
203
|
+
elif key.ndim == 1:
|
|
204
|
+
keys = vmap(random.split, in_axes=(0, None))(key, num)
|
|
205
|
+
return keys.reshape(*key.shape, *shape)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def truncated_normal_onesided(
|
|
209
|
+
key: Key[Array, ''],
|
|
210
|
+
shape: Sequence[int],
|
|
211
|
+
upper: Bool[Array, '*'],
|
|
212
|
+
bound: Float32[Array, '*'],
|
|
213
|
+
*,
|
|
214
|
+
clip: bool = True,
|
|
215
|
+
) -> Float32[Array, '*']:
|
|
216
|
+
"""
|
|
217
|
+
Sample from a one-sided truncated standard normal distribution.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
key
|
|
222
|
+
JAX random key.
|
|
223
|
+
shape
|
|
224
|
+
Shape of output array, broadcasted with other inputs.
|
|
225
|
+
upper
|
|
226
|
+
True for (-∞, bound], False for [bound, ∞).
|
|
227
|
+
bound
|
|
228
|
+
The truncation boundary.
|
|
229
|
+
clip
|
|
230
|
+
Whether to clip the truncated uniform samples to (0, 1) before
|
|
231
|
+
transforming them to truncated normal. Intended for debugging purposes.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
Array of samples from the truncated normal distribution.
|
|
236
|
+
"""
|
|
237
|
+
# Pseudocode:
|
|
238
|
+
# | if upper:
|
|
239
|
+
# | if bound < 0:
|
|
240
|
+
# | ndtri(uniform(0, ndtr(bound))) =
|
|
241
|
+
# | ndtri(ndtr(bound) * u)
|
|
242
|
+
# | if bound > 0:
|
|
243
|
+
# | -ndtri(uniform(ndtr(-bound), 1)) =
|
|
244
|
+
# | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
|
|
245
|
+
# | if not upper:
|
|
246
|
+
# | if bound < 0:
|
|
247
|
+
# | ndtri(uniform(ndtr(bound), 1)) =
|
|
248
|
+
# | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
|
|
249
|
+
# | if bound > 0:
|
|
250
|
+
# | -ndtri(uniform(0, ndtr(-bound))) =
|
|
251
|
+
# | -ndtri(ndtr(-bound) * u)
|
|
252
|
+
shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
|
|
253
|
+
bound_pos = bound > 0
|
|
254
|
+
ndtr_bound = ndtr(bound)
|
|
255
|
+
ndtr_neg_bound = ndtr(-bound)
|
|
256
|
+
scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
|
|
257
|
+
shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
|
|
258
|
+
u = random.uniform(key, shape)
|
|
259
|
+
left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
|
|
260
|
+
right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
|
|
261
|
+
truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
|
|
262
|
+
if clip:
|
|
263
|
+
# on gpu the accuracy is lower and sometimes u can reach the boundaries
|
|
264
|
+
zero = jnp.zeros((), truncated_u.dtype)
|
|
265
|
+
one = jnp.ones((), truncated_u.dtype)
|
|
266
|
+
truncated_u = jnp.clip(
|
|
267
|
+
truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
|
|
268
|
+
)
|
|
269
|
+
truncated_norm = ndtri(truncated_u)
|
|
270
|
+
return jnp.where(bound_pos, -truncated_norm, truncated_norm)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_default_device() -> Device:
|
|
274
|
+
"""Get the current default JAX device."""
|
|
275
|
+
with ensure_compile_time_eval():
|
|
276
|
+
return jnp.zeros(()).device
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_device_count() -> int:
|
|
280
|
+
"""Get the number of available devices on the default platform."""
|
|
281
|
+
device = get_default_device()
|
|
282
|
+
return device_count(device.platform)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def is_key(x: object) -> bool:
|
|
286
|
+
"""Determine if `x` is a jax random key."""
|
|
287
|
+
return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key)
|