aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,276 @@
1
+ import copy
2
+ import logging
3
+ from typing import Callable
4
+
5
+ import array_api_compat.torch as torch_api
6
+ import torch
7
+ import tqdm
8
+ import zuko
9
+ from array_api_compat import is_numpy_namespace, is_torch_array
10
+
11
+ from ...history import FlowHistory
12
+ from ..base import Flow
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class BaseTorchFlow(Flow):
18
+ _flow = None
19
+ xp = torch_api
20
+
21
+ def __init__(
22
+ self,
23
+ dims: int,
24
+ seed: int = 1234,
25
+ device: str = "cpu",
26
+ data_transform=None,
27
+ ):
28
+ super().__init__(
29
+ dims,
30
+ device=torch.device(device or "cpu"),
31
+ data_transform=data_transform,
32
+ )
33
+ torch.manual_seed(seed)
34
+ self.loc = None
35
+ self.scale = None
36
+
37
+ @property
38
+ def flow(self):
39
+ return self._flow
40
+
41
+ @flow.setter
42
+ def flow(self, flow):
43
+ self._flow = flow
44
+ self._flow.to(self.device)
45
+ self._flow.compile()
46
+
47
+ def fit(self, x) -> FlowHistory:
48
+ raise NotImplementedError()
49
+
50
+
51
+ class ZukoFlow(BaseTorchFlow):
52
+ def __init__(
53
+ self,
54
+ dims,
55
+ flow_class: str | Callable = "MAF",
56
+ data_transform=None,
57
+ seed=1234,
58
+ device: str = "cpu",
59
+ **kwargs,
60
+ ):
61
+ super().__init__(
62
+ dims,
63
+ device=device,
64
+ data_transform=data_transform,
65
+ seed=seed,
66
+ )
67
+
68
+ if isinstance(flow_class, str):
69
+ FlowClass = getattr(zuko.flows, flow_class)
70
+ else:
71
+ FlowClass = flow_class
72
+
73
+ # Ints are some times passed as strings, so we convert them
74
+ if hidden_features := kwargs.pop("hidden_features", None):
75
+ kwargs["hidden_features"] = list(map(int, hidden_features))
76
+
77
+ self.flow = FlowClass(self.dims, 0, **kwargs)
78
+ logger.info(f"Initialized normalizing flow: \n {self.flow}\n")
79
+
80
+ def loss_fn(self, x):
81
+ return -self.flow().log_prob(x).mean()
82
+
83
+ def fit(
84
+ self,
85
+ x,
86
+ n_epochs: int = 100,
87
+ lr: float = 1e-3,
88
+ batch_size: int = 500,
89
+ validation_fraction: float = 0.2,
90
+ clip_grad: float | None = None,
91
+ lr_annealing: bool = False,
92
+ ):
93
+ from ...history import FlowHistory
94
+
95
+ if not is_torch_array(x):
96
+ x = torch.tensor(
97
+ x, dtype=torch.get_default_dtype(), device=self.device
98
+ )
99
+ else:
100
+ x = torch.clone(x)
101
+ x = x.type(torch.get_default_dtype())
102
+ x = x.to(self.device)
103
+ x_prime = self.fit_data_transform(x)
104
+ indices = torch.randperm(x_prime.shape[0])
105
+ x_prime = x_prime[indices, ...]
106
+
107
+ n = x_prime.shape[0]
108
+ x_train = torch.as_tensor(
109
+ x_prime[: -int(validation_fraction * n)],
110
+ dtype=torch.get_default_dtype(),
111
+ device=self.device,
112
+ )
113
+
114
+ logger.info(
115
+ f"Training on {x_train.shape[0]} samples, "
116
+ f"validating on {x_prime.shape[0] - x_train.shape[0]} samples."
117
+ )
118
+
119
+ if torch.isnan(x_train).any():
120
+ raise ValueError("Training data contains NaN values.")
121
+ if not torch.isfinite(x_train).all():
122
+ raise ValueError("Training data contains infinite values.")
123
+
124
+ x_val = torch.as_tensor(
125
+ x_prime[-int(validation_fraction * n) :],
126
+ dtype=torch.get_default_dtype(),
127
+ device=self.device,
128
+ )
129
+ if torch.isnan(x_val).any():
130
+ raise ValueError("Validation data contains infinite values.")
131
+
132
+ if not torch.isfinite(x_val).all():
133
+ raise ValueError("Validation data contains infinite values.")
134
+
135
+ dataset = torch.utils.data.DataLoader(
136
+ torch.utils.data.TensorDataset(x_train),
137
+ shuffle=True,
138
+ batch_size=batch_size,
139
+ )
140
+ val_dataset = torch.utils.data.DataLoader(
141
+ torch.utils.data.TensorDataset(x_val),
142
+ shuffle=False,
143
+ batch_size=batch_size,
144
+ )
145
+
146
+ # Train to maximize the log-likelihood
147
+ optimizer = torch.optim.Adam(self._flow.parameters(), lr=lr)
148
+ if lr_annealing:
149
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
150
+ optimizer, n_epochs
151
+ )
152
+ history = FlowHistory()
153
+
154
+ best_val_loss = float("inf")
155
+ best_flow_state = None
156
+
157
+ with tqdm.tqdm(range(n_epochs), desc="Epochs") as pbar:
158
+ for _ in pbar:
159
+ self.flow.train()
160
+ loss_epoch = 0.0
161
+ for (x_batch,) in dataset:
162
+ loss = self.loss_fn(x_batch)
163
+ optimizer.zero_grad()
164
+ loss.backward()
165
+ if clip_grad is not None:
166
+ torch.nn.utils.clip_grad_norm_(
167
+ self.flow.parameters(), clip_grad
168
+ )
169
+ optimizer.step()
170
+ loss_epoch += loss.item()
171
+ if lr_annealing:
172
+ scheduler.step()
173
+ avg_train_loss = loss_epoch / len(dataset)
174
+ history.training_loss.append(avg_train_loss)
175
+ self.flow.eval()
176
+ val_loss = 0.0
177
+ for (x_batch,) in val_dataset:
178
+ with torch.no_grad():
179
+ val_loss += self.loss_fn(x_batch).item()
180
+ avg_val_loss = val_loss / len(val_dataset)
181
+ if avg_val_loss < best_val_loss:
182
+ best_val_loss = avg_val_loss
183
+ best_flow_state = copy.deepcopy(self.flow.state_dict())
184
+
185
+ history.validation_loss.append(avg_val_loss)
186
+ pbar.set_postfix(
187
+ train_loss=f"{avg_train_loss:.4f}",
188
+ val_loss=f"{avg_val_loss:.4f}",
189
+ )
190
+ if best_flow_state is not None:
191
+ self.flow.load_state_dict(best_flow_state)
192
+ logger.info(f"Loaded best model with val loss {best_val_loss:.4f}")
193
+
194
+ self.flow.eval()
195
+ return history
196
+
197
+ def sample_and_log_prob(self, n_samples: int, xp=torch_api):
198
+ with torch.no_grad():
199
+ x_prime, log_prob = self.flow().rsample_and_log_prob((n_samples,))
200
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
201
+ return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
202
+
203
+ def sample(self, n_samples: int, xp=torch_api):
204
+ with torch.no_grad():
205
+ x_prime = self.flow().rsample((n_samples,))
206
+ x = self.inverse_rescale(x_prime)[0]
207
+ return xp.asarray(x)
208
+
209
+ def log_prob(self, x, xp=torch_api):
210
+ x = torch.as_tensor(
211
+ x, dtype=torch.get_default_dtype(), device=self.device
212
+ )
213
+ x_prime, log_abs_det_jacobian = self.rescale(x)
214
+ return xp.asarray(
215
+ self._flow().log_prob(x_prime) + log_abs_det_jacobian
216
+ )
217
+
218
+ def forward(self, x, xp=torch_api):
219
+ x = torch.as_tensor(
220
+ x, dtype=torch.get_default_dtype(), device=self.device
221
+ )
222
+ x_prime, log_j_rescale = self.rescale(x)
223
+ z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
224
+ if is_numpy_namespace(xp):
225
+ # Convert to numpy namespace if needed
226
+ z = z.detach().numpy()
227
+ log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
228
+ log_j_rescale = log_j_rescale.detach().numpy()
229
+ return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
230
+
231
+ def inverse(self, z, xp=torch_api):
232
+ z = torch.as_tensor(
233
+ z, dtype=torch.get_default_dtype(), device=self.device
234
+ )
235
+ with torch.no_grad():
236
+ x_prime, log_abs_det_jacobian = (
237
+ self._flow().transform.inv.call_and_ladj(z)
238
+ )
239
+ x, log_j_rescale = self.inverse_rescale(x_prime)
240
+ if is_numpy_namespace(xp):
241
+ # Convert to numpy namespace if needed
242
+ x = x.detach().numpy()
243
+ log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
244
+ log_j_rescale = log_j_rescale.detach().numpy()
245
+ return xp.asarray(x), xp.asarray(log_j_rescale + log_abs_det_jacobian)
246
+
247
+
248
+ class ZukoFlowMatching(ZukoFlow):
249
+ def __init__(
250
+ self,
251
+ dims,
252
+ data_transform=None,
253
+ seed=1234,
254
+ device="cpu",
255
+ eta: float = 1e-3,
256
+ **kwargs,
257
+ ):
258
+ kwargs.setdefault("hidden_features", 4 * [100])
259
+ super().__init__(
260
+ dims,
261
+ seed=seed,
262
+ device=device,
263
+ data_transform=data_transform,
264
+ flow_class="CNF",
265
+ )
266
+ self.eta = eta
267
+
268
+ def loss_fn(self, theta: torch.Tensor):
269
+ t = torch.rand(
270
+ theta.shape[:-1], dtype=theta.dtype, device=theta.device
271
+ )
272
+ t_ = t[..., None]
273
+ eps = torch.randn_like(theta)
274
+ theta_prime = (1 - t_) * theta + (t_ + self.eta) * eps
275
+ v = eps - theta
276
+ return (self._flow.transform.f(t, theta_prime) - v).square().mean()
aspire/history.py ADDED
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass, field
5
+
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.figure import Figure
8
+
9
+ from .utils import recursively_save_to_h5_file
10
+
11
+
12
+ @dataclass
13
+ class History:
14
+ """Base class for storing history of a sampler."""
15
+
16
+ def save(self, h5_file, path="history"):
17
+ """Save the history to an HDF5 file."""
18
+ dictionary = copy.deepcopy(self.__dict__)
19
+ recursively_save_to_h5_file(h5_file, path, dictionary)
20
+
21
+
22
+ @dataclass
23
+ class FlowHistory(History):
24
+ training_loss: list[float] = field(default_factory=list)
25
+ validation_loss: list[float] = field(default_factory=list)
26
+
27
+ def plot_loss(self) -> Figure:
28
+ """Plot the training and validation loss."""
29
+ fig = plt.figure()
30
+ plt.plot(self.training_loss, label="Training loss")
31
+ plt.plot(self.validation_loss, label="Validation loss")
32
+ plt.legend()
33
+ plt.xlabel("Epoch")
34
+ plt.ylabel("Loss")
35
+ return fig
36
+
37
+ def save(self, h5_file, path="flow_history"):
38
+ """Save the history to an HDF5 file."""
39
+ super().save(h5_file, path=path)
40
+
41
+
42
+ @dataclass
43
+ class SMCHistory(History):
44
+ log_norm_ratio: list[float] = field(default_factory=list)
45
+ log_norm_ratio_var: list[float] = field(default_factory=list)
46
+ beta: list[float] = field(default_factory=list)
47
+ ess: list[float] = field(default_factory=list)
48
+ ess_target: list[float] = field(default_factory=list)
49
+ eff_target: list[float] = field(default_factory=list)
50
+ mcmc_autocorr: list[float] = field(default_factory=list)
51
+ mcmc_acceptance: list[float] = field(default_factory=list)
52
+
53
+ def save(self, h5_file, path="smc_history"):
54
+ """Save the history to an HDF5 file."""
55
+ super().save(h5_file, path=path)
56
+
57
+ def plot_beta(self, ax=None) -> Figure | None:
58
+ if ax is None:
59
+ fig, ax = plt.subplots()
60
+ else:
61
+ fig = None
62
+ ax.plot(self.beta)
63
+ ax.set_xlabel("Iteration")
64
+ ax.set_ylabel(r"$\beta$")
65
+ return fig
66
+
67
+ def plot_log_norm_ratio(self, ax=None) -> Figure | None:
68
+ if ax is None:
69
+ fig, ax = plt.subplots()
70
+ else:
71
+ fig = None
72
+ ax.plot(self.log_norm_ratio)
73
+ ax.set_xlabel("Iteration")
74
+ ax.set_ylabel("Log evidence ratio")
75
+ return fig
76
+
77
+ def plot_ess(self, ax=None) -> Figure | None:
78
+ if ax is None:
79
+ fig, ax = plt.subplots()
80
+ else:
81
+ fig = None
82
+ ax.plot(self.ess)
83
+ ax.set_xlabel("Iteration")
84
+ ax.set_ylabel("ESS")
85
+ return fig
86
+
87
+ def plot_ess_target(self, ax=None) -> Figure | None:
88
+ if ax is None:
89
+ fig, ax = plt.subplots()
90
+ else:
91
+ fig = None
92
+ ax.plot(self.ess_target)
93
+ ax.set_xlabel("Iteration")
94
+ ax.set_ylabel("ESS target")
95
+ return fig
96
+
97
+ def plot_eff_target(self, ax=None) -> Figure | None:
98
+ if ax is None:
99
+ fig, ax = plt.subplots()
100
+ else:
101
+ fig = None
102
+ ax.plot(self.eff_target)
103
+ ax.set_xlabel("Iteration")
104
+ ax.set_ylabel("Efficiency target")
105
+ return fig
106
+
107
+ def plot_mcmc_acceptance(self, ax=None) -> Figure | None:
108
+ if ax is None:
109
+ fig, ax = plt.subplots()
110
+ else:
111
+ fig = None
112
+ ax.plot(self.mcmc_acceptance)
113
+ ax.set_xlabel("Iteration")
114
+ ax.set_ylabel("MCMC Acceptance")
115
+ return fig
116
+
117
+ def plot_mcmc_autocorr(self, ax=None) -> Figure | None:
118
+ if ax is None:
119
+ fig, ax = plt.subplots()
120
+ else:
121
+ fig = None
122
+ ax.plot(self.mcmc_autocorr)
123
+ ax.set_xlabel("Iteration")
124
+ ax.set_ylabel("MCMC Autocorr")
125
+ return fig
126
+
127
+ def plot(self, fig: Figure | None = None) -> Figure:
128
+ methods = [
129
+ self.plot_beta,
130
+ self.plot_log_norm_ratio,
131
+ self.plot_ess,
132
+ self.plot_ess_target,
133
+ self.plot_eff_target,
134
+ self.plot_mcmc_acceptance,
135
+ ]
136
+
137
+ if fig is None:
138
+ fig, axs = plt.subplots(len(methods), 1, sharex=True)
139
+ else:
140
+ axs = fig.axes
141
+
142
+ for method, ax in zip(methods, axs):
143
+ method(ax)
144
+
145
+ for ax in axs[:-1]:
146
+ ax.set_xlabel("")
147
+
148
+ return fig
aspire/plot.py ADDED
@@ -0,0 +1,50 @@
1
+ import copy
2
+
3
+
4
+ def plot_comparison(
5
+ *samples, parameters=None, per_samples_kwargs=None, labels=None, **kwargs
6
+ ):
7
+ """
8
+ Plot a comparison of multiple samples.
9
+ """
10
+ default_kwargs = dict(
11
+ density=True,
12
+ bins=30,
13
+ color="C0",
14
+ smooth=1.0,
15
+ plot_datapoints=True,
16
+ plot_density=False,
17
+ hist_kwargs=dict(density=True, color="C0"),
18
+ )
19
+ default_kwargs.update(kwargs)
20
+
21
+ if per_samples_kwargs is None:
22
+ per_samples_kwargs = [{}] * len(samples)
23
+
24
+ fig = None
25
+ for i, sample in enumerate(samples):
26
+ kwds = copy.deepcopy(default_kwargs)
27
+ color = per_samples_kwargs[i].pop("color", f"C{i}")
28
+ kwds["color"] = color
29
+ kwds["hist_kwargs"]["color"] = color
30
+ kwds.update(per_samples_kwargs[i])
31
+ fig = sample.plot_corner(fig=fig, parameters=parameters, **kwds)
32
+
33
+ if labels:
34
+ fig.legend(
35
+ labels=labels,
36
+ loc="upper right",
37
+ bbox_to_anchor=(0.9, 0.9),
38
+ bbox_transform=fig.transFigure,
39
+ )
40
+ return fig
41
+
42
+
43
+ def plot_history_comparison(*histories):
44
+ # Assert that all histories are of the same type
45
+ if not all(isinstance(h, histories[0].__class__) for h in histories):
46
+ raise ValueError("All histories must be of the same type")
47
+ fig = histories[0].plot()
48
+ for history in histories[1:]:
49
+ fig = history.plot(fig=fig)
50
+ return fig
File without changes
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ from ..flows.base import Flow
5
+ from ..samples import Samples
6
+ from ..transforms import IdentityTransform
7
+ from ..utils import track_calls
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class Sampler:
13
+ """Base class for all samplers.
14
+
15
+ Parameters
16
+ ----------
17
+ log_likelihood : Callable
18
+ The log likelihood function.
19
+ log_prior : Callable
20
+ The log prior function.
21
+ dims : int
22
+ The number of dimensions.
23
+ flow : Flow
24
+ The flow object.
25
+ xp : Callable
26
+ The array backend to use.
27
+ parameters : list[str] | None
28
+ The list of parameter names. If None, any samples objects will not
29
+ have the parameters names specified.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ log_likelihood: Callable,
35
+ log_prior: Callable,
36
+ dims: int,
37
+ prior_flow: Flow,
38
+ xp: Callable,
39
+ parameters: list[str] | None = None,
40
+ preconditioning_transform: Callable | None = None,
41
+ ):
42
+ self.prior_flow = prior_flow
43
+ self._log_likelihood = log_likelihood
44
+ self.log_prior = log_prior
45
+ self.dims = dims
46
+ self.xp = xp
47
+ self.parameters = parameters
48
+ self.history = None
49
+ self.n_likelihood_evaluations = 0
50
+ if preconditioning_transform is None:
51
+ self.preconditioning_transform = IdentityTransform(xp=self.xp)
52
+ else:
53
+ self.preconditioning_transform = preconditioning_transform
54
+
55
+ def fit_preconditioning_transform(self, x):
56
+ """Fit the data transform to the data."""
57
+ x = self.preconditioning_transform.xp.asarray(x)
58
+ return self.preconditioning_transform.fit(x)
59
+
60
+ @track_calls
61
+ def sample(self, n_samples: int) -> Samples:
62
+ raise NotImplementedError
63
+
64
+ def log_likelihood(self, samples: Samples) -> Samples:
65
+ """Computes the log likelihood of the samples.
66
+
67
+ Also tracks the number of likelihood evaluations.
68
+ """
69
+ self.n_likelihood_evaluations += len(samples)
70
+ return self._log_likelihood(samples)
71
+
72
+ def config_dict(self, include_sample_calls: bool = True) -> dict:
73
+ """
74
+ Returns a dictionary with the configuration of the sampler.
75
+
76
+ Parameters
77
+ ----------
78
+ include_sample_calls : bool
79
+ Whether to include the sample calls in the configuration.
80
+ Default is True.
81
+ """
82
+ config = {}
83
+ if include_sample_calls:
84
+ if hasattr(self, "sample") and hasattr(self.sample, "calls"):
85
+ config["sample_calls"] = self.sample.calls.to_dict(
86
+ list_to_dict=True
87
+ )
88
+ else:
89
+ logger.warning(
90
+ "Sampler does not have a sample method with calls attribute."
91
+ )
92
+ return config
@@ -0,0 +1,18 @@
1
+ from ..samples import Samples
2
+ from ..utils import track_calls
3
+ from .base import Sampler
4
+
5
+
6
+ class ImportanceSampler(Sampler):
7
+ @track_calls
8
+ def sample(self, n_samples: int) -> Samples:
9
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
10
+ samples = Samples(
11
+ x, log_q=log_q, xp=self.xp, parameters=self.parameters
12
+ )
13
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
14
+ samples.log_likelihood = samples.array_to_namespace(
15
+ self.log_likelihood(samples)
16
+ )
17
+ samples.compute_weights()
18
+ return samples