aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,330 @@
1
+ import logging
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+
6
+ from ...samples import SMCSamples
7
+ from ...utils import asarray, to_numpy, track_calls
8
+ from .base import SMCSampler
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class BlackJAXSMC(SMCSampler):
14
+ """BlackJAX SMC sampler."""
15
+
16
+ def __init__(
17
+ self,
18
+ log_likelihood,
19
+ log_prior,
20
+ dims,
21
+ prior_flow,
22
+ xp,
23
+ parameters=None,
24
+ preconditioning_transform=None,
25
+ rng: np.random.Generator | None = None, # New parameter
26
+ ):
27
+ # For JAX compatibility, we'll keep the original xp
28
+ super().__init__(
29
+ log_likelihood=log_likelihood,
30
+ log_prior=log_prior,
31
+ dims=dims,
32
+ prior_flow=prior_flow,
33
+ xp=xp,
34
+ parameters=parameters,
35
+ preconditioning_transform=preconditioning_transform,
36
+ )
37
+ self.key = None
38
+ self.rng = rng or np.random.default_rng()
39
+
40
+ def log_prob(self, x, beta=None):
41
+ """Log probability function compatible with BlackJAX."""
42
+ # Convert to original xp format for computation
43
+ if hasattr(x, "__array__"):
44
+ x_original = asarray(x, self.xp)
45
+ else:
46
+ x_original = x
47
+
48
+ # Transform back to parameter space
49
+ x_params, log_abs_det_jacobian = (
50
+ self.preconditioning_transform.inverse(x_original)
51
+ )
52
+ samples = SMCSamples(x_params, xp=self.xp)
53
+
54
+ # Compute log probabilities
55
+ log_q = self.prior_flow.log_prob(samples.x)
56
+ samples.log_q = samples.array_to_namespace(log_q)
57
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
58
+ samples.log_likelihood = samples.array_to_namespace(
59
+ self.log_likelihood(samples)
60
+ )
61
+
62
+ # Compute target log probability
63
+ log_prob = samples.log_p_t(
64
+ beta=beta
65
+ ).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
66
+
67
+ # Handle NaN values
68
+ log_prob = self.xp.where(
69
+ self.xp.isnan(log_prob), -self.xp.inf, log_prob
70
+ )
71
+
72
+ return log_prob
73
+
74
+ @track_calls
75
+ def sample(
76
+ self,
77
+ n_samples: int,
78
+ n_steps: int = None,
79
+ adaptive: bool = True,
80
+ target_efficiency: float = 0.5,
81
+ target_efficiency_rate: float = 1.0,
82
+ n_final_samples: int | None = None,
83
+ sampler_kwargs: dict | None = None,
84
+ rng_key=None,
85
+ ):
86
+ """Sample using BlackJAX SMC.
87
+
88
+ Parameters
89
+ ----------
90
+ n_samples : int
91
+ Number of samples to draw.
92
+ n_steps : int
93
+ Number of SMC steps.
94
+ adaptive : bool
95
+ Whether to use adaptive tempering.
96
+ target_efficiency : float
97
+ Target efficiency for adaptive tempering.
98
+ n_final_samples : int | None
99
+ Number of final samples to return.
100
+ sampler_kwargs : dict | None
101
+ Additional arguments for the BlackJAX sampler.
102
+ - algorithm: str, one of "nuts", "hmc", "rwmh", "random_walk"
103
+ - n_steps: int, number of MCMC steps per mutation
104
+ - step_size: float, step size for HMC/NUTS
105
+ - inverse_mass_matrix: array, inverse mass matrix
106
+ - sigma: float or array, proposal covariance for random walk MH
107
+ - num_integration_steps: int, integration steps for HMC
108
+ rng_key : jax.random.key| None
109
+ JAX random key for reproducibility.
110
+ """
111
+ self.sampler_kwargs = sampler_kwargs or {}
112
+ self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
113
+ self.sampler_kwargs.setdefault("algorithm", "nuts")
114
+ self.sampler_kwargs.setdefault("step_size", 1e-3)
115
+ self.sampler_kwargs.setdefault("inverse_mass_matrix", None)
116
+ self.sampler_kwargs.setdefault("sigma", 0.1) # For random walk MH
117
+
118
+ # Initialize JAX random key
119
+ if rng_key is None:
120
+ import jax
121
+
122
+ self.key = jax.random.key(42)
123
+ else:
124
+ self.key = rng_key
125
+
126
+ return super().sample(
127
+ n_samples,
128
+ n_steps=n_steps,
129
+ adaptive=adaptive,
130
+ target_efficiency=target_efficiency,
131
+ target_efficiency_rate=target_efficiency_rate,
132
+ n_final_samples=n_final_samples,
133
+ )
134
+
135
+ def mutate(self, particles, beta, n_steps=None):
136
+ """Mutate particles using BlackJAX MCMC."""
137
+ import blackjax
138
+ import jax
139
+
140
+ logger.debug("Mutating particles with BlackJAX")
141
+
142
+ # Split the random key
143
+ self.key, subkey = jax.random.split(self.key)
144
+
145
+ # Transform particles to latent space
146
+ z = self.fit_preconditioning_transform(particles.x)
147
+
148
+ # Convert to JAX arrays
149
+ z_jax = jax.numpy.asarray(to_numpy(z))
150
+
151
+ # Create log probability function for this beta
152
+ log_prob_fn = partial(self._jax_log_prob, beta=beta)
153
+
154
+ # Choose BlackJAX algorithm
155
+ algorithm = self.sampler_kwargs["algorithm"].lower()
156
+
157
+ n_steps = n_steps or self.sampler_kwargs["n_steps"]
158
+
159
+ if algorithm == "rwmh" or algorithm == "random_walk":
160
+ # Initialize Random Walk Metropolis-Hastings sampler
161
+ sigma = self.sampler_kwargs.get("sigma", 0.1)
162
+
163
+ # BlackJAX RMH expects a transition function, not a covariance
164
+ if isinstance(sigma, (int, float)):
165
+ # Create a multivariate normal proposal function
166
+ def proposal_fn(key, position):
167
+ return position + sigma * jax.random.normal(
168
+ key, position.shape
169
+ )
170
+ else:
171
+ # For more complex covariance structures
172
+ if len(sigma) == self.dims:
173
+ # Diagonal covariance
174
+ sigma_diag = jax.numpy.array(sigma)
175
+
176
+ def proposal_fn(key, position):
177
+ return position + sigma_diag * jax.random.normal(
178
+ key, position.shape
179
+ )
180
+ else:
181
+ # Full covariance matrix
182
+ sigma_matrix = jax.numpy.array(sigma)
183
+
184
+ def proposal_fn(key, position):
185
+ return position + jax.random.multivariate_normal(
186
+ key, jax.numpy.zeros(self.dims), sigma_matrix
187
+ )
188
+
189
+ rwmh = blackjax.rmh(log_prob_fn, proposal_fn)
190
+
191
+ # Initialize states for each particle
192
+ n_particles = z_jax.shape[0]
193
+ keys = jax.random.split(subkey, n_particles)
194
+
195
+ # Vectorized initialization and sampling
196
+ def init_and_sample(key, z_init):
197
+ state = rwmh.init(z_init)
198
+
199
+ def one_step(state, key):
200
+ state, info = rwmh.step(key, state)
201
+ return state, (state, info)
202
+
203
+ keys = jax.random.split(key, n_steps)
204
+ final_state, (states, infos) = jax.lax.scan(
205
+ one_step, state, keys
206
+ )
207
+ return final_state, infos
208
+
209
+ # Vectorize over particles
210
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
211
+
212
+ # Extract final positions
213
+ z_final = final_states.position
214
+
215
+ # Calculate acceptance rates
216
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
217
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
218
+
219
+ elif algorithm == "nuts":
220
+ # Initialize step size and mass matrix if not provided
221
+ inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
222
+ if inverse_mass_matrix is None:
223
+ inverse_mass_matrix = jax.numpy.eye(self.dims)
224
+
225
+ step_size = self.sampler_kwargs["step_size"]
226
+
227
+ # Initialize NUTS sampler
228
+ nuts = blackjax.nuts(
229
+ log_prob_fn,
230
+ step_size=step_size,
231
+ inverse_mass_matrix=inverse_mass_matrix,
232
+ )
233
+
234
+ # Initialize states for each particle
235
+ n_particles = z_jax.shape[0]
236
+ keys = jax.random.split(subkey, n_particles)
237
+
238
+ # Vectorized initialization and sampling
239
+ def init_and_sample(key, z_init):
240
+ state = nuts.init(z_init)
241
+
242
+ def one_step(state, key):
243
+ state, info = nuts.step(key, state)
244
+ return state, (state, info)
245
+
246
+ keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
247
+ final_state, (states, infos) = jax.lax.scan(
248
+ one_step, state, keys
249
+ )
250
+ return final_state, infos
251
+
252
+ # Vectorize over particles
253
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
254
+
255
+ # Extract final positions
256
+ z_final = final_states.position
257
+
258
+ # Calculate acceptance rates
259
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
260
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
261
+
262
+ elif algorithm == "hmc":
263
+ # Initialize HMC sampler
264
+ hmc = blackjax.hmc(
265
+ log_prob_fn,
266
+ step_size=self.sampler_kwargs["step_size"],
267
+ num_integration_steps=self.sampler_kwargs.get(
268
+ "num_integration_steps", 10
269
+ ),
270
+ inverse_mass_matrix=(
271
+ self.sampler_kwargs["inverse_mass_matrix"]
272
+ or jax.numpy.eye(self.dims)
273
+ ),
274
+ )
275
+
276
+ # Similar vectorized sampling as NUTS
277
+ n_particles = z_jax.shape[0]
278
+ keys = jax.random.split(subkey, n_particles)
279
+
280
+ def init_and_sample(key, z_init):
281
+ state = hmc.init(z_init)
282
+
283
+ def one_step(state, key):
284
+ state, info = hmc.step(key, state)
285
+ return state, (state, info)
286
+
287
+ keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
288
+ final_state, (states, infos) = jax.lax.scan(
289
+ one_step, state, keys
290
+ )
291
+ return final_state, infos
292
+
293
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
294
+ z_final = final_states.position
295
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
296
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
297
+
298
+ else:
299
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
300
+
301
+ # Convert back to parameter space
302
+ z_final_np = to_numpy(z_final)
303
+ x_final = self.preconditioning_transform.inverse(z_final_np)[0]
304
+
305
+ # Store MCMC diagnostics
306
+ self.history.mcmc_acceptance.append(float(mean_acceptance))
307
+
308
+ # Create new samples
309
+ samples = SMCSamples(x_final, xp=self.xp, beta=beta)
310
+ samples.log_q = samples.array_to_namespace(
311
+ self.prior_flow.log_prob(samples.x)
312
+ )
313
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
314
+ samples.log_likelihood = samples.array_to_namespace(
315
+ self.log_likelihood(samples)
316
+ )
317
+
318
+ if samples.xp.isnan(samples.log_q).any():
319
+ raise ValueError("Log proposal contains NaN values")
320
+
321
+ return samples
322
+
323
+ def _jax_log_prob(self, z, beta):
324
+ """JAX-compatible log probability function."""
325
+ import jax.numpy as jnp
326
+
327
+ # Single particle version for JAX
328
+ z_expanded = jnp.expand_dims(z, 0) # Add batch dimension
329
+ log_prob = self.log_prob(z_expanded, beta=beta)
330
+ return log_prob[0] # Remove batch dimension
@@ -0,0 +1,75 @@
1
+ import copy
2
+ import logging
3
+
4
+ import numpy as np
5
+
6
+ from ...samples import SMCSamples
7
+ from ...utils import track_calls
8
+ from .base import NumpySMCSampler
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class EmceeSMC(NumpySMCSampler):
14
+ @track_calls
15
+ def sample(
16
+ self,
17
+ n_samples: int,
18
+ n_steps: int = None,
19
+ adaptive: bool = True,
20
+ target_efficiency: float = 0.5,
21
+ target_efficiency_rate: float = 1.0,
22
+ sampler_kwargs: dict | None = None,
23
+ n_final_samples: int | None = None,
24
+ ):
25
+ self.sampler_kwargs = sampler_kwargs or {}
26
+ self.sampler_kwargs.setdefault("nsteps", 5 * self.dims)
27
+ self.sampler_kwargs.setdefault("progress", True)
28
+ self.emcee_moves = self.sampler_kwargs.pop("moves", None)
29
+ return super().sample(
30
+ n_samples,
31
+ n_steps=n_steps,
32
+ adaptive=adaptive,
33
+ target_efficiency=target_efficiency,
34
+ target_efficiency_rate=target_efficiency_rate,
35
+ n_final_samples=n_final_samples,
36
+ )
37
+
38
+ def mutate(self, particles, beta, n_steps=None):
39
+ import emcee
40
+
41
+ logger.info("Mutating particles")
42
+ sampler = emcee.EnsembleSampler(
43
+ len(particles.x),
44
+ self.dims,
45
+ self.log_prob,
46
+ args=(beta,),
47
+ vectorize=True,
48
+ moves=self.emcee_moves,
49
+ )
50
+ z = self.fit_preconditioning_transform(particles.x)
51
+ kwargs = copy.deepcopy(self.sampler_kwargs)
52
+ if n_steps is not None:
53
+ kwargs["nsteps"] = n_steps
54
+ sampler.run_mcmc(z, **kwargs)
55
+ self.history.mcmc_acceptance.append(
56
+ np.mean(sampler.acceptance_fraction)
57
+ )
58
+ self.history.mcmc_autocorr.append(
59
+ sampler.get_autocorr_time(
60
+ quiet=True, discard=int(0.2 * self.sampler_kwargs["nsteps"])
61
+ )
62
+ )
63
+ z = sampler.get_chain(flat=False)[-1, ...]
64
+ x = self.preconditioning_transform.inverse(z)[0]
65
+ samples = SMCSamples(x, xp=self.xp, beta=beta)
66
+ samples.log_q = samples.array_to_namespace(
67
+ self.prior_flow.log_prob(samples.x)
68
+ )
69
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
70
+ samples.log_likelihood = samples.array_to_namespace(
71
+ self.log_likelihood(samples)
72
+ )
73
+ if samples.xp.isnan(samples.log_q).any():
74
+ raise ValueError("Log proposal contains NaN values")
75
+ return samples
@@ -0,0 +1,82 @@
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+
5
+ from ...samples import SMCSamples
6
+ from ...utils import to_numpy, track_calls
7
+ from .base import NumpySMCSampler
8
+
9
+
10
+ class MiniPCNSMC(NumpySMCSampler):
11
+ """MiniPCN SMC sampler."""
12
+
13
+ rng = None
14
+
15
+ def log_prob(self, x, beta=None):
16
+ return to_numpy(super().log_prob(x, beta))
17
+
18
+ @track_calls
19
+ def sample(
20
+ self,
21
+ n_samples: int,
22
+ n_steps: int = None,
23
+ min_step: float | None = None,
24
+ max_n_steps: int | None = None,
25
+ adaptive: bool = True,
26
+ target_efficiency: float = 0.5,
27
+ target_efficiency_rate: float = 1.0,
28
+ n_final_samples: int | None = None,
29
+ sampler_kwargs: dict | None = None,
30
+ rng: np.random.Generator | None = None,
31
+ ):
32
+ self.sampler_kwargs = sampler_kwargs or {}
33
+ self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
34
+ self.sampler_kwargs.setdefault("target_acceptance_rate", 0.234)
35
+ self.sampler_kwargs.setdefault("step_fn", "tpcn")
36
+ self.rng = rng or np.random.default_rng()
37
+ return super().sample(
38
+ n_samples,
39
+ n_steps=n_steps,
40
+ adaptive=adaptive,
41
+ target_efficiency=target_efficiency,
42
+ target_efficiency_rate=target_efficiency_rate,
43
+ n_final_samples=n_final_samples,
44
+ min_step=min_step,
45
+ max_n_steps=max_n_steps,
46
+ )
47
+
48
+ def mutate(self, particles, beta, n_steps=None):
49
+ from minipcn import Sampler
50
+
51
+ log_prob_fn = partial(self.log_prob, beta=beta)
52
+
53
+ sampler = Sampler(
54
+ log_prob_fn=log_prob_fn,
55
+ step_fn=self.sampler_kwargs["step_fn"],
56
+ rng=self.rng,
57
+ dims=self.dims,
58
+ target_acceptance_rate=self.sampler_kwargs[
59
+ "target_acceptance_rate"
60
+ ],
61
+ )
62
+ # Map to transformed dimension for sampling
63
+ z = to_numpy(self.fit_preconditioning_transform(particles.x))
64
+ chain, history = sampler.sample(
65
+ z,
66
+ n_steps=n_steps or self.sampler_kwargs["n_steps"],
67
+ )
68
+ x = self.preconditioning_transform.inverse(chain[-1])[0]
69
+
70
+ self.history.mcmc_acceptance.append(np.mean(history.acceptance_rate))
71
+
72
+ samples = SMCSamples(x, xp=self.xp, beta=beta)
73
+ samples.log_q = samples.array_to_namespace(
74
+ self.prior_flow.log_prob(samples.x)
75
+ )
76
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
77
+ samples.log_likelihood = samples.array_to_namespace(
78
+ self.log_likelihood(samples)
79
+ )
80
+ if samples.xp.isnan(samples.log_q).any():
81
+ raise ValueError("Log proposal contains NaN values")
82
+ return samples