aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
aspire/utils.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import array_api_compat.numpy as np
|
|
11
|
+
import h5py
|
|
12
|
+
import wrapt
|
|
13
|
+
from array_api_compat import (
|
|
14
|
+
array_namespace,
|
|
15
|
+
is_jax_array,
|
|
16
|
+
is_torch_array,
|
|
17
|
+
is_torch_namespace,
|
|
18
|
+
to_device,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from multiprocessing import Pool
|
|
23
|
+
|
|
24
|
+
from array_api_compat.common._typing import Array
|
|
25
|
+
|
|
26
|
+
from .aspire import Aspire
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def configure_logger(
|
|
32
|
+
log_level: str | int = "INFO",
|
|
33
|
+
additional_loggers: list[str] = None,
|
|
34
|
+
include_aspire_loggers: bool = True,
|
|
35
|
+
) -> logging.Logger:
|
|
36
|
+
"""Configure the logger.
|
|
37
|
+
|
|
38
|
+
Adds a stream handler to the logger.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
log_level : str or int, optional
|
|
43
|
+
The log level to use. Defaults to "INFO".
|
|
44
|
+
additional_loggers : list of str, optional
|
|
45
|
+
Additional loggers to configure. Defaults to None.
|
|
46
|
+
include_aspire_loggers : bool, optional
|
|
47
|
+
Whether to include all loggers that start with "aspire_" or "aspire-".
|
|
48
|
+
Defaults to True.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
logging.Logger
|
|
53
|
+
The configured logger.
|
|
54
|
+
"""
|
|
55
|
+
logger = logging.getLogger("aspire")
|
|
56
|
+
logger.setLevel(log_level)
|
|
57
|
+
ch = logging.StreamHandler()
|
|
58
|
+
ch.setLevel(log_level)
|
|
59
|
+
formatter = logging.Formatter(
|
|
60
|
+
"%(asctime)s - aspire - %(levelname)s - %(message)s"
|
|
61
|
+
)
|
|
62
|
+
ch.setFormatter(formatter)
|
|
63
|
+
logger.addHandler(ch)
|
|
64
|
+
|
|
65
|
+
additional_loggers = additional_loggers or []
|
|
66
|
+
for name in logger.manager.loggerDict:
|
|
67
|
+
if include_aspire_loggers and (
|
|
68
|
+
name.startswith("aspire_") or name.startswith("aspire-")
|
|
69
|
+
):
|
|
70
|
+
additional_loggers.append(name)
|
|
71
|
+
|
|
72
|
+
for name in additional_loggers:
|
|
73
|
+
dep_logger = logging.getLogger(name)
|
|
74
|
+
dep_logger.setLevel(log_level)
|
|
75
|
+
dep_logger.handlers.clear()
|
|
76
|
+
for handler in logger.handlers:
|
|
77
|
+
dep_logger.addHandler(handler)
|
|
78
|
+
dep_logger.propagate = False
|
|
79
|
+
|
|
80
|
+
return logger
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class PoolHandler:
|
|
84
|
+
"""Context manager to temporarily replace the log_likelihood method of a
|
|
85
|
+
aspire instance with a version that uses a multiprocessing pool to
|
|
86
|
+
parallelize computation.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
aspire_instance : aspire
|
|
91
|
+
The aspire instance to modify. The log_likelihood method of this
|
|
92
|
+
instance must accept a :code:`map_fn` keyword argument.
|
|
93
|
+
pool : multiprocessing.Pool
|
|
94
|
+
The pool to use for parallel computation.
|
|
95
|
+
close_pool : bool, optional
|
|
96
|
+
Whether to close the pool when exiting the context manager.
|
|
97
|
+
Defaults to True.
|
|
98
|
+
parallelize_prior : bool, optional
|
|
99
|
+
Whether to parallelize the log_prior method as well. Defaults to False.
|
|
100
|
+
If True, the log_prior method of the aspire instance must also
|
|
101
|
+
accept a :code:`map_fn` keyword argument.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
aspire_instance: Aspire,
|
|
107
|
+
pool: Pool,
|
|
108
|
+
close_pool: bool = True,
|
|
109
|
+
parallelize_prior: bool = False,
|
|
110
|
+
):
|
|
111
|
+
self.parallelize_prior = parallelize_prior
|
|
112
|
+
self.aspire_instance = aspire_instance
|
|
113
|
+
self.pool = pool
|
|
114
|
+
self.close_pool = close_pool
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def aspire_instance(self):
|
|
118
|
+
return self._aspire_instance
|
|
119
|
+
|
|
120
|
+
@aspire_instance.setter
|
|
121
|
+
def aspire_instance(self, value: Aspire):
|
|
122
|
+
signature = inspect.signature(value.log_likelihood)
|
|
123
|
+
if "map_fn" not in signature.parameters:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"The log_likelihood method of the Aspire instance must accept a"
|
|
126
|
+
" 'map_fn' keyword argument."
|
|
127
|
+
)
|
|
128
|
+
signature = inspect.signature(value.log_prior)
|
|
129
|
+
if "map_fn" not in signature.parameters and self.parallelize_prior:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"The log_prior method of the Aspire instance must accept a"
|
|
132
|
+
" 'map_fn' keyword argument if parallelize_prior is True."
|
|
133
|
+
)
|
|
134
|
+
self._aspire_instance = value
|
|
135
|
+
|
|
136
|
+
def __enter__(self):
|
|
137
|
+
self.original_log_likelihood = self.aspire_instance.log_likelihood
|
|
138
|
+
self.original_log_prior = self.aspire_instance.log_prior
|
|
139
|
+
if self.pool is not None:
|
|
140
|
+
logger.debug("Updating map function in log-likelihood method")
|
|
141
|
+
self.aspire_instance.log_likelihood = partial(
|
|
142
|
+
self.original_log_likelihood, map_fn=self.pool.map
|
|
143
|
+
)
|
|
144
|
+
if self.parallelize_prior:
|
|
145
|
+
logger.debug("Updating map function in log-prior method")
|
|
146
|
+
self.aspire_instance.log_prior = partial(
|
|
147
|
+
self.original_log_prior, map_fn=self.pool.map
|
|
148
|
+
)
|
|
149
|
+
return self.pool
|
|
150
|
+
|
|
151
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
152
|
+
self.aspire_instance.log_likelihood = self.original_log_likelihood
|
|
153
|
+
self.aspire_instance.log_prior = self.original_log_prior
|
|
154
|
+
if self.close_pool:
|
|
155
|
+
logger.debug("Closing pool")
|
|
156
|
+
self.pool.close()
|
|
157
|
+
self.pool.join()
|
|
158
|
+
else:
|
|
159
|
+
logger.debug("Not closing pool")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def logit(x: Array, eps: float | None = None) -> tuple[Array, Array]:
|
|
163
|
+
"""Logit function that also returns log Jacobian determinant.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
x : float or ndarray
|
|
168
|
+
Array of values
|
|
169
|
+
eps : float, optional
|
|
170
|
+
Epsilon value used to clamp inputs to [eps, 1 - eps]. If None, then
|
|
171
|
+
inputs are not clamped.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
float or ndarray
|
|
176
|
+
Rescaled values.
|
|
177
|
+
float or ndarray
|
|
178
|
+
Log Jacobian determinant.
|
|
179
|
+
"""
|
|
180
|
+
xp = array_namespace(x)
|
|
181
|
+
if eps:
|
|
182
|
+
x = xp.clip(x, eps, 1 - eps)
|
|
183
|
+
y = xp.log(x) - xp.log1p(-x)
|
|
184
|
+
log_j = (-xp.log(x) - xp.log1p(-x)).sum(-1)
|
|
185
|
+
return y, log_j
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def sigmoid(x: Array) -> tuple[Array, Array]:
|
|
189
|
+
"""Sigmoid function that also returns log Jacobian determinant.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
x : float or ndarray
|
|
194
|
+
Array of values
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
float or ndarray
|
|
199
|
+
Rescaled values.
|
|
200
|
+
float or ndarray
|
|
201
|
+
Log Jacobian determinant.
|
|
202
|
+
"""
|
|
203
|
+
xp = array_namespace(x)
|
|
204
|
+
x = xp.divide(1, 1 + xp.exp(-x))
|
|
205
|
+
log_j = (xp.log(x) + xp.log1p(-x)).sum(-1)
|
|
206
|
+
return x, log_j
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def logsumexp(x: Array, axis: int | None = None) -> Array:
|
|
210
|
+
"""Implementation of logsumexp that works with array api.
|
|
211
|
+
|
|
212
|
+
This will be removed once the implementation in scipy is compatible.
|
|
213
|
+
"""
|
|
214
|
+
xp = array_namespace(x)
|
|
215
|
+
c = x.max()
|
|
216
|
+
return c + xp.log(xp.sum(xp.exp(x - c), axis=axis))
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def to_numpy(x: Array, **kwargs) -> np.ndarray:
|
|
220
|
+
"""Convert an array to a numpy array.
|
|
221
|
+
|
|
222
|
+
This automatically moves the array to the CPU.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
x : Array
|
|
227
|
+
The array to convert.
|
|
228
|
+
kwargs : dict
|
|
229
|
+
Additional keyword arguments to pass to numpy.asarray.
|
|
230
|
+
"""
|
|
231
|
+
try:
|
|
232
|
+
return np.asarray(to_device(x, "cpu"), **kwargs)
|
|
233
|
+
except (ValueError, NotImplementedError):
|
|
234
|
+
return np.asarray(x, **kwargs)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
238
|
+
"""Convert an array to the specified array API.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
x : Array
|
|
243
|
+
The array to convert.
|
|
244
|
+
xp : Any
|
|
245
|
+
The array API to use for the conversion. If None, the array API
|
|
246
|
+
is inferred from the input array.
|
|
247
|
+
kwargs : dict
|
|
248
|
+
Additional keyword arguments to pass to xp.asarray.
|
|
249
|
+
"""
|
|
250
|
+
if is_jax_array(x) and is_torch_namespace(xp):
|
|
251
|
+
return xp.utils.dlpack.from_dlpack(x)
|
|
252
|
+
else:
|
|
253
|
+
return xp.asarray(x, **kwargs)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
|
|
257
|
+
"""Resolve a dtype specification into an XP-specific dtype.
|
|
258
|
+
|
|
259
|
+
Parameters
|
|
260
|
+
----------
|
|
261
|
+
dtype : Any | str | None
|
|
262
|
+
The dtype specification. Can be None, a string, or a dtype-like object.
|
|
263
|
+
xp : module
|
|
264
|
+
The array API module that should interpret the dtype.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
Any | None
|
|
269
|
+
The resolved dtype object compatible with ``xp`` (or None if unspecified).
|
|
270
|
+
"""
|
|
271
|
+
if dtype is None or xp is None:
|
|
272
|
+
return dtype
|
|
273
|
+
|
|
274
|
+
if isinstance(dtype, str):
|
|
275
|
+
dtype_name = _dtype_to_name(dtype)
|
|
276
|
+
if is_torch_namespace(xp):
|
|
277
|
+
resolved = getattr(xp, dtype_name, None)
|
|
278
|
+
if resolved is None:
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Unknown dtype '{dtype}' for namespace {xp.__name__}"
|
|
281
|
+
)
|
|
282
|
+
return resolved
|
|
283
|
+
try:
|
|
284
|
+
return xp.dtype(dtype_name)
|
|
285
|
+
except (AttributeError, TypeError, ValueError):
|
|
286
|
+
resolved = getattr(xp, dtype_name, None)
|
|
287
|
+
if resolved is not None:
|
|
288
|
+
return resolved
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Unknown dtype '{dtype}' for namespace {getattr(xp, '__name__', xp)}"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
if is_torch_namespace(xp):
|
|
294
|
+
return dtype
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
return xp.dtype(dtype)
|
|
298
|
+
except (AttributeError, TypeError, ValueError):
|
|
299
|
+
return dtype
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _dtype_to_name(dtype: Any | str | None) -> str | None:
|
|
303
|
+
"""Extract a canonical (lowercase) name for a dtype-like object."""
|
|
304
|
+
if dtype is None:
|
|
305
|
+
return None
|
|
306
|
+
if isinstance(dtype, str):
|
|
307
|
+
name = dtype
|
|
308
|
+
elif hasattr(dtype, "name") and getattr(dtype, "name"):
|
|
309
|
+
name = dtype.name
|
|
310
|
+
elif hasattr(dtype, "__name__"):
|
|
311
|
+
name = dtype.__name__
|
|
312
|
+
else:
|
|
313
|
+
text = str(dtype)
|
|
314
|
+
if text.startswith("<class '") and text.endswith("'>"):
|
|
315
|
+
text = text.split("'")[1]
|
|
316
|
+
if text.startswith("dtype(") and text.endswith(")"):
|
|
317
|
+
inner = text[6:-1].strip("'\" ")
|
|
318
|
+
text = inner or text
|
|
319
|
+
name = text
|
|
320
|
+
name = name.split(".")[-1]
|
|
321
|
+
return name.strip(" '\"<>").lower()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def convert_dtype(
|
|
325
|
+
dtype: Any | str | None,
|
|
326
|
+
target_xp: Any,
|
|
327
|
+
*,
|
|
328
|
+
source_xp: Any | None = None,
|
|
329
|
+
) -> Any | None:
|
|
330
|
+
"""Convert a dtype between array API namespaces.
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
dtype : Any | str | None
|
|
335
|
+
The dtype to convert. Can be a dtype object, string, or None.
|
|
336
|
+
target_xp : module
|
|
337
|
+
The target array API namespace to convert the dtype into.
|
|
338
|
+
source_xp : module, optional
|
|
339
|
+
The source namespace of the dtype. Provided for API symmetry and future
|
|
340
|
+
use; currently unused but accepted.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
Any | None
|
|
345
|
+
The dtype object compatible with ``target_xp`` (or None if ``dtype`` is None).
|
|
346
|
+
"""
|
|
347
|
+
if dtype is None:
|
|
348
|
+
return None
|
|
349
|
+
if target_xp is None:
|
|
350
|
+
raise ValueError("target_xp must be provided to convert dtype.")
|
|
351
|
+
|
|
352
|
+
target_name = getattr(target_xp, "__name__", "")
|
|
353
|
+
dtype_module = getattr(dtype, "__module__", "")
|
|
354
|
+
if dtype_module.startswith(target_name):
|
|
355
|
+
return dtype
|
|
356
|
+
if is_torch_namespace(target_xp) and str(dtype).startswith("torch."):
|
|
357
|
+
return dtype
|
|
358
|
+
|
|
359
|
+
name = _dtype_to_name(dtype)
|
|
360
|
+
if not name:
|
|
361
|
+
raise ValueError(f"Could not infer dtype name from {dtype!r}")
|
|
362
|
+
|
|
363
|
+
candidates = dict.fromkeys(
|
|
364
|
+
[name, name.lower(), name.upper(), name.capitalize()]
|
|
365
|
+
)
|
|
366
|
+
last_error: Exception | None = None
|
|
367
|
+
for candidate in candidates:
|
|
368
|
+
try:
|
|
369
|
+
return resolve_dtype(candidate, target_xp)
|
|
370
|
+
except ValueError as exc:
|
|
371
|
+
last_error = exc
|
|
372
|
+
|
|
373
|
+
# Fallback to direct attribute lookup
|
|
374
|
+
attr = getattr(target_xp, name, None) or getattr(
|
|
375
|
+
target_xp, name.lower(), None
|
|
376
|
+
)
|
|
377
|
+
if attr is not None:
|
|
378
|
+
return attr
|
|
379
|
+
|
|
380
|
+
raise ValueError(
|
|
381
|
+
f"Unable to convert dtype {dtype!r} to namespace {target_name}"
|
|
382
|
+
) from last_error
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def copy_array(x, xp: Any = None) -> Array:
|
|
386
|
+
"""Copy an array based on the array API being used.
|
|
387
|
+
|
|
388
|
+
This uses the most appropriate method to copy the array
|
|
389
|
+
depending on the array API.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
x : Array
|
|
394
|
+
The array to copy.
|
|
395
|
+
xp : Any
|
|
396
|
+
The array API to use for the copy.
|
|
397
|
+
|
|
398
|
+
Returns
|
|
399
|
+
-------
|
|
400
|
+
Array
|
|
401
|
+
The copied array.
|
|
402
|
+
"""
|
|
403
|
+
if xp is None:
|
|
404
|
+
xp = array_namespace(x)
|
|
405
|
+
# torch does not play nicely since it complains about copying tensors
|
|
406
|
+
if is_torch_namespace(xp):
|
|
407
|
+
if is_torch_array(x):
|
|
408
|
+
return xp.clone(x)
|
|
409
|
+
else:
|
|
410
|
+
return xp.as_tensor(x)
|
|
411
|
+
else:
|
|
412
|
+
try:
|
|
413
|
+
return xp.copy(x)
|
|
414
|
+
except (AttributeError, TypeError):
|
|
415
|
+
# Fallback for array APIs that do not have a copy method
|
|
416
|
+
return xp.array(x, copy=True)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def effective_sample_size(log_w: Array) -> float:
|
|
420
|
+
xp = array_namespace(log_w)
|
|
421
|
+
return xp.exp(xp.asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2)))
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@contextmanager
|
|
425
|
+
def disable_gradients(xp, inference: bool = True):
|
|
426
|
+
"""Disable gradients for a specific array API.
|
|
427
|
+
|
|
428
|
+
Usage:
|
|
429
|
+
|
|
430
|
+
```python
|
|
431
|
+
with disable_gradients(xp):
|
|
432
|
+
# Do something
|
|
433
|
+
```
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
xp : module
|
|
438
|
+
The array API module to use.
|
|
439
|
+
inference : bool, optional
|
|
440
|
+
When using PyTorch, set to True to enable inference mode.
|
|
441
|
+
"""
|
|
442
|
+
if is_torch_namespace(xp):
|
|
443
|
+
if inference:
|
|
444
|
+
with xp.inference_mode():
|
|
445
|
+
yield
|
|
446
|
+
else:
|
|
447
|
+
with xp.no_grad():
|
|
448
|
+
yield
|
|
449
|
+
else:
|
|
450
|
+
yield
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def encode_dtype(xp, dtype):
|
|
454
|
+
"""Encode a dtype for storage in an HDF5 file.
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
----------
|
|
458
|
+
xp : module
|
|
459
|
+
The array API module to use.
|
|
460
|
+
dtype : dtype
|
|
461
|
+
The dtype to encode.
|
|
462
|
+
|
|
463
|
+
Returns
|
|
464
|
+
-------
|
|
465
|
+
str
|
|
466
|
+
The encoded dtype.
|
|
467
|
+
"""
|
|
468
|
+
if dtype is None:
|
|
469
|
+
return None
|
|
470
|
+
return {
|
|
471
|
+
"__dtype__": True,
|
|
472
|
+
"xp": xp.__name__,
|
|
473
|
+
"dtype": _dtype_to_name(dtype),
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def decode_dtype(xp, encoded_dtype):
|
|
478
|
+
"""Decode a dtype from an HDF5 file.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
xp : module
|
|
483
|
+
The array API module to use.
|
|
484
|
+
encoded_dtype : dict
|
|
485
|
+
The encoded dtype.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
dtype
|
|
490
|
+
The decoded dtype.
|
|
491
|
+
"""
|
|
492
|
+
if isinstance(encoded_dtype, dict) and encoded_dtype.get("__dtype__"):
|
|
493
|
+
if encoded_dtype["xp"] != xp.__name__:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
f"Encoded dtype xp {encoded_dtype['xp']} does not match "
|
|
496
|
+
f"current xp {xp.__name__}"
|
|
497
|
+
)
|
|
498
|
+
if is_torch_namespace(xp):
|
|
499
|
+
return getattr(xp, encoded_dtype["dtype"].split(".")[-1])
|
|
500
|
+
else:
|
|
501
|
+
return xp.dtype(encoded_dtype["dtype"].split(".")[-1])
|
|
502
|
+
else:
|
|
503
|
+
return encoded_dtype
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def encode_for_hdf5(value: Any) -> Any:
|
|
507
|
+
"""Encode a value for storage in an HDF5 file.
|
|
508
|
+
|
|
509
|
+
Special cases:
|
|
510
|
+
- None is replaced with "__none__"
|
|
511
|
+
- Empty dictionaries are replaced with "__empty_dict__"
|
|
512
|
+
"""
|
|
513
|
+
if is_jax_array(value) or is_torch_array(value):
|
|
514
|
+
return to_numpy(value)
|
|
515
|
+
if isinstance(value, CallHistory):
|
|
516
|
+
return value.to_dict(list_to_dict=True)
|
|
517
|
+
if isinstance(value, np.ndarray):
|
|
518
|
+
return value
|
|
519
|
+
if isinstance(value, (int, float, str)):
|
|
520
|
+
return value
|
|
521
|
+
if isinstance(value, (list, tuple)):
|
|
522
|
+
if all(isinstance(v, str) for v in value):
|
|
523
|
+
dt = h5py.string_dtype(encoding="utf-8")
|
|
524
|
+
return np.array(value, dtype=dt)
|
|
525
|
+
return [encode_for_hdf5(v) for v in value]
|
|
526
|
+
if isinstance(value, set):
|
|
527
|
+
return {encode_for_hdf5(v) for v in value}
|
|
528
|
+
if isinstance(value, dict):
|
|
529
|
+
if not value:
|
|
530
|
+
return "__empty_dict__"
|
|
531
|
+
else:
|
|
532
|
+
return {k: encode_for_hdf5(v) for k, v in value.items()}
|
|
533
|
+
if value is None:
|
|
534
|
+
return "__none__"
|
|
535
|
+
|
|
536
|
+
return value
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def decode_from_hdf5(value: Any) -> Any:
|
|
540
|
+
"""Decode a value loaded from an HDF5 file, reversing encode_for_hdf5."""
|
|
541
|
+
if isinstance(value, bytes): # HDF5 may store strings as bytes
|
|
542
|
+
value = value.decode("utf-8")
|
|
543
|
+
|
|
544
|
+
if isinstance(value, str):
|
|
545
|
+
if value == "__none__":
|
|
546
|
+
return None
|
|
547
|
+
if value == "__empty_dict__":
|
|
548
|
+
return {}
|
|
549
|
+
|
|
550
|
+
if isinstance(value, np.ndarray):
|
|
551
|
+
# Try to collapse 0-D arrays into scalars
|
|
552
|
+
if value.shape == ():
|
|
553
|
+
return value.item()
|
|
554
|
+
if value.dtype.kind in {"S", "O"}:
|
|
555
|
+
try:
|
|
556
|
+
return value.astype(str).tolist()
|
|
557
|
+
except Exception:
|
|
558
|
+
# fallback: leave as ndarray
|
|
559
|
+
return value
|
|
560
|
+
return value
|
|
561
|
+
|
|
562
|
+
if isinstance(value, list):
|
|
563
|
+
return [decode_from_hdf5(v) for v in value]
|
|
564
|
+
if isinstance(value, tuple):
|
|
565
|
+
return tuple(decode_from_hdf5(v) for v in value)
|
|
566
|
+
if isinstance(value, set):
|
|
567
|
+
return {decode_from_hdf5(v) for v in value}
|
|
568
|
+
if isinstance(value, dict):
|
|
569
|
+
return {
|
|
570
|
+
k.decode("utf-8"): decode_from_hdf5(v) for k, v in value.items()
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
# Fallback for ints, floats, strs, etc.
|
|
574
|
+
return value
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def recursively_save_to_h5_file(h5_file, path, dictionary):
|
|
578
|
+
"""Save a dictionary to an HDF5 file with flattened keys under a given group path."""
|
|
579
|
+
# Ensure the group exists (or open it if already present)
|
|
580
|
+
group = h5_file.require_group(path)
|
|
581
|
+
|
|
582
|
+
def _save_flattened(g, prefix, d):
|
|
583
|
+
for key, value in d.items():
|
|
584
|
+
full_key = f"{prefix}.{key}" if prefix else key
|
|
585
|
+
if isinstance(value, dict):
|
|
586
|
+
_save_flattened(g, full_key, value)
|
|
587
|
+
else:
|
|
588
|
+
try:
|
|
589
|
+
g.create_dataset(full_key, data=encode_for_hdf5(value))
|
|
590
|
+
except TypeError as error:
|
|
591
|
+
try:
|
|
592
|
+
# Try saving as a string
|
|
593
|
+
dt = h5py.string_dtype(encoding="utf-8")
|
|
594
|
+
g.create_dataset(
|
|
595
|
+
full_key, data=np.array(str(value), dtype=dt)
|
|
596
|
+
)
|
|
597
|
+
except Exception:
|
|
598
|
+
raise RuntimeError(
|
|
599
|
+
f"Cannot save key {full_key} with value {value} to HDF5 file."
|
|
600
|
+
) from error
|
|
601
|
+
|
|
602
|
+
_save_flattened(group, "", dictionary)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def load_from_h5_file(h5_file, path):
|
|
606
|
+
"""Load a flattened dictionary from an HDF5 group and rebuild nesting."""
|
|
607
|
+
group = h5_file[path]
|
|
608
|
+
result = {}
|
|
609
|
+
|
|
610
|
+
for key, dataset in group.items():
|
|
611
|
+
parts = key.split(".")
|
|
612
|
+
d = result
|
|
613
|
+
for part in parts[:-1]:
|
|
614
|
+
d = d.setdefault(part, {})
|
|
615
|
+
d[parts[-1]] = decode_from_hdf5(dataset[()])
|
|
616
|
+
|
|
617
|
+
return result
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def get_package_version(package_name: str) -> str:
|
|
621
|
+
"""Get the version of a package.
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
package_name : str
|
|
626
|
+
The name of the package.
|
|
627
|
+
|
|
628
|
+
Returns
|
|
629
|
+
-------
|
|
630
|
+
str
|
|
631
|
+
The version of the package.
|
|
632
|
+
"""
|
|
633
|
+
try:
|
|
634
|
+
module = __import__(package_name)
|
|
635
|
+
return module.__version__
|
|
636
|
+
except ImportError:
|
|
637
|
+
return "not installed"
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
class AspireFile(h5py.File):
|
|
641
|
+
"""A subclass of h5py.File that adds metadata to the file."""
|
|
642
|
+
|
|
643
|
+
def __init__(self, *args, **kwargs):
|
|
644
|
+
super().__init__(*args, **kwargs)
|
|
645
|
+
self._set_aspire_metadata()
|
|
646
|
+
|
|
647
|
+
def _set_aspire_metadata(self):
|
|
648
|
+
from . import __version__ as aspire_version
|
|
649
|
+
|
|
650
|
+
if self.mode in {"w", "w-", "a", "r+"}:
|
|
651
|
+
self.attrs["aspire_version"] = aspire_version
|
|
652
|
+
else:
|
|
653
|
+
aspire_version = self.attrs.get("aspire_version", "unknown")
|
|
654
|
+
if aspire_version != "unknown":
|
|
655
|
+
logger.warning(
|
|
656
|
+
f"Opened Aspire file created with version {aspire_version}. "
|
|
657
|
+
f"Current version is {aspire_version}."
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
|
|
662
|
+
"""Update an array at specific indices."
|
|
663
|
+
|
|
664
|
+
This is a workaround for the fact that array API does not support
|
|
665
|
+
advanced indexing with all backends.
|
|
666
|
+
|
|
667
|
+
Examples
|
|
668
|
+
--------
|
|
669
|
+
>>> x = xp.array([[1, 2], [3, 4], [5, 6]])
|
|
670
|
+
>>> update_at_indices(x, (slice(None), 0), xp.array([10, 20, 30]))
|
|
671
|
+
[[10 2]
|
|
672
|
+
[20 4]
|
|
673
|
+
[30 6]]
|
|
674
|
+
|
|
675
|
+
Parameters
|
|
676
|
+
----------
|
|
677
|
+
x : Array
|
|
678
|
+
The array to update.
|
|
679
|
+
slc : Array
|
|
680
|
+
The indices to update.
|
|
681
|
+
y : Array
|
|
682
|
+
The values to set at the indices.
|
|
683
|
+
|
|
684
|
+
Returns
|
|
685
|
+
-------
|
|
686
|
+
Array
|
|
687
|
+
The updated array.
|
|
688
|
+
"""
|
|
689
|
+
try:
|
|
690
|
+
x[slc] = y
|
|
691
|
+
return x
|
|
692
|
+
except TypeError:
|
|
693
|
+
return x.at[slc].set(y)
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
@dataclass
|
|
697
|
+
class CallHistory:
|
|
698
|
+
"""Class to store the history of calls to a function.
|
|
699
|
+
|
|
700
|
+
Attributes
|
|
701
|
+
----------
|
|
702
|
+
args : list[tuple]
|
|
703
|
+
The positional arguments of each call.
|
|
704
|
+
kwargs : list[dict]
|
|
705
|
+
The keyword arguments of each call.
|
|
706
|
+
"""
|
|
707
|
+
|
|
708
|
+
args: list[tuple]
|
|
709
|
+
kwargs: list[dict]
|
|
710
|
+
|
|
711
|
+
def to_dict(self, list_to_dict: bool = False) -> dict[str, Any]:
|
|
712
|
+
"""Convert the call history to a dictionary.
|
|
713
|
+
|
|
714
|
+
Parameters
|
|
715
|
+
----------
|
|
716
|
+
list_to_dict : bool
|
|
717
|
+
If True, convert the lists of args and kwargs to dictionaries
|
|
718
|
+
with string keys. If False, keep them as lists. This is useful
|
|
719
|
+
when encoding the history for HDF5.
|
|
720
|
+
"""
|
|
721
|
+
if list_to_dict:
|
|
722
|
+
return {
|
|
723
|
+
"args": {str(i): v for i, v in enumerate(self.args)},
|
|
724
|
+
"kwargs": {str(i): v for i, v in enumerate(self.kwargs)},
|
|
725
|
+
}
|
|
726
|
+
else:
|
|
727
|
+
return {
|
|
728
|
+
"args": [list(arg) for arg in self.args],
|
|
729
|
+
"kwargs": [dict(kwarg) for kwarg in self.kwargs],
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def track_calls(wrapped=None):
|
|
734
|
+
"""Decorator to track calls to a function.
|
|
735
|
+
|
|
736
|
+
The decorator adds a :code:`calls` attribute to the wrapped function,
|
|
737
|
+
which is a :py:class:`CallHistory` object that stores the arguments and
|
|
738
|
+
keyword arguments of each call.
|
|
739
|
+
"""
|
|
740
|
+
|
|
741
|
+
@wrapt.decorator
|
|
742
|
+
def wrapper(wrapped_func, instance, args, kwargs):
|
|
743
|
+
# If instance is provided, we're dealing with a method.
|
|
744
|
+
if instance:
|
|
745
|
+
# Attach `calls` attribute to the method's `__func__`, which is the original function
|
|
746
|
+
if not hasattr(wrapped_func.__func__, "calls"):
|
|
747
|
+
wrapped_func.__func__.calls = CallHistory([], [])
|
|
748
|
+
wrapped_func.__func__.calls.args.append(args)
|
|
749
|
+
wrapped_func.__func__.calls.kwargs.append(kwargs)
|
|
750
|
+
else:
|
|
751
|
+
# For standalone functions, attach `calls` directly to the function
|
|
752
|
+
if not hasattr(wrapped_func, "calls"):
|
|
753
|
+
wrapped_func.calls = CallHistory([], [])
|
|
754
|
+
wrapped_func.calls.args.append(args)
|
|
755
|
+
wrapped_func.calls.kwargs.append(kwargs)
|
|
756
|
+
|
|
757
|
+
# Call the original wrapped function
|
|
758
|
+
return wrapped_func(*args, **kwargs)
|
|
759
|
+
|
|
760
|
+
return wrapper(wrapped) if wrapped else wrapper
|