aspire-inference 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +457 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +37 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +82 -0
- aspire/flows/jax/utils.py +54 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +276 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +92 -0
- aspire/samplers/importance.py +18 -0
- aspire/samplers/mcmc.py +158 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +312 -0
- aspire/samplers/smc/blackjax.py +330 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +476 -0
- aspire/transforms.py +491 -0
- aspire/utils.py +491 -0
- aspire_inference-0.1.0a2.dist-info/METADATA +48 -0
- aspire_inference-0.1.0a2.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a2.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a2.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a2.dist-info/top_level.txt +1 -0
aspire/samplers/mcmc.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from ..samples import Samples, to_numpy
|
|
4
|
+
from ..utils import track_calls
|
|
5
|
+
from .base import Sampler
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MCMCSampler(Sampler):
|
|
9
|
+
def draw_initial_samples(self, n_samples: int) -> Samples:
|
|
10
|
+
"""Draw initial samples from the prior flow."""
|
|
11
|
+
# Flow may propose samples outside prior bounds, so we may need
|
|
12
|
+
# to try multiple times to get enough valid samples.
|
|
13
|
+
n_samples_drawn = 0
|
|
14
|
+
samples = None
|
|
15
|
+
while n_samples_drawn < n_samples:
|
|
16
|
+
n_to_draw = n_samples - n_samples_drawn
|
|
17
|
+
x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
|
|
18
|
+
new_samples = Samples(x, xp=self.xp, log_q=log_q)
|
|
19
|
+
new_samples.log_prior = new_samples.array_to_namespace(
|
|
20
|
+
self.log_prior(new_samples)
|
|
21
|
+
)
|
|
22
|
+
valid = self.xp.isfinite(new_samples.log_prior)
|
|
23
|
+
n_valid = self.xp.sum(valid)
|
|
24
|
+
if n_valid > 0:
|
|
25
|
+
if samples is None:
|
|
26
|
+
samples = new_samples[valid]
|
|
27
|
+
else:
|
|
28
|
+
samples = Samples.concatenate(
|
|
29
|
+
[samples, new_samples[valid]]
|
|
30
|
+
)
|
|
31
|
+
n_samples_drawn += n_valid
|
|
32
|
+
|
|
33
|
+
if n_samples_drawn > n_samples:
|
|
34
|
+
samples = samples[:n_samples]
|
|
35
|
+
|
|
36
|
+
samples.log_likelihood = samples.array_to_namespace(
|
|
37
|
+
self.log_likelihood(samples)
|
|
38
|
+
)
|
|
39
|
+
return samples
|
|
40
|
+
|
|
41
|
+
def log_prob(self, z):
|
|
42
|
+
"""Compute the log probability of the samples.
|
|
43
|
+
|
|
44
|
+
Input samples are in the transformed space.
|
|
45
|
+
"""
|
|
46
|
+
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
47
|
+
samples = Samples(x, xp=self.xp)
|
|
48
|
+
samples.log_prior = self.log_prior(samples)
|
|
49
|
+
samples.log_likelihood = self.log_likelihood(samples)
|
|
50
|
+
log_prob = (
|
|
51
|
+
samples.log_likelihood
|
|
52
|
+
+ samples.log_prior
|
|
53
|
+
+ samples.array_to_namespace(log_abs_det_jacobian)
|
|
54
|
+
)
|
|
55
|
+
return to_numpy(log_prob).flatten()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Emcee(MCMCSampler):
|
|
59
|
+
@track_calls
|
|
60
|
+
def sample(
|
|
61
|
+
self,
|
|
62
|
+
n_samples: int,
|
|
63
|
+
nwalkers: int = None,
|
|
64
|
+
nsteps: int = 500,
|
|
65
|
+
rng=None,
|
|
66
|
+
discard=0,
|
|
67
|
+
**kwargs,
|
|
68
|
+
) -> Samples:
|
|
69
|
+
from emcee import EnsembleSampler
|
|
70
|
+
|
|
71
|
+
nwalkers = nwalkers or n_samples
|
|
72
|
+
self.sampler = EnsembleSampler(
|
|
73
|
+
nwalkers,
|
|
74
|
+
self.dims,
|
|
75
|
+
log_prob_fn=self.log_prob,
|
|
76
|
+
vectorize=True,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
rng = rng or np.random.default_rng()
|
|
80
|
+
|
|
81
|
+
samples = self.draw_initial_samples(nwalkers)
|
|
82
|
+
p0 = samples.x
|
|
83
|
+
|
|
84
|
+
z0 = to_numpy(self.preconditioning_transform.fit(p0))
|
|
85
|
+
|
|
86
|
+
self.sampler.run_mcmc(z0, nsteps, **kwargs)
|
|
87
|
+
|
|
88
|
+
z = self.sampler.get_chain(flat=True, discard=discard)
|
|
89
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
90
|
+
|
|
91
|
+
x_evidence, log_q = self.prior_flow.sample_and_log_prob(n_samples)
|
|
92
|
+
samples_evidence = Samples(x_evidence, log_q=log_q, xp=self.xp)
|
|
93
|
+
samples_evidence.log_prior = self.log_prior(samples_evidence)
|
|
94
|
+
samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
|
|
95
|
+
samples_evidence.compute_weights()
|
|
96
|
+
|
|
97
|
+
samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
|
|
98
|
+
samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
|
|
99
|
+
self.log_prior(samples_mcmc)
|
|
100
|
+
)
|
|
101
|
+
samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
|
|
102
|
+
self.log_likelihood(samples_mcmc)
|
|
103
|
+
)
|
|
104
|
+
samples_mcmc.log_evidence = samples_mcmc.array_to_namespace(
|
|
105
|
+
samples_evidence.log_evidence
|
|
106
|
+
)
|
|
107
|
+
samples_mcmc.log_evidence_error = samples_mcmc.array_to_namespace(
|
|
108
|
+
samples_evidence.log_evidence_error
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return samples_mcmc
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class MiniPCN(MCMCSampler):
|
|
115
|
+
@track_calls
|
|
116
|
+
def sample(
|
|
117
|
+
self,
|
|
118
|
+
n_samples,
|
|
119
|
+
rng=None,
|
|
120
|
+
target_acceptance_rate=0.234,
|
|
121
|
+
n_steps=100,
|
|
122
|
+
thin=1,
|
|
123
|
+
burnin=0,
|
|
124
|
+
last_step_only=False,
|
|
125
|
+
step_fn="tpcn",
|
|
126
|
+
):
|
|
127
|
+
from minipcn import Sampler
|
|
128
|
+
|
|
129
|
+
rng = rng or np.random.default_rng()
|
|
130
|
+
p0 = self.draw_initial_samples(n_samples).x
|
|
131
|
+
|
|
132
|
+
z0 = to_numpy(self.preconditioning_transform.fit(p0))
|
|
133
|
+
|
|
134
|
+
self.sampler = Sampler(
|
|
135
|
+
log_prob_fn=self.log_prob,
|
|
136
|
+
step_fn=step_fn,
|
|
137
|
+
rng=rng,
|
|
138
|
+
dims=self.dims,
|
|
139
|
+
target_acceptance_rate=target_acceptance_rate,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
chain, history = self.sampler.sample(z0, n_steps=n_steps)
|
|
143
|
+
|
|
144
|
+
if last_step_only:
|
|
145
|
+
z = chain[-1]
|
|
146
|
+
else:
|
|
147
|
+
z = chain[burnin::thin].reshape(-1, self.dims)
|
|
148
|
+
|
|
149
|
+
x = self.preconditioning_transform.inverse(z)[0]
|
|
150
|
+
|
|
151
|
+
samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
|
|
152
|
+
samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
|
|
153
|
+
self.log_prior(samples_mcmc)
|
|
154
|
+
)
|
|
155
|
+
samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
|
|
156
|
+
self.log_likelihood(samples_mcmc)
|
|
157
|
+
)
|
|
158
|
+
return samples_mcmc
|
|
File without changes
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import array_api_compat.numpy as np
|
|
5
|
+
|
|
6
|
+
from ...flows.base import Flow
|
|
7
|
+
from ...history import SMCHistory
|
|
8
|
+
from ...samples import SMCSamples
|
|
9
|
+
from ...utils import (
|
|
10
|
+
asarray,
|
|
11
|
+
effective_sample_size,
|
|
12
|
+
track_calls,
|
|
13
|
+
update_at_indices,
|
|
14
|
+
)
|
|
15
|
+
from ..mcmc import MCMCSampler
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SMCSampler(MCMCSampler):
|
|
21
|
+
"""Base class for Sequential Monte Carlo samplers."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
log_likelihood: Callable,
|
|
26
|
+
log_prior: Callable,
|
|
27
|
+
dims: int,
|
|
28
|
+
prior_flow: Flow,
|
|
29
|
+
xp: Callable,
|
|
30
|
+
parameters: list[str] | None = None,
|
|
31
|
+
rng: np.random.Generator | None = None,
|
|
32
|
+
preconditioning_transform: Callable | None = None,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
log_likelihood=log_likelihood,
|
|
36
|
+
log_prior=log_prior,
|
|
37
|
+
dims=dims,
|
|
38
|
+
prior_flow=prior_flow,
|
|
39
|
+
xp=xp,
|
|
40
|
+
parameters=parameters,
|
|
41
|
+
preconditioning_transform=preconditioning_transform,
|
|
42
|
+
)
|
|
43
|
+
self.rng = rng or np.random.default_rng()
|
|
44
|
+
self._adapative_target_efficiency = False
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def target_efficiency(self):
|
|
48
|
+
return self._target_efficiency
|
|
49
|
+
|
|
50
|
+
@target_efficiency.setter
|
|
51
|
+
def target_efficiency(self, value: float | tuple):
|
|
52
|
+
"""Set the target efficiency.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
value : float or tuple
|
|
57
|
+
If a float, the target efficiency to use for all iterations.
|
|
58
|
+
If a tuple of two floats, the target efficiency will adapt from
|
|
59
|
+
the first value to the second value over the course of the SMC
|
|
60
|
+
iterations. See `target_efficiency_rate` for details.
|
|
61
|
+
"""
|
|
62
|
+
if isinstance(value, float):
|
|
63
|
+
if not (0 < value < 1):
|
|
64
|
+
raise ValueError("target_efficiency must be in (0, 1)")
|
|
65
|
+
self._target_efficiency = value
|
|
66
|
+
self._adapative_target_efficiency = False
|
|
67
|
+
elif len(value) != 2:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"target_efficiency must be a float or tuple of two floats"
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
value = tuple(map(float, value))
|
|
73
|
+
if not (0 < value[0] < value[1] < 1):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"target_efficiency tuple must be in (0, 1) and increasing"
|
|
76
|
+
)
|
|
77
|
+
self._target_efficiency = value
|
|
78
|
+
self._adapative_target_efficiency = True
|
|
79
|
+
|
|
80
|
+
def current_target_efficiency(self, beta: float) -> float:
|
|
81
|
+
"""Get the current target efficiency based on beta."""
|
|
82
|
+
if self._adapative_target_efficiency:
|
|
83
|
+
return self._target_efficiency[0] + (
|
|
84
|
+
self._target_efficiency[1] - self._target_efficiency[0]
|
|
85
|
+
) * (beta**self.target_efficiency_rate)
|
|
86
|
+
else:
|
|
87
|
+
return self._target_efficiency
|
|
88
|
+
|
|
89
|
+
def determine_beta(
|
|
90
|
+
self,
|
|
91
|
+
samples: SMCSamples,
|
|
92
|
+
beta: float,
|
|
93
|
+
beta_step: float,
|
|
94
|
+
min_step: float,
|
|
95
|
+
) -> tuple[float, float]:
|
|
96
|
+
"""Determine the next beta value.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
samples : SMCSamples
|
|
101
|
+
The current samples.
|
|
102
|
+
beta : float
|
|
103
|
+
The current beta value.
|
|
104
|
+
beta_step : float
|
|
105
|
+
The fixed beta step size if not adaptive.
|
|
106
|
+
min_step : float
|
|
107
|
+
The minimum beta step size.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
beta : float
|
|
112
|
+
The new beta value.
|
|
113
|
+
min_step : float
|
|
114
|
+
The new minimum step size if adaptive_min_step is True.
|
|
115
|
+
"""
|
|
116
|
+
if not self.adaptive:
|
|
117
|
+
beta += beta_step
|
|
118
|
+
if beta >= 1.0:
|
|
119
|
+
beta = 1.0
|
|
120
|
+
else:
|
|
121
|
+
beta_prev = beta
|
|
122
|
+
beta_min = beta_prev
|
|
123
|
+
beta_max = 1.0
|
|
124
|
+
tol = 1e-5
|
|
125
|
+
eff_beta_max = effective_sample_size(
|
|
126
|
+
samples.log_weights(beta_max)
|
|
127
|
+
) / len(samples)
|
|
128
|
+
if eff_beta_max >= self.current_target_efficiency(beta_prev):
|
|
129
|
+
beta_min = 1.0
|
|
130
|
+
target_eff = self.current_target_efficiency(beta_prev)
|
|
131
|
+
while beta_max - beta_min > tol:
|
|
132
|
+
beta_try = 0.5 * (beta_max + beta_min)
|
|
133
|
+
eff = effective_sample_size(
|
|
134
|
+
samples.log_weights(beta_try)
|
|
135
|
+
) / len(samples)
|
|
136
|
+
if eff >= target_eff:
|
|
137
|
+
beta_min = beta_try
|
|
138
|
+
else:
|
|
139
|
+
beta_max = beta_try
|
|
140
|
+
beta_star = beta_min
|
|
141
|
+
|
|
142
|
+
if self.adaptive_min_step:
|
|
143
|
+
min_step = min_step * (1 - beta_prev) / (1 - beta_star)
|
|
144
|
+
beta = max(beta_star, beta_prev + min_step)
|
|
145
|
+
beta = min(beta, 1.0)
|
|
146
|
+
return beta, min_step
|
|
147
|
+
|
|
148
|
+
@track_calls
|
|
149
|
+
def sample(
|
|
150
|
+
self,
|
|
151
|
+
n_samples: int,
|
|
152
|
+
n_steps: int | None = None,
|
|
153
|
+
adaptive: bool = True,
|
|
154
|
+
min_step: float | None = None,
|
|
155
|
+
max_n_steps: int | None = None,
|
|
156
|
+
target_efficiency: float = 0.5,
|
|
157
|
+
target_efficiency_rate: float = 1.0,
|
|
158
|
+
n_final_samples: int | None = None,
|
|
159
|
+
) -> SMCSamples:
|
|
160
|
+
samples = self.draw_initial_samples(n_samples)
|
|
161
|
+
samples = SMCSamples.from_samples(samples, xp=self.xp, beta=0.0)
|
|
162
|
+
self.fit_preconditioning_transform(samples.x)
|
|
163
|
+
|
|
164
|
+
if self.xp.isnan(samples.log_q).any():
|
|
165
|
+
raise ValueError("Log proposal contains NaN values")
|
|
166
|
+
if self.xp.isnan(samples.log_prior).any():
|
|
167
|
+
raise ValueError("Log prior contains NaN values")
|
|
168
|
+
if self.xp.isnan(samples.log_likelihood).any():
|
|
169
|
+
raise ValueError("Log likelihood contains NaN values")
|
|
170
|
+
|
|
171
|
+
logger.debug(f"Initial sample summary: {samples}")
|
|
172
|
+
|
|
173
|
+
# Remove the n_final_steps from sampler_kwargs if present
|
|
174
|
+
self.sampler_kwargs = self.sampler_kwargs or {}
|
|
175
|
+
n_final_steps = self.sampler_kwargs.pop("n_final_steps", None)
|
|
176
|
+
|
|
177
|
+
self.history = SMCHistory()
|
|
178
|
+
|
|
179
|
+
self.target_efficiency = target_efficiency
|
|
180
|
+
self.target_efficiency_rate = target_efficiency_rate
|
|
181
|
+
|
|
182
|
+
if n_steps is not None:
|
|
183
|
+
beta_step = 1 / n_steps
|
|
184
|
+
elif not adaptive:
|
|
185
|
+
raise ValueError("Either n_steps or adaptive=True must be set")
|
|
186
|
+
else:
|
|
187
|
+
beta_step = np.nan
|
|
188
|
+
self.adaptive = adaptive
|
|
189
|
+
beta = 0.0
|
|
190
|
+
|
|
191
|
+
if min_step is None:
|
|
192
|
+
if max_n_steps is None:
|
|
193
|
+
min_step = 0.0
|
|
194
|
+
self.adaptive_min_step = False
|
|
195
|
+
else:
|
|
196
|
+
min_step = 1 / max_n_steps
|
|
197
|
+
self.adaptive_min_step = True
|
|
198
|
+
else:
|
|
199
|
+
self.adaptive_min_step = False
|
|
200
|
+
|
|
201
|
+
iterations = 0
|
|
202
|
+
while True:
|
|
203
|
+
iterations += 1
|
|
204
|
+
|
|
205
|
+
beta, min_step = self.determine_beta(
|
|
206
|
+
samples,
|
|
207
|
+
beta,
|
|
208
|
+
beta_step,
|
|
209
|
+
min_step,
|
|
210
|
+
)
|
|
211
|
+
self.history.eff_target.append(
|
|
212
|
+
self.current_target_efficiency(beta)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
logger.info(f"it {iterations} - beta: {beta}")
|
|
216
|
+
self.history.beta.append(beta)
|
|
217
|
+
|
|
218
|
+
ess = effective_sample_size(samples.log_weights(beta))
|
|
219
|
+
eff = ess / len(samples)
|
|
220
|
+
if eff < 0.1:
|
|
221
|
+
logger.warning(
|
|
222
|
+
f"it {iterations} - Low sample efficiency: {eff:.2f}"
|
|
223
|
+
)
|
|
224
|
+
self.history.ess.append(ess)
|
|
225
|
+
logger.info(
|
|
226
|
+
f"it {iterations} - ESS: {ess:.1f} ({eff:.2f} efficiency)"
|
|
227
|
+
)
|
|
228
|
+
self.history.ess_target.append(
|
|
229
|
+
effective_sample_size(samples.log_weights(1.0))
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
log_evidence_ratio = samples.log_evidence_ratio(beta)
|
|
233
|
+
log_evidence_ratio_var = samples.log_evidence_ratio_variance(beta)
|
|
234
|
+
self.history.log_norm_ratio.append(log_evidence_ratio)
|
|
235
|
+
self.history.log_norm_ratio_var.append(log_evidence_ratio_var)
|
|
236
|
+
logger.info(
|
|
237
|
+
f"it {iterations} - Log evidence ratio: {log_evidence_ratio:.2f} +/- {np.sqrt(log_evidence_ratio_var):.2f}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
samples = samples.resample(beta, rng=self.rng)
|
|
241
|
+
|
|
242
|
+
samples = self.mutate(samples, beta)
|
|
243
|
+
if beta == 1.0 or (
|
|
244
|
+
max_n_steps is not None and iterations >= max_n_steps
|
|
245
|
+
):
|
|
246
|
+
break
|
|
247
|
+
|
|
248
|
+
# If n_final_samples is not None, perform additional mutations steps
|
|
249
|
+
if n_final_samples is not None:
|
|
250
|
+
logger.info(f"Generating {n_final_samples} final samples")
|
|
251
|
+
final_samples = samples.resample(
|
|
252
|
+
1.0, n_samples=n_final_samples, rng=self.rng
|
|
253
|
+
)
|
|
254
|
+
samples = self.mutate(final_samples, 1.0, n_steps=n_final_steps)
|
|
255
|
+
|
|
256
|
+
samples.log_evidence = samples.xp.sum(
|
|
257
|
+
asarray(self.history.log_norm_ratio, self.xp)
|
|
258
|
+
)
|
|
259
|
+
samples.log_evidence_error = samples.xp.sqrt(
|
|
260
|
+
samples.xp.sum(asarray(self.history.log_norm_ratio_var, self.xp))
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
final_samples = samples.to_standard_samples()
|
|
264
|
+
logger.info(
|
|
265
|
+
f"Log evidence: {final_samples.log_evidence:.2f} +/- {final_samples.log_evidence_error:.2f}"
|
|
266
|
+
)
|
|
267
|
+
return final_samples
|
|
268
|
+
|
|
269
|
+
def mutate(self, particles):
|
|
270
|
+
raise NotImplementedError
|
|
271
|
+
|
|
272
|
+
def log_prob(self, z, beta=None):
|
|
273
|
+
x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
|
|
274
|
+
samples = SMCSamples(x, xp=self.xp)
|
|
275
|
+
log_q = self.prior_flow.log_prob(samples.x)
|
|
276
|
+
samples.log_q = samples.array_to_namespace(log_q)
|
|
277
|
+
samples.log_prior = self.log_prior(samples)
|
|
278
|
+
samples.log_likelihood = self.log_likelihood(samples)
|
|
279
|
+
log_prob = samples.log_p_t(
|
|
280
|
+
beta=beta
|
|
281
|
+
).flatten() + samples.array_to_namespace(log_abs_det_jacobian)
|
|
282
|
+
|
|
283
|
+
log_prob = update_at_indices(
|
|
284
|
+
log_prob, self.xp.isnan(log_prob), -self.xp.inf
|
|
285
|
+
)
|
|
286
|
+
return log_prob
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class NumpySMCSampler(SMCSampler):
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
log_likelihood,
|
|
293
|
+
log_prior,
|
|
294
|
+
dims,
|
|
295
|
+
prior_flow,
|
|
296
|
+
xp,
|
|
297
|
+
parameters=None,
|
|
298
|
+
preconditioning_transform=None,
|
|
299
|
+
):
|
|
300
|
+
if preconditioning_transform is not None:
|
|
301
|
+
preconditioning_transform = preconditioning_transform.new_instance(
|
|
302
|
+
xp=np
|
|
303
|
+
)
|
|
304
|
+
super().__init__(
|
|
305
|
+
log_likelihood,
|
|
306
|
+
log_prior,
|
|
307
|
+
dims,
|
|
308
|
+
prior_flow,
|
|
309
|
+
xp,
|
|
310
|
+
parameters=parameters,
|
|
311
|
+
preconditioning_transform=preconditioning_transform,
|
|
312
|
+
)
|