aspire-inference 0.1.0a7__py3-none-any.whl → 0.1.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/__init__.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
6
  from importlib.metadata import PackageNotFoundError, version
7
7
 
8
8
  from .aspire import Aspire
9
+ from .samples import Samples
9
10
 
10
11
  try:
11
12
  __version__ = version("aspire")
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
17
 
17
18
  __all__ = [
18
19
  "Aspire",
20
+ "Samples",
19
21
  ]
aspire/flows/__init__.py CHANGED
@@ -1,5 +1,31 @@
1
- def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
2
- """Get the wrapper for the flow implementation."""
1
+ import logging
2
+ from typing import Any
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ def get_flow_wrapper(
8
+ backend: str = "zuko", flow_matching: bool = False
9
+ ) -> tuple[type, Any]:
10
+ """Get the wrapper for the flow implementation.
11
+
12
+ Parameters
13
+ ----------
14
+ backend : str
15
+ The backend to use. Options are "zuko" (PyTorch), "flowjax" (JAX), or
16
+ any other registered flow class via entry points. Default is "zuko".
17
+ flow_matching : bool, optional
18
+ Whether to use flow matching variant of the flow. Default is False.
19
+
20
+ Returns
21
+ -------
22
+ FlowClass : type
23
+ The flow class corresponding to the specified backend.
24
+ xp : module
25
+ The array API module corresponding to the specified backend.
26
+ """
27
+ from importlib.metadata import entry_points
28
+
3
29
  if backend == "zuko":
4
30
  import array_api_compat.torch as torch_api
5
31
 
@@ -20,11 +46,12 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
20
46
  )
21
47
  return FlowJax, jnp
22
48
  else:
23
- from importlib.metadata import entry_points
24
-
49
+ if flow_matching:
50
+ logger.warning(
51
+ "Flow matching option is ignored for external backends."
52
+ )
25
53
  eps = {
26
- ep.name.lower(): ep
27
- for ep in entry_points().get("aspire.flows", [])
54
+ ep.name.lower(): ep for ep in entry_points(group="aspire.flows")
28
55
  }
29
56
  if backend in eps:
30
57
  FlowClass = eps[backend].load()
@@ -35,6 +62,7 @@ def get_flow_wrapper(backend: str = "zuko", flow_matching: bool = False):
35
62
  )
36
63
  return FlowClass, xp
37
64
  else:
65
+ known_backends = ["zuko", "flowjax"] + list(eps.keys())
38
66
  raise ValueError(
39
- f"Unknown flow class: {backend}. Available classes: {list(eps.keys())}"
67
+ f"Unknown backend '{backend}'. Available backends: {known_backends}"
40
68
  )
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
220
220
 
221
221
  elif algorithm == "nuts":
222
222
  # Initialize step size and mass matrix if not provided
223
- inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
223
+ inverse_mass_matrix = self.sampler_kwargs.get(
224
+ "inverse_mass_matrix"
225
+ )
224
226
  if inverse_mass_matrix is None:
225
227
  inverse_mass_matrix = jax.numpy.eye(self.dims)
226
228
 
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
258
260
  z_final = final_states.position
259
261
 
260
262
  # Calculate acceptance rates
261
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
262
- mean_acceptance = jax.numpy.mean(acceptance_rates)
263
+ try:
264
+ acceptance_rates = jax.numpy.mean(
265
+ all_infos.is_accepted, axis=1
266
+ )
267
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
268
+ except AttributeError:
269
+ mean_acceptance = np.nan
263
270
 
264
271
  elif algorithm == "hmc":
265
272
  # Initialize HMC sampler
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
294
301
 
295
302
  final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
296
303
  z_final = final_states.position
297
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
298
- mean_acceptance = jax.numpy.mean(acceptance_rates)
304
+ try:
305
+ acceptance_rates = jax.numpy.mean(
306
+ all_infos.is_accepted, axis=1
307
+ )
308
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
309
+ except AttributeError:
310
+ mean_acceptance = np.nan
299
311
 
300
312
  else:
301
313
  raise ValueError(f"Unsupported algorithm: {algorithm}")
aspire/transforms.py CHANGED
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
332
332
  prior_bounds=self.prior_bounds,
333
333
  bounded_to_unbounded=self.bounded_to_unbounded,
334
334
  bounded_transform=self.bounded_transform,
335
+ affine_transform=self.affine_transform,
335
336
  device=self.device,
336
337
  xp=xp or self.xp,
337
338
  eps=self.eps,
@@ -684,7 +685,7 @@ class FlowPreconditioningTransform(BaseTransform):
684
685
  self.flow_kwargs.setdefault("dtype", dtype)
685
686
  self.fit_kwargs = dict(fit_kwargs or {})
686
687
 
687
- FlowClass = get_flow_wrapper(
688
+ FlowClass, xp = get_flow_wrapper(
688
689
  backend=flow_backend, flow_matching=flow_matching
689
690
  )
690
691
  transform = CompositeTransform(
@@ -695,7 +696,7 @@ class FlowPreconditioningTransform(BaseTransform):
695
696
  bounded_transform=bounded_transform,
696
697
  affine_transform=affine_transform,
697
698
  device=device,
698
- xp=FlowClass.xp,
699
+ xp=xp,
699
700
  eps=eps,
700
701
  dtype=dtype,
701
702
  )
@@ -0,0 +1,111 @@
1
+ Metadata-Version: 2.4
2
+ Name: aspire-inference
3
+ Version: 0.1.0a9
4
+ Summary: Accelerate Sequential Posterior Inference via REuse
5
+ Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/mj-will/aspire
8
+ Project-URL: Documentation, https://aspire.readthedocs.io/
9
+ Classifier: Programming Language :: Python :: 3
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: numpy
15
+ Requires-Dist: array-api-compat
16
+ Requires-Dist: wrapt
17
+ Requires-Dist: h5py
18
+ Provides-Extra: scipy
19
+ Requires-Dist: scipy; extra == "scipy"
20
+ Provides-Extra: jax
21
+ Requires-Dist: jax; extra == "jax"
22
+ Requires-Dist: jaxlib; extra == "jax"
23
+ Requires-Dist: flowjax; extra == "jax"
24
+ Provides-Extra: torch
25
+ Requires-Dist: torch; extra == "torch"
26
+ Requires-Dist: zuko; extra == "torch"
27
+ Requires-Dist: tqdm; extra == "torch"
28
+ Provides-Extra: minipcn
29
+ Requires-Dist: minipcn; extra == "minipcn"
30
+ Provides-Extra: emcee
31
+ Requires-Dist: emcee; extra == "emcee"
32
+ Provides-Extra: blackjax
33
+ Requires-Dist: blackjax; extra == "blackjax"
34
+ Provides-Extra: test
35
+ Requires-Dist: pytest; extra == "test"
36
+ Requires-Dist: pytest-requires; extra == "test"
37
+ Requires-Dist: pytest-cov; extra == "test"
38
+ Dynamic: license-file
39
+
40
+ # aspire: Accelerated Sequential Posterior Inference via REuse
41
+
42
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
43
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
44
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
45
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
46
+
47
+
48
+ aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
49
+
50
+ ## Installation
51
+
52
+ aspire can be installed from PyPI using `pip`
53
+
54
+ ```
55
+ pip install aspire-inference
56
+ ```
57
+
58
+ **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
59
+ the package can be imported and used as `aspire`.
60
+
61
+ ## Quickstart
62
+
63
+ ```python
64
+ import numpy as np
65
+ from aspire import Aspire, Samples
66
+
67
+ # Define a log-likelihood and log-prior
68
+ def log_likelihood(samples):
69
+ x = samples.x
70
+ return -0.5 * np.sum(x**2, axis=-1)
71
+
72
+ def log_prior(samples):
73
+ return -0.5 * np.sum(samples.x**2, axis=-1)
74
+
75
+ # Create the initial samples
76
+ init = Samples(np.random.normal(size=(2_000, 4)))
77
+
78
+ # Define the aspire object
79
+ aspire = Aspire(
80
+ log_likelihood=log_likelihood,
81
+ log_prior=log_prior,
82
+ dims=4,
83
+ parameters=[f"x{i}" for i in range(4)],
84
+ )
85
+
86
+ # Fit the normalizing flow
87
+ aspire.fit(init, n_epochs=20)
88
+
89
+ # Sample the posterior
90
+ posterior = aspire.sample_posterior(
91
+ sampler="smc",
92
+ n_samples=500,
93
+ sampler_kwargs=dict(n_steps=100),
94
+ )
95
+
96
+ # Plot the posterior distribution
97
+ posterior.plot_corner()
98
+ ```
99
+
100
+ ## Documentation
101
+
102
+ See the [documentation on ReadTheDocs][docs].
103
+
104
+ ## Citation
105
+
106
+ If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
107
+
108
+
109
+ [docs]: https://aspire.readthedocs.io/
110
+ [DOI]: https://doi.org/10.5281/zenodo.15658747
111
+ [paper]: https://arxiv.org/abs/2511.04218
@@ -1,11 +1,11 @@
1
- aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
1
+ aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
2
  aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
5
  aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
6
- aspire/transforms.py,sha256=XMbf5MxK49elQeKDsmFraHN-0JeO1AciljdTk7k2ujk,24928
6
+ aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
7
  aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
8
- aspire/flows/__init__.py,sha256=3gGXF4HziMlZSmcEdJ_uHtrP-QEC6RXvylm4vtM-Xnk,1306
8
+ aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
11
11
  aspire/flows/jax/flows.py,sha256=1HnVgQ1GUXNcvxiZqEV19H2QI9Th5bWX_QbNfGaUhuA,6625
@@ -18,11 +18,11 @@ aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
- aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
21
+ aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
22
  aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
23
  aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
24
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a7.dist-info/METADATA,sha256=9TJxm0A3UrOIybNTr2CdJVVX9bbeztLqvFkhVO_pdu0,1617
26
- aspire_inference-0.1.0a7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a7.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a7.dist-info/RECORD,,
24
+ aspire_inference-0.1.0a9.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a9.dist-info/METADATA,sha256=DCcTGComTNt9UFlbUwGPf58itBX6Y-1-edVys0pb2RQ,3187
26
+ aspire_inference-0.1.0a9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a9.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a9.dist-info/RECORD,,
@@ -1,52 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: aspire-inference
3
- Version: 0.1.0a7
4
- Summary: Accelerate Sequential Posterior Inference via REuse
5
- Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
- License: MIT
7
- Project-URL: Homepage, https://github.com/mj-will/aspire
8
- Classifier: Programming Language :: Python :: 3
9
- Requires-Python: >=3.10
10
- Description-Content-Type: text/markdown
11
- License-File: LICENSE
12
- Requires-Dist: matplotlib
13
- Requires-Dist: numpy
14
- Requires-Dist: array-api-compat
15
- Requires-Dist: wrapt
16
- Requires-Dist: h5py
17
- Provides-Extra: scipy
18
- Requires-Dist: scipy; extra == "scipy"
19
- Provides-Extra: jax
20
- Requires-Dist: jax; extra == "jax"
21
- Requires-Dist: jaxlib; extra == "jax"
22
- Requires-Dist: flowjax; extra == "jax"
23
- Provides-Extra: torch
24
- Requires-Dist: torch; extra == "torch"
25
- Requires-Dist: zuko; extra == "torch"
26
- Requires-Dist: tqdm; extra == "torch"
27
- Provides-Extra: minipcn
28
- Requires-Dist: minipcn; extra == "minipcn"
29
- Provides-Extra: emcee
30
- Requires-Dist: emcee; extra == "emcee"
31
- Provides-Extra: blackjax
32
- Requires-Dist: blackjax; extra == "blackjax"
33
- Provides-Extra: test
34
- Requires-Dist: pytest; extra == "test"
35
- Requires-Dist: pytest-requires; extra == "test"
36
- Requires-Dist: pytest-cov; extra == "test"
37
- Dynamic: license-file
38
-
39
- # aspire: Accelerated Sequential Posterior Inference via REuse
40
-
41
- aspire is a framework for reusing existing posterior samples to obtain new results at a reduced code.
42
-
43
- ## Installation
44
-
45
- aspire can be installed from PyPI using `pip`
46
-
47
- ```
48
- pip install aspire-inference
49
- ```
50
-
51
- **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
52
- the package can be imported and used as `aspire`.