aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import array_api_compat.torch as torch_api
|
|
6
|
+
import torch
|
|
7
|
+
import tqdm
|
|
8
|
+
import zuko
|
|
9
|
+
from array_api_compat import is_numpy_namespace, is_torch_array
|
|
10
|
+
|
|
11
|
+
from ...history import FlowHistory
|
|
12
|
+
from ..base import Flow
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseTorchFlow(Flow):
|
|
18
|
+
_flow = None
|
|
19
|
+
xp = torch_api
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dims: int,
|
|
24
|
+
seed: int = 1234,
|
|
25
|
+
device: str = "cpu",
|
|
26
|
+
data_transform=None,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(
|
|
29
|
+
dims,
|
|
30
|
+
device=torch.device(device or "cpu"),
|
|
31
|
+
data_transform=data_transform,
|
|
32
|
+
)
|
|
33
|
+
torch.manual_seed(seed)
|
|
34
|
+
self.loc = None
|
|
35
|
+
self.scale = None
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def flow(self):
|
|
39
|
+
return self._flow
|
|
40
|
+
|
|
41
|
+
@flow.setter
|
|
42
|
+
def flow(self, flow):
|
|
43
|
+
self._flow = flow
|
|
44
|
+
self._flow.to(self.device)
|
|
45
|
+
self._flow.compile()
|
|
46
|
+
|
|
47
|
+
def fit(self, x) -> FlowHistory:
|
|
48
|
+
raise NotImplementedError()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ZukoFlow(BaseTorchFlow):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
dims,
|
|
55
|
+
flow_class: str | Callable = "MAF",
|
|
56
|
+
data_transform=None,
|
|
57
|
+
seed=1234,
|
|
58
|
+
device: str = "cpu",
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(
|
|
62
|
+
dims,
|
|
63
|
+
device=device,
|
|
64
|
+
data_transform=data_transform,
|
|
65
|
+
seed=seed,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if isinstance(flow_class, str):
|
|
69
|
+
FlowClass = getattr(zuko.flows, flow_class)
|
|
70
|
+
else:
|
|
71
|
+
FlowClass = flow_class
|
|
72
|
+
|
|
73
|
+
# Ints are some times passed as strings, so we convert them
|
|
74
|
+
if hidden_features := kwargs.pop("hidden_features", None):
|
|
75
|
+
kwargs["hidden_features"] = list(map(int, hidden_features))
|
|
76
|
+
|
|
77
|
+
self.flow = FlowClass(self.dims, 0, **kwargs)
|
|
78
|
+
logger.info(f"Initialized normalizing flow: \n {self.flow}\n")
|
|
79
|
+
|
|
80
|
+
def loss_fn(self, x):
|
|
81
|
+
return -self.flow().log_prob(x).mean()
|
|
82
|
+
|
|
83
|
+
def fit(
|
|
84
|
+
self,
|
|
85
|
+
x,
|
|
86
|
+
n_epochs: int = 100,
|
|
87
|
+
lr: float = 1e-3,
|
|
88
|
+
batch_size: int = 500,
|
|
89
|
+
validation_fraction: float = 0.2,
|
|
90
|
+
clip_grad: float | None = None,
|
|
91
|
+
lr_annealing: bool = False,
|
|
92
|
+
):
|
|
93
|
+
from ...history import FlowHistory
|
|
94
|
+
|
|
95
|
+
if not is_torch_array(x):
|
|
96
|
+
x = torch.tensor(
|
|
97
|
+
x, dtype=torch.get_default_dtype(), device=self.device
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
x = torch.clone(x)
|
|
101
|
+
x = x.type(torch.get_default_dtype())
|
|
102
|
+
x = x.to(self.device)
|
|
103
|
+
x_prime = self.fit_data_transform(x)
|
|
104
|
+
indices = torch.randperm(x_prime.shape[0])
|
|
105
|
+
x_prime = x_prime[indices, ...]
|
|
106
|
+
|
|
107
|
+
n = x_prime.shape[0]
|
|
108
|
+
x_train = torch.as_tensor(
|
|
109
|
+
x_prime[: -int(validation_fraction * n)],
|
|
110
|
+
dtype=torch.get_default_dtype(),
|
|
111
|
+
device=self.device,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
logger.info(
|
|
115
|
+
f"Training on {x_train.shape[0]} samples, "
|
|
116
|
+
f"validating on {x_prime.shape[0] - x_train.shape[0]} samples."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if torch.isnan(x_train).any():
|
|
120
|
+
raise ValueError("Training data contains NaN values.")
|
|
121
|
+
if not torch.isfinite(x_train).all():
|
|
122
|
+
raise ValueError("Training data contains infinite values.")
|
|
123
|
+
|
|
124
|
+
x_val = torch.as_tensor(
|
|
125
|
+
x_prime[-int(validation_fraction * n) :],
|
|
126
|
+
dtype=torch.get_default_dtype(),
|
|
127
|
+
device=self.device,
|
|
128
|
+
)
|
|
129
|
+
if torch.isnan(x_val).any():
|
|
130
|
+
raise ValueError("Validation data contains infinite values.")
|
|
131
|
+
|
|
132
|
+
if not torch.isfinite(x_val).all():
|
|
133
|
+
raise ValueError("Validation data contains infinite values.")
|
|
134
|
+
|
|
135
|
+
dataset = torch.utils.data.DataLoader(
|
|
136
|
+
torch.utils.data.TensorDataset(x_train),
|
|
137
|
+
shuffle=True,
|
|
138
|
+
batch_size=batch_size,
|
|
139
|
+
)
|
|
140
|
+
val_dataset = torch.utils.data.DataLoader(
|
|
141
|
+
torch.utils.data.TensorDataset(x_val),
|
|
142
|
+
shuffle=False,
|
|
143
|
+
batch_size=batch_size,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Train to maximize the log-likelihood
|
|
147
|
+
optimizer = torch.optim.Adam(self._flow.parameters(), lr=lr)
|
|
148
|
+
if lr_annealing:
|
|
149
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
150
|
+
optimizer, n_epochs
|
|
151
|
+
)
|
|
152
|
+
history = FlowHistory()
|
|
153
|
+
|
|
154
|
+
best_val_loss = float("inf")
|
|
155
|
+
best_flow_state = None
|
|
156
|
+
|
|
157
|
+
with tqdm.tqdm(range(n_epochs), desc="Epochs") as pbar:
|
|
158
|
+
for _ in pbar:
|
|
159
|
+
self.flow.train()
|
|
160
|
+
loss_epoch = 0.0
|
|
161
|
+
for (x_batch,) in dataset:
|
|
162
|
+
loss = self.loss_fn(x_batch)
|
|
163
|
+
optimizer.zero_grad()
|
|
164
|
+
loss.backward()
|
|
165
|
+
if clip_grad is not None:
|
|
166
|
+
torch.nn.utils.clip_grad_norm_(
|
|
167
|
+
self.flow.parameters(), clip_grad
|
|
168
|
+
)
|
|
169
|
+
optimizer.step()
|
|
170
|
+
loss_epoch += loss.item()
|
|
171
|
+
if lr_annealing:
|
|
172
|
+
scheduler.step()
|
|
173
|
+
avg_train_loss = loss_epoch / len(dataset)
|
|
174
|
+
history.training_loss.append(avg_train_loss)
|
|
175
|
+
self.flow.eval()
|
|
176
|
+
val_loss = 0.0
|
|
177
|
+
for (x_batch,) in val_dataset:
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
val_loss += self.loss_fn(x_batch).item()
|
|
180
|
+
avg_val_loss = val_loss / len(val_dataset)
|
|
181
|
+
if avg_val_loss < best_val_loss:
|
|
182
|
+
best_val_loss = avg_val_loss
|
|
183
|
+
best_flow_state = copy.deepcopy(self.flow.state_dict())
|
|
184
|
+
|
|
185
|
+
history.validation_loss.append(avg_val_loss)
|
|
186
|
+
pbar.set_postfix(
|
|
187
|
+
train_loss=f"{avg_train_loss:.4f}",
|
|
188
|
+
val_loss=f"{avg_val_loss:.4f}",
|
|
189
|
+
)
|
|
190
|
+
if best_flow_state is not None:
|
|
191
|
+
self.flow.load_state_dict(best_flow_state)
|
|
192
|
+
logger.info(f"Loaded best model with val loss {best_val_loss:.4f}")
|
|
193
|
+
|
|
194
|
+
self.flow.eval()
|
|
195
|
+
return history
|
|
196
|
+
|
|
197
|
+
def sample_and_log_prob(self, n_samples: int, xp=torch_api):
|
|
198
|
+
with torch.no_grad():
|
|
199
|
+
x_prime, log_prob = self.flow().rsample_and_log_prob((n_samples,))
|
|
200
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
201
|
+
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
202
|
+
|
|
203
|
+
def sample(self, n_samples: int, xp=torch_api):
|
|
204
|
+
with torch.no_grad():
|
|
205
|
+
x_prime = self.flow().rsample((n_samples,))
|
|
206
|
+
x = self.inverse_rescale(x_prime)[0]
|
|
207
|
+
return xp.asarray(x)
|
|
208
|
+
|
|
209
|
+
def log_prob(self, x, xp=torch_api):
|
|
210
|
+
x = torch.as_tensor(
|
|
211
|
+
x, dtype=torch.get_default_dtype(), device=self.device
|
|
212
|
+
)
|
|
213
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
214
|
+
return xp.asarray(
|
|
215
|
+
self._flow().log_prob(x_prime) + log_abs_det_jacobian
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def forward(self, x, xp=torch_api):
|
|
219
|
+
x = torch.as_tensor(
|
|
220
|
+
x, dtype=torch.get_default_dtype(), device=self.device
|
|
221
|
+
)
|
|
222
|
+
x_prime, log_j_rescale = self.rescale(x)
|
|
223
|
+
z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
|
|
224
|
+
if is_numpy_namespace(xp):
|
|
225
|
+
# Convert to numpy namespace if needed
|
|
226
|
+
z = z.detach().numpy()
|
|
227
|
+
log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
|
|
228
|
+
log_j_rescale = log_j_rescale.detach().numpy()
|
|
229
|
+
return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
|
|
230
|
+
|
|
231
|
+
def inverse(self, z, xp=torch_api):
|
|
232
|
+
z = torch.as_tensor(
|
|
233
|
+
z, dtype=torch.get_default_dtype(), device=self.device
|
|
234
|
+
)
|
|
235
|
+
with torch.no_grad():
|
|
236
|
+
x_prime, log_abs_det_jacobian = (
|
|
237
|
+
self._flow().transform.inv.call_and_ladj(z)
|
|
238
|
+
)
|
|
239
|
+
x, log_j_rescale = self.inverse_rescale(x_prime)
|
|
240
|
+
if is_numpy_namespace(xp):
|
|
241
|
+
# Convert to numpy namespace if needed
|
|
242
|
+
x = x.detach().numpy()
|
|
243
|
+
log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
|
|
244
|
+
log_j_rescale = log_j_rescale.detach().numpy()
|
|
245
|
+
return xp.asarray(x), xp.asarray(log_j_rescale + log_abs_det_jacobian)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class ZukoFlowMatching(ZukoFlow):
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
dims,
|
|
252
|
+
data_transform=None,
|
|
253
|
+
seed=1234,
|
|
254
|
+
device="cpu",
|
|
255
|
+
eta: float = 1e-3,
|
|
256
|
+
**kwargs,
|
|
257
|
+
):
|
|
258
|
+
kwargs.setdefault("hidden_features", 4 * [100])
|
|
259
|
+
super().__init__(
|
|
260
|
+
dims,
|
|
261
|
+
seed=seed,
|
|
262
|
+
device=device,
|
|
263
|
+
data_transform=data_transform,
|
|
264
|
+
flow_class="CNF",
|
|
265
|
+
)
|
|
266
|
+
self.eta = eta
|
|
267
|
+
|
|
268
|
+
def loss_fn(self, theta: torch.Tensor):
|
|
269
|
+
t = torch.rand(
|
|
270
|
+
theta.shape[:-1], dtype=theta.dtype, device=theta.device
|
|
271
|
+
)
|
|
272
|
+
t_ = t[..., None]
|
|
273
|
+
eps = torch.randn_like(theta)
|
|
274
|
+
theta_prime = (1 - t_) * theta + (t_ + self.eta) * eps
|
|
275
|
+
v = eps - theta
|
|
276
|
+
return (self._flow.transform.f(t, theta_prime) - v).square().mean()
|
aspire/history.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
from matplotlib.figure import Figure
|
|
8
|
+
|
|
9
|
+
from .utils import recursively_save_to_h5_file
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class History:
|
|
14
|
+
"""Base class for storing history of a sampler."""
|
|
15
|
+
|
|
16
|
+
def save(self, h5_file, path="history"):
|
|
17
|
+
"""Save the history to an HDF5 file."""
|
|
18
|
+
dictionary = copy.deepcopy(self.__dict__)
|
|
19
|
+
recursively_save_to_h5_file(h5_file, path, dictionary)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class FlowHistory(History):
|
|
24
|
+
training_loss: list[float] = field(default_factory=list)
|
|
25
|
+
validation_loss: list[float] = field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
def plot_loss(self) -> Figure:
|
|
28
|
+
"""Plot the training and validation loss."""
|
|
29
|
+
fig = plt.figure()
|
|
30
|
+
plt.plot(self.training_loss, label="Training loss")
|
|
31
|
+
plt.plot(self.validation_loss, label="Validation loss")
|
|
32
|
+
plt.legend()
|
|
33
|
+
plt.xlabel("Epoch")
|
|
34
|
+
plt.ylabel("Loss")
|
|
35
|
+
return fig
|
|
36
|
+
|
|
37
|
+
def save(self, h5_file, path="flow_history"):
|
|
38
|
+
"""Save the history to an HDF5 file."""
|
|
39
|
+
super().save(h5_file, path=path)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class SMCHistory(History):
|
|
44
|
+
log_norm_ratio: list[float] = field(default_factory=list)
|
|
45
|
+
log_norm_ratio_var: list[float] = field(default_factory=list)
|
|
46
|
+
beta: list[float] = field(default_factory=list)
|
|
47
|
+
ess: list[float] = field(default_factory=list)
|
|
48
|
+
ess_target: list[float] = field(default_factory=list)
|
|
49
|
+
eff_target: list[float] = field(default_factory=list)
|
|
50
|
+
mcmc_autocorr: list[float] = field(default_factory=list)
|
|
51
|
+
mcmc_acceptance: list[float] = field(default_factory=list)
|
|
52
|
+
|
|
53
|
+
def save(self, h5_file, path="smc_history"):
|
|
54
|
+
"""Save the history to an HDF5 file."""
|
|
55
|
+
super().save(h5_file, path=path)
|
|
56
|
+
|
|
57
|
+
def plot_beta(self, ax=None) -> Figure | None:
|
|
58
|
+
if ax is None:
|
|
59
|
+
fig, ax = plt.subplots()
|
|
60
|
+
else:
|
|
61
|
+
fig = None
|
|
62
|
+
ax.plot(self.beta)
|
|
63
|
+
ax.set_xlabel("Iteration")
|
|
64
|
+
ax.set_ylabel(r"$\beta$")
|
|
65
|
+
return fig
|
|
66
|
+
|
|
67
|
+
def plot_log_norm_ratio(self, ax=None) -> Figure | None:
|
|
68
|
+
if ax is None:
|
|
69
|
+
fig, ax = plt.subplots()
|
|
70
|
+
else:
|
|
71
|
+
fig = None
|
|
72
|
+
ax.plot(self.log_norm_ratio)
|
|
73
|
+
ax.set_xlabel("Iteration")
|
|
74
|
+
ax.set_ylabel("Log evidence ratio")
|
|
75
|
+
return fig
|
|
76
|
+
|
|
77
|
+
def plot_ess(self, ax=None) -> Figure | None:
|
|
78
|
+
if ax is None:
|
|
79
|
+
fig, ax = plt.subplots()
|
|
80
|
+
else:
|
|
81
|
+
fig = None
|
|
82
|
+
ax.plot(self.ess)
|
|
83
|
+
ax.set_xlabel("Iteration")
|
|
84
|
+
ax.set_ylabel("ESS")
|
|
85
|
+
return fig
|
|
86
|
+
|
|
87
|
+
def plot_ess_target(self, ax=None) -> Figure | None:
|
|
88
|
+
if ax is None:
|
|
89
|
+
fig, ax = plt.subplots()
|
|
90
|
+
else:
|
|
91
|
+
fig = None
|
|
92
|
+
ax.plot(self.ess_target)
|
|
93
|
+
ax.set_xlabel("Iteration")
|
|
94
|
+
ax.set_ylabel("ESS target")
|
|
95
|
+
return fig
|
|
96
|
+
|
|
97
|
+
def plot_eff_target(self, ax=None) -> Figure | None:
|
|
98
|
+
if ax is None:
|
|
99
|
+
fig, ax = plt.subplots()
|
|
100
|
+
else:
|
|
101
|
+
fig = None
|
|
102
|
+
ax.plot(self.eff_target)
|
|
103
|
+
ax.set_xlabel("Iteration")
|
|
104
|
+
ax.set_ylabel("Efficiency target")
|
|
105
|
+
return fig
|
|
106
|
+
|
|
107
|
+
def plot_mcmc_acceptance(self, ax=None) -> Figure | None:
|
|
108
|
+
if ax is None:
|
|
109
|
+
fig, ax = plt.subplots()
|
|
110
|
+
else:
|
|
111
|
+
fig = None
|
|
112
|
+
ax.plot(self.mcmc_acceptance)
|
|
113
|
+
ax.set_xlabel("Iteration")
|
|
114
|
+
ax.set_ylabel("MCMC Acceptance")
|
|
115
|
+
return fig
|
|
116
|
+
|
|
117
|
+
def plot_mcmc_autocorr(self, ax=None) -> Figure | None:
|
|
118
|
+
if ax is None:
|
|
119
|
+
fig, ax = plt.subplots()
|
|
120
|
+
else:
|
|
121
|
+
fig = None
|
|
122
|
+
ax.plot(self.mcmc_autocorr)
|
|
123
|
+
ax.set_xlabel("Iteration")
|
|
124
|
+
ax.set_ylabel("MCMC Autocorr")
|
|
125
|
+
return fig
|
|
126
|
+
|
|
127
|
+
def plot(self, fig: Figure | None = None) -> Figure:
|
|
128
|
+
methods = [
|
|
129
|
+
self.plot_beta,
|
|
130
|
+
self.plot_log_norm_ratio,
|
|
131
|
+
self.plot_ess,
|
|
132
|
+
self.plot_ess_target,
|
|
133
|
+
self.plot_eff_target,
|
|
134
|
+
self.plot_mcmc_acceptance,
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if fig is None:
|
|
138
|
+
fig, axs = plt.subplots(len(methods), 1, sharex=True)
|
|
139
|
+
else:
|
|
140
|
+
axs = fig.axes
|
|
141
|
+
|
|
142
|
+
for method, ax in zip(methods, axs):
|
|
143
|
+
method(ax)
|
|
144
|
+
|
|
145
|
+
for ax in axs[:-1]:
|
|
146
|
+
ax.set_xlabel("")
|
|
147
|
+
|
|
148
|
+
return fig
|
aspire/plot.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def plot_comparison(
|
|
5
|
+
*samples, parameters=None, per_samples_kwargs=None, labels=None, **kwargs
|
|
6
|
+
):
|
|
7
|
+
"""
|
|
8
|
+
Plot a comparison of multiple samples.
|
|
9
|
+
"""
|
|
10
|
+
default_kwargs = dict(
|
|
11
|
+
density=True,
|
|
12
|
+
bins=30,
|
|
13
|
+
color="C0",
|
|
14
|
+
smooth=1.0,
|
|
15
|
+
plot_datapoints=True,
|
|
16
|
+
plot_density=False,
|
|
17
|
+
hist_kwargs=dict(density=True, color="C0"),
|
|
18
|
+
)
|
|
19
|
+
default_kwargs.update(kwargs)
|
|
20
|
+
|
|
21
|
+
if per_samples_kwargs is None:
|
|
22
|
+
per_samples_kwargs = [{}] * len(samples)
|
|
23
|
+
|
|
24
|
+
fig = None
|
|
25
|
+
for i, sample in enumerate(samples):
|
|
26
|
+
kwds = copy.deepcopy(default_kwargs)
|
|
27
|
+
color = per_samples_kwargs[i].pop("color", f"C{i}")
|
|
28
|
+
kwds["color"] = color
|
|
29
|
+
kwds["hist_kwargs"]["color"] = color
|
|
30
|
+
kwds.update(per_samples_kwargs[i])
|
|
31
|
+
fig = sample.plot_corner(fig=fig, parameters=parameters, **kwds)
|
|
32
|
+
|
|
33
|
+
if labels:
|
|
34
|
+
fig.legend(
|
|
35
|
+
labels=labels,
|
|
36
|
+
loc="upper right",
|
|
37
|
+
bbox_to_anchor=(0.9, 0.9),
|
|
38
|
+
bbox_transform=fig.transFigure,
|
|
39
|
+
)
|
|
40
|
+
return fig
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_history_comparison(*histories):
|
|
44
|
+
# Assert that all histories are of the same type
|
|
45
|
+
if not all(isinstance(h, histories[0].__class__) for h in histories):
|
|
46
|
+
raise ValueError("All histories must be of the same type")
|
|
47
|
+
fig = histories[0].plot()
|
|
48
|
+
for history in histories[1:]:
|
|
49
|
+
fig = history.plot(fig=fig)
|
|
50
|
+
return fig
|
|
File without changes
|
aspire/samplers/base.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
from ..flows.base import Flow
|
|
5
|
+
from ..samples import Samples
|
|
6
|
+
from ..transforms import IdentityTransform
|
|
7
|
+
from ..utils import track_calls
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Sampler:
|
|
13
|
+
"""Base class for all samplers.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
log_likelihood : Callable
|
|
18
|
+
The log likelihood function.
|
|
19
|
+
log_prior : Callable
|
|
20
|
+
The log prior function.
|
|
21
|
+
dims : int
|
|
22
|
+
The number of dimensions.
|
|
23
|
+
flow : Flow
|
|
24
|
+
The flow object.
|
|
25
|
+
xp : Callable
|
|
26
|
+
The array backend to use.
|
|
27
|
+
parameters : list[str] | None
|
|
28
|
+
The list of parameter names. If None, any samples objects will not
|
|
29
|
+
have the parameters names specified.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
log_likelihood: Callable,
|
|
35
|
+
log_prior: Callable,
|
|
36
|
+
dims: int,
|
|
37
|
+
prior_flow: Flow,
|
|
38
|
+
xp: Callable,
|
|
39
|
+
parameters: list[str] | None = None,
|
|
40
|
+
preconditioning_transform: Callable | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.prior_flow = prior_flow
|
|
43
|
+
self._log_likelihood = log_likelihood
|
|
44
|
+
self.log_prior = log_prior
|
|
45
|
+
self.dims = dims
|
|
46
|
+
self.xp = xp
|
|
47
|
+
self.parameters = parameters
|
|
48
|
+
self.history = None
|
|
49
|
+
self.n_likelihood_evaluations = 0
|
|
50
|
+
if preconditioning_transform is None:
|
|
51
|
+
self.preconditioning_transform = IdentityTransform(xp=self.xp)
|
|
52
|
+
else:
|
|
53
|
+
self.preconditioning_transform = preconditioning_transform
|
|
54
|
+
|
|
55
|
+
def fit_preconditioning_transform(self, x):
|
|
56
|
+
"""Fit the data transform to the data."""
|
|
57
|
+
x = self.preconditioning_transform.xp.asarray(x)
|
|
58
|
+
return self.preconditioning_transform.fit(x)
|
|
59
|
+
|
|
60
|
+
@track_calls
|
|
61
|
+
def sample(self, n_samples: int) -> Samples:
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
def log_likelihood(self, samples: Samples) -> Samples:
|
|
65
|
+
"""Computes the log likelihood of the samples.
|
|
66
|
+
|
|
67
|
+
Also tracks the number of likelihood evaluations.
|
|
68
|
+
"""
|
|
69
|
+
self.n_likelihood_evaluations += len(samples)
|
|
70
|
+
return self._log_likelihood(samples)
|
|
71
|
+
|
|
72
|
+
def config_dict(self, include_sample_calls: bool = True) -> dict:
|
|
73
|
+
"""
|
|
74
|
+
Returns a dictionary with the configuration of the sampler.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
include_sample_calls : bool
|
|
79
|
+
Whether to include the sample calls in the configuration.
|
|
80
|
+
Default is True.
|
|
81
|
+
"""
|
|
82
|
+
config = {}
|
|
83
|
+
if include_sample_calls:
|
|
84
|
+
if hasattr(self, "sample") and hasattr(self.sample, "calls"):
|
|
85
|
+
config["sample_calls"] = self.sample.calls.to_dict(
|
|
86
|
+
list_to_dict=True
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
logger.warning(
|
|
90
|
+
"Sampler does not have a sample method with calls attribute."
|
|
91
|
+
)
|
|
92
|
+
return config
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from ..samples import Samples
|
|
2
|
+
from ..utils import track_calls
|
|
3
|
+
from .base import Sampler
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ImportanceSampler(Sampler):
|
|
7
|
+
@track_calls
|
|
8
|
+
def sample(self, n_samples: int) -> Samples:
|
|
9
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
10
|
+
samples = Samples(
|
|
11
|
+
x, log_q=log_q, xp=self.xp, parameters=self.parameters
|
|
12
|
+
)
|
|
13
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
14
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
15
|
+
self.log_likelihood(samples)
|
|
16
|
+
)
|
|
17
|
+
samples.compute_weights()
|
|
18
|
+
return samples
|