aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ...samples import SMCSamples
|
|
7
|
+
from ...utils import track_calls
|
|
8
|
+
from .base import NumpySMCSampler
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmceeSMC(NumpySMCSampler):
|
|
14
|
+
@track_calls
|
|
15
|
+
def sample(
|
|
16
|
+
self,
|
|
17
|
+
n_samples: int,
|
|
18
|
+
n_steps: int = None,
|
|
19
|
+
adaptive: bool = True,
|
|
20
|
+
target_efficiency: float = 0.5,
|
|
21
|
+
target_efficiency_rate: float = 1.0,
|
|
22
|
+
sampler_kwargs: dict | None = None,
|
|
23
|
+
n_final_samples: int | None = None,
|
|
24
|
+
):
|
|
25
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
26
|
+
self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
|
|
27
|
+
self.sampler_kwargs.setdefault("progress", True)
|
|
28
|
+
self.emcee_moves = self.sampler_kwargs.pop("moves", None)
|
|
29
|
+
return super().sample(
|
|
30
|
+
n_samples,
|
|
31
|
+
n_steps=n_steps,
|
|
32
|
+
adaptive=adaptive,
|
|
33
|
+
target_efficiency=target_efficiency,
|
|
34
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
35
|
+
n_final_samples=n_final_samples,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
39
|
+
import emcee
|
|
40
|
+
|
|
41
|
+
logger.info("Mutating particles")
|
|
42
|
+
sampler = emcee.EnsembleSampler(
|
|
43
|
+
len(particles.x),
|
|
44
|
+
self.dims,
|
|
45
|
+
self.log_prob,
|
|
46
|
+
args=(beta,),
|
|
47
|
+
vectorize=True,
|
|
48
|
+
moves=self.emcee_moves,
|
|
49
|
+
)
|
|
50
|
+
z = self.fit_preconditioning_transform(particles.x)
|
|
51
|
+
kwargs = copy.deepcopy(self.sampler_kwargs)
|
|
52
|
+
if n_steps is not None:
|
|
53
|
+
kwargs["nsteps"] = n_steps
|
|
54
|
+
sampler.run_mcmc(z, **kwargs)
|
|
55
|
+
self.history.mcmc_acceptance.append(
|
|
56
|
+
np.mean(sampler.acceptance_fraction)
|
|
57
|
+
)
|
|
58
|
+
self.history.mcmc_autocorr.append(
|
|
59
|
+
sampler.get_autocorr_time(
|
|
60
|
+
quiet=True, discard=int(0.2 * self.sampler_kwargs["nsteps"])
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
z = sampler.get_chain(flat=False)[-1, ...]
|
|
64
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
65
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
66
|
+
samples.log_q = samples.array_to_namespace(
|
|
67
|
+
self.prior_flow.log_prob(samples.x)
|
|
68
|
+
)
|
|
69
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
70
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
71
|
+
self.log_likelihood(samples)
|
|
72
|
+
)
|
|
73
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
74
|
+
raise ValueError("Log proposal contains NaN values")
|
|
75
|
+
return samples
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ...samples import SMCSamples
|
|
6
|
+
from ...utils import to_numpy, track_calls
|
|
7
|
+
from .base import NumpySMCSampler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MiniPCNSMC(NumpySMCSampler):
|
|
11
|
+
"""MiniPCN SMC sampler."""
|
|
12
|
+
|
|
13
|
+
rng = None
|
|
14
|
+
|
|
15
|
+
def log_prob(self, x, beta=None):
|
|
16
|
+
return to_numpy(super().log_prob(x, beta))
|
|
17
|
+
|
|
18
|
+
@track_calls
|
|
19
|
+
def sample(
|
|
20
|
+
self,
|
|
21
|
+
n_samples: int,
|
|
22
|
+
n_steps: int = None,
|
|
23
|
+
min_step: float | None = None,
|
|
24
|
+
max_n_steps: int | None = None,
|
|
25
|
+
adaptive: bool = True,
|
|
26
|
+
target_efficiency: float = 0.5,
|
|
27
|
+
target_efficiency_rate: float = 1.0,
|
|
28
|
+
n_final_samples: int | None = None,
|
|
29
|
+
sampler_kwargs: dict | None = None,
|
|
30
|
+
rng: np.random.Generator | None = None,
|
|
31
|
+
):
|
|
32
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
33
|
+
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
34
|
+
self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
|
|
35
|
+
self.sampler_kwargs.setdefault("step_fn", "tpcn")
|
|
36
|
+
self.rng = rng or np.random.default_rng()
|
|
37
|
+
return super().sample(
|
|
38
|
+
n_samples,
|
|
39
|
+
n_steps=n_steps,
|
|
40
|
+
adaptive=adaptive,
|
|
41
|
+
target_efficiency=target_efficiency,
|
|
42
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
43
|
+
n_final_samples=n_final_samples,
|
|
44
|
+
min_step=min_step,
|
|
45
|
+
max_n_steps=max_n_steps,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
49
|
+
from minipcn import Sampler
|
|
50
|
+
|
|
51
|
+
log_prob_fn = partial(self.log_prob, beta=beta)
|
|
52
|
+
|
|
53
|
+
sampler = Sampler(
|
|
54
|
+
log_prob_fn=log_prob_fn,
|
|
55
|
+
step_fn=self.sampler_kwargs["step_fn"],
|
|
56
|
+
rng=self.rng,
|
|
57
|
+
dims=self.dims,
|
|
58
|
+
target_acceptance_rate=self.sampler_kwargs[
|
|
59
|
+
"target_acceptance_rate"
|
|
60
|
+
],
|
|
61
|
+
)
|
|
62
|
+
# Map to transformed dimension for sampling
|
|
63
|
+
z = to_numpy(self.fit_preconditioning_transform(particles.x))
|
|
64
|
+
chain, history = sampler.sample(
|
|
65
|
+
z,
|
|
66
|
+
n_steps=n_steps or self.sampler_kwargs["n_steps"],
|
|
67
|
+
)
|
|
68
|
+
x = self.preconditioning_transform.inverse(chain[-1])[0]
|
|
69
|
+
|
|
70
|
+
self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
|
|
71
|
+
|
|
72
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
73
|
+
samples.log_q = samples.array_to_namespace(
|
|
74
|
+
self.prior_flow.log_prob(samples.x)
|
|
75
|
+
)
|
|
76
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
77
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
78
|
+
self.log_likelihood(samples)
|
|
79
|
+
)
|
|
80
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
81
|
+
raise ValueError("Log proposal contains NaN values")
|
|
82
|
+
return samples
|