aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
import array_api_compat.numpy as np
|
|
5
|
+
|
|
6
|
+
from ...flows.base import Flow
|
|
7
|
+
from ...history import SMCHistory
|
|
8
|
+
from ...samples import SMCSamples
|
|
9
|
+
from ...utils import (
|
|
10
|
+
asarray,
|
|
11
|
+
effective_sample_size,
|
|
12
|
+
track_calls,
|
|
13
|
+
update_at_indices,
|
|
14
|
+
)
|
|
15
|
+
from ..mcmc import MCMCSampler
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SMCSampler(MCMCSampler):
|
|
21
|
+
"""Base class for Sequential Monte Carlo samplers."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
log_likelihood: Callable,
|
|
26
|
+
log_prior: Callable,
|
|
27
|
+
dims: int,
|
|
28
|
+
prior_flow: Flow,
|
|
29
|
+
xp: Callable,
|
|
30
|
+
dtype: Any | str | None = None,
|
|
31
|
+
parameters: list[str] | None = None,
|
|
32
|
+
rng: np.random.Generator | None = None,
|
|
33
|
+
preconditioning_transform: Callable | None = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(
|
|
36
|
+
log_likelihood=log_likelihood,
|
|
37
|
+
log_prior=log_prior,
|
|
38
|
+
dims=dims,
|
|
39
|
+
prior_flow=prior_flow,
|
|
40
|
+
xp=xp,
|
|
41
|
+
dtype=dtype,
|
|
42
|
+
parameters=parameters,
|
|
43
|
+
preconditioning_transform=preconditioning_transform,
|
|
44
|
+
)
|
|
45
|
+
self.rng = rng or np.random.default_rng()
|
|
46
|
+
self._adapative_target_efficiency = False
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def target_efficiency(self):
|
|
50
|
+
return self._target_efficiency
|
|
51
|
+
|
|
52
|
+
@target_efficiency.setter
|
|
53
|
+
def target_efficiency(self, value: float | tuple):
|
|
54
|
+
"""Set the target efficiency.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
value : float or tuple
|
|
59
|
+
If a float, the target efficiency to use for all iterations.
|
|
60
|
+
If a tuple of two floats, the target efficiency will adapt from
|
|
61
|
+
the first value to the second value over the course of the SMC
|
|
62
|
+
iterations. See `target_efficiency_rate` for details.
|
|
63
|
+
"""
|
|
64
|
+
if isinstance(value, float):
|
|
65
|
+
if not (0 < value < 1):
|
|
66
|
+
raise ValueError("target_efficiency must be in (0, 1)")
|
|
67
|
+
self._target_efficiency = value
|
|
68
|
+
self._adapative_target_efficiency = False
|
|
69
|
+
elif len(value) != 2:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"target_efficiency must be a float or tuple of two floats"
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
value = tuple(map(float, value))
|
|
75
|
+
if not (0 < value[0] < value[1] < 1):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"target_efficiency tuple must be in (0, 1) and increasing"
|
|
78
|
+
)
|
|
79
|
+
self._target_efficiency = value
|
|
80
|
+
self._adapative_target_efficiency = True
|
|
81
|
+
|
|
82
|
+
def current_target_efficiency(self, beta: float) -> float:
|
|
83
|
+
"""Get the current target efficiency based on beta."""
|
|
84
|
+
if self._adapative_target_efficiency:
|
|
85
|
+
return self._target_efficiency[0] + (
|
|
86
|
+
self._target_efficiency[1] - self._target_efficiency[0]
|
|
87
|
+
) * (beta**self.target_efficiency_rate)
|
|
88
|
+
else:
|
|
89
|
+
return self._target_efficiency
|
|
90
|
+
|
|
91
|
+
def determine_beta(
|
|
92
|
+
self,
|
|
93
|
+
samples: SMCSamples,
|
|
94
|
+
beta: float,
|
|
95
|
+
beta_step: float,
|
|
96
|
+
min_step: float,
|
|
97
|
+
) -> tuple[float, float]:
|
|
98
|
+
"""Determine the next beta value.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
samples : SMCSamples
|
|
103
|
+
The current samples.
|
|
104
|
+
beta : float
|
|
105
|
+
The current beta value.
|
|
106
|
+
beta_step : float
|
|
107
|
+
The fixed beta step size if not adaptive.
|
|
108
|
+
min_step : float
|
|
109
|
+
The minimum beta step size.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
beta : float
|
|
114
|
+
The new beta value.
|
|
115
|
+
min_step : float
|
|
116
|
+
The new minimum step size if adaptive_min_step is True.
|
|
117
|
+
"""
|
|
118
|
+
if not self.adaptive:
|
|
119
|
+
beta += beta_step
|
|
120
|
+
if beta >= 1.0:
|
|
121
|
+
beta = 1.0
|
|
122
|
+
else:
|
|
123
|
+
beta_prev = beta
|
|
124
|
+
beta_min = beta_prev
|
|
125
|
+
beta_max = 1.0
|
|
126
|
+
tol = 1e-5
|
|
127
|
+
eff_beta_max = effective_sample_size(
|
|
128
|
+
samples.log_weights(beta_max)
|
|
129
|
+
) / len(samples)
|
|
130
|
+
if eff_beta_max >= self.current_target_efficiency(beta_prev):
|
|
131
|
+
beta_min = 1.0
|
|
132
|
+
target_eff = self.current_target_efficiency(beta_prev)
|
|
133
|
+
while beta_max - beta_min > tol:
|
|
134
|
+
beta_try = 0.5 * (beta_max + beta_min)
|
|
135
|
+
eff = effective_sample_size(
|
|
136
|
+
samples.log_weights(beta_try)
|
|
137
|
+
) / len(samples)
|
|
138
|
+
if eff >= target_eff:
|
|
139
|
+
beta_min = beta_try
|
|
140
|
+
else:
|
|
141
|
+
beta_max = beta_try
|
|
142
|
+
beta_star = beta_min
|
|
143
|
+
|
|
144
|
+
if self.adaptive_min_step:
|
|
145
|
+
min_step = min_step * (1 - beta_prev) / (1 - beta_star)
|
|
146
|
+
beta = max(beta_star, beta_prev + min_step)
|
|
147
|
+
beta = min(beta, 1.0)
|
|
148
|
+
return beta, min_step
|
|
149
|
+
|
|
150
|
+
@track_calls
|
|
151
|
+
def sample(
|
|
152
|
+
self,
|
|
153
|
+
n_samples: int,
|
|
154
|
+
n_steps: int | None = None,
|
|
155
|
+
adaptive: bool = True,
|
|
156
|
+
min_step: float | None = None,
|
|
157
|
+
max_n_steps: int | None = None,
|
|
158
|
+
target_efficiency: float = 0.5,
|
|
159
|
+
target_efficiency_rate: float = 1.0,
|
|
160
|
+
n_final_samples: int | None = None,
|
|
161
|
+
) -> SMCSamples:
|
|
162
|
+
samples = self.draw_initial_samples(n_samples)
|
|
163
|
+
samples = SMCSamples.from_samples(
|
|
164
|
+
samples, xp=self.xp, beta=0.0, dtype=self.dtype
|
|
165
|
+
)
|
|
166
|
+
self.fit_preconditioning_transform(samples.x)
|
|
167
|
+
|
|
168
|
+
if self.xp.isnan(samples.log_q).any():
|
|
169
|
+
raise ValueError("Log proposal contains NaN values")
|
|
170
|
+
if self.xp.isnan(samples.log_prior).any():
|
|
171
|
+
raise ValueError("Log prior contains NaN values")
|
|
172
|
+
if self.xp.isnan(samples.log_likelihood).any():
|
|
173
|
+
raise ValueError("Log likelihood contains NaN values")
|
|
174
|
+
|
|
175
|
+
logger.debug(f"Initial sample summary: {samples}")
|
|
176
|
+
|
|
177
|
+
# Remove the n_final_steps from sampler_kwargs if present
|
|
178
|
+
self.sampler_kwargs = self.sampler_kwargs or {}
|
|
179
|
+
n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
|
|
180
|
+
|
|
181
|
+
self.history = SMCHistory()
|
|
182
|
+
|
|
183
|
+
self.target_efficiency = target_efficiency
|
|
184
|
+
self.target_efficiency_rate = target_efficiency_rate
|
|
185
|
+
|
|
186
|
+
if n_steps is not None:
|
|
187
|
+
beta_step = 1 / n_steps
|
|
188
|
+
elif not adaptive:
|
|
189
|
+
raise ValueError("Either n_steps or adaptive=True must be set")
|
|
190
|
+
else:
|
|
191
|
+
beta_step = np.nan
|
|
192
|
+
self.adaptive = adaptive
|
|
193
|
+
beta = 0.0
|
|
194
|
+
|
|
195
|
+
if min_step is None:
|
|
196
|
+
if max_n_steps is None:
|
|
197
|
+
min_step = 0.0
|
|
198
|
+
self.adaptive_min_step = False
|
|
199
|
+
else:
|
|
200
|
+
min_step = 1 / max_n_steps
|
|
201
|
+
self.adaptive_min_step = True
|
|
202
|
+
else:
|
|
203
|
+
self.adaptive_min_step = False
|
|
204
|
+
|
|
205
|
+
iterations = 0
|
|
206
|
+
while True:
|
|
207
|
+
iterations += 1
|
|
208
|
+
|
|
209
|
+
beta, min_step = self.determine_beta(
|
|
210
|
+
samples,
|
|
211
|
+
beta,
|
|
212
|
+
beta_step,
|
|
213
|
+
min_step,
|
|
214
|
+
)
|
|
215
|
+
self.history.eff_target.append(
|
|
216
|
+
self.current_target_efficiency(beta)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
logger.info(f"it {iterations} - beta: {beta}")
|
|
220
|
+
self.history.beta.append(beta)
|
|
221
|
+
|
|
222
|
+
ess = effective_sample_size(samples.log_weights(beta))
|
|
223
|
+
eff = ess / len(samples)
|
|
224
|
+
if eff < 0.1:
|
|
225
|
+
logger.warning(
|
|
226
|
+
f"it {iterations} - Low sample efficiency: {eff:.2f}"
|
|
227
|
+
)
|
|
228
|
+
self.history.ess.append(ess)
|
|
229
|
+
logger.info(
|
|
230
|
+
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
231
|
+
)
|
|
232
|
+
self.history.ess_target.append(
|
|
233
|
+
effective_sample_size(samples.log_weights(1.0))
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
log_evidence_ratio = samples.log_evidence_ratio(beta)
|
|
237
|
+
log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
|
|
238
|
+
self.history.log_norm_ratio.append(log_evidence_ratio)
|
|
239
|
+
self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
|
|
240
|
+
logger.info(
|
|
241
|
+
f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
samples = samples.resample(beta, rng=self.rng)
|
|
245
|
+
|
|
246
|
+
samples = self.mutate(samples, beta)
|
|
247
|
+
if beta == 1.0 or (
|
|
248
|
+
max_n_steps is not None and iterations >= max_n_steps
|
|
249
|
+
):
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
# If n_final_samples is not None, perform additional mutations steps
|
|
253
|
+
if n_final_samples is not None:
|
|
254
|
+
logger.info(f"Generating {n_final_samples} final samples")
|
|
255
|
+
final_samples = samples.resample(
|
|
256
|
+
1.0, n_samples=n_final_samples, rng=self.rng
|
|
257
|
+
)
|
|
258
|
+
samples = self.mutate(final_samples, 1.0, n_steps=n_final_steps)
|
|
259
|
+
|
|
260
|
+
samples.log_evidence = samples.xp.sum(
|
|
261
|
+
asarray(self.history.log_norm_ratio, self.xp)
|
|
262
|
+
)
|
|
263
|
+
samples.log_evidence_error = samples.xp.sqrt(
|
|
264
|
+
samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
final_samples = samples.to_standard_samples()
|
|
268
|
+
logger.info(
|
|
269
|
+
f"Log evidence: {final_samples.log_evidence:.2f} +/- {final_samples.log_evidence_error:.2f}"
|
|
270
|
+
)
|
|
271
|
+
return final_samples
|
|
272
|
+
|
|
273
|
+
def mutate(self, particles):
|
|
274
|
+
raise NotImplementedError
|
|
275
|
+
|
|
276
|
+
def log_prob(self, z, beta=None):
|
|
277
|
+
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
278
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
279
|
+
log_q = self.prior_flow.log_prob(samples.x)
|
|
280
|
+
samples.log_q = samples.array_to_namespace(log_q)
|
|
281
|
+
samples.log_prior = self.log_prior(samples)
|
|
282
|
+
samples.log_likelihood = self.log_likelihood(samples)
|
|
283
|
+
log_prob = samples.log_p_t(
|
|
284
|
+
beta=beta
|
|
285
|
+
).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
|
|
286
|
+
|
|
287
|
+
log_prob = update_at_indices(
|
|
288
|
+
log_prob, self.xp.isnan(log_prob), -self.xp.inf
|
|
289
|
+
)
|
|
290
|
+
return log_prob
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class NumpySMCSampler(SMCSampler):
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
log_likelihood,
|
|
297
|
+
log_prior,
|
|
298
|
+
dims,
|
|
299
|
+
prior_flow,
|
|
300
|
+
xp,
|
|
301
|
+
dtype=None,
|
|
302
|
+
parameters=None,
|
|
303
|
+
preconditioning_transform=None,
|
|
304
|
+
):
|
|
305
|
+
if preconditioning_transform is not None:
|
|
306
|
+
preconditioning_transform = preconditioning_transform.new_instance(
|
|
307
|
+
xp=np
|
|
308
|
+
)
|
|
309
|
+
super().__init__(
|
|
310
|
+
log_likelihood,
|
|
311
|
+
log_prior,
|
|
312
|
+
dims,
|
|
313
|
+
prior_flow=prior_flow,
|
|
314
|
+
xp=xp,
|
|
315
|
+
dtype=dtype,
|
|
316
|
+
parameters=parameters,
|
|
317
|
+
preconditioning_transform=preconditioning_transform,
|
|
318
|
+
)
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ...samples import SMCSamples
|
|
7
|
+
from ...utils import asarray, to_numpy, track_calls
|
|
8
|
+
from .base import SMCSampler
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BlackJAXSMC(SMCSampler):
|
|
14
|
+
"""BlackJAX SMC sampler."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
log_likelihood,
|
|
19
|
+
log_prior,
|
|
20
|
+
dims,
|
|
21
|
+
prior_flow,
|
|
22
|
+
xp,
|
|
23
|
+
dtype=None,
|
|
24
|
+
parameters=None,
|
|
25
|
+
preconditioning_transform=None,
|
|
26
|
+
rng: np.random.Generator | None = None, # New parameter
|
|
27
|
+
):
|
|
28
|
+
# For JAX compatibility, we'll keep the original xp
|
|
29
|
+
super().__init__(
|
|
30
|
+
log_likelihood=log_likelihood,
|
|
31
|
+
log_prior=log_prior,
|
|
32
|
+
dims=dims,
|
|
33
|
+
prior_flow=prior_flow,
|
|
34
|
+
xp=xp,
|
|
35
|
+
dtype=dtype,
|
|
36
|
+
parameters=parameters,
|
|
37
|
+
preconditioning_transform=preconditioning_transform,
|
|
38
|
+
)
|
|
39
|
+
self.key = None
|
|
40
|
+
self.rng = rng or np.random.default_rng()
|
|
41
|
+
|
|
42
|
+
def log_prob(self, x, beta=None):
|
|
43
|
+
"""Log probability function compatible with BlackJAX."""
|
|
44
|
+
# Convert to original xp format for computation
|
|
45
|
+
if hasattr(x, "__array__"):
|
|
46
|
+
x_original = asarray(x, self.xp)
|
|
47
|
+
else:
|
|
48
|
+
x_original = x
|
|
49
|
+
|
|
50
|
+
# Transform back to parameter space
|
|
51
|
+
x_params, log_abs_det_jacobian = (
|
|
52
|
+
self.preconditioning_transform.inverse(x_original)
|
|
53
|
+
)
|
|
54
|
+
samples = SMCSamples(x_params, xp=self.xp, dtype=self.dtype)
|
|
55
|
+
|
|
56
|
+
# Compute log probabilities
|
|
57
|
+
log_q = self.prior_flow.log_prob(samples.x)
|
|
58
|
+
samples.log_q = samples.array_to_namespace(log_q)
|
|
59
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
60
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
61
|
+
self.log_likelihood(samples)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Compute target log probability
|
|
65
|
+
log_prob = samples.log_p_t(
|
|
66
|
+
beta=beta
|
|
67
|
+
).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
|
|
68
|
+
|
|
69
|
+
# Handle NaN values
|
|
70
|
+
log_prob = self.xp.where(
|
|
71
|
+
self.xp.isnan(log_prob), -self.xp.inf, log_prob
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return log_prob
|
|
75
|
+
|
|
76
|
+
@track_calls
|
|
77
|
+
def sample(
|
|
78
|
+
self,
|
|
79
|
+
n_samples: int,
|
|
80
|
+
n_steps: int = None,
|
|
81
|
+
adaptive: bool = True,
|
|
82
|
+
target_efficiency: float = 0.5,
|
|
83
|
+
target_efficiency_rate: float = 1.0,
|
|
84
|
+
n_final_samples: int | None = None,
|
|
85
|
+
sampler_kwargs: dict | None = None,
|
|
86
|
+
rng_key=None,
|
|
87
|
+
):
|
|
88
|
+
"""Sample using BlackJAX SMC.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
n_samples : int
|
|
93
|
+
Number of samples to draw.
|
|
94
|
+
n_steps : int
|
|
95
|
+
Number of SMC steps.
|
|
96
|
+
adaptive : bool
|
|
97
|
+
Whether to use adaptive tempering.
|
|
98
|
+
target_efficiency : float
|
|
99
|
+
Target efficiency for adaptive tempering.
|
|
100
|
+
n_final_samples : int | None
|
|
101
|
+
Number of final samples to return.
|
|
102
|
+
sampler_kwargs : dict | None
|
|
103
|
+
Additional arguments for the BlackJAX sampler.
|
|
104
|
+
- algorithm: str, one of "nuts", "hmc", "rwmh", "random_walk"
|
|
105
|
+
- n_steps: int, number of MCMC steps per mutation
|
|
106
|
+
- step_size: float, step size for HMC/NUTS
|
|
107
|
+
- inverse_mass_matrix: array, inverse mass matrix
|
|
108
|
+
- sigma: float or array, proposal covariance for random walk MH
|
|
109
|
+
- num_integration_steps: int, integration steps for HMC
|
|
110
|
+
rng_key : jax.random.key| None
|
|
111
|
+
JAX random key for reproducibility.
|
|
112
|
+
"""
|
|
113
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
114
|
+
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
115
|
+
self.sampler_kwargs.setdefault("algorithm", "nuts")
|
|
116
|
+
self.sampler_kwargs.setdefault("step_size", 1e-3)
|
|
117
|
+
self.sampler_kwargs.setdefault("inverse_mass_matrix", None)
|
|
118
|
+
self.sampler_kwargs.setdefault("sigma", 0.1) # For random walk MH
|
|
119
|
+
|
|
120
|
+
# Initialize JAX random key
|
|
121
|
+
if rng_key is None:
|
|
122
|
+
import jax
|
|
123
|
+
|
|
124
|
+
self.key = jax.random.key(42)
|
|
125
|
+
else:
|
|
126
|
+
self.key = rng_key
|
|
127
|
+
|
|
128
|
+
return super().sample(
|
|
129
|
+
n_samples,
|
|
130
|
+
n_steps=n_steps,
|
|
131
|
+
adaptive=adaptive,
|
|
132
|
+
target_efficiency=target_efficiency,
|
|
133
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
134
|
+
n_final_samples=n_final_samples,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
138
|
+
"""Mutate particles using BlackJAX MCMC."""
|
|
139
|
+
import blackjax
|
|
140
|
+
import jax
|
|
141
|
+
|
|
142
|
+
logger.debug("Mutating particles with BlackJAX")
|
|
143
|
+
|
|
144
|
+
# Split the random key
|
|
145
|
+
self.key, subkey = jax.random.split(self.key)
|
|
146
|
+
|
|
147
|
+
# Transform particles to latent space
|
|
148
|
+
z = self.fit_preconditioning_transform(particles.x)
|
|
149
|
+
|
|
150
|
+
# Convert to JAX arrays
|
|
151
|
+
z_jax = jax.numpy.asarray(to_numpy(z))
|
|
152
|
+
|
|
153
|
+
# Create log probability function for this beta
|
|
154
|
+
log_prob_fn = partial(self._jax_log_prob, beta=beta)
|
|
155
|
+
|
|
156
|
+
# Choose BlackJAX algorithm
|
|
157
|
+
algorithm = self.sampler_kwargs["algorithm"].lower()
|
|
158
|
+
|
|
159
|
+
n_steps = n_steps or self.sampler_kwargs["n_steps"]
|
|
160
|
+
|
|
161
|
+
if algorithm == "rwmh" or algorithm == "random_walk":
|
|
162
|
+
# Initialize Random Walk Metropolis-Hastings sampler
|
|
163
|
+
sigma = self.sampler_kwargs.get("sigma", 0.1)
|
|
164
|
+
|
|
165
|
+
# BlackJAX RMH expects a transition function, not a covariance
|
|
166
|
+
if isinstance(sigma, (int, float)):
|
|
167
|
+
# Create a multivariate normal proposal function
|
|
168
|
+
def proposal_fn(key, position):
|
|
169
|
+
return position + sigma * jax.random.normal(
|
|
170
|
+
key, position.shape
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
# For more complex covariance structures
|
|
174
|
+
if len(sigma) == self.dims:
|
|
175
|
+
# Diagonal covariance
|
|
176
|
+
sigma_diag = jax.numpy.array(sigma)
|
|
177
|
+
|
|
178
|
+
def proposal_fn(key, position):
|
|
179
|
+
return position + sigma_diag * jax.random.normal(
|
|
180
|
+
key, position.shape
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
# Full covariance matrix
|
|
184
|
+
sigma_matrix = jax.numpy.array(sigma)
|
|
185
|
+
|
|
186
|
+
def proposal_fn(key, position):
|
|
187
|
+
return position + jax.random.multivariate_normal(
|
|
188
|
+
key, jax.numpy.zeros(self.dims), sigma_matrix
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
rwmh = blackjax.rmh(log_prob_fn, proposal_fn)
|
|
192
|
+
|
|
193
|
+
# Initialize states for each particle
|
|
194
|
+
n_particles = z_jax.shape[0]
|
|
195
|
+
keys = jax.random.split(subkey, n_particles)
|
|
196
|
+
|
|
197
|
+
# Vectorized initialization and sampling
|
|
198
|
+
def init_and_sample(key, z_init):
|
|
199
|
+
state = rwmh.init(z_init)
|
|
200
|
+
|
|
201
|
+
def one_step(state, key):
|
|
202
|
+
state, info = rwmh.step(key, state)
|
|
203
|
+
return state, (state, info)
|
|
204
|
+
|
|
205
|
+
keys = jax.random.split(key, n_steps)
|
|
206
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
207
|
+
one_step, state, keys
|
|
208
|
+
)
|
|
209
|
+
return final_state, infos
|
|
210
|
+
|
|
211
|
+
# Vectorize over particles
|
|
212
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
213
|
+
|
|
214
|
+
# Extract final positions
|
|
215
|
+
z_final = final_states.position
|
|
216
|
+
|
|
217
|
+
# Calculate acceptance rates
|
|
218
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
219
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
220
|
+
|
|
221
|
+
elif algorithm == "nuts":
|
|
222
|
+
# Initialize step size and mass matrix if not provided
|
|
223
|
+
inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
|
|
224
|
+
if inverse_mass_matrix is None:
|
|
225
|
+
inverse_mass_matrix = jax.numpy.eye(self.dims)
|
|
226
|
+
|
|
227
|
+
step_size = self.sampler_kwargs["step_size"]
|
|
228
|
+
|
|
229
|
+
# Initialize NUTS sampler
|
|
230
|
+
nuts = blackjax.nuts(
|
|
231
|
+
log_prob_fn,
|
|
232
|
+
step_size=step_size,
|
|
233
|
+
inverse_mass_matrix=inverse_mass_matrix,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Initialize states for each particle
|
|
237
|
+
n_particles = z_jax.shape[0]
|
|
238
|
+
keys = jax.random.split(subkey, n_particles)
|
|
239
|
+
|
|
240
|
+
# Vectorized initialization and sampling
|
|
241
|
+
def init_and_sample(key, z_init):
|
|
242
|
+
state = nuts.init(z_init)
|
|
243
|
+
|
|
244
|
+
def one_step(state, key):
|
|
245
|
+
state, info = nuts.step(key, state)
|
|
246
|
+
return state, (state, info)
|
|
247
|
+
|
|
248
|
+
keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
|
|
249
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
250
|
+
one_step, state, keys
|
|
251
|
+
)
|
|
252
|
+
return final_state, infos
|
|
253
|
+
|
|
254
|
+
# Vectorize over particles
|
|
255
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
256
|
+
|
|
257
|
+
# Extract final positions
|
|
258
|
+
z_final = final_states.position
|
|
259
|
+
|
|
260
|
+
# Calculate acceptance rates
|
|
261
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
262
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
263
|
+
|
|
264
|
+
elif algorithm == "hmc":
|
|
265
|
+
# Initialize HMC sampler
|
|
266
|
+
hmc = blackjax.hmc(
|
|
267
|
+
log_prob_fn,
|
|
268
|
+
step_size=self.sampler_kwargs["step_size"],
|
|
269
|
+
num_integration_steps=self.sampler_kwargs.get(
|
|
270
|
+
"num_integration_steps", 10
|
|
271
|
+
),
|
|
272
|
+
inverse_mass_matrix=(
|
|
273
|
+
self.sampler_kwargs["inverse_mass_matrix"]
|
|
274
|
+
or jax.numpy.eye(self.dims)
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Similar vectorized sampling as NUTS
|
|
279
|
+
n_particles = z_jax.shape[0]
|
|
280
|
+
keys = jax.random.split(subkey, n_particles)
|
|
281
|
+
|
|
282
|
+
def init_and_sample(key, z_init):
|
|
283
|
+
state = hmc.init(z_init)
|
|
284
|
+
|
|
285
|
+
def one_step(state, key):
|
|
286
|
+
state, info = hmc.step(key, state)
|
|
287
|
+
return state, (state, info)
|
|
288
|
+
|
|
289
|
+
keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
|
|
290
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
291
|
+
one_step, state, keys
|
|
292
|
+
)
|
|
293
|
+
return final_state, infos
|
|
294
|
+
|
|
295
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
296
|
+
z_final = final_states.position
|
|
297
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
298
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
299
|
+
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
|
302
|
+
|
|
303
|
+
# Convert back to parameter space
|
|
304
|
+
z_final_np = to_numpy(z_final)
|
|
305
|
+
x_final = self.preconditioning_transform.inverse(z_final_np)[0]
|
|
306
|
+
|
|
307
|
+
# Store MCMC diagnostics
|
|
308
|
+
self.history.mcmc_acceptance.append(float(mean_acceptance))
|
|
309
|
+
|
|
310
|
+
# Create new samples
|
|
311
|
+
samples = SMCSamples(x_final, xp=self.xp, beta=beta, dtype=self.dtype)
|
|
312
|
+
samples.log_q = samples.array_to_namespace(
|
|
313
|
+
self.prior_flow.log_prob(samples.x)
|
|
314
|
+
)
|
|
315
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
316
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
317
|
+
self.log_likelihood(samples)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
321
|
+
raise ValueError("Log proposal contains NaN values")
|
|
322
|
+
|
|
323
|
+
return samples
|
|
324
|
+
|
|
325
|
+
def _jax_log_prob(self, z, beta):
|
|
326
|
+
"""JAX-compatible log probability function."""
|
|
327
|
+
import jax.numpy as jnp
|
|
328
|
+
|
|
329
|
+
# Single particle version for JAX
|
|
330
|
+
z_expanded = jnp.expand_dims(z, 0) # Add batch dimension
|
|
331
|
+
log_prob = self.log_prob(z_expanded, beta=beta)
|
|
332
|
+
return log_prob[0] # Remove batch dimension
|