aspire-inference 0.1.0a5__py3-none-any.whl → 0.1.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/aspire.py +55 -6
- aspire/flows/base.py +37 -0
- aspire/flows/jax/flows.py +118 -4
- aspire/flows/jax/utils.py +4 -1
- aspire/flows/torch/flows.py +86 -18
- aspire/samplers/base.py +3 -1
- aspire/samplers/importance.py +5 -1
- aspire/samplers/mcmc.py +5 -3
- aspire/samplers/smc/base.py +11 -5
- aspire/samplers/smc/blackjax.py +4 -2
- aspire/samplers/smc/emcee.py +1 -1
- aspire/samplers/smc/minipcn.py +1 -1
- aspire/samples.py +88 -28
- aspire/transforms.py +297 -44
- aspire/utils.py +285 -16
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/METADATA +2 -1
- aspire_inference-0.1.0a6.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a5.dist-info/RECORD +0 -28
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/WHEEL +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/licenses/LICENSE +0 -0
- {aspire_inference-0.1.0a5.dist-info → aspire_inference-0.1.0a6.dist-info}/top_level.txt +0 -0
aspire/samplers/smc/base.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Callable
|
|
2
|
+
from typing import Any, Callable
|
|
3
3
|
|
|
4
4
|
import array_api_compat.numpy as np
|
|
5
5
|
|
|
@@ -27,6 +27,7 @@ class SMCSampler(MCMCSampler):
|
|
|
27
27
|
dims: int,
|
|
28
28
|
prior_flow: Flow,
|
|
29
29
|
xp: Callable,
|
|
30
|
+
dtype: Any | str | None = None,
|
|
30
31
|
parameters: list[str] | None = None,
|
|
31
32
|
rng: np.random.Generator | None = None,
|
|
32
33
|
preconditioning_transform: Callable | None = None,
|
|
@@ -37,6 +38,7 @@ class SMCSampler(MCMCSampler):
|
|
|
37
38
|
dims=dims,
|
|
38
39
|
prior_flow=prior_flow,
|
|
39
40
|
xp=xp,
|
|
41
|
+
dtype=dtype,
|
|
40
42
|
parameters=parameters,
|
|
41
43
|
preconditioning_transform=preconditioning_transform,
|
|
42
44
|
)
|
|
@@ -158,7 +160,9 @@ class SMCSampler(MCMCSampler):
|
|
|
158
160
|
n_final_samples: int | None = None,
|
|
159
161
|
) -> SMCSamples:
|
|
160
162
|
samples = self.draw_initial_samples(n_samples)
|
|
161
|
-
samples = SMCSamples.from_samples(
|
|
163
|
+
samples = SMCSamples.from_samples(
|
|
164
|
+
samples, xp=self.xp, beta=0.0, dtype=self.dtype
|
|
165
|
+
)
|
|
162
166
|
self.fit_preconditioning_transform(samples.x)
|
|
163
167
|
|
|
164
168
|
if self.xp.isnan(samples.log_q).any():
|
|
@@ -271,7 +275,7 @@ class SMCSampler(MCMCSampler):
|
|
|
271
275
|
|
|
272
276
|
def log_prob(self, z, beta=None):
|
|
273
277
|
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
274
|
-
samples = SMCSamples(x, xp=self.xp)
|
|
278
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
275
279
|
log_q = self.prior_flow.log_prob(samples.x)
|
|
276
280
|
samples.log_q = samples.array_to_namespace(log_q)
|
|
277
281
|
samples.log_prior = self.log_prior(samples)
|
|
@@ -294,6 +298,7 @@ class NumpySMCSampler(SMCSampler):
|
|
|
294
298
|
dims,
|
|
295
299
|
prior_flow,
|
|
296
300
|
xp,
|
|
301
|
+
dtype=None,
|
|
297
302
|
parameters=None,
|
|
298
303
|
preconditioning_transform=None,
|
|
299
304
|
):
|
|
@@ -305,8 +310,9 @@ class NumpySMCSampler(SMCSampler):
|
|
|
305
310
|
log_likelihood,
|
|
306
311
|
log_prior,
|
|
307
312
|
dims,
|
|
308
|
-
prior_flow,
|
|
309
|
-
xp,
|
|
313
|
+
prior_flow=prior_flow,
|
|
314
|
+
xp=xp,
|
|
315
|
+
dtype=dtype,
|
|
310
316
|
parameters=parameters,
|
|
311
317
|
preconditioning_transform=preconditioning_transform,
|
|
312
318
|
)
|
aspire/samplers/smc/blackjax.py
CHANGED
|
@@ -20,6 +20,7 @@ class BlackJAXSMC(SMCSampler):
|
|
|
20
20
|
dims,
|
|
21
21
|
prior_flow,
|
|
22
22
|
xp,
|
|
23
|
+
dtype=None,
|
|
23
24
|
parameters=None,
|
|
24
25
|
preconditioning_transform=None,
|
|
25
26
|
rng: np.random.Generator | None = None, # New parameter
|
|
@@ -31,6 +32,7 @@ class BlackJAXSMC(SMCSampler):
|
|
|
31
32
|
dims=dims,
|
|
32
33
|
prior_flow=prior_flow,
|
|
33
34
|
xp=xp,
|
|
35
|
+
dtype=dtype,
|
|
34
36
|
parameters=parameters,
|
|
35
37
|
preconditioning_transform=preconditioning_transform,
|
|
36
38
|
)
|
|
@@ -49,7 +51,7 @@ class BlackJAXSMC(SMCSampler):
|
|
|
49
51
|
x_params, log_abs_det_jacobian = (
|
|
50
52
|
self.preconditioning_transform.inverse(x_original)
|
|
51
53
|
)
|
|
52
|
-
samples = SMCSamples(x_params, xp=self.xp)
|
|
54
|
+
samples = SMCSamples(x_params, xp=self.xp, dtype=self.dtype)
|
|
53
55
|
|
|
54
56
|
# Compute log probabilities
|
|
55
57
|
log_q = self.prior_flow.log_prob(samples.x)
|
|
@@ -306,7 +308,7 @@ class BlackJAXSMC(SMCSampler):
|
|
|
306
308
|
self.history.mcmc_acceptance.append(float(mean_acceptance))
|
|
307
309
|
|
|
308
310
|
# Create new samples
|
|
309
|
-
samples = SMCSamples(x_final, xp=self.xp, beta=beta)
|
|
311
|
+
samples = SMCSamples(x_final, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
310
312
|
samples.log_q = samples.array_to_namespace(
|
|
311
313
|
self.prior_flow.log_prob(samples.x)
|
|
312
314
|
)
|
aspire/samplers/smc/emcee.py
CHANGED
|
@@ -62,7 +62,7 @@ class EmceeSMC(NumpySMCSampler):
|
|
|
62
62
|
)
|
|
63
63
|
z = sampler.get_chain(flat=False)[-1, ...]
|
|
64
64
|
x = self.preconditioning_transform.inverse(z)[0]
|
|
65
|
-
samples = SMCSamples(x, xp=self.xp, beta=beta)
|
|
65
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
66
66
|
samples.log_q = samples.array_to_namespace(
|
|
67
67
|
self.prior_flow.log_prob(samples.x)
|
|
68
68
|
)
|
aspire/samplers/smc/minipcn.py
CHANGED
|
@@ -69,7 +69,7 @@ class MiniPCNSMC(NumpySMCSampler):
|
|
|
69
69
|
|
|
70
70
|
self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
|
|
71
71
|
|
|
72
|
-
samples = SMCSamples(x, xp=self.xp, beta=beta)
|
|
72
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
73
73
|
samples.log_q = samples.array_to_namespace(
|
|
74
74
|
self.prior_flow.log_prob(samples.x)
|
|
75
75
|
)
|
aspire/samples.py
CHANGED
|
@@ -15,7 +15,14 @@ from array_api_compat import (
|
|
|
15
15
|
from array_api_compat import device as api_device
|
|
16
16
|
from array_api_compat.common._typing import Array
|
|
17
17
|
|
|
18
|
-
from .utils import
|
|
18
|
+
from .utils import (
|
|
19
|
+
asarray,
|
|
20
|
+
convert_dtype,
|
|
21
|
+
logsumexp,
|
|
22
|
+
recursively_save_to_h5_file,
|
|
23
|
+
resolve_dtype,
|
|
24
|
+
to_numpy,
|
|
25
|
+
)
|
|
19
26
|
|
|
20
27
|
logger = logging.getLogger(__name__)
|
|
21
28
|
|
|
@@ -29,12 +36,31 @@ class BaseSamples:
|
|
|
29
36
|
"""
|
|
30
37
|
|
|
31
38
|
x: Array
|
|
39
|
+
"""Array of samples, shape (n_samples, n_dims)."""
|
|
32
40
|
log_likelihood: Array | None = None
|
|
41
|
+
"""Log-likelihood values for the samples."""
|
|
33
42
|
log_prior: Array | None = None
|
|
43
|
+
"""Log-prior values for the samples."""
|
|
34
44
|
log_q: Array | None = None
|
|
45
|
+
"""Log-probability values under the proposal distribution."""
|
|
35
46
|
parameters: list[str] | None = None
|
|
47
|
+
"""List of parameter names."""
|
|
48
|
+
dtype: Any | str | None = None
|
|
49
|
+
"""Data type of the samples.
|
|
50
|
+
|
|
51
|
+
If None, the default dtype for the array namespace will be used.
|
|
52
|
+
"""
|
|
36
53
|
xp: Callable | None = None
|
|
54
|
+
"""
|
|
55
|
+
The array namespace to use for the samples.
|
|
56
|
+
|
|
57
|
+
If None, the array namespace will be inferred from the type of :code:`x`.
|
|
58
|
+
"""
|
|
37
59
|
device: Any = None
|
|
60
|
+
"""Device to store the samples on.
|
|
61
|
+
|
|
62
|
+
If None, the device will be inferred from the array namespace of :code:`x`.
|
|
63
|
+
"""
|
|
38
64
|
|
|
39
65
|
def __post_init__(self):
|
|
40
66
|
if self.xp is None:
|
|
@@ -42,15 +68,24 @@ class BaseSamples:
|
|
|
42
68
|
# Numpy arrays need to be on the CPU before being converted
|
|
43
69
|
if is_numpy_namespace(self.xp):
|
|
44
70
|
self.device = "cpu"
|
|
45
|
-
self.
|
|
71
|
+
if self.dtype is not None:
|
|
72
|
+
self.dtype = resolve_dtype(self.dtype, self.xp)
|
|
73
|
+
else:
|
|
74
|
+
# Fall back to default dtype for the array namespace
|
|
75
|
+
self.dtype = None
|
|
76
|
+
self.x = self.array_to_namespace(self.x, dtype=self.dtype)
|
|
46
77
|
if self.device is None:
|
|
47
78
|
self.device = api_device(self.x)
|
|
48
79
|
if self.log_likelihood is not None:
|
|
49
|
-
self.log_likelihood = self.array_to_namespace(
|
|
80
|
+
self.log_likelihood = self.array_to_namespace(
|
|
81
|
+
self.log_likelihood, dtype=self.dtype
|
|
82
|
+
)
|
|
50
83
|
if self.log_prior is not None:
|
|
51
|
-
self.log_prior = self.array_to_namespace(
|
|
84
|
+
self.log_prior = self.array_to_namespace(
|
|
85
|
+
self.log_prior, dtype=self.dtype
|
|
86
|
+
)
|
|
52
87
|
if self.log_q is not None:
|
|
53
|
-
self.log_q = self.array_to_namespace(self.log_q)
|
|
88
|
+
self.log_q = self.array_to_namespace(self.log_q, dtype=self.dtype)
|
|
54
89
|
|
|
55
90
|
if self.parameters is None:
|
|
56
91
|
self.parameters = [f"x_{i}" for i in range(self.dims)]
|
|
@@ -62,37 +97,48 @@ class BaseSamples:
|
|
|
62
97
|
return 0
|
|
63
98
|
return self.x.shape[1] if self.x.ndim > 1 else 1
|
|
64
99
|
|
|
65
|
-
def to_numpy(self):
|
|
100
|
+
def to_numpy(self, dtype: Any | str | None = None):
|
|
66
101
|
logger.debug("Converting samples to numpy arrays")
|
|
102
|
+
import array_api_compat.numpy as np
|
|
103
|
+
|
|
104
|
+
if dtype is not None:
|
|
105
|
+
dtype = resolve_dtype(dtype, np)
|
|
106
|
+
else:
|
|
107
|
+
dtype = convert_dtype(self.dtype, np)
|
|
67
108
|
return self.__class__(
|
|
68
|
-
x=
|
|
109
|
+
x=self.x,
|
|
69
110
|
parameters=self.parameters,
|
|
70
|
-
log_likelihood=
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
if self.log_prior is not None
|
|
75
|
-
else None,
|
|
76
|
-
log_q=to_numpy(self.log_q) if self.log_q is not None else None,
|
|
111
|
+
log_likelihood=self.log_likelihood,
|
|
112
|
+
log_prior=self.log_prior,
|
|
113
|
+
log_q=self.log_q,
|
|
114
|
+
xp=np,
|
|
77
115
|
)
|
|
78
116
|
|
|
79
|
-
def to_namespace(self, xp):
|
|
117
|
+
def to_namespace(self, xp, dtype: Any | str | None = None):
|
|
118
|
+
if dtype is None:
|
|
119
|
+
dtype = convert_dtype(self.dtype, xp)
|
|
120
|
+
else:
|
|
121
|
+
dtype = resolve_dtype(dtype, xp)
|
|
80
122
|
logger.debug("Converting samples to {} namespace", xp)
|
|
81
123
|
return self.__class__(
|
|
82
|
-
x=
|
|
124
|
+
x=self.x,
|
|
83
125
|
parameters=self.parameters,
|
|
84
|
-
log_likelihood=
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
log_q=asarray(self.log_q, xp) if self.log_q is not None else None,
|
|
126
|
+
log_likelihood=self.log_likelihood,
|
|
127
|
+
log_prior=self.log_prior,
|
|
128
|
+
log_q=self.log_q,
|
|
129
|
+
xp=xp,
|
|
130
|
+
device=self.device,
|
|
131
|
+
dtype=dtype,
|
|
91
132
|
)
|
|
92
133
|
|
|
93
|
-
def array_to_namespace(self, x):
|
|
134
|
+
def array_to_namespace(self, x, dtype=None):
|
|
94
135
|
"""Convert an array to the same namespace as the samples"""
|
|
95
|
-
|
|
136
|
+
kwargs = {}
|
|
137
|
+
if dtype is not None:
|
|
138
|
+
kwargs["dtype"] = resolve_dtype(dtype, self.xp)
|
|
139
|
+
else:
|
|
140
|
+
kwargs["dtype"] = self.dtype
|
|
141
|
+
x = asarray(x, self.xp, **kwargs)
|
|
96
142
|
if self.device:
|
|
97
143
|
x = to_device(x, self.device)
|
|
98
144
|
return x
|
|
@@ -132,8 +178,7 @@ class BaseSamples:
|
|
|
132
178
|
|
|
133
179
|
def __str__(self):
|
|
134
180
|
out = (
|
|
135
|
-
f"No. samples: {len(self.x)}\n"
|
|
136
|
-
f"No. parameters: {len(self.parameters)}\n"
|
|
181
|
+
f"No. samples: {len(self.x)}\nNo. parameters: {self.x.shape[-1]}\n"
|
|
137
182
|
)
|
|
138
183
|
return out
|
|
139
184
|
|
|
@@ -169,6 +214,7 @@ class BaseSamples:
|
|
|
169
214
|
else None,
|
|
170
215
|
log_q=self.log_q[idx] if self.log_q is not None else None,
|
|
171
216
|
parameters=self.parameters,
|
|
217
|
+
dtype=self.dtype,
|
|
172
218
|
)
|
|
173
219
|
|
|
174
220
|
def __setitem__(self, idx, value: BaseSamples):
|
|
@@ -183,6 +229,8 @@ class BaseSamples:
|
|
|
183
229
|
raise ValueError("Parameters do not match")
|
|
184
230
|
if not all(s.xp == samples[0].xp for s in samples):
|
|
185
231
|
raise ValueError("Array namespaces do not match")
|
|
232
|
+
if not all(s.dtype == samples[0].dtype for s in samples):
|
|
233
|
+
raise ValueError("Dtypes do not match")
|
|
186
234
|
xp = samples[0].xp
|
|
187
235
|
return cls(
|
|
188
236
|
x=xp.concatenate([s.x for s in samples], axis=0),
|
|
@@ -198,6 +246,7 @@ class BaseSamples:
|
|
|
198
246
|
if all(s.log_q is not None for s in samples)
|
|
199
247
|
else None,
|
|
200
248
|
parameters=samples[0].parameters,
|
|
249
|
+
dtype=samples[0].dtype,
|
|
201
250
|
)
|
|
202
251
|
|
|
203
252
|
@classmethod
|
|
@@ -205,6 +254,11 @@ class BaseSamples:
|
|
|
205
254
|
"""Create a Samples object from a BaseSamples object."""
|
|
206
255
|
xp = kwargs.pop("xp", samples.xp)
|
|
207
256
|
device = kwargs.pop("device", samples.device)
|
|
257
|
+
dtype = kwargs.pop("dtype", samples.dtype)
|
|
258
|
+
if dtype is not None:
|
|
259
|
+
dtype = resolve_dtype(dtype, xp)
|
|
260
|
+
else:
|
|
261
|
+
dtype = convert_dtype(samples.dtype, xp)
|
|
208
262
|
return cls(
|
|
209
263
|
x=samples.x,
|
|
210
264
|
log_likelihood=samples.log_likelihood,
|
|
@@ -308,6 +362,7 @@ class Samples(BaseSamples):
|
|
|
308
362
|
x=self.x[accept],
|
|
309
363
|
log_likelihood=self.log_likelihood[accept],
|
|
310
364
|
log_prior=self.log_prior[accept],
|
|
365
|
+
dtype=self.dtype,
|
|
311
366
|
)
|
|
312
367
|
|
|
313
368
|
def to_dict(self, flat: bool = True):
|
|
@@ -400,8 +455,11 @@ class Samples(BaseSamples):
|
|
|
400
455
|
@dataclass
|
|
401
456
|
class SMCSamples(BaseSamples):
|
|
402
457
|
beta: float | None = None
|
|
403
|
-
log_evidence: float | None = None
|
|
404
458
|
"""Temperature parameter for the current samples."""
|
|
459
|
+
log_evidence: float | None = None
|
|
460
|
+
"""Log evidence estimate for the current samples."""
|
|
461
|
+
log_evidence_error: float | None = None
|
|
462
|
+
"""Log evidence error estimate for the current samples."""
|
|
405
463
|
|
|
406
464
|
def log_p_t(self, beta):
|
|
407
465
|
log_p_T = self.log_likelihood + self.log_prior
|
|
@@ -461,6 +519,7 @@ class SMCSamples(BaseSamples):
|
|
|
461
519
|
log_prior=self.log_prior[idx],
|
|
462
520
|
log_q=self.log_q[idx],
|
|
463
521
|
beta=beta,
|
|
522
|
+
dtype=self.dtype,
|
|
464
523
|
)
|
|
465
524
|
|
|
466
525
|
def __str__(self):
|
|
@@ -487,4 +546,5 @@ class SMCSamples(BaseSamples):
|
|
|
487
546
|
sliced,
|
|
488
547
|
beta=self.beta,
|
|
489
548
|
log_evidence=self.log_evidence,
|
|
549
|
+
log_evidence_error=self.log_evidence_error,
|
|
490
550
|
)
|