aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/samples.py ADDED
@@ -0,0 +1,568 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ import math
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable
8
+
9
+ import numpy as np
10
+ from array_api_compat import (
11
+ array_namespace,
12
+ is_numpy_namespace,
13
+ to_device,
14
+ )
15
+ from array_api_compat import device as api_device
16
+ from array_api_compat.common._typing import Array
17
+ from matplotlib.figure import Figure
18
+
19
+ from .utils import (
20
+ asarray,
21
+ convert_dtype,
22
+ logsumexp,
23
+ recursively_save_to_h5_file,
24
+ resolve_dtype,
25
+ to_numpy,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class BaseSamples:
33
+ """Class for storing samples and corresponding weights.
34
+
35
+ If :code:`xp` is not specified, all inputs will be converted to match
36
+ the array type of :code:`x`.
37
+ """
38
+
39
+ x: Array
40
+ """Array of samples, shape (n_samples, n_dims)."""
41
+ log_likelihood: Array | None = None
42
+ """Log-likelihood values for the samples."""
43
+ log_prior: Array | None = None
44
+ """Log-prior values for the samples."""
45
+ log_q: Array | None = None
46
+ """Log-probability values under the proposal distribution."""
47
+ parameters: list[str] | None = None
48
+ """List of parameter names."""
49
+ dtype: Any | str | None = None
50
+ """Data type of the samples.
51
+
52
+ If None, the default dtype for the array namespace will be used.
53
+ """
54
+ xp: Callable | None = None
55
+ """
56
+ The array namespace to use for the samples.
57
+
58
+ If None, the array namespace will be inferred from the type of :code:`x`.
59
+ """
60
+ device: Any = None
61
+ """Device to store the samples on.
62
+
63
+ If None, the device will be inferred from the array namespace of :code:`x`.
64
+ """
65
+
66
+ def __post_init__(self):
67
+ if self.xp is None:
68
+ self.xp = array_namespace(self.x)
69
+ # Numpy arrays need to be on the CPU before being converted
70
+ if is_numpy_namespace(self.xp):
71
+ self.device = "cpu"
72
+ if self.dtype is not None:
73
+ self.dtype = resolve_dtype(self.dtype, self.xp)
74
+ else:
75
+ # Fall back to default dtype for the array namespace
76
+ self.dtype = None
77
+ self.x = self.array_to_namespace(self.x, dtype=self.dtype)
78
+ if self.device is None:
79
+ self.device = api_device(self.x)
80
+ if self.log_likelihood is not None:
81
+ self.log_likelihood = self.array_to_namespace(
82
+ self.log_likelihood, dtype=self.dtype
83
+ )
84
+ if self.log_prior is not None:
85
+ self.log_prior = self.array_to_namespace(
86
+ self.log_prior, dtype=self.dtype
87
+ )
88
+ if self.log_q is not None:
89
+ self.log_q = self.array_to_namespace(self.log_q, dtype=self.dtype)
90
+
91
+ if self.parameters is None:
92
+ self.parameters = [f"x_{i}" for i in range(self.dims)]
93
+
94
+ @property
95
+ def dims(self):
96
+ """Number of dimensions (parameters) in the samples."""
97
+ if self.x is None:
98
+ return 0
99
+ return self.x.shape[1] if self.x.ndim > 1 else 1
100
+
101
+ def to_numpy(self, dtype: Any | str | None = None):
102
+ logger.debug("Converting samples to numpy arrays")
103
+ import array_api_compat.numpy as np
104
+
105
+ if dtype is not None:
106
+ dtype = resolve_dtype(dtype, np)
107
+ else:
108
+ dtype = convert_dtype(self.dtype, np)
109
+ return self.__class__(
110
+ x=self.x,
111
+ parameters=self.parameters,
112
+ log_likelihood=self.log_likelihood,
113
+ log_prior=self.log_prior,
114
+ log_q=self.log_q,
115
+ xp=np,
116
+ )
117
+
118
+ def to_namespace(self, xp, dtype: Any | str | None = None):
119
+ if dtype is None:
120
+ dtype = convert_dtype(self.dtype, xp)
121
+ else:
122
+ dtype = resolve_dtype(dtype, xp)
123
+ logger.debug("Converting samples to {} namespace", xp)
124
+ return self.__class__(
125
+ x=self.x,
126
+ parameters=self.parameters,
127
+ log_likelihood=self.log_likelihood,
128
+ log_prior=self.log_prior,
129
+ log_q=self.log_q,
130
+ xp=xp,
131
+ device=self.device,
132
+ dtype=dtype,
133
+ )
134
+
135
+ def array_to_namespace(self, x, dtype=None):
136
+ """Convert an array to the same namespace as the samples"""
137
+ kwargs = {}
138
+ if dtype is not None:
139
+ kwargs["dtype"] = resolve_dtype(dtype, self.xp)
140
+ else:
141
+ kwargs["dtype"] = self.dtype
142
+ x = asarray(x, self.xp, **kwargs)
143
+ if self.device:
144
+ x = to_device(x, self.device)
145
+ return x
146
+
147
+ def to_dict(self, flat: bool = True):
148
+ samples = dict(zip(self.parameters, self.x.T, strict=True))
149
+ out = {
150
+ "log_likelihood": self.log_likelihood,
151
+ "log_prior": self.log_prior,
152
+ "log_q": self.log_q,
153
+ }
154
+ if flat:
155
+ out.update(samples)
156
+ else:
157
+ out["samples"] = samples
158
+ return out
159
+
160
+ def to_dataframe(self, flat: bool = True):
161
+ import pandas as pd
162
+
163
+ return pd.DataFrame(self.to_dict(flat=flat))
164
+
165
+ def plot_corner(
166
+ self,
167
+ parameters: list[str] | None = None,
168
+ fig: Figure | None = None,
169
+ **kwargs,
170
+ ):
171
+ """Plot a corner plot of the samples.
172
+
173
+ Parameters
174
+ ----------
175
+ parameters : list[str] | None
176
+ List of parameters to plot. If None, all parameters are plotted.
177
+ fig : matplotlib.figure.Figure | None
178
+ Figure to plot on. If None, a new figure is created.
179
+ **kwargs : dict
180
+ Additional keyword arguments to pass to corner.corner(). Kwargs
181
+ are deep-copied before use.
182
+ """
183
+ import corner
184
+
185
+ kwargs = copy.deepcopy(kwargs)
186
+ kwargs.setdefault("labels", self.parameters)
187
+
188
+ if parameters is not None:
189
+ indices = [self.parameters.index(p) for p in parameters]
190
+ kwargs["labels"] = parameters
191
+ x = self.x[:, indices] if self.x.ndim > 1 else self.x[indices]
192
+ else:
193
+ x = self.x
194
+ fig = corner.corner(to_numpy(x), fig=fig, **kwargs)
195
+ return fig
196
+
197
+ def __str__(self):
198
+ out = (
199
+ f"No. samples: {len(self.x)}\nNo. parameters: {self.x.shape[-1]}\n"
200
+ )
201
+ return out
202
+
203
+ def save(self, h5_file, path="samples", flat=False):
204
+ """Save the samples to an HDF5 file.
205
+
206
+ This converts the samples to numpy and then to a dictionary.
207
+
208
+ Parameters
209
+ ----------
210
+ h5_file : h5py.File
211
+ The HDF5 file to save to.
212
+ path : str
213
+ The path in the HDF5 file to save to.
214
+ flat : bool
215
+ If True, save the samples as a flat dictionary.
216
+ If False, save the samples as a nested dictionary.
217
+ """
218
+ dictionary = self.to_numpy().to_dict(flat=flat)
219
+ recursively_save_to_h5_file(h5_file, path, dictionary)
220
+
221
+ def __len__(self):
222
+ return len(self.x)
223
+
224
+ def __getitem__(self, idx) -> BaseSamples:
225
+ return self.__class__(
226
+ x=self.x[idx],
227
+ log_likelihood=self.log_likelihood[idx]
228
+ if self.log_likelihood is not None
229
+ else None,
230
+ log_prior=self.log_prior[idx]
231
+ if self.log_prior is not None
232
+ else None,
233
+ log_q=self.log_q[idx] if self.log_q is not None else None,
234
+ parameters=self.parameters,
235
+ dtype=self.dtype,
236
+ )
237
+
238
+ def __setitem__(self, idx, value: BaseSamples):
239
+ raise NotImplementedError("Setting items is not supported")
240
+
241
+ @classmethod
242
+ def concatenate(cls, samples: list[BaseSamples]) -> BaseSamples:
243
+ """Concatenate multiple Samples objects."""
244
+ if not samples:
245
+ raise ValueError("No samples to concatenate")
246
+ if not all(s.parameters == samples[0].parameters for s in samples):
247
+ raise ValueError("Parameters do not match")
248
+ if not all(s.xp == samples[0].xp for s in samples):
249
+ raise ValueError("Array namespaces do not match")
250
+ if not all(s.dtype == samples[0].dtype for s in samples):
251
+ raise ValueError("Dtypes do not match")
252
+ xp = samples[0].xp
253
+ return cls(
254
+ x=xp.concatenate([s.x for s in samples], axis=0),
255
+ log_likelihood=xp.concatenate(
256
+ [s.log_likelihood for s in samples], axis=0
257
+ )
258
+ if all(s.log_likelihood is not None for s in samples)
259
+ else None,
260
+ log_prior=xp.concatenate([s.log_prior for s in samples], axis=0)
261
+ if all(s.log_prior is not None for s in samples)
262
+ else None,
263
+ log_q=xp.concatenate([s.log_q for s in samples], axis=0)
264
+ if all(s.log_q is not None for s in samples)
265
+ else None,
266
+ parameters=samples[0].parameters,
267
+ dtype=samples[0].dtype,
268
+ )
269
+
270
+ @classmethod
271
+ def from_samples(cls, samples: BaseSamples, **kwargs) -> BaseSamples:
272
+ """Create a Samples object from a BaseSamples object."""
273
+ xp = kwargs.pop("xp", samples.xp)
274
+ device = kwargs.pop("device", samples.device)
275
+ dtype = kwargs.pop("dtype", samples.dtype)
276
+ if dtype is not None:
277
+ dtype = resolve_dtype(dtype, xp)
278
+ else:
279
+ dtype = convert_dtype(samples.dtype, xp)
280
+ return cls(
281
+ x=samples.x,
282
+ log_likelihood=samples.log_likelihood,
283
+ log_prior=samples.log_prior,
284
+ log_q=samples.log_q,
285
+ parameters=samples.parameters,
286
+ xp=xp,
287
+ device=device,
288
+ **kwargs,
289
+ )
290
+
291
+ def __getstate__(self):
292
+ state = self.__dict__.copy()
293
+ # replace xp (callable) with module name string
294
+ if self.xp is not None:
295
+ state["xp"] = (
296
+ self.xp.__name__ if hasattr(self.xp, "__name__") else None
297
+ )
298
+ return state
299
+
300
+ def __setstate__(self, state):
301
+ # Restore xp by checking the namespace of x
302
+ state["xp"] = array_namespace(state["x"])
303
+ self.__dict__.update(state)
304
+
305
+
306
+ @dataclass
307
+ class Samples(BaseSamples):
308
+ """Class for storing samples and corresponding weights.
309
+
310
+ If :code:`xp` is not specified, all inputs will be converted to match
311
+ the array type of :code:`x`.
312
+ """
313
+
314
+ log_w: Array = field(init=False)
315
+ weights: Array = field(init=False)
316
+ evidence: float = field(init=False)
317
+ evidence_error: float = field(init=False)
318
+ log_evidence: float | None = None
319
+ log_evidence_error: float | None = None
320
+ effective_sample_size: float = field(init=False)
321
+
322
+ def __post_init__(self):
323
+ super().__post_init__()
324
+
325
+ if all(
326
+ x is not None
327
+ for x in [self.log_likelihood, self.log_prior, self.log_q]
328
+ ):
329
+ self.compute_weights()
330
+ else:
331
+ self.log_w = None
332
+ self.weights = None
333
+ self.evidence = None
334
+ self.evidence_error = None
335
+ self.effective_sample_size = None
336
+
337
+ @property
338
+ def efficiency(self):
339
+ """Efficiency of the weighted samples.
340
+
341
+ Defined as ESS / number of samples.
342
+ """
343
+ if self.log_w is None:
344
+ raise RuntimeError("Samples do not contain weights!")
345
+ return self.effective_sample_size / len(self.x)
346
+
347
+ def compute_weights(self):
348
+ """Compute the posterior weights."""
349
+ self.log_w = self.log_likelihood + self.log_prior - self.log_q
350
+ self.log_evidence = asarray(logsumexp(self.log_w), self.xp) - math.log(
351
+ len(self.x)
352
+ )
353
+ self.weights = self.xp.exp(self.log_w)
354
+ self.evidence = self.xp.exp(self.log_evidence)
355
+ n = len(self.x)
356
+ self.evidence_error = self.xp.sqrt(
357
+ self.xp.sum((self.weights - self.evidence) ** 2) / (n * (n - 1))
358
+ )
359
+ self.log_evidence_error = self.xp.abs(
360
+ self.evidence_error / self.evidence
361
+ )
362
+ log_w = self.log_w - self.xp.max(self.log_w)
363
+ self.effective_sample_size = self.xp.exp(
364
+ asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2), self.xp)
365
+ )
366
+
367
+ @property
368
+ def scaled_weights(self):
369
+ return self.xp.exp(self.log_w - self.xp.max(self.log_w))
370
+
371
+ def rejection_sample(self, rng=None):
372
+ if rng is None:
373
+ rng = np.random.default_rng()
374
+ log_u = asarray(
375
+ np.log(rng.uniform(size=len(self.x))), self.xp, device=self.device
376
+ )
377
+ log_w = self.log_w - self.xp.max(self.log_w)
378
+ accept = log_w > log_u
379
+ return self.__class__(
380
+ x=self.x[accept],
381
+ log_likelihood=self.log_likelihood[accept],
382
+ log_prior=self.log_prior[accept],
383
+ dtype=self.dtype,
384
+ )
385
+
386
+ def to_dict(self, flat: bool = True):
387
+ samples = dict(zip(self.parameters, self.x.T, strict=True))
388
+ out = super().to_dict(flat=flat)
389
+ other = {
390
+ "log_w": self.log_w,
391
+ "weights": self.weights,
392
+ "evidence": self.evidence,
393
+ "log_evidence": self.log_evidence,
394
+ "evidence_error": self.evidence_error,
395
+ "log_evidence_error": self.log_evidence_error,
396
+ "effective_sample_size": self.effective_sample_size,
397
+ }
398
+ out.update(other)
399
+ if flat:
400
+ out.update(samples)
401
+ else:
402
+ out["samples"] = samples
403
+ return out
404
+
405
+ def plot_corner(self, include_weights: bool = True, **kwargs):
406
+ kwargs = copy.deepcopy(kwargs)
407
+ if (
408
+ include_weights
409
+ and self.weights is not None
410
+ and "weights" not in kwargs
411
+ ):
412
+ kwargs["weights"] = to_numpy(self.scaled_weights)
413
+ return super().plot_corner(**kwargs)
414
+
415
+ def __str__(self):
416
+ out = super().__str__()
417
+ if self.log_evidence is not None:
418
+ out += f"Log evidence: {self.log_evidence:.2f} +/- {self.log_evidence_error:.2f}\n"
419
+ if self.log_w is not None:
420
+ out += (
421
+ f"Effective sample size: {self.effective_sample_size:.1f}\n"
422
+ f"Efficiency: {self.efficiency:.2f}\n"
423
+ )
424
+ return out
425
+
426
+ def to_namespace(self, xp):
427
+ return self.__class__(
428
+ x=asarray(self.x, xp),
429
+ parameters=self.parameters,
430
+ log_likelihood=asarray(self.log_likelihood, xp)
431
+ if self.log_likelihood is not None
432
+ else None,
433
+ log_prior=asarray(self.log_prior, xp)
434
+ if self.log_prior is not None
435
+ else None,
436
+ log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
437
+ log_evidence=asarray(self.log_evidence, xp)
438
+ if self.log_evidence is not None
439
+ else None,
440
+ log_evidence_error=asarray(self.log_evidence_error, xp)
441
+ if self.log_evidence_error is not None
442
+ else None,
443
+ )
444
+
445
+ def to_numpy(self):
446
+ return self.__class__(
447
+ x=to_numpy(self.x),
448
+ parameters=self.parameters,
449
+ log_likelihood=to_numpy(self.log_likelihood)
450
+ if self.log_likelihood is not None
451
+ else None,
452
+ log_prior=to_numpy(self.log_prior)
453
+ if self.log_prior is not None
454
+ else None,
455
+ log_q=to_numpy(self.log_q) if self.log_q is not None else None,
456
+ log_evidence=self.log_evidence
457
+ if self.log_evidence is not None
458
+ else None,
459
+ log_evidence_error=self.log_evidence_error
460
+ if self.log_evidence_error is not None
461
+ else None,
462
+ )
463
+
464
+ def __getitem__(self, idx):
465
+ sliced = super().__getitem__(idx)
466
+ return self.__class__.from_samples(
467
+ sliced,
468
+ log_evidence=self.log_evidence,
469
+ log_evidence_error=self.log_evidence_error,
470
+ )
471
+
472
+
473
+ @dataclass
474
+ class SMCSamples(BaseSamples):
475
+ beta: float | None = None
476
+ """Temperature parameter for the current samples."""
477
+ log_evidence: float | None = None
478
+ """Log evidence estimate for the current samples."""
479
+ log_evidence_error: float | None = None
480
+ """Log evidence error estimate for the current samples."""
481
+
482
+ def log_p_t(self, beta):
483
+ log_p_T = self.log_likelihood + self.log_prior
484
+ return (1 - beta) * self.log_q + beta * log_p_T
485
+
486
+ def unnormalized_log_weights(self, beta: float) -> Array:
487
+ return (self.beta - beta) * self.log_q + (beta - self.beta) * (
488
+ self.log_likelihood + self.log_prior
489
+ )
490
+
491
+ def log_evidence_ratio(self, beta: float) -> float:
492
+ log_w = self.unnormalized_log_weights(beta)
493
+ return logsumexp(log_w) - math.log(len(self.x))
494
+
495
+ def log_evidence_ratio_variance(self, beta: float) -> float:
496
+ """Estimate the variance of the log evidence ratio using the delta method.
497
+
498
+ Defined as Var(log Z) = Var(w) / (E[w])^2 where w are the unnormalized weights.
499
+ """
500
+ log_w = self.unnormalized_log_weights(beta)
501
+ m = self.xp.max(log_w)
502
+ u = self.xp.exp(log_w - m)
503
+ mean_w = self.xp.mean(u)
504
+ var_w = self.xp.var(u)
505
+ return (
506
+ var_w / (len(self) * (mean_w**2)) if mean_w != 0 else self.xp.nan
507
+ )
508
+
509
+ def log_weights(self, beta: float) -> Array:
510
+ log_w = self.unnormalized_log_weights(beta)
511
+ if self.xp.isnan(log_w).any():
512
+ raise ValueError(f"Log weights contain NaN values for beta={beta}")
513
+ log_evidence_ratio = logsumexp(log_w) - math.log(len(self.x))
514
+ return log_w + log_evidence_ratio
515
+
516
+ def resample(
517
+ self,
518
+ beta,
519
+ n_samples: int | None = None,
520
+ rng: np.random.Generator = None,
521
+ ) -> "SMCSamples":
522
+ if beta == self.beta and n_samples is None:
523
+ logger.warning(
524
+ "Resampling with the same beta value, returning identical samples"
525
+ )
526
+ return self
527
+ if rng is None:
528
+ rng = np.random.default_rng()
529
+ if n_samples is None:
530
+ n_samples = len(self.x)
531
+ log_w = self.log_weights(beta)
532
+ w = to_numpy(self.xp.exp(log_w - logsumexp(log_w)))
533
+ idx = rng.choice(len(self.x), size=n_samples, replace=True, p=w)
534
+ return self.__class__(
535
+ x=self.x[idx],
536
+ log_likelihood=self.log_likelihood[idx],
537
+ log_prior=self.log_prior[idx],
538
+ log_q=self.log_q[idx],
539
+ beta=beta,
540
+ dtype=self.dtype,
541
+ )
542
+
543
+ def __str__(self):
544
+ out = super().__str__()
545
+ if self.log_evidence is not None:
546
+ out += f"Log evidence: {self.log_evidence:.2f}\n"
547
+ return out
548
+
549
+ def to_standard_samples(self):
550
+ """Convert the samples to standard samples."""
551
+ return Samples(
552
+ x=self.x,
553
+ log_likelihood=self.log_likelihood,
554
+ log_prior=self.log_prior,
555
+ xp=self.xp,
556
+ parameters=self.parameters,
557
+ log_evidence=self.log_evidence,
558
+ log_evidence_error=self.log_evidence_error,
559
+ )
560
+
561
+ def __getitem__(self, idx):
562
+ sliced = super().__getitem__(idx)
563
+ return self.__class__.from_samples(
564
+ sliced,
565
+ beta=self.beta,
566
+ log_evidence=self.log_evidence,
567
+ log_evidence_error=self.log_evidence_error,
568
+ )