bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/_profiler.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# bartz/src/bartz/_profiler.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Module with utilities related to profiling bartz."""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Callable, Iterator
|
|
28
|
+
from contextlib import contextmanager
|
|
29
|
+
from functools import wraps
|
|
30
|
+
from typing import Any, TypeVar
|
|
31
|
+
|
|
32
|
+
from jax import block_until_ready, debug, jit
|
|
33
|
+
from jax.lax import cond, scan
|
|
34
|
+
from jax.profiler import TraceAnnotation
|
|
35
|
+
from jaxtyping import Array, Bool
|
|
36
|
+
|
|
37
|
+
from bartz.mcmcstep._state import vmap_chains
|
|
38
|
+
|
|
39
|
+
PROFILE_MODE: bool = False
|
|
40
|
+
|
|
41
|
+
T = TypeVar('T')
|
|
42
|
+
Carry = TypeVar('Carry')
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_profile_mode() -> bool:
|
|
46
|
+
"""Return the current profile mode status.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
True if profile mode is enabled, False otherwise.
|
|
51
|
+
"""
|
|
52
|
+
return PROFILE_MODE
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def set_profile_mode(value: bool, /) -> None:
|
|
56
|
+
"""Set the profile mode status.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
value
|
|
61
|
+
If True, enable profile mode. If False, disable it.
|
|
62
|
+
"""
|
|
63
|
+
global PROFILE_MODE # noqa: PLW0603
|
|
64
|
+
PROFILE_MODE = value
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@contextmanager
|
|
68
|
+
def profile_mode(value: bool, /) -> Iterator[None]:
|
|
69
|
+
"""Context manager to temporarily set profile mode.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
value
|
|
74
|
+
Profile mode value to set within the context.
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
>>> with profile_mode(True):
|
|
79
|
+
... # Code runs with profile mode enabled
|
|
80
|
+
... pass
|
|
81
|
+
|
|
82
|
+
Notes
|
|
83
|
+
-----
|
|
84
|
+
In profiling mode, the MCMC loop is not compiled into a single function, but
|
|
85
|
+
instead compiled in smaller pieces that are instrumented to show up in the
|
|
86
|
+
jax tracer and Python profiling statistics. Search for function names
|
|
87
|
+
starting with 'jab' (see `jit_and_block_if_profiling`).
|
|
88
|
+
|
|
89
|
+
Jax tracing is not enabled by this context manager and if used must be
|
|
90
|
+
handled separately by the user; this context manager only makes sure that
|
|
91
|
+
the execution flow will be more interpretable in the traces if the tracer is
|
|
92
|
+
used.
|
|
93
|
+
"""
|
|
94
|
+
old_value = get_profile_mode()
|
|
95
|
+
set_profile_mode(value)
|
|
96
|
+
try:
|
|
97
|
+
yield
|
|
98
|
+
finally:
|
|
99
|
+
set_profile_mode(old_value)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def jit_and_block_if_profiling(
|
|
103
|
+
func: Callable[..., T], block_before: bool = False, **kwargs
|
|
104
|
+
) -> Callable[..., T]:
|
|
105
|
+
"""Apply JIT compilation and block if profiling is enabled.
|
|
106
|
+
|
|
107
|
+
When profile mode is off, the function runs without JIT. When profile mode
|
|
108
|
+
is on, the function is JIT compiled and blocks outputs to ensure proper
|
|
109
|
+
timing.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
func
|
|
114
|
+
Function to wrap.
|
|
115
|
+
block_before
|
|
116
|
+
If True block inputs before passing them to the JIT-compiled function.
|
|
117
|
+
This ensures that any pending computations are completed before entering
|
|
118
|
+
the JIT-compiled function. This phase is not included in the trace
|
|
119
|
+
event.
|
|
120
|
+
**kwargs
|
|
121
|
+
Additional arguments to pass to `jax.jit`.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
Wrapped function.
|
|
126
|
+
|
|
127
|
+
Notes
|
|
128
|
+
-----
|
|
129
|
+
Under profiling mode, the function invocation is handled such that a custom
|
|
130
|
+
jax trace event with name `jab[<func_name>]` is created. The statistics on
|
|
131
|
+
the actual Python function will be off, while the function
|
|
132
|
+
`jab_inner_wrapper` represents the actual execution time.
|
|
133
|
+
"""
|
|
134
|
+
jitted_func = jit(func, **kwargs)
|
|
135
|
+
|
|
136
|
+
event_name = f'jab[{func.__name__}]'
|
|
137
|
+
|
|
138
|
+
# this wrapper is meant to measure the time spent executing the function
|
|
139
|
+
def jab_inner_wrapper(*args, **kwargs) -> T:
|
|
140
|
+
with TraceAnnotation(event_name):
|
|
141
|
+
result = jitted_func(*args, **kwargs)
|
|
142
|
+
return block_until_ready(result)
|
|
143
|
+
|
|
144
|
+
@wraps(func)
|
|
145
|
+
def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
146
|
+
if get_profile_mode():
|
|
147
|
+
if block_before:
|
|
148
|
+
args, kwargs = block_until_ready((args, kwargs))
|
|
149
|
+
return jab_inner_wrapper(*args, **kwargs)
|
|
150
|
+
else:
|
|
151
|
+
return func(*args, **kwargs)
|
|
152
|
+
|
|
153
|
+
return jab_outer_wrapper
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def jit_if_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
|
|
157
|
+
"""Apply JIT compilation only when profiling.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
func
|
|
162
|
+
Function to wrap.
|
|
163
|
+
*args
|
|
164
|
+
**kwargs
|
|
165
|
+
Additional arguments to pass to `jax.jit`.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
Wrapped function.
|
|
170
|
+
"""
|
|
171
|
+
jitted_func = jit(func, *args, **kwargs)
|
|
172
|
+
|
|
173
|
+
@wraps(func)
|
|
174
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
175
|
+
if get_profile_mode():
|
|
176
|
+
return jitted_func(*args, **kwargs)
|
|
177
|
+
else:
|
|
178
|
+
return func(*args, **kwargs)
|
|
179
|
+
|
|
180
|
+
return wrapper
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
|
|
184
|
+
"""Apply JIT compilation only when not profiling.
|
|
185
|
+
|
|
186
|
+
When profile mode is off, the function is JIT compiled. When profile mode is
|
|
187
|
+
on, the function runs as-is.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
func
|
|
192
|
+
Function to wrap.
|
|
193
|
+
*args
|
|
194
|
+
**kwargs
|
|
195
|
+
Additional arguments to pass to `jax.jit`.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
Wrapped function.
|
|
200
|
+
"""
|
|
201
|
+
jitted_func = jit(func, *args, **kwargs)
|
|
202
|
+
|
|
203
|
+
@wraps(func)
|
|
204
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
205
|
+
if get_profile_mode():
|
|
206
|
+
return func(*args, **kwargs)
|
|
207
|
+
else:
|
|
208
|
+
return jitted_func(*args, **kwargs)
|
|
209
|
+
|
|
210
|
+
wrapper._fun = func # used by run_mcmc # noqa: SLF001
|
|
211
|
+
|
|
212
|
+
return wrapper
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def scan_if_not_profiling(
|
|
216
|
+
f: Callable[[Carry, None], tuple[Carry, None]],
|
|
217
|
+
init: Carry,
|
|
218
|
+
xs: None,
|
|
219
|
+
length: int,
|
|
220
|
+
/,
|
|
221
|
+
) -> tuple[Carry, None]:
|
|
222
|
+
"""Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
f
|
|
227
|
+
Scan body function with signature (carry, None) -> (carry, None).
|
|
228
|
+
init
|
|
229
|
+
Initial carry value.
|
|
230
|
+
xs
|
|
231
|
+
Input values to scan over (not supported).
|
|
232
|
+
length
|
|
233
|
+
Integer specifying the number of loop iterations.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
Tuple of (final_carry, None) (stacked outputs not supported).
|
|
238
|
+
"""
|
|
239
|
+
assert xs is None
|
|
240
|
+
if get_profile_mode():
|
|
241
|
+
carry = init
|
|
242
|
+
for _i in range(length):
|
|
243
|
+
carry, _ = f(carry, None)
|
|
244
|
+
return carry, None
|
|
245
|
+
|
|
246
|
+
else:
|
|
247
|
+
return scan(f, init, None, length)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def cond_if_not_profiling(
|
|
251
|
+
pred: bool | Bool[Array, ''],
|
|
252
|
+
true_fun: Callable[..., T],
|
|
253
|
+
false_fun: Callable[..., T],
|
|
254
|
+
/,
|
|
255
|
+
*operands,
|
|
256
|
+
) -> T:
|
|
257
|
+
"""Restricted replacement for `jax.lax.cond` that uses a Python if when profiling.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
pred
|
|
262
|
+
Boolean predicate to choose which function to execute.
|
|
263
|
+
true_fun
|
|
264
|
+
Function to execute if `pred` is True.
|
|
265
|
+
false_fun
|
|
266
|
+
Function to execute if `pred` is False.
|
|
267
|
+
*operands
|
|
268
|
+
Arguments passed to `true_fun` and `false_fun`.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
Result of either `true_fun()` or `false_fun()`.
|
|
273
|
+
"""
|
|
274
|
+
if get_profile_mode():
|
|
275
|
+
if pred:
|
|
276
|
+
return true_fun(*operands)
|
|
277
|
+
else:
|
|
278
|
+
return false_fun(*operands)
|
|
279
|
+
else:
|
|
280
|
+
return cond(pred, true_fun, false_fun, *operands)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def callback_if_not_profiling(
|
|
284
|
+
callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any
|
|
285
|
+
):
|
|
286
|
+
"""Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode."""
|
|
287
|
+
if get_profile_mode():
|
|
288
|
+
callback(*args, **kwargs)
|
|
289
|
+
else:
|
|
290
|
+
debug.callback(callback, *args, ordered=ordered, **kwargs)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
|
|
294
|
+
"""Apply `vmap_chains` only when profile mode is enabled."""
|
|
295
|
+
new_fun = vmap_chains(fun, **kwargs)
|
|
296
|
+
|
|
297
|
+
@wraps(fun)
|
|
298
|
+
def wrapper(*args, **kwargs):
|
|
299
|
+
if get_profile_mode():
|
|
300
|
+
return new_fun(*args, **kwargs)
|
|
301
|
+
else:
|
|
302
|
+
return fun(*args, **kwargs)
|
|
303
|
+
|
|
304
|
+
return wrapper
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def vmap_chains_if_not_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
|
|
308
|
+
"""Apply `vmap_chains` only when profile mode is disabled."""
|
|
309
|
+
new_fun = vmap_chains(fun, **kwargs)
|
|
310
|
+
|
|
311
|
+
@wraps(fun)
|
|
312
|
+
def wrapper(*args, **kwargs):
|
|
313
|
+
if get_profile_mode():
|
|
314
|
+
return fun(*args, **kwargs)
|
|
315
|
+
else:
|
|
316
|
+
return new_fun(*args, **kwargs)
|
|
317
|
+
|
|
318
|
+
return wrapper
|
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.8.0'
|