jaxspec 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxspec-0.0.4 → jaxspec-0.0.6}/PKG-INFO +5 -3
- {jaxspec-0.0.4 → jaxspec-0.0.6}/README.md +1 -1
- {jaxspec-0.0.4 → jaxspec-0.0.6}/pyproject.toml +4 -2
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/__init__.py +1 -1
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/compare.py +3 -3
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/results.py +239 -110
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/instrument.py +0 -2
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/ogip.py +18 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/util.py +11 -3
- jaxspec-0.0.6/src/jaxspec/fit.py +347 -0
- jaxspec-0.0.6/src/jaxspec/model/_additive/apec.py +377 -0
- jaxspec-0.0.6/src/jaxspec/model/_additive/apec_loaders.py +90 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/abc.py +55 -7
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/additive.py +2 -51
- jaxspec-0.0.6/src/jaxspec/tables/abundances.dat +31 -0
- jaxspec-0.0.6/src/jaxspec/util/__init__.py +0 -0
- jaxspec-0.0.6/src/jaxspec/util/abundance.py +111 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/util/integrate.py +5 -4
- jaxspec-0.0.4/src/jaxspec/fit.py +0 -253
- {jaxspec-0.0.4 → jaxspec-0.0.6}/LICENSE.md +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/__init__.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/__init__.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.arf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.pha +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.rmf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1_spectrum_grp.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1background_spectrum.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.arf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.pha +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.rmf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2_spectrum_grp.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2background_spectrum.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.arf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.pha +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.rmf +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN_spectrum_grp20.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PNbackground_spectrum.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/fakeit.pha +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/grouping.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/obsconf.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/observation.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/__init__.py +0 -0
- {jaxspec-0.0.4/src/jaxspec/util → jaxspec-0.0.6/src/jaxspec/model/_additive}/__init__.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/background.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/list.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/multiplicative.py +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_wabs_angr.fits +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jaxspec
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: sdupourque
|
|
@@ -10,7 +10,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
Requires-Dist: arviz (>=0.17.1,<0.
|
|
13
|
+
Requires-Dist: arviz (>=0.17.1,<0.19.0)
|
|
14
14
|
Requires-Dist: astropy (>=6.0.0,<7.0.0)
|
|
15
15
|
Requires-Dist: chainconsumer (>=1.0.0,<2.0.0)
|
|
16
16
|
Requires-Dist: cmasher (>=1.6.3,<2.0.0)
|
|
@@ -20,11 +20,13 @@ Requires-Dist: jax (>=0.4.23,<0.5.0)
|
|
|
20
20
|
Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
|
|
21
21
|
Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
|
|
22
22
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
|
23
|
+
Requires-Dist: mendeleev (>=0.15.0,<0.16.0)
|
|
23
24
|
Requires-Dist: mkdocstrings (>=0.24.0,<0.25.0)
|
|
24
25
|
Requires-Dist: networkx (>=3.1,<4.0)
|
|
25
26
|
Requires-Dist: numpy (>=1.26.1,<2.0.0)
|
|
26
27
|
Requires-Dist: numpyro (>=0.13.2,<0.15.0)
|
|
27
28
|
Requires-Dist: pandas (>=2.2.0,<3.0.0)
|
|
29
|
+
Requires-Dist: pyzmq (<26)
|
|
28
30
|
Requires-Dist: scipy (<1.13)
|
|
29
31
|
Requires-Dist: seaborn (>=0.13.1,<0.14.0)
|
|
30
32
|
Requires-Dist: simpleeval (>=0.9.13,<0.10.0)
|
|
@@ -50,7 +52,7 @@ Documentation : https://jaxspec.readthedocs.io/en/latest/
|
|
|
50
52
|
|
|
51
53
|
## Installation
|
|
52
54
|
|
|
53
|
-
We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
|
|
55
|
+
We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
|
|
54
56
|
|
|
55
57
|
```
|
|
56
58
|
conda create -n jaxspec python=3.10
|
|
@@ -16,7 +16,7 @@ Documentation : https://jaxspec.readthedocs.io/en/latest/
|
|
|
16
16
|
|
|
17
17
|
## Installation
|
|
18
18
|
|
|
19
|
-
We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
|
|
19
|
+
We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
|
|
20
20
|
|
|
21
21
|
```
|
|
22
22
|
conda create -n jaxspec python=3.10
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "jaxspec"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.6"
|
|
4
4
|
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
|
|
5
5
|
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -17,7 +17,7 @@ numpyro = ">=0.13.2,<0.15.0"
|
|
|
17
17
|
dm-haiku = ">=0.0.11,<0.0.13"
|
|
18
18
|
networkx = "^3.1"
|
|
19
19
|
matplotlib = "^3.8.0"
|
|
20
|
-
arviz = "
|
|
20
|
+
arviz = ">=0.17.1,<0.19.0"
|
|
21
21
|
chainconsumer = "^1.0.0"
|
|
22
22
|
simpleeval = "^0.9.13"
|
|
23
23
|
cmasher = "^1.6.3"
|
|
@@ -28,6 +28,8 @@ seaborn = "^0.13.1"
|
|
|
28
28
|
mkdocstrings = "^0.24.0"
|
|
29
29
|
sparse = "^0.15.1"
|
|
30
30
|
scipy = "<1.13"
|
|
31
|
+
mendeleev = "^0.15.0"
|
|
32
|
+
pyzmq = "<26"
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
[tool.poetry.group.docs.dependencies]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from typing import Dict
|
|
2
|
-
from .results import
|
|
2
|
+
from .results import FitResult
|
|
3
3
|
from chainconsumer import ChainConsumer
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def plot_corner_comparison(obs_dict: Dict[str,
|
|
6
|
+
def plot_corner_comparison(obs_dict: Dict[str, FitResult], **kwargs):
|
|
7
7
|
"""
|
|
8
8
|
Plot the correlation plot of parameters from different fitted observations. Observations are passed in as a
|
|
9
9
|
dictionary. Each observation is named according to its key. It shall be used to compare the same model independently
|
|
@@ -16,6 +16,6 @@ def plot_corner_comparison(obs_dict: Dict[str, ChainResult], **kwargs):
|
|
|
16
16
|
c = ChainConsumer()
|
|
17
17
|
|
|
18
18
|
for name, obs in obs_dict.items():
|
|
19
|
-
c.add_chain(obs.
|
|
19
|
+
c.add_chain(obs.to_chain(name))
|
|
20
20
|
|
|
21
21
|
return c.plotter.plot(**kwargs)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import arviz as az
|
|
2
2
|
import numpy as np
|
|
3
|
+
import xarray as xr
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
5
|
from ..data import ObsConfiguration
|
|
5
6
|
from ..model.abc import SpectralModel
|
|
6
7
|
from ..model.background import BackgroundModel
|
|
7
8
|
from collections.abc import Mapping
|
|
8
|
-
from typing import TypeVar, Tuple, Literal
|
|
9
|
+
from typing import TypeVar, Tuple, Literal, Any
|
|
9
10
|
from astropy.cosmology import Cosmology, Planck18
|
|
10
11
|
import astropy.units as u
|
|
11
12
|
from astropy.units import Unit
|
|
@@ -17,6 +18,10 @@ from scipy.integrate import trapezoid
|
|
|
17
18
|
|
|
18
19
|
K = TypeVar("K")
|
|
19
20
|
V = TypeVar("V")
|
|
21
|
+
T = TypeVar("T")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
def _plot_binned_samples_with_error(
|
|
@@ -69,7 +74,7 @@ def _plot_binned_samples_with_error(
|
|
|
69
74
|
|
|
70
75
|
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
71
76
|
|
|
72
|
-
# The legend cannot handle fill_between, so we pass a fill to get a fancy
|
|
77
|
+
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
73
78
|
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
74
79
|
|
|
75
80
|
ax.fill_between(
|
|
@@ -111,24 +116,24 @@ def format_parameters(parameter_name):
|
|
|
111
116
|
return rf"${parameter}$" + module
|
|
112
117
|
|
|
113
118
|
|
|
114
|
-
class
|
|
115
|
-
|
|
119
|
+
class FitResult:
|
|
120
|
+
"""
|
|
121
|
+
This class is the container for the result of a fit using any ModelFitter class.
|
|
122
|
+
"""
|
|
123
|
+
|
|
116
124
|
# TODO : Add type hints
|
|
117
|
-
# TODO : Add proper separation between params and samples, cf from haiku and numpyro
|
|
118
125
|
def __init__(
|
|
119
126
|
self,
|
|
120
127
|
model: SpectralModel,
|
|
121
|
-
|
|
128
|
+
obsconf: ObsConfiguration | dict[str, ObsConfiguration],
|
|
122
129
|
inference_data: az.InferenceData,
|
|
123
|
-
samples,
|
|
124
130
|
structure: Mapping[K, V],
|
|
125
131
|
background_model: BackgroundModel = None,
|
|
126
132
|
):
|
|
127
133
|
self.model = model
|
|
128
134
|
self._structure = structure
|
|
129
135
|
self.inference_data = inference_data
|
|
130
|
-
self.
|
|
131
|
-
self.samples = samples
|
|
136
|
+
self.obsconfs = {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
|
|
132
137
|
self.background_model = background_model
|
|
133
138
|
self._structure = structure
|
|
134
139
|
|
|
@@ -139,6 +144,14 @@ class ChainResult:
|
|
|
139
144
|
metadata["model"] = str(model)
|
|
140
145
|
# TODO : Store metadata about observations used in the fitting process
|
|
141
146
|
|
|
147
|
+
@property
|
|
148
|
+
def converged(self) -> bool:
|
|
149
|
+
"""
|
|
150
|
+
Convergence of the chain as computed by the $\hat{R}$ statistic.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
return all(az.rhat(self.inference_data) < 1.01)
|
|
154
|
+
|
|
142
155
|
def photon_flux(
|
|
143
156
|
self,
|
|
144
157
|
e_min: float,
|
|
@@ -164,7 +177,7 @@ class ChainResult:
|
|
|
164
177
|
conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
|
|
165
178
|
|
|
166
179
|
value = flux * conversion_factor
|
|
167
|
-
|
|
180
|
+
# TODO : fix this since sample doesn't exist anymore
|
|
168
181
|
self.samples[rf"Photon flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
|
|
169
182
|
|
|
170
183
|
return value
|
|
@@ -195,6 +208,7 @@ class ChainResult:
|
|
|
195
208
|
|
|
196
209
|
value = flux * conversion_factor
|
|
197
210
|
|
|
211
|
+
# TODO : fix this since sample doesn't exist anymore
|
|
198
212
|
self.samples[rf"Energy flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
|
|
199
213
|
|
|
200
214
|
return value
|
|
@@ -207,7 +221,7 @@ class ChainResult:
|
|
|
207
221
|
observer_frame: bool = True,
|
|
208
222
|
cosmology: Cosmology = Planck18,
|
|
209
223
|
unit: Unit = u.erg / u.s,
|
|
210
|
-
):
|
|
224
|
+
) -> ArrayLike:
|
|
211
225
|
"""
|
|
212
226
|
Compute the luminosity of the source specifying its redshift. The luminosity is then added to
|
|
213
227
|
the result parameters so covariance can be plotted.
|
|
@@ -228,11 +242,12 @@ class ChainResult:
|
|
|
228
242
|
|
|
229
243
|
value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
|
|
230
244
|
|
|
245
|
+
# TODO : fix this since sample doesn't exist anymore
|
|
231
246
|
self.samples[rf"Luminosity ({e_min:.1f}-{e_max:.1f} keV)"] = value
|
|
232
247
|
|
|
233
248
|
return value
|
|
234
249
|
|
|
235
|
-
def
|
|
250
|
+
def to_chain(self, name: str, parameters: Literal["model", "bkg"] = "model") -> Chain:
|
|
236
251
|
"""
|
|
237
252
|
Return a ChainConsumer Chain object from the posterior distribution of the parameters.
|
|
238
253
|
|
|
@@ -257,9 +272,31 @@ class ChainResult:
|
|
|
257
272
|
return chain
|
|
258
273
|
|
|
259
274
|
@property
|
|
260
|
-
def
|
|
275
|
+
def samples_haiku(self) -> HaikuDict[ArrayLike]:
|
|
261
276
|
"""
|
|
262
|
-
Haiku-like structure for the
|
|
277
|
+
Haiku-like structure for the samples e.g.
|
|
278
|
+
|
|
279
|
+
```
|
|
280
|
+
{
|
|
281
|
+
'powerlaw_1' :
|
|
282
|
+
{
|
|
283
|
+
'alpha': ...,
|
|
284
|
+
'amplitude': ...
|
|
285
|
+
},
|
|
286
|
+
|
|
287
|
+
'blackbody_1':
|
|
288
|
+
{
|
|
289
|
+
'kT': ...,
|
|
290
|
+
'norm': ...
|
|
291
|
+
},
|
|
292
|
+
|
|
293
|
+
'tbabs_1':
|
|
294
|
+
{
|
|
295
|
+
'nH': ...
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
```
|
|
299
|
+
|
|
263
300
|
"""
|
|
264
301
|
|
|
265
302
|
params = {}
|
|
@@ -271,7 +308,40 @@ class ChainResult:
|
|
|
271
308
|
|
|
272
309
|
return params
|
|
273
310
|
|
|
274
|
-
|
|
311
|
+
@property
|
|
312
|
+
def samples_flat(self) -> dict[str, ArrayLike]:
|
|
313
|
+
"""
|
|
314
|
+
Flat structure for the samples e.g.
|
|
315
|
+
|
|
316
|
+
```
|
|
317
|
+
{
|
|
318
|
+
'powerlaw_1_alpha': ...,
|
|
319
|
+
'powerlaw_1_amplitude': ...,
|
|
320
|
+
'blackbody_1_kT': ...,
|
|
321
|
+
'blackbody_1_norm': ...,
|
|
322
|
+
'tbabs_1_nH': ...,
|
|
323
|
+
}
|
|
324
|
+
```
|
|
325
|
+
"""
|
|
326
|
+
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
327
|
+
posterior = az.extract(self.inference_data, var_names=var_names)
|
|
328
|
+
return {key: posterior[key].data for key in var_names}
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def likelihood(self) -> xr.Dataset:
|
|
332
|
+
"""
|
|
333
|
+
Return the likelihood of each observation
|
|
334
|
+
"""
|
|
335
|
+
log_likelihood = az.extract(self.inference_data, group="log_likelihood")
|
|
336
|
+
dimensions_to_reduce = [coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]]
|
|
337
|
+
return log_likelihood.sum(dimensions_to_reduce)
|
|
338
|
+
|
|
339
|
+
def plot_ppc(
|
|
340
|
+
self,
|
|
341
|
+
percentile: Tuple[int, int] = (14, 86),
|
|
342
|
+
x_unit: str | u.Unit = "keV",
|
|
343
|
+
y_type: Literal["counts", "countrate", "photon_flux", "photon_flux_density"] = "photon_flux_density",
|
|
344
|
+
) -> plt.Figure:
|
|
275
345
|
r"""
|
|
276
346
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
277
347
|
following formula:
|
|
@@ -281,115 +351,176 @@ class ChainResult:
|
|
|
281
351
|
|
|
282
352
|
Parameters:
|
|
283
353
|
percentile: The percentile of the posterior predictive distribution to plot.
|
|
354
|
+
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
355
|
+
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
284
356
|
|
|
285
357
|
Returns:
|
|
286
|
-
The matplotlib
|
|
358
|
+
The matplotlib figure.
|
|
287
359
|
"""
|
|
288
360
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
361
|
+
obsconf_container = self.obsconfs
|
|
362
|
+
x_unit = u.Unit(x_unit)
|
|
363
|
+
|
|
364
|
+
match y_type:
|
|
365
|
+
case "counts":
|
|
366
|
+
y_units = u.photon
|
|
367
|
+
case "countrate":
|
|
368
|
+
y_units = u.photon / u.s
|
|
369
|
+
case "photon_flux":
|
|
370
|
+
y_units = u.photon / u.cm**2 / u.s
|
|
371
|
+
case "photon_flux_density":
|
|
372
|
+
y_units = u.photon / u.cm**2 / u.s / x_unit
|
|
373
|
+
case _:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
|
|
376
|
+
)
|
|
299
377
|
|
|
300
378
|
color = (0.15, 0.25, 0.45)
|
|
301
379
|
|
|
302
380
|
with plt.style.context("default"):
|
|
303
|
-
# Note to Simon : do not change
|
|
381
|
+
# Note to Simon : do not change xbins[1] - xbins[0] to
|
|
304
382
|
# np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
|
|
305
383
|
# and enable weird broadcasting that makes the plot fail
|
|
306
384
|
|
|
307
|
-
fig, axs = plt.subplots(
|
|
308
|
-
|
|
309
|
-
mid_bins_arf = folding_model.in_energies.mean(axis=0)
|
|
310
|
-
|
|
311
|
-
e_grid = np.linspace(*folding_model.out_energies, 10)
|
|
312
|
-
interpolated_arf = np.interp(e_grid, mid_bins_arf, folding_model.area)
|
|
313
|
-
integrated_arf = trapezoid(interpolated_arf, x=e_grid, axis=0) / (
|
|
314
|
-
folding_model.out_energies[1] - folding_model.out_energies[0]
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
if folding_model.out_energies[0][0] < 1 < folding_model.out_energies[1][-1]:
|
|
318
|
-
xticks = [np.floor(folding_model.out_energies[0][0] * 10) / 10, 1.0, np.floor(folding_model.out_energies[1][-1])]
|
|
319
|
-
else:
|
|
320
|
-
xticks = [np.floor(folding_model.out_energies[0][0] * 10) / 10, np.floor(folding_model.out_energies[1][-1])]
|
|
321
|
-
|
|
322
|
-
denominator = (
|
|
323
|
-
(folding_model.out_energies[1] - folding_model.out_energies[0]) * folding_model.exposure.data * integrated_arf
|
|
385
|
+
fig, axs = plt.subplots(
|
|
386
|
+
2, len(obsconf_container), figsize=(6 * len(obsconf_container), 6), sharex=True, height_ratios=[0.7, 0.3]
|
|
324
387
|
)
|
|
325
388
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
389
|
+
plot_ylabels_once = True
|
|
390
|
+
|
|
391
|
+
for name, obsconf, ax in zip(
|
|
392
|
+
obsconf_container.keys(), obsconf_container.values(), axs.T if len(obsconf_container) > 1 else [axs]
|
|
393
|
+
):
|
|
394
|
+
legend_plots = []
|
|
395
|
+
legend_labels = []
|
|
396
|
+
count = az.extract(self.inference_data, var_names=f"obs_{name}", group="posterior_predictive").values.T
|
|
397
|
+
bkg_count = (
|
|
398
|
+
None
|
|
399
|
+
if self.background_model is None
|
|
400
|
+
else az.extract(self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive").values.T
|
|
401
|
+
)
|
|
336
402
|
|
|
337
|
-
|
|
403
|
+
xbins = obsconf.out_energies * u.keV
|
|
404
|
+
xbins = xbins.to(x_unit, u.spectral())
|
|
405
|
+
|
|
406
|
+
# This compute the total effective area within all bins
|
|
407
|
+
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
408
|
+
exposure = obsconf.exposure.data * u.s
|
|
409
|
+
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
410
|
+
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
411
|
+
e_grid = np.linspace(*xbins, 10)
|
|
412
|
+
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
413
|
+
integrated_arf = (
|
|
414
|
+
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
415
|
+
/ (
|
|
416
|
+
np.abs(xbins[1] - xbins[0]) # Must fold in abs because some units reverse the ordering of the bins
|
|
417
|
+
)
|
|
418
|
+
* u.cm**2
|
|
419
|
+
)
|
|
338
420
|
|
|
339
|
-
|
|
340
|
-
|
|
421
|
+
"""
|
|
422
|
+
if xbins[0][0] < 1 < xbins[1][-1]:
|
|
423
|
+
xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
|
|
424
|
+
else:
|
|
425
|
+
xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
match y_type:
|
|
429
|
+
case "counts":
|
|
430
|
+
denominator = 1
|
|
431
|
+
case "countrate":
|
|
432
|
+
denominator = exposure
|
|
433
|
+
case "photon_flux":
|
|
434
|
+
denominator = integrated_arf * exposure
|
|
435
|
+
case "photon_flux_density":
|
|
436
|
+
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
437
|
+
|
|
438
|
+
y_samples = (count * u.photon / denominator).to(y_units)
|
|
439
|
+
y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
|
|
440
|
+
|
|
441
|
+
# Use the helper function to plot the data and posterior predictive
|
|
341
442
|
legend_plots += _plot_binned_samples_with_error(
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
denominator=
|
|
347
|
-
color=
|
|
443
|
+
ax[0],
|
|
444
|
+
xbins.value,
|
|
445
|
+
y_samples=y_samples.value,
|
|
446
|
+
y_observed=y_observed.value,
|
|
447
|
+
denominator=np.ones_like(y_observed).value,
|
|
448
|
+
color=color,
|
|
348
449
|
percentile=percentile,
|
|
349
450
|
)
|
|
350
451
|
|
|
351
|
-
legend_labels.append("Background")
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
axs[1].set_xlabel("Energy \n[keV]")
|
|
376
|
-
|
|
377
|
-
axs[1].axhline(0, color=color, ls="--")
|
|
378
|
-
axs[1].axhline(-3, color=color, ls=":")
|
|
379
|
-
axs[1].axhline(3, color=color, ls=":")
|
|
380
|
-
|
|
381
|
-
axs[1].set_xticks(xticks, labels=xticks)
|
|
382
|
-
axs[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
|
|
383
|
-
axs[1].set_yticks(range(-3, 4), minor=True)
|
|
384
|
-
|
|
385
|
-
axs[0].set_xlim(folding_model.out_energies.min(), folding_model.out_energies.max())
|
|
452
|
+
legend_labels.append("Source + Background")
|
|
453
|
+
|
|
454
|
+
if self.background_model is not None:
|
|
455
|
+
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
456
|
+
ratio = obsconf.folded_backratio.data
|
|
457
|
+
y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
|
|
458
|
+
y_observed_bkg = (obsconf.folded_background.data * u.photon / (denominator * ratio)).to(y_units)
|
|
459
|
+
legend_plots += _plot_binned_samples_with_error(
|
|
460
|
+
ax[0],
|
|
461
|
+
xbins.value,
|
|
462
|
+
y_samples=y_samples_bkg.value,
|
|
463
|
+
y_observed=y_observed_bkg.value,
|
|
464
|
+
denominator=np.ones_like(y_observed).value,
|
|
465
|
+
color=(0.26787604, 0.60085972, 0.63302651),
|
|
466
|
+
percentile=percentile,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
legend_labels.append("Background")
|
|
470
|
+
|
|
471
|
+
residuals = np.percentile(
|
|
472
|
+
(obsconf.folded_counts.data - count) / np.diff(np.percentile(count, percentile, axis=0), axis=0),
|
|
473
|
+
percentile,
|
|
474
|
+
axis=0,
|
|
475
|
+
)
|
|
386
476
|
|
|
387
|
-
|
|
388
|
-
|
|
477
|
+
ax[1].fill_between(
|
|
478
|
+
list(xbins.value[0]) + [xbins.value[1][-1]],
|
|
479
|
+
list(residuals[0]) + [residuals[0][-1]],
|
|
480
|
+
list(residuals[1]) + [residuals[1][-1]],
|
|
481
|
+
alpha=0.3,
|
|
482
|
+
step="post",
|
|
483
|
+
facecolor=color,
|
|
484
|
+
)
|
|
389
485
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
486
|
+
max_residuals = np.max(np.abs(residuals))
|
|
487
|
+
|
|
488
|
+
ax[0].loglog()
|
|
489
|
+
ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
|
|
490
|
+
|
|
491
|
+
if plot_ylabels_once:
|
|
492
|
+
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
493
|
+
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
494
|
+
plot_ylabels_once = False
|
|
495
|
+
|
|
496
|
+
match getattr(x_unit, "physical_type"):
|
|
497
|
+
case "length":
|
|
498
|
+
ax[1].set_xlabel(f"Wavelength \n[{x_unit:latex_inline}]")
|
|
499
|
+
case "energy":
|
|
500
|
+
ax[1].set_xlabel(f"Energy \n[{x_unit:latex_inline}]")
|
|
501
|
+
case "frequency":
|
|
502
|
+
ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
|
|
503
|
+
case _:
|
|
504
|
+
RuntimeError(
|
|
505
|
+
f"Unknown physical type for x_units: {x_unit}. " f"Must be 'length', 'energy' or 'frequency'"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
ax[1].axhline(0, color=color, ls="--")
|
|
509
|
+
ax[1].axhline(-3, color=color, ls=":")
|
|
510
|
+
ax[1].axhline(3, color=color, ls=":")
|
|
511
|
+
|
|
512
|
+
# ax[1].set_xticks(xticks, labels=xticks)
|
|
513
|
+
# ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
|
|
514
|
+
ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
|
|
515
|
+
ax[1].set_yticks(range(-3, 4), minor=True)
|
|
516
|
+
|
|
517
|
+
ax[0].set_xlim(xbins.value.min(), xbins.value.max())
|
|
518
|
+
|
|
519
|
+
ax[0].legend(legend_plots, legend_labels)
|
|
520
|
+
fig.suptitle(self.model.to_string())
|
|
521
|
+
fig.align_ylabels()
|
|
522
|
+
plt.subplots_adjust(hspace=0.0)
|
|
523
|
+
fig.tight_layout()
|
|
393
524
|
|
|
394
525
|
return fig
|
|
395
526
|
|
|
@@ -399,29 +530,27 @@ class ChainResult:
|
|
|
399
530
|
"""
|
|
400
531
|
|
|
401
532
|
consumer = ChainConsumer()
|
|
402
|
-
consumer.add_chain(self.
|
|
533
|
+
consumer.add_chain(self.to_chain(self.model.to_string()))
|
|
403
534
|
|
|
404
535
|
return consumer.analysis.get_latex_table(caption="Results of the fit", label="tab:results")
|
|
405
536
|
|
|
406
537
|
def plot_corner(
|
|
407
538
|
self,
|
|
408
539
|
config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=6),
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
):
|
|
540
|
+
**kwargs: Any,
|
|
541
|
+
) -> plt.Figure:
|
|
412
542
|
"""
|
|
413
543
|
Plot the corner plot of the posterior distribution of the parameters. This method uses the ChainConsumer.
|
|
414
544
|
|
|
415
545
|
Parameters:
|
|
416
546
|
config: The configuration of the plot.
|
|
417
|
-
style: The matplotlib style of the plot.
|
|
418
547
|
**kwargs: Additional arguments passed to ChainConsumer.plotter.plot.
|
|
419
548
|
"""
|
|
420
549
|
|
|
421
550
|
consumer = ChainConsumer()
|
|
422
|
-
consumer.add_chain(self.
|
|
551
|
+
consumer.add_chain(self.to_chain(self.model.to_string()))
|
|
423
552
|
consumer.set_plot_config(config)
|
|
424
553
|
|
|
425
554
|
# Context for default mpl style
|
|
426
|
-
with plt.style.context(
|
|
555
|
+
with plt.style.context("default"):
|
|
427
556
|
return consumer.plotter.plot(**kwargs)
|
|
@@ -75,6 +75,24 @@ class DataPHA:
|
|
|
75
75
|
data = QTable.read(pha_file, "SPECTRUM")
|
|
76
76
|
header = fits.getheader(pha_file, "SPECTRUM")
|
|
77
77
|
|
|
78
|
+
if header.get("HDUCLAS2") == "NET":
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
|
|
81
|
+
f"Please open an issue if this is required."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if header.get("HDUCLAS3") == "RATE":
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"The HDUCLAS3={header.get('HDUCLAS3')} keyword in the PHA file is not supported."
|
|
87
|
+
f"Please open an issue if this is required."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if header.get("HDUCLAS4") == "TYPE:II":
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"The HDUCLAS4={header.get('HDUCLAS4')} keyword in the PHA file is not supported."
|
|
93
|
+
f"Please open an issue if this is required."
|
|
94
|
+
)
|
|
95
|
+
|
|
78
96
|
if header.get("GROUPING") == 0:
|
|
79
97
|
grouping = None
|
|
80
98
|
elif "GROUPING" in data.colnames:
|