aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
aspire/history.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
from matplotlib.figure import Figure
|
|
8
|
+
|
|
9
|
+
from .utils import recursively_save_to_h5_file
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class History:
|
|
14
|
+
"""Base class for storing history of a sampler."""
|
|
15
|
+
|
|
16
|
+
def save(self, h5_file, path="history"):
|
|
17
|
+
"""Save the history to an HDF5 file."""
|
|
18
|
+
dictionary = copy.deepcopy(self.__dict__)
|
|
19
|
+
recursively_save_to_h5_file(h5_file, path, dictionary)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class FlowHistory(History):
|
|
24
|
+
training_loss: list[float] = field(default_factory=list)
|
|
25
|
+
validation_loss: list[float] = field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
def plot_loss(self) -> Figure:
|
|
28
|
+
"""Plot the training and validation loss."""
|
|
29
|
+
fig = plt.figure()
|
|
30
|
+
plt.plot(self.training_loss, label="Training loss")
|
|
31
|
+
plt.plot(self.validation_loss, label="Validation loss")
|
|
32
|
+
plt.legend()
|
|
33
|
+
plt.xlabel("Epoch")
|
|
34
|
+
plt.ylabel("Loss")
|
|
35
|
+
return fig
|
|
36
|
+
|
|
37
|
+
def save(self, h5_file, path="flow_history"):
|
|
38
|
+
"""Save the history to an HDF5 file."""
|
|
39
|
+
super().save(h5_file, path=path)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class SMCHistory(History):
|
|
44
|
+
log_norm_ratio: list[float] = field(default_factory=list)
|
|
45
|
+
log_norm_ratio_var: list[float] = field(default_factory=list)
|
|
46
|
+
beta: list[float] = field(default_factory=list)
|
|
47
|
+
ess: list[float] = field(default_factory=list)
|
|
48
|
+
ess_target: list[float] = field(default_factory=list)
|
|
49
|
+
eff_target: list[float] = field(default_factory=list)
|
|
50
|
+
mcmc_autocorr: list[float] = field(default_factory=list)
|
|
51
|
+
mcmc_acceptance: list[float] = field(default_factory=list)
|
|
52
|
+
|
|
53
|
+
def save(self, h5_file, path="smc_history"):
|
|
54
|
+
"""Save the history to an HDF5 file."""
|
|
55
|
+
super().save(h5_file, path=path)
|
|
56
|
+
|
|
57
|
+
def plot_beta(self, ax=None) -> Figure | None:
|
|
58
|
+
if ax is None:
|
|
59
|
+
fig, ax = plt.subplots()
|
|
60
|
+
else:
|
|
61
|
+
fig = None
|
|
62
|
+
ax.plot(self.beta)
|
|
63
|
+
ax.set_xlabel("Iteration")
|
|
64
|
+
ax.set_ylabel(r"$\beta$")
|
|
65
|
+
return fig
|
|
66
|
+
|
|
67
|
+
def plot_log_norm_ratio(self, ax=None) -> Figure | None:
|
|
68
|
+
if ax is None:
|
|
69
|
+
fig, ax = plt.subplots()
|
|
70
|
+
else:
|
|
71
|
+
fig = None
|
|
72
|
+
ax.plot(self.log_norm_ratio)
|
|
73
|
+
ax.set_xlabel("Iteration")
|
|
74
|
+
ax.set_ylabel("Log evidence ratio")
|
|
75
|
+
return fig
|
|
76
|
+
|
|
77
|
+
def plot_ess(self, ax=None) -> Figure | None:
|
|
78
|
+
if ax is None:
|
|
79
|
+
fig, ax = plt.subplots()
|
|
80
|
+
else:
|
|
81
|
+
fig = None
|
|
82
|
+
ax.plot(self.ess)
|
|
83
|
+
ax.set_xlabel("Iteration")
|
|
84
|
+
ax.set_ylabel("ESS")
|
|
85
|
+
return fig
|
|
86
|
+
|
|
87
|
+
def plot_ess_target(self, ax=None) -> Figure | None:
|
|
88
|
+
if ax is None:
|
|
89
|
+
fig, ax = plt.subplots()
|
|
90
|
+
else:
|
|
91
|
+
fig = None
|
|
92
|
+
ax.plot(self.ess_target)
|
|
93
|
+
ax.set_xlabel("Iteration")
|
|
94
|
+
ax.set_ylabel("ESS target")
|
|
95
|
+
return fig
|
|
96
|
+
|
|
97
|
+
def plot_eff_target(self, ax=None) -> Figure | None:
|
|
98
|
+
if ax is None:
|
|
99
|
+
fig, ax = plt.subplots()
|
|
100
|
+
else:
|
|
101
|
+
fig = None
|
|
102
|
+
ax.plot(self.eff_target)
|
|
103
|
+
ax.set_xlabel("Iteration")
|
|
104
|
+
ax.set_ylabel("Efficiency target")
|
|
105
|
+
return fig
|
|
106
|
+
|
|
107
|
+
def plot_mcmc_acceptance(self, ax=None) -> Figure | None:
|
|
108
|
+
if ax is None:
|
|
109
|
+
fig, ax = plt.subplots()
|
|
110
|
+
else:
|
|
111
|
+
fig = None
|
|
112
|
+
ax.plot(self.mcmc_acceptance)
|
|
113
|
+
ax.set_xlabel("Iteration")
|
|
114
|
+
ax.set_ylabel("MCMC Acceptance")
|
|
115
|
+
return fig
|
|
116
|
+
|
|
117
|
+
def plot_mcmc_autocorr(self, ax=None) -> Figure | None:
|
|
118
|
+
if ax is None:
|
|
119
|
+
fig, ax = plt.subplots()
|
|
120
|
+
else:
|
|
121
|
+
fig = None
|
|
122
|
+
ax.plot(self.mcmc_autocorr)
|
|
123
|
+
ax.set_xlabel("Iteration")
|
|
124
|
+
ax.set_ylabel("MCMC Autocorr")
|
|
125
|
+
return fig
|
|
126
|
+
|
|
127
|
+
def plot(self, fig: Figure | None = None) -> Figure:
|
|
128
|
+
methods = [
|
|
129
|
+
self.plot_beta,
|
|
130
|
+
self.plot_log_norm_ratio,
|
|
131
|
+
self.plot_ess,
|
|
132
|
+
self.plot_ess_target,
|
|
133
|
+
self.plot_eff_target,
|
|
134
|
+
self.plot_mcmc_acceptance,
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if fig is None:
|
|
138
|
+
fig, axs = plt.subplots(len(methods), 1, sharex=True)
|
|
139
|
+
else:
|
|
140
|
+
axs = fig.axes
|
|
141
|
+
|
|
142
|
+
for method, ax in zip(methods, axs):
|
|
143
|
+
method(ax)
|
|
144
|
+
|
|
145
|
+
for ax in axs[:-1]:
|
|
146
|
+
ax.set_xlabel("")
|
|
147
|
+
|
|
148
|
+
return fig
|
aspire/plot.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def plot_comparison(
|
|
5
|
+
*samples, parameters=None, per_samples_kwargs=None, labels=None, **kwargs
|
|
6
|
+
):
|
|
7
|
+
"""
|
|
8
|
+
Plot a comparison of multiple samples.
|
|
9
|
+
"""
|
|
10
|
+
default_kwargs = dict(
|
|
11
|
+
density=True,
|
|
12
|
+
bins=30,
|
|
13
|
+
color="C0",
|
|
14
|
+
smooth=1.0,
|
|
15
|
+
plot_datapoints=True,
|
|
16
|
+
plot_density=False,
|
|
17
|
+
hist_kwargs=dict(density=True, color="C0"),
|
|
18
|
+
)
|
|
19
|
+
default_kwargs.update(kwargs)
|
|
20
|
+
|
|
21
|
+
if per_samples_kwargs is None:
|
|
22
|
+
per_samples_kwargs = [{}] * len(samples)
|
|
23
|
+
|
|
24
|
+
fig = None
|
|
25
|
+
for i, sample in enumerate(samples):
|
|
26
|
+
kwds = copy.deepcopy(default_kwargs)
|
|
27
|
+
color = per_samples_kwargs[i].pop("color", f"C{i}")
|
|
28
|
+
kwds["color"] = color
|
|
29
|
+
kwds["hist_kwargs"]["color"] = color
|
|
30
|
+
kwds.update(per_samples_kwargs[i])
|
|
31
|
+
fig = sample.plot_corner(fig=fig, parameters=parameters, **kwds)
|
|
32
|
+
|
|
33
|
+
if labels:
|
|
34
|
+
fig.legend(
|
|
35
|
+
labels=labels,
|
|
36
|
+
loc="upper right",
|
|
37
|
+
bbox_to_anchor=(0.9, 0.9),
|
|
38
|
+
bbox_transform=fig.transFigure,
|
|
39
|
+
)
|
|
40
|
+
return fig
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_history_comparison(*histories):
|
|
44
|
+
# Assert that all histories are of the same type
|
|
45
|
+
if not all(isinstance(h, histories[0].__class__) for h in histories):
|
|
46
|
+
raise ValueError("All histories must be of the same type")
|
|
47
|
+
fig = histories[0].plot()
|
|
48
|
+
for history in histories[1:]:
|
|
49
|
+
fig = history.plot(fig=fig)
|
|
50
|
+
return fig
|
|
File without changes
|
aspire/samplers/base.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
from ..flows.base import Flow
|
|
5
|
+
from ..samples import Samples
|
|
6
|
+
from ..transforms import IdentityTransform
|
|
7
|
+
from ..utils import track_calls
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Sampler:
|
|
13
|
+
"""Base class for all samplers.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
log_likelihood : Callable
|
|
18
|
+
The log likelihood function.
|
|
19
|
+
log_prior : Callable
|
|
20
|
+
The log prior function.
|
|
21
|
+
dims : int
|
|
22
|
+
The number of dimensions.
|
|
23
|
+
flow : Flow
|
|
24
|
+
The flow object.
|
|
25
|
+
xp : Callable
|
|
26
|
+
The array backend to use.
|
|
27
|
+
parameters : list[str] | None
|
|
28
|
+
The list of parameter names. If None, any samples objects will not
|
|
29
|
+
have the parameters names specified.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
log_likelihood: Callable,
|
|
35
|
+
log_prior: Callable,
|
|
36
|
+
dims: int,
|
|
37
|
+
prior_flow: Flow,
|
|
38
|
+
xp: Callable,
|
|
39
|
+
dtype: Any | str | None = None,
|
|
40
|
+
parameters: list[str] | None = None,
|
|
41
|
+
preconditioning_transform: Callable | None = None,
|
|
42
|
+
):
|
|
43
|
+
self.prior_flow = prior_flow
|
|
44
|
+
self._log_likelihood = log_likelihood
|
|
45
|
+
self.log_prior = log_prior
|
|
46
|
+
self.dims = dims
|
|
47
|
+
self.xp = xp
|
|
48
|
+
self.dtype = dtype
|
|
49
|
+
self.parameters = parameters
|
|
50
|
+
self.history = None
|
|
51
|
+
self.n_likelihood_evaluations = 0
|
|
52
|
+
if preconditioning_transform is None:
|
|
53
|
+
self.preconditioning_transform = IdentityTransform(xp=self.xp)
|
|
54
|
+
else:
|
|
55
|
+
self.preconditioning_transform = preconditioning_transform
|
|
56
|
+
|
|
57
|
+
def fit_preconditioning_transform(self, x):
|
|
58
|
+
"""Fit the data transform to the data."""
|
|
59
|
+
x = self.preconditioning_transform.xp.asarray(x)
|
|
60
|
+
return self.preconditioning_transform.fit(x)
|
|
61
|
+
|
|
62
|
+
@track_calls
|
|
63
|
+
def sample(self, n_samples: int) -> Samples:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def log_likelihood(self, samples: Samples) -> Samples:
|
|
67
|
+
"""Computes the log likelihood of the samples.
|
|
68
|
+
|
|
69
|
+
Also tracks the number of likelihood evaluations.
|
|
70
|
+
"""
|
|
71
|
+
self.n_likelihood_evaluations += len(samples)
|
|
72
|
+
return self._log_likelihood(samples)
|
|
73
|
+
|
|
74
|
+
def config_dict(self, include_sample_calls: bool = True) -> dict:
|
|
75
|
+
"""
|
|
76
|
+
Returns a dictionary with the configuration of the sampler.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
include_sample_calls : bool
|
|
81
|
+
Whether to include the sample calls in the configuration.
|
|
82
|
+
Default is True.
|
|
83
|
+
"""
|
|
84
|
+
config = {}
|
|
85
|
+
if include_sample_calls:
|
|
86
|
+
if hasattr(self, "sample") and hasattr(self.sample, "calls"):
|
|
87
|
+
config["sample_calls"] = self.sample.calls.to_dict(
|
|
88
|
+
list_to_dict=True
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
logger.warning(
|
|
92
|
+
"Sampler does not have a sample method with calls attribute."
|
|
93
|
+
)
|
|
94
|
+
return config
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ..samples import Samples
|
|
2
|
+
from ..utils import track_calls
|
|
3
|
+
from .base import Sampler
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ImportanceSampler(Sampler):
|
|
7
|
+
@track_calls
|
|
8
|
+
def sample(self, n_samples: int) -> Samples:
|
|
9
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
10
|
+
samples = Samples(
|
|
11
|
+
x,
|
|
12
|
+
log_q=log_q,
|
|
13
|
+
xp=self.xp,
|
|
14
|
+
parameters=self.parameters,
|
|
15
|
+
dtype=self.dtype,
|
|
16
|
+
)
|
|
17
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
18
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
19
|
+
self.log_likelihood(samples)
|
|
20
|
+
)
|
|
21
|
+
samples.compute_weights()
|
|
22
|
+
return samples
|
aspire/samplers/mcmc.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from ..samples import Samples, to_numpy
|
|
4
|
+
from ..utils import track_calls
|
|
5
|
+
from .base import Sampler
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MCMCSampler(Sampler):
|
|
9
|
+
def draw_initial_samples(self, n_samples: int) -> Samples:
|
|
10
|
+
"""Draw initial samples from the prior flow."""
|
|
11
|
+
# Flow may propose samples outside prior bounds, so we may need
|
|
12
|
+
# to try multiple times to get enough valid samples.
|
|
13
|
+
n_samples_drawn = 0
|
|
14
|
+
samples = None
|
|
15
|
+
while n_samples_drawn < n_samples:
|
|
16
|
+
n_to_draw = n_samples - n_samples_drawn
|
|
17
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
|
|
18
|
+
new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
|
|
19
|
+
new_samples.log_prior = new_samples.array_to_namespace(
|
|
20
|
+
self.log_prior(new_samples)
|
|
21
|
+
)
|
|
22
|
+
valid = self.xp.isfinite(new_samples.log_prior)
|
|
23
|
+
n_valid = self.xp.sum(valid)
|
|
24
|
+
if n_valid > 0:
|
|
25
|
+
if samples is None:
|
|
26
|
+
samples = new_samples[valid]
|
|
27
|
+
else:
|
|
28
|
+
samples = Samples.concatenate(
|
|
29
|
+
[samples, new_samples[valid]]
|
|
30
|
+
)
|
|
31
|
+
n_samples_drawn += n_valid
|
|
32
|
+
|
|
33
|
+
if n_samples_drawn > n_samples:
|
|
34
|
+
samples = samples[:n_samples]
|
|
35
|
+
|
|
36
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
37
|
+
self.log_likelihood(samples)
|
|
38
|
+
)
|
|
39
|
+
return samples
|
|
40
|
+
|
|
41
|
+
def log_prob(self, z):
|
|
42
|
+
"""Compute the log probability of the samples.
|
|
43
|
+
|
|
44
|
+
Input samples are in the transformed space.
|
|
45
|
+
"""
|
|
46
|
+
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
47
|
+
samples = Samples(x, xp=self.xp, dtype=self.dtype)
|
|
48
|
+
samples.log_prior = self.log_prior(samples)
|
|
49
|
+
samples.log_likelihood = self.log_likelihood(samples)
|
|
50
|
+
log_prob = (
|
|
51
|
+
samples.log_likelihood
|
|
52
|
+
+ samples.log_prior
|
|
53
|
+
+ samples.array_to_namespace(log_abs_det_jacobian)
|
|
54
|
+
)
|
|
55
|
+
return to_numpy(log_prob).flatten()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Emcee(MCMCSampler):
|
|
59
|
+
@track_calls
|
|
60
|
+
def sample(
|
|
61
|
+
self,
|
|
62
|
+
n_samples: int,
|
|
63
|
+
nwalkers: int = None,
|
|
64
|
+
nsteps: int = 500,
|
|
65
|
+
rng=None,
|
|
66
|
+
discard=0,
|
|
67
|
+
**kwargs,
|
|
68
|
+
) -> Samples:
|
|
69
|
+
from emcee import EnsembleSampler
|
|
70
|
+
|
|
71
|
+
nwalkers = nwalkers or n_samples
|
|
72
|
+
self.sampler = EnsembleSampler(
|
|
73
|
+
nwalkers,
|
|
74
|
+
self.dims,
|
|
75
|
+
log_prob_fn=self.log_prob,
|
|
76
|
+
vectorize=True,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
rng = rng or np.random.default_rng()
|
|
80
|
+
|
|
81
|
+
samples = self.draw_initial_samples(nwalkers)
|
|
82
|
+
p0 = samples.x
|
|
83
|
+
|
|
84
|
+
z0 = to_numpy(self.preconditioning_transform.fit(p0))
|
|
85
|
+
|
|
86
|
+
self.sampler.run_mcmc(z0, nsteps, **kwargs)
|
|
87
|
+
|
|
88
|
+
z = self.sampler.get_chain(flat=True, discard=discard)
|
|
89
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
90
|
+
|
|
91
|
+
x_evidence, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
92
|
+
samples_evidence = Samples(x_evidence, log_q=log_q, xp=self.xp)
|
|
93
|
+
samples_evidence.log_prior = self.log_prior(samples_evidence)
|
|
94
|
+
samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
|
|
95
|
+
samples_evidence.compute_weights()
|
|
96
|
+
|
|
97
|
+
samples_mcmc = Samples(
|
|
98
|
+
x, xp=self.xp, parameters=self.parameters, dtype=self.dtype
|
|
99
|
+
)
|
|
100
|
+
samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
|
|
101
|
+
self.log_prior(samples_mcmc)
|
|
102
|
+
)
|
|
103
|
+
samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
|
|
104
|
+
self.log_likelihood(samples_mcmc)
|
|
105
|
+
)
|
|
106
|
+
samples_mcmc.log_evidence = samples_mcmc.array_to_namespace(
|
|
107
|
+
samples_evidence.log_evidence
|
|
108
|
+
)
|
|
109
|
+
samples_mcmc.log_evidence_error = samples_mcmc.array_to_namespace(
|
|
110
|
+
samples_evidence.log_evidence_error
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return samples_mcmc
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class MiniPCN(MCMCSampler):
|
|
117
|
+
@track_calls
|
|
118
|
+
def sample(
|
|
119
|
+
self,
|
|
120
|
+
n_samples,
|
|
121
|
+
rng=None,
|
|
122
|
+
target_acceptance_rate=0.234,
|
|
123
|
+
n_steps=100,
|
|
124
|
+
thin=1,
|
|
125
|
+
burnin=0,
|
|
126
|
+
last_step_only=False,
|
|
127
|
+
step_fn="tpcn",
|
|
128
|
+
):
|
|
129
|
+
from minipcn import Sampler
|
|
130
|
+
|
|
131
|
+
rng = rng or np.random.default_rng()
|
|
132
|
+
p0 = self.draw_initial_samples(n_samples).x
|
|
133
|
+
|
|
134
|
+
z0 = to_numpy(self.preconditioning_transform.fit(p0))
|
|
135
|
+
|
|
136
|
+
self.sampler = Sampler(
|
|
137
|
+
log_prob_fn=self.log_prob,
|
|
138
|
+
step_fn=step_fn,
|
|
139
|
+
rng=rng,
|
|
140
|
+
dims=self.dims,
|
|
141
|
+
target_acceptance_rate=target_acceptance_rate,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
chain, history = self.sampler.sample(z0, n_steps=n_steps)
|
|
145
|
+
|
|
146
|
+
if last_step_only:
|
|
147
|
+
z = chain[-1]
|
|
148
|
+
else:
|
|
149
|
+
z = chain[burnin::thin].reshape(-1, self.dims)
|
|
150
|
+
|
|
151
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
152
|
+
|
|
153
|
+
samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
|
|
154
|
+
samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
|
|
155
|
+
self.log_prior(samples_mcmc)
|
|
156
|
+
)
|
|
157
|
+
samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
|
|
158
|
+
self.log_likelihood(samples_mcmc)
|
|
159
|
+
)
|
|
160
|
+
return samples_mcmc
|
|
File without changes
|