aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
aspire/samples.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import logging
|
|
5
|
+
import math
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from array_api_compat import (
|
|
11
|
+
array_namespace,
|
|
12
|
+
is_numpy_namespace,
|
|
13
|
+
to_device,
|
|
14
|
+
)
|
|
15
|
+
from array_api_compat import device as api_device
|
|
16
|
+
from array_api_compat.common._typing import Array
|
|
17
|
+
|
|
18
|
+
from .utils import asarray, logsumexp, recursively_save_to_h5_file, to_numpy
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class BaseSamples:
|
|
25
|
+
"""Class for storing samples and corresponding weights.
|
|
26
|
+
|
|
27
|
+
If :code:`xp` is not specified, all inputs will be converted to match
|
|
28
|
+
the array type of :code:`x`.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
x: Array
|
|
32
|
+
log_likelihood: Array | None = None
|
|
33
|
+
log_prior: Array | None = None
|
|
34
|
+
log_q: Array | None = None
|
|
35
|
+
parameters: list[str] | None = None
|
|
36
|
+
xp: Callable | None = None
|
|
37
|
+
device: Any = None
|
|
38
|
+
|
|
39
|
+
def __post_init__(self):
|
|
40
|
+
if self.xp is None:
|
|
41
|
+
self.xp = array_namespace(self.x)
|
|
42
|
+
# Numpy arrays need to be on the CPU before being converted
|
|
43
|
+
if is_numpy_namespace(self.xp):
|
|
44
|
+
self.device = "cpu"
|
|
45
|
+
self.x = self.array_to_namespace(self.x)
|
|
46
|
+
if self.device is None:
|
|
47
|
+
self.device = api_device(self.x)
|
|
48
|
+
if self.log_likelihood is not None:
|
|
49
|
+
self.log_likelihood = self.array_to_namespace(self.log_likelihood)
|
|
50
|
+
if self.log_prior is not None:
|
|
51
|
+
self.log_prior = self.array_to_namespace(self.log_prior)
|
|
52
|
+
if self.log_q is not None:
|
|
53
|
+
self.log_q = self.array_to_namespace(self.log_q)
|
|
54
|
+
|
|
55
|
+
if self.parameters is None:
|
|
56
|
+
self.parameters = [f"x_{i}" for i in range(self.dims)]
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def dims(self):
|
|
60
|
+
"""Number of dimensions (parameters) in the samples."""
|
|
61
|
+
if self.x is None:
|
|
62
|
+
return 0
|
|
63
|
+
return self.x.shape[1] if self.x.ndim > 1 else 1
|
|
64
|
+
|
|
65
|
+
def to_numpy(self):
|
|
66
|
+
logger.debug("Converting samples to numpy arrays")
|
|
67
|
+
return self.__class__(
|
|
68
|
+
x=to_numpy(self.x),
|
|
69
|
+
parameters=self.parameters,
|
|
70
|
+
log_likelihood=to_numpy(self.log_likelihood)
|
|
71
|
+
if self.log_likelihood is not None
|
|
72
|
+
else None,
|
|
73
|
+
log_prior=to_numpy(self.log_prior)
|
|
74
|
+
if self.log_prior is not None
|
|
75
|
+
else None,
|
|
76
|
+
log_q=to_numpy(self.log_q) if self.log_q is not None else None,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def to_namespace(self, xp):
|
|
80
|
+
logger.debug("Converting samples to {} namespace", xp)
|
|
81
|
+
return self.__class__(
|
|
82
|
+
x=asarray(self.x, xp),
|
|
83
|
+
parameters=self.parameters,
|
|
84
|
+
log_likelihood=asarray(self.log_likelihood, xp)
|
|
85
|
+
if self.log_likelihood is not None
|
|
86
|
+
else None,
|
|
87
|
+
log_prior=asarray(self.log_prior, xp)
|
|
88
|
+
if self.log_prior is not None
|
|
89
|
+
else None,
|
|
90
|
+
log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def array_to_namespace(self, x):
|
|
94
|
+
"""Convert an array to the same namespace as the samples"""
|
|
95
|
+
x = asarray(x, self.xp)
|
|
96
|
+
if self.device:
|
|
97
|
+
x = to_device(x, self.device)
|
|
98
|
+
return x
|
|
99
|
+
|
|
100
|
+
def to_dict(self, flat: bool = True):
|
|
101
|
+
samples = dict(zip(self.parameters, self.x.T, strict=True))
|
|
102
|
+
out = {
|
|
103
|
+
"log_likelihood": self.log_likelihood,
|
|
104
|
+
"log_prior": self.log_prior,
|
|
105
|
+
"log_q": self.log_q,
|
|
106
|
+
}
|
|
107
|
+
if flat:
|
|
108
|
+
out.update(samples)
|
|
109
|
+
else:
|
|
110
|
+
out["samples"] = samples
|
|
111
|
+
return out
|
|
112
|
+
|
|
113
|
+
def to_dataframe(self, flat: bool = True):
|
|
114
|
+
import pandas as pd
|
|
115
|
+
|
|
116
|
+
return pd.DataFrame(self.to_dict(flat=flat))
|
|
117
|
+
|
|
118
|
+
def plot_corner(self, parameters: list[str] | None = None, **kwargs):
|
|
119
|
+
import corner
|
|
120
|
+
|
|
121
|
+
kwargs = copy.deepcopy(kwargs)
|
|
122
|
+
kwargs.setdefault("labels", self.parameters)
|
|
123
|
+
|
|
124
|
+
if parameters is not None:
|
|
125
|
+
indices = [self.parameters.index(p) for p in parameters]
|
|
126
|
+
kwargs["labels"] = parameters
|
|
127
|
+
x = self.x[:, indices] if self.x.ndim > 1 else self.x[indices]
|
|
128
|
+
else:
|
|
129
|
+
x = self.x
|
|
130
|
+
fig = corner.corner(to_numpy(x), **kwargs)
|
|
131
|
+
return fig
|
|
132
|
+
|
|
133
|
+
def __str__(self):
|
|
134
|
+
out = (
|
|
135
|
+
f"No. samples: {len(self.x)}\n"
|
|
136
|
+
f"No. parameters: {len(self.parameters)}\n"
|
|
137
|
+
)
|
|
138
|
+
return out
|
|
139
|
+
|
|
140
|
+
def save(self, h5_file, path="samples", flat=False):
|
|
141
|
+
"""Save the samples to an HDF5 file.
|
|
142
|
+
|
|
143
|
+
This converts the samples to numpy and then to a dictionary.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
h5_file : h5py.File
|
|
148
|
+
The HDF5 file to save to.
|
|
149
|
+
path : str
|
|
150
|
+
The path in the HDF5 file to save to.
|
|
151
|
+
flat : bool
|
|
152
|
+
If True, save the samples as a flat dictionary.
|
|
153
|
+
If False, save the samples as a nested dictionary.
|
|
154
|
+
"""
|
|
155
|
+
dictionary = self.to_numpy().to_dict(flat=flat)
|
|
156
|
+
recursively_save_to_h5_file(h5_file, path, dictionary)
|
|
157
|
+
|
|
158
|
+
def __len__(self):
|
|
159
|
+
return len(self.x)
|
|
160
|
+
|
|
161
|
+
def __getitem__(self, idx) -> BaseSamples:
|
|
162
|
+
return self.__class__(
|
|
163
|
+
x=self.x[idx],
|
|
164
|
+
log_likelihood=self.log_likelihood[idx]
|
|
165
|
+
if self.log_likelihood is not None
|
|
166
|
+
else None,
|
|
167
|
+
log_prior=self.log_prior[idx]
|
|
168
|
+
if self.log_prior is not None
|
|
169
|
+
else None,
|
|
170
|
+
log_q=self.log_q[idx] if self.log_q is not None else None,
|
|
171
|
+
parameters=self.parameters,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def __setitem__(self, idx, value: BaseSamples):
|
|
175
|
+
raise NotImplementedError("Setting items is not supported")
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def concatenate(cls, samples: list[BaseSamples]) -> BaseSamples:
|
|
179
|
+
"""Concatenate multiple Samples objects."""
|
|
180
|
+
if not samples:
|
|
181
|
+
raise ValueError("No samples to concatenate")
|
|
182
|
+
if not all(s.parameters == samples[0].parameters for s in samples):
|
|
183
|
+
raise ValueError("Parameters do not match")
|
|
184
|
+
if not all(s.xp == samples[0].xp for s in samples):
|
|
185
|
+
raise ValueError("Array namespaces do not match")
|
|
186
|
+
xp = samples[0].xp
|
|
187
|
+
return cls(
|
|
188
|
+
x=xp.concatenate([s.x for s in samples], axis=0),
|
|
189
|
+
log_likelihood=xp.concatenate(
|
|
190
|
+
[s.log_likelihood for s in samples], axis=0
|
|
191
|
+
)
|
|
192
|
+
if all(s.log_likelihood is not None for s in samples)
|
|
193
|
+
else None,
|
|
194
|
+
log_prior=xp.concatenate([s.log_prior for s in samples], axis=0)
|
|
195
|
+
if all(s.log_prior is not None for s in samples)
|
|
196
|
+
else None,
|
|
197
|
+
log_q=xp.concatenate([s.log_q for s in samples], axis=0)
|
|
198
|
+
if all(s.log_q is not None for s in samples)
|
|
199
|
+
else None,
|
|
200
|
+
parameters=samples[0].parameters,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def from_samples(cls, samples: BaseSamples, **kwargs) -> BaseSamples:
|
|
205
|
+
"""Create a Samples object from a BaseSamples object."""
|
|
206
|
+
xp = kwargs.pop("xp", samples.xp)
|
|
207
|
+
device = kwargs.pop("device", samples.device)
|
|
208
|
+
return cls(
|
|
209
|
+
x=samples.x,
|
|
210
|
+
log_likelihood=samples.log_likelihood,
|
|
211
|
+
log_prior=samples.log_prior,
|
|
212
|
+
log_q=samples.log_q,
|
|
213
|
+
parameters=samples.parameters,
|
|
214
|
+
xp=xp,
|
|
215
|
+
device=device,
|
|
216
|
+
**kwargs,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@dataclass
|
|
221
|
+
class Samples(BaseSamples):
|
|
222
|
+
"""Class for storing samples and corresponding weights.
|
|
223
|
+
|
|
224
|
+
If :code:`xp` is not specified, all inputs will be converted to match
|
|
225
|
+
the array type of :code:`x`.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
log_w: Array = field(init=False)
|
|
229
|
+
weights: Array = field(init=False)
|
|
230
|
+
evidence: float = field(init=False)
|
|
231
|
+
evidence_error: float = field(init=False)
|
|
232
|
+
log_evidence: float | None = None
|
|
233
|
+
log_evidence_error: float | None = None
|
|
234
|
+
effective_sample_size: float = field(init=False)
|
|
235
|
+
|
|
236
|
+
def __post_init__(self):
|
|
237
|
+
super().__post_init__()
|
|
238
|
+
|
|
239
|
+
if all(
|
|
240
|
+
x is not None
|
|
241
|
+
for x in [self.log_likelihood, self.log_prior, self.log_q]
|
|
242
|
+
):
|
|
243
|
+
self.compute_weights()
|
|
244
|
+
else:
|
|
245
|
+
self.log_w = None
|
|
246
|
+
self.weights = None
|
|
247
|
+
self.evidence = None
|
|
248
|
+
self.evidence_error = None
|
|
249
|
+
self.effective_sample_size = None
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def efficiency(self):
|
|
253
|
+
"""Efficiency of the weighted samples.
|
|
254
|
+
|
|
255
|
+
Defined as ESS / number of samples.
|
|
256
|
+
"""
|
|
257
|
+
if self.log_w is None:
|
|
258
|
+
raise RuntimeError("Samples do not contain weights!")
|
|
259
|
+
return self.effective_sample_size / len(self.x)
|
|
260
|
+
|
|
261
|
+
def compute_weights(self):
|
|
262
|
+
"""Compute the posterior weights."""
|
|
263
|
+
self.log_w = self.log_likelihood + self.log_prior - self.log_q
|
|
264
|
+
self.log_evidence = asarray(logsumexp(self.log_w), self.xp) - math.log(
|
|
265
|
+
len(self.x)
|
|
266
|
+
)
|
|
267
|
+
self.weights = self.xp.exp(self.log_w)
|
|
268
|
+
self.evidence = self.xp.exp(self.log_evidence)
|
|
269
|
+
n = len(self.x)
|
|
270
|
+
self.evidence_error = self.xp.sqrt(
|
|
271
|
+
self.xp.sum((self.weights - self.evidence) ** 2) / (n * (n - 1))
|
|
272
|
+
)
|
|
273
|
+
self.log_evidence_error = self.xp.abs(
|
|
274
|
+
self.evidence_error / self.evidence
|
|
275
|
+
)
|
|
276
|
+
log_w = self.log_w - self.xp.max(self.log_w)
|
|
277
|
+
self.effective_sample_size = self.xp.exp(
|
|
278
|
+
asarray(logsumexp(log_w) * 2 - logsumexp(log_w * 2), self.xp)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def scaled_weights(self):
|
|
283
|
+
return self.xp.exp(self.log_w - self.xp.max(self.log_w))
|
|
284
|
+
|
|
285
|
+
def rejection_sample(self, rng=None):
|
|
286
|
+
if rng is None:
|
|
287
|
+
rng = np.random.default_rng()
|
|
288
|
+
log_u = asarray(
|
|
289
|
+
np.log(rng.uniform(size=len(self.x))), self.xp, device=self.device
|
|
290
|
+
)
|
|
291
|
+
log_w = self.log_w - self.xp.max(self.log_w)
|
|
292
|
+
accept = log_w > log_u
|
|
293
|
+
return self.__class__(
|
|
294
|
+
x=self.x[accept],
|
|
295
|
+
log_likelihood=self.log_likelihood[accept],
|
|
296
|
+
log_prior=self.log_prior[accept],
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def to_dict(self, flat: bool = True):
|
|
300
|
+
samples = dict(zip(self.parameters, self.x.T, strict=True))
|
|
301
|
+
out = super().to_dict(flat=flat)
|
|
302
|
+
other = {
|
|
303
|
+
"log_w": self.log_w,
|
|
304
|
+
"weights": self.weights,
|
|
305
|
+
"evidence": self.evidence,
|
|
306
|
+
"log_evidence": self.log_evidence,
|
|
307
|
+
"evidence_error": self.evidence_error,
|
|
308
|
+
"log_evidence_error": self.log_evidence_error,
|
|
309
|
+
"effective_sample_size": self.effective_sample_size,
|
|
310
|
+
}
|
|
311
|
+
out.update(other)
|
|
312
|
+
if flat:
|
|
313
|
+
out.update(samples)
|
|
314
|
+
else:
|
|
315
|
+
out["samples"] = samples
|
|
316
|
+
return out
|
|
317
|
+
|
|
318
|
+
def plot_corner(self, include_weights: bool = True, **kwargs):
|
|
319
|
+
kwargs = copy.deepcopy(kwargs)
|
|
320
|
+
if (
|
|
321
|
+
include_weights
|
|
322
|
+
and self.weights is not None
|
|
323
|
+
and "weights" not in kwargs
|
|
324
|
+
):
|
|
325
|
+
kwargs["weights"] = to_numpy(self.scaled_weights)
|
|
326
|
+
return super().plot_corner(**kwargs)
|
|
327
|
+
|
|
328
|
+
def __str__(self):
|
|
329
|
+
out = super().__str__()
|
|
330
|
+
if self.log_evidence is not None:
|
|
331
|
+
out += f"Log evidence: {self.log_evidence:.2f} +/- {self.log_evidence_error:.2f}\n"
|
|
332
|
+
if self.log_w is not None:
|
|
333
|
+
out += (
|
|
334
|
+
f"Effective sample size: {self.effective_sample_size:.1f}\n"
|
|
335
|
+
f"Efficiency: {self.efficiency:.2f}\n"
|
|
336
|
+
)
|
|
337
|
+
return out
|
|
338
|
+
|
|
339
|
+
def to_namespace(self, xp):
|
|
340
|
+
return self.__class__(
|
|
341
|
+
x=asarray(self.x, xp),
|
|
342
|
+
parameters=self.parameters,
|
|
343
|
+
log_likelihood=asarray(self.log_likelihood, xp)
|
|
344
|
+
if self.log_likelihood is not None
|
|
345
|
+
else None,
|
|
346
|
+
log_prior=asarray(self.log_prior, xp)
|
|
347
|
+
if self.log_prior is not None
|
|
348
|
+
else None,
|
|
349
|
+
log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
|
|
350
|
+
log_evidence=asarray(self.log_evidence, xp)
|
|
351
|
+
if self.log_evidence is not None
|
|
352
|
+
else None,
|
|
353
|
+
log_evidence_error=asarray(self.log_evidence_error, xp)
|
|
354
|
+
if self.log_evidence_error is not None
|
|
355
|
+
else None,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
def to_numpy(self):
|
|
359
|
+
return self.__class__(
|
|
360
|
+
x=to_numpy(self.x),
|
|
361
|
+
parameters=self.parameters,
|
|
362
|
+
log_likelihood=to_numpy(self.log_likelihood)
|
|
363
|
+
if self.log_likelihood is not None
|
|
364
|
+
else None,
|
|
365
|
+
log_prior=to_numpy(self.log_prior)
|
|
366
|
+
if self.log_prior is not None
|
|
367
|
+
else None,
|
|
368
|
+
log_q=to_numpy(self.log_q) if self.log_q is not None else None,
|
|
369
|
+
log_evidence=self.log_evidence
|
|
370
|
+
if self.log_evidence is not None
|
|
371
|
+
else None,
|
|
372
|
+
log_evidence_error=self.log_evidence_error
|
|
373
|
+
if self.log_evidence_error is not None
|
|
374
|
+
else None,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def __getitem__(self, idx):
|
|
378
|
+
sliced = super().__getitem__(idx)
|
|
379
|
+
return self.__class__.from_samples(
|
|
380
|
+
sliced,
|
|
381
|
+
log_evidence=self.log_evidence,
|
|
382
|
+
log_evidence_error=self.log_evidence_error,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@dataclass
|
|
387
|
+
class SMCSamples(BaseSamples):
|
|
388
|
+
beta: float | None = None
|
|
389
|
+
log_evidence: float | None = None
|
|
390
|
+
"""Temperature parameter for the current samples."""
|
|
391
|
+
|
|
392
|
+
def log_p_t(self, beta):
|
|
393
|
+
log_p_T = self.log_likelihood + self.log_prior
|
|
394
|
+
return (1 - beta) * self.log_q + beta * log_p_T
|
|
395
|
+
|
|
396
|
+
def unnormalized_log_weights(self, beta: float) -> Array:
|
|
397
|
+
return (self.beta - beta) * self.log_q + (beta - self.beta) * (
|
|
398
|
+
self.log_likelihood + self.log_prior
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
def log_evidence_ratio(self, beta: float) -> float:
|
|
402
|
+
log_w = self.unnormalized_log_weights(beta)
|
|
403
|
+
return logsumexp(log_w) - math.log(len(self.x))
|
|
404
|
+
|
|
405
|
+
def log_evidence_ratio_variance(self, beta: float) -> float:
|
|
406
|
+
"""Estimate the variance of the log evidence ratio using the delta method.
|
|
407
|
+
|
|
408
|
+
Defined as Var(log Z) = Var(w) / (E[w])^2 where w are the unnormalized weights.
|
|
409
|
+
"""
|
|
410
|
+
log_w = self.unnormalized_log_weights(beta)
|
|
411
|
+
m = self.xp.max(log_w)
|
|
412
|
+
u = self.xp.exp(log_w - m)
|
|
413
|
+
mean_w = self.xp.mean(u)
|
|
414
|
+
var_w = self.xp.var(u)
|
|
415
|
+
return (
|
|
416
|
+
var_w / (len(self) * (mean_w**2)) if mean_w != 0 else self.xp.nan
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def log_weights(self, beta: float) -> Array:
|
|
420
|
+
log_w = self.unnormalized_log_weights(beta)
|
|
421
|
+
if self.xp.isnan(log_w).any():
|
|
422
|
+
raise ValueError(f"Log weights contain NaN values for beta={beta}")
|
|
423
|
+
log_evidence_ratio = logsumexp(log_w) - math.log(len(self.x))
|
|
424
|
+
return log_w + log_evidence_ratio
|
|
425
|
+
|
|
426
|
+
def resample(
|
|
427
|
+
self,
|
|
428
|
+
beta,
|
|
429
|
+
n_samples: int | None = None,
|
|
430
|
+
rng: np.random.Generator = None,
|
|
431
|
+
) -> "SMCSamples":
|
|
432
|
+
if beta == self.beta and n_samples is None:
|
|
433
|
+
logger.warning(
|
|
434
|
+
"Resampling with the same beta value, returning identical samples"
|
|
435
|
+
)
|
|
436
|
+
return self
|
|
437
|
+
if rng is None:
|
|
438
|
+
rng = np.random.default_rng()
|
|
439
|
+
if n_samples is None:
|
|
440
|
+
n_samples = len(self.x)
|
|
441
|
+
log_w = self.log_weights(beta)
|
|
442
|
+
w = to_numpy(self.xp.exp(log_w - logsumexp(log_w)))
|
|
443
|
+
idx = rng.choice(len(self.x), size=n_samples, replace=True, p=w)
|
|
444
|
+
return self.__class__(
|
|
445
|
+
x=self.x[idx],
|
|
446
|
+
log_likelihood=self.log_likelihood[idx],
|
|
447
|
+
log_prior=self.log_prior[idx],
|
|
448
|
+
log_q=self.log_q[idx],
|
|
449
|
+
beta=beta,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def __str__(self):
|
|
453
|
+
out = super().__str__()
|
|
454
|
+
if self.log_evidence is not None:
|
|
455
|
+
out += f"Log evidence: {self.log_evidence:.2f}\n"
|
|
456
|
+
return out
|
|
457
|
+
|
|
458
|
+
def to_standard_samples(self):
|
|
459
|
+
"""Convert the samples to standard samples."""
|
|
460
|
+
return Samples(
|
|
461
|
+
x=self.x,
|
|
462
|
+
log_likelihood=self.log_likelihood,
|
|
463
|
+
log_prior=self.log_prior,
|
|
464
|
+
xp=self.xp,
|
|
465
|
+
parameters=self.parameters,
|
|
466
|
+
log_evidence=self.log_evidence,
|
|
467
|
+
log_evidence_error=self.log_evidence_error,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def __getitem__(self, idx):
|
|
471
|
+
sliced = super().__getitem__(idx)
|
|
472
|
+
return self.__class__.from_samples(
|
|
473
|
+
sliced,
|
|
474
|
+
beta=self.beta,
|
|
475
|
+
log_evidence=self.log_evidence,
|
|
476
|
+
)
|