jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +5 -5
- jaxspec/analysis/results.py +38 -25
- jaxspec/data/obsconf.py +9 -3
- jaxspec/data/observation.py +3 -1
- jaxspec/data/ogip.py +9 -2
- jaxspec/data/util.py +17 -11
- jaxspec/experimental/interpolator.py +74 -0
- jaxspec/experimental/interpolator_jax.py +79 -0
- jaxspec/experimental/intrument_models.py +159 -0
- jaxspec/experimental/nested_sampler.py +78 -0
- jaxspec/experimental/tabulated.py +264 -0
- jaxspec/fit/__init__.py +3 -0
- jaxspec/{fit.py → fit/_bayesian_model.py} +86 -338
- jaxspec/{_fit → fit}/_build_model.py +42 -6
- jaxspec/fit/_fitter.py +255 -0
- jaxspec/model/abc.py +52 -80
- jaxspec/model/additive.py +14 -5
- jaxspec/model/background.py +17 -14
- jaxspec/model/instrument.py +81 -0
- jaxspec/model/list.py +4 -1
- jaxspec/model/multiplicative.py +32 -12
- jaxspec/util/integrate.py +17 -5
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/METADATA +9 -9
- jaxspec-0.3.0.dist-info/RECORD +42 -0
- jaxspec-0.2.2.dev0.dist-info/RECORD +0 -34
- /jaxspec/{_fit → experimental}/__init__.py +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/WHEEL +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -1,54 +1,54 @@
|
|
|
1
1
|
import operator
|
|
2
|
-
import warnings
|
|
3
2
|
|
|
4
|
-
from abc import ABC, abstractmethod
|
|
5
3
|
from collections.abc import Callable
|
|
6
4
|
from functools import cached_property
|
|
7
|
-
from typing import
|
|
5
|
+
from typing import Any
|
|
8
6
|
|
|
9
|
-
import arviz as az
|
|
10
7
|
import jax
|
|
11
8
|
import jax.numpy as jnp
|
|
12
9
|
import matplotlib.pyplot as plt
|
|
13
10
|
import numpyro
|
|
14
11
|
|
|
15
|
-
from
|
|
12
|
+
from flax import nnx
|
|
16
13
|
from jax.experimental import mesh_utils
|
|
17
14
|
from jax.random import PRNGKey
|
|
18
15
|
from jax.sharding import PositionalSharding
|
|
19
|
-
from numpyro.contrib.nested_sampling import NestedSampler
|
|
20
16
|
from numpyro.distributions import Poisson, TransformedDistribution
|
|
21
|
-
from numpyro.infer import
|
|
17
|
+
from numpyro.infer import Predictive
|
|
22
18
|
from numpyro.infer.inspect import get_model_relations
|
|
23
19
|
from numpyro.infer.reparam import TransformReparam
|
|
24
20
|
from numpyro.infer.util import log_density
|
|
25
21
|
|
|
26
|
-
from .
|
|
27
|
-
from .analysis._plot import (
|
|
22
|
+
from ..analysis._plot import (
|
|
28
23
|
_error_bars_for_observed_data,
|
|
29
24
|
_plot_binned_samples_with_error,
|
|
30
25
|
_plot_poisson_data_with_error,
|
|
31
26
|
)
|
|
32
|
-
from
|
|
33
|
-
from .
|
|
34
|
-
from
|
|
35
|
-
from
|
|
36
|
-
from
|
|
27
|
+
from ..data import ObsConfiguration
|
|
28
|
+
from ..model.abc import SpectralModel
|
|
29
|
+
from ..model.background import BackgroundModel
|
|
30
|
+
from ..model.instrument import InstrumentModel
|
|
31
|
+
from ..util.typing import PriorDictType
|
|
32
|
+
from ._build_model import build_prior, forward_model
|
|
37
33
|
|
|
38
34
|
|
|
39
|
-
class BayesianModel:
|
|
35
|
+
class BayesianModel(nnx.Module):
|
|
40
36
|
"""
|
|
41
37
|
Base class for a Bayesian model. This class contains the necessary methods to build a model, sample from the prior
|
|
42
38
|
and compute the log-likelihood and posterior probability.
|
|
43
39
|
"""
|
|
44
40
|
|
|
41
|
+
settings: dict[str, Any]
|
|
42
|
+
|
|
45
43
|
def __init__(
|
|
46
44
|
self,
|
|
47
45
|
model: SpectralModel,
|
|
48
46
|
prior_distributions: PriorDictType | Callable,
|
|
49
47
|
observations: ObsConfiguration | list[ObsConfiguration] | dict[str, ObsConfiguration],
|
|
50
48
|
background_model: BackgroundModel = None,
|
|
49
|
+
instrument_model: InstrumentModel = None,
|
|
51
50
|
sparsify_matrix: bool = False,
|
|
51
|
+
n_points: int = 2,
|
|
52
52
|
):
|
|
53
53
|
"""
|
|
54
54
|
Build a Bayesian model for a given spectral model and observations.
|
|
@@ -59,74 +59,36 @@ class BayesianModel:
|
|
|
59
59
|
callable function that returns parameter samples.
|
|
60
60
|
observations: the observations to fit the model to.
|
|
61
61
|
background_model: the background model to fit.
|
|
62
|
+
instrument_model: the instrument model to fit.
|
|
62
63
|
sparsify_matrix: whether to sparsify the transfer matrix.
|
|
63
64
|
"""
|
|
64
|
-
|
|
65
|
+
|
|
66
|
+
self.spectral_model = model
|
|
65
67
|
self._observations = observations
|
|
66
68
|
self.background_model = background_model
|
|
67
|
-
self.
|
|
69
|
+
self.instrument_model = instrument_model
|
|
70
|
+
self.settings = {"sparse": sparsify_matrix}
|
|
68
71
|
|
|
69
72
|
if not callable(prior_distributions):
|
|
70
|
-
# Validate the entry with pydantic
|
|
71
|
-
# prior = PriorDictModel.from_dict(prior_distributions).
|
|
72
73
|
|
|
73
74
|
def prior_distributions_func():
|
|
74
75
|
return build_prior(
|
|
75
|
-
prior_distributions,
|
|
76
|
+
prior_distributions,
|
|
77
|
+
expand_shape=(len(self._observation_container),),
|
|
78
|
+
prefix="mod/~/",
|
|
76
79
|
)
|
|
77
80
|
|
|
78
81
|
else:
|
|
79
82
|
prior_distributions_func = prior_distributions
|
|
80
83
|
|
|
81
84
|
self.prior_distributions_func = prior_distributions_func
|
|
82
|
-
self.init_params = self.prior_samples()
|
|
83
|
-
|
|
84
|
-
# Check the priors are suited for the observations
|
|
85
|
-
split_parameters = [
|
|
86
|
-
(param, shape[-1])
|
|
87
|
-
for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
|
|
88
|
-
if (len(shape) > 1)
|
|
89
|
-
and not param.startswith("_")
|
|
90
|
-
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
91
|
-
]
|
|
92
|
-
|
|
93
|
-
for parameter, proposed_number_of_obs in split_parameters:
|
|
94
|
-
if proposed_number_of_obs != len(self.observation_container):
|
|
95
|
-
raise ValueError(
|
|
96
|
-
f"Invalid splitting in the prior distribution. "
|
|
97
|
-
f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
@cached_property
|
|
101
|
-
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
102
|
-
"""
|
|
103
|
-
The observations used in the fit as a dictionary of observations.
|
|
104
|
-
"""
|
|
105
|
-
|
|
106
|
-
if isinstance(self._observations, dict):
|
|
107
|
-
return self._observations
|
|
108
|
-
|
|
109
|
-
elif isinstance(self._observations, list):
|
|
110
|
-
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
111
|
-
|
|
112
|
-
elif isinstance(self._observations, ObsConfiguration):
|
|
113
|
-
return {"data": self._observations}
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
117
|
-
|
|
118
|
-
@cached_property
|
|
119
|
-
def numpyro_model(self) -> Callable:
|
|
120
|
-
"""
|
|
121
|
-
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
122
|
-
"""
|
|
123
85
|
|
|
124
86
|
def numpyro_model(observed=True):
|
|
125
87
|
# Instantiate and register the parameters of the spectral model and the background
|
|
126
88
|
prior_params = self.prior_distributions_func()
|
|
127
89
|
|
|
128
90
|
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
129
|
-
for i, (name, observation) in enumerate(self.
|
|
91
|
+
for i, (name, observation) in enumerate(self._observation_container.items()):
|
|
130
92
|
# Check that we can indeed fit a background
|
|
131
93
|
if (getattr(observation, "folded_background", None) is not None) and (
|
|
132
94
|
self.background_model is not None
|
|
@@ -150,23 +112,70 @@ class BayesianModel:
|
|
|
150
112
|
# They can be identical or different for each observation
|
|
151
113
|
params = jax.tree.map(lambda x: x[i], prior_params)
|
|
152
114
|
|
|
115
|
+
if self.instrument_model is not None:
|
|
116
|
+
gain, shift = self.instrument_model.get_gain_and_shift_model(name)
|
|
117
|
+
else:
|
|
118
|
+
gain, shift = None, None
|
|
119
|
+
|
|
153
120
|
# Forward model the observation and get the associated countrate
|
|
154
121
|
obs_model = jax.jit(
|
|
155
|
-
lambda par: forward_model(
|
|
122
|
+
lambda par: forward_model(
|
|
123
|
+
self.spectral_model,
|
|
124
|
+
par,
|
|
125
|
+
observation,
|
|
126
|
+
sparse=self.settings.get("sparse", False),
|
|
127
|
+
gain=gain,
|
|
128
|
+
shift=shift,
|
|
129
|
+
n_points=n_points,
|
|
130
|
+
)
|
|
156
131
|
)
|
|
132
|
+
|
|
157
133
|
obs_countrate = obs_model(params)
|
|
158
134
|
|
|
159
135
|
# Register the observation as an observed site
|
|
160
|
-
with numpyro.plate("
|
|
136
|
+
with numpyro.plate("obs_plate/~/" + name, len(observation.folded_counts)):
|
|
161
137
|
numpyro.sample(
|
|
162
|
-
"
|
|
163
|
-
Poisson(
|
|
164
|
-
obs_countrate + bkg_countrate
|
|
165
|
-
), # / observation.folded_backratio.data
|
|
138
|
+
"obs/~/" + name,
|
|
139
|
+
Poisson(obs_countrate + bkg_countrate / observation.folded_backratio.data),
|
|
166
140
|
obs=observation.folded_counts.data if observed else None,
|
|
167
141
|
)
|
|
168
142
|
|
|
169
|
-
|
|
143
|
+
self.numpyro_model = numpyro_model
|
|
144
|
+
self._init_params = self.prior_samples()
|
|
145
|
+
# Check the priors are suited for the observations
|
|
146
|
+
split_parameters = [
|
|
147
|
+
(param, shape[-1])
|
|
148
|
+
for param, shape in jax.tree.map(lambda x: x.shape, self._init_params).items()
|
|
149
|
+
if (len(shape) > 1)
|
|
150
|
+
and not param.startswith("_")
|
|
151
|
+
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
152
|
+
and not param.startswith("ins")
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
for parameter, proposed_number_of_obs in split_parameters:
|
|
156
|
+
if proposed_number_of_obs != len(self._observation_container):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Invalid splitting in the prior distribution. "
|
|
159
|
+
f"Expected {len(self._observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
@cached_property
|
|
163
|
+
def _observation_container(self) -> dict[str, ObsConfiguration]:
|
|
164
|
+
"""
|
|
165
|
+
The observations used in the fit as a dictionary of observations.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
if isinstance(self._observations, dict):
|
|
169
|
+
return self._observations
|
|
170
|
+
|
|
171
|
+
elif isinstance(self._observations, list):
|
|
172
|
+
return {f"data_{i}": obs for i, obs in enumerate(self._observations)}
|
|
173
|
+
|
|
174
|
+
elif isinstance(self._observations, ObsConfiguration):
|
|
175
|
+
return {"data": self._observations}
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError(f"Invalid type for observations : {type(self._observations)}")
|
|
170
179
|
|
|
171
180
|
@cached_property
|
|
172
181
|
def transformed_numpyro_model(self) -> Callable:
|
|
@@ -235,7 +244,7 @@ class BayesianModel:
|
|
|
235
244
|
return log_posterior_prob
|
|
236
245
|
|
|
237
246
|
@cached_property
|
|
238
|
-
def
|
|
247
|
+
def _parameter_names(self) -> list[str]:
|
|
239
248
|
"""
|
|
240
249
|
A list of parameter names for the model.
|
|
241
250
|
"""
|
|
@@ -260,7 +269,7 @@ class BayesianModel:
|
|
|
260
269
|
"""
|
|
261
270
|
input_params = {}
|
|
262
271
|
|
|
263
|
-
for index, key in enumerate(self.
|
|
272
|
+
for index, key in enumerate(self._parameter_names):
|
|
264
273
|
input_params[key] = theta[index]
|
|
265
274
|
|
|
266
275
|
return input_params
|
|
@@ -270,9 +279,9 @@ class BayesianModel:
|
|
|
270
279
|
Convert a dictionary of parameters to an array of parameters.
|
|
271
280
|
"""
|
|
272
281
|
|
|
273
|
-
theta = jnp.zeros(len(self.
|
|
282
|
+
theta = jnp.zeros(len(self._parameter_names))
|
|
274
283
|
|
|
275
|
-
for index, key in enumerate(self.
|
|
284
|
+
for index, key in enumerate(self._parameter_names):
|
|
276
285
|
theta = theta.at[index].set(dict_of_params[key])
|
|
277
286
|
|
|
278
287
|
return theta
|
|
@@ -289,7 +298,7 @@ class BayesianModel:
|
|
|
289
298
|
@jax.jit
|
|
290
299
|
def prior_sample(key):
|
|
291
300
|
return Predictive(
|
|
292
|
-
self.numpyro_model, return_sites=self.
|
|
301
|
+
self.numpyro_model, return_sites=self._parameter_names, num_samples=num_samples
|
|
293
302
|
)(key, observed=False)
|
|
294
303
|
|
|
295
304
|
return prior_sample(key)
|
|
@@ -327,7 +336,7 @@ class BayesianModel:
|
|
|
327
336
|
sharded_parameters = jax.device_put(prior_params, sharding)
|
|
328
337
|
posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)
|
|
329
338
|
|
|
330
|
-
for key, value in self.
|
|
339
|
+
for key, value in self._observation_container.items():
|
|
331
340
|
fig, ax = plt.subplots(
|
|
332
341
|
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
333
342
|
)
|
|
@@ -349,7 +358,7 @@ class BayesianModel:
|
|
|
349
358
|
)
|
|
350
359
|
|
|
351
360
|
prior_plot = _plot_binned_samples_with_error(
|
|
352
|
-
ax[0], value.out_energies, posterior_observations["
|
|
361
|
+
ax[0], value.out_energies, posterior_observations["obs/~/" + key], n_sigmas=3
|
|
353
362
|
)
|
|
354
363
|
|
|
355
364
|
legend_plots.append((true_data_plot,))
|
|
@@ -358,7 +367,7 @@ class BayesianModel:
|
|
|
358
367
|
legend_labels.append("Prior Predictive")
|
|
359
368
|
|
|
360
369
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
361
|
-
counts = posterior_observations["
|
|
370
|
+
counts = posterior_observations["obs/~/" + key]
|
|
362
371
|
observed = value.folded_counts.values
|
|
363
372
|
|
|
364
373
|
num_samples = counts.shape[0]
|
|
@@ -387,264 +396,3 @@ class BayesianModel:
|
|
|
387
396
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
388
397
|
plt.tight_layout()
|
|
389
398
|
plt.show()
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
class BayesianModelFitter(BayesianModel, ABC):
|
|
393
|
-
def build_inference_data(
|
|
394
|
-
self,
|
|
395
|
-
posterior_samples,
|
|
396
|
-
num_chains: int = 1,
|
|
397
|
-
num_predictive_samples: int = 1000,
|
|
398
|
-
key: PRNGKey = PRNGKey(42),
|
|
399
|
-
use_transformed_model: bool = False,
|
|
400
|
-
filter_inference_data: bool = True,
|
|
401
|
-
) -> az.InferenceData:
|
|
402
|
-
"""
|
|
403
|
-
Build an [InferenceData][arviz.InferenceData] object from posterior samples.
|
|
404
|
-
|
|
405
|
-
Parameters:
|
|
406
|
-
posterior_samples: the samples from the posterior distribution.
|
|
407
|
-
num_chains: the number of chains used to sample the posterior.
|
|
408
|
-
num_predictive_samples: the number of samples to draw from the prior.
|
|
409
|
-
key: the random key used to initialize the sampler.
|
|
410
|
-
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
411
|
-
filter_inference_data: whether to filter the InferenceData to keep only the relevant parameters.
|
|
412
|
-
"""
|
|
413
|
-
|
|
414
|
-
numpyro_model = (
|
|
415
|
-
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
416
|
-
)
|
|
417
|
-
|
|
418
|
-
keys = random.split(key, 3)
|
|
419
|
-
|
|
420
|
-
posterior_predictive = Predictive(numpyro_model, posterior_samples)(keys[0], observed=False)
|
|
421
|
-
|
|
422
|
-
prior = Predictive(numpyro_model, num_samples=num_predictive_samples * num_chains)(
|
|
423
|
-
keys[1], observed=False
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
log_likelihood = numpyro.infer.log_likelihood(numpyro_model, posterior_samples)
|
|
427
|
-
|
|
428
|
-
seeded_model = numpyro.handlers.substitute(
|
|
429
|
-
numpyro.handlers.seed(numpyro_model, keys[3]),
|
|
430
|
-
substitute_fn=numpyro.infer.init_to_sample,
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
observations = {
|
|
434
|
-
name: site["value"]
|
|
435
|
-
for name, site in numpyro.handlers.trace(seeded_model).get_trace().items()
|
|
436
|
-
if site["type"] == "sample" and site["is_observed"]
|
|
437
|
-
}
|
|
438
|
-
|
|
439
|
-
def reshape_first_dimension(arr):
|
|
440
|
-
new_dim = arr.shape[0] // num_chains
|
|
441
|
-
new_shape = (num_chains, new_dim) + arr.shape[1:]
|
|
442
|
-
reshaped_array = arr.reshape(new_shape)
|
|
443
|
-
|
|
444
|
-
return reshaped_array
|
|
445
|
-
|
|
446
|
-
posterior_samples = {
|
|
447
|
-
key: reshape_first_dimension(value) for key, value in posterior_samples.items()
|
|
448
|
-
}
|
|
449
|
-
prior = {key: value[None, :] for key, value in prior.items()}
|
|
450
|
-
posterior_predictive = {
|
|
451
|
-
key: reshape_first_dimension(value) for key, value in posterior_predictive.items()
|
|
452
|
-
}
|
|
453
|
-
log_likelihood = {
|
|
454
|
-
key: reshape_first_dimension(value) for key, value in log_likelihood.items()
|
|
455
|
-
}
|
|
456
|
-
|
|
457
|
-
inference_data = az.from_dict(
|
|
458
|
-
posterior_samples,
|
|
459
|
-
prior=prior,
|
|
460
|
-
posterior_predictive=posterior_predictive,
|
|
461
|
-
log_likelihood=log_likelihood,
|
|
462
|
-
observed_data=observations,
|
|
463
|
-
)
|
|
464
|
-
|
|
465
|
-
return (
|
|
466
|
-
self.filter_inference_data(inference_data) if filter_inference_data else inference_data
|
|
467
|
-
)
|
|
468
|
-
|
|
469
|
-
def filter_inference_data(
|
|
470
|
-
self,
|
|
471
|
-
inference_data: az.InferenceData,
|
|
472
|
-
) -> az.InferenceData:
|
|
473
|
-
"""
|
|
474
|
-
Filter the inference data to keep only the relevant parameters for the observations.
|
|
475
|
-
|
|
476
|
-
- Removes predictive parameters from deterministic random variables (e.g. kernel of background GP)
|
|
477
|
-
- Removes parameters build from reparametrised variables (e.g. ending with `"_base"`)
|
|
478
|
-
"""
|
|
479
|
-
|
|
480
|
-
predictive_parameters = []
|
|
481
|
-
|
|
482
|
-
for key, value in self.observation_container.items():
|
|
483
|
-
if self.background_model is not None:
|
|
484
|
-
predictive_parameters.append(f"obs_{key}")
|
|
485
|
-
predictive_parameters.append(f"bkg_{key}")
|
|
486
|
-
else:
|
|
487
|
-
predictive_parameters.append(f"obs_{key}")
|
|
488
|
-
|
|
489
|
-
inference_data.posterior_predictive = inference_data.posterior_predictive[
|
|
490
|
-
predictive_parameters
|
|
491
|
-
]
|
|
492
|
-
|
|
493
|
-
parameters = [
|
|
494
|
-
x
|
|
495
|
-
for x in inference_data.posterior.keys()
|
|
496
|
-
if not x.endswith("_base") or x.startswith("_")
|
|
497
|
-
]
|
|
498
|
-
inference_data.posterior = inference_data.posterior[parameters]
|
|
499
|
-
inference_data.prior = inference_data.prior[parameters]
|
|
500
|
-
|
|
501
|
-
return inference_data
|
|
502
|
-
|
|
503
|
-
@abstractmethod
|
|
504
|
-
def fit(self, **kwargs) -> FitResult: ...
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class MCMCFitter(BayesianModelFitter):
|
|
508
|
-
"""
|
|
509
|
-
A class to fit a model to a given set of observation using a Bayesian approach. This class uses samplers
|
|
510
|
-
from numpyro to perform the inference on the model parameters.
|
|
511
|
-
"""
|
|
512
|
-
|
|
513
|
-
kernel_dict = {
|
|
514
|
-
"nuts": NUTS,
|
|
515
|
-
"aies": AIES,
|
|
516
|
-
"ess": ESS,
|
|
517
|
-
}
|
|
518
|
-
|
|
519
|
-
def fit(
|
|
520
|
-
self,
|
|
521
|
-
rng_key: int = 0,
|
|
522
|
-
num_chains: int = len(jax.devices()),
|
|
523
|
-
num_warmup: int = 1000,
|
|
524
|
-
num_samples: int = 1000,
|
|
525
|
-
sampler: Literal["nuts", "aies", "ess"] = "nuts",
|
|
526
|
-
use_transformed_model: bool = True,
|
|
527
|
-
kernel_kwargs: dict = {},
|
|
528
|
-
mcmc_kwargs: dict = {},
|
|
529
|
-
) -> FitResult:
|
|
530
|
-
"""
|
|
531
|
-
Fit the model to the data using a MCMC sampler from numpyro.
|
|
532
|
-
|
|
533
|
-
Parameters:
|
|
534
|
-
rng_key: the random key used to initialize the sampler.
|
|
535
|
-
num_chains: the number of chains to run.
|
|
536
|
-
num_warmup: the number of warmup steps.
|
|
537
|
-
num_samples: the number of samples to draw.
|
|
538
|
-
sampler: the sampler to use. Can be one of "nuts", "aies" or "ess".
|
|
539
|
-
use_transformed_model: whether to use the transformed model to build the InferenceData.
|
|
540
|
-
kernel_kwargs: additional arguments to pass to the kernel. See [`NUTS`][numpyro.infer.mcmc.MCMCKernel] for more details.
|
|
541
|
-
mcmc_kwargs: additional arguments to pass to the MCMC sampler. See [`MCMC`][numpyro.infer.mcmc.MCMC] for more details.
|
|
542
|
-
|
|
543
|
-
Returns:
|
|
544
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
545
|
-
"""
|
|
546
|
-
|
|
547
|
-
bayesian_model = (
|
|
548
|
-
self.transformed_numpyro_model if use_transformed_model else self.numpyro_model
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
chain_kwargs = {
|
|
552
|
-
"num_warmup": num_warmup,
|
|
553
|
-
"num_samples": num_samples,
|
|
554
|
-
"num_chains": num_chains,
|
|
555
|
-
}
|
|
556
|
-
|
|
557
|
-
kernel = self.kernel_dict[sampler](bayesian_model, **kernel_kwargs)
|
|
558
|
-
|
|
559
|
-
mcmc_kwargs = chain_kwargs | mcmc_kwargs
|
|
560
|
-
|
|
561
|
-
if sampler in ["aies", "ess"] and mcmc_kwargs.get("chain_method", None) != "vectorized":
|
|
562
|
-
mcmc_kwargs["chain_method"] = "vectorized"
|
|
563
|
-
warnings.warn("The chain_method is set to 'vectorized' for AIES and ESS samplers")
|
|
564
|
-
|
|
565
|
-
mcmc = MCMC(kernel, **mcmc_kwargs)
|
|
566
|
-
keys = random.split(random.PRNGKey(rng_key), 3)
|
|
567
|
-
|
|
568
|
-
mcmc.run(keys[0])
|
|
569
|
-
|
|
570
|
-
posterior = mcmc.get_samples()
|
|
571
|
-
|
|
572
|
-
inference_data = self.build_inference_data(
|
|
573
|
-
posterior, num_chains=num_chains, use_transformed_model=True
|
|
574
|
-
)
|
|
575
|
-
|
|
576
|
-
return FitResult(
|
|
577
|
-
self,
|
|
578
|
-
inference_data,
|
|
579
|
-
background_model=self.background_model,
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
class NSFitter(BayesianModelFitter):
|
|
584
|
-
r"""
|
|
585
|
-
A class to fit a model to a given set of observation using the Nested Sampling algorithm. This class uses the
|
|
586
|
-
[`DefaultNestedSampler`][jaxns.DefaultNestedSampler] from [`jaxns`](https://jaxns.readthedocs.io/en/latest/) which
|
|
587
|
-
implements the [Phantom-Powered Nested Sampling](https://arxiv.org/abs/2312.11330) algorithm.
|
|
588
|
-
|
|
589
|
-
!!! info
|
|
590
|
-
Ensure large prior volume is covered by the prior distributions to ensure the algorithm yield proper results.
|
|
591
|
-
|
|
592
|
-
"""
|
|
593
|
-
|
|
594
|
-
def fit(
|
|
595
|
-
self,
|
|
596
|
-
rng_key: int = 0,
|
|
597
|
-
num_samples: int = 1000,
|
|
598
|
-
num_live_points: int = 1000,
|
|
599
|
-
plot_diagnostics=False,
|
|
600
|
-
termination_kwargs: dict | None = None,
|
|
601
|
-
verbose=True,
|
|
602
|
-
) -> FitResult:
|
|
603
|
-
"""
|
|
604
|
-
Fit the model to the data using the Phantom-Powered nested sampling algorithm.
|
|
605
|
-
|
|
606
|
-
Parameters:
|
|
607
|
-
rng_key: the random key used to initialize the sampler.
|
|
608
|
-
num_samples: the number of samples to draw.
|
|
609
|
-
num_live_points: the number of live points to use at the start of the NS algorithm.
|
|
610
|
-
plot_diagnostics: whether to plot the diagnostics of the NS algorithm.
|
|
611
|
-
termination_kwargs: additional arguments to pass to the termination criterion of the NS algorithm.
|
|
612
|
-
verbose: whether to print the progress of the NS algorithm.
|
|
613
|
-
|
|
614
|
-
Returns:
|
|
615
|
-
A [`FitResult`][jaxspec.analysis.results.FitResult] instance containing the results of the fit.
|
|
616
|
-
"""
|
|
617
|
-
|
|
618
|
-
bayesian_model = self.transformed_numpyro_model
|
|
619
|
-
keys = random.split(random.PRNGKey(rng_key), 4)
|
|
620
|
-
|
|
621
|
-
ns = NestedSampler(
|
|
622
|
-
bayesian_model,
|
|
623
|
-
constructor_kwargs=dict(
|
|
624
|
-
verbose=verbose,
|
|
625
|
-
difficult_model=True,
|
|
626
|
-
max_samples=1e5,
|
|
627
|
-
parameter_estimation=True,
|
|
628
|
-
gradient_guided=True,
|
|
629
|
-
devices=jax.devices(),
|
|
630
|
-
# init_efficiency_threshold=0.01,
|
|
631
|
-
num_live_points=num_live_points,
|
|
632
|
-
),
|
|
633
|
-
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
634
|
-
)
|
|
635
|
-
|
|
636
|
-
ns.run(keys[0])
|
|
637
|
-
|
|
638
|
-
if plot_diagnostics:
|
|
639
|
-
ns.diagnostics()
|
|
640
|
-
|
|
641
|
-
posterior = ns.get_samples(keys[1], num_samples=num_samples)
|
|
642
|
-
inference_data = self.build_inference_data(
|
|
643
|
-
posterior, num_chains=1, use_transformed_model=True
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
return FitResult(
|
|
647
|
-
self,
|
|
648
|
-
inference_data,
|
|
649
|
-
background_model=self.background_model,
|
|
650
|
-
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
from typing import TYPE_CHECKING
|
|
2
3
|
|
|
4
|
+
import jax
|
|
3
5
|
import jax.numpy as jnp
|
|
4
6
|
import numpy as np
|
|
5
7
|
import numpyro
|
|
@@ -9,9 +11,15 @@ from jax.typing import ArrayLike
|
|
|
9
11
|
from numpyro.distributions import Distribution
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
14
|
+
from jaxspec.data import ObsConfiguration
|
|
15
|
+
from jaxspec.model.abc import SpectralModel
|
|
16
|
+
from jaxspec.util.typing import PriorDictType
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TiedParameter:
|
|
20
|
+
def __init__(self, tied_to, func):
|
|
21
|
+
self.tied_to = tied_to
|
|
22
|
+
self.func = func
|
|
15
23
|
|
|
16
24
|
|
|
17
25
|
def forward_model(
|
|
@@ -19,6 +27,10 @@ def forward_model(
|
|
|
19
27
|
parameters,
|
|
20
28
|
obs_configuration: "ObsConfiguration",
|
|
21
29
|
sparse=False,
|
|
30
|
+
gain: Callable | None = None,
|
|
31
|
+
shift: Callable | None = None,
|
|
32
|
+
split_branches: bool = False,
|
|
33
|
+
n_points: int | None = 2,
|
|
22
34
|
):
|
|
23
35
|
energies = np.asarray(obs_configuration.in_energies)
|
|
24
36
|
|
|
@@ -31,10 +43,24 @@ def forward_model(
|
|
|
31
43
|
else:
|
|
32
44
|
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
33
45
|
|
|
34
|
-
|
|
46
|
+
energies = shift(energies) if shift is not None else energies
|
|
47
|
+
energies = jnp.clip(energies, min=1e-6) # Ensure shifted energies remain positive
|
|
48
|
+
factor = gain(energies) if gain is not None else 1.0
|
|
49
|
+
factor = jnp.clip(factor, min=0.0) # Ensure the gain is positive to avoid NaNs
|
|
35
50
|
|
|
36
|
-
|
|
37
|
-
|
|
51
|
+
if not split_branches:
|
|
52
|
+
expected_counts = transfer_matrix @ (
|
|
53
|
+
model.photon_flux(parameters, *energies, n_points=n_points) * factor
|
|
54
|
+
)
|
|
55
|
+
return jnp.clip(expected_counts, min=1e-6) # Ensure the expected counts are positive
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
model_flux = model.photon_flux(
|
|
59
|
+
parameters, *energies, split_branches=True, n_points=n_points
|
|
60
|
+
)
|
|
61
|
+
return jax.tree.map(
|
|
62
|
+
lambda f: jnp.clip(transfer_matrix @ (f * factor), min=1e-6), model_flux
|
|
63
|
+
)
|
|
38
64
|
|
|
39
65
|
|
|
40
66
|
def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
@@ -43,15 +69,20 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
|
43
69
|
Must be used within a numpyro model.
|
|
44
70
|
"""
|
|
45
71
|
parameters = {}
|
|
72
|
+
params_to_tie = {}
|
|
46
73
|
|
|
47
74
|
for key, value in prior.items():
|
|
48
75
|
# Split the key to extract the module name and parameter name
|
|
49
76
|
module_name, param_name = key.rsplit("_", 1)
|
|
77
|
+
|
|
50
78
|
if isinstance(value, Distribution):
|
|
51
79
|
parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
|
|
52
80
|
f"{prefix}{module_name}_{param_name}", value
|
|
53
81
|
)
|
|
54
82
|
|
|
83
|
+
elif isinstance(value, TiedParameter):
|
|
84
|
+
params_to_tie[key] = value
|
|
85
|
+
|
|
55
86
|
elif isinstance(value, ArrayLike):
|
|
56
87
|
parameters[key] = jnp.ones(expand_shape) * value
|
|
57
88
|
|
|
@@ -60,4 +91,9 @@ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
|
60
91
|
f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
|
|
61
92
|
)
|
|
62
93
|
|
|
94
|
+
for key, value in params_to_tie.items():
|
|
95
|
+
func_to_apply = value.func
|
|
96
|
+
tied_to = value.tied_to
|
|
97
|
+
parameters[key] = func_to_apply(parameters[tied_to])
|
|
98
|
+
|
|
63
99
|
return parameters
|