aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
aspire/utils.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import array_api_compat.numpy as np
|
|
11
|
+
import h5py
|
|
12
|
+
import wrapt
|
|
13
|
+
from array_api_compat import (
|
|
14
|
+
array_namespace,
|
|
15
|
+
is_jax_array,
|
|
16
|
+
is_torch_array,
|
|
17
|
+
is_torch_namespace,
|
|
18
|
+
to_device,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from multiprocessing import Pool
|
|
23
|
+
|
|
24
|
+
from array_api_compat.common._typing import Array
|
|
25
|
+
|
|
26
|
+
from .aspire import Aspire
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def configure_logger(
|
|
32
|
+
log_level: str | int = "INFO",
|
|
33
|
+
additional_loggers: list[str] = None,
|
|
34
|
+
include_aspire_loggers: bool = True,
|
|
35
|
+
) -> logging.Logger:
|
|
36
|
+
"""Configure the logger.
|
|
37
|
+
|
|
38
|
+
Adds a stream handler to the logger.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
log_level : str or int, optional
|
|
43
|
+
The log level to use. Defaults to "INFO".
|
|
44
|
+
additional_loggers : list of str, optional
|
|
45
|
+
Additional loggers to configure. Defaults to None.
|
|
46
|
+
include_aspire_loggers : bool, optional
|
|
47
|
+
Whether to include all loggers that start with "aspire_" or "aspire-".
|
|
48
|
+
Defaults to True.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
logging.Logger
|
|
53
|
+
The configured logger.
|
|
54
|
+
"""
|
|
55
|
+
logger = logging.getLogger("aspire")
|
|
56
|
+
logger.setLevel(log_level)
|
|
57
|
+
ch = logging.StreamHandler()
|
|
58
|
+
ch.setLevel(log_level)
|
|
59
|
+
formatter = logging.Formatter(
|
|
60
|
+
"%(asctime)s - aspire - %(levelname)s - %(message)s"
|
|
61
|
+
)
|
|
62
|
+
ch.setFormatter(formatter)
|
|
63
|
+
logger.addHandler(ch)
|
|
64
|
+
|
|
65
|
+
additional_loggers = additional_loggers or []
|
|
66
|
+
for name in logger.manager.loggerDict:
|
|
67
|
+
if include_aspire_loggers and (
|
|
68
|
+
name.startswith("aspire_") or name.startswith("aspire-")
|
|
69
|
+
):
|
|
70
|
+
additional_loggers.append(name)
|
|
71
|
+
|
|
72
|
+
for name in additional_loggers:
|
|
73
|
+
dep_logger = logging.getLogger(name)
|
|
74
|
+
dep_logger.setLevel(log_level)
|
|
75
|
+
dep_logger.handlers.clear()
|
|
76
|
+
for handler in logger.handlers:
|
|
77
|
+
dep_logger.addHandler(handler)
|
|
78
|
+
dep_logger.propagate = False
|
|
79
|
+
|
|
80
|
+
return logger
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class PoolHandler:
|
|
84
|
+
"""Context manager to temporarily replace the log_likelihood method of a
|
|
85
|
+
aspire instance with a version that uses a multiprocessing pool to
|
|
86
|
+
parallelize computation.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
aspire_instance : aspire
|
|
91
|
+
The aspire instance to modify. The log_likelihood method of this
|
|
92
|
+
instance must accept a :code:`map_fn` keyword argument.
|
|
93
|
+
pool : multiprocessing.Pool
|
|
94
|
+
The pool to use for parallel computation.
|
|
95
|
+
close_pool : bool, optional
|
|
96
|
+
Whether to close the pool when exiting the context manager.
|
|
97
|
+
Defaults to True.
|
|
98
|
+
parallelize_prior : bool, optional
|
|
99
|
+
Whether to parallelize the log_prior method as well. Defaults to False.
|
|
100
|
+
If True, the log_prior method of the aspire instance must also
|
|
101
|
+
accept a :code:`map_fn` keyword argument.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
aspire_instance: Aspire,
|
|
107
|
+
pool: Pool,
|
|
108
|
+
close_pool: bool = True,
|
|
109
|
+
parallelize_prior: bool = False,
|
|
110
|
+
):
|
|
111
|
+
self.parallelize_prior = parallelize_prior
|
|
112
|
+
self.aspire_instance = aspire_instance
|
|
113
|
+
self.pool = pool
|
|
114
|
+
self.close_pool = close_pool
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def aspire_instance(self):
|
|
118
|
+
return self._aspire_instance
|
|
119
|
+
|
|
120
|
+
@aspire_instance.setter
|
|
121
|
+
def aspire_instance(self, value: Aspire):
|
|
122
|
+
signature = inspect.signature(value.log_likelihood)
|
|
123
|
+
if "map_fn" not in signature.parameters:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"The log_likelihood method of the Aspire instance must accept a"
|
|
126
|
+
" 'map_fn' keyword argument."
|
|
127
|
+
)
|
|
128
|
+
signature = inspect.signature(value.log_prior)
|
|
129
|
+
if "map_fn" not in signature.parameters and self.parallelize_prior:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"The log_prior method of the Aspire instance must accept a"
|
|
132
|
+
" 'map_fn' keyword argument if parallelize_prior is True."
|
|
133
|
+
)
|
|
134
|
+
self._aspire_instance = value
|
|
135
|
+
|
|
136
|
+
def __enter__(self):
|
|
137
|
+
self.original_log_likelihood = self.aspire_instance.log_likelihood
|
|
138
|
+
self.original_log_prior = self.aspire_instance.log_prior
|
|
139
|
+
if self.pool is not None:
|
|
140
|
+
logger.debug("Updating map function in log-likelihood method")
|
|
141
|
+
self.aspire_instance.log_likelihood = partial(
|
|
142
|
+
self.original_log_likelihood, map_fn=self.pool.map
|
|
143
|
+
)
|
|
144
|
+
if self.parallelize_prior:
|
|
145
|
+
logger.debug("Updating map function in log-prior method")
|
|
146
|
+
self.aspire_instance.log_prior = partial(
|
|
147
|
+
self.original_log_prior, map_fn=self.pool.map
|
|
148
|
+
)
|
|
149
|
+
return self.pool
|
|
150
|
+
|
|
151
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
152
|
+
self.aspire_instance.log_likelihood = self.original_log_likelihood
|
|
153
|
+
self.aspire_instance.log_prior = self.original_log_prior
|
|
154
|
+
if self.close_pool:
|
|
155
|
+
logger.debug("Closing pool")
|
|
156
|
+
self.pool.close()
|
|
157
|
+
self.pool.join()
|
|
158
|
+
else:
|
|
159
|
+
logger.debug("Not closing pool")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def logit(x: Array, eps: float | None = None) -> tuple[Array, Array]:
|
|
163
|
+
"""Logit function that also returns log Jacobian determinant.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
x : float or ndarray
|
|
168
|
+
Array of values
|
|
169
|
+
eps : float, optional
|
|
170
|
+
Epsilon value used to clamp inputs to [eps, 1 - eps]. If None, then
|
|
171
|
+
inputs are not clamped.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
float or ndarray
|
|
176
|
+
Rescaled values.
|
|
177
|
+
float or ndarray
|
|
178
|
+
Log Jacobian determinant.
|
|
179
|
+
"""
|
|
180
|
+
xp = array_namespace(x)
|
|
181
|
+
if eps:
|
|
182
|
+
x = xp.clip(x, eps, 1 - eps)
|
|
183
|
+
y = xp.log(x) - xp.log1p(-x)
|
|
184
|
+
log_j = (-xp.log(x) - xp.log1p(-x)).sum(-1)
|
|
185
|
+
return y, log_j
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def sigmoid(x: Array) -> tuple[Array, Array]:
|
|
189
|
+
"""Sigmoid function that also returns log Jacobian determinant.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
x : float or ndarray
|
|
194
|
+
Array of values
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
float or ndarray
|
|
199
|
+
Rescaled values.
|
|
200
|
+
float or ndarray
|
|
201
|
+
Log Jacobian determinant.
|
|
202
|
+
"""
|
|
203
|
+
xp = array_namespace(x)
|
|
204
|
+
x = xp.divide(1, 1 + xp.exp(-x))
|
|
205
|
+
log_j = (xp.log(x) + xp.log1p(-x)).sum(-1)
|
|
206
|
+
return x, log_j
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def logsumexp(x: Array, axis: int | None = None) -> Array:
|
|
210
|
+
"""Implementation of logsumexp that works with array api.
|
|
211
|
+
|
|
212
|
+
This will be removed once the implementation in scipy is compatible.
|
|
213
|
+
"""
|
|
214
|
+
xp = array_namespace(x)
|
|
215
|
+
c = x.max()
|
|
216
|
+
return c + xp.log(xp.sum(xp.exp(x - c), axis=axis))
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def to_numpy(x: Array, **kwargs) -> np.ndarray:
|
|
220
|
+
"""Convert an array to a numpy array.
|
|
221
|
+
|
|
222
|
+
This automatically moves the device to the CPU.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
x : Array
|
|
227
|
+
The array to convert.
|
|
228
|
+
kwargs : dict
|
|
229
|
+
Additional keyword arguments to pass to numpy.asarray.
|
|
230
|
+
"""
|
|
231
|
+
try:
|
|
232
|
+
return np.asarray(to_device(x, "cpu"), **kwargs)
|
|
233
|
+
except ValueError:
|
|
234
|
+
return np.asarray(x, **kwargs)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def asarray(x, xp: Any = None, **kwargs) -> Array:
|
|
238
|
+
"""Convert an array to the specified array API.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
x : Array
|
|
243
|
+
The array to convert.
|
|
244
|
+
xp : Any
|
|
245
|
+
The array API to use for the conversion. If None, the array API
|
|
246
|
+
is inferred from the input array.
|
|
247
|
+
kwargs : dict
|
|
248
|
+
Additional keyword arguments to pass to xp.asarray.
|
|
249
|
+
"""
|
|
250
|
+
if is_jax_array(x) and is_torch_namespace(xp):
|
|
251
|
+
return xp.utils.dlpack.from_dlpack(x)
|
|
252
|
+
else:
|
|
253
|
+
return xp.asarray(x, **kwargs)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def copy_array(x, xp: Any = None) -> Array:
|
|
257
|
+
"""Copy an array based on the array API being used.
|
|
258
|
+
|
|
259
|
+
This uses the most appropriate method to copy the array
|
|
260
|
+
depending on the array API.
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
x : Array
|
|
265
|
+
The array to copy.
|
|
266
|
+
xp : Any
|
|
267
|
+
The array API to use for the copy.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
Array
|
|
272
|
+
The copied array.
|
|
273
|
+
"""
|
|
274
|
+
if xp is None:
|
|
275
|
+
xp = array_namespace(x)
|
|
276
|
+
# torch does not play nicely since it complains about copying tensors
|
|
277
|
+
if is_torch_namespace(xp):
|
|
278
|
+
if is_torch_array(x):
|
|
279
|
+
return xp.clone(x)
|
|
280
|
+
else:
|
|
281
|
+
return xp.as_tensor(x)
|
|
282
|
+
else:
|
|
283
|
+
try:
|
|
284
|
+
return xp.copy(x)
|
|
285
|
+
except (AttributeError, TypeError):
|
|
286
|
+
# Fallback for array APIs that do not have a copy method
|
|
287
|
+
return xp.array(x, copy=True)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def effective_sample_size(log_w: Array) -> float:
|
|
291
|
+
xp = array_namespace(log_w)
|
|
292
|
+
return xp.exp(xp.asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2)))
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@contextmanager
|
|
296
|
+
def disable_gradients(xp, inference: bool = True):
|
|
297
|
+
"""Disable gradients for a specific array API.
|
|
298
|
+
|
|
299
|
+
Usage:
|
|
300
|
+
|
|
301
|
+
```python
|
|
302
|
+
with disable_gradients(xp):
|
|
303
|
+
# Do something
|
|
304
|
+
```
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
xp : module
|
|
309
|
+
The array API module to use.
|
|
310
|
+
inference : bool, optional
|
|
311
|
+
When using PyTorch, set to True to enable inference mode.
|
|
312
|
+
"""
|
|
313
|
+
if is_torch_namespace(xp):
|
|
314
|
+
if inference:
|
|
315
|
+
with xp.inference_mode():
|
|
316
|
+
yield
|
|
317
|
+
else:
|
|
318
|
+
with xp.no_grad():
|
|
319
|
+
yield
|
|
320
|
+
else:
|
|
321
|
+
yield
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def encode_for_hdf5(value: Any) -> Any:
|
|
325
|
+
"""Encode a value for storage in an HDF5 file.
|
|
326
|
+
|
|
327
|
+
Special cases:
|
|
328
|
+
- None is replaced with "__none__"
|
|
329
|
+
- Empty dictionaries are replaced with "__empty_dict__"
|
|
330
|
+
"""
|
|
331
|
+
if isinstance(value, CallHistory):
|
|
332
|
+
return value.to_dict(list_to_dict=True)
|
|
333
|
+
if isinstance(value, np.ndarray):
|
|
334
|
+
return value
|
|
335
|
+
if isinstance(value, (int, float, str)):
|
|
336
|
+
return value
|
|
337
|
+
if isinstance(value, (list, tuple)):
|
|
338
|
+
return [encode_for_hdf5(v) for v in value]
|
|
339
|
+
if isinstance(value, set):
|
|
340
|
+
return {encode_for_hdf5(v) for v in value}
|
|
341
|
+
if isinstance(value, dict):
|
|
342
|
+
if not value:
|
|
343
|
+
return "__empty_dict__"
|
|
344
|
+
else:
|
|
345
|
+
return {k: encode_for_hdf5(v) for k, v in value.items()}
|
|
346
|
+
if value is None:
|
|
347
|
+
return "__none__"
|
|
348
|
+
return value
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def recursively_save_to_h5_file(h5_file, path, dictionary):
|
|
352
|
+
"""Recursively save a dictionary to an HDF5 file."""
|
|
353
|
+
for key, value in dictionary.items():
|
|
354
|
+
if isinstance(value, dict):
|
|
355
|
+
recursively_save_to_h5_file(h5_file, f"{path}/{key}", value)
|
|
356
|
+
else:
|
|
357
|
+
try:
|
|
358
|
+
h5_file.create_dataset(
|
|
359
|
+
f"{path}/{key}", data=encode_for_hdf5(value)
|
|
360
|
+
)
|
|
361
|
+
except TypeError as error:
|
|
362
|
+
raise RuntimeError(
|
|
363
|
+
f"Cannot save key {key} with value {value} to HDF5 file."
|
|
364
|
+
) from error
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def get_package_version(package_name: str) -> str:
|
|
368
|
+
"""Get the version of a package.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
package_name : str
|
|
373
|
+
The name of the package.
|
|
374
|
+
|
|
375
|
+
Returns
|
|
376
|
+
-------
|
|
377
|
+
str
|
|
378
|
+
The version of the package.
|
|
379
|
+
"""
|
|
380
|
+
try:
|
|
381
|
+
module = __import__(package_name)
|
|
382
|
+
return module.__version__
|
|
383
|
+
except ImportError:
|
|
384
|
+
return "not installed"
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class AspireFile(h5py.File):
|
|
388
|
+
"""A subclass of h5py.File that adds metadata to the file."""
|
|
389
|
+
|
|
390
|
+
def __init__(self, *args, **kwargs):
|
|
391
|
+
super().__init__(*args, **kwargs)
|
|
392
|
+
self._set_aspire_metadata()
|
|
393
|
+
|
|
394
|
+
def _set_aspire_metadata(self):
|
|
395
|
+
from . import __version__ as aspire_version
|
|
396
|
+
|
|
397
|
+
self.attrs["aspire_version"] = aspire_version
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def update_at_indices(x: Array, slc: Array, y: Array) -> Array:
|
|
401
|
+
"""Update an array at specific indices."
|
|
402
|
+
|
|
403
|
+
This is a workaround for the fact that array API does not support
|
|
404
|
+
advanced indexing with all backends.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
x : Array
|
|
409
|
+
The array to update.
|
|
410
|
+
slc : Array
|
|
411
|
+
The indices to update.
|
|
412
|
+
y : Array
|
|
413
|
+
The values to set at the indices.
|
|
414
|
+
|
|
415
|
+
Returns
|
|
416
|
+
-------
|
|
417
|
+
Array
|
|
418
|
+
The updated array.
|
|
419
|
+
"""
|
|
420
|
+
try:
|
|
421
|
+
x[slc] = y
|
|
422
|
+
return x
|
|
423
|
+
except TypeError:
|
|
424
|
+
return x.at[slc].set(y)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
@dataclass
|
|
428
|
+
class CallHistory:
|
|
429
|
+
"""Class to store the history of calls to a function.
|
|
430
|
+
|
|
431
|
+
Attributes
|
|
432
|
+
----------
|
|
433
|
+
args : list[tuple]
|
|
434
|
+
The positional arguments of each call.
|
|
435
|
+
kwargs : list[dict]
|
|
436
|
+
The keyword arguments of each call.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
args: list[tuple]
|
|
440
|
+
kwargs: list[dict]
|
|
441
|
+
|
|
442
|
+
def to_dict(self, list_to_dict: bool = False) -> dict[str, Any]:
|
|
443
|
+
"""Convert the call history to a dictionary.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
list_to_dict : bool
|
|
448
|
+
If True, convert the lists of args and kwargs to dictionaries
|
|
449
|
+
with string keys. If False, keep them as lists. This is useful
|
|
450
|
+
when encoding the history for HDF5.
|
|
451
|
+
"""
|
|
452
|
+
if list_to_dict:
|
|
453
|
+
return {
|
|
454
|
+
"args": {str(i): v for i, v in enumerate(self.args)},
|
|
455
|
+
"kwargs": {str(i): v for i, v in enumerate(self.kwargs)},
|
|
456
|
+
}
|
|
457
|
+
else:
|
|
458
|
+
return {
|
|
459
|
+
"args": [list(arg) for arg in self.args],
|
|
460
|
+
"kwargs": [dict(kwarg) for kwarg in self.kwargs],
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def track_calls(wrapped=None):
|
|
465
|
+
"""Decorator to track calls to a function.
|
|
466
|
+
|
|
467
|
+
The decorator adds a :code:`calls` attribute to the wrapped function,
|
|
468
|
+
which is a :py:class:`CallHistory` object that stores the arguments and
|
|
469
|
+
keyword arguments of each call.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
@wrapt.decorator
|
|
473
|
+
def wrapper(wrapped_func, instance, args, kwargs):
|
|
474
|
+
# If instance is provided, we're dealing with a method.
|
|
475
|
+
if instance:
|
|
476
|
+
# Attach `calls` attribute to the method's `__func__`, which is the original function
|
|
477
|
+
if not hasattr(wrapped_func.__func__, "calls"):
|
|
478
|
+
wrapped_func.__func__.calls = CallHistory([], [])
|
|
479
|
+
wrapped_func.__func__.calls.args.append(args)
|
|
480
|
+
wrapped_func.__func__.calls.kwargs.append(kwargs)
|
|
481
|
+
else:
|
|
482
|
+
# For standalone functions, attach `calls` directly to the function
|
|
483
|
+
if not hasattr(wrapped_func, "calls"):
|
|
484
|
+
wrapped_func.calls = CallHistory([], [])
|
|
485
|
+
wrapped_func.calls.args.append(args)
|
|
486
|
+
wrapped_func.calls.kwargs.append(kwargs)
|
|
487
|
+
|
|
488
|
+
# Call the original wrapped function
|
|
489
|
+
return wrapped_func(*args, **kwargs)
|
|
490
|
+
|
|
491
|
+
return wrapper(wrapped) if wrapped else wrapper
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aspire-inference
|
|
3
|
+
Version: 0.1.0a2
|
|
4
|
+
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
|
+
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mj-will/aspire
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: matplotlib
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: array-api-compat
|
|
15
|
+
Requires-Dist: wrapt
|
|
16
|
+
Requires-Dist: h5py
|
|
17
|
+
Provides-Extra: jax
|
|
18
|
+
Requires-Dist: jax; extra == "jax"
|
|
19
|
+
Requires-Dist: jaxlib; extra == "jax"
|
|
20
|
+
Requires-Dist: flowjax; extra == "jax"
|
|
21
|
+
Provides-Extra: torch
|
|
22
|
+
Requires-Dist: torch; extra == "torch"
|
|
23
|
+
Requires-Dist: zuko; extra == "torch"
|
|
24
|
+
Provides-Extra: minipcn
|
|
25
|
+
Requires-Dist: minipcn; extra == "minipcn"
|
|
26
|
+
Provides-Extra: emcee
|
|
27
|
+
Requires-Dist: emcee; extra == "emcee"
|
|
28
|
+
Provides-Extra: blackjax
|
|
29
|
+
Requires-Dist: blackjax; extra == "blackjax"
|
|
30
|
+
Provides-Extra: test
|
|
31
|
+
Requires-Dist: pytest; extra == "test"
|
|
32
|
+
Requires-Dist: pytest-requires; extra == "test"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
36
|
+
|
|
37
|
+
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced code.
|
|
38
|
+
|
|
39
|
+
## Installation
|
|
40
|
+
|
|
41
|
+
aspire can be installed from PyPI using `pip`
|
|
42
|
+
|
|
43
|
+
```
|
|
44
|
+
pip install aspire-inference
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
48
|
+
the package can be imported and used as `aspire`.
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
|
|
2
|
+
aspire/aspire.py,sha256=AEkFUuOCF4F_iXUqRNst_4mucxozYRK4fG4V2wGrT4Q,15762
|
|
3
|
+
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
|
+
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
|
+
aspire/samples.py,sha256=hMlONOtSuYE3bU6r_wQCZ8Z1dcc3Ch15bNMLO8fGU8g,16263
|
|
6
|
+
aspire/transforms.py,sha256=R_BNPlYxK8tvACkZMjgHayr9gUpHxUQiD8148jfTnmg,16407
|
|
7
|
+
aspire/utils.py,sha256=fQeLMauCN3vAogKbVTVg9jfjW7nTEFi7V6Ot-BYNfxE,14301
|
|
8
|
+
aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
|
|
9
|
+
aspire/flows/base.py,sha256=scBhYvtaoa1x_gcrWs0nLfOKhWYu2bqivVVqbH4zSI8,860
|
|
10
|
+
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
|
+
aspire/flows/jax/flows.py,sha256=jZ93fnc7U7ZhZLVixGUTwyeDb6Vz0UWpYkkVHwirNug,2896
|
|
12
|
+
aspire/flows/jax/utils.py,sha256=UlvXOOqC5fNsmVUnU4LSksliq7pLRm9NhOu0ZvVHqgc,1455
|
|
13
|
+
aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
aspire/flows/torch/flows.py,sha256=ZNnShj-FMr56ZbcY06fNQa0epolzMZBd8ok2TzKGZ8E,8996
|
|
15
|
+
aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
aspire/samplers/base.py,sha256=BZ5nY_wtvuOIpTaJWUYZflCFXPTDk24xB-qLirIn9qE,2835
|
|
17
|
+
aspire/samplers/importance.py,sha256=3mY6JEqzdunHwAF6l3-CN-tBEdC_8J0LkhxD57DyHoY,609
|
|
18
|
+
aspire/samplers/mcmc.py,sha256=uuCjHZeey5mqjntnYaisNytYBazIc0xuvRcXPHwtg0Y,5075
|
|
19
|
+
aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
+
aspire/samplers/smc/base.py,sha256=GePA6tm8Dno_AjCeNuRX3KOaKnoKSFHSRAb-QWx9wJE,10531
|
|
21
|
+
aspire/samplers/smc/blackjax.py,sha256=9w1ORzWTT1viwp99_ttLxnNMdgTO-VqAzsf-NhgG9vY,11722
|
|
22
|
+
aspire/samplers/smc/emcee.py,sha256=ZXXyN2l1Bz5ZsCPEcswg-Kakiw41nNa2jEW1N8zGjuc,2498
|
|
23
|
+
aspire/samplers/smc/minipcn.py,sha256=ZjeP4iHFR67G8WKEfMe0b1McrtPgQMNHyyy4vRx6WNE,2747
|
|
24
|
+
aspire_inference-0.1.0a2.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a2.dist-info/METADATA,sha256=8s65XoHR6AJmpCDFAA1mqqCWZXd_m8skZgOotgNRO2U,1475
|
|
26
|
+
aspire_inference-0.1.0a2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a2.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a2.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Michael J. Williams
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
aspire
|