bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/_profiler.py ADDED
@@ -0,0 +1,318 @@
1
+ # bartz/src/bartz/_profiler.py
2
+ #
3
+ # Copyright (c) 2025-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Module with utilities related to profiling bartz."""
26
+
27
+ from collections.abc import Callable, Iterator
28
+ from contextlib import contextmanager
29
+ from functools import wraps
30
+ from typing import Any, TypeVar
31
+
32
+ from jax import block_until_ready, debug, jit
33
+ from jax.lax import cond, scan
34
+ from jax.profiler import TraceAnnotation
35
+ from jaxtyping import Array, Bool
36
+
37
+ from bartz.mcmcstep._state import vmap_chains
38
+
39
+ PROFILE_MODE: bool = False
40
+
41
+ T = TypeVar('T')
42
+ Carry = TypeVar('Carry')
43
+
44
+
45
+ def get_profile_mode() -> bool:
46
+ """Return the current profile mode status.
47
+
48
+ Returns
49
+ -------
50
+ True if profile mode is enabled, False otherwise.
51
+ """
52
+ return PROFILE_MODE
53
+
54
+
55
+ def set_profile_mode(value: bool, /) -> None:
56
+ """Set the profile mode status.
57
+
58
+ Parameters
59
+ ----------
60
+ value
61
+ If True, enable profile mode. If False, disable it.
62
+ """
63
+ global PROFILE_MODE # noqa: PLW0603
64
+ PROFILE_MODE = value
65
+
66
+
67
+ @contextmanager
68
+ def profile_mode(value: bool, /) -> Iterator[None]:
69
+ """Context manager to temporarily set profile mode.
70
+
71
+ Parameters
72
+ ----------
73
+ value
74
+ Profile mode value to set within the context.
75
+
76
+ Examples
77
+ --------
78
+ >>> with profile_mode(True):
79
+ ... # Code runs with profile mode enabled
80
+ ... pass
81
+
82
+ Notes
83
+ -----
84
+ In profiling mode, the MCMC loop is not compiled into a single function, but
85
+ instead compiled in smaller pieces that are instrumented to show up in the
86
+ jax tracer and Python profiling statistics. Search for function names
87
+ starting with 'jab' (see `jit_and_block_if_profiling`).
88
+
89
+ Jax tracing is not enabled by this context manager and if used must be
90
+ handled separately by the user; this context manager only makes sure that
91
+ the execution flow will be more interpretable in the traces if the tracer is
92
+ used.
93
+ """
94
+ old_value = get_profile_mode()
95
+ set_profile_mode(value)
96
+ try:
97
+ yield
98
+ finally:
99
+ set_profile_mode(old_value)
100
+
101
+
102
+ def jit_and_block_if_profiling(
103
+ func: Callable[..., T], block_before: bool = False, **kwargs
104
+ ) -> Callable[..., T]:
105
+ """Apply JIT compilation and block if profiling is enabled.
106
+
107
+ When profile mode is off, the function runs without JIT. When profile mode
108
+ is on, the function is JIT compiled and blocks outputs to ensure proper
109
+ timing.
110
+
111
+ Parameters
112
+ ----------
113
+ func
114
+ Function to wrap.
115
+ block_before
116
+ If True block inputs before passing them to the JIT-compiled function.
117
+ This ensures that any pending computations are completed before entering
118
+ the JIT-compiled function. This phase is not included in the trace
119
+ event.
120
+ **kwargs
121
+ Additional arguments to pass to `jax.jit`.
122
+
123
+ Returns
124
+ -------
125
+ Wrapped function.
126
+
127
+ Notes
128
+ -----
129
+ Under profiling mode, the function invocation is handled such that a custom
130
+ jax trace event with name `jab[<func_name>]` is created. The statistics on
131
+ the actual Python function will be off, while the function
132
+ `jab_inner_wrapper` represents the actual execution time.
133
+ """
134
+ jitted_func = jit(func, **kwargs)
135
+
136
+ event_name = f'jab[{func.__name__}]'
137
+
138
+ # this wrapper is meant to measure the time spent executing the function
139
+ def jab_inner_wrapper(*args, **kwargs) -> T:
140
+ with TraceAnnotation(event_name):
141
+ result = jitted_func(*args, **kwargs)
142
+ return block_until_ready(result)
143
+
144
+ @wraps(func)
145
+ def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T:
146
+ if get_profile_mode():
147
+ if block_before:
148
+ args, kwargs = block_until_ready((args, kwargs))
149
+ return jab_inner_wrapper(*args, **kwargs)
150
+ else:
151
+ return func(*args, **kwargs)
152
+
153
+ return jab_outer_wrapper
154
+
155
+
156
+ def jit_if_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
157
+ """Apply JIT compilation only when profiling.
158
+
159
+ Parameters
160
+ ----------
161
+ func
162
+ Function to wrap.
163
+ *args
164
+ **kwargs
165
+ Additional arguments to pass to `jax.jit`.
166
+
167
+ Returns
168
+ -------
169
+ Wrapped function.
170
+ """
171
+ jitted_func = jit(func, *args, **kwargs)
172
+
173
+ @wraps(func)
174
+ def wrapper(*args: Any, **kwargs: Any) -> T:
175
+ if get_profile_mode():
176
+ return jitted_func(*args, **kwargs)
177
+ else:
178
+ return func(*args, **kwargs)
179
+
180
+ return wrapper
181
+
182
+
183
+ def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
184
+ """Apply JIT compilation only when not profiling.
185
+
186
+ When profile mode is off, the function is JIT compiled. When profile mode is
187
+ on, the function runs as-is.
188
+
189
+ Parameters
190
+ ----------
191
+ func
192
+ Function to wrap.
193
+ *args
194
+ **kwargs
195
+ Additional arguments to pass to `jax.jit`.
196
+
197
+ Returns
198
+ -------
199
+ Wrapped function.
200
+ """
201
+ jitted_func = jit(func, *args, **kwargs)
202
+
203
+ @wraps(func)
204
+ def wrapper(*args: Any, **kwargs: Any) -> T:
205
+ if get_profile_mode():
206
+ return func(*args, **kwargs)
207
+ else:
208
+ return jitted_func(*args, **kwargs)
209
+
210
+ wrapper._fun = func # used by run_mcmc # noqa: SLF001
211
+
212
+ return wrapper
213
+
214
+
215
+ def scan_if_not_profiling(
216
+ f: Callable[[Carry, None], tuple[Carry, None]],
217
+ init: Carry,
218
+ xs: None,
219
+ length: int,
220
+ /,
221
+ ) -> tuple[Carry, None]:
222
+ """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling.
223
+
224
+ Parameters
225
+ ----------
226
+ f
227
+ Scan body function with signature (carry, None) -> (carry, None).
228
+ init
229
+ Initial carry value.
230
+ xs
231
+ Input values to scan over (not supported).
232
+ length
233
+ Integer specifying the number of loop iterations.
234
+
235
+ Returns
236
+ -------
237
+ Tuple of (final_carry, None) (stacked outputs not supported).
238
+ """
239
+ assert xs is None
240
+ if get_profile_mode():
241
+ carry = init
242
+ for _i in range(length):
243
+ carry, _ = f(carry, None)
244
+ return carry, None
245
+
246
+ else:
247
+ return scan(f, init, None, length)
248
+
249
+
250
+ def cond_if_not_profiling(
251
+ pred: bool | Bool[Array, ''],
252
+ true_fun: Callable[..., T],
253
+ false_fun: Callable[..., T],
254
+ /,
255
+ *operands,
256
+ ) -> T:
257
+ """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling.
258
+
259
+ Parameters
260
+ ----------
261
+ pred
262
+ Boolean predicate to choose which function to execute.
263
+ true_fun
264
+ Function to execute if `pred` is True.
265
+ false_fun
266
+ Function to execute if `pred` is False.
267
+ *operands
268
+ Arguments passed to `true_fun` and `false_fun`.
269
+
270
+ Returns
271
+ -------
272
+ Result of either `true_fun()` or `false_fun()`.
273
+ """
274
+ if get_profile_mode():
275
+ if pred:
276
+ return true_fun(*operands)
277
+ else:
278
+ return false_fun(*operands)
279
+ else:
280
+ return cond(pred, true_fun, false_fun, *operands)
281
+
282
+
283
+ def callback_if_not_profiling(
284
+ callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any
285
+ ):
286
+ """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode."""
287
+ if get_profile_mode():
288
+ callback(*args, **kwargs)
289
+ else:
290
+ debug.callback(callback, *args, ordered=ordered, **kwargs)
291
+
292
+
293
+ def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
294
+ """Apply `vmap_chains` only when profile mode is enabled."""
295
+ new_fun = vmap_chains(fun, **kwargs)
296
+
297
+ @wraps(fun)
298
+ def wrapper(*args, **kwargs):
299
+ if get_profile_mode():
300
+ return new_fun(*args, **kwargs)
301
+ else:
302
+ return fun(*args, **kwargs)
303
+
304
+ return wrapper
305
+
306
+
307
+ def vmap_chains_if_not_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
308
+ """Apply `vmap_chains` only when profile mode is disabled."""
309
+ new_fun = vmap_chains(fun, **kwargs)
310
+
311
+ @wraps(fun)
312
+ def wrapper(*args, **kwargs):
313
+ if get_profile_mode():
314
+ return fun(*args, **kwargs)
315
+ else:
316
+ return new_fun(*args, **kwargs)
317
+
318
+ return wrapper
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.7.0'
1
+ __version__ = '0.8.0'