aspire-inference 0.1.0a8__py3-none-any.whl → 0.1.0a9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +2 -0
- aspire/samplers/smc/blackjax.py +17 -5
- aspire/transforms.py +1 -0
- {aspire_inference-0.1.0a8.dist-info → aspire_inference-0.1.0a9.dist-info}/METADATA +46 -1
- {aspire_inference-0.1.0a8.dist-info → aspire_inference-0.1.0a9.dist-info}/RECORD +8 -8
- {aspire_inference-0.1.0a8.dist-info → aspire_inference-0.1.0a9.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a8.dist-info → aspire_inference-0.1.0a9.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a8.dist-info → aspire_inference-0.1.0a9.dist-info}/top_level.txt +0 -0
aspire/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ import logging
|
|
|
6
6
|
from importlib.metadata import PackageNotFoundError, version
|
|
7
7
|
|
|
8
8
|
from .aspire import Aspire
|
|
9
|
+
from .samples import Samples
|
|
9
10
|
|
|
10
11
|
try:
|
|
11
12
|
__version__ = version("aspire")
|
|
@@ -16,4 +17,5 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
|
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
19
|
"Aspire",
|
|
20
|
+
"Samples",
|
|
19
21
|
]
|
aspire/samplers/smc/blackjax.py
CHANGED
|
@@ -220,7 +220,9 @@ class BlackJAXSMC(SMCSampler):
|
|
|
220
220
|
|
|
221
221
|
elif algorithm == "nuts":
|
|
222
222
|
# Initialize step size and mass matrix if not provided
|
|
223
|
-
inverse_mass_matrix = self.sampler_kwargs
|
|
223
|
+
inverse_mass_matrix = self.sampler_kwargs.get(
|
|
224
|
+
"inverse_mass_matrix"
|
|
225
|
+
)
|
|
224
226
|
if inverse_mass_matrix is None:
|
|
225
227
|
inverse_mass_matrix = jax.numpy.eye(self.dims)
|
|
226
228
|
|
|
@@ -258,8 +260,13 @@ class BlackJAXSMC(SMCSampler):
|
|
|
258
260
|
z_final = final_states.position
|
|
259
261
|
|
|
260
262
|
# Calculate acceptance rates
|
|
261
|
-
|
|
262
|
-
|
|
263
|
+
try:
|
|
264
|
+
acceptance_rates = jax.numpy.mean(
|
|
265
|
+
all_infos.is_accepted, axis=1
|
|
266
|
+
)
|
|
267
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
268
|
+
except AttributeError:
|
|
269
|
+
mean_acceptance = np.nan
|
|
263
270
|
|
|
264
271
|
elif algorithm == "hmc":
|
|
265
272
|
# Initialize HMC sampler
|
|
@@ -294,8 +301,13 @@ class BlackJAXSMC(SMCSampler):
|
|
|
294
301
|
|
|
295
302
|
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
296
303
|
z_final = final_states.position
|
|
297
|
-
|
|
298
|
-
|
|
304
|
+
try:
|
|
305
|
+
acceptance_rates = jax.numpy.mean(
|
|
306
|
+
all_infos.is_accepted, axis=1
|
|
307
|
+
)
|
|
308
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
309
|
+
except AttributeError:
|
|
310
|
+
mean_acceptance = np.nan
|
|
299
311
|
|
|
300
312
|
else:
|
|
301
313
|
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
aspire/transforms.py
CHANGED
|
@@ -332,6 +332,7 @@ class CompositeTransform(BaseTransform):
|
|
|
332
332
|
prior_bounds=self.prior_bounds,
|
|
333
333
|
bounded_to_unbounded=self.bounded_to_unbounded,
|
|
334
334
|
bounded_transform=self.bounded_transform,
|
|
335
|
+
affine_transform=self.affine_transform,
|
|
335
336
|
device=self.device,
|
|
336
337
|
xp=xp or self.xp,
|
|
337
338
|
eps=self.eps,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: aspire-inference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a9
|
|
4
4
|
Summary: Accelerate Sequential Posterior Inference via REuse
|
|
5
5
|
Author-email: "Michael J. Williams" <michaeljw1@googlemail.com>
|
|
6
6
|
License: MIT
|
|
@@ -39,6 +39,12 @@ Dynamic: license-file
|
|
|
39
39
|
|
|
40
40
|
# aspire: Accelerated Sequential Posterior Inference via REuse
|
|
41
41
|
|
|
42
|
+
[](https://doi.org/10.5281/zenodo.15658747)
|
|
43
|
+
[](https://pypi.org/project/aspire-inference/)
|
|
44
|
+
[](https://aspire.readthedocs.io/en/latest/?badge=latest)
|
|
45
|
+

|
|
46
|
+
|
|
47
|
+
|
|
42
48
|
aspire is a framework for reusing existing posterior samples to obtain new results at a reduced cost.
|
|
43
49
|
|
|
44
50
|
## Installation
|
|
@@ -52,6 +58,45 @@ pip install aspire-inference
|
|
|
52
58
|
**Important:** the name of `aspire` on PyPI is `aspire-inference` but once installed
|
|
53
59
|
the package can be imported and used as `aspire`.
|
|
54
60
|
|
|
61
|
+
## Quickstart
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import numpy as np
|
|
65
|
+
from aspire import Aspire, Samples
|
|
66
|
+
|
|
67
|
+
# Define a log-likelihood and log-prior
|
|
68
|
+
def log_likelihood(samples):
|
|
69
|
+
x = samples.x
|
|
70
|
+
return -0.5 * np.sum(x**2, axis=-1)
|
|
71
|
+
|
|
72
|
+
def log_prior(samples):
|
|
73
|
+
return -0.5 * np.sum(samples.x**2, axis=-1)
|
|
74
|
+
|
|
75
|
+
# Create the initial samples
|
|
76
|
+
init = Samples(np.random.normal(size=(2_000, 4)))
|
|
77
|
+
|
|
78
|
+
# Define the aspire object
|
|
79
|
+
aspire = Aspire(
|
|
80
|
+
log_likelihood=log_likelihood,
|
|
81
|
+
log_prior=log_prior,
|
|
82
|
+
dims=4,
|
|
83
|
+
parameters=[f"x{i}" for i in range(4)],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Fit the normalizing flow
|
|
87
|
+
aspire.fit(init, n_epochs=20)
|
|
88
|
+
|
|
89
|
+
# Sample the posterior
|
|
90
|
+
posterior = aspire.sample_posterior(
|
|
91
|
+
sampler="smc",
|
|
92
|
+
n_samples=500,
|
|
93
|
+
sampler_kwargs=dict(n_steps=100),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Plot the posterior distribution
|
|
97
|
+
posterior.plot_corner()
|
|
98
|
+
```
|
|
99
|
+
|
|
55
100
|
## Documentation
|
|
56
101
|
|
|
57
102
|
See the [documentation on ReadTheDocs][docs].
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
aspire/__init__.py,sha256=
|
|
1
|
+
aspire/__init__.py,sha256=B2IETvlpB0oBh57prRYLdi8jB5yFGw8qVviGdf1NcnE,409
|
|
2
2
|
aspire/aspire.py,sha256=M5o-QxLthE_5daa1trgUfWxPz-g4rmpEUKimKosw4lw,17400
|
|
3
3
|
aspire/history.py,sha256=l_j-riZKbTWK7Wz9zvvD_mTk9psNCKItiveYhr_pYv8,4313
|
|
4
4
|
aspire/plot.py,sha256=oXwUDOb_953_ADm2KLk41JIfpE3JeiiQiSYKvUVwLqw,1423
|
|
5
5
|
aspire/samples.py,sha256=nVQULOr19lQVyGhitI8EgDdCDGx9sExPomQBJrV4rTc,19237
|
|
6
|
-
aspire/transforms.py,sha256=
|
|
6
|
+
aspire/transforms.py,sha256=CHrfPQHEyHQ9I0WWiAgUplWwxYypZsD7uCYIHUbSFtY,24974
|
|
7
7
|
aspire/utils.py,sha256=pj8O0chqfP6VS8bpW0wCw8W0P5JNQKvWRz1Rg9AYIhg,22525
|
|
8
8
|
aspire/flows/__init__.py,sha256=GUZToPVNJoTwULpbeW10UijfQukNrILoAQ_ubeq7G3w,2110
|
|
9
9
|
aspire/flows/base.py,sha256=5UWKAiXDXLJ6Sg6a380ajLrGFaZSQyOnFEihQiiA4ko,2237
|
|
@@ -18,11 +18,11 @@ aspire/samplers/importance.py,sha256=opn_jY-V8snUz0JztLBtnaTT3WfrZ5OSpHBV5WAuM3M
|
|
|
18
18
|
aspire/samplers/mcmc.py,sha256=ihHgzqvSoy1oxdFBjyqNUbCuRX7CqWjlshCUZcgEL5E,5151
|
|
19
19
|
aspire/samplers/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
20
|
aspire/samplers/smc/base.py,sha256=66f_ORUvcKRqMIW35qjhUc-c0PFuY87lJa91MpSaTZI,10729
|
|
21
|
-
aspire/samplers/smc/blackjax.py,sha256=
|
|
21
|
+
aspire/samplers/smc/blackjax.py,sha256=IcTguAETiPmgFofmVW2GN40P5HBIxkmyd2VR8AU8f4k,12115
|
|
22
22
|
aspire/samplers/smc/emcee.py,sha256=Wm0vvAlCcRhJMBt7_fU2ZnjDb8SN8jgUOTXLzNstRpA,2516
|
|
23
23
|
aspire/samplers/smc/minipcn.py,sha256=ju1gcgyKHjodLEACPdL3eXA9ai8ZJ9_LwitD_Gmf1Rc,2765
|
|
24
|
-
aspire_inference-0.1.
|
|
25
|
-
aspire_inference-0.1.
|
|
26
|
-
aspire_inference-0.1.
|
|
27
|
-
aspire_inference-0.1.
|
|
28
|
-
aspire_inference-0.1.
|
|
24
|
+
aspire_inference-0.1.0a9.dist-info/licenses/LICENSE,sha256=DN-eRtBfS9dZyT0Ds0Mdn2Y4nb-ZQ7h71vpASYBm5k4,1076
|
|
25
|
+
aspire_inference-0.1.0a9.dist-info/METADATA,sha256=DCcTGComTNt9UFlbUwGPf58itBX6Y-1-edVys0pb2RQ,3187
|
|
26
|
+
aspire_inference-0.1.0a9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
aspire_inference-0.1.0a9.dist-info/top_level.txt,sha256=9FRIYEl2xwVFG7jSOBHsElHQ0y3_4fq01Cf4_OyMQn8,7
|
|
28
|
+
aspire_inference-0.1.0a9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|