jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +297 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +114 -84
- jaxspec/fit.py +335 -96
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +194 -55
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +63 -41
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +28 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/METADATA +14 -10
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/RECORD +19 -25
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/WHEEL +1 -1
- jaxspec/data/example_data/MOS1.pha +0 -46
- jaxspec/data/example_data/MOS2.pha +0 -42
- jaxspec/data/example_data/PN.pha +1 -293
- jaxspec/data/example_data/fakeit.pha +1 -335
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.6.dist-info → jaxspec-0.0.8.dist-info}/LICENSE.md +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -1,20 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
5
|
+
|
|
1
6
|
import arviz as az
|
|
7
|
+
import astropy.units as u
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
2
11
|
import numpy as np
|
|
3
12
|
import xarray as xr
|
|
4
|
-
|
|
5
|
-
from ..data import ObsConfiguration
|
|
6
|
-
from ..model.abc import SpectralModel
|
|
7
|
-
from ..model.background import BackgroundModel
|
|
8
|
-
from collections.abc import Mapping
|
|
9
|
-
from typing import TypeVar, Tuple, Literal, Any
|
|
13
|
+
|
|
10
14
|
from astropy.cosmology import Cosmology, Planck18
|
|
11
|
-
import astropy.units as u
|
|
12
15
|
from astropy.units import Unit
|
|
16
|
+
from chainconsumer import Chain, ChainConsumer, PlotConfig
|
|
13
17
|
from haiku.data_structures import traverse
|
|
14
|
-
from chainconsumer import Chain, PlotConfig, ChainConsumer
|
|
15
|
-
import jax
|
|
16
18
|
from jax.typing import ArrayLike
|
|
19
|
+
from numpyro.handlers import seed
|
|
17
20
|
from scipy.integrate import trapezoid
|
|
21
|
+
from scipy.special import gammaln
|
|
22
|
+
from scipy.stats import nbinom
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..fit import BayesianModel
|
|
26
|
+
from ..model.background import BackgroundModel
|
|
18
27
|
|
|
19
28
|
K = TypeVar("K")
|
|
20
29
|
V = TypeVar("V")
|
|
@@ -29,7 +38,6 @@ def _plot_binned_samples_with_error(
|
|
|
29
38
|
x_bins: ArrayLike,
|
|
30
39
|
denominator: ArrayLike | None = None,
|
|
31
40
|
y_samples: ArrayLike | None = None,
|
|
32
|
-
y_observed: ArrayLike | None = None,
|
|
33
41
|
color=(0.15, 0.25, 0.45),
|
|
34
42
|
percentile: tuple = (16, 84),
|
|
35
43
|
):
|
|
@@ -38,7 +46,8 @@ def _plot_binned_samples_with_error(
|
|
|
38
46
|
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
39
47
|
area. If the observed data is provided, it is also plotted as a step function.
|
|
40
48
|
|
|
41
|
-
Parameters
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
42
51
|
x_bins: The bin edges of the data (2 x N).
|
|
43
52
|
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
44
53
|
denominator: Values used to divided the samples, i.e. to get energy flux (N).
|
|
@@ -51,22 +60,15 @@ def _plot_binned_samples_with_error(
|
|
|
51
60
|
|
|
52
61
|
mean, envelope = None, None
|
|
53
62
|
|
|
54
|
-
if
|
|
55
|
-
|
|
63
|
+
if denominator is None:
|
|
64
|
+
denominator = np.ones_like(x_bins[0])
|
|
56
65
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
(mean,) = ax.step(
|
|
65
|
-
list(x_bins[0]) + [x_bins[1][-1]], # x_bins[1][-1]+1],
|
|
66
|
-
list(y_observed / denominator) + [np.nan], # + [np.nan, np.nan],
|
|
67
|
-
where="pre",
|
|
68
|
-
c=color,
|
|
69
|
-
)
|
|
66
|
+
mean = ax.stairs(
|
|
67
|
+
list(np.median(y_samples, axis=0) / denominator),
|
|
68
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
69
|
+
color=color,
|
|
70
|
+
alpha=0.7,
|
|
71
|
+
)
|
|
70
72
|
|
|
71
73
|
if y_samples is not None:
|
|
72
74
|
if denominator is None:
|
|
@@ -77,63 +79,35 @@ def _plot_binned_samples_with_error(
|
|
|
77
79
|
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
78
80
|
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
79
81
|
|
|
80
|
-
ax.
|
|
81
|
-
|
|
82
|
-
list(
|
|
83
|
-
|
|
82
|
+
ax.stairs(
|
|
83
|
+
percentiles[1] / denominator,
|
|
84
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
85
|
+
baseline=percentiles[0] / denominator,
|
|
84
86
|
alpha=0.3,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
+
fill=True,
|
|
88
|
+
color=color,
|
|
87
89
|
)
|
|
88
90
|
|
|
89
91
|
return [(mean, envelope)]
|
|
90
92
|
|
|
91
93
|
|
|
92
|
-
def format_parameters(parameter_name):
|
|
93
|
-
computed_parameters = ["Photon flux", "Energy flux", "Luminosity"]
|
|
94
|
-
|
|
95
|
-
if parameter_name == "weight":
|
|
96
|
-
# ChainConsumer add a weight column to the samples
|
|
97
|
-
return parameter_name
|
|
98
|
-
|
|
99
|
-
for parameter in computed_parameters:
|
|
100
|
-
if parameter in parameter_name:
|
|
101
|
-
return parameter_name
|
|
102
|
-
|
|
103
|
-
# Find second occurrence of the character '_'
|
|
104
|
-
first_occurrence = parameter_name.find("_")
|
|
105
|
-
second_occurrence = parameter_name.find("_", first_occurrence + 1)
|
|
106
|
-
module = parameter_name[:second_occurrence]
|
|
107
|
-
parameter = parameter_name[second_occurrence + 1 :]
|
|
108
|
-
|
|
109
|
-
name, number = module.split("_")
|
|
110
|
-
module = rf"[{name.capitalize()} ({number})]"
|
|
111
|
-
|
|
112
|
-
if parameter == "norm":
|
|
113
|
-
return r"Norm " + module
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
return rf"${parameter}$" + module
|
|
117
|
-
|
|
118
|
-
|
|
119
94
|
class FitResult:
|
|
120
95
|
"""
|
|
121
|
-
|
|
96
|
+
Container for the result of a fit using any ModelFitter class.
|
|
122
97
|
"""
|
|
123
98
|
|
|
124
99
|
# TODO : Add type hints
|
|
125
100
|
def __init__(
|
|
126
101
|
self,
|
|
127
|
-
|
|
128
|
-
obsconf: ObsConfiguration | dict[str, ObsConfiguration],
|
|
102
|
+
bayesian_fitter: BayesianModel,
|
|
129
103
|
inference_data: az.InferenceData,
|
|
130
104
|
structure: Mapping[K, V],
|
|
131
105
|
background_model: BackgroundModel = None,
|
|
132
106
|
):
|
|
133
|
-
self.model = model
|
|
134
|
-
self.
|
|
107
|
+
self.model = bayesian_fitter.model
|
|
108
|
+
self.bayesian_fitter = bayesian_fitter
|
|
135
109
|
self.inference_data = inference_data
|
|
136
|
-
self.obsconfs =
|
|
110
|
+
self.obsconfs = bayesian_fitter.observation_container
|
|
137
111
|
self.background_model = background_model
|
|
138
112
|
self._structure = structure
|
|
139
113
|
|
|
@@ -141,22 +115,94 @@ class FitResult:
|
|
|
141
115
|
for group in self.inference_data.groups():
|
|
142
116
|
group_name = group.split("/")[-1]
|
|
143
117
|
metadata = getattr(self.inference_data, group_name).attrs
|
|
144
|
-
metadata["model"] = str(model)
|
|
118
|
+
metadata["model"] = str(self.model)
|
|
145
119
|
# TODO : Store metadata about observations used in the fitting process
|
|
146
120
|
|
|
147
121
|
@property
|
|
148
122
|
def converged(self) -> bool:
|
|
149
|
-
"""
|
|
123
|
+
r"""
|
|
150
124
|
Convergence of the chain as computed by the $\hat{R}$ statistic.
|
|
151
125
|
"""
|
|
152
126
|
|
|
153
127
|
return all(az.rhat(self.inference_data) < 1.01)
|
|
154
128
|
|
|
129
|
+
@property
|
|
130
|
+
def _structured_samples(self):
|
|
131
|
+
"""
|
|
132
|
+
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
samples_flat = self._structured_samples_flat
|
|
136
|
+
|
|
137
|
+
samples_haiku = {}
|
|
138
|
+
|
|
139
|
+
for module, parameter, value in traverse(self._structure):
|
|
140
|
+
if samples_haiku.get(module, None) is None:
|
|
141
|
+
samples_haiku[module] = {}
|
|
142
|
+
samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
|
|
143
|
+
|
|
144
|
+
return samples_haiku
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def _structured_samples_flat(self):
|
|
148
|
+
"""
|
|
149
|
+
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
153
|
+
posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
|
|
154
|
+
samples_flat = {key: posterior[key].data for key in var_names}
|
|
155
|
+
|
|
156
|
+
return samples_flat
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def input_parameters(self) -> HaikuDict[ArrayLike]:
|
|
160
|
+
"""
|
|
161
|
+
The input parameters of the model.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
posterior = az.extract(self.inference_data, combined=False)
|
|
165
|
+
|
|
166
|
+
samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
|
|
167
|
+
|
|
168
|
+
total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
|
|
169
|
+
|
|
170
|
+
posterior = {key: posterior[key].data for key in posterior.data_vars}
|
|
171
|
+
|
|
172
|
+
with seed(rng_seed=0):
|
|
173
|
+
input_parameters = self.bayesian_fitter.prior_distributions_func()
|
|
174
|
+
|
|
175
|
+
for module, parameter, value in traverse(input_parameters):
|
|
176
|
+
if f"{module}_{parameter}" in posterior.keys():
|
|
177
|
+
# We add as extra dimension as there might be different values per observation
|
|
178
|
+
if posterior[f"{module}_{parameter}"].shape == samples_shape:
|
|
179
|
+
to_set = posterior[f"{module}_{parameter}"][..., None]
|
|
180
|
+
else:
|
|
181
|
+
to_set = posterior[f"{module}_{parameter}"]
|
|
182
|
+
|
|
183
|
+
input_parameters[module][parameter] = to_set
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
# The parameter is fixed in this case, so we just broadcast is over chain and draws
|
|
187
|
+
input_parameters[module][parameter] = value[None, None, ...]
|
|
188
|
+
|
|
189
|
+
if len(total_shape) < len(input_parameters[module][parameter].shape):
|
|
190
|
+
# If there are only chains and draws, we reduce
|
|
191
|
+
input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
input_parameters[module][parameter] = jnp.broadcast_to(
|
|
195
|
+
input_parameters[module][parameter], total_shape
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return input_parameters
|
|
199
|
+
|
|
155
200
|
def photon_flux(
|
|
156
201
|
self,
|
|
157
202
|
e_min: float,
|
|
158
203
|
e_max: float,
|
|
159
204
|
unit: Unit = u.photon / u.cm**2 / u.s,
|
|
205
|
+
register: bool = False,
|
|
160
206
|
) -> ArrayLike:
|
|
161
207
|
"""
|
|
162
208
|
Compute the unfolded photon flux in a given energy band. The flux is then added to
|
|
@@ -166,19 +212,31 @@ class FitResult:
|
|
|
166
212
|
e_min: The lower bound of the energy band in observer frame.
|
|
167
213
|
e_max: The upper bound of the energy band in observer frame.
|
|
168
214
|
unit: The unit of the photon flux.
|
|
215
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
169
216
|
|
|
170
217
|
!!! warning
|
|
171
218
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
172
219
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
173
220
|
"""
|
|
174
221
|
|
|
175
|
-
|
|
222
|
+
@jax.jit
|
|
223
|
+
@jnp.vectorize
|
|
224
|
+
def vectorized_flux(*pars):
|
|
225
|
+
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
226
|
+
return self.model.photon_flux(
|
|
227
|
+
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
|
|
228
|
+
)[0]
|
|
176
229
|
|
|
230
|
+
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
231
|
+
flux = vectorized_flux(*flat_tree)
|
|
177
232
|
conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
|
|
178
|
-
|
|
179
233
|
value = flux * conversion_factor
|
|
180
|
-
|
|
181
|
-
|
|
234
|
+
|
|
235
|
+
if register:
|
|
236
|
+
self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
237
|
+
list(self.inference_data.posterior.coords),
|
|
238
|
+
value,
|
|
239
|
+
)
|
|
182
240
|
|
|
183
241
|
return value
|
|
184
242
|
|
|
@@ -187,6 +245,7 @@ class FitResult:
|
|
|
187
245
|
e_min: float,
|
|
188
246
|
e_max: float,
|
|
189
247
|
unit: Unit = u.erg / u.cm**2 / u.s,
|
|
248
|
+
register: bool = False,
|
|
190
249
|
) -> ArrayLike:
|
|
191
250
|
"""
|
|
192
251
|
Compute the unfolded energy flux in a given energy band. The flux is then added to
|
|
@@ -196,20 +255,31 @@ class FitResult:
|
|
|
196
255
|
e_min: The lower bound of the energy band in observer frame.
|
|
197
256
|
e_max: The upper bound of the energy band in observer frame.
|
|
198
257
|
unit: The unit of the energy flux.
|
|
258
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
199
259
|
|
|
200
260
|
!!! warning
|
|
201
261
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
202
262
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
203
263
|
"""
|
|
204
264
|
|
|
205
|
-
|
|
265
|
+
@jax.jit
|
|
266
|
+
@jnp.vectorize
|
|
267
|
+
def vectorized_flux(*pars):
|
|
268
|
+
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
269
|
+
return self.model.energy_flux(
|
|
270
|
+
parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
|
|
271
|
+
)[0]
|
|
206
272
|
|
|
273
|
+
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
274
|
+
flux = vectorized_flux(*flat_tree)
|
|
207
275
|
conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
|
|
208
|
-
|
|
209
276
|
value = flux * conversion_factor
|
|
210
277
|
|
|
211
|
-
|
|
212
|
-
|
|
278
|
+
if register:
|
|
279
|
+
self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
280
|
+
list(self.inference_data.posterior.coords),
|
|
281
|
+
value,
|
|
282
|
+
)
|
|
213
283
|
|
|
214
284
|
return value
|
|
215
285
|
|
|
@@ -217,10 +287,11 @@ class FitResult:
|
|
|
217
287
|
self,
|
|
218
288
|
e_min: float,
|
|
219
289
|
e_max: float,
|
|
220
|
-
redshift: float | ArrayLike = 0,
|
|
290
|
+
redshift: float | ArrayLike = 0.1,
|
|
221
291
|
observer_frame: bool = True,
|
|
222
292
|
cosmology: Cosmology = Planck18,
|
|
223
293
|
unit: Unit = u.erg / u.s,
|
|
294
|
+
register: bool = False,
|
|
224
295
|
) -> ArrayLike:
|
|
225
296
|
"""
|
|
226
297
|
Compute the luminosity of the source specifying its redshift. The luminosity is then added to
|
|
@@ -233,41 +304,65 @@ class FitResult:
|
|
|
233
304
|
observer_frame: Whether the input bands are defined in observer frame or not.
|
|
234
305
|
cosmology: Chosen cosmology.
|
|
235
306
|
unit: The unit of the luminosity.
|
|
307
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
236
308
|
"""
|
|
237
309
|
|
|
238
310
|
if not observer_frame:
|
|
239
311
|
raise NotImplementedError()
|
|
240
312
|
|
|
241
|
-
|
|
242
|
-
|
|
313
|
+
@jax.jit
|
|
314
|
+
@jnp.vectorize
|
|
315
|
+
def vectorized_flux(*pars):
|
|
316
|
+
parameters_pytree = jax.tree.unflatten(pytree_def, pars)
|
|
317
|
+
return self.model.energy_flux(
|
|
318
|
+
parameters_pytree,
|
|
319
|
+
jnp.asarray([e_min]) * (1 + redshift),
|
|
320
|
+
jnp.asarray([e_max]) * (1 + redshift),
|
|
321
|
+
n_points=100,
|
|
322
|
+
)[0]
|
|
323
|
+
|
|
324
|
+
flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
|
|
325
|
+
flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
|
|
243
326
|
value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
|
|
244
327
|
|
|
245
|
-
|
|
246
|
-
|
|
328
|
+
if register:
|
|
329
|
+
self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
330
|
+
list(self.inference_data.posterior.coords),
|
|
331
|
+
value,
|
|
332
|
+
)
|
|
247
333
|
|
|
248
334
|
return value
|
|
249
335
|
|
|
250
|
-
def to_chain(self, name: str,
|
|
336
|
+
def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
|
|
251
337
|
"""
|
|
252
|
-
Return a ChainConsumer Chain object from the posterior distribution of the
|
|
338
|
+
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
253
339
|
|
|
254
340
|
Parameters:
|
|
255
341
|
name: The name of the chain.
|
|
256
|
-
|
|
342
|
+
parameters_type: The parameters_type to include in the chain.
|
|
257
343
|
"""
|
|
258
344
|
|
|
259
345
|
obs_id = self.inference_data.copy()
|
|
260
346
|
|
|
261
|
-
if
|
|
262
|
-
keys_to_drop = [
|
|
263
|
-
|
|
347
|
+
if parameters_type == "model":
|
|
348
|
+
keys_to_drop = [
|
|
349
|
+
key
|
|
350
|
+
for key in obs_id.posterior.keys()
|
|
351
|
+
if (key.startswith("_") or key.startswith("bkg"))
|
|
352
|
+
]
|
|
353
|
+
elif parameters_type == "bkg":
|
|
264
354
|
keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
|
|
265
355
|
else:
|
|
266
|
-
raise ValueError(f"Unknown value for
|
|
356
|
+
raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
|
|
267
357
|
|
|
268
358
|
obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
|
|
269
359
|
chain = Chain.from_arviz(obs_id, name)
|
|
270
|
-
|
|
360
|
+
|
|
361
|
+
"""
|
|
362
|
+
chain.samples.columns = [
|
|
363
|
+
format_parameters(parameter) for parameter in chain.samples.columns
|
|
364
|
+
]
|
|
365
|
+
"""
|
|
271
366
|
|
|
272
367
|
return chain
|
|
273
368
|
|
|
@@ -304,7 +399,7 @@ class FitResult:
|
|
|
304
399
|
for module, parameter, value in traverse(self._structure):
|
|
305
400
|
if params.get(module, None) is None:
|
|
306
401
|
params[module] = {}
|
|
307
|
-
params[module][parameter] = self.
|
|
402
|
+
params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
|
|
308
403
|
|
|
309
404
|
return params
|
|
310
405
|
|
|
@@ -328,19 +423,50 @@ class FitResult:
|
|
|
328
423
|
return {key: posterior[key].data for key in var_names}
|
|
329
424
|
|
|
330
425
|
@property
|
|
331
|
-
def
|
|
426
|
+
def log_likelihood(self) -> xr.Dataset:
|
|
332
427
|
"""
|
|
333
|
-
Return the
|
|
428
|
+
Return the log_likelihood of each observation
|
|
334
429
|
"""
|
|
335
430
|
log_likelihood = az.extract(self.inference_data, group="log_likelihood")
|
|
336
|
-
dimensions_to_reduce = [
|
|
431
|
+
dimensions_to_reduce = [
|
|
432
|
+
coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]
|
|
433
|
+
]
|
|
337
434
|
return log_likelihood.sum(dimensions_to_reduce)
|
|
338
435
|
|
|
436
|
+
@property
|
|
437
|
+
def c_stat(self):
|
|
438
|
+
r"""
|
|
439
|
+
Return the C-statistic of the model
|
|
440
|
+
|
|
441
|
+
The C-statistic is defined as:
|
|
442
|
+
|
|
443
|
+
$$ C = 2 \sum_{i} M - D*log(M) + D*log(D) - D $$
|
|
444
|
+
or
|
|
445
|
+
$$ C = 2 \sum_{i} M - D*log(M)$$
|
|
446
|
+
for bins with no counts
|
|
447
|
+
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
exclude_dims = ["chain", "draw", "sample"]
|
|
451
|
+
all_dims = list(self.inference_data.log_likelihood.dims)
|
|
452
|
+
reduce_dims = [dim for dim in all_dims if dim not in exclude_dims]
|
|
453
|
+
data = self.inference_data.observed_data
|
|
454
|
+
c_stat = -2 * (
|
|
455
|
+
self.log_likelihood
|
|
456
|
+
+ (gammaln(data + 1) - (xr.where(data > 0, data * (np.log(data) - 1), 0))).sum(
|
|
457
|
+
dim=reduce_dims
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
return c_stat
|
|
462
|
+
|
|
339
463
|
def plot_ppc(
|
|
340
464
|
self,
|
|
341
|
-
percentile:
|
|
465
|
+
percentile: tuple[int, int] = (16, 84),
|
|
342
466
|
x_unit: str | u.Unit = "keV",
|
|
343
|
-
y_type: Literal[
|
|
467
|
+
y_type: Literal[
|
|
468
|
+
"counts", "countrate", "photon_flux", "photon_flux_density"
|
|
469
|
+
] = "photon_flux_density",
|
|
344
470
|
) -> plt.Figure:
|
|
345
471
|
r"""
|
|
346
472
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
@@ -383,21 +509,31 @@ class FitResult:
|
|
|
383
509
|
# and enable weird broadcasting that makes the plot fail
|
|
384
510
|
|
|
385
511
|
fig, axs = plt.subplots(
|
|
386
|
-
2,
|
|
512
|
+
2,
|
|
513
|
+
len(obsconf_container),
|
|
514
|
+
figsize=(6 * len(obsconf_container), 6),
|
|
515
|
+
sharex=True,
|
|
516
|
+
height_ratios=[0.7, 0.3],
|
|
387
517
|
)
|
|
388
518
|
|
|
389
519
|
plot_ylabels_once = True
|
|
390
520
|
|
|
391
521
|
for name, obsconf, ax in zip(
|
|
392
|
-
obsconf_container.keys(),
|
|
522
|
+
obsconf_container.keys(),
|
|
523
|
+
obsconf_container.values(),
|
|
524
|
+
axs.T if len(obsconf_container) > 1 else [axs],
|
|
393
525
|
):
|
|
394
526
|
legend_plots = []
|
|
395
527
|
legend_labels = []
|
|
396
|
-
count = az.extract(
|
|
528
|
+
count = az.extract(
|
|
529
|
+
self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
|
|
530
|
+
).values.T
|
|
397
531
|
bkg_count = (
|
|
398
532
|
None
|
|
399
533
|
if self.background_model is None
|
|
400
|
-
else az.extract(
|
|
534
|
+
else az.extract(
|
|
535
|
+
self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
|
|
536
|
+
).values.T
|
|
401
537
|
)
|
|
402
538
|
|
|
403
539
|
xbins = obsconf.out_energies * u.keV
|
|
@@ -413,18 +549,13 @@ class FitResult:
|
|
|
413
549
|
integrated_arf = (
|
|
414
550
|
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
415
551
|
/ (
|
|
416
|
-
np.abs(
|
|
552
|
+
np.abs(
|
|
553
|
+
xbins[1] - xbins[0]
|
|
554
|
+
) # Must fold in abs because some units reverse the ordering of the bins
|
|
417
555
|
)
|
|
418
556
|
* u.cm**2
|
|
419
557
|
)
|
|
420
558
|
|
|
421
|
-
"""
|
|
422
|
-
if xbins[0][0] < 1 < xbins[1][-1]:
|
|
423
|
-
xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
|
|
424
|
-
else:
|
|
425
|
-
xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
|
|
426
|
-
"""
|
|
427
|
-
|
|
428
559
|
match y_type:
|
|
429
560
|
case "counts":
|
|
430
561
|
denominator = 1
|
|
@@ -437,50 +568,93 @@ class FitResult:
|
|
|
437
568
|
|
|
438
569
|
y_samples = (count * u.photon / denominator).to(y_units)
|
|
439
570
|
y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
|
|
571
|
+
y_observed_low = (
|
|
572
|
+
nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
|
|
573
|
+
* u.photon
|
|
574
|
+
/ denominator
|
|
575
|
+
).to(y_units)
|
|
576
|
+
y_observed_high = (
|
|
577
|
+
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
578
|
+
* u.photon
|
|
579
|
+
/ denominator
|
|
580
|
+
).to(y_units)
|
|
440
581
|
|
|
441
582
|
# Use the helper function to plot the data and posterior predictive
|
|
442
583
|
legend_plots += _plot_binned_samples_with_error(
|
|
443
584
|
ax[0],
|
|
444
585
|
xbins.value,
|
|
445
586
|
y_samples=y_samples.value,
|
|
446
|
-
y_observed=y_observed.value,
|
|
447
587
|
denominator=np.ones_like(y_observed).value,
|
|
448
588
|
color=color,
|
|
449
589
|
percentile=percentile,
|
|
450
590
|
)
|
|
451
591
|
|
|
452
|
-
legend_labels.append("
|
|
592
|
+
legend_labels.append("Model")
|
|
593
|
+
|
|
594
|
+
true_data_plot = ax[0].errorbar(
|
|
595
|
+
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
596
|
+
y_observed.value,
|
|
597
|
+
xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
|
|
598
|
+
yerr=[
|
|
599
|
+
y_observed.value - y_observed_low.value,
|
|
600
|
+
y_observed_high.value - y_observed.value,
|
|
601
|
+
],
|
|
602
|
+
color="black",
|
|
603
|
+
linestyle="none",
|
|
604
|
+
alpha=0.3,
|
|
605
|
+
capsize=2,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
legend_plots.append((true_data_plot,))
|
|
609
|
+
legend_labels.append("Observed")
|
|
453
610
|
|
|
454
611
|
if self.background_model is not None:
|
|
455
612
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
456
613
|
ratio = obsconf.folded_backratio.data
|
|
457
614
|
y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
|
|
458
|
-
y_observed_bkg = (
|
|
615
|
+
y_observed_bkg = (
|
|
616
|
+
obsconf.folded_background.data * u.photon / (denominator * ratio)
|
|
617
|
+
).to(y_units)
|
|
459
618
|
legend_plots += _plot_binned_samples_with_error(
|
|
460
619
|
ax[0],
|
|
461
620
|
xbins.value,
|
|
462
621
|
y_samples=y_samples_bkg.value,
|
|
463
|
-
y_observed=y_observed_bkg.value,
|
|
464
622
|
denominator=np.ones_like(y_observed).value,
|
|
465
623
|
color=(0.26787604, 0.60085972, 0.63302651),
|
|
466
624
|
percentile=percentile,
|
|
467
625
|
)
|
|
468
626
|
|
|
469
|
-
legend_labels.append("
|
|
627
|
+
legend_labels.append("Model (bkg)")
|
|
628
|
+
|
|
629
|
+
residual_samples = (obsconf.folded_counts.data - count) / np.diff(
|
|
630
|
+
np.percentile(count, percentile, axis=0), axis=0
|
|
631
|
+
)
|
|
470
632
|
|
|
471
633
|
residuals = np.percentile(
|
|
472
|
-
|
|
634
|
+
residual_samples,
|
|
473
635
|
percentile,
|
|
474
636
|
axis=0,
|
|
475
637
|
)
|
|
476
638
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
639
|
+
median_residuals = np.median(
|
|
640
|
+
residual_samples,
|
|
641
|
+
axis=0,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
ax[1].stairs(
|
|
645
|
+
residuals[1],
|
|
646
|
+
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
647
|
+
baseline=list(residuals[0]),
|
|
481
648
|
alpha=0.3,
|
|
482
|
-
step="post",
|
|
483
649
|
facecolor=color,
|
|
650
|
+
fill=True,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
ax[1].stairs(
|
|
654
|
+
median_residuals,
|
|
655
|
+
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
656
|
+
color=color,
|
|
657
|
+
alpha=0.7,
|
|
484
658
|
)
|
|
485
659
|
|
|
486
660
|
max_residuals = np.max(np.abs(residuals))
|
|
@@ -502,7 +676,8 @@ class FitResult:
|
|
|
502
676
|
ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
|
|
503
677
|
case _:
|
|
504
678
|
RuntimeError(
|
|
505
|
-
f"Unknown physical type for x_units: {x_unit}. "
|
|
679
|
+
f"Unknown physical type for x_units: {x_unit}. "
|
|
680
|
+
f"Must be 'length', 'energy' or 'frequency'"
|
|
506
681
|
)
|
|
507
682
|
|
|
508
683
|
ax[1].axhline(0, color=color, ls="--")
|
|
@@ -536,15 +711,16 @@ class FitResult:
|
|
|
536
711
|
|
|
537
712
|
def plot_corner(
|
|
538
713
|
self,
|
|
539
|
-
config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=
|
|
714
|
+
config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=12),
|
|
540
715
|
**kwargs: Any,
|
|
541
716
|
) -> plt.Figure:
|
|
542
717
|
"""
|
|
543
|
-
Plot the corner plot of the posterior distribution of the
|
|
718
|
+
Plot the corner plot of the posterior distribution of the parameters_type. This method uses the ChainConsumer.
|
|
544
719
|
|
|
545
720
|
Parameters:
|
|
546
721
|
config: The configuration of the plot.
|
|
547
|
-
**kwargs: Additional arguments passed to ChainConsumer.plotter.plot.
|
|
722
|
+
**kwargs: Additional arguments passed to ChainConsumer.plotter.plot. Some useful parameters are :
|
|
723
|
+
- columns : list of parameters to plot.
|
|
548
724
|
"""
|
|
549
725
|
|
|
550
726
|
consumer = ChainConsumer()
|