aspire-inference 0.1.0a6__py3-none-any.whl → 0.1.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/flows/__init__.py +35 -7
- aspire/samples.py +20 -2
- aspire/transforms.py +2 -2
- {aspire_inference-0.1.0a6.dist-info → aspire_inference-0.1.0a8.dist-info}/METADATA +16 -2
- {aspire_inference-0.1.0a6.dist-info → aspire_inference-0.1.0a8.dist-info}/RECORD +8 -8
- {aspire_inference-0.1.0a6.dist-info → aspire_inference-0.1.0a8.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a6.dist-info → aspire_inference-0.1.0a8.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a6.dist-info → aspire_inference-0.1.0a8.dist-info}/top_level.txt +0 -0
aspire/flows/__init__.py
CHANGED
|
@@ -1,5 +1,31 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_flow_wrapper(
|
|
8
|
+
backend: str = "zuko", flow_matching: bool = False
|
|
9
|
+
) -> tuple[type, Any]:
|
|
10
|
+
"""Get the wrapper for the flow implementation.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
backend : str
|
|
15
|
+
The backend to use. Options are "zuko" (PyTorch), "flowjax" (JAX), or
|
|
16
|
+
any other registered flow class via entry points. Default is "zuko".
|
|
17
|
+
flow_matching : bool, optional
|
|
18
|
+
Whether to use flow matching variant of the flow. Default is False.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
FlowClass : type
|
|
23
|
+
The flow class corresponding to the specified backend.
|
|
24
|
+
xp : module
|
|
25
|
+
The array API module corresponding to the specified backend.
|
|
26
|
+
"""
|
|
27
|
+
from importlib.metadata import entry_points
|
|
28
|
+
|
|
3
29
|
if backend == "zuko":
|
|
4
30
|
import array_api_compat.torch as torch_api
|
|
5
31
|
|
|
@@ -20,11 +46,12 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
20
46
|
)
|
|
21
47
|
return FlowJax, jnp
|
|
22
48
|
else:
|
|
23
|
-
|
|
24
|
-
|
|
49
|
+
if flow_matching:
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Flow matching option is ignored for external backends."
|
|
52
|
+
)
|
|
25
53
|
eps = {
|
|
26
|
-
ep.name.lower(): ep
|
|
27
|
-
for ep in entry_points().get("aspire.flows", [])
|
|
54
|
+
ep.name.lower(): ep for ep in entry_points(group="aspire.flows")
|
|
28
55
|
}
|
|
29
56
|
if backend in eps:
|
|
30
57
|
FlowClass = eps[backend].load()
|
|
@@ -35,6 +62,7 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
35
62
|
)
|
|
36
63
|
return FlowClass, xp
|
|
37
64
|
else:
|
|
65
|
+
known_backends = ["zuko", "flowjax"] + list(eps.keys())
|
|
38
66
|
raise ValueError(
|
|
39
|
-
f"Unknown
|
|
67
|
+
f"Unknown backend '{backend}'. Available backends: {known_backends}"
|
|
40
68
|
)
|
aspire/samples.py
CHANGED
|
@@ -14,6 +14,7 @@ from array_api_compat import (
|
|
|
14
14
|
)
|
|
15
15
|
from array_api_compat import device as api_device
|
|
16
16
|
from array_api_compat.common._typing import Array
|
|
17
|
+
from matplotlib.figure import Figure
|
|
17
18
|
|
|
18
19
|
from .utils import (
|
|
19
20
|
asarray,
|
|
@@ -161,7 +162,24 @@ class BaseSamples:
|
|
|
161
162
|
|
|
162
163
|
return pd.DataFrame(self.to_dict(flat=flat))
|
|
163
164
|
|
|
164
|
-
def plot_corner(
|
|
165
|
+
def plot_corner(
|
|
166
|
+
self,
|
|
167
|
+
parameters: list[str] | None = None,
|
|
168
|
+
fig: Figure | None = None,
|
|
169
|
+
**kwargs,
|
|
170
|
+
):
|
|
171
|
+
"""Plot a corner plot of the samples.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
parameters : list[str] | None
|
|
176
|
+
List of parameters to plot. If None, all parameters are plotted.
|
|
177
|
+
fig : matplotlib.figure.Figure | None
|
|
178
|
+
Figure to plot on. If None, a new figure is created.
|
|
179
|
+
**kwargs : dict
|
|
180
|
+
Additional keyword arguments to pass to corner.corner(). Kwargs
|
|
181
|
+
are deep-copied before use.
|
|
182
|
+
"""
|
|
165
183
|
import corner
|
|
166
184
|
|
|
167
185
|
kwargs = copy.deepcopy(kwargs)
|
|
@@ -173,7 +191,7 @@ class BaseSamples:
|
|
|
173
191
|
x = self.x[:, indices] if self.x.ndim > 1 else self.x[indices]
|
|
174
192
|
else:
|
|
175
193
|
x = self.x
|
|
176
|
-
fig = corner.corner(to_numpy(x), **kwargs)
|
|
194
|
+
fig = corner.corner(to_numpy(x), fig=fig, **kwargs)
|
|
177
195
|
return fig
|
|
178
196
|
|
|
179
197
|
def __str__(self):
|
aspire/transforms.py
CHANGED
|
@@ -684,7 +684,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
684
684
|
self.flow_kwargs.setdefault("dtype", dtype)
|
|
685
685
|
self.fit_kwargs = dict(fit_kwargs or {})
|
|
686
686
|
|
|
687
|
-
FlowClass = get_flow_wrapper(
|
|
687
|
+
FlowClass, xp = get_flow_wrapper(
|
|
688
688
|
backend=flow_backend, flow_matching=flow_matching
|
|
689
689
|
)
|
|
690
690
|
transform = CompositeTransform(
|
|
@@ -695,7 +695,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
695
695
|
bounded_transform=bounded_transform,
|
|
696
696
|
affine_transform=affine_transform,
|
|
697
697
|
device=device,
|
|
698
|
-
xp=
|
|
698
|
+
xp=xp,
|
|
699
699
|
eps=eps,
|
|
700
700
|
dtype=dtype,
|
|
701
701
|
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a8
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/mj-will/aspire
|
|
8
|
+
Project-URL: Documentation, https://aspire.readthedocs.io/
|
|
8
9
|
Classifier: Programming Language :: Python :: 3
|
|
9
10
|
Requires-Python: >=3.10
|
|
10
11
|
Description-Content-Type: text/markdown
|
|
@@ -38,7 +39,7 @@ Dynamic: license-file
|
|
|
38
39
|
|
|
39
40
|
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
40
41
|
|
|
41
|
-
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced
|
|
42
|
+
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
|
|
42
43
|
|
|
43
44
|
## Installation
|
|
44
45
|
|
|
@@ -50,3 +51,16 @@ pip install aspire-inference
|
|
|
50
51
|
|
|
51
52
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
52
53
|
the package can be imported and used as `aspire`.
|
|
54
|
+
|
|
55
|
+
## Documentation
|
|
56
|
+
|
|
57
|
+
See the [documentation on ReadTheDocs][docs].
|
|
58
|
+
|
|
59
|
+
## Citation
|
|
60
|
+
|
|
61
|
+
If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
[docs]: https://aspire.readthedocs.io/
|
|
65
|
+
[DOI]: https://doi.org/10.5281/zenodo.15658747
|
|
66
|
+
[paper]: https://arxiv.org/abs/2511.04218
|
|
@@ -2,10 +2,10 @@ aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
|
|
|
2
2
|
aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
|
-
aspire/samples.py,sha256=
|
|
6
|
-
aspire/transforms.py,sha256=
|
|
5
|
+
aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
|
|
6
|
+
aspire/transforms.py,sha256=W2JL045zkczvdBVXXeSn2c8aKyX8ASxFKgWU2cJHufk,24922
|
|
7
7
|
aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
|
|
8
|
-
aspire/flows/__init__.py,sha256=
|
|
8
|
+
aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
|
|
9
9
|
aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
|
|
10
10
|
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
11
|
aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
|
|
@@ -21,8 +21,8 @@ aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,1
|
|
|
21
21
|
aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a8.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a8.dist-info/METADATA,sha256=Or73qfnx3KsPNSeaNTJ70t56GME2r9qACM03aSW40t8,1965
|
|
26
|
+
aspire_inference-0.1.0a8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a8.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|