aspire-inference 0.1.0a8__py3-none-any.whl → 0.1.0a9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/__init__.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
6
  from importlib.metadata import PackageNotFoundError, version
7
7
 
8
8
  from .aspire import Aspire
9
+ from .samples import Samples
9
10
 
10
11
  try:
11
12
  __version__ = version("aspire")
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
16
17
 
17
18
  __all__ = [
18
19
  "Aspire",
20
+ "Samples",
19
21
  ]
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
220
220
 
221
221
  elif algorithm == "nuts":
222
222
  # Initialize step size and mass matrix if not provided
223
- inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
223
+ inverse_mass_matrix = self.sampler_kwargs.get(
224
+ "inverse_mass_matrix"
225
+ )
224
226
  if inverse_mass_matrix is None:
225
227
  inverse_mass_matrix = jax.numpy.eye(self.dims)
226
228
 
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
258
260
  z_final = final_states.position
259
261
 
260
262
  # Calculate acceptance rates
261
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
262
- mean_acceptance = jax.numpy.mean(acceptance_rates)
263
+ try:
264
+ acceptance_rates = jax.numpy.mean(
265
+ all_infos.is_accepted, axis=1
266
+ )
267
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
268
+ except AttributeError:
269
+ mean_acceptance = np.nan
263
270
 
264
271
  elif algorithm == "hmc":
265
272
  # Initialize HMC sampler
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
294
301
 
295
302
  final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
296
303
  z_final = final_states.position
297
- acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
298
- mean_acceptance = jax.numpy.mean(acceptance_rates)
304
+ try:
305
+ acceptance_rates = jax.numpy.mean(
306
+ all_infos.is_accepted, axis=1
307
+ )
308
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
309
+ except AttributeError:
310
+ mean_acceptance = np.nan
299
311
 
300
312
  else:
301
313
  raise ValueError(f"Unsupported algorithm: {algorithm}")
aspire/transforms.py CHANGED
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
332
332
  prior_bounds=self.prior_bounds,
333
333
  bounded_to_unbounded=self.bounded_to_unbounded,
334
334
  bounded_transform=self.bounded_transform,
335
+ affine_transform=self.affine_transform,
335
336
  device=self.device,
336
337
  xp=xp or self.xp,
337
338
  eps=self.eps,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aspire-inference
3
- Version: 0.1.0a8
3
+ Version: 0.1.0a9
4
4
  Summary: Accelerate Sequential Posterior Inference via REuse
5
5
  Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
6
6
  License: MIT
@@ -39,6 +39,12 @@ Dynamic: license-file
39
39
 
40
40
  # aspire: Accelerated Sequential Posterior Inference via REuse
41
41
 
42
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15658747.svg)](https://doi.org/10.5281/zenodo.15658747)
43
+ [![PyPI](https://img.shields.io/pypi/v/aspire-inference)](https://pypi.org/project/aspire-inference/)
44
+ [![Documentation Status](https://readthedocs.org/projects/aspire/badge/?version=latest)](https://aspire.readthedocs.io/en/latest/?badge=latest)
45
+ ![tests](https://github.com/mj-will/aspire/actions/workflows/tests.yml/badge.svg)
46
+
47
+
42
48
  aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
43
49
 
44
50
  ## Installation
@@ -52,6 +58,45 @@ pip install aspire-inference
52
58
  **Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
53
59
  the package can be imported and used as `aspire`.
54
60
 
61
+ ## Quickstart
62
+
63
+ ```python
64
+ import numpy as np
65
+ from aspire import Aspire, Samples
66
+
67
+ # Define a log-likelihood and log-prior
68
+ def log_likelihood(samples):
69
+ x = samples.x
70
+ return -0.5 * np.sum(x**2, axis=-1)
71
+
72
+ def log_prior(samples):
73
+ return -0.5 * np.sum(samples.x**2, axis=-1)
74
+
75
+ # Create the initial samples
76
+ init = Samples(np.random.normal(size=(2_000, 4)))
77
+
78
+ # Define the aspire object
79
+ aspire = Aspire(
80
+ log_likelihood=log_likelihood,
81
+ log_prior=log_prior,
82
+ dims=4,
83
+ parameters=[f"x{i}" for i in range(4)],
84
+ )
85
+
86
+ # Fit the normalizing flow
87
+ aspire.fit(init, n_epochs=20)
88
+
89
+ # Sample the posterior
90
+ posterior = aspire.sample_posterior(
91
+ sampler="smc",
92
+ n_samples=500,
93
+ sampler_kwargs=dict(n_steps=100),
94
+ )
95
+
96
+ # Plot the posterior distribution
97
+ posterior.plot_corner()
98
+ ```
99
+
55
100
  ## Documentation
56
101
 
57
102
  See the [documentation on ReadTheDocs][docs].
@@ -1,9 +1,9 @@
1
- aspire/__init__.py,sha256=45R0xWaLg0aJEPK5zoTK0aIek0KOwpHwQWS1jLCDhIE,365
1
+ aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
2
2
  aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
3
3
  aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
4
4
  aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
5
5
  aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
6
- aspire/transforms.py,sha256=W2JL045zkczvdBVXXeSn2c8aKyX8ASxFKgWU2cJHufk,24922
6
+ aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
7
7
  aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
8
8
  aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
9
9
  aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
@@ -18,11 +18,11 @@ aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M
18
18
  aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
19
19
  aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
21
- aspire/samplers/smc/blackjax.py,sha256=4L4kgRKlaWl-knTWXXzdJTh-zZBh5BTpy5GaLDzT8Sc,11803
21
+ aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
22
22
  aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
23
23
  aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
24
- aspire_inference-0.1.0a8.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
- aspire_inference-0.1.0a8.dist-info/METADATA,sha256=Or73qfnx3KsPNSeaNTJ70t56GME2r9qACM03aSW40t8,1965
26
- aspire_inference-0.1.0a8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- aspire_inference-0.1.0a8.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
- aspire_inference-0.1.0a8.dist-info/RECORD,,
24
+ aspire_inference-0.1.0a9.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
25
+ aspire_inference-0.1.0a9.dist-info/METADATA,sha256=DCcTGComTNt9UFlbUwGPf58itBX6Y-1-edVys0pb2RQ,3187
26
+ aspire_inference-0.1.0a9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ aspire_inference-0.1.0a9.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
28
+ aspire_inference-0.1.0a9.dist-info/RECORD,,