aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/samples.py ADDED
@@ -0,0 +1,476 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ import math
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable
8
+
9
+ import numpy as np
10
+ from array_api_compat import (
11
+ array_namespace,
12
+ is_numpy_namespace,
13
+ to_device,
14
+ )
15
+ from array_api_compat import device as api_device
16
+ from array_api_compat.common._typing import Array
17
+
18
+ from .utils import asarray, logsumexp, recursively_save_to_h5_file, to_numpy
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class BaseSamples:
25
+ """Class for storing samples and corresponding weights.
26
+
27
+ If :code:`xp` is not specified, all inputs will be converted to match
28
+ the array type of :code:`x`.
29
+ """
30
+
31
+ x: Array
32
+ log_likelihood: Array | None = None
33
+ log_prior: Array | None = None
34
+ log_q: Array | None = None
35
+ parameters: list[str] | None = None
36
+ xp: Callable | None = None
37
+ device: Any = None
38
+
39
+ def __post_init__(self):
40
+ if self.xp is None:
41
+ self.xp = array_namespace(self.x)
42
+ # Numpy arrays need to be on the CPU before being converted
43
+ if is_numpy_namespace(self.xp):
44
+ self.device = "cpu"
45
+ self.x = self.array_to_namespace(self.x)
46
+ if self.device is None:
47
+ self.device = api_device(self.x)
48
+ if self.log_likelihood is not None:
49
+ self.log_likelihood = self.array_to_namespace(self.log_likelihood)
50
+ if self.log_prior is not None:
51
+ self.log_prior = self.array_to_namespace(self.log_prior)
52
+ if self.log_q is not None:
53
+ self.log_q = self.array_to_namespace(self.log_q)
54
+
55
+ if self.parameters is None:
56
+ self.parameters = [f"x_{i}" for i in range(self.dims)]
57
+
58
+ @property
59
+ def dims(self):
60
+ """Number of dimensions (parameters) in the samples."""
61
+ if self.x is None:
62
+ return 0
63
+ return self.x.shape[1] if self.x.ndim > 1 else 1
64
+
65
+ def to_numpy(self):
66
+ logger.debug("Converting samples to numpy arrays")
67
+ return self.__class__(
68
+ x=to_numpy(self.x),
69
+ parameters=self.parameters,
70
+ log_likelihood=to_numpy(self.log_likelihood)
71
+ if self.log_likelihood is not None
72
+ else None,
73
+ log_prior=to_numpy(self.log_prior)
74
+ if self.log_prior is not None
75
+ else None,
76
+ log_q=to_numpy(self.log_q) if self.log_q is not None else None,
77
+ )
78
+
79
+ def to_namespace(self, xp):
80
+ logger.debug("Converting samples to {} namespace", xp)
81
+ return self.__class__(
82
+ x=asarray(self.x, xp),
83
+ parameters=self.parameters,
84
+ log_likelihood=asarray(self.log_likelihood, xp)
85
+ if self.log_likelihood is not None
86
+ else None,
87
+ log_prior=asarray(self.log_prior, xp)
88
+ if self.log_prior is not None
89
+ else None,
90
+ log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
91
+ )
92
+
93
+ def array_to_namespace(self, x):
94
+ """Convert an array to the same namespace as the samples"""
95
+ x = asarray(x, self.xp)
96
+ if self.device:
97
+ x = to_device(x, self.device)
98
+ return x
99
+
100
+ def to_dict(self, flat: bool = True):
101
+ samples = dict(zip(self.parameters, self.x.T, strict=True))
102
+ out = {
103
+ "log_likelihood": self.log_likelihood,
104
+ "log_prior": self.log_prior,
105
+ "log_q": self.log_q,
106
+ }
107
+ if flat:
108
+ out.update(samples)
109
+ else:
110
+ out["samples"] = samples
111
+ return out
112
+
113
+ def to_dataframe(self, flat: bool = True):
114
+ import pandas as pd
115
+
116
+ return pd.DataFrame(self.to_dict(flat=flat))
117
+
118
+ def plot_corner(self, parameters: list[str] | None = None, **kwargs):
119
+ import corner
120
+
121
+ kwargs = copy.deepcopy(kwargs)
122
+ kwargs.setdefault("labels", self.parameters)
123
+
124
+ if parameters is not None:
125
+ indices = [self.parameters.index(p) for p in parameters]
126
+ kwargs["labels"] = parameters
127
+ x = self.x[:, indices] if self.x.ndim > 1 else self.x[indices]
128
+ else:
129
+ x = self.x
130
+ fig = corner.corner(to_numpy(x), **kwargs)
131
+ return fig
132
+
133
+ def __str__(self):
134
+ out = (
135
+ f"No. samples: {len(self.x)}\n"
136
+ f"No. parameters: {len(self.parameters)}\n"
137
+ )
138
+ return out
139
+
140
+ def save(self, h5_file, path="samples", flat=False):
141
+ """Save the samples to an HDF5 file.
142
+
143
+ This converts the samples to numpy and then to a dictionary.
144
+
145
+ Parameters
146
+ ----------
147
+ h5_file : h5py.File
148
+ The HDF5 file to save to.
149
+ path : str
150
+ The path in the HDF5 file to save to.
151
+ flat : bool
152
+ If True, save the samples as a flat dictionary.
153
+ If False, save the samples as a nested dictionary.
154
+ """
155
+ dictionary = self.to_numpy().to_dict(flat=flat)
156
+ recursively_save_to_h5_file(h5_file, path, dictionary)
157
+
158
+ def __len__(self):
159
+ return len(self.x)
160
+
161
+ def __getitem__(self, idx) -> BaseSamples:
162
+ return self.__class__(
163
+ x=self.x[idx],
164
+ log_likelihood=self.log_likelihood[idx]
165
+ if self.log_likelihood is not None
166
+ else None,
167
+ log_prior=self.log_prior[idx]
168
+ if self.log_prior is not None
169
+ else None,
170
+ log_q=self.log_q[idx] if self.log_q is not None else None,
171
+ parameters=self.parameters,
172
+ )
173
+
174
+ def __setitem__(self, idx, value: BaseSamples):
175
+ raise NotImplementedError("Setting items is not supported")
176
+
177
+ @classmethod
178
+ def concatenate(cls, samples: list[BaseSamples]) -> BaseSamples:
179
+ """Concatenate multiple Samples objects."""
180
+ if not samples:
181
+ raise ValueError("No samples to concatenate")
182
+ if not all(s.parameters == samples[0].parameters for s in samples):
183
+ raise ValueError("Parameters do not match")
184
+ if not all(s.xp == samples[0].xp for s in samples):
185
+ raise ValueError("Array namespaces do not match")
186
+ xp = samples[0].xp
187
+ return cls(
188
+ x=xp.concatenate([s.x for s in samples], axis=0),
189
+ log_likelihood=xp.concatenate(
190
+ [s.log_likelihood for s in samples], axis=0
191
+ )
192
+ if all(s.log_likelihood is not None for s in samples)
193
+ else None,
194
+ log_prior=xp.concatenate([s.log_prior for s in samples], axis=0)
195
+ if all(s.log_prior is not None for s in samples)
196
+ else None,
197
+ log_q=xp.concatenate([s.log_q for s in samples], axis=0)
198
+ if all(s.log_q is not None for s in samples)
199
+ else None,
200
+ parameters=samples[0].parameters,
201
+ )
202
+
203
+ @classmethod
204
+ def from_samples(cls, samples: BaseSamples, **kwargs) -> BaseSamples:
205
+ """Create a Samples object from a BaseSamples object."""
206
+ xp = kwargs.pop("xp", samples.xp)
207
+ device = kwargs.pop("device", samples.device)
208
+ return cls(
209
+ x=samples.x,
210
+ log_likelihood=samples.log_likelihood,
211
+ log_prior=samples.log_prior,
212
+ log_q=samples.log_q,
213
+ parameters=samples.parameters,
214
+ xp=xp,
215
+ device=device,
216
+ **kwargs,
217
+ )
218
+
219
+
220
+ @dataclass
221
+ class Samples(BaseSamples):
222
+ """Class for storing samples and corresponding weights.
223
+
224
+ If :code:`xp` is not specified, all inputs will be converted to match
225
+ the array type of :code:`x`.
226
+ """
227
+
228
+ log_w: Array = field(init=False)
229
+ weights: Array = field(init=False)
230
+ evidence: float = field(init=False)
231
+ evidence_error: float = field(init=False)
232
+ log_evidence: float | None = None
233
+ log_evidence_error: float | None = None
234
+ effective_sample_size: float = field(init=False)
235
+
236
+ def __post_init__(self):
237
+ super().__post_init__()
238
+
239
+ if all(
240
+ x is not None
241
+ for x in [self.log_likelihood, self.log_prior, self.log_q]
242
+ ):
243
+ self.compute_weights()
244
+ else:
245
+ self.log_w = None
246
+ self.weights = None
247
+ self.evidence = None
248
+ self.evidence_error = None
249
+ self.effective_sample_size = None
250
+
251
+ @property
252
+ def efficiency(self):
253
+ """Efficiency of the weighted samples.
254
+
255
+ Defined as ESS / number of samples.
256
+ """
257
+ if self.log_w is None:
258
+ raise RuntimeError("Samples do not contain weights!")
259
+ return self.effective_sample_size / len(self.x)
260
+
261
+ def compute_weights(self):
262
+ """Compute the posterior weights."""
263
+ self.log_w = self.log_likelihood + self.log_prior - self.log_q
264
+ self.log_evidence = asarray(logsumexp(self.log_w), self.xp) - math.log(
265
+ len(self.x)
266
+ )
267
+ self.weights = self.xp.exp(self.log_w)
268
+ self.evidence = self.xp.exp(self.log_evidence)
269
+ n = len(self.x)
270
+ self.evidence_error = self.xp.sqrt(
271
+ self.xp.sum((self.weights - self.evidence) ** 2) / (n * (n - 1))
272
+ )
273
+ self.log_evidence_error = self.xp.abs(
274
+ self.evidence_error / self.evidence
275
+ )
276
+ log_w = self.log_w - self.xp.max(self.log_w)
277
+ self.effective_sample_size = self.xp.exp(
278
+ asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2), self.xp)
279
+ )
280
+
281
+ @property
282
+ def scaled_weights(self):
283
+ return self.xp.exp(self.log_w - self.xp.max(self.log_w))
284
+
285
+ def rejection_sample(self, rng=None):
286
+ if rng is None:
287
+ rng = np.random.default_rng()
288
+ log_u = asarray(
289
+ np.log(rng.uniform(size=len(self.x))), self.xp, device=self.device
290
+ )
291
+ log_w = self.log_w - self.xp.max(self.log_w)
292
+ accept = log_w > log_u
293
+ return self.__class__(
294
+ x=self.x[accept],
295
+ log_likelihood=self.log_likelihood[accept],
296
+ log_prior=self.log_prior[accept],
297
+ )
298
+
299
+ def to_dict(self, flat: bool = True):
300
+ samples = dict(zip(self.parameters, self.x.T, strict=True))
301
+ out = super().to_dict(flat=flat)
302
+ other = {
303
+ "log_w": self.log_w,
304
+ "weights": self.weights,
305
+ "evidence": self.evidence,
306
+ "log_evidence": self.log_evidence,
307
+ "evidence_error": self.evidence_error,
308
+ "log_evidence_error": self.log_evidence_error,
309
+ "effective_sample_size": self.effective_sample_size,
310
+ }
311
+ out.update(other)
312
+ if flat:
313
+ out.update(samples)
314
+ else:
315
+ out["samples"] = samples
316
+ return out
317
+
318
+ def plot_corner(self, include_weights: bool = True, **kwargs):
319
+ kwargs = copy.deepcopy(kwargs)
320
+ if (
321
+ include_weights
322
+ and self.weights is not None
323
+ and "weights" not in kwargs
324
+ ):
325
+ kwargs["weights"] = to_numpy(self.scaled_weights)
326
+ return super().plot_corner(**kwargs)
327
+
328
+ def __str__(self):
329
+ out = super().__str__()
330
+ if self.log_evidence is not None:
331
+ out += f"Log evidence: {self.log_evidence:.2f} +/- {self.log_evidence_error:.2f}\n"
332
+ if self.log_w is not None:
333
+ out += (
334
+ f"Effective sample size: {self.effective_sample_size:.1f}\n"
335
+ f"Efficiency: {self.efficiency:.2f}\n"
336
+ )
337
+ return out
338
+
339
+ def to_namespace(self, xp):
340
+ return self.__class__(
341
+ x=asarray(self.x, xp),
342
+ parameters=self.parameters,
343
+ log_likelihood=asarray(self.log_likelihood, xp)
344
+ if self.log_likelihood is not None
345
+ else None,
346
+ log_prior=asarray(self.log_prior, xp)
347
+ if self.log_prior is not None
348
+ else None,
349
+ log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
350
+ log_evidence=asarray(self.log_evidence, xp)
351
+ if self.log_evidence is not None
352
+ else None,
353
+ log_evidence_error=asarray(self.log_evidence_error, xp)
354
+ if self.log_evidence_error is not None
355
+ else None,
356
+ )
357
+
358
+ def to_numpy(self):
359
+ return self.__class__(
360
+ x=to_numpy(self.x),
361
+ parameters=self.parameters,
362
+ log_likelihood=to_numpy(self.log_likelihood)
363
+ if self.log_likelihood is not None
364
+ else None,
365
+ log_prior=to_numpy(self.log_prior)
366
+ if self.log_prior is not None
367
+ else None,
368
+ log_q=to_numpy(self.log_q) if self.log_q is not None else None,
369
+ log_evidence=self.log_evidence
370
+ if self.log_evidence is not None
371
+ else None,
372
+ log_evidence_error=self.log_evidence_error
373
+ if self.log_evidence_error is not None
374
+ else None,
375
+ )
376
+
377
+ def __getitem__(self, idx):
378
+ sliced = super().__getitem__(idx)
379
+ return self.__class__.from_samples(
380
+ sliced,
381
+ log_evidence=self.log_evidence,
382
+ log_evidence_error=self.log_evidence_error,
383
+ )
384
+
385
+
386
+ @dataclass
387
+ class SMCSamples(BaseSamples):
388
+ beta: float | None = None
389
+ log_evidence: float | None = None
390
+ """Temperature parameter for the current samples."""
391
+
392
+ def log_p_t(self, beta):
393
+ log_p_T = self.log_likelihood + self.log_prior
394
+ return (1 - beta) * self.log_q + beta * log_p_T
395
+
396
+ def unnormalized_log_weights(self, beta: float) -> Array:
397
+ return (self.beta - beta) * self.log_q + (beta - self.beta) * (
398
+ self.log_likelihood + self.log_prior
399
+ )
400
+
401
+ def log_evidence_ratio(self, beta: float) -> float:
402
+ log_w = self.unnormalized_log_weights(beta)
403
+ return logsumexp(log_w) - math.log(len(self.x))
404
+
405
+ def log_evidence_ratio_variance(self, beta: float) -> float:
406
+ """Estimate the variance of the log evidence ratio using the delta method.
407
+
408
+ Defined as Var(log Z) = Var(w) / (E[w])^2 where w are the unnormalized weights.
409
+ """
410
+ log_w = self.unnormalized_log_weights(beta)
411
+ m = self.xp.max(log_w)
412
+ u = self.xp.exp(log_w - m)
413
+ mean_w = self.xp.mean(u)
414
+ var_w = self.xp.var(u)
415
+ return (
416
+ var_w / (len(self) * (mean_w**2)) if mean_w != 0 else self.xp.nan
417
+ )
418
+
419
+ def log_weights(self, beta: float) -> Array:
420
+ log_w = self.unnormalized_log_weights(beta)
421
+ if self.xp.isnan(log_w).any():
422
+ raise ValueError(f"Log weights contain NaN values for beta={beta}")
423
+ log_evidence_ratio = logsumexp(log_w) - math.log(len(self.x))
424
+ return log_w + log_evidence_ratio
425
+
426
+ def resample(
427
+ self,
428
+ beta,
429
+ n_samples: int | None = None,
430
+ rng: np.random.Generator = None,
431
+ ) -> "SMCSamples":
432
+ if beta == self.beta and n_samples is None:
433
+ logger.warning(
434
+ "Resampling with the same beta value, returning identical samples"
435
+ )
436
+ return self
437
+ if rng is None:
438
+ rng = np.random.default_rng()
439
+ if n_samples is None:
440
+ n_samples = len(self.x)
441
+ log_w = self.log_weights(beta)
442
+ w = to_numpy(self.xp.exp(log_w - logsumexp(log_w)))
443
+ idx = rng.choice(len(self.x), size=n_samples, replace=True, p=w)
444
+ return self.__class__(
445
+ x=self.x[idx],
446
+ log_likelihood=self.log_likelihood[idx],
447
+ log_prior=self.log_prior[idx],
448
+ log_q=self.log_q[idx],
449
+ beta=beta,
450
+ )
451
+
452
+ def __str__(self):
453
+ out = super().__str__()
454
+ if self.log_evidence is not None:
455
+ out += f"Log evidence: {self.log_evidence:.2f}\n"
456
+ return out
457
+
458
+ def to_standard_samples(self):
459
+ """Convert the samples to standard samples."""
460
+ return Samples(
461
+ x=self.x,
462
+ log_likelihood=self.log_likelihood,
463
+ log_prior=self.log_prior,
464
+ xp=self.xp,
465
+ parameters=self.parameters,
466
+ log_evidence=self.log_evidence,
467
+ log_evidence_error=self.log_evidence_error,
468
+ )
469
+
470
+ def __getitem__(self, idx):
471
+ sliced = super().__getitem__(idx)
472
+ return self.__class__.from_samples(
473
+ sliced,
474
+ beta=self.beta,
475
+ log_evidence=self.log_evidence,
476
+ )