aspire-inference 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +2 -0
- aspire/flows/__init__.py +35 -7
- aspire/samplers/smc/blackjax.py +17 -5
- aspire/transforms.py +3 -2
- aspire_inference-0.1.0a9.dist-info/METADATA +111 -0
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a9.dist-info}/RECORD +9 -9
- aspire_inference-0.1.0a7.dist-info/METADATA +0 -52
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a9.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a9.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a9.dist-info}/top_level.txt +0 -0
aspire/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ import logging
|
|
|
6
6
|
from importlib.metadata import PackageNotFoundError, version
|
|
7
7
|
|
|
8
8
|
from .aspire import Aspire
|
|
9
|
+
from .samples import Samples
|
|
9
10
|
|
|
10
11
|
try:
|
|
11
12
|
__version__ = version("aspire")
|
|
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
|
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
19
|
"Aspire",
|
|
20
|
+
"Samples",
|
|
19
21
|
]
|
aspire/flows/__init__.py
CHANGED
|
@@ -1,5 +1,31 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_flow_wrapper(
|
|
8
|
+
backend: str = "zuko", flow_matching: bool = False
|
|
9
|
+
) -> tuple[type, Any]:
|
|
10
|
+
"""Get the wrapper for the flow implementation.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
backend : str
|
|
15
|
+
The backend to use. Options are "zuko" (PyTorch), "flowjax" (JAX), or
|
|
16
|
+
any other registered flow class via entry points. Default is "zuko".
|
|
17
|
+
flow_matching : bool, optional
|
|
18
|
+
Whether to use flow matching variant of the flow. Default is False.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
FlowClass : type
|
|
23
|
+
The flow class corresponding to the specified backend.
|
|
24
|
+
xp : module
|
|
25
|
+
The array API module corresponding to the specified backend.
|
|
26
|
+
"""
|
|
27
|
+
from importlib.metadata import entry_points
|
|
28
|
+
|
|
3
29
|
if backend == "zuko":
|
|
4
30
|
import array_api_compat.torch as torch_api
|
|
5
31
|
|
|
@@ -20,11 +46,12 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
20
46
|
)
|
|
21
47
|
return FlowJax, jnp
|
|
22
48
|
else:
|
|
23
|
-
|
|
24
|
-
|
|
49
|
+
if flow_matching:
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Flow matching option is ignored for external backends."
|
|
52
|
+
)
|
|
25
53
|
eps = {
|
|
26
|
-
ep.name.lower(): ep
|
|
27
|
-
for ep in entry_points().get("aspire.flows", [])
|
|
54
|
+
ep.name.lower(): ep for ep in entry_points(group="aspire.flows")
|
|
28
55
|
}
|
|
29
56
|
if backend in eps:
|
|
30
57
|
FlowClass = eps[backend].load()
|
|
@@ -35,6 +62,7 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
35
62
|
)
|
|
36
63
|
return FlowClass, xp
|
|
37
64
|
else:
|
|
65
|
+
known_backends = ["zuko", "flowjax"] + list(eps.keys())
|
|
38
66
|
raise ValueError(
|
|
39
|
-
f"Unknown
|
|
67
|
+
f"Unknown backend '{backend}'. Available backends: {known_backends}"
|
|
40
68
|
)
|
aspire/samplers/smc/blackjax.py
CHANGED
|
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
|
|
|
220
220
|
|
|
221
221
|
elif algorithm == "nuts":
|
|
222
222
|
# Initialize step size and mass matrix if not provided
|
|
223
|
-
inverse_mass_matrix = self.sampler_kwargs
|
|
223
|
+
inverse_mass_matrix = self.sampler_kwargs.get(
|
|
224
|
+
"inverse_mass_matrix"
|
|
225
|
+
)
|
|
224
226
|
if inverse_mass_matrix is None:
|
|
225
227
|
inverse_mass_matrix = jax.numpy.eye(self.dims)
|
|
226
228
|
|
|
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
|
|
|
258
260
|
z_final = final_states.position
|
|
259
261
|
|
|
260
262
|
# Calculate acceptance rates
|
|
261
|
-
|
|
262
|
-
|
|
263
|
+
try:
|
|
264
|
+
acceptance_rates = jax.numpy.mean(
|
|
265
|
+
all_infos.is_accepted, axis=1
|
|
266
|
+
)
|
|
267
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
268
|
+
except AttributeError:
|
|
269
|
+
mean_acceptance = np.nan
|
|
263
270
|
|
|
264
271
|
elif algorithm == "hmc":
|
|
265
272
|
# Initialize HMC sampler
|
|
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
|
|
|
294
301
|
|
|
295
302
|
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
296
303
|
z_final = final_states.position
|
|
297
|
-
|
|
298
|
-
|
|
304
|
+
try:
|
|
305
|
+
acceptance_rates = jax.numpy.mean(
|
|
306
|
+
all_infos.is_accepted, axis=1
|
|
307
|
+
)
|
|
308
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
309
|
+
except AttributeError:
|
|
310
|
+
mean_acceptance = np.nan
|
|
299
311
|
|
|
300
312
|
else:
|
|
301
313
|
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
aspire/transforms.py
CHANGED
|
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
|
|
|
332
332
|
prior_bounds=self.prior_bounds,
|
|
333
333
|
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
334
334
|
bounded_transform=self.bounded_transform,
|
|
335
|
+
affine_transform=self.affine_transform,
|
|
335
336
|
device=self.device,
|
|
336
337
|
xp=xp or self.xp,
|
|
337
338
|
eps=self.eps,
|
|
@@ -684,7 +685,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
684
685
|
self.flow_kwargs.setdefault("dtype", dtype)
|
|
685
686
|
self.fit_kwargs = dict(fit_kwargs or {})
|
|
686
687
|
|
|
687
|
-
FlowClass = get_flow_wrapper(
|
|
688
|
+
FlowClass, xp = get_flow_wrapper(
|
|
688
689
|
backend=flow_backend, flow_matching=flow_matching
|
|
689
690
|
)
|
|
690
691
|
transform = CompositeTransform(
|
|
@@ -695,7 +696,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
695
696
|
bounded_transform=bounded_transform,
|
|
696
697
|
affine_transform=affine_transform,
|
|
697
698
|
device=device,
|
|
698
|
-
xp=
|
|
699
|
+
xp=xp,
|
|
699
700
|
eps=eps,
|
|
700
701
|
dtype=dtype,
|
|
701
702
|
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aspire-inference
|
|
3
|
+
Version: 0.1.0a9
|
|
4
|
+
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
|
+
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mj-will/aspire
|
|
8
|
+
Project-URL: Documentation, https://aspire.readthedocs.io/
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Requires-Python: >=3.10
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Requires-Dist: matplotlib
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: array-api-compat
|
|
16
|
+
Requires-Dist: wrapt
|
|
17
|
+
Requires-Dist: h5py
|
|
18
|
+
Provides-Extra: scipy
|
|
19
|
+
Requires-Dist: scipy; extra == "scipy"
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Requires-Dist: jax; extra == "jax"
|
|
22
|
+
Requires-Dist: jaxlib; extra == "jax"
|
|
23
|
+
Requires-Dist: flowjax; extra == "jax"
|
|
24
|
+
Provides-Extra: torch
|
|
25
|
+
Requires-Dist: torch; extra == "torch"
|
|
26
|
+
Requires-Dist: zuko; extra == "torch"
|
|
27
|
+
Requires-Dist: tqdm; extra == "torch"
|
|
28
|
+
Provides-Extra: minipcn
|
|
29
|
+
Requires-Dist: minipcn; extra == "minipcn"
|
|
30
|
+
Provides-Extra: emcee
|
|
31
|
+
Requires-Dist: emcee; extra == "emcee"
|
|
32
|
+
Provides-Extra: blackjax
|
|
33
|
+
Requires-Dist: blackjax; extra == "blackjax"
|
|
34
|
+
Provides-Extra: test
|
|
35
|
+
Requires-Dist: pytest; extra == "test"
|
|
36
|
+
Requires-Dist: pytest-requires; extra == "test"
|
|
37
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
38
|
+
Dynamic: license-file
|
|
39
|
+
|
|
40
|
+
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
41
|
+
|
|
42
|
+
[](https://doi.org/10.5281/zenodo.15658747)
|
|
43
|
+
[](https://pypi.org/project/aspire-inference/)
|
|
44
|
+
[](https://aspire.readthedocs.io/en/latest/?badge=latest)
|
|
45
|
+

|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
|
|
49
|
+
|
|
50
|
+
## Installation
|
|
51
|
+
|
|
52
|
+
aspire can be installed from PyPI using `pip`
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
pip install aspire-inference
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
59
|
+
the package can be imported and used as `aspire`.
|
|
60
|
+
|
|
61
|
+
## Quickstart
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import numpy as np
|
|
65
|
+
from aspire import Aspire, Samples
|
|
66
|
+
|
|
67
|
+
# Define a log-likelihood and log-prior
|
|
68
|
+
def log_likelihood(samples):
|
|
69
|
+
x = samples.x
|
|
70
|
+
return -0.5 * np.sum(x**2, axis=-1)
|
|
71
|
+
|
|
72
|
+
def log_prior(samples):
|
|
73
|
+
return -0.5 * np.sum(samples.x**2, axis=-1)
|
|
74
|
+
|
|
75
|
+
# Create the initial samples
|
|
76
|
+
init = Samples(np.random.normal(size=(2_000, 4)))
|
|
77
|
+
|
|
78
|
+
# Define the aspire object
|
|
79
|
+
aspire = Aspire(
|
|
80
|
+
log_likelihood=log_likelihood,
|
|
81
|
+
log_prior=log_prior,
|
|
82
|
+
dims=4,
|
|
83
|
+
parameters=[f"x{i}" for i in range(4)],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Fit the normalizing flow
|
|
87
|
+
aspire.fit(init, n_epochs=20)
|
|
88
|
+
|
|
89
|
+
# Sample the posterior
|
|
90
|
+
posterior = aspire.sample_posterior(
|
|
91
|
+
sampler="smc",
|
|
92
|
+
n_samples=500,
|
|
93
|
+
sampler_kwargs=dict(n_steps=100),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Plot the posterior distribution
|
|
97
|
+
posterior.plot_corner()
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Documentation
|
|
101
|
+
|
|
102
|
+
See the [documentation on ReadTheDocs][docs].
|
|
103
|
+
|
|
104
|
+
## Citation
|
|
105
|
+
|
|
106
|
+
If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
[docs]: https://aspire.readthedocs.io/
|
|
110
|
+
[DOI]: https://doi.org/10.5281/zenodo.15658747
|
|
111
|
+
[paper]: https://arxiv.org/abs/2511.04218
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
aspire/__init__.py,sha256=
|
|
1
|
+
aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
|
|
2
2
|
aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
5
|
aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
|
|
6
|
-
aspire/transforms.py,sha256=
|
|
6
|
+
aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
|
|
7
7
|
aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
|
|
8
|
-
aspire/flows/__init__.py,sha256=
|
|
8
|
+
aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
|
|
9
9
|
aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
|
|
10
10
|
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
11
|
aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
|
|
@@ -18,11 +18,11 @@ aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M
|
|
|
18
18
|
aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
|
|
19
19
|
aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
20
|
aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
|
|
21
|
-
aspire/samplers/smc/blackjax.py,sha256=
|
|
21
|
+
aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a9.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a9.dist-info/METADATA,sha256=DCcTGComTNt9UFlbUwGPf58itBX6Y-1-edVys0pb2RQ,3187
|
|
26
|
+
aspire_inference-0.1.0a9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a9.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a9.dist-info/RECORD,,
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: aspire-inference
|
|
3
|
-
Version: 0.1.0a7
|
|
4
|
-
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
|
-
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
|
-
License: MIT
|
|
7
|
-
Project-URL: Homepage, https://github.com/mj-will/aspire
|
|
8
|
-
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Requires-Python: >=3.10
|
|
10
|
-
Description-Content-Type: text/markdown
|
|
11
|
-
License-File: LICENSE
|
|
12
|
-
Requires-Dist: matplotlib
|
|
13
|
-
Requires-Dist: numpy
|
|
14
|
-
Requires-Dist: array-api-compat
|
|
15
|
-
Requires-Dist: wrapt
|
|
16
|
-
Requires-Dist: h5py
|
|
17
|
-
Provides-Extra: scipy
|
|
18
|
-
Requires-Dist: scipy; extra == "scipy"
|
|
19
|
-
Provides-Extra: jax
|
|
20
|
-
Requires-Dist: jax; extra == "jax"
|
|
21
|
-
Requires-Dist: jaxlib; extra == "jax"
|
|
22
|
-
Requires-Dist: flowjax; extra == "jax"
|
|
23
|
-
Provides-Extra: torch
|
|
24
|
-
Requires-Dist: torch; extra == "torch"
|
|
25
|
-
Requires-Dist: zuko; extra == "torch"
|
|
26
|
-
Requires-Dist: tqdm; extra == "torch"
|
|
27
|
-
Provides-Extra: minipcn
|
|
28
|
-
Requires-Dist: minipcn; extra == "minipcn"
|
|
29
|
-
Provides-Extra: emcee
|
|
30
|
-
Requires-Dist: emcee; extra == "emcee"
|
|
31
|
-
Provides-Extra: blackjax
|
|
32
|
-
Requires-Dist: blackjax; extra == "blackjax"
|
|
33
|
-
Provides-Extra: test
|
|
34
|
-
Requires-Dist: pytest; extra == "test"
|
|
35
|
-
Requires-Dist: pytest-requires; extra == "test"
|
|
36
|
-
Requires-Dist: pytest-cov; extra == "test"
|
|
37
|
-
Dynamic: license-file
|
|
38
|
-
|
|
39
|
-
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
40
|
-
|
|
41
|
-
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced code.
|
|
42
|
-
|
|
43
|
-
## Installation
|
|
44
|
-
|
|
45
|
-
aspire can be installed from PyPI using `pip`
|
|
46
|
-
|
|
47
|
-
```
|
|
48
|
-
pip install aspire-inference
|
|
49
|
-
```
|
|
50
|
-
|
|
51
|
-
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
52
|
-
the package can be imported and used as `aspire`.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|