aspire-inference 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ import numpy as np
2
+
3
+ from ..samples import Samples, to_numpy
4
+ from ..utils import track_calls
5
+ from .base import Sampler
6
+
7
+
8
+ class MCMCSampler(Sampler):
9
+ def draw_initial_samples(self, n_samples: int) -> Samples:
10
+ """Draw initial samples from the prior flow."""
11
+ # Flow may propose samples outside prior bounds, so we may need
12
+ # to try multiple times to get enough valid samples.
13
+ n_samples_drawn = 0
14
+ samples = None
15
+ while n_samples_drawn < n_samples:
16
+ n_to_draw = n_samples - n_samples_drawn
17
+ x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
18
+ new_samples = Samples(x, xp=self.xp, log_q=log_q)
19
+ new_samples.log_prior = new_samples.array_to_namespace(
20
+ self.log_prior(new_samples)
21
+ )
22
+ valid = self.xp.isfinite(new_samples.log_prior)
23
+ n_valid = self.xp.sum(valid)
24
+ if n_valid > 0:
25
+ if samples is None:
26
+ samples = new_samples[valid]
27
+ else:
28
+ samples = Samples.concatenate(
29
+ [samples, new_samples[valid]]
30
+ )
31
+ n_samples_drawn += n_valid
32
+
33
+ if n_samples_drawn > n_samples:
34
+ samples = samples[:n_samples]
35
+
36
+ samples.log_likelihood = samples.array_to_namespace(
37
+ self.log_likelihood(samples)
38
+ )
39
+ return samples
40
+
41
+ def log_prob(self, z):
42
+ """Compute the log probability of the samples.
43
+
44
+ Input samples are in the transformed space.
45
+ """
46
+ x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
47
+ samples = Samples(x, xp=self.xp)
48
+ samples.log_prior = self.log_prior(samples)
49
+ samples.log_likelihood = self.log_likelihood(samples)
50
+ log_prob = (
51
+ samples.log_likelihood
52
+ + samples.log_prior
53
+ + samples.array_to_namespace(log_abs_det_jacobian)
54
+ )
55
+ return to_numpy(log_prob).flatten()
56
+
57
+
58
+ class Emcee(MCMCSampler):
59
+ @track_calls
60
+ def sample(
61
+ self,
62
+ n_samples: int,
63
+ nwalkers: int = None,
64
+ nsteps: int = 500,
65
+ rng=None,
66
+ discard=0,
67
+ **kwargs,
68
+ ) -> Samples:
69
+ from emcee import EnsembleSampler
70
+
71
+ nwalkers = nwalkers or n_samples
72
+ self.sampler = EnsembleSampler(
73
+ nwalkers,
74
+ self.dims,
75
+ log_prob_fn=self.log_prob,
76
+ vectorize=True,
77
+ )
78
+
79
+ rng = rng or np.random.default_rng()
80
+
81
+ samples = self.draw_initial_samples(nwalkers)
82
+ p0 = samples.x
83
+
84
+ z0 = to_numpy(self.preconditioning_transform.fit(p0))
85
+
86
+ self.sampler.run_mcmc(z0, nsteps, **kwargs)
87
+
88
+ z = self.sampler.get_chain(flat=True, discard=discard)
89
+ x = self.preconditioning_transform.inverse(z)[0]
90
+
91
+ x_evidence, log_q = self.prior_flow.sample_and_log_prob(n_samples)
92
+ samples_evidence = Samples(x_evidence, log_q=log_q, xp=self.xp)
93
+ samples_evidence.log_prior = self.log_prior(samples_evidence)
94
+ samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
95
+ samples_evidence.compute_weights()
96
+
97
+ samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
98
+ samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
99
+ self.log_prior(samples_mcmc)
100
+ )
101
+ samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
102
+ self.log_likelihood(samples_mcmc)
103
+ )
104
+ samples_mcmc.log_evidence = samples_mcmc.array_to_namespace(
105
+ samples_evidence.log_evidence
106
+ )
107
+ samples_mcmc.log_evidence_error = samples_mcmc.array_to_namespace(
108
+ samples_evidence.log_evidence_error
109
+ )
110
+
111
+ return samples_mcmc
112
+
113
+
114
+ class MiniPCN(MCMCSampler):
115
+ @track_calls
116
+ def sample(
117
+ self,
118
+ n_samples,
119
+ rng=None,
120
+ target_acceptance_rate=0.234,
121
+ n_steps=100,
122
+ thin=1,
123
+ burnin=0,
124
+ last_step_only=False,
125
+ step_fn="tpcn",
126
+ ):
127
+ from minipcn import Sampler
128
+
129
+ rng = rng or np.random.default_rng()
130
+ p0 = self.draw_initial_samples(n_samples).x
131
+
132
+ z0 = to_numpy(self.preconditioning_transform.fit(p0))
133
+
134
+ self.sampler = Sampler(
135
+ log_prob_fn=self.log_prob,
136
+ step_fn=step_fn,
137
+ rng=rng,
138
+ dims=self.dims,
139
+ target_acceptance_rate=target_acceptance_rate,
140
+ )
141
+
142
+ chain, history = self.sampler.sample(z0, n_steps=n_steps)
143
+
144
+ if last_step_only:
145
+ z = chain[-1]
146
+ else:
147
+ z = chain[burnin::thin].reshape(-1, self.dims)
148
+
149
+ x = self.preconditioning_transform.inverse(z)[0]
150
+
151
+ samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
152
+ samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
153
+ self.log_prior(samples_mcmc)
154
+ )
155
+ samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
156
+ self.log_likelihood(samples_mcmc)
157
+ )
158
+ return samples_mcmc
File without changes
@@ -0,0 +1,312 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import array_api_compat.numpy as np
5
+
6
+ from ...flows.base import Flow
7
+ from ...history import SMCHistory
8
+ from ...samples import SMCSamples
9
+ from ...utils import (
10
+ asarray,
11
+ effective_sample_size,
12
+ track_calls,
13
+ update_at_indices,
14
+ )
15
+ from ..mcmc import MCMCSampler
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SMCSampler(MCMCSampler):
21
+ """Base class for Sequential Monte Carlo samplers."""
22
+
23
+ def __init__(
24
+ self,
25
+ log_likelihood: Callable,
26
+ log_prior: Callable,
27
+ dims: int,
28
+ prior_flow: Flow,
29
+ xp: Callable,
30
+ parameters: list[str] | None = None,
31
+ rng: np.random.Generator | None = None,
32
+ preconditioning_transform: Callable | None = None,
33
+ ):
34
+ super().__init__(
35
+ log_likelihood=log_likelihood,
36
+ log_prior=log_prior,
37
+ dims=dims,
38
+ prior_flow=prior_flow,
39
+ xp=xp,
40
+ parameters=parameters,
41
+ preconditioning_transform=preconditioning_transform,
42
+ )
43
+ self.rng = rng or np.random.default_rng()
44
+ self._adapative_target_efficiency = False
45
+
46
+ @property
47
+ def target_efficiency(self):
48
+ return self._target_efficiency
49
+
50
+ @target_efficiency.setter
51
+ def target_efficiency(self, value: float | tuple):
52
+ """Set the target efficiency.
53
+
54
+ Parameters
55
+ ----------
56
+ value : float or tuple
57
+ If a float, the target efficiency to use for all iterations.
58
+ If a tuple of two floats, the target efficiency will adapt from
59
+ the first value to the second value over the course of the SMC
60
+ iterations. See `target_efficiency_rate` for details.
61
+ """
62
+ if isinstance(value, float):
63
+ if not (0 < value < 1):
64
+ raise ValueError("target_efficiency must be in (0, 1)")
65
+ self._target_efficiency = value
66
+ self._adapative_target_efficiency = False
67
+ elif len(value) != 2:
68
+ raise ValueError(
69
+ "target_efficiency must be a float or tuple of two floats"
70
+ )
71
+ else:
72
+ value = tuple(map(float, value))
73
+ if not (0 < value[0] < value[1] < 1):
74
+ raise ValueError(
75
+ "target_efficiency tuple must be in (0, 1) and increasing"
76
+ )
77
+ self._target_efficiency = value
78
+ self._adapative_target_efficiency = True
79
+
80
+ def current_target_efficiency(self, beta: float) -> float:
81
+ """Get the current target efficiency based on beta."""
82
+ if self._adapative_target_efficiency:
83
+ return self._target_efficiency[0] + (
84
+ self._target_efficiency[1] - self._target_efficiency[0]
85
+ ) * (beta**self.target_efficiency_rate)
86
+ else:
87
+ return self._target_efficiency
88
+
89
+ def determine_beta(
90
+ self,
91
+ samples: SMCSamples,
92
+ beta: float,
93
+ beta_step: float,
94
+ min_step: float,
95
+ ) -> tuple[float, float]:
96
+ """Determine the next beta value.
97
+
98
+ Parameters
99
+ ----------
100
+ samples : SMCSamples
101
+ The current samples.
102
+ beta : float
103
+ The current beta value.
104
+ beta_step : float
105
+ The fixed beta step size if not adaptive.
106
+ min_step : float
107
+ The minimum beta step size.
108
+
109
+ Returns
110
+ -------
111
+ beta : float
112
+ The new beta value.
113
+ min_step : float
114
+ The new minimum step size if adaptive_min_step is True.
115
+ """
116
+ if not self.adaptive:
117
+ beta += beta_step
118
+ if beta >= 1.0:
119
+ beta = 1.0
120
+ else:
121
+ beta_prev = beta
122
+ beta_min = beta_prev
123
+ beta_max = 1.0
124
+ tol = 1e-5
125
+ eff_beta_max = effective_sample_size(
126
+ samples.log_weights(beta_max)
127
+ ) / len(samples)
128
+ if eff_beta_max >= self.current_target_efficiency(beta_prev):
129
+ beta_min = 1.0
130
+ target_eff = self.current_target_efficiency(beta_prev)
131
+ while beta_max - beta_min > tol:
132
+ beta_try = 0.5 * (beta_max + beta_min)
133
+ eff = effective_sample_size(
134
+ samples.log_weights(beta_try)
135
+ ) / len(samples)
136
+ if eff >= target_eff:
137
+ beta_min = beta_try
138
+ else:
139
+ beta_max = beta_try
140
+ beta_star = beta_min
141
+
142
+ if self.adaptive_min_step:
143
+ min_step = min_step * (1 - beta_prev) / (1 - beta_star)
144
+ beta = max(beta_star, beta_prev + min_step)
145
+ beta = min(beta, 1.0)
146
+ return beta, min_step
147
+
148
+ @track_calls
149
+ def sample(
150
+ self,
151
+ n_samples: int,
152
+ n_steps: int | None = None,
153
+ adaptive: bool = True,
154
+ min_step: float | None = None,
155
+ max_n_steps: int | None = None,
156
+ target_efficiency: float = 0.5,
157
+ target_efficiency_rate: float = 1.0,
158
+ n_final_samples: int | None = None,
159
+ ) -> SMCSamples:
160
+ samples = self.draw_initial_samples(n_samples)
161
+ samples = SMCSamples.from_samples(samples, xp=self.xp, beta=0.0)
162
+ self.fit_preconditioning_transform(samples.x)
163
+
164
+ if self.xp.isnan(samples.log_q).any():
165
+ raise ValueError("Log proposal contains NaN values")
166
+ if self.xp.isnan(samples.log_prior).any():
167
+ raise ValueError("Log prior contains NaN values")
168
+ if self.xp.isnan(samples.log_likelihood).any():
169
+ raise ValueError("Log likelihood contains NaN values")
170
+
171
+ logger.debug(f"Initial sample summary: {samples}")
172
+
173
+ # Remove the n_final_steps from sampler_kwargs if present
174
+ self.sampler_kwargs = self.sampler_kwargs or {}
175
+ n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
176
+
177
+ self.history = SMCHistory()
178
+
179
+ self.target_efficiency = target_efficiency
180
+ self.target_efficiency_rate = target_efficiency_rate
181
+
182
+ if n_steps is not None:
183
+ beta_step = 1 / n_steps
184
+ elif not adaptive:
185
+ raise ValueError("Either n_steps or adaptive=True must be set")
186
+ else:
187
+ beta_step = np.nan
188
+ self.adaptive = adaptive
189
+ beta = 0.0
190
+
191
+ if min_step is None:
192
+ if max_n_steps is None:
193
+ min_step = 0.0
194
+ self.adaptive_min_step = False
195
+ else:
196
+ min_step = 1 / max_n_steps
197
+ self.adaptive_min_step = True
198
+ else:
199
+ self.adaptive_min_step = False
200
+
201
+ iterations = 0
202
+ while True:
203
+ iterations += 1
204
+
205
+ beta, min_step = self.determine_beta(
206
+ samples,
207
+ beta,
208
+ beta_step,
209
+ min_step,
210
+ )
211
+ self.history.eff_target.append(
212
+ self.current_target_efficiency(beta)
213
+ )
214
+
215
+ logger.info(f"it {iterations} - beta: {beta}")
216
+ self.history.beta.append(beta)
217
+
218
+ ess = effective_sample_size(samples.log_weights(beta))
219
+ eff = ess / len(samples)
220
+ if eff < 0.1:
221
+ logger.warning(
222
+ f"it {iterations} - Low sample efficiency: {eff:.2f}"
223
+ )
224
+ self.history.ess.append(ess)
225
+ logger.info(
226
+ f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
227
+ )
228
+ self.history.ess_target.append(
229
+ effective_sample_size(samples.log_weights(1.0))
230
+ )
231
+
232
+ log_evidence_ratio = samples.log_evidence_ratio(beta)
233
+ log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
234
+ self.history.log_norm_ratio.append(log_evidence_ratio)
235
+ self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
236
+ logger.info(
237
+ f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
238
+ )
239
+
240
+ samples = samples.resample(beta, rng=self.rng)
241
+
242
+ samples = self.mutate(samples, beta)
243
+ if beta == 1.0 or (
244
+ max_n_steps is not None and iterations >= max_n_steps
245
+ ):
246
+ break
247
+
248
+ # If n_final_samples is not None, perform additional mutations steps
249
+ if n_final_samples is not None:
250
+ logger.info(f"Generating {n_final_samples} final samples")
251
+ final_samples = samples.resample(
252
+ 1.0, n_samples=n_final_samples, rng=self.rng
253
+ )
254
+ samples = self.mutate(final_samples, 1.0, n_steps=n_final_steps)
255
+
256
+ samples.log_evidence = samples.xp.sum(
257
+ asarray(self.history.log_norm_ratio, self.xp)
258
+ )
259
+ samples.log_evidence_error = samples.xp.sqrt(
260
+ samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
261
+ )
262
+
263
+ final_samples = samples.to_standard_samples()
264
+ logger.info(
265
+ f"Log evidence: {final_samples.log_evidence:.2f} +/- {final_samples.log_evidence_error:.2f}"
266
+ )
267
+ return final_samples
268
+
269
+ def mutate(self, particles):
270
+ raise NotImplementedError
271
+
272
+ def log_prob(self, z, beta=None):
273
+ x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
274
+ samples = SMCSamples(x, xp=self.xp)
275
+ log_q = self.prior_flow.log_prob(samples.x)
276
+ samples.log_q = samples.array_to_namespace(log_q)
277
+ samples.log_prior = self.log_prior(samples)
278
+ samples.log_likelihood = self.log_likelihood(samples)
279
+ log_prob = samples.log_p_t(
280
+ beta=beta
281
+ ).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
282
+
283
+ log_prob = update_at_indices(
284
+ log_prob, self.xp.isnan(log_prob), -self.xp.inf
285
+ )
286
+ return log_prob
287
+
288
+
289
+ class NumpySMCSampler(SMCSampler):
290
+ def __init__(
291
+ self,
292
+ log_likelihood,
293
+ log_prior,
294
+ dims,
295
+ prior_flow,
296
+ xp,
297
+ parameters=None,
298
+ preconditioning_transform=None,
299
+ ):
300
+ if preconditioning_transform is not None:
301
+ preconditioning_transform = preconditioning_transform.new_instance(
302
+ xp=np
303
+ )
304
+ super().__init__(
305
+ log_likelihood,
306
+ log_prior,
307
+ dims,
308
+ prior_flow,
309
+ xp,
310
+ parameters=parameters,
311
+ preconditioning_transform=preconditioning_transform,
312
+ )