aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ...samples import SMCSamples
|
|
7
|
+
from ...utils import asarray, to_numpy, track_calls
|
|
8
|
+
from .base import SMCSampler
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BlackJAXSMC(SMCSampler):
|
|
14
|
+
"""BlackJAX SMC sampler."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
log_likelihood,
|
|
19
|
+
log_prior,
|
|
20
|
+
dims,
|
|
21
|
+
prior_flow,
|
|
22
|
+
xp,
|
|
23
|
+
parameters=None,
|
|
24
|
+
preconditioning_transform=None,
|
|
25
|
+
rng: np.random.Generator | None = None, # New parameter
|
|
26
|
+
):
|
|
27
|
+
# For JAX compatibility, we'll keep the original xp
|
|
28
|
+
super().__init__(
|
|
29
|
+
log_likelihood=log_likelihood,
|
|
30
|
+
log_prior=log_prior,
|
|
31
|
+
dims=dims,
|
|
32
|
+
prior_flow=prior_flow,
|
|
33
|
+
xp=xp,
|
|
34
|
+
parameters=parameters,
|
|
35
|
+
preconditioning_transform=preconditioning_transform,
|
|
36
|
+
)
|
|
37
|
+
self.key = None
|
|
38
|
+
self.rng = rng or np.random.default_rng()
|
|
39
|
+
|
|
40
|
+
def log_prob(self, x, beta=None):
|
|
41
|
+
"""Log probability function compatible with BlackJAX."""
|
|
42
|
+
# Convert to original xp format for computation
|
|
43
|
+
if hasattr(x, "__array__"):
|
|
44
|
+
x_original = asarray(x, self.xp)
|
|
45
|
+
else:
|
|
46
|
+
x_original = x
|
|
47
|
+
|
|
48
|
+
# Transform back to parameter space
|
|
49
|
+
x_params, log_abs_det_jacobian = (
|
|
50
|
+
self.preconditioning_transform.inverse(x_original)
|
|
51
|
+
)
|
|
52
|
+
samples = SMCSamples(x_params, xp=self.xp)
|
|
53
|
+
|
|
54
|
+
# Compute log probabilities
|
|
55
|
+
log_q = self.prior_flow.log_prob(samples.x)
|
|
56
|
+
samples.log_q = samples.array_to_namespace(log_q)
|
|
57
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
58
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
59
|
+
self.log_likelihood(samples)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Compute target log probability
|
|
63
|
+
log_prob = samples.log_p_t(
|
|
64
|
+
beta=beta
|
|
65
|
+
).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
|
|
66
|
+
|
|
67
|
+
# Handle NaN values
|
|
68
|
+
log_prob = self.xp.where(
|
|
69
|
+
self.xp.isnan(log_prob), -self.xp.inf, log_prob
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return log_prob
|
|
73
|
+
|
|
74
|
+
@track_calls
|
|
75
|
+
def sample(
|
|
76
|
+
self,
|
|
77
|
+
n_samples: int,
|
|
78
|
+
n_steps: int = None,
|
|
79
|
+
adaptive: bool = True,
|
|
80
|
+
target_efficiency: float = 0.5,
|
|
81
|
+
target_efficiency_rate: float = 1.0,
|
|
82
|
+
n_final_samples: int | None = None,
|
|
83
|
+
sampler_kwargs: dict | None = None,
|
|
84
|
+
rng_key=None,
|
|
85
|
+
):
|
|
86
|
+
"""Sample using BlackJAX SMC.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
n_samples : int
|
|
91
|
+
Number of samples to draw.
|
|
92
|
+
n_steps : int
|
|
93
|
+
Number of SMC steps.
|
|
94
|
+
adaptive : bool
|
|
95
|
+
Whether to use adaptive tempering.
|
|
96
|
+
target_efficiency : float
|
|
97
|
+
Target efficiency for adaptive tempering.
|
|
98
|
+
n_final_samples : int | None
|
|
99
|
+
Number of final samples to return.
|
|
100
|
+
sampler_kwargs : dict | None
|
|
101
|
+
Additional arguments for the BlackJAX sampler.
|
|
102
|
+
- algorithm: str, one of "nuts", "hmc", "rwmh", "random_walk"
|
|
103
|
+
- n_steps: int, number of MCMC steps per mutation
|
|
104
|
+
- step_size: float, step size for HMC/NUTS
|
|
105
|
+
- inverse_mass_matrix: array, inverse mass matrix
|
|
106
|
+
- sigma: float or array, proposal covariance for random walk MH
|
|
107
|
+
- num_integration_steps: int, integration steps for HMC
|
|
108
|
+
rng_key : jax.random.key| None
|
|
109
|
+
JAX random key for reproducibility.
|
|
110
|
+
"""
|
|
111
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
112
|
+
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
113
|
+
self.sampler_kwargs.setdefault("algorithm", "nuts")
|
|
114
|
+
self.sampler_kwargs.setdefault("step_size", 1e-3)
|
|
115
|
+
self.sampler_kwargs.setdefault("inverse_mass_matrix", None)
|
|
116
|
+
self.sampler_kwargs.setdefault("sigma", 0.1) # For random walk MH
|
|
117
|
+
|
|
118
|
+
# Initialize JAX random key
|
|
119
|
+
if rng_key is None:
|
|
120
|
+
import jax
|
|
121
|
+
|
|
122
|
+
self.key = jax.random.key(42)
|
|
123
|
+
else:
|
|
124
|
+
self.key = rng_key
|
|
125
|
+
|
|
126
|
+
return super().sample(
|
|
127
|
+
n_samples,
|
|
128
|
+
n_steps=n_steps,
|
|
129
|
+
adaptive=adaptive,
|
|
130
|
+
target_efficiency=target_efficiency,
|
|
131
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
132
|
+
n_final_samples=n_final_samples,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
136
|
+
"""Mutate particles using BlackJAX MCMC."""
|
|
137
|
+
import blackjax
|
|
138
|
+
import jax
|
|
139
|
+
|
|
140
|
+
logger.debug("Mutating particles with BlackJAX")
|
|
141
|
+
|
|
142
|
+
# Split the random key
|
|
143
|
+
self.key, subkey = jax.random.split(self.key)
|
|
144
|
+
|
|
145
|
+
# Transform particles to latent space
|
|
146
|
+
z = self.fit_preconditioning_transform(particles.x)
|
|
147
|
+
|
|
148
|
+
# Convert to JAX arrays
|
|
149
|
+
z_jax = jax.numpy.asarray(to_numpy(z))
|
|
150
|
+
|
|
151
|
+
# Create log probability function for this beta
|
|
152
|
+
log_prob_fn = partial(self._jax_log_prob, beta=beta)
|
|
153
|
+
|
|
154
|
+
# Choose BlackJAX algorithm
|
|
155
|
+
algorithm = self.sampler_kwargs["algorithm"].lower()
|
|
156
|
+
|
|
157
|
+
n_steps = n_steps or self.sampler_kwargs["n_steps"]
|
|
158
|
+
|
|
159
|
+
if algorithm == "rwmh" or algorithm == "random_walk":
|
|
160
|
+
# Initialize Random Walk Metropolis-Hastings sampler
|
|
161
|
+
sigma = self.sampler_kwargs.get("sigma", 0.1)
|
|
162
|
+
|
|
163
|
+
# BlackJAX RMH expects a transition function, not a covariance
|
|
164
|
+
if isinstance(sigma, (int, float)):
|
|
165
|
+
# Create a multivariate normal proposal function
|
|
166
|
+
def proposal_fn(key, position):
|
|
167
|
+
return position + sigma * jax.random.normal(
|
|
168
|
+
key, position.shape
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
# For more complex covariance structures
|
|
172
|
+
if len(sigma) == self.dims:
|
|
173
|
+
# Diagonal covariance
|
|
174
|
+
sigma_diag = jax.numpy.array(sigma)
|
|
175
|
+
|
|
176
|
+
def proposal_fn(key, position):
|
|
177
|
+
return position + sigma_diag * jax.random.normal(
|
|
178
|
+
key, position.shape
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
# Full covariance matrix
|
|
182
|
+
sigma_matrix = jax.numpy.array(sigma)
|
|
183
|
+
|
|
184
|
+
def proposal_fn(key, position):
|
|
185
|
+
return position + jax.random.multivariate_normal(
|
|
186
|
+
key, jax.numpy.zeros(self.dims), sigma_matrix
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
rwmh = blackjax.rmh(log_prob_fn, proposal_fn)
|
|
190
|
+
|
|
191
|
+
# Initialize states for each particle
|
|
192
|
+
n_particles = z_jax.shape[0]
|
|
193
|
+
keys = jax.random.split(subkey, n_particles)
|
|
194
|
+
|
|
195
|
+
# Vectorized initialization and sampling
|
|
196
|
+
def init_and_sample(key, z_init):
|
|
197
|
+
state = rwmh.init(z_init)
|
|
198
|
+
|
|
199
|
+
def one_step(state, key):
|
|
200
|
+
state, info = rwmh.step(key, state)
|
|
201
|
+
return state, (state, info)
|
|
202
|
+
|
|
203
|
+
keys = jax.random.split(key, n_steps)
|
|
204
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
205
|
+
one_step, state, keys
|
|
206
|
+
)
|
|
207
|
+
return final_state, infos
|
|
208
|
+
|
|
209
|
+
# Vectorize over particles
|
|
210
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
211
|
+
|
|
212
|
+
# Extract final positions
|
|
213
|
+
z_final = final_states.position
|
|
214
|
+
|
|
215
|
+
# Calculate acceptance rates
|
|
216
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
217
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
218
|
+
|
|
219
|
+
elif algorithm == "nuts":
|
|
220
|
+
# Initialize step size and mass matrix if not provided
|
|
221
|
+
inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
|
|
222
|
+
if inverse_mass_matrix is None:
|
|
223
|
+
inverse_mass_matrix = jax.numpy.eye(self.dims)
|
|
224
|
+
|
|
225
|
+
step_size = self.sampler_kwargs["step_size"]
|
|
226
|
+
|
|
227
|
+
# Initialize NUTS sampler
|
|
228
|
+
nuts = blackjax.nuts(
|
|
229
|
+
log_prob_fn,
|
|
230
|
+
step_size=step_size,
|
|
231
|
+
inverse_mass_matrix=inverse_mass_matrix,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Initialize states for each particle
|
|
235
|
+
n_particles = z_jax.shape[0]
|
|
236
|
+
keys = jax.random.split(subkey, n_particles)
|
|
237
|
+
|
|
238
|
+
# Vectorized initialization and sampling
|
|
239
|
+
def init_and_sample(key, z_init):
|
|
240
|
+
state = nuts.init(z_init)
|
|
241
|
+
|
|
242
|
+
def one_step(state, key):
|
|
243
|
+
state, info = nuts.step(key, state)
|
|
244
|
+
return state, (state, info)
|
|
245
|
+
|
|
246
|
+
keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
|
|
247
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
248
|
+
one_step, state, keys
|
|
249
|
+
)
|
|
250
|
+
return final_state, infos
|
|
251
|
+
|
|
252
|
+
# Vectorize over particles
|
|
253
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
254
|
+
|
|
255
|
+
# Extract final positions
|
|
256
|
+
z_final = final_states.position
|
|
257
|
+
|
|
258
|
+
# Calculate acceptance rates
|
|
259
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
260
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
261
|
+
|
|
262
|
+
elif algorithm == "hmc":
|
|
263
|
+
# Initialize HMC sampler
|
|
264
|
+
hmc = blackjax.hmc(
|
|
265
|
+
log_prob_fn,
|
|
266
|
+
step_size=self.sampler_kwargs["step_size"],
|
|
267
|
+
num_integration_steps=self.sampler_kwargs.get(
|
|
268
|
+
"num_integration_steps", 10
|
|
269
|
+
),
|
|
270
|
+
inverse_mass_matrix=(
|
|
271
|
+
self.sampler_kwargs["inverse_mass_matrix"]
|
|
272
|
+
or jax.numpy.eye(self.dims)
|
|
273
|
+
),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Similar vectorized sampling as NUTS
|
|
277
|
+
n_particles = z_jax.shape[0]
|
|
278
|
+
keys = jax.random.split(subkey, n_particles)
|
|
279
|
+
|
|
280
|
+
def init_and_sample(key, z_init):
|
|
281
|
+
state = hmc.init(z_init)
|
|
282
|
+
|
|
283
|
+
def one_step(state, key):
|
|
284
|
+
state, info = hmc.step(key, state)
|
|
285
|
+
return state, (state, info)
|
|
286
|
+
|
|
287
|
+
keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
|
|
288
|
+
final_state, (states, infos) = jax.lax.scan(
|
|
289
|
+
one_step, state, keys
|
|
290
|
+
)
|
|
291
|
+
return final_state, infos
|
|
292
|
+
|
|
293
|
+
final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
|
|
294
|
+
z_final = final_states.position
|
|
295
|
+
acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
|
|
296
|
+
mean_acceptance = jax.numpy.mean(acceptance_rates)
|
|
297
|
+
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
|
300
|
+
|
|
301
|
+
# Convert back to parameter space
|
|
302
|
+
z_final_np = to_numpy(z_final)
|
|
303
|
+
x_final = self.preconditioning_transform.inverse(z_final_np)[0]
|
|
304
|
+
|
|
305
|
+
# Store MCMC diagnostics
|
|
306
|
+
self.history.mcmc_acceptance.append(float(mean_acceptance))
|
|
307
|
+
|
|
308
|
+
# Create new samples
|
|
309
|
+
samples = SMCSamples(x_final, xp=self.xp, beta=beta)
|
|
310
|
+
samples.log_q = samples.array_to_namespace(
|
|
311
|
+
self.prior_flow.log_prob(samples.x)
|
|
312
|
+
)
|
|
313
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
314
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
315
|
+
self.log_likelihood(samples)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
319
|
+
raise ValueError("Log proposal contains NaN values")
|
|
320
|
+
|
|
321
|
+
return samples
|
|
322
|
+
|
|
323
|
+
def _jax_log_prob(self, z, beta):
|
|
324
|
+
"""JAX-compatible log probability function."""
|
|
325
|
+
import jax.numpy as jnp
|
|
326
|
+
|
|
327
|
+
# Single particle version for JAX
|
|
328
|
+
z_expanded = jnp.expand_dims(z, 0) # Add batch dimension
|
|
329
|
+
log_prob = self.log_prob(z_expanded, beta=beta)
|
|
330
|
+
return log_prob[0] # Remove batch dimension
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from ...samples import SMCSamples
|
|
7
|
+
from ...utils import track_calls
|
|
8
|
+
from .base import NumpySMCSampler
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmceeSMC(NumpySMCSampler):
|
|
14
|
+
@track_calls
|
|
15
|
+
def sample(
|
|
16
|
+
self,
|
|
17
|
+
n_samples: int,
|
|
18
|
+
n_steps: int = None,
|
|
19
|
+
adaptive: bool = True,
|
|
20
|
+
target_efficiency: float = 0.5,
|
|
21
|
+
target_efficiency_rate: float = 1.0,
|
|
22
|
+
sampler_kwargs: dict | None = None,
|
|
23
|
+
n_final_samples: int | None = None,
|
|
24
|
+
):
|
|
25
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
26
|
+
self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
|
|
27
|
+
self.sampler_kwargs.setdefault("progress", True)
|
|
28
|
+
self.emcee_moves = self.sampler_kwargs.pop("moves", None)
|
|
29
|
+
return super().sample(
|
|
30
|
+
n_samples,
|
|
31
|
+
n_steps=n_steps,
|
|
32
|
+
adaptive=adaptive,
|
|
33
|
+
target_efficiency=target_efficiency,
|
|
34
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
35
|
+
n_final_samples=n_final_samples,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
39
|
+
import emcee
|
|
40
|
+
|
|
41
|
+
logger.info("Mutating particles")
|
|
42
|
+
sampler = emcee.EnsembleSampler(
|
|
43
|
+
len(particles.x),
|
|
44
|
+
self.dims,
|
|
45
|
+
self.log_prob,
|
|
46
|
+
args=(beta,),
|
|
47
|
+
vectorize=True,
|
|
48
|
+
moves=self.emcee_moves,
|
|
49
|
+
)
|
|
50
|
+
z = self.fit_preconditioning_transform(particles.x)
|
|
51
|
+
kwargs = copy.deepcopy(self.sampler_kwargs)
|
|
52
|
+
if n_steps is not None:
|
|
53
|
+
kwargs["nsteps"] = n_steps
|
|
54
|
+
sampler.run_mcmc(z, **kwargs)
|
|
55
|
+
self.history.mcmc_acceptance.append(
|
|
56
|
+
np.mean(sampler.acceptance_fraction)
|
|
57
|
+
)
|
|
58
|
+
self.history.mcmc_autocorr.append(
|
|
59
|
+
sampler.get_autocorr_time(
|
|
60
|
+
quiet=True, discard=int(0.2 * self.sampler_kwargs["nsteps"])
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
z = sampler.get_chain(flat=False)[-1, ...]
|
|
64
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
65
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta)
|
|
66
|
+
samples.log_q = samples.array_to_namespace(
|
|
67
|
+
self.prior_flow.log_prob(samples.x)
|
|
68
|
+
)
|
|
69
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
70
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
71
|
+
self.log_likelihood(samples)
|
|
72
|
+
)
|
|
73
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
74
|
+
raise ValueError("Log proposal contains NaN values")
|
|
75
|
+
return samples
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ...samples import SMCSamples
|
|
6
|
+
from ...utils import to_numpy, track_calls
|
|
7
|
+
from .base import NumpySMCSampler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MiniPCNSMC(NumpySMCSampler):
|
|
11
|
+
"""MiniPCN SMC sampler."""
|
|
12
|
+
|
|
13
|
+
rng = None
|
|
14
|
+
|
|
15
|
+
def log_prob(self, x, beta=None):
|
|
16
|
+
return to_numpy(super().log_prob(x, beta))
|
|
17
|
+
|
|
18
|
+
@track_calls
|
|
19
|
+
def sample(
|
|
20
|
+
self,
|
|
21
|
+
n_samples: int,
|
|
22
|
+
n_steps: int = None,
|
|
23
|
+
min_step: float | None = None,
|
|
24
|
+
max_n_steps: int | None = None,
|
|
25
|
+
adaptive: bool = True,
|
|
26
|
+
target_efficiency: float = 0.5,
|
|
27
|
+
target_efficiency_rate: float = 1.0,
|
|
28
|
+
n_final_samples: int | None = None,
|
|
29
|
+
sampler_kwargs: dict | None = None,
|
|
30
|
+
rng: np.random.Generator | None = None,
|
|
31
|
+
):
|
|
32
|
+
self.sampler_kwargs = sampler_kwargs or {}
|
|
33
|
+
self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
|
|
34
|
+
self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
|
|
35
|
+
self.sampler_kwargs.setdefault("step_fn", "tpcn")
|
|
36
|
+
self.rng = rng or np.random.default_rng()
|
|
37
|
+
return super().sample(
|
|
38
|
+
n_samples,
|
|
39
|
+
n_steps=n_steps,
|
|
40
|
+
adaptive=adaptive,
|
|
41
|
+
target_efficiency=target_efficiency,
|
|
42
|
+
target_efficiency_rate=target_efficiency_rate,
|
|
43
|
+
n_final_samples=n_final_samples,
|
|
44
|
+
min_step=min_step,
|
|
45
|
+
max_n_steps=max_n_steps,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def mutate(self, particles, beta, n_steps=None):
|
|
49
|
+
from minipcn import Sampler
|
|
50
|
+
|
|
51
|
+
log_prob_fn = partial(self.log_prob, beta=beta)
|
|
52
|
+
|
|
53
|
+
sampler = Sampler(
|
|
54
|
+
log_prob_fn=log_prob_fn,
|
|
55
|
+
step_fn=self.sampler_kwargs["step_fn"],
|
|
56
|
+
rng=self.rng,
|
|
57
|
+
dims=self.dims,
|
|
58
|
+
target_acceptance_rate=self.sampler_kwargs[
|
|
59
|
+
"target_acceptance_rate"
|
|
60
|
+
],
|
|
61
|
+
)
|
|
62
|
+
# Map to transformed dimension for sampling
|
|
63
|
+
z = to_numpy(self.fit_preconditioning_transform(particles.x))
|
|
64
|
+
chain, history = sampler.sample(
|
|
65
|
+
z,
|
|
66
|
+
n_steps=n_steps or self.sampler_kwargs["n_steps"],
|
|
67
|
+
)
|
|
68
|
+
x = self.preconditioning_transform.inverse(chain[-1])[0]
|
|
69
|
+
|
|
70
|
+
self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
|
|
71
|
+
|
|
72
|
+
samples = SMCSamples(x, xp=self.xp, beta=beta)
|
|
73
|
+
samples.log_q = samples.array_to_namespace(
|
|
74
|
+
self.prior_flow.log_prob(samples.x)
|
|
75
|
+
)
|
|
76
|
+
samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
|
|
77
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
78
|
+
self.log_likelihood(samples)
|
|
79
|
+
)
|
|
80
|
+
if samples.xp.isnan(samples.log_q).any():
|
|
81
|
+
raise ValueError("Log proposal contains NaN values")
|
|
82
|
+
return samples
|