aspire-inference 0.1.0a8__tar.gz → 0.1.0a9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/PKG-INFO +46 -1
  2. aspire_inference-0.1.0a9/README.md +72 -0
  3. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/aspire_inference.egg-info/PKG-INFO +46 -1
  4. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/aspire_inference.egg-info/SOURCES.txt +3 -1
  5. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/conf.py +14 -0
  6. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/index.rst +3 -3
  7. aspire_inference-0.1.0a9/docs/multiprocessing.rst +70 -0
  8. aspire_inference-0.1.0a9/docs/recipes.rst +70 -0
  9. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/requirements.txt +1 -0
  10. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/user_guide.rst +51 -7
  11. aspire_inference-0.1.0a9/examples/blackjax_smc_example.py +158 -0
  12. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/__init__.py +2 -0
  13. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/smc/blackjax.py +17 -5
  14. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/transforms.py +1 -0
  15. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/integration_tests/conftest.py +17 -4
  16. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/integration_tests/test_integration.py +2 -2
  17. aspire_inference-0.1.0a8/README.md +0 -27
  18. aspire_inference-0.1.0a8/docs/api.rst +0 -28
  19. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/.github/workflows/lint.yml +0 -0
  20. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/.github/workflows/publish.yml +0 -0
  21. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/.github/workflows/tests.yml +0 -0
  22. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/.gitignore +0 -0
  23. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/.pre-commit-config.yaml +0 -0
  24. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/LICENSE +0 -0
  25. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/aspire_inference.egg-info/dependency_links.txt +0 -0
  26. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/aspire_inference.egg-info/requires.txt +0 -0
  27. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/aspire_inference.egg-info/top_level.txt +0 -0
  28. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/Makefile +0 -0
  29. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/entry_points.rst +0 -0
  30. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/examples.rst +0 -0
  31. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/docs/installation.rst +0 -0
  32. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/examples/basic_example.py +0 -0
  33. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/examples/smc_example.py +0 -0
  34. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/pyproject.toml +0 -0
  35. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/readthedocs.yml +0 -0
  36. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/setup.cfg +0 -0
  37. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/aspire.py +0 -0
  38. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/__init__.py +0 -0
  39. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/base.py +0 -0
  40. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/jax/__init__.py +0 -0
  41. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/jax/flows.py +0 -0
  42. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/jax/utils.py +0 -0
  43. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/torch/__init__.py +0 -0
  44. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/flows/torch/flows.py +0 -0
  45. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/history.py +0 -0
  46. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/plot.py +0 -0
  47. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/__init__.py +0 -0
  48. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/base.py +0 -0
  49. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/importance.py +0 -0
  50. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/mcmc.py +0 -0
  51. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/smc/__init__.py +0 -0
  52. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/smc/base.py +0 -0
  53. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/smc/emcee.py +0 -0
  54. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samplers/smc/minipcn.py +0 -0
  55. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/samples.py +0 -0
  56. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/src/aspire/utils.py +0 -0
  57. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/conftest.py +0 -0
  58. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_flows/test_flows_core.py +0 -0
  59. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_flows/test_jax_flows/test_flowjax_flows.py +0 -0
  60. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_flows/test_torch_flows/test_zuko_flows.py +0 -0
  61. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_samples.py +0 -0
  62. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_transforms.py +0 -0
  63. {aspire_inference-0.1.0a8 → aspire_inference-0.1.0a9}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a8
3
+ Version: 0.1.0a9
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -39,6 +39,12 @@ Dynamic: license-file
39
39
 
40
40
  # aspire: Accelerated Sequential Posterior Inference via REuse
41
41
 
42
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
43
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
44
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
45
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
46
+
47
+
42
48
  aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
43
49
 
44
50
  ## Installation
@@ -52,6 +58,45 @@ pip install aspire-inference
52
58
  **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
53
59
  the package can be imported and used as `aspire`.
54
60
 
61
+ ## Quickstart
62
+
63
+ ```python
64
+ import numpy as np
65
+ from aspire import Aspire, Samples
66
+
67
+ # Define a log-likelihood and log-prior
68
+ def log_likelihood(samples):
69
+ x = samples.x
70
+ return -0.5 * np.sum(x**2, axis=-1)
71
+
72
+ def log_prior(samples):
73
+ return -0.5 * np.sum(samples.x**2, axis=-1)
74
+
75
+ # Create the initial samples
76
+ init = Samples(np.random.normal(size=(2_000, 4)))
77
+
78
+ # Define the aspire object
79
+ aspire = Aspire(
80
+ log_likelihood=log_likelihood,
81
+ log_prior=log_prior,
82
+ dims=4,
83
+ parameters=[f"x{i}" for i in range(4)],
84
+ )
85
+
86
+ # Fit the normalizing flow
87
+ aspire.fit(init, n_epochs=20)
88
+
89
+ # Sample the posterior
90
+ posterior = aspire.sample_posterior(
91
+ sampler="smc",
92
+ n_samples=500,
93
+ sampler_kwargs=dict(n_steps=100),
94
+ )
95
+
96
+ # Plot the posterior distribution
97
+ posterior.plot_corner()
98
+ ```
99
+
55
100
  ## Documentation
56
101
 
57
102
  See the [documentation on ReadTheDocs][docs].
@@ -0,0 +1,72 @@
1
+ # aspire: Accelerated Sequential Posterior Inference via REuse
2
+
3
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
4
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
5
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
6
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
7
+
8
+
9
+ aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
10
+
11
+ ## Installation
12
+
13
+ aspire can be installed from PyPI using `pip`
14
+
15
+ ```
16
+ pip install aspire-inference
17
+ ```
18
+
19
+ **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
20
+ the package can be imported and used as `aspire`.
21
+
22
+ ## Quickstart
23
+
24
+ ```python
25
+ import numpy as np
26
+ from aspire import Aspire, Samples
27
+
28
+ # Define a log-likelihood and log-prior
29
+ def log_likelihood(samples):
30
+ x = samples.x
31
+ return -0.5 * np.sum(x**2, axis=-1)
32
+
33
+ def log_prior(samples):
34
+ return -0.5 * np.sum(samples.x**2, axis=-1)
35
+
36
+ # Create the initial samples
37
+ init = Samples(np.random.normal(size=(2_000, 4)))
38
+
39
+ # Define the aspire object
40
+ aspire = Aspire(
41
+ log_likelihood=log_likelihood,
42
+ log_prior=log_prior,
43
+ dims=4,
44
+ parameters=[f"x{i}" for i in range(4)],
45
+ )
46
+
47
+ # Fit the normalizing flow
48
+ aspire.fit(init, n_epochs=20)
49
+
50
+ # Sample the posterior
51
+ posterior = aspire.sample_posterior(
52
+ sampler="smc",
53
+ n_samples=500,
54
+ sampler_kwargs=dict(n_steps=100),
55
+ )
56
+
57
+ # Plot the posterior distribution
58
+ posterior.plot_corner()
59
+ ```
60
+
61
+ ## Documentation
62
+
63
+ See the [documentation on ReadTheDocs][docs].
64
+
65
+ ## Citation
66
+
67
+ If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
68
+
69
+
70
+ [docs]: https://aspire.readthedocs.io/
71
+ [DOI]: https://doi.org/10.5281/zenodo.15658747
72
+ [paper]: https://arxiv.org/abs/2511.04218
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a8
3
+ Version: 0.1.0a9
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -39,6 +39,12 @@ Dynamic: license-file
39
39
 
40
40
  # aspire: Accelerated Sequential Posterior Inference via REuse
41
41
 
42
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
43
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
44
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
45
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
46
+
47
+
42
48
  aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
43
49
 
44
50
  ## Installation
@@ -52,6 +58,45 @@ pip install aspire-inference
52
58
  **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
53
59
  the package can be imported and used as `aspire`.
54
60
 
61
+ ## Quickstart
62
+
63
+ ```python
64
+ import numpy as np
65
+ from aspire import Aspire, Samples
66
+
67
+ # Define a log-likelihood and log-prior
68
+ def log_likelihood(samples):
69
+ x = samples.x
70
+ return -0.5 * np.sum(x**2, axis=-1)
71
+
72
+ def log_prior(samples):
73
+ return -0.5 * np.sum(samples.x**2, axis=-1)
74
+
75
+ # Create the initial samples
76
+ init = Samples(np.random.normal(size=(2_000, 4)))
77
+
78
+ # Define the aspire object
79
+ aspire = Aspire(
80
+ log_likelihood=log_likelihood,
81
+ log_prior=log_prior,
82
+ dims=4,
83
+ parameters=[f"x{i}" for i in range(4)],
84
+ )
85
+
86
+ # Fit the normalizing flow
87
+ aspire.fit(init, n_epochs=20)
88
+
89
+ # Sample the posterior
90
+ posterior = aspire.sample_posterior(
91
+ sampler="smc",
92
+ n_samples=500,
93
+ sampler_kwargs=dict(n_steps=100),
94
+ )
95
+
96
+ # Plot the posterior distribution
97
+ posterior.plot_corner()
98
+ ```
99
+
55
100
  ## Documentation
56
101
 
57
102
  See the [documentation on ReadTheDocs][docs].
@@ -13,15 +13,17 @@ aspire_inference.egg-info/dependency_links.txt
13
13
  aspire_inference.egg-info/requires.txt
14
14
  aspire_inference.egg-info/top_level.txt
15
15
  docs/Makefile
16
- docs/api.rst
17
16
  docs/conf.py
18
17
  docs/entry_points.rst
19
18
  docs/examples.rst
20
19
  docs/index.rst
21
20
  docs/installation.rst
21
+ docs/multiprocessing.rst
22
+ docs/recipes.rst
22
23
  docs/requirements.txt
23
24
  docs/user_guide.rst
24
25
  examples/basic_example.py
26
+ examples/blackjax_smc_example.py
25
27
  examples/smc_example.py
26
28
  src/aspire/__init__.py
27
29
  src/aspire/aspire.py
@@ -29,6 +29,7 @@ extensions = [
29
29
  "sphinx.ext.autosummary",
30
30
  "sphinx.ext.napoleon",
31
31
  "sphinx.ext.viewcode",
32
+ "autoapi.extension",
32
33
  ]
33
34
 
34
35
  autodoc_typehints = "description"
@@ -41,6 +42,19 @@ napoleon_preprocess_types = True
41
42
  templates_path = ["_templates"]
42
43
  exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
43
44
 
45
+ # -- Configure autoapi -------------------------------------------------------
46
+
47
+ autoapi_type = "python"
48
+ autoapi_dirs = ["../src/aspire/"]
49
+ autoapi_add_toctree_entry = True
50
+ autoapi_options = [
51
+ "members",
52
+ "imported-members",
53
+ "show-inheritance",
54
+ "show-module-summary",
55
+ "undoc-members",
56
+ ]
57
+
44
58
  # -- Options for HTML output -------------------------------------------------
45
59
 
46
60
  html_theme = "sphinx_book_theme"
@@ -22,8 +22,7 @@ Quick start
22
22
  .. code-block:: python
23
23
 
24
24
  import numpy as np
25
- from aspire import Aspire
26
- from aspire.samples import Samples
25
+ from aspire import Aspire, Samples
27
26
 
28
27
  def log_likelihood(samples):
29
28
  x = samples.x
@@ -58,6 +57,7 @@ examples, and the complete API reference.
58
57
 
59
58
  installation
60
59
  user_guide
60
+ recipes
61
+ multiprocessing
61
62
  examples
62
63
  entry_points
63
- api
@@ -0,0 +1,70 @@
1
+ Multiprocessing
2
+ ===============
3
+
4
+ Use :meth:`aspire.Aspire.enable_pool` to run your likelihood (and optionally
5
+ prior) in parallel across a :class:`multiprocessing.Pool`. The helper swaps the
6
+ ``map_fn`` argument expected by your log-likelihood / log-prior for
7
+ ``pool.map`` while the context is active, then restores the original methods.
8
+
9
+ Prepare a map-aware likelihood
10
+ ------------------------------
11
+
12
+ Your likelihood must accept ``map_fn``. A minimal
13
+ pattern:
14
+
15
+ .. code-block:: python
16
+
17
+ import numpy as np
18
+
19
+
20
+ def _global_log_likelihood(x):
21
+ # Expensive likelihood computation for a single sample `x`
22
+ return -np.sum(x**2) # Example likelihood
23
+
24
+ def log_likelihood(samples, map_fn=map):
25
+ logl = -np.inf * np.ones(len(samples.x))
26
+ if samples.log_prior is None:
27
+ raise RuntimeError("log-prior has not been evaluated!")
28
+ mask = np.isfinite(samples.log_prior, dtype=bool)
29
+ x = np.asarray(samples.x[mask, :], dtype=float)
30
+ logl[mask] = np.fromiter(
31
+ map_fn(_global_log_likelihood, x),
32
+ dtype=float,
33
+ )
34
+ return logl
35
+
36
+ Swap in a multiprocessing pool
37
+ ------------------------------
38
+
39
+ Wrap your sampling call inside ``enable_pool`` to parallelize the map step:
40
+
41
+ .. code-block:: python
42
+
43
+ import multiprocessing as mp
44
+ from aspire import Aspire
45
+
46
+ aspire = Aspire(
47
+ log_likelihood=log_likelihood,
48
+ log_prior=log_prior, # must also accept map_fn if parallelize_prior=True
49
+ dims=4,
50
+ parameters=["a", "b", "c", "d"],
51
+ )
52
+
53
+ with mp.Pool() as pool, aspire.enable_pool(pool):
54
+ samples, history = aspire.sample_posterior(
55
+ sampler="smc",
56
+ n_samples=1_000,
57
+ return_history=True,
58
+ )
59
+
60
+ Notes
61
+ -----
62
+
63
+ - By default only the likelihood is parallelized; set
64
+ ``aspire.enable_pool(pool, parallelize_prior=True)`` if your prior also
65
+ accepts ``map_fn``.
66
+ - ``enable_pool`` closes the pool on exit unless you pass ``close_pool=False``.
67
+ - The context manager itself is implemented by
68
+ :class:`aspire.utils.PoolHandler`; if you need finer control (for example,
69
+ reusing the same pool across multiple ``Aspire`` instances) you can
70
+ instantiate it directly.
@@ -0,0 +1,70 @@
1
+ Practical recipes
2
+ =================
3
+
4
+ Checking the prior when evaluating the likelihood
5
+ -------------------------------------------------
6
+
7
+ By default, Aspire samplers always evaluate the log-prior before the
8
+ log-likelihood. This allows users to check the prior support and skip
9
+ likelihood evaluations for samples that lie outside the prior bounds.
10
+
11
+ .. code-block:: python
12
+
13
+ import aspire
14
+ import numpy as np
15
+
16
+
17
+ def log_likelihood(samples: aspire.Samples) -> np.ndarray:
18
+ if samples.log_prior is None:
19
+ raise RuntimeError("log-prior has not been evaluated!")
20
+ # Return -inf for samples with invalid prior
21
+ logl = np.full(samples.n_samples, -np.inf, dtype=float)
22
+ # Only evaluate the likelihood where the prior is finite
23
+ mask = np.isfinite(samples.log_prior, dtype=bool)
24
+ # Valid samples
25
+ x = samples.x[mask, :]
26
+ logl[mask] = -np.sum(x**2, axis=1) # Example likelihood
27
+ return logl
28
+
29
+
30
+ Checking the flow distribution
31
+ ------------------------------
32
+
33
+ It can be useful to inspect the flow-based proposal distribution before sampling
34
+ from the posterior. You can do this by drawing samples from the flow after fitting
35
+ and comparing them to the initial samples:
36
+
37
+
38
+ .. code-block:: python
39
+
40
+ from aspire import Aspire, Samples
41
+ from aspire.plot import plot_comparison
42
+
43
+ # Define the initial samples
44
+ initial_samples = Samples(...)
45
+
46
+ # Define the Aspire instance
47
+ aspire = Aspire(
48
+ log_likelihood=log_likelihood,
49
+ log_prior=log_prior,
50
+ ...
51
+ )
52
+
53
+ # Fit the flow to the initial samples
54
+ fit_history = aspire.fit(initial_samples)
55
+
56
+ # Draw samples from the flow
57
+ flow_samples = aspire.sample_flow(10_000)
58
+
59
+ # Plot a comparison between initial samples and flow samples
60
+ fig = plot_comparison(
61
+ initial_samples,
62
+ flow_samples,
63
+ per_samples_kwargs=[
64
+ dict(include_weights=False, color="C0"),
65
+ dict(include_weights=False, color="C1"),
66
+ ],
67
+ labels=["Initial samples", "Flow samples"],
68
+ )
69
+ # Save or show the figure
70
+ fig.savefig("flow_comparison.png")
@@ -1,2 +1,3 @@
1
1
  sphinx>=7.2
2
2
  sphinx-book-theme>=1.1
3
+ sphinx-autoapi>=3.2
@@ -44,8 +44,8 @@ switch namespaces or merge multiple runs with
44
44
  Flows and transforms
45
45
  --------------------
46
46
 
47
- Aspire can work with any proposal that implements ``sample_and_log_prob`` and
48
- ``log_prob``; normalising flows remain the default. Flows are defined via
47
+ Aspire can work with any flow that implements ``sample_and_log_prob`` and
48
+ ``log_prob``. Flows are defined via
49
49
  :class:`aspire.flows.base.Flow` and instantiated by
50
50
  :meth:`aspire.Aspire.init_flow`. By default Aspire uses the ``zuko``
51
51
  implementation of Masked Autoregressive Flows on top of PyTorch. The flow is
@@ -61,6 +61,55 @@ estimator (requires the `zuko` backend).
61
61
  External flow implementations can be plugged in via the
62
62
  ``aspire.flows`` entry point group. See :ref:`custom_flows` for details.
63
63
 
64
+ Transform mechanics
65
+ ~~~~~~~~~~~~~~~~~~~
66
+
67
+ Aspire keeps a clear separation between your native parameters and the space
68
+ where flows or kernels operate:
69
+
70
+ * :class:`aspire.transforms.FlowTransform` is attached to every flow created by
71
+ :meth:`aspire.Aspire.init_flow`. By default, it maps bounded parameters to the real line (``probit`` or
72
+ ``logit``), and recentres / rescales dimensions with an affine
73
+ transform learned from the training samples. Log-Jacobian terms are tracked so
74
+ calls to ``log_prob`` or ``sample_and_log_prob`` remain properly normalised.
75
+ ``bounded_to_unbounded`` and ``affine_transform`` can be specified when creating
76
+ the Aspire instance to control this behaviour.
77
+ * The same components are exposed via :class:`aspire.transforms.CompositeTransform`
78
+ if you want to opt out of the bounded-to-unbounded step or the affine
79
+ whitening when building custom transports.
80
+
81
+ Preconditioning inside samplers
82
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
83
+
84
+ SMC and MCMC samplers also work in a transformed space. They fit the chosen
85
+ ``preconditioning`` transform to the initial particles, perform moves there, and
86
+ then call ``inverse(...)`` (including the log-Jacobian) whenever the likelihood
87
+ or prior is evaluated. Configure it via
88
+ :meth:`aspire.Aspire.sample_posterior`:
89
+
90
+ * ``"default"`` / ``"standard"`` uses :class:`aspire.transforms.CompositeTransform`
91
+ with bounded-to-unbounded and affine scaling turned off by default; periodic
92
+ wrapping still applies. To whiten dimensions or map bounds to the real line,
93
+ pass ``preconditioning_kwargs={"affine_transform": True, "bounded_to_unbounded": True}``.
94
+ * ``"flow"`` fits a lightweight :class:`aspire.transforms.FlowPreconditioningTransform`
95
+ to the current particles and treats it as a transport map during SMC/MCMC
96
+ updates. This reuses the same bounded / periodic handling while providing a
97
+ richer geometry for the kernels.
98
+ * ``None`` leaves the sampler in the original parameterisation with an identity
99
+ transform. The importance sampler defaults to this; other samplers default to
100
+ ``"standard"`` so periodic parameters are at least kept consistent with their
101
+ bounds.
102
+
103
+
104
+ .. note::
105
+
106
+ By default, the preconditioning transform does not include bounded-to-unbounded
107
+ steps. This means your log-prior and log-likelihood must handle points that
108
+ lie outside the specified bounds (e.g. by returning ``-inf``). If you want
109
+ the sampler to automatically map bounded parameters to an unconstrained
110
+ space, enable the ``bounded_to_unbounded`` option in
111
+ ``preconditioning_kwargs``.
112
+
64
113
  Sampling strategies
65
114
  -------------------
66
115
 
@@ -101,11 +150,6 @@ Sequential Monte Carlo
101
150
  Replaces the internal MCMC move with the ``emcee`` ensemble sampler,
102
151
  providing a gradient-free option that still benefits from SMC tempering.
103
152
 
104
- You can plug in custom preconditioning by setting ``preconditioning`` to
105
- ``"standard"`` (affine normalisation based on current samples), ``"flow"``
106
- (use the fitted flow as a transport map), or ``None`` to disable additional
107
- transforms.
108
-
109
153
  History, diagnostics, and persistence
110
154
  -------------------------------------
111
155
 
@@ -0,0 +1,158 @@
1
+ """Example using sequential posterior inference with SMC.
2
+
3
+ This example uses JAX for computations and BlackJAX for the MCMC sampling
4
+ in SMC step.
5
+ """
6
+
7
+ from pathlib import Path
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+
12
+ from aspire import Aspire
13
+ from aspire.plot import plot_comparison
14
+ from aspire.samples import Samples
15
+ from aspire.utils import configure_logger
16
+
17
+ # RNG for generating initial samples
18
+ key = jax.random.key(42)
19
+
20
+ # Output directory
21
+ outdir = Path("outdir") / "blackjax_smc_example"
22
+ outdir.mkdir(parents=True, exist_ok=True)
23
+
24
+ # Configure logger to show INFO level messages
25
+ configure_logger()
26
+
27
+ # Number of dimensions
28
+ dims = 4
29
+ # Parameter names
30
+ parameters = [f"x{i}" for i in range(dims)]
31
+ # Prior bounds
32
+ prior_bounds = {param: (-5, 5) for param in parameters}
33
+
34
+ # Means and covariances of the two Gaussian components
35
+ mu1 = 2 * jnp.ones(dims)
36
+ mu2 = -2 * jnp.ones(dims)
37
+ cov1 = 0.5 * jnp.eye(dims)
38
+ cov2 = jnp.eye(dims)
39
+
40
+
41
+ def log_likelihood(samples):
42
+ """Log-likelihood of a mixture of two Gaussians"""
43
+ x = samples.x
44
+ comp1 = (
45
+ -0.5 * ((x - mu1) @ jnp.linalg.inv(cov1) * (x - mu1)).sum(axis=-1)
46
+ - 0.5 * dims * jnp.log(2 * jnp.pi)
47
+ - 0.5 * jnp.linalg.slogdet(cov1)[1]
48
+ )
49
+ comp2 = (
50
+ -0.5 * ((x - mu2) @ jnp.linalg.inv(cov2) * (x - mu2)).sum(axis=-1)
51
+ - 0.5 * dims * jnp.log(2 * jnp.pi)
52
+ - 0.5 * jnp.linalg.slogdet(cov2)[1]
53
+ )
54
+ return jnp.logaddexp(comp1, comp2) # Log-sum-exp for numerical stability
55
+
56
+
57
+ def log_prior(samples):
58
+ """Uniform prior between -5 and 5 in each dimension"""
59
+ x = samples.x
60
+ in_bounds = jnp.all((x >= -5) & (x <= 5), axis=-1)
61
+ logp = jnp.where(in_bounds, -dims * jnp.log(10), -jnp.inf)
62
+ return logp
63
+
64
+
65
+ # Generate prior samples for comparison, these are not used in SMC
66
+ key, prior_key = jax.random.split(key)
67
+ prior_samples = Samples(
68
+ jax.random.uniform(prior_key, shape=(5000, dims), minval=-5, maxval=5),
69
+ parameters=parameters,
70
+ )
71
+
72
+ # True posterior samples for comparison
73
+ key, post_key0, post_key1 = jax.random.split(key, 3)
74
+ true_posterior_samples = Samples(
75
+ jnp.concatenate(
76
+ [
77
+ jax.random.multivariate_normal(
78
+ post_key0, mu1, cov1, shape=(2500,)
79
+ ),
80
+ jax.random.multivariate_normal(
81
+ post_key1, mu2, cov2, shape=(2500,)
82
+ ),
83
+ ],
84
+ axis=0,
85
+ ),
86
+ parameters=parameters,
87
+ )
88
+
89
+ # We draw initial samples from two Gaussians centered away from the true modes
90
+ # to demonstrate the ability of SMC to explore the posterior
91
+ key, offset_key1, offset_key2 = jax.random.split(key, 3)
92
+ offset_1 = jax.random.uniform(offset_key1, shape=(dims,), minval=-3, maxval=3)
93
+ offset_2 = jax.random.uniform(offset_key2, shape=(dims,), minval=-3, maxval=3)
94
+ key, init_key1, init_key2 = jax.random.split(key, 3)
95
+ initial_samples = jnp.concatenate(
96
+ [
97
+ jax.random.normal(init_key1, shape=(2500, dims)) + (mu1 - offset_1),
98
+ jax.random.normal(init_key2, shape=(2500, dims)) + (mu2 - offset_2),
99
+ ],
100
+ axis=0,
101
+ )
102
+ initial_samples = Samples(initial_samples, parameters=parameters)
103
+
104
+ # Initialize Aspire with the log-likelihood and log-prior
105
+ key, aspire_key = jax.random.split(key)
106
+ aspire = Aspire(
107
+ log_likelihood=log_likelihood,
108
+ log_prior=log_prior,
109
+ dims=dims,
110
+ flow_backend="flowjax", # Use Flowjax as the normalizing flow backend
111
+ prior_bounds=prior_bounds, # Specify prior bounds
112
+ parameters=parameters,
113
+ key=aspire_key,
114
+ )
115
+
116
+ # Fit the normalizing flow to the initial samples
117
+ fit_history = aspire.fit(initial_samples, max_epochs=30)
118
+
119
+ # Plot loss
120
+ fit_history.plot_loss().savefig(outdir / "loss.png")
121
+
122
+ # Sample from the posterior using SMC
123
+ # We use BlackJAX's NUTS as the MCMC kernel within SMC
124
+ # We enable the bounded to unbounded transform in the preconditioning to avoid
125
+ # issues with NUTS on bounded spaces
126
+ samples, history = aspire.sample_posterior(
127
+ sampler="blackjax_smc", # use the BlackJAX SMC sampler
128
+ n_samples=500, # Number of particles in SMC
129
+ n_final_samples=5000, # Number of samples to draw from the final distribution
130
+ adaptive=True,
131
+ target_efficiency=0.8,
132
+ sampler_kwargs=dict( # Keyword arguments for the specific sampler
133
+ algorithm="nuts", # Use NUTS within SMC
134
+ step_size=0.1, # Step size for NUTS, this will need tuning
135
+ n_steps=20, # Number of leapfrog steps for NUTS
136
+ ),
137
+ preconditioning_kwargs=dict(
138
+ affine_transform=True, # Use affine transform preconditioning
139
+ bounded_to_unbounded=True, # Transform bounded parameters to unbounded space
140
+ ),
141
+ return_history=True, # To return the SMC history (e.g., ESS, betas)
142
+ )
143
+ # Plot SMC diagnostics
144
+ history.plot().savefig(outdir / "smc_diagnostics.png")
145
+
146
+ # Plot corner plot of the samples
147
+ # Include initial samples and prior samples for comparison
148
+ plot_comparison(
149
+ initial_samples,
150
+ true_posterior_samples,
151
+ samples,
152
+ labels=["Initial Samples", "True Posterior Samples", "SMC Samples"],
153
+ per_samples_kwargs=[
154
+ {"color": "grey"},
155
+ {"color": "k"},
156
+ {"color": "C1"},
157
+ ],
158
+ ).savefig(outdir / "posterior.png")
@@ -6,6 +6,7 @@ import logging
6
6
  from importlib.metadata import PackageNotFoundError, version
7
7
 
8
8
  from .aspire import Aspire
9
+ from .samples import Samples
9
10
 
10
11
  try:
11
12
  __version__ = version("aspire")
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
17
 
17
18
  __all__ = [
18
19
  "Aspire",
20
+ "Samples",
19
21
  ]
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
220
220
 
221
221
  elif algorithm == "nuts":
222
222
  # Initialize step size and mass matrix if not provided
223
- inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
223
+ inverse_mass_matrix = self.sampler_kwargs.get(
224
+ "inverse_mass_matrix"
225
+ )
224
226
  if inverse_mass_matrix is None:
225
227
  inverse_mass_matrix = jax.numpy.eye(self.dims)
226
228
 
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
258
260
  z_final = final_states.position
259
261
 
260
262
  # Calculate acceptance rates
261
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
262
- mean_acceptance = jax.numpy.mean(acceptance_rates)
263
+ try:
264
+ acceptance_rates = jax.numpy.mean(
265
+ all_infos.is_accepted, axis=1
266
+ )
267
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
268
+ except AttributeError:
269
+ mean_acceptance = np.nan
263
270
 
264
271
  elif algorithm == "hmc":
265
272
  # Initialize HMC sampler
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
294
301
 
295
302
  final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
296
303
  z_final = final_states.position
297
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
298
- mean_acceptance = jax.numpy.mean(acceptance_rates)
304
+ try:
305
+ acceptance_rates = jax.numpy.mean(
306
+ all_infos.is_accepted, axis=1
307
+ )
308
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
309
+ except AttributeError:
310
+ mean_acceptance = np.nan
299
311
 
300
312
  else:
301
313
  raise ValueError(f"Unsupported algorithm: {algorithm}")
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
332
332
  prior_bounds=self.prior_bounds,
333
333
  bounded_to_unbounded=self.bounded_to_unbounded,
334
334
  bounded_transform=self.bounded_transform,
335
+ affine_transform=self.affine_transform,
335
336
  device=self.device,
336
337
  xp=xp or self.xp,
337
338
  eps=self.eps,
@@ -116,7 +116,8 @@ def log_prior(dims, xp):
116
116
  "smc",
117
117
  "emcee",
118
118
  "minipcn",
119
- "blackjax_smc",
119
+ "blackjax_smc_rwmh",
120
+ "blackjax_smc_nuts",
120
121
  ]
121
122
  )
122
123
  def sampler_config(request):
@@ -126,7 +127,7 @@ def sampler_config(request):
126
127
  return SamplerConfig(
127
128
  sampler="emcee",
128
129
  sampler_kwargs={
129
- "nsteps": 500,
130
+ "nsteps": 100,
130
131
  },
131
132
  )
132
133
  elif request.param == "minipcn":
@@ -157,7 +158,7 @@ def sampler_config(request):
157
158
  },
158
159
  },
159
160
  )
160
- elif request.param == "blackjax_smc":
161
+ elif request.param == "blackjax_smc_rwmh":
161
162
  return SamplerConfig(
162
163
  sampler="blackjax_smc",
163
164
  sampler_kwargs={
@@ -165,7 +166,19 @@ def sampler_config(request):
165
166
  "sampler_kwargs": {
166
167
  "algorithm": "rwmh",
167
168
  "sigma": 0.1,
168
- "n_steps": 500,
169
+ "n_steps": 10,
170
+ },
171
+ },
172
+ )
173
+ elif request.param == "blackjax_smc_nuts":
174
+ return SamplerConfig(
175
+ sampler="blackjax_smc",
176
+ sampler_kwargs={
177
+ "adaptive": True,
178
+ "sampler_kwargs": {
179
+ "algorithm": "nuts",
180
+ "step_size": 0.1,
181
+ "n_steps": 10,
169
182
  },
170
183
  },
171
184
  )
@@ -21,7 +21,7 @@ def test_integration_zuko(
21
21
  dtype,
22
22
  tmp_path,
23
23
  ):
24
- if sampler_config.sampler == "blackjax_smc":
24
+ if "blackjax_smc" in sampler_config.sampler:
25
25
  pytest.xfail(reason="BlackJAX requires JAX arrays.")
26
26
 
27
27
  aspire = Aspire(
@@ -63,7 +63,7 @@ def test_integration_flowjax(
63
63
  ):
64
64
  import jax
65
65
 
66
- if sampler_config.sampler == "blackjax_smc":
66
+ if "blackjax_smc" in sampler_config.sampler:
67
67
  if samples_backend != "jax":
68
68
  pytest.xfail(reason="BlackJAX requires JAX arrays.")
69
69
  if dtype == "float32":
@@ -1,27 +0,0 @@
1
- # aspire: Accelerated Sequential Posterior Inference via REuse
2
-
3
- aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
4
-
5
- ## Installation
6
-
7
- aspire can be installed from PyPI using `pip`
8
-
9
- ```
10
- pip install aspire-inference
11
- ```
12
-
13
- **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
14
- the package can be imported and used as `aspire`.
15
-
16
- ## Documentation
17
-
18
- See the [documentation on ReadTheDocs][docs].
19
-
20
- ## Citation
21
-
22
- If you use `aspire` in your work please cite the [DOI][DOI] and [paper][paper].
23
-
24
-
25
- [docs]: https://aspire.readthedocs.io/
26
- [DOI]: https://doi.org/10.5281/zenodo.15658747
27
- [paper]: https://arxiv.org/abs/2511.04218
@@ -1,28 +0,0 @@
1
- API Reference
2
- =============
3
-
4
- This section documents the main public classes.
5
-
6
- Aspire interface
7
- ----------------
8
-
9
- .. autoclass:: aspire.Aspire
10
- :members:
11
- :undoc-members:
12
- :show-inheritance:
13
-
14
- Samples utilities
15
- -----------------
16
-
17
- .. automodule:: aspire.samples
18
- :members: Samples, SMCSamples
19
- :undoc-members:
20
- :show-inheritance:
21
-
22
- History objects
23
- ---------------
24
-
25
- .. automodule:: aspire.history
26
- :members: History, FlowHistory, SMCHistory
27
- :undoc-members:
28
- :show-inheritance: