aspire-inference 0.1.0a8__py3-none-any.whl → 0.1.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/__init__.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
6
  from importlib.metadata import PackageNotFoundError, version
7
7
 
8
8
  from .aspire import Aspire
9
+ from .samples import Samples
9
10
 
10
11
  try:
11
12
  __version__ = version("aspire")
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
17
 
17
18
  __all__ = [
18
19
  "Aspire",
20
+ "Samples",
19
21
  ]
aspire/samplers/base.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Callable
4
4
  from ..flows.base import Flow
5
5
  from ..samples import Samples
6
6
  from ..transforms import IdentityTransform
7
- from ..utils import track_calls
7
+ from ..utils import asarray, track_calls
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
 
@@ -56,7 +56,11 @@ class Sampler:
56
56
 
57
57
  def fit_preconditioning_transform(self, x):
58
58
  """Fit the data transform to the data."""
59
- x = self.preconditioning_transform.xp.asarray(x)
59
+ x = asarray(
60
+ x,
61
+ xp=self.preconditioning_transform.xp,
62
+ dtype=self.preconditioning_transform.dtype,
63
+ )
60
64
  return self.preconditioning_transform.fit(x)
61
65
 
62
66
  @track_calls
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
220
220
 
221
221
  elif algorithm == "nuts":
222
222
  # Initialize step size and mass matrix if not provided
223
- inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
223
+ inverse_mass_matrix = self.sampler_kwargs.get(
224
+ "inverse_mass_matrix"
225
+ )
224
226
  if inverse_mass_matrix is None:
225
227
  inverse_mass_matrix = jax.numpy.eye(self.dims)
226
228
 
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
258
260
  z_final = final_states.position
259
261
 
260
262
  # Calculate acceptance rates
261
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
262
- mean_acceptance = jax.numpy.mean(acceptance_rates)
263
+ try:
264
+ acceptance_rates = jax.numpy.mean(
265
+ all_infos.is_accepted, axis=1
266
+ )
267
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
268
+ except AttributeError:
269
+ mean_acceptance = np.nan
263
270
 
264
271
  elif algorithm == "hmc":
265
272
  # Initialize HMC sampler
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
294
301
 
295
302
  final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
296
303
  z_final = final_states.position
297
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
298
- mean_acceptance = jax.numpy.mean(acceptance_rates)
304
+ try:
305
+ acceptance_rates = jax.numpy.mean(
306
+ all_infos.is_accepted, axis=1
307
+ )
308
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
309
+ except AttributeError:
310
+ mean_acceptance = np.nan
299
311
 
300
312
  else:
301
313
  raise ValueError(f"Unsupported algorithm: {algorithm}")
@@ -3,7 +3,11 @@ from functools import partial
3
3
  import numpy as np
4
4
 
5
5
  from ...samples import SMCSamples
6
- from ...utils import to_numpy, track_calls
6
+ from ...utils import (
7
+ asarray,
8
+ determine_backend_name,
9
+ track_calls,
10
+ )
7
11
  from .base import NumpySMCSampler
8
12
 
9
13
 
@@ -13,7 +17,7 @@ class MiniPCNSMC(NumpySMCSampler):
13
17
  rng = None
14
18
 
15
19
  def log_prob(self, x, beta=None):
16
- return to_numpy(super().log_prob(x, beta))
20
+ return super().log_prob(x, beta)
17
21
 
18
22
  @track_calls
19
23
  def sample(
@@ -29,11 +33,14 @@ class MiniPCNSMC(NumpySMCSampler):
29
33
  sampler_kwargs: dict | None = None,
30
34
  rng: np.random.Generator | None = None,
31
35
  ):
36
+ from orng import ArrayRNG
37
+
32
38
  self.sampler_kwargs = sampler_kwargs or {}
33
39
  self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
34
40
  self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
35
41
  self.sampler_kwargs.setdefault("step_fn", "tpcn")
36
- self.rng = rng or np.random.default_rng()
42
+ self.backend_str = determine_backend_name(xp=self.xp)
43
+ self.rng = rng or ArrayRNG(backend=self.backend_str)
37
44
  return super().sample(
38
45
  n_samples,
39
46
  n_steps=n_steps,
@@ -58,9 +65,14 @@ class MiniPCNSMC(NumpySMCSampler):
58
65
  target_acceptance_rate=self.sampler_kwargs[
59
66
  "target_acceptance_rate"
60
67
  ],
68
+ xp=self.xp,
61
69
  )
62
70
  # Map to transformed dimension for sampling
63
- z = to_numpy(self.fit_preconditioning_transform(particles.x))
71
+ z = asarray(
72
+ self.fit_preconditioning_transform(particles.x),
73
+ xp=self.xp,
74
+ dtype=self.dtype,
75
+ )
64
76
  chain, history = sampler.sample(
65
77
  z,
66
78
  n_steps=n_steps or self.sampler_kwargs["n_steps"],
aspire/samples.py CHANGED
@@ -425,19 +425,23 @@ class Samples(BaseSamples):
425
425
 
426
426
  def to_namespace(self, xp):
427
427
  return self.__class__(
428
- x=asarray(self.x, xp),
428
+ x=asarray(self.x, xp, dtype=self.dtype),
429
429
  parameters=self.parameters,
430
- log_likelihood=asarray(self.log_likelihood, xp)
430
+ log_likelihood=asarray(self.log_likelihood, xp, dtype=self.dtype)
431
431
  if self.log_likelihood is not None
432
432
  else None,
433
- log_prior=asarray(self.log_prior, xp)
433
+ log_prior=asarray(self.log_prior, xp, dtype=self.dtype)
434
434
  if self.log_prior is not None
435
435
  else None,
436
- log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
437
- log_evidence=asarray(self.log_evidence, xp)
436
+ log_q=asarray(self.log_q, xp, dtype=self.dtype)
437
+ if self.log_q is not None
438
+ else None,
439
+ log_evidence=asarray(self.log_evidence, xp, dtype=self.dtype)
438
440
  if self.log_evidence is not None
439
441
  else None,
440
- log_evidence_error=asarray(self.log_evidence_error, xp)
442
+ log_evidence_error=asarray(
443
+ self.log_evidence_error, xp, dtype=self.dtype
444
+ )
441
445
  if self.log_evidence_error is not None
442
446
  else None,
443
447
  )
aspire/transforms.py CHANGED
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
332
332
  prior_bounds=self.prior_bounds,
333
333
  bounded_to_unbounded=self.bounded_to_unbounded,
334
334
  bounded_transform=self.bounded_transform,
335
+ affine_transform=self.affine_transform,
335
336
  device=self.device,
336
337
  xp=xp or self.xp,
337
338
  eps=self.eps,
aspire/utils.py CHANGED
@@ -12,7 +12,13 @@ import h5py
12
12
  import wrapt
13
13
  from array_api_compat import (
14
14
  array_namespace,
15
+ is_cupy_namespace,
16
+ is_dask_namespace,
15
17
  is_jax_array,
18
+ is_jax_namespace,
19
+ is_ndonnx_namespace,
20
+ is_numpy_namespace,
21
+ is_pydata_sparse_namespace,
16
22
  is_torch_array,
17
23
  is_torch_namespace,
18
24
  to_device,
@@ -28,6 +34,17 @@ if TYPE_CHECKING:
28
34
  logger = logging.getLogger(__name__)
29
35
 
30
36
 
37
+ IS_NAMESPACE_FUNCTIONS = {
38
+ "numpy": is_numpy_namespace,
39
+ "torch": is_torch_namespace,
40
+ "jax": is_jax_namespace,
41
+ "cupy": is_cupy_namespace,
42
+ "dask": is_dask_namespace,
43
+ "pydata_sparse": is_pydata_sparse_namespace,
44
+ "ndonnx": is_ndonnx_namespace,
45
+ }
46
+
47
+
31
48
  def configure_logger(
32
49
  log_level: str | int = "INFO",
33
50
  additional_loggers: list[str] = None,
@@ -234,7 +251,7 @@ def to_numpy(x: Array, **kwargs) -> np.ndarray:
234
251
  return np.asarray(x, **kwargs)
235
252
 
236
253
 
237
- def asarray(x, xp: Any = None, **kwargs) -> Array:
254
+ def asarray(x, xp: Any = None, dtype: Any | None = None, **kwargs) -> Array:
238
255
  """Convert an array to the specified array API.
239
256
 
240
257
  Parameters
@@ -244,13 +261,51 @@ def asarray(x, xp: Any = None, **kwargs) -> Array:
244
261
  xp : Any
245
262
  The array API to use for the conversion. If None, the array API
246
263
  is inferred from the input array.
264
+ dtype : Any | str | None
265
+ The dtype to use for the conversion. If None, the dtype is not changed.
247
266
  kwargs : dict
248
267
  Additional keyword arguments to pass to xp.asarray.
249
268
  """
269
+ # Handle DLPack conversion from JAX to PyTorch to avoid shape issues when
270
+ # passing JAX arrays directly to torch.asarray.
250
271
  if is_jax_array(x) and is_torch_namespace(xp):
251
- return xp.utils.dlpack.from_dlpack(x)
252
- else:
253
- return xp.asarray(x, **kwargs)
272
+ tensor = xp.utils.dlpack.from_dlpack(x)
273
+ if dtype is not None:
274
+ tensor = tensor.to(resolve_dtype(dtype, xp=xp))
275
+ return tensor
276
+
277
+ if dtype is not None:
278
+ kwargs["dtype"] = resolve_dtype(dtype, xp=xp)
279
+ return xp.asarray(x, **kwargs)
280
+
281
+
282
+ def determine_backend_name(
283
+ x: Array | None = None, xp: Any | None = None
284
+ ) -> str:
285
+ """Determine the backend name from an array or array API module.
286
+
287
+ Parameters
288
+ ----------
289
+ x : Array or None
290
+ The array to infer the backend from. If None, xp must be provided.
291
+ xp : Any or None
292
+ The array API module to infer the backend from. If None, x must be provided.
293
+
294
+ Returns
295
+ -------
296
+ str
297
+ The name of the backend. If the backend cannot be determined, returns "unknown".
298
+ """
299
+ if x is not None:
300
+ xp = array_namespace(x)
301
+ if xp is None:
302
+ raise ValueError(
303
+ "Either x or xp must be provided to determine backend."
304
+ )
305
+ for name, is_namespace_fn in IS_NAMESPACE_FUNCTIONS.items():
306
+ if is_namespace_fn(xp):
307
+ return name
308
+ return "unknown"
254
309
 
255
310
 
256
311
  def resolve_dtype(dtype: Any | str | None, xp: Any) -> Any | None:
@@ -0,0 +1,130 @@
1
+ Metadata-Version: 2.4
2
+ Name: aspire-inference
3
+ Version: 0.1.0a10
4
+ Summary: Accelerate Sequential Posterior Inference via REuse
5
+ Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/mj-will/aspire
8
+ Project-URL: Documentation, https://aspire.readthedocs.io/
9
+ Classifier: Programming Language :: Python :: 3
10
+ Requires-Python: >=3.10
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: numpy
15
+ Requires-Dist: array-api-compat
16
+ Requires-Dist: wrapt
17
+ Requires-Dist: h5py
18
+ Provides-Extra: scipy
19
+ Requires-Dist: scipy; extra == "scipy"
20
+ Provides-Extra: jax
21
+ Requires-Dist: jax; extra == "jax"
22
+ Requires-Dist: jaxlib; extra == "jax"
23
+ Requires-Dist: flowjax; extra == "jax"
24
+ Provides-Extra: torch
25
+ Requires-Dist: torch; extra == "torch"
26
+ Requires-Dist: zuko; extra == "torch"
27
+ Requires-Dist: tqdm; extra == "torch"
28
+ Provides-Extra: minipcn
29
+ Requires-Dist: minipcn[array-api]>=0.2.0a3; extra == "minipcn"
30
+ Requires-Dist: orng; extra == "minipcn"
31
+ Provides-Extra: emcee
32
+ Requires-Dist: emcee; extra == "emcee"
33
+ Provides-Extra: blackjax
34
+ Requires-Dist: blackjax; extra == "blackjax"
35
+ Provides-Extra: test
36
+ Requires-Dist: pytest; extra == "test"
37
+ Requires-Dist: pytest-requires; extra == "test"
38
+ Requires-Dist: pytest-cov; extra == "test"
39
+ Dynamic: license-file
40
+
41
+ # aspire: Accelerated Sequential Posterior Inference via REuse
42
+
43
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
44
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
45
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
46
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
47
+
48
+
49
+ aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
50
+
51
+ ## Installation
52
+
53
+ aspire can be installed from PyPI using `pip`. By default, you need to install
54
+ one of the backends for the normalizing flows, either `torch` or `jax`.
55
+ We also recommend installing `minipcn` if using the `smc` sampler:
56
+
57
+
58
+ **Torch**
59
+
60
+ We recommend installing `torch` manually to ensure correct CPU/CUDA versions are
61
+ installed. See the [PyTorch installation instructions](https://pytorch.org/)
62
+ for more details.
63
+
64
+ ```
65
+ pip install aspire-inference[torch,minipcn]
66
+ ```
67
+
68
+ **Jax**:
69
+
70
+ We recommend install `jax` manually to ensure the correct GPU/CUDA versions
71
+ are installed. See the [jax documentation for details](https://docs.jax.dev/en/latest/installation.html)
72
+
73
+ ```
74
+ pip install aspire-inference[jax,minipcn]
75
+ ```
76
+
77
+ **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
78
+ the package can be imported and used as `aspire`.
79
+
80
+ ## Quickstart
81
+
82
+ ```python
83
+ import numpy as np
84
+ from aspire import Aspire, Samples
85
+
86
+ # Define a log-likelihood and log-prior
87
+ def log_likelihood(samples):
88
+ x = samples.x
89
+ return -0.5 * np.sum(x**2, axis=-1)
90
+
91
+ def log_prior(samples):
92
+ return -0.5 * np.sum(samples.x**2, axis=-1)
93
+
94
+ # Create the initial samples
95
+ init = Samples(np.random.normal(size=(2_000, 4)))
96
+
97
+ # Define the aspire object
98
+ aspire = Aspire(
99
+ log_likelihood=log_likelihood,
100
+ log_prior=log_prior,
101
+ dims=4,
102
+ parameters=[f"x{i}" for i in range(4)],
103
+ )
104
+
105
+ # Fit the normalizing flow
106
+ aspire.fit(init, n_epochs=20)
107
+
108
+ # Sample the posterior
109
+ posterior = aspire.sample_posterior(
110
+ sampler="smc",
111
+ n_samples=500,
112
+ sampler_kwargs=dict(n_steps=100),
113
+ )
114
+
115
+ # Plot the posterior distribution
116
+ posterior.plot_corner()
117
+ ```
118
+
119
+ ## Documentation
120
+
121
+ See the [documentation on ReadTheDocs][docs].
122
+
123
+ ## Citation
124
+
125
+ If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
126
+
127
+
128
+ [docs]: https://aspire.readthedocs.io/
129
+ [DOI]: https://doi.org/10.5281/zenodo.15658747
130
+ [paper]: https://arxiv.org/abs/2511.04218
@@ -1,10 +1,10 @@
1
- aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
1
+ aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
2
  aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
- aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
6
- aspire/transforms.py,sha256=W2JL045zkczvdBVXXeSn2c8aKyX8ASxFKgWU2cJHufk,24922
7
- aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
5
+ aspire/samples.py,sha256=z5x5hpWuVFH1hYhltmROAe8pbWxGD2UvHi3vcc132dg,19399
6
+ aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
+ aspire/utils.py,sha256=eKGmchpuoNL15Xbu-AGoeZ00PcQEykEQiDZMnnRyV6A,24234
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
10
10
  aspire/flows/jax/__init__.py,sha256=7cmiY_MbEC8RDA8Cmi8HVnNJm0sqFKlBsDethdsy5lA,52
@@ -13,16 +13,16 @@ aspire/flows/jax/utils.py,sha256=5T6UrgpARG9VywC9qmTl45LjyZWuEdkW3XUladE6xJE,151
13
13
  aspire/flows/torch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  aspire/flows/torch/flows.py,sha256=0_YkiMT49QolyQnEFsh28tfKLnURVF0Z6aTnaWLIUDI,11672
15
15
  aspire/samplers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- aspire/samplers/base.py,sha256=8slvgOBnacUrHXCVDAqo-3IZ_LB7-dS8wdMP55MI43Y,2907
16
+ aspire/samplers/base.py,sha256=VEHawyVA33jXHMo63p5hBHkp9k2qxU_bOxh5iaZSXew,3011
17
17
  aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M,676
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
- aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
21
+ aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
22
  aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
- aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
24
- aspire_inference-0.1.0a8.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a8.dist-info/METADATA,sha256=Or73qfnx3KsPNSeaNTJ70t56GME2r9qACM03aSW40t8,1965
26
- aspire_inference-0.1.0a8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a8.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a8.dist-info/RECORD,,
23
+ aspire/samplers/smc/minipcn.py,sha256=iQUBBwHZ_D4CxNjARMngklRvx6yTlEDKdeidyYCgqM4,3003
24
+ aspire_inference-0.1.0a10.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a10.dist-info/METADATA,sha256=SdvlXKjQn0uJQPpWzoZKAH3oMJSKpZnvlUzxPsIwNlY,3869
26
+ aspire_inference-0.1.0a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a10.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a10.dist-info/RECORD,,
@@ -1,66 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: aspire-inference
3
- Version: 0.1.0a8
4
- Summary: Accelerate Sequential Posterior Inference via REuse
5
- Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
- License: MIT
7
- Project-URL: Homepage, https://github.com/mj-will/aspire
8
- Project-URL: Documentation, https://aspire.readthedocs.io/
9
- Classifier: Programming Language :: Python :: 3
10
- Requires-Python: >=3.10
11
- Description-Content-Type: text/markdown
12
- License-File: LICENSE
13
- Requires-Dist: matplotlib
14
- Requires-Dist: numpy
15
- Requires-Dist: array-api-compat
16
- Requires-Dist: wrapt
17
- Requires-Dist: h5py
18
- Provides-Extra: scipy
19
- Requires-Dist: scipy; extra == "scipy"
20
- Provides-Extra: jax
21
- Requires-Dist: jax; extra == "jax"
22
- Requires-Dist: jaxlib; extra == "jax"
23
- Requires-Dist: flowjax; extra == "jax"
24
- Provides-Extra: torch
25
- Requires-Dist: torch; extra == "torch"
26
- Requires-Dist: zuko; extra == "torch"
27
- Requires-Dist: tqdm; extra == "torch"
28
- Provides-Extra: minipcn
29
- Requires-Dist: minipcn; extra == "minipcn"
30
- Provides-Extra: emcee
31
- Requires-Dist: emcee; extra == "emcee"
32
- Provides-Extra: blackjax
33
- Requires-Dist: blackjax; extra == "blackjax"
34
- Provides-Extra: test
35
- Requires-Dist: pytest; extra == "test"
36
- Requires-Dist: pytest-requires; extra == "test"
37
- Requires-Dist: pytest-cov; extra == "test"
38
- Dynamic: license-file
39
-
40
- # aspire: Accelerated Sequential Posterior Inference via REuse
41
-
42
- aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
43
-
44
- ## Installation
45
-
46
- aspire can be installed from PyPI using `pip`
47
-
48
- ```
49
- pip install aspire-inference
50
- ```
51
-
52
- **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
53
- the package can be imported and used as `aspire`.
54
-
55
- ## Documentation
56
-
57
- See the [documentation on ReadTheDocs][docs].
58
-
59
- ## Citation
60
-
61
- If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
62
-
63
-
64
- [docs]: https://aspire.readthedocs.io/
65
- [DOI]: https://doi.org/10.5281/zenodo.15658747
66
- [paper]: https://arxiv.org/abs/2511.04218