jaxspec 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +35 -0
- jaxspec/data/obsconf.py +34 -0
- jaxspec/fit.py +247 -261
- jaxspec/model/abc.py +60 -11
- jaxspec/model/additive.py +1 -3
- jaxspec/model/multiplicative.py +3 -11
- jaxspec/util/typing.py +27 -2
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/METADATA +13 -7
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/RECORD +12 -14
- jaxspec/model/_additive/__init__.py +0 -0
- jaxspec/model/_additive/apec.py +0 -316
- jaxspec/model/_additive/apec_loaders.py +0 -73
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/WHEEL +0 -0
- {jaxspec-0.1.0.dist-info → jaxspec-0.1.2.dist-info}/entry_points.txt +0 -0
jaxspec/fit.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import operator
|
|
2
|
+
import warnings
|
|
2
3
|
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from collections.abc import Callable
|
|
@@ -9,29 +10,27 @@ import arviz as az
|
|
|
9
10
|
import haiku as hk
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
12
15
|
import numpyro
|
|
13
|
-
import optimistix as optx
|
|
14
16
|
|
|
15
17
|
from jax import random
|
|
16
18
|
from jax.experimental.sparse import BCOO
|
|
17
|
-
from jax.flatten_util import ravel_pytree
|
|
18
19
|
from jax.random import PRNGKey
|
|
19
20
|
from jax.tree_util import tree_map
|
|
20
21
|
from jax.typing import ArrayLike
|
|
21
22
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
22
23
|
from numpyro.distributions import Distribution, Poisson, TransformedDistribution
|
|
23
|
-
from numpyro.infer import MCMC, NUTS, Predictive
|
|
24
|
-
from numpyro.infer.initialization import init_to_value
|
|
24
|
+
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
25
25
|
from numpyro.infer.inspect import get_model_relations
|
|
26
26
|
from numpyro.infer.reparam import TransformReparam
|
|
27
|
-
from numpyro.infer.util import
|
|
28
|
-
from scipy.stats import Covariance, multivariate_normal
|
|
27
|
+
from numpyro.infer.util import log_density
|
|
29
28
|
|
|
29
|
+
from .analysis._plot import _plot_poisson_data_with_error
|
|
30
30
|
from .analysis.results import FitResult
|
|
31
31
|
from .data import ObsConfiguration
|
|
32
32
|
from .model.abc import SpectralModel
|
|
33
33
|
from .model.background import BackgroundModel
|
|
34
|
-
from .util import catchtime
|
|
35
34
|
from .util.typing import PriorDictModel, PriorDictType
|
|
36
35
|
|
|
37
36
|
|
|
@@ -100,27 +99,6 @@ def build_numpyro_model_for_single_obs(
|
|
|
100
99
|
return numpyro_model
|
|
101
100
|
|
|
102
101
|
|
|
103
|
-
def filter_inference_data(
|
|
104
|
-
inference_data, observation_container, background_model=None
|
|
105
|
-
) -> az.InferenceData:
|
|
106
|
-
predictive_parameters = []
|
|
107
|
-
|
|
108
|
-
for key, value in observation_container.items():
|
|
109
|
-
if background_model is not None:
|
|
110
|
-
predictive_parameters.append(f"obs_{key}")
|
|
111
|
-
predictive_parameters.append(f"bkg_{key}")
|
|
112
|
-
else:
|
|
113
|
-
predictive_parameters.append(f"obs_{key}")
|
|
114
|
-
|
|
115
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[predictive_parameters]
|
|
116
|
-
|
|
117
|
-
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
118
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
119
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
120
|
-
|
|
121
|
-
return inference_data
|
|
122
|
-
|
|
123
|
-
|
|
124
102
|
class CountForwardModel(hk.Module):
|
|
125
103
|
"""
|
|
126
104
|
A haiku module which allows to build the function that simulates the measured counts
|
|
@@ -153,7 +131,8 @@ class CountForwardModel(hk.Module):
|
|
|
153
131
|
|
|
154
132
|
class BayesianModel:
|
|
155
133
|
"""
|
|
156
|
-
|
|
134
|
+
Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
|
|
135
|
+
and compute the log-likelihood and posterior probability.
|
|
157
136
|
"""
|
|
158
137
|
|
|
159
138
|
def __init__(
|
|
@@ -165,6 +144,8 @@ class BayesianModel:
|
|
|
165
144
|
sparsify_matrix: bool = False,
|
|
166
145
|
):
|
|
167
146
|
"""
|
|
147
|
+
Build a Bayesian model for a given spectral model and observations.
|
|
148
|
+
|
|
168
149
|
Parameters:
|
|
169
150
|
model: the spectral model to fit.
|
|
170
151
|
prior_distributions: a nested dictionary containing the prior distributions for the model parameters, or a
|
|
@@ -181,7 +162,7 @@ class BayesianModel:
|
|
|
181
162
|
|
|
182
163
|
if not callable(prior_distributions):
|
|
183
164
|
# Validate the entry with pydantic
|
|
184
|
-
prior = PriorDictModel(
|
|
165
|
+
prior = PriorDictModel.from_dict(prior_distributions).nested_dict
|
|
185
166
|
|
|
186
167
|
def prior_distributions_func():
|
|
187
168
|
return build_prior(prior, expand_shape=(len(self.observation_container),))
|
|
@@ -190,7 +171,7 @@ class BayesianModel:
|
|
|
190
171
|
prior_distributions_func = prior_distributions
|
|
191
172
|
|
|
192
173
|
self.prior_distributions_func = prior_distributions_func
|
|
193
|
-
self.init_params = self.
|
|
174
|
+
self.init_params = self.prior_samples()
|
|
194
175
|
|
|
195
176
|
@cached_property
|
|
196
177
|
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
@@ -214,9 +195,6 @@ class BayesianModel:
|
|
|
214
195
|
def numpyro_model(self) -> Callable:
|
|
215
196
|
"""
|
|
216
197
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
217
|
-
|
|
218
|
-
Returns:
|
|
219
|
-
A model function that can be used with numpyro.
|
|
220
198
|
"""
|
|
221
199
|
|
|
222
200
|
def model(observed=True):
|
|
@@ -256,9 +234,6 @@ class BayesianModel:
|
|
|
256
234
|
def log_likelihood_per_obs(self) -> Callable:
|
|
257
235
|
"""
|
|
258
236
|
Build the log likelihood function for each bins in each observation.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
Callable log-likelihood function.
|
|
262
237
|
"""
|
|
263
238
|
|
|
264
239
|
@jax.jit
|
|
@@ -293,6 +268,9 @@ class BayesianModel:
|
|
|
293
268
|
that can be fetched with the [`parameter_names`][jaxspec.fit.BayesianModel.parameter_names].
|
|
294
269
|
"""
|
|
295
270
|
|
|
271
|
+
# This is required as numpyro.infer.util.log_densities does not check parameter validity by itself
|
|
272
|
+
numpyro.enable_validation()
|
|
273
|
+
|
|
296
274
|
@jax.jit
|
|
297
275
|
def log_posterior_prob(constrained_params):
|
|
298
276
|
log_posterior_prob, _ = log_density(
|
|
@@ -312,6 +290,16 @@ class BayesianModel:
|
|
|
312
290
|
observed_sites = relations["observed"]
|
|
313
291
|
return [site for site in all_sites if site not in observed_sites]
|
|
314
292
|
|
|
293
|
+
@cached_property
|
|
294
|
+
def observation_names(self) -> list[str]:
|
|
295
|
+
"""
|
|
296
|
+
List of the observations.
|
|
297
|
+
"""
|
|
298
|
+
relations = get_model_relations(self.numpyro_model)
|
|
299
|
+
all_sites = relations["sample_sample"].keys()
|
|
300
|
+
observed_sites = relations["observed"]
|
|
301
|
+
return [site for site in all_sites if site in observed_sites]
|
|
302
|
+
|
|
315
303
|
def array_to_dict(self, theta):
|
|
316
304
|
"""
|
|
317
305
|
Convert an array of parameters to a dictionary of parameters.
|
|
@@ -335,7 +323,7 @@ class BayesianModel:
|
|
|
335
323
|
|
|
336
324
|
return theta
|
|
337
325
|
|
|
338
|
-
def
|
|
326
|
+
def prior_samples(self, key: PRNGKey = PRNGKey(0), num_samples: int = 100):
|
|
339
327
|
"""
|
|
340
328
|
Get initial parameters for the model by sampling from the prior distribution
|
|
341
329
|
|
|
@@ -344,242 +332,264 @@ class BayesianModel:
|
|
|
344
332
|
num_samples: the number of samples to draw from the prior.
|
|
345
333
|
"""
|
|
346
334
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
335
|
+
@jax.jit
|
|
336
|
+
def prior_sample(key):
|
|
337
|
+
return Predictive(
|
|
338
|
+
self.numpyro_model, return_sites=self.parameter_names, num_samples=num_samples
|
|
339
|
+
)(key, observed=False)
|
|
350
340
|
|
|
341
|
+
return prior_sample(key)
|
|
351
342
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
343
|
+
def mock_observations(self, parameters, key: PRNGKey = PRNGKey(0)):
|
|
344
|
+
@jax.jit
|
|
345
|
+
def fakeit(key, parameters):
|
|
346
|
+
return Predictive(
|
|
347
|
+
self.numpyro_model,
|
|
348
|
+
return_sites=self.observation_names,
|
|
349
|
+
posterior_samples=parameters,
|
|
350
|
+
)(key, observed=False)
|
|
355
351
|
|
|
352
|
+
return fakeit(key, parameters)
|
|
356
353
|
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
354
|
+
def prior_predictive_coverage(
|
|
355
|
+
self, key: PRNGKey = PRNGKey(0), num_samples: int = 1000, percentiles: tuple = (16, 84)
|
|
356
|
+
):
|
|
357
|
+
"""
|
|
358
|
+
Check if the prior distribution include the observed data.
|
|
359
|
+
"""
|
|
360
|
+
key_prior, key_posterior = jax.random.split(key, 2)
|
|
361
|
+
prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
|
|
362
|
+
posterior_observations = self.mock_observations(prior_params, key=key_posterior)
|
|
362
363
|
|
|
363
|
-
|
|
364
|
+
for key, value in self.observation_container.items():
|
|
365
|
+
fig, axs = plt.subplots(
|
|
366
|
+
nrows=2, ncols=1, sharex=True, figsize=(8, 8), height_ratios=[3, 1]
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
_plot_poisson_data_with_error(
|
|
370
|
+
axs[0],
|
|
371
|
+
value.out_energies,
|
|
372
|
+
value.folded_counts.values,
|
|
373
|
+
percentiles=percentiles,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
axs[0].stairs(
|
|
377
|
+
np.max(posterior_observations["obs_" + key], axis=0),
|
|
378
|
+
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
|
|
379
|
+
baseline=np.min(posterior_observations["obs_" + key], axis=0),
|
|
380
|
+
alpha=0.3,
|
|
381
|
+
fill=True,
|
|
382
|
+
color=(0.15, 0.25, 0.45),
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
386
|
+
counts = posterior_observations["obs_" + key]
|
|
387
|
+
observed = value.folded_counts.values
|
|
388
|
+
|
|
389
|
+
num_samples = counts.shape[0]
|
|
390
|
+
|
|
391
|
+
less_than_obs = (counts < observed).sum(axis=0)
|
|
392
|
+
equal_to_obs = (counts == observed).sum(axis=0)
|
|
393
|
+
|
|
394
|
+
rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
|
|
395
|
+
|
|
396
|
+
axs[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
|
|
397
|
+
|
|
398
|
+
axs[1].plot(
|
|
399
|
+
(value.out_energies.min(), value.out_energies.max()),
|
|
400
|
+
(50, 50),
|
|
401
|
+
color="black",
|
|
402
|
+
linestyle="--",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
axs[1].set_xlabel("Energy (keV)")
|
|
406
|
+
axs[0].set_ylabel("Counts")
|
|
407
|
+
axs[1].set_ylabel("Rank (%)")
|
|
408
|
+
axs[1].set_ylim(0, 100)
|
|
409
|
+
axs[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
410
|
+
axs[0].loglog()
|
|
411
|
+
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
412
|
+
plt.show()
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class BayesianModelFitter(BayesianModel, ABC):
|
|
416
|
+
def build_inference_data(
|
|
364
417
|
self,
|
|
365
|
-
|
|
366
|
-
num_chains: int =
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
kernel_kwargs: dict = {},
|
|
373
|
-
mcmc_kwargs: dict = {},
|
|
374
|
-
) -> FitResult:
|
|
418
|
+
posterior_samples,
|
|
419
|
+
num_chains: int = 1,
|
|
420
|
+
num_predictive_samples: int = 1000,
|
|
421
|
+
key: PRNGKey = PRNGKey(42),
|
|
422
|
+
use_transformed_model: bool = False,
|
|
423
|
+
filter_inference_data: bool = True,
|
|
424
|
+
) -> az.InferenceData:
|
|
375
425
|
"""
|
|
376
|
-
|
|
426
|
+
Build an [InferenceData][arviz.InferenceData] object from posterior samples.
|
|
377
427
|
|
|
378
428
|
Parameters:
|
|
379
|
-
|
|
380
|
-
num_chains: the number of chains to
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
dense_mass: whether to use a dense mass for the NUTS sampler.
|
|
386
|
-
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
387
|
-
|
|
388
|
-
Returns:
|
|
389
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
429
|
+
posterior_samples: the samples from the posterior distribution.
|
|
430
|
+
num_chains: the number of chains used to sample the posterior.
|
|
431
|
+
num_predictive_samples: the number of samples to draw from the prior.
|
|
432
|
+
key: the random key used to initialize the sampler.
|
|
433
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
434
|
+
filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
|
|
390
435
|
"""
|
|
391
436
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
chain_kwargs = {
|
|
396
|
-
"num_warmup": num_warmup,
|
|
397
|
-
"num_samples": num_samples,
|
|
398
|
-
"num_chains": num_chains,
|
|
399
|
-
}
|
|
400
|
-
|
|
401
|
-
kernel = NUTS(
|
|
402
|
-
bayesian_model,
|
|
403
|
-
max_tree_depth=max_tree_depth,
|
|
404
|
-
target_accept_prob=target_accept_prob,
|
|
405
|
-
dense_mass=dense_mass,
|
|
406
|
-
**kernel_kwargs,
|
|
437
|
+
numpyro_model = (
|
|
438
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
407
439
|
)
|
|
408
440
|
|
|
409
|
-
|
|
410
|
-
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
441
|
+
keys = random.split(key, 3)
|
|
411
442
|
|
|
412
|
-
|
|
443
|
+
posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
|
|
413
444
|
|
|
414
|
-
|
|
445
|
+
prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
|
|
415
446
|
keys[1], observed=False
|
|
416
447
|
)
|
|
417
448
|
|
|
418
|
-
|
|
449
|
+
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
419
450
|
|
|
420
|
-
|
|
421
|
-
|
|
451
|
+
seeded_model = numpyro.handlers.substitute(
|
|
452
|
+
numpyro.handlers.seed(numpyro_model, keys[3]),
|
|
453
|
+
substitute_fn=numpyro.infer.init_to_sample,
|
|
422
454
|
)
|
|
423
455
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
456
|
+
observations = {
|
|
457
|
+
name: site["value"]
|
|
458
|
+
for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
|
|
459
|
+
if site["type"] == "sample" and site["is_observed"]
|
|
460
|
+
}
|
|
427
461
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
background_model=self.background_model,
|
|
433
|
-
)
|
|
462
|
+
def reshape_first_dimension(arr):
|
|
463
|
+
new_dim = arr.shape[0] // num_chains
|
|
464
|
+
new_shape = (num_chains, new_dim) + arr.shape[1:]
|
|
465
|
+
reshaped_array = arr.reshape(new_shape)
|
|
434
466
|
|
|
467
|
+
return reshaped_array
|
|
435
468
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
469
|
+
posterior_samples = {
|
|
470
|
+
key: reshape_first_dimension(value) for key, value in posterior_samples.items()
|
|
471
|
+
}
|
|
472
|
+
prior = {key: value[None, :] for key, value in prior.items()}
|
|
473
|
+
posterior_predictive = {
|
|
474
|
+
key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
|
|
475
|
+
}
|
|
476
|
+
log_likelihood = {
|
|
477
|
+
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
478
|
+
}
|
|
443
479
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
refine_first_guess=True,
|
|
452
|
-
) -> FitResult:
|
|
453
|
-
"""
|
|
454
|
-
Fit the model to the data using L-BFGS algorithm.
|
|
480
|
+
inference_data = az.from_dict(
|
|
481
|
+
posterior_samples,
|
|
482
|
+
prior=prior,
|
|
483
|
+
posterior_predictive=posterior_predictive,
|
|
484
|
+
log_likelihood=log_likelihood,
|
|
485
|
+
observed_data=observations,
|
|
486
|
+
)
|
|
455
487
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
num_samples: the number of sample to draw from the best-fit covariance.
|
|
488
|
+
return (
|
|
489
|
+
self.filter_inference_data(inference_data) if filter_inference_data else inference_data
|
|
490
|
+
)
|
|
460
491
|
|
|
461
|
-
|
|
462
|
-
|
|
492
|
+
def filter_inference_data(
|
|
493
|
+
self,
|
|
494
|
+
inference_data: az.InferenceData,
|
|
495
|
+
) -> az.InferenceData:
|
|
463
496
|
"""
|
|
497
|
+
Filter the inference data to keep only the relevant parameters for the observations.
|
|
464
498
|
|
|
465
|
-
|
|
466
|
-
|
|
499
|
+
- Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
|
|
500
|
+
- Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
|
|
501
|
+
"""
|
|
467
502
|
|
|
468
|
-
|
|
469
|
-
# We initialize the parameters by randomly sampling from the prior
|
|
470
|
-
local_keys = jax.random.split(keys[0], 2)
|
|
503
|
+
predictive_parameters = []
|
|
471
504
|
|
|
472
|
-
|
|
473
|
-
|
|
505
|
+
for key, value in self.observation_container.items():
|
|
506
|
+
if self.background_model is not None:
|
|
507
|
+
predictive_parameters.append(f"obs_{key}")
|
|
508
|
+
predictive_parameters.append(f"bkg_{key}")
|
|
509
|
+
else:
|
|
510
|
+
predictive_parameters.append(f"obs_{key}")
|
|
474
511
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
starting_value[f"{m}_{n}"] = val
|
|
512
|
+
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
513
|
+
predictive_parameters
|
|
514
|
+
]
|
|
479
515
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
516
|
+
parameters = [x for x in inference_data.posterior.keys() if not x.endswith("_base")]
|
|
517
|
+
inference_data.posterior = inference_data.posterior[parameters]
|
|
518
|
+
inference_data.prior = inference_data.prior[parameters]
|
|
483
519
|
|
|
484
|
-
|
|
485
|
-
init_params, _ = numpyro.infer.util.find_valid_initial_params(keys[0], bayesian_model)
|
|
520
|
+
return inference_data
|
|
486
521
|
|
|
487
|
-
|
|
522
|
+
@abstractmethod
|
|
523
|
+
def fit(self, **kwargs) -> FitResult: ...
|
|
488
524
|
|
|
489
|
-
@jax.jit
|
|
490
|
-
def nll(unconstrained_params, _):
|
|
491
|
-
constrained_params = constrain_fn(
|
|
492
|
-
bayesian_model, tuple(), dict(observed=True), unconstrained_params
|
|
493
|
-
)
|
|
494
525
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
526
|
+
class MCMCFitter(BayesianModelFitter):
|
|
527
|
+
"""
|
|
528
|
+
A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
|
|
529
|
+
from numpyro to perform the inference on the model parameters.
|
|
530
|
+
"""
|
|
498
531
|
|
|
499
|
-
|
|
500
|
-
|
|
532
|
+
kernel_dict = {
|
|
533
|
+
"nuts": NUTS,
|
|
534
|
+
"aies": AIES,
|
|
535
|
+
"ess": ESS,
|
|
536
|
+
}
|
|
501
537
|
|
|
538
|
+
def fit(
|
|
539
|
+
self,
|
|
540
|
+
rng_key: int = 0,
|
|
541
|
+
num_chains: int = len(jax.devices()),
|
|
542
|
+
num_warmup: int = 1000,
|
|
543
|
+
num_samples: int = 1000,
|
|
544
|
+
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
545
|
+
use_transformed_model: bool = True,
|
|
546
|
+
kernel_kwargs: dict = {},
|
|
547
|
+
mcmc_kwargs: dict = {},
|
|
548
|
+
) -> FitResult:
|
|
502
549
|
"""
|
|
503
|
-
|
|
504
|
-
with catchtime("Refine_first"):
|
|
505
|
-
solution = optx.least_squares(
|
|
506
|
-
nll,
|
|
507
|
-
optx.BestSoFarMinimiser(optx.OptaxMinimiser(optax.adam(1e-4), 1e-6, 1e-6)),
|
|
508
|
-
init_params,
|
|
509
|
-
max_steps=1000,
|
|
510
|
-
throw=False
|
|
511
|
-
)
|
|
512
|
-
init_params = solution.value
|
|
513
|
-
"""
|
|
550
|
+
Fit the model to the data using a MCMC sampler from numpyro.
|
|
514
551
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
nll,
|
|
525
|
-
solver,
|
|
526
|
-
init_params,
|
|
527
|
-
max_steps=num_iter_max,
|
|
528
|
-
)
|
|
552
|
+
Parameters:
|
|
553
|
+
rng_key: the random key used to initialize the sampler.
|
|
554
|
+
num_chains: the number of chains to run.
|
|
555
|
+
num_warmup: the number of warmup steps.
|
|
556
|
+
num_samples: the number of samples to draw.
|
|
557
|
+
sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
|
|
558
|
+
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
559
|
+
kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
|
|
560
|
+
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
529
561
|
|
|
530
|
-
|
|
531
|
-
|
|
562
|
+
Returns:
|
|
563
|
+
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
564
|
+
"""
|
|
532
565
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
)(value_flat)
|
|
566
|
+
bayesian_model = (
|
|
567
|
+
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
568
|
+
)
|
|
537
569
|
|
|
538
|
-
|
|
570
|
+
chain_kwargs = {
|
|
571
|
+
"num_warmup": num_warmup,
|
|
572
|
+
"num_samples": num_samples,
|
|
573
|
+
"num_chains": num_chains,
|
|
574
|
+
}
|
|
539
575
|
|
|
540
|
-
|
|
576
|
+
kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
|
|
541
577
|
|
|
542
|
-
|
|
543
|
-
posterior_samples = jax.jit(
|
|
544
|
-
jax.vmap(lambda p: constrain_fn(bayesian_model, tuple(), dict(observed=True), p))
|
|
545
|
-
)(samples)
|
|
578
|
+
mcmc_kwargs = chain_kwargs | mcmc_kwargs
|
|
546
579
|
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
)
|
|
551
|
-
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
552
|
-
log_likelihood = numpyro.infer.log_likelihood(bayesian_model, posterior_samples)
|
|
580
|
+
if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
|
|
581
|
+
mcmc_kwargs["chain_method"] = "vectorized"
|
|
582
|
+
warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
|
|
553
583
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
557
|
-
"""
|
|
558
|
-
return tree_map(lambda x: x[None, ...], chain)
|
|
584
|
+
mcmc = MCMC(kernel, **mcmc_kwargs)
|
|
585
|
+
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
559
586
|
|
|
560
|
-
|
|
561
|
-
seeded_model = numpyro.handlers.substitute(
|
|
562
|
-
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
563
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
564
|
-
)
|
|
565
|
-
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
566
|
-
observations = {
|
|
567
|
-
name: site["value"]
|
|
568
|
-
for name, site in trace.items()
|
|
569
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
570
|
-
}
|
|
587
|
+
mcmc.run(keys[0])
|
|
571
588
|
|
|
572
|
-
|
|
573
|
-
inference_data = az.from_dict(
|
|
574
|
-
sanitize_chain(posterior_samples),
|
|
575
|
-
prior=sanitize_chain(prior),
|
|
576
|
-
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
577
|
-
log_likelihood=sanitize_chain(log_likelihood),
|
|
578
|
-
observed_data=observations,
|
|
579
|
-
)
|
|
589
|
+
posterior = mcmc.get_samples()
|
|
580
590
|
|
|
581
|
-
inference_data =
|
|
582
|
-
|
|
591
|
+
inference_data = self.build_inference_data(
|
|
592
|
+
posterior, num_chains=num_chains, use_transformed_model=True
|
|
583
593
|
)
|
|
584
594
|
|
|
585
595
|
return FitResult(
|
|
@@ -590,18 +600,22 @@ class MinimizationFitter(BayesianModelFitter):
|
|
|
590
600
|
)
|
|
591
601
|
|
|
592
602
|
|
|
593
|
-
class
|
|
603
|
+
class NSFitter(BayesianModelFitter):
|
|
594
604
|
r"""
|
|
595
605
|
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
596
606
|
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
597
607
|
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
598
|
-
|
|
608
|
+
|
|
609
|
+
!!! info
|
|
610
|
+
Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
|
|
611
|
+
|
|
599
612
|
"""
|
|
600
613
|
|
|
601
614
|
def fit(
|
|
602
615
|
self,
|
|
603
616
|
rng_key: int = 0,
|
|
604
617
|
num_samples: int = 1000,
|
|
618
|
+
num_live_points: int = 1000,
|
|
605
619
|
plot_diagnostics=False,
|
|
606
620
|
termination_kwargs: dict | None = None,
|
|
607
621
|
verbose=True,
|
|
@@ -612,6 +626,10 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
612
626
|
Parameters:
|
|
613
627
|
rng_key: the random key used to initialize the sampler.
|
|
614
628
|
num_samples: the number of samples to draw.
|
|
629
|
+
num_live_points: the number of live points to use at the start of the NS algorithm.
|
|
630
|
+
plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
|
|
631
|
+
termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
|
|
632
|
+
verbose: whether to print the progress of the NS algorithm.
|
|
615
633
|
|
|
616
634
|
Returns:
|
|
617
635
|
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
@@ -628,7 +646,7 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
628
646
|
difficult_model=True,
|
|
629
647
|
max_samples=1e6,
|
|
630
648
|
parameter_estimation=True,
|
|
631
|
-
num_live_points=
|
|
649
|
+
num_live_points=num_live_points,
|
|
632
650
|
),
|
|
633
651
|
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
634
652
|
)
|
|
@@ -638,41 +656,9 @@ class NestedSamplingFitter(BayesianModelFitter):
|
|
|
638
656
|
if plot_diagnostics:
|
|
639
657
|
ns.diagnostics()
|
|
640
658
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
keys[2], observed=False
|
|
645
|
-
)
|
|
646
|
-
|
|
647
|
-
prior = Predictive(bayesian_model, num_samples=num_samples)(keys[3], observed=False)
|
|
648
|
-
|
|
649
|
-
seeded_model = numpyro.handlers.substitute(
|
|
650
|
-
numpyro.handlers.seed(bayesian_model, jax.random.PRNGKey(0)),
|
|
651
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
652
|
-
)
|
|
653
|
-
trace = numpyro.handlers.trace(seeded_model).get_trace()
|
|
654
|
-
observations = {
|
|
655
|
-
name: site["value"]
|
|
656
|
-
for name, site in trace.items()
|
|
657
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
658
|
-
}
|
|
659
|
-
|
|
660
|
-
def sanitize_chain(chain):
|
|
661
|
-
"""
|
|
662
|
-
reshape the samples so that it is arviz compliant with an extra starting dimension
|
|
663
|
-
"""
|
|
664
|
-
return tree_map(lambda x: x[None, ...], chain)
|
|
665
|
-
|
|
666
|
-
inference_data = az.from_dict(
|
|
667
|
-
sanitize_chain(posterior_samples),
|
|
668
|
-
prior=sanitize_chain(prior),
|
|
669
|
-
posterior_predictive=sanitize_chain(posterior_predictive),
|
|
670
|
-
log_likelihood=sanitize_chain(log_likelihood),
|
|
671
|
-
observed_data=observations,
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
inference_data = filter_inference_data(
|
|
675
|
-
inference_data, self.observation_container, self.background_model
|
|
659
|
+
posterior = ns.get_samples(keys[1], num_samples=num_samples)
|
|
660
|
+
inference_data = self.build_inference_data(
|
|
661
|
+
posterior, num_chains=1, use_transformed_model=True
|
|
676
662
|
)
|
|
677
663
|
|
|
678
664
|
return FitResult(
|