jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +5 -5
- jaxspec/analysis/results.py +41 -26
- jaxspec/data/obsconf.py +9 -3
- jaxspec/data/observation.py +3 -1
- jaxspec/data/ogip.py +9 -2
- jaxspec/data/util.py +17 -11
- jaxspec/experimental/interpolator.py +74 -0
- jaxspec/experimental/interpolator_jax.py +79 -0
- jaxspec/experimental/intrument_models.py +159 -0
- jaxspec/experimental/nested_sampler.py +78 -0
- jaxspec/experimental/tabulated.py +264 -0
- jaxspec/fit/__init__.py +3 -0
- jaxspec/{fit.py → fit/_bayesian_model.py} +84 -336
- jaxspec/{_fit → fit}/_build_model.py +42 -6
- jaxspec/fit/_fitter.py +255 -0
- jaxspec/model/abc.py +52 -80
- jaxspec/model/additive.py +14 -5
- jaxspec/model/background.py +17 -14
- jaxspec/model/instrument.py +81 -0
- jaxspec/model/list.py +4 -1
- jaxspec/model/multiplicative.py +32 -12
- jaxspec/util/integrate.py +17 -5
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/METADATA +11 -11
- jaxspec-0.3.1.dist-info/RECORD +42 -0
- jaxspec-0.2.2.dev0.dist-info/RECORD +0 -34
- /jaxspec/{_fit → experimental}/__init__.py +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/WHEEL +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -1,54 +1,53 @@
|
|
|
1
1
|
import operator
|
|
2
|
-
import warnings
|
|
3
2
|
|
|
4
|
-
from abc import ABC, abstractmethod
|
|
5
3
|
from collections.abc import Callable
|
|
6
4
|
from functools import cached_property
|
|
7
|
-
from typing import
|
|
5
|
+
from typing import Any
|
|
8
6
|
|
|
9
|
-
import arviz as az
|
|
10
7
|
import jax
|
|
11
8
|
import jax.numpy as jnp
|
|
12
9
|
import matplotlib.pyplot as plt
|
|
13
10
|
import numpyro
|
|
14
11
|
|
|
15
|
-
from
|
|
16
|
-
from jax.experimental import mesh_utils
|
|
12
|
+
from flax import nnx
|
|
17
13
|
from jax.random import PRNGKey
|
|
18
|
-
from jax.sharding import
|
|
19
|
-
from numpyro.contrib.nested_sampling import NestedSampler
|
|
14
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
20
15
|
from numpyro.distributions import Poisson, TransformedDistribution
|
|
21
|
-
from numpyro.infer import
|
|
16
|
+
from numpyro.infer import Predictive
|
|
22
17
|
from numpyro.infer.inspect import get_model_relations
|
|
23
18
|
from numpyro.infer.reparam import TransformReparam
|
|
24
19
|
from numpyro.infer.util import log_density
|
|
25
20
|
|
|
26
|
-
from .
|
|
27
|
-
from .analysis._plot import (
|
|
21
|
+
from ..analysis._plot import (
|
|
28
22
|
_error_bars_for_observed_data,
|
|
29
23
|
_plot_binned_samples_with_error,
|
|
30
24
|
_plot_poisson_data_with_error,
|
|
31
25
|
)
|
|
32
|
-
from
|
|
33
|
-
from .
|
|
34
|
-
from
|
|
35
|
-
from
|
|
36
|
-
from
|
|
26
|
+
from ..data import ObsConfiguration
|
|
27
|
+
from ..model.abc import SpectralModel
|
|
28
|
+
from ..model.background import BackgroundModel
|
|
29
|
+
from ..model.instrument import InstrumentModel
|
|
30
|
+
from ..util.typing import PriorDictType
|
|
31
|
+
from ._build_model import build_prior, forward_model
|
|
37
32
|
|
|
38
33
|
|
|
39
|
-
class BayesianModel:
|
|
34
|
+
class BayesianModel(nnx.Module):
|
|
40
35
|
"""
|
|
41
36
|
Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
|
|
42
37
|
and compute the log-likelihood and posterior probability.
|
|
43
38
|
"""
|
|
44
39
|
|
|
40
|
+
settings: dict[str, Any]
|
|
41
|
+
|
|
45
42
|
def __init__(
|
|
46
43
|
self,
|
|
47
44
|
model: SpectralModel,
|
|
48
45
|
prior_distributions: PriorDictType | Callable,
|
|
49
46
|
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
50
47
|
background_model: BackgroundModel = None,
|
|
48
|
+
instrument_model: InstrumentModel = None,
|
|
51
49
|
sparsify_matrix: bool = False,
|
|
50
|
+
n_points: int = 2,
|
|
52
51
|
):
|
|
53
52
|
"""
|
|
54
53
|
Build a Bayesian model for a given spectral model and observations.
|
|
@@ -59,74 +58,36 @@ class BayesianModel:
|
|
|
59
58
|
callable function that returns parameter samples.
|
|
60
59
|
observations: the observations to fit the model to.
|
|
61
60
|
background_model: the background model to fit.
|
|
61
|
+
instrument_model: the instrument model to fit.
|
|
62
62
|
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
63
63
|
"""
|
|
64
|
-
|
|
64
|
+
|
|
65
|
+
self.spectral_model = model
|
|
65
66
|
self._observations = observations
|
|
66
67
|
self.background_model = background_model
|
|
67
|
-
self.
|
|
68
|
+
self.instrument_model = instrument_model
|
|
69
|
+
self.settings = {"sparse": sparsify_matrix}
|
|
68
70
|
|
|
69
71
|
if not callable(prior_distributions):
|
|
70
|
-
# Validate the entry with pydantic
|
|
71
|
-
# prior = PriorDictModel.from_dict(prior_distributions).
|
|
72
72
|
|
|
73
73
|
def prior_distributions_func():
|
|
74
74
|
return build_prior(
|
|
75
|
-
prior_distributions,
|
|
75
|
+
prior_distributions,
|
|
76
|
+
expand_shape=(len(self._observation_container),),
|
|
77
|
+
prefix="mod/~/",
|
|
76
78
|
)
|
|
77
79
|
|
|
78
80
|
else:
|
|
79
81
|
prior_distributions_func = prior_distributions
|
|
80
82
|
|
|
81
83
|
self.prior_distributions_func = prior_distributions_func
|
|
82
|
-
self.init_params = self.prior_samples()
|
|
83
|
-
|
|
84
|
-
# Check the priors are suited for the observations
|
|
85
|
-
split_parameters = [
|
|
86
|
-
(param, shape[-1])
|
|
87
|
-
for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
|
|
88
|
-
if (len(shape) > 1)
|
|
89
|
-
and not param.startswith("_")
|
|
90
|
-
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
91
|
-
]
|
|
92
|
-
|
|
93
|
-
for parameter, proposed_number_of_obs in split_parameters:
|
|
94
|
-
if proposed_number_of_obs != len(self.observation_container):
|
|
95
|
-
raise ValueError(
|
|
96
|
-
f"Invalid splitting in the prior distribution. "
|
|
97
|
-
f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
@cached_property
|
|
101
|
-
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
102
|
-
"""
|
|
103
|
-
The observations used in the fit as a dictionary of observations.
|
|
104
|
-
"""
|
|
105
|
-
|
|
106
|
-
if isinstance(self._observations, dict):
|
|
107
|
-
return self._observations
|
|
108
|
-
|
|
109
|
-
elif isinstance(self._observations, list):
|
|
110
|
-
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
111
|
-
|
|
112
|
-
elif isinstance(self._observations, ObsConfiguration):
|
|
113
|
-
return {"data": self._observations}
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
117
|
-
|
|
118
|
-
@cached_property
|
|
119
|
-
def numpyro_model(self) -> Callable:
|
|
120
|
-
"""
|
|
121
|
-
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
122
|
-
"""
|
|
123
84
|
|
|
124
85
|
def numpyro_model(observed=True):
|
|
125
86
|
# Instantiate and register the parameters of the spectral model and the background
|
|
126
87
|
prior_params = self.prior_distributions_func()
|
|
127
88
|
|
|
128
89
|
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
129
|
-
for i, (name, observation) in enumerate(self.
|
|
90
|
+
for i, (name, observation) in enumerate(self._observation_container.items()):
|
|
130
91
|
# Check that we can indeed fit a background
|
|
131
92
|
if (getattr(observation, "folded_background", None) is not None) and (
|
|
132
93
|
self.background_model is not None
|
|
@@ -150,23 +111,70 @@ class BayesianModel:
|
|
|
150
111
|
# They can be identical or different for each observation
|
|
151
112
|
params = jax.tree.map(lambda x: x[i], prior_params)
|
|
152
113
|
|
|
114
|
+
if self.instrument_model is not None:
|
|
115
|
+
gain, shift = self.instrument_model.get_gain_and_shift_model(name)
|
|
116
|
+
else:
|
|
117
|
+
gain, shift = None, None
|
|
118
|
+
|
|
153
119
|
# Forward model the observation and get the associated countrate
|
|
154
120
|
obs_model = jax.jit(
|
|
155
|
-
lambda par: forward_model(
|
|
121
|
+
lambda par: forward_model(
|
|
122
|
+
self.spectral_model,
|
|
123
|
+
par,
|
|
124
|
+
observation,
|
|
125
|
+
sparse=self.settings.get("sparse", False),
|
|
126
|
+
gain=gain,
|
|
127
|
+
shift=shift,
|
|
128
|
+
n_points=n_points,
|
|
129
|
+
)
|
|
156
130
|
)
|
|
131
|
+
|
|
157
132
|
obs_countrate = obs_model(params)
|
|
158
133
|
|
|
159
134
|
# Register the observation as an observed site
|
|
160
|
-
with numpyro.plate("
|
|
135
|
+
with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
|
|
161
136
|
numpyro.sample(
|
|
162
|
-
"
|
|
163
|
-
Poisson(
|
|
164
|
-
obs_countrate + bkg_countrate
|
|
165
|
-
), # / observation.folded_backratio.data
|
|
137
|
+
"obs/~/" + name,
|
|
138
|
+
Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
|
|
166
139
|
obs=observation.folded_counts.data if observed else None,
|
|
167
140
|
)
|
|
168
141
|
|
|
169
|
-
|
|
142
|
+
self.numpyro_model = numpyro_model
|
|
143
|
+
self._init_params = self.prior_samples()
|
|
144
|
+
# Check the priors are suited for the observations
|
|
145
|
+
split_parameters = [
|
|
146
|
+
(param, shape[-1])
|
|
147
|
+
for param, shape in jax.tree.map(lambda x: x.shape, self._init_params).items()
|
|
148
|
+
if (len(shape) > 1)
|
|
149
|
+
and not param.startswith("_")
|
|
150
|
+
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
151
|
+
and not param.startswith("ins")
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
for parameter, proposed_number_of_obs in split_parameters:
|
|
155
|
+
if proposed_number_of_obs != len(self._observation_container):
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Invalid splitting in the prior distribution. "
|
|
158
|
+
f"Expected {len(self._observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@cached_property
|
|
162
|
+
def _observation_container(self) -> dict[str, ObsConfiguration]:
|
|
163
|
+
"""
|
|
164
|
+
The observations used in the fit as a dictionary of observations.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
if isinstance(self._observations, dict):
|
|
168
|
+
return self._observations
|
|
169
|
+
|
|
170
|
+
elif isinstance(self._observations, list):
|
|
171
|
+
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
172
|
+
|
|
173
|
+
elif isinstance(self._observations, ObsConfiguration):
|
|
174
|
+
return {"data": self._observations}
|
|
175
|
+
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
170
178
|
|
|
171
179
|
@cached_property
|
|
172
180
|
def transformed_numpyro_model(self) -> Callable:
|
|
@@ -315,7 +323,8 @@ class BayesianModel:
|
|
|
315
323
|
"""
|
|
316
324
|
key_prior, key_posterior = jax.random.split(key, 2)
|
|
317
325
|
n_devices = len(jax.local_devices())
|
|
318
|
-
|
|
326
|
+
mesh = jax.make_mesh((n_devices,), ("batch",))
|
|
327
|
+
sharding = NamedSharding(mesh, PartitionSpec("batch"))
|
|
319
328
|
|
|
320
329
|
# Sample from prior and correct if the number of samples is not a multiple of the number of devices
|
|
321
330
|
if num_samples % n_devices != 0:
|
|
@@ -327,7 +336,7 @@ class BayesianModel:
|
|
|
327
336
|
sharded_parameters = jax.device_put(prior_params, sharding)
|
|
328
337
|
posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
|
|
329
338
|
|
|
330
|
-
for key, value in self.
|
|
339
|
+
for key, value in self._observation_container.items():
|
|
331
340
|
fig, ax = plt.subplots(
|
|
332
341
|
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
333
342
|
)
|
|
@@ -349,7 +358,7 @@ class BayesianModel:
|
|
|
349
358
|
)
|
|
350
359
|
|
|
351
360
|
prior_plot = _plot_binned_samples_with_error(
|
|
352
|
-
ax[0], value.out_energies, posterior_observations["
|
|
361
|
+
ax[0], value.out_energies, posterior_observations["obs/~/" + key], n_sigmas=3
|
|
353
362
|
)
|
|
354
363
|
|
|
355
364
|
legend_plots.append((true_data_plot,))
|
|
@@ -358,7 +367,7 @@ class BayesianModel:
|
|
|
358
367
|
legend_labels.append("Prior Predictive")
|
|
359
368
|
|
|
360
369
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
361
|
-
counts = posterior_observations["
|
|
370
|
+
counts = posterior_observations["obs/~/" + key]
|
|
362
371
|
observed = value.folded_counts.values
|
|
363
372
|
|
|
364
373
|
num_samples = counts.shape[0]
|
|
@@ -387,264 +396,3 @@ class BayesianModel:
|
|
|
387
396
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
388
397
|
plt.tight_layout()
|
|
389
398
|
plt.show()
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
class BayesianModelFitter(BayesianModel, ABC):
|
|
393
|
-
def build_inference_data(
|
|
394
|
-
self,
|
|
395
|
-
posterior_samples,
|
|
396
|
-
num_chains: int = 1,
|
|
397
|
-
num_predictive_samples: int = 1000,
|
|
398
|
-
key: PRNGKey = PRNGKey(42),
|
|
399
|
-
use_transformed_model: bool = False,
|
|
400
|
-
filter_inference_data: bool = True,
|
|
401
|
-
) -> az.InferenceData:
|
|
402
|
-
"""
|
|
403
|
-
Build an [InferenceData][arviz.InferenceData] object from posterior samples.
|
|
404
|
-
|
|
405
|
-
Parameters:
|
|
406
|
-
posterior_samples: the samples from the posterior distribution.
|
|
407
|
-
num_chains: the number of chains used to sample the posterior.
|
|
408
|
-
num_predictive_samples: the number of samples to draw from the prior.
|
|
409
|
-
key: the random key used to initialize the sampler.
|
|
410
|
-
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
411
|
-
filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
|
|
412
|
-
"""
|
|
413
|
-
|
|
414
|
-
numpyro_model = (
|
|
415
|
-
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
416
|
-
)
|
|
417
|
-
|
|
418
|
-
keys = random.split(key, 3)
|
|
419
|
-
|
|
420
|
-
posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
|
|
421
|
-
|
|
422
|
-
prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
|
|
423
|
-
keys[1], observed=False
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
427
|
-
|
|
428
|
-
seeded_model = numpyro.handlers.substitute(
|
|
429
|
-
numpyro.handlers.seed(numpyro_model, keys[3]),
|
|
430
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
observations = {
|
|
434
|
-
name: site["value"]
|
|
435
|
-
for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
|
|
436
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
437
|
-
}
|
|
438
|
-
|
|
439
|
-
def reshape_first_dimension(arr):
|
|
440
|
-
new_dim = arr.shape[0] // num_chains
|
|
441
|
-
new_shape = (num_chains, new_dim) + arr.shape[1:]
|
|
442
|
-
reshaped_array = arr.reshape(new_shape)
|
|
443
|
-
|
|
444
|
-
return reshaped_array
|
|
445
|
-
|
|
446
|
-
posterior_samples = {
|
|
447
|
-
key: reshape_first_dimension(value) for key, value in posterior_samples.items()
|
|
448
|
-
}
|
|
449
|
-
prior = {key: value[None, :] for key, value in prior.items()}
|
|
450
|
-
posterior_predictive = {
|
|
451
|
-
key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
|
|
452
|
-
}
|
|
453
|
-
log_likelihood = {
|
|
454
|
-
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
455
|
-
}
|
|
456
|
-
|
|
457
|
-
inference_data = az.from_dict(
|
|
458
|
-
posterior_samples,
|
|
459
|
-
prior=prior,
|
|
460
|
-
posterior_predictive=posterior_predictive,
|
|
461
|
-
log_likelihood=log_likelihood,
|
|
462
|
-
observed_data=observations,
|
|
463
|
-
)
|
|
464
|
-
|
|
465
|
-
return (
|
|
466
|
-
self.filter_inference_data(inference_data) if filter_inference_data else inference_data
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
def filter_inference_data(
|
|
470
|
-
self,
|
|
471
|
-
inference_data: az.InferenceData,
|
|
472
|
-
) -> az.InferenceData:
|
|
473
|
-
"""
|
|
474
|
-
Filter the inference data to keep only the relevant parameters for the observations.
|
|
475
|
-
|
|
476
|
-
- Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
|
|
477
|
-
- Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
|
|
478
|
-
"""
|
|
479
|
-
|
|
480
|
-
predictive_parameters = []
|
|
481
|
-
|
|
482
|
-
for key, value in self.observation_container.items():
|
|
483
|
-
if self.background_model is not None:
|
|
484
|
-
predictive_parameters.append(f"obs_{key}")
|
|
485
|
-
predictive_parameters.append(f"bkg_{key}")
|
|
486
|
-
else:
|
|
487
|
-
predictive_parameters.append(f"obs_{key}")
|
|
488
|
-
|
|
489
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
490
|
-
predictive_parameters
|
|
491
|
-
]
|
|
492
|
-
|
|
493
|
-
parameters = [
|
|
494
|
-
x
|
|
495
|
-
for x in inference_data.posterior.keys()
|
|
496
|
-
if not x.endswith("_base") or x.startswith("_")
|
|
497
|
-
]
|
|
498
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
499
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
500
|
-
|
|
501
|
-
return inference_data
|
|
502
|
-
|
|
503
|
-
@abstractmethod
|
|
504
|
-
def fit(self, **kwargs) -> FitResult: ...
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class MCMCFitter(BayesianModelFitter):
|
|
508
|
-
"""
|
|
509
|
-
A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
|
|
510
|
-
from numpyro to perform the inference on the model parameters.
|
|
511
|
-
"""
|
|
512
|
-
|
|
513
|
-
kernel_dict = {
|
|
514
|
-
"nuts": NUTS,
|
|
515
|
-
"aies": AIES,
|
|
516
|
-
"ess": ESS,
|
|
517
|
-
}
|
|
518
|
-
|
|
519
|
-
def fit(
|
|
520
|
-
self,
|
|
521
|
-
rng_key: int = 0,
|
|
522
|
-
num_chains: int = len(jax.devices()),
|
|
523
|
-
num_warmup: int = 1000,
|
|
524
|
-
num_samples: int = 1000,
|
|
525
|
-
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
526
|
-
use_transformed_model: bool = True,
|
|
527
|
-
kernel_kwargs: dict = {},
|
|
528
|
-
mcmc_kwargs: dict = {},
|
|
529
|
-
) -> FitResult:
|
|
530
|
-
"""
|
|
531
|
-
Fit the model to the data using a MCMC sampler from numpyro.
|
|
532
|
-
|
|
533
|
-
Parameters:
|
|
534
|
-
rng_key: the random key used to initialize the sampler.
|
|
535
|
-
num_chains: the number of chains to run.
|
|
536
|
-
num_warmup: the number of warmup steps.
|
|
537
|
-
num_samples: the number of samples to draw.
|
|
538
|
-
sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
|
|
539
|
-
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
540
|
-
kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
|
|
541
|
-
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
542
|
-
|
|
543
|
-
Returns:
|
|
544
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
545
|
-
"""
|
|
546
|
-
|
|
547
|
-
bayesian_model = (
|
|
548
|
-
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
chain_kwargs = {
|
|
552
|
-
"num_warmup": num_warmup,
|
|
553
|
-
"num_samples": num_samples,
|
|
554
|
-
"num_chains": num_chains,
|
|
555
|
-
}
|
|
556
|
-
|
|
557
|
-
kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
|
|
558
|
-
|
|
559
|
-
mcmc_kwargs = chain_kwargs | mcmc_kwargs
|
|
560
|
-
|
|
561
|
-
if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
|
|
562
|
-
mcmc_kwargs["chain_method"] = "vectorized"
|
|
563
|
-
warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
|
|
564
|
-
|
|
565
|
-
mcmc = MCMC(kernel, **mcmc_kwargs)
|
|
566
|
-
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
567
|
-
|
|
568
|
-
mcmc.run(keys[0])
|
|
569
|
-
|
|
570
|
-
posterior = mcmc.get_samples()
|
|
571
|
-
|
|
572
|
-
inference_data = self.build_inference_data(
|
|
573
|
-
posterior, num_chains=num_chains, use_transformed_model=True
|
|
574
|
-
)
|
|
575
|
-
|
|
576
|
-
return FitResult(
|
|
577
|
-
self,
|
|
578
|
-
inference_data,
|
|
579
|
-
background_model=self.background_model,
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
class NSFitter(BayesianModelFitter):
|
|
584
|
-
r"""
|
|
585
|
-
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
586
|
-
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
587
|
-
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
588
|
-
|
|
589
|
-
!!! info
|
|
590
|
-
Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
|
|
591
|
-
|
|
592
|
-
"""
|
|
593
|
-
|
|
594
|
-
def fit(
|
|
595
|
-
self,
|
|
596
|
-
rng_key: int = 0,
|
|
597
|
-
num_samples: int = 1000,
|
|
598
|
-
num_live_points: int = 1000,
|
|
599
|
-
plot_diagnostics=False,
|
|
600
|
-
termination_kwargs: dict | None = None,
|
|
601
|
-
verbose=True,
|
|
602
|
-
) -> FitResult:
|
|
603
|
-
"""
|
|
604
|
-
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
605
|
-
|
|
606
|
-
Parameters:
|
|
607
|
-
rng_key: the random key used to initialize the sampler.
|
|
608
|
-
num_samples: the number of samples to draw.
|
|
609
|
-
num_live_points: the number of live points to use at the start of the NS algorithm.
|
|
610
|
-
plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
|
|
611
|
-
termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
|
|
612
|
-
verbose: whether to print the progress of the NS algorithm.
|
|
613
|
-
|
|
614
|
-
Returns:
|
|
615
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
616
|
-
"""
|
|
617
|
-
|
|
618
|
-
bayesian_model = self.transformed_numpyro_model
|
|
619
|
-
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
620
|
-
|
|
621
|
-
ns = NestedSampler(
|
|
622
|
-
bayesian_model,
|
|
623
|
-
constructor_kwargs=dict(
|
|
624
|
-
verbose=verbose,
|
|
625
|
-
difficult_model=True,
|
|
626
|
-
max_samples=1e5,
|
|
627
|
-
parameter_estimation=True,
|
|
628
|
-
gradient_guided=True,
|
|
629
|
-
devices=jax.devices(),
|
|
630
|
-
# init_efficiency_threshold=0.01,
|
|
631
|
-
num_live_points=num_live_points,
|
|
632
|
-
),
|
|
633
|
-
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
634
|
-
)
|
|
635
|
-
|
|
636
|
-
ns.run(keys[0])
|
|
637
|
-
|
|
638
|
-
if plot_diagnostics:
|
|
639
|
-
ns.diagnostics()
|
|
640
|
-
|
|
641
|
-
posterior = ns.get_samples(keys[1], num_samples=num_samples)
|
|
642
|
-
inference_data = self.build_inference_data(
|
|
643
|
-
posterior, num_chains=1, use_transformed_model=True
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
return FitResult(
|
|
647
|
-
self,
|
|
648
|
-
inference_data,
|
|
649
|
-
background_model=self.background_model,
|
|
650
|
-
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
from typing import TYPE_CHECKING
|
|
2
3
|
|
|
4
|
+
import jax
|
|
3
5
|
import jax.numpy as jnp
|
|
4
6
|
import numpy as np
|
|
5
7
|
import numpyro
|
|
@@ -9,9 +11,15 @@ from jax.typing import ArrayLike
|
|
|
9
11
|
from numpyro.distributions import Distribution
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
14
|
+
from jaxspec.data import ObsConfiguration
|
|
15
|
+
from jaxspec.model.abc import SpectralModel
|
|
16
|
+
from jaxspec.util.typing import PriorDictType
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TiedParameter:
|
|
20
|
+
def __init__(self, tied_to, func):
|
|
21
|
+
self.tied_to = tied_to
|
|
22
|
+
self.func = func
|
|
15
23
|
|
|
16
24
|
|
|
17
25
|
def forward_model(
|
|
@@ -19,6 +27,10 @@ def forward_model(
|
|
|
19
27
|
parameters,
|
|
20
28
|
obs_configuration: "ObsConfiguration",
|
|
21
29
|
sparse=False,
|
|
30
|
+
gain: Callable | None = None,
|
|
31
|
+
shift: Callable | None = None,
|
|
32
|
+
split_branches: bool = False,
|
|
33
|
+
n_points: int | None = 2,
|
|
22
34
|
):
|
|
23
35
|
energies = np.asarray(obs_configuration.in_energies)
|
|
24
36
|
|
|
@@ -31,10 +43,24 @@ def forward_model(
|
|
|
31
43
|
else:
|
|
32
44
|
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
33
45
|
|
|
34
|
-
|
|
46
|
+
energies = shift(energies) if shift is not None else energies
|
|
47
|
+
energies = jnp.clip(energies, min=1e-6) # Ensure shifted energies remain positive
|
|
48
|
+
factor = gain(energies) if gain is not None else 1.0
|
|
49
|
+
factor = jnp.clip(factor, min=0.0) # Ensure the gain is positive to avoid NaNs
|
|
35
50
|
|
|
36
|
-
|
|
37
|
-
|
|
51
|
+
if not split_branches:
|
|
52
|
+
expected_counts = transfer_matrix @ (
|
|
53
|
+
model.photon_flux(parameters, *energies, n_points=n_points) * factor
|
|
54
|
+
)
|
|
55
|
+
return jnp.clip(expected_counts, min=1e-6) # Ensure the expected counts are positive
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
model_flux = model.photon_flux(
|
|
59
|
+
parameters, *energies, split_branches=True, n_points=n_points
|
|
60
|
+
)
|
|
61
|
+
return jax.tree.map(
|
|
62
|
+
lambda f: jnp.clip(transfer_matrix @ (f * factor), min=1e-6), model_flux
|
|
63
|
+
)
|
|
38
64
|
|
|
39
65
|
|
|
40
66
|
def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
@@ -43,15 +69,20 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
|
43
69
|
Must be used within a numpyro model.
|
|
44
70
|
"""
|
|
45
71
|
parameters = {}
|
|
72
|
+
params_to_tie = {}
|
|
46
73
|
|
|
47
74
|
for key, value in prior.items():
|
|
48
75
|
# Split the key to extract the module name and parameter name
|
|
49
76
|
module_name, param_name = key.rsplit("_", 1)
|
|
77
|
+
|
|
50
78
|
if isinstance(value, Distribution):
|
|
51
79
|
parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
|
|
52
80
|
f"{prefix}{module_name}_{param_name}", value
|
|
53
81
|
)
|
|
54
82
|
|
|
83
|
+
elif isinstance(value, TiedParameter):
|
|
84
|
+
params_to_tie[key] = value
|
|
85
|
+
|
|
55
86
|
elif isinstance(value, ArrayLike):
|
|
56
87
|
parameters[key] = jnp.ones(expand_shape) * value
|
|
57
88
|
|
|
@@ -60,4 +91,9 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
|
60
91
|
f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
|
|
61
92
|
)
|
|
62
93
|
|
|
94
|
+
for key, value in params_to_tie.items():
|
|
95
|
+
func_to_apply = value.func
|
|
96
|
+
tied_to = value.tied_to
|
|
97
|
+
parameters[key] = func_to_apply(parameters[tied_to])
|
|
98
|
+
|
|
63
99
|
return parameters
|