aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,318 @@
1
+ import logging
2
+ from typing import Any, Callable
3
+
4
+ import array_api_compat.numpy as np
5
+
6
+ from ...flows.base import Flow
7
+ from ...history import SMCHistory
8
+ from ...samples import SMCSamples
9
+ from ...utils import (
10
+ asarray,
11
+ effective_sample_size,
12
+ track_calls,
13
+ update_at_indices,
14
+ )
15
+ from ..mcmc import MCMCSampler
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SMCSampler(MCMCSampler):
21
+ """Base class for Sequential Monte Carlo samplers."""
22
+
23
+ def __init__(
24
+ self,
25
+ log_likelihood: Callable,
26
+ log_prior: Callable,
27
+ dims: int,
28
+ prior_flow: Flow,
29
+ xp: Callable,
30
+ dtype: Any | str | None = None,
31
+ parameters: list[str] | None = None,
32
+ rng: np.random.Generator | None = None,
33
+ preconditioning_transform: Callable | None = None,
34
+ ):
35
+ super().__init__(
36
+ log_likelihood=log_likelihood,
37
+ log_prior=log_prior,
38
+ dims=dims,
39
+ prior_flow=prior_flow,
40
+ xp=xp,
41
+ dtype=dtype,
42
+ parameters=parameters,
43
+ preconditioning_transform=preconditioning_transform,
44
+ )
45
+ self.rng = rng or np.random.default_rng()
46
+ self._adapative_target_efficiency = False
47
+
48
+ @property
49
+ def target_efficiency(self):
50
+ return self._target_efficiency
51
+
52
+ @target_efficiency.setter
53
+ def target_efficiency(self, value: float | tuple):
54
+ """Set the target efficiency.
55
+
56
+ Parameters
57
+ ----------
58
+ value : float or tuple
59
+ If a float, the target efficiency to use for all iterations.
60
+ If a tuple of two floats, the target efficiency will adapt from
61
+ the first value to the second value over the course of the SMC
62
+ iterations. See `target_efficiency_rate` for details.
63
+ """
64
+ if isinstance(value, float):
65
+ if not (0 < value < 1):
66
+ raise ValueError("target_efficiency must be in (0, 1)")
67
+ self._target_efficiency = value
68
+ self._adapative_target_efficiency = False
69
+ elif len(value) != 2:
70
+ raise ValueError(
71
+ "target_efficiency must be a float or tuple of two floats"
72
+ )
73
+ else:
74
+ value = tuple(map(float, value))
75
+ if not (0 < value[0] < value[1] < 1):
76
+ raise ValueError(
77
+ "target_efficiency tuple must be in (0, 1) and increasing"
78
+ )
79
+ self._target_efficiency = value
80
+ self._adapative_target_efficiency = True
81
+
82
+ def current_target_efficiency(self, beta: float) -> float:
83
+ """Get the current target efficiency based on beta."""
84
+ if self._adapative_target_efficiency:
85
+ return self._target_efficiency[0] + (
86
+ self._target_efficiency[1] - self._target_efficiency[0]
87
+ ) * (beta**self.target_efficiency_rate)
88
+ else:
89
+ return self._target_efficiency
90
+
91
+ def determine_beta(
92
+ self,
93
+ samples: SMCSamples,
94
+ beta: float,
95
+ beta_step: float,
96
+ min_step: float,
97
+ ) -> tuple[float, float]:
98
+ """Determine the next beta value.
99
+
100
+ Parameters
101
+ ----------
102
+ samples : SMCSamples
103
+ The current samples.
104
+ beta : float
105
+ The current beta value.
106
+ beta_step : float
107
+ The fixed beta step size if not adaptive.
108
+ min_step : float
109
+ The minimum beta step size.
110
+
111
+ Returns
112
+ -------
113
+ beta : float
114
+ The new beta value.
115
+ min_step : float
116
+ The new minimum step size if adaptive_min_step is True.
117
+ """
118
+ if not self.adaptive:
119
+ beta += beta_step
120
+ if beta >= 1.0:
121
+ beta = 1.0
122
+ else:
123
+ beta_prev = beta
124
+ beta_min = beta_prev
125
+ beta_max = 1.0
126
+ tol = 1e-5
127
+ eff_beta_max = effective_sample_size(
128
+ samples.log_weights(beta_max)
129
+ ) / len(samples)
130
+ if eff_beta_max >= self.current_target_efficiency(beta_prev):
131
+ beta_min = 1.0
132
+ target_eff = self.current_target_efficiency(beta_prev)
133
+ while beta_max - beta_min > tol:
134
+ beta_try = 0.5 * (beta_max + beta_min)
135
+ eff = effective_sample_size(
136
+ samples.log_weights(beta_try)
137
+ ) / len(samples)
138
+ if eff >= target_eff:
139
+ beta_min = beta_try
140
+ else:
141
+ beta_max = beta_try
142
+ beta_star = beta_min
143
+
144
+ if self.adaptive_min_step:
145
+ min_step = min_step * (1 - beta_prev) / (1 - beta_star)
146
+ beta = max(beta_star, beta_prev + min_step)
147
+ beta = min(beta, 1.0)
148
+ return beta, min_step
149
+
150
+ @track_calls
151
+ def sample(
152
+ self,
153
+ n_samples: int,
154
+ n_steps: int | None = None,
155
+ adaptive: bool = True,
156
+ min_step: float | None = None,
157
+ max_n_steps: int | None = None,
158
+ target_efficiency: float = 0.5,
159
+ target_efficiency_rate: float = 1.0,
160
+ n_final_samples: int | None = None,
161
+ ) -> SMCSamples:
162
+ samples = self.draw_initial_samples(n_samples)
163
+ samples = SMCSamples.from_samples(
164
+ samples, xp=self.xp, beta=0.0, dtype=self.dtype
165
+ )
166
+ self.fit_preconditioning_transform(samples.x)
167
+
168
+ if self.xp.isnan(samples.log_q).any():
169
+ raise ValueError("Log proposal contains NaN values")
170
+ if self.xp.isnan(samples.log_prior).any():
171
+ raise ValueError("Log prior contains NaN values")
172
+ if self.xp.isnan(samples.log_likelihood).any():
173
+ raise ValueError("Log likelihood contains NaN values")
174
+
175
+ logger.debug(f"Initial sample summary: {samples}")
176
+
177
+ # Remove the n_final_steps from sampler_kwargs if present
178
+ self.sampler_kwargs = self.sampler_kwargs or {}
179
+ n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
180
+
181
+ self.history = SMCHistory()
182
+
183
+ self.target_efficiency = target_efficiency
184
+ self.target_efficiency_rate = target_efficiency_rate
185
+
186
+ if n_steps is not None:
187
+ beta_step = 1 / n_steps
188
+ elif not adaptive:
189
+ raise ValueError("Either n_steps or adaptive=True must be set")
190
+ else:
191
+ beta_step = np.nan
192
+ self.adaptive = adaptive
193
+ beta = 0.0
194
+
195
+ if min_step is None:
196
+ if max_n_steps is None:
197
+ min_step = 0.0
198
+ self.adaptive_min_step = False
199
+ else:
200
+ min_step = 1 / max_n_steps
201
+ self.adaptive_min_step = True
202
+ else:
203
+ self.adaptive_min_step = False
204
+
205
+ iterations = 0
206
+ while True:
207
+ iterations += 1
208
+
209
+ beta, min_step = self.determine_beta(
210
+ samples,
211
+ beta,
212
+ beta_step,
213
+ min_step,
214
+ )
215
+ self.history.eff_target.append(
216
+ self.current_target_efficiency(beta)
217
+ )
218
+
219
+ logger.info(f"it {iterations} - beta: {beta}")
220
+ self.history.beta.append(beta)
221
+
222
+ ess = effective_sample_size(samples.log_weights(beta))
223
+ eff = ess / len(samples)
224
+ if eff < 0.1:
225
+ logger.warning(
226
+ f"it {iterations} - Low sample efficiency: {eff:.2f}"
227
+ )
228
+ self.history.ess.append(ess)
229
+ logger.info(
230
+ f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
231
+ )
232
+ self.history.ess_target.append(
233
+ effective_sample_size(samples.log_weights(1.0))
234
+ )
235
+
236
+ log_evidence_ratio = samples.log_evidence_ratio(beta)
237
+ log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
238
+ self.history.log_norm_ratio.append(log_evidence_ratio)
239
+ self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
240
+ logger.info(
241
+ f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
242
+ )
243
+
244
+ samples = samples.resample(beta, rng=self.rng)
245
+
246
+ samples = self.mutate(samples, beta)
247
+ if beta == 1.0 or (
248
+ max_n_steps is not None and iterations >= max_n_steps
249
+ ):
250
+ break
251
+
252
+ # If n_final_samples is not None, perform additional mutations steps
253
+ if n_final_samples is not None:
254
+ logger.info(f"Generating {n_final_samples} final samples")
255
+ final_samples = samples.resample(
256
+ 1.0, n_samples=n_final_samples, rng=self.rng
257
+ )
258
+ samples = self.mutate(final_samples, 1.0, n_steps=n_final_steps)
259
+
260
+ samples.log_evidence = samples.xp.sum(
261
+ asarray(self.history.log_norm_ratio, self.xp)
262
+ )
263
+ samples.log_evidence_error = samples.xp.sqrt(
264
+ samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
265
+ )
266
+
267
+ final_samples = samples.to_standard_samples()
268
+ logger.info(
269
+ f"Log evidence: {final_samples.log_evidence:.2f} +/- {final_samples.log_evidence_error:.2f}"
270
+ )
271
+ return final_samples
272
+
273
+ def mutate(self, particles):
274
+ raise NotImplementedError
275
+
276
+ def log_prob(self, z, beta=None):
277
+ x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
278
+ samples = SMCSamples(x, xp=self.xp, beta=beta, dtype=self.dtype)
279
+ log_q = self.prior_flow.log_prob(samples.x)
280
+ samples.log_q = samples.array_to_namespace(log_q)
281
+ samples.log_prior = self.log_prior(samples)
282
+ samples.log_likelihood = self.log_likelihood(samples)
283
+ log_prob = samples.log_p_t(
284
+ beta=beta
285
+ ).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
286
+
287
+ log_prob = update_at_indices(
288
+ log_prob, self.xp.isnan(log_prob), -self.xp.inf
289
+ )
290
+ return log_prob
291
+
292
+
293
+ class NumpySMCSampler(SMCSampler):
294
+ def __init__(
295
+ self,
296
+ log_likelihood,
297
+ log_prior,
298
+ dims,
299
+ prior_flow,
300
+ xp,
301
+ dtype=None,
302
+ parameters=None,
303
+ preconditioning_transform=None,
304
+ ):
305
+ if preconditioning_transform is not None:
306
+ preconditioning_transform = preconditioning_transform.new_instance(
307
+ xp=np
308
+ )
309
+ super().__init__(
310
+ log_likelihood,
311
+ log_prior,
312
+ dims,
313
+ prior_flow=prior_flow,
314
+ xp=xp,
315
+ dtype=dtype,
316
+ parameters=parameters,
317
+ preconditioning_transform=preconditioning_transform,
318
+ )
@@ -0,0 +1,332 @@
1
+ import logging
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+
6
+ from ...samples import SMCSamples
7
+ from ...utils import asarray, to_numpy, track_calls
8
+ from .base import SMCSampler
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class BlackJAXSMC(SMCSampler):
14
+ """BlackJAX SMC sampler."""
15
+
16
+ def __init__(
17
+ self,
18
+ log_likelihood,
19
+ log_prior,
20
+ dims,
21
+ prior_flow,
22
+ xp,
23
+ dtype=None,
24
+ parameters=None,
25
+ preconditioning_transform=None,
26
+ rng: np.random.Generator | None = None, # New parameter
27
+ ):
28
+ # For JAX compatibility, we'll keep the original xp
29
+ super().__init__(
30
+ log_likelihood=log_likelihood,
31
+ log_prior=log_prior,
32
+ dims=dims,
33
+ prior_flow=prior_flow,
34
+ xp=xp,
35
+ dtype=dtype,
36
+ parameters=parameters,
37
+ preconditioning_transform=preconditioning_transform,
38
+ )
39
+ self.key = None
40
+ self.rng = rng or np.random.default_rng()
41
+
42
+ def log_prob(self, x, beta=None):
43
+ """Log probability function compatible with BlackJAX."""
44
+ # Convert to original xp format for computation
45
+ if hasattr(x, "__array__"):
46
+ x_original = asarray(x, self.xp)
47
+ else:
48
+ x_original = x
49
+
50
+ # Transform back to parameter space
51
+ x_params, log_abs_det_jacobian = (
52
+ self.preconditioning_transform.inverse(x_original)
53
+ )
54
+ samples = SMCSamples(x_params, xp=self.xp, dtype=self.dtype)
55
+
56
+ # Compute log probabilities
57
+ log_q = self.prior_flow.log_prob(samples.x)
58
+ samples.log_q = samples.array_to_namespace(log_q)
59
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
60
+ samples.log_likelihood = samples.array_to_namespace(
61
+ self.log_likelihood(samples)
62
+ )
63
+
64
+ # Compute target log probability
65
+ log_prob = samples.log_p_t(
66
+ beta=beta
67
+ ).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
68
+
69
+ # Handle NaN values
70
+ log_prob = self.xp.where(
71
+ self.xp.isnan(log_prob), -self.xp.inf, log_prob
72
+ )
73
+
74
+ return log_prob
75
+
76
+ @track_calls
77
+ def sample(
78
+ self,
79
+ n_samples: int,
80
+ n_steps: int = None,
81
+ adaptive: bool = True,
82
+ target_efficiency: float = 0.5,
83
+ target_efficiency_rate: float = 1.0,
84
+ n_final_samples: int | None = None,
85
+ sampler_kwargs: dict | None = None,
86
+ rng_key=None,
87
+ ):
88
+ """Sample using BlackJAX SMC.
89
+
90
+ Parameters
91
+ ----------
92
+ n_samples : int
93
+ Number of samples to draw.
94
+ n_steps : int
95
+ Number of SMC steps.
96
+ adaptive : bool
97
+ Whether to use adaptive tempering.
98
+ target_efficiency : float
99
+ Target efficiency for adaptive tempering.
100
+ n_final_samples : int | None
101
+ Number of final samples to return.
102
+ sampler_kwargs : dict | None
103
+ Additional arguments for the BlackJAX sampler.
104
+ - algorithm: str, one of "nuts", "hmc", "rwmh", "random_walk"
105
+ - n_steps: int, number of MCMC steps per mutation
106
+ - step_size: float, step size for HMC/NUTS
107
+ - inverse_mass_matrix: array, inverse mass matrix
108
+ - sigma: float or array, proposal covariance for random walk MH
109
+ - num_integration_steps: int, integration steps for HMC
110
+ rng_key : jax.random.key| None
111
+ JAX random key for reproducibility.
112
+ """
113
+ self.sampler_kwargs = sampler_kwargs or {}
114
+ self.sampler_kwargs.setdefault("n_steps", 5 * self.dims)
115
+ self.sampler_kwargs.setdefault("algorithm", "nuts")
116
+ self.sampler_kwargs.setdefault("step_size", 1e-3)
117
+ self.sampler_kwargs.setdefault("inverse_mass_matrix", None)
118
+ self.sampler_kwargs.setdefault("sigma", 0.1) # For random walk MH
119
+
120
+ # Initialize JAX random key
121
+ if rng_key is None:
122
+ import jax
123
+
124
+ self.key = jax.random.key(42)
125
+ else:
126
+ self.key = rng_key
127
+
128
+ return super().sample(
129
+ n_samples,
130
+ n_steps=n_steps,
131
+ adaptive=adaptive,
132
+ target_efficiency=target_efficiency,
133
+ target_efficiency_rate=target_efficiency_rate,
134
+ n_final_samples=n_final_samples,
135
+ )
136
+
137
+ def mutate(self, particles, beta, n_steps=None):
138
+ """Mutate particles using BlackJAX MCMC."""
139
+ import blackjax
140
+ import jax
141
+
142
+ logger.debug("Mutating particles with BlackJAX")
143
+
144
+ # Split the random key
145
+ self.key, subkey = jax.random.split(self.key)
146
+
147
+ # Transform particles to latent space
148
+ z = self.fit_preconditioning_transform(particles.x)
149
+
150
+ # Convert to JAX arrays
151
+ z_jax = jax.numpy.asarray(to_numpy(z))
152
+
153
+ # Create log probability function for this beta
154
+ log_prob_fn = partial(self._jax_log_prob, beta=beta)
155
+
156
+ # Choose BlackJAX algorithm
157
+ algorithm = self.sampler_kwargs["algorithm"].lower()
158
+
159
+ n_steps = n_steps or self.sampler_kwargs["n_steps"]
160
+
161
+ if algorithm == "rwmh" or algorithm == "random_walk":
162
+ # Initialize Random Walk Metropolis-Hastings sampler
163
+ sigma = self.sampler_kwargs.get("sigma", 0.1)
164
+
165
+ # BlackJAX RMH expects a transition function, not a covariance
166
+ if isinstance(sigma, (int, float)):
167
+ # Create a multivariate normal proposal function
168
+ def proposal_fn(key, position):
169
+ return position + sigma * jax.random.normal(
170
+ key, position.shape
171
+ )
172
+ else:
173
+ # For more complex covariance structures
174
+ if len(sigma) == self.dims:
175
+ # Diagonal covariance
176
+ sigma_diag = jax.numpy.array(sigma)
177
+
178
+ def proposal_fn(key, position):
179
+ return position + sigma_diag * jax.random.normal(
180
+ key, position.shape
181
+ )
182
+ else:
183
+ # Full covariance matrix
184
+ sigma_matrix = jax.numpy.array(sigma)
185
+
186
+ def proposal_fn(key, position):
187
+ return position + jax.random.multivariate_normal(
188
+ key, jax.numpy.zeros(self.dims), sigma_matrix
189
+ )
190
+
191
+ rwmh = blackjax.rmh(log_prob_fn, proposal_fn)
192
+
193
+ # Initialize states for each particle
194
+ n_particles = z_jax.shape[0]
195
+ keys = jax.random.split(subkey, n_particles)
196
+
197
+ # Vectorized initialization and sampling
198
+ def init_and_sample(key, z_init):
199
+ state = rwmh.init(z_init)
200
+
201
+ def one_step(state, key):
202
+ state, info = rwmh.step(key, state)
203
+ return state, (state, info)
204
+
205
+ keys = jax.random.split(key, n_steps)
206
+ final_state, (states, infos) = jax.lax.scan(
207
+ one_step, state, keys
208
+ )
209
+ return final_state, infos
210
+
211
+ # Vectorize over particles
212
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
213
+
214
+ # Extract final positions
215
+ z_final = final_states.position
216
+
217
+ # Calculate acceptance rates
218
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
219
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
220
+
221
+ elif algorithm == "nuts":
222
+ # Initialize step size and mass matrix if not provided
223
+ inverse_mass_matrix = self.sampler_kwargs["inverse_mass_matrix"]
224
+ if inverse_mass_matrix is None:
225
+ inverse_mass_matrix = jax.numpy.eye(self.dims)
226
+
227
+ step_size = self.sampler_kwargs["step_size"]
228
+
229
+ # Initialize NUTS sampler
230
+ nuts = blackjax.nuts(
231
+ log_prob_fn,
232
+ step_size=step_size,
233
+ inverse_mass_matrix=inverse_mass_matrix,
234
+ )
235
+
236
+ # Initialize states for each particle
237
+ n_particles = z_jax.shape[0]
238
+ keys = jax.random.split(subkey, n_particles)
239
+
240
+ # Vectorized initialization and sampling
241
+ def init_and_sample(key, z_init):
242
+ state = nuts.init(z_init)
243
+
244
+ def one_step(state, key):
245
+ state, info = nuts.step(key, state)
246
+ return state, (state, info)
247
+
248
+ keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
249
+ final_state, (states, infos) = jax.lax.scan(
250
+ one_step, state, keys
251
+ )
252
+ return final_state, infos
253
+
254
+ # Vectorize over particles
255
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
256
+
257
+ # Extract final positions
258
+ z_final = final_states.position
259
+
260
+ # Calculate acceptance rates
261
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
262
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
263
+
264
+ elif algorithm == "hmc":
265
+ # Initialize HMC sampler
266
+ hmc = blackjax.hmc(
267
+ log_prob_fn,
268
+ step_size=self.sampler_kwargs["step_size"],
269
+ num_integration_steps=self.sampler_kwargs.get(
270
+ "num_integration_steps", 10
271
+ ),
272
+ inverse_mass_matrix=(
273
+ self.sampler_kwargs["inverse_mass_matrix"]
274
+ or jax.numpy.eye(self.dims)
275
+ ),
276
+ )
277
+
278
+ # Similar vectorized sampling as NUTS
279
+ n_particles = z_jax.shape[0]
280
+ keys = jax.random.split(subkey, n_particles)
281
+
282
+ def init_and_sample(key, z_init):
283
+ state = hmc.init(z_init)
284
+
285
+ def one_step(state, key):
286
+ state, info = hmc.step(key, state)
287
+ return state, (state, info)
288
+
289
+ keys = jax.random.split(key, self.sampler_kwargs["n_steps"])
290
+ final_state, (states, infos) = jax.lax.scan(
291
+ one_step, state, keys
292
+ )
293
+ return final_state, infos
294
+
295
+ final_states, all_infos = jax.vmap(init_and_sample)(keys, z_jax)
296
+ z_final = final_states.position
297
+ acceptance_rates = jax.numpy.mean(all_infos.is_accepted, axis=1)
298
+ mean_acceptance = jax.numpy.mean(acceptance_rates)
299
+
300
+ else:
301
+ raise ValueError(f"Unsupported algorithm: {algorithm}")
302
+
303
+ # Convert back to parameter space
304
+ z_final_np = to_numpy(z_final)
305
+ x_final = self.preconditioning_transform.inverse(z_final_np)[0]
306
+
307
+ # Store MCMC diagnostics
308
+ self.history.mcmc_acceptance.append(float(mean_acceptance))
309
+
310
+ # Create new samples
311
+ samples = SMCSamples(x_final, xp=self.xp, beta=beta, dtype=self.dtype)
312
+ samples.log_q = samples.array_to_namespace(
313
+ self.prior_flow.log_prob(samples.x)
314
+ )
315
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
316
+ samples.log_likelihood = samples.array_to_namespace(
317
+ self.log_likelihood(samples)
318
+ )
319
+
320
+ if samples.xp.isnan(samples.log_q).any():
321
+ raise ValueError("Log proposal contains NaN values")
322
+
323
+ return samples
324
+
325
+ def _jax_log_prob(self, z, beta):
326
+ """JAX-compatible log probability function."""
327
+ import jax.numpy as jnp
328
+
329
+ # Single particle version for JAX
330
+ z_expanded = jnp.expand_dims(z, 0) # Add batch dimension
331
+ log_prob = self.log_prob(z_expanded, beta=beta)
332
+ return log_prob[0] # Remove batch dimension