aspire-inference 0.1.0a7__py3-none-any.whl → 0.1.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/flows/__init__.py +35 -7
- aspire/transforms.py +2 -2
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a8.dist-info}/METADATA +16 -2
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a8.dist-info}/RECORD +7 -7
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a8.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a8.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a7.dist-info → aspire_inference-0.1.0a8.dist-info}/top_level.txt +0 -0
aspire/flows/__init__.py
CHANGED
|
@@ -1,5 +1,31 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_flow_wrapper(
|
|
8
|
+
backend: str = "zuko", flow_matching: bool = False
|
|
9
|
+
) -> tuple[type, Any]:
|
|
10
|
+
"""Get the wrapper for the flow implementation.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
backend : str
|
|
15
|
+
The backend to use. Options are "zuko" (PyTorch), "flowjax" (JAX), or
|
|
16
|
+
any other registered flow class via entry points. Default is "zuko".
|
|
17
|
+
flow_matching : bool, optional
|
|
18
|
+
Whether to use flow matching variant of the flow. Default is False.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
FlowClass : type
|
|
23
|
+
The flow class corresponding to the specified backend.
|
|
24
|
+
xp : module
|
|
25
|
+
The array API module corresponding to the specified backend.
|
|
26
|
+
"""
|
|
27
|
+
from importlib.metadata import entry_points
|
|
28
|
+
|
|
3
29
|
if backend == "zuko":
|
|
4
30
|
import array_api_compat.torch as torch_api
|
|
5
31
|
|
|
@@ -20,11 +46,12 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
20
46
|
)
|
|
21
47
|
return FlowJax, jnp
|
|
22
48
|
else:
|
|
23
|
-
|
|
24
|
-
|
|
49
|
+
if flow_matching:
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Flow matching option is ignored for external backends."
|
|
52
|
+
)
|
|
25
53
|
eps = {
|
|
26
|
-
ep.name.lower(): ep
|
|
27
|
-
for ep in entry_points().get("aspire.flows", [])
|
|
54
|
+
ep.name.lower(): ep for ep in entry_points(group="aspire.flows")
|
|
28
55
|
}
|
|
29
56
|
if backend in eps:
|
|
30
57
|
FlowClass = eps[backend].load()
|
|
@@ -35,6 +62,7 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
|
|
|
35
62
|
)
|
|
36
63
|
return FlowClass, xp
|
|
37
64
|
else:
|
|
65
|
+
known_backends = ["zuko", "flowjax"] + list(eps.keys())
|
|
38
66
|
raise ValueError(
|
|
39
|
-
f"Unknown
|
|
67
|
+
f"Unknown backend '{backend}'. Available backends: {known_backends}"
|
|
40
68
|
)
|
aspire/transforms.py
CHANGED
|
@@ -684,7 +684,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
684
684
|
self.flow_kwargs.setdefault("dtype", dtype)
|
|
685
685
|
self.fit_kwargs = dict(fit_kwargs or {})
|
|
686
686
|
|
|
687
|
-
FlowClass = get_flow_wrapper(
|
|
687
|
+
FlowClass, xp = get_flow_wrapper(
|
|
688
688
|
backend=flow_backend, flow_matching=flow_matching
|
|
689
689
|
)
|
|
690
690
|
transform = CompositeTransform(
|
|
@@ -695,7 +695,7 @@ class FlowPreconditioningTransform(BaseTransform):
|
|
|
695
695
|
bounded_transform=bounded_transform,
|
|
696
696
|
affine_transform=affine_transform,
|
|
697
697
|
device=device,
|
|
698
|
-
xp=
|
|
698
|
+
xp=xp,
|
|
699
699
|
eps=eps,
|
|
700
700
|
dtype=dtype,
|
|
701
701
|
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a8
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/mj-will/aspire
|
|
8
|
+
Project-URL: Documentation, https://aspire.readthedocs.io/
|
|
8
9
|
Classifier: Programming Language :: Python :: 3
|
|
9
10
|
Requires-Python: >=3.10
|
|
10
11
|
Description-Content-Type: text/markdown
|
|
@@ -38,7 +39,7 @@ Dynamic: license-file
|
|
|
38
39
|
|
|
39
40
|
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
40
41
|
|
|
41
|
-
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced
|
|
42
|
+
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
|
|
42
43
|
|
|
43
44
|
## Installation
|
|
44
45
|
|
|
@@ -50,3 +51,16 @@ pip install aspire-inference
|
|
|
50
51
|
|
|
51
52
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
52
53
|
the package can be imported and used as `aspire`.
|
|
54
|
+
|
|
55
|
+
## Documentation
|
|
56
|
+
|
|
57
|
+
See the [documentation on ReadTheDocs][docs].
|
|
58
|
+
|
|
59
|
+
## Citation
|
|
60
|
+
|
|
61
|
+
If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
[docs]: https://aspire.readthedocs.io/
|
|
65
|
+
[DOI]: https://doi.org/10.5281/zenodo.15658747
|
|
66
|
+
[paper]: https://arxiv.org/abs/2511.04218
|
|
@@ -3,9 +3,9 @@ aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
|
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
5
|
aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
|
|
6
|
-
aspire/transforms.py,sha256=
|
|
6
|
+
aspire/transforms.py,sha256=W2JL045zkczvdBVXXeSn2c8aKyX8ASxFKgWU2cJHufk,24922
|
|
7
7
|
aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
|
|
8
|
-
aspire/flows/__init__.py,sha256=
|
|
8
|
+
aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
|
|
9
9
|
aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
|
|
10
10
|
aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
|
|
11
11
|
aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
|
|
@@ -21,8 +21,8 @@ aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,1
|
|
|
21
21
|
aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a8.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a8.dist-info/METADATA,sha256=Or73qfnx3KsPNSeaNTJ70t56GME2r9qACM03aSW40t8,1965
|
|
26
|
+
aspire_inference-0.1.0a8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a8.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|