jaxspec 0.0.4__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {jaxspec-0.0.4 → jaxspec-0.0.6}/PKG-INFO +5 -3
  2. {jaxspec-0.0.4 → jaxspec-0.0.6}/README.md +1 -1
  3. {jaxspec-0.0.4 → jaxspec-0.0.6}/pyproject.toml +4 -2
  4. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/__init__.py +1 -1
  5. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/compare.py +3 -3
  6. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/results.py +239 -110
  7. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/instrument.py +0 -2
  8. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/ogip.py +18 -0
  9. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/util.py +11 -3
  10. jaxspec-0.0.6/src/jaxspec/fit.py +347 -0
  11. jaxspec-0.0.6/src/jaxspec/model/_additive/apec.py +377 -0
  12. jaxspec-0.0.6/src/jaxspec/model/_additive/apec_loaders.py +90 -0
  13. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/abc.py +55 -7
  14. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/additive.py +2 -51
  15. jaxspec-0.0.6/src/jaxspec/tables/abundances.dat +31 -0
  16. jaxspec-0.0.6/src/jaxspec/util/__init__.py +0 -0
  17. jaxspec-0.0.6/src/jaxspec/util/abundance.py +111 -0
  18. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/util/integrate.py +5 -4
  19. jaxspec-0.0.4/src/jaxspec/fit.py +0 -253
  20. {jaxspec-0.0.4 → jaxspec-0.0.6}/LICENSE.md +0 -0
  21. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/analysis/__init__.py +0 -0
  22. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/__init__.py +0 -0
  23. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.arf +0 -0
  24. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.pha +0 -0
  25. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1.rmf +0 -0
  26. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1_spectrum_grp.fits +0 -0
  27. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS1background_spectrum.fits +0 -0
  28. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.arf +0 -0
  29. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.pha +0 -0
  30. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2.rmf +0 -0
  31. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2_spectrum_grp.fits +0 -0
  32. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/MOS2background_spectrum.fits +0 -0
  33. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.arf +0 -0
  34. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.pha +0 -0
  35. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN.rmf +0 -0
  36. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PN_spectrum_grp20.fits +0 -0
  37. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/PNbackground_spectrum.fits +0 -0
  38. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/example_data/fakeit.pha +0 -0
  39. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/grouping.py +0 -0
  40. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/obsconf.py +0 -0
  41. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/data/observation.py +0 -0
  42. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/__init__.py +0 -0
  43. {jaxspec-0.0.4/src/jaxspec/util → jaxspec-0.0.6/src/jaxspec/model/_additive}/__init__.py +0 -0
  44. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/background.py +0 -0
  45. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/list.py +0 -0
  46. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/model/multiplicative.py +0 -0
  47. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_phabs_aspl.fits +0 -0
  48. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
  49. {jaxspec-0.0.4 → jaxspec-0.0.6}/src/jaxspec/tables/xsect_wabs_angr.fits +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxspec
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: jaxspec is a bayesian spectral fitting library for X-ray astronomy.
5
5
  License: MIT
6
6
  Author: sdupourque
@@ -10,7 +10,7 @@ Classifier: License :: OSI Approved :: MIT License
10
10
  Classifier: Programming Language :: Python :: 3
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
- Requires-Dist: arviz (>=0.17.1,<0.18.0)
13
+ Requires-Dist: arviz (>=0.17.1,<0.19.0)
14
14
  Requires-Dist: astropy (>=6.0.0,<7.0.0)
15
15
  Requires-Dist: chainconsumer (>=1.0.0,<2.0.0)
16
16
  Requires-Dist: cmasher (>=1.6.3,<2.0.0)
@@ -20,11 +20,13 @@ Requires-Dist: jax (>=0.4.23,<0.5.0)
20
20
  Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
21
21
  Requires-Dist: jaxopt (>=0.8.1,<0.9.0)
22
22
  Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
23
+ Requires-Dist: mendeleev (>=0.15.0,<0.16.0)
23
24
  Requires-Dist: mkdocstrings (>=0.24.0,<0.25.0)
24
25
  Requires-Dist: networkx (>=3.1,<4.0)
25
26
  Requires-Dist: numpy (>=1.26.1,<2.0.0)
26
27
  Requires-Dist: numpyro (>=0.13.2,<0.15.0)
27
28
  Requires-Dist: pandas (>=2.2.0,<3.0.0)
29
+ Requires-Dist: pyzmq (<26)
28
30
  Requires-Dist: scipy (<1.13)
29
31
  Requires-Dist: seaborn (>=0.13.1,<0.14.0)
30
32
  Requires-Dist: simpleeval (>=0.9.13,<0.10.0)
@@ -50,7 +52,7 @@ Documentation : https://jaxspec.readthedocs.io/en/latest/
50
52
 
51
53
  ## Installation
52
54
 
53
- We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
55
+ We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
54
56
 
55
57
  ```
56
58
  conda create -n jaxspec python=3.10
@@ -16,7 +16,7 @@ Documentation : https://jaxspec.readthedocs.io/en/latest/
16
16
 
17
17
  ## Installation
18
18
 
19
- We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
19
+ We recommend the users to start from a fresh Python 3.10 [conda environment](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
20
20
 
21
21
  ```
22
22
  conda create -n jaxspec python=3.10
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "jaxspec"
3
- version = "0.0.4"
3
+ version = "0.0.6"
4
4
  description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
5
5
  authors = ["sdupourque <sdupourque@irap.omp.eu>"]
6
6
  license = "MIT"
@@ -17,7 +17,7 @@ numpyro = ">=0.13.2,<0.15.0"
17
17
  dm-haiku = ">=0.0.11,<0.0.13"
18
18
  networkx = "^3.1"
19
19
  matplotlib = "^3.8.0"
20
- arviz = "^0.17.1"
20
+ arviz = ">=0.17.1,<0.19.0"
21
21
  chainconsumer = "^1.0.0"
22
22
  simpleeval = "^0.9.13"
23
23
  cmasher = "^1.6.3"
@@ -28,6 +28,8 @@ seaborn = "^0.13.1"
28
28
  mkdocstrings = "^0.24.0"
29
29
  sparse = "^0.15.1"
30
30
  scipy = "<1.13"
31
+ mendeleev = "^0.15.0"
32
+ pyzmq = "<26"
31
33
 
32
34
 
33
35
  [tool.poetry.group.docs.dependencies]
@@ -1,5 +1,5 @@
1
1
  """
2
- This is the root of jaxspec's module. TODO: clear this up.
2
+ This is the root of jaxspec's module.
3
3
  """
4
4
 
5
5
  import importlib.metadata
@@ -1,9 +1,9 @@
1
1
  from typing import Dict
2
- from .results import ChainResult
2
+ from .results import FitResult
3
3
  from chainconsumer import ChainConsumer
4
4
 
5
5
 
6
- def plot_corner_comparison(obs_dict: Dict[str, ChainResult], **kwargs):
6
+ def plot_corner_comparison(obs_dict: Dict[str, FitResult], **kwargs):
7
7
  """
8
8
  Plot the correlation plot of parameters from different fitted observations. Observations are passed in as a
9
9
  dictionary. Each observation is named according to its key. It shall be used to compare the same model independently
@@ -16,6 +16,6 @@ def plot_corner_comparison(obs_dict: Dict[str, ChainResult], **kwargs):
16
16
  c = ChainConsumer()
17
17
 
18
18
  for name, obs in obs_dict.items():
19
- c.add_chain(obs.chain(name))
19
+ c.add_chain(obs.to_chain(name))
20
20
 
21
21
  return c.plotter.plot(**kwargs)
@@ -1,11 +1,12 @@
1
1
  import arviz as az
2
2
  import numpy as np
3
+ import xarray as xr
3
4
  import matplotlib.pyplot as plt
4
5
  from ..data import ObsConfiguration
5
6
  from ..model.abc import SpectralModel
6
7
  from ..model.background import BackgroundModel
7
8
  from collections.abc import Mapping
8
- from typing import TypeVar, Tuple, Literal
9
+ from typing import TypeVar, Tuple, Literal, Any
9
10
  from astropy.cosmology import Cosmology, Planck18
10
11
  import astropy.units as u
11
12
  from astropy.units import Unit
@@ -17,6 +18,10 @@ from scipy.integrate import trapezoid
17
18
 
18
19
  K = TypeVar("K")
19
20
  V = TypeVar("V")
21
+ T = TypeVar("T")
22
+
23
+
24
+ class HaikuDict(dict[str, dict[str, T]]): ...
20
25
 
21
26
 
22
27
  def _plot_binned_samples_with_error(
@@ -69,7 +74,7 @@ def _plot_binned_samples_with_error(
69
74
 
70
75
  percentiles = np.percentile(y_samples, percentile, axis=0)
71
76
 
72
- # The legend cannot handle fill_between, so we pass a fill to get a fancy icone
77
+ # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
73
78
  (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
74
79
 
75
80
  ax.fill_between(
@@ -111,24 +116,24 @@ def format_parameters(parameter_name):
111
116
  return rf"${parameter}$" + module
112
117
 
113
118
 
114
- class ChainResult:
115
- # TODO : Add docstring
119
+ class FitResult:
120
+ """
121
+ This class is the container for the result of a fit using any ModelFitter class.
122
+ """
123
+
116
124
  # TODO : Add type hints
117
- # TODO : Add proper separation between params and samples, cf from haiku and numpyro
118
125
  def __init__(
119
126
  self,
120
127
  model: SpectralModel,
121
- folding_model: ObsConfiguration,
128
+ obsconf: ObsConfiguration | dict[str, ObsConfiguration],
122
129
  inference_data: az.InferenceData,
123
- samples,
124
130
  structure: Mapping[K, V],
125
131
  background_model: BackgroundModel = None,
126
132
  ):
127
133
  self.model = model
128
134
  self._structure = structure
129
135
  self.inference_data = inference_data
130
- self.folding_model = folding_model
131
- self.samples = samples
136
+ self.obsconfs = {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
132
137
  self.background_model = background_model
133
138
  self._structure = structure
134
139
 
@@ -139,6 +144,14 @@ class ChainResult:
139
144
  metadata["model"] = str(model)
140
145
  # TODO : Store metadata about observations used in the fitting process
141
146
 
147
+ @property
148
+ def converged(self) -> bool:
149
+ """
150
+ Convergence of the chain as computed by the $\hat{R}$ statistic.
151
+ """
152
+
153
+ return all(az.rhat(self.inference_data) < 1.01)
154
+
142
155
  def photon_flux(
143
156
  self,
144
157
  e_min: float,
@@ -164,7 +177,7 @@ class ChainResult:
164
177
  conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
165
178
 
166
179
  value = flux * conversion_factor
167
-
180
+ # TODO : fix this since sample doesn't exist anymore
168
181
  self.samples[rf"Photon flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
169
182
 
170
183
  return value
@@ -195,6 +208,7 @@ class ChainResult:
195
208
 
196
209
  value = flux * conversion_factor
197
210
 
211
+ # TODO : fix this since sample doesn't exist anymore
198
212
  self.samples[rf"Energy flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
199
213
 
200
214
  return value
@@ -207,7 +221,7 @@ class ChainResult:
207
221
  observer_frame: bool = True,
208
222
  cosmology: Cosmology = Planck18,
209
223
  unit: Unit = u.erg / u.s,
210
- ):
224
+ ) -> ArrayLike:
211
225
  """
212
226
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
213
227
  the result parameters so covariance can be plotted.
@@ -228,11 +242,12 @@ class ChainResult:
228
242
 
229
243
  value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
230
244
 
245
+ # TODO : fix this since sample doesn't exist anymore
231
246
  self.samples[rf"Luminosity ({e_min:.1f}-{e_max:.1f} keV)"] = value
232
247
 
233
248
  return value
234
249
 
235
- def chain(self, name: str, parameters: Literal["model", "bkg"] = "model") -> Chain:
250
+ def to_chain(self, name: str, parameters: Literal["model", "bkg"] = "model") -> Chain:
236
251
  """
237
252
  Return a ChainConsumer Chain object from the posterior distribution of the parameters.
238
253
 
@@ -257,9 +272,31 @@ class ChainResult:
257
272
  return chain
258
273
 
259
274
  @property
260
- def params(self):
275
+ def samples_haiku(self) -> HaikuDict[ArrayLike]:
261
276
  """
262
- Haiku-like structure for the parameters
277
+ Haiku-like structure for the samples e.g.
278
+
279
+ ```
280
+ {
281
+ 'powerlaw_1' :
282
+ {
283
+ 'alpha': ...,
284
+ 'amplitude': ...
285
+ },
286
+
287
+ 'blackbody_1':
288
+ {
289
+ 'kT': ...,
290
+ 'norm': ...
291
+ },
292
+
293
+ 'tbabs_1':
294
+ {
295
+ 'nH': ...
296
+ }
297
+ }
298
+ ```
299
+
263
300
  """
264
301
 
265
302
  params = {}
@@ -271,7 +308,40 @@ class ChainResult:
271
308
 
272
309
  return params
273
310
 
274
- def plot_ppc(self, percentile: Tuple[int, int] = (14, 86)) -> plt.Figure:
311
+ @property
312
+ def samples_flat(self) -> dict[str, ArrayLike]:
313
+ """
314
+ Flat structure for the samples e.g.
315
+
316
+ ```
317
+ {
318
+ 'powerlaw_1_alpha': ...,
319
+ 'powerlaw_1_amplitude': ...,
320
+ 'blackbody_1_kT': ...,
321
+ 'blackbody_1_norm': ...,
322
+ 'tbabs_1_nH': ...,
323
+ }
324
+ ```
325
+ """
326
+ var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
327
+ posterior = az.extract(self.inference_data, var_names=var_names)
328
+ return {key: posterior[key].data for key in var_names}
329
+
330
+ @property
331
+ def likelihood(self) -> xr.Dataset:
332
+ """
333
+ Return the likelihood of each observation
334
+ """
335
+ log_likelihood = az.extract(self.inference_data, group="log_likelihood")
336
+ dimensions_to_reduce = [coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]]
337
+ return log_likelihood.sum(dimensions_to_reduce)
338
+
339
+ def plot_ppc(
340
+ self,
341
+ percentile: Tuple[int, int] = (14, 86),
342
+ x_unit: str | u.Unit = "keV",
343
+ y_type: Literal["counts", "countrate", "photon_flux", "photon_flux_density"] = "photon_flux_density",
344
+ ) -> plt.Figure:
275
345
  r"""
276
346
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
277
347
  following formula:
@@ -281,115 +351,176 @@ class ChainResult:
281
351
 
282
352
  Parameters:
283
353
  percentile: The percentile of the posterior predictive distribution to plot.
354
+ x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
355
+ y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
284
356
 
285
357
  Returns:
286
- The matplotlib two panel figure.
358
+ The matplotlib figure.
287
359
  """
288
360
 
289
- folding_model = self.folding_model
290
- count = az.extract(self.inference_data, var_names="obs", group="posterior_predictive").values.T
291
- bkg_count = (
292
- None
293
- if self.background_model is None
294
- else az.extract(self.inference_data, var_names="bkg", group="posterior_predictive").values.T
295
- )
296
-
297
- legend_plots = []
298
- legend_labels = []
361
+ obsconf_container = self.obsconfs
362
+ x_unit = u.Unit(x_unit)
363
+
364
+ match y_type:
365
+ case "counts":
366
+ y_units = u.photon
367
+ case "countrate":
368
+ y_units = u.photon / u.s
369
+ case "photon_flux":
370
+ y_units = u.photon / u.cm**2 / u.s
371
+ case "photon_flux_density":
372
+ y_units = u.photon / u.cm**2 / u.s / x_unit
373
+ case _:
374
+ raise ValueError(
375
+ f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
376
+ )
299
377
 
300
378
  color = (0.15, 0.25, 0.45)
301
379
 
302
380
  with plt.style.context("default"):
303
- # Note to Simon : do not change folding_model.out_energies[1] - folding_model.out_energies[0] to
381
+ # Note to Simon : do not change xbins[1] - xbins[0] to
304
382
  # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
305
383
  # and enable weird broadcasting that makes the plot fail
306
384
 
307
- fig, axs = plt.subplots(2, 1, figsize=(6, 6), sharex=True, height_ratios=[0.7, 0.3])
308
-
309
- mid_bins_arf = folding_model.in_energies.mean(axis=0)
310
-
311
- e_grid = np.linspace(*folding_model.out_energies, 10)
312
- interpolated_arf = np.interp(e_grid, mid_bins_arf, folding_model.area)
313
- integrated_arf = trapezoid(interpolated_arf, x=e_grid, axis=0) / (
314
- folding_model.out_energies[1] - folding_model.out_energies[0]
315
- )
316
-
317
- if folding_model.out_energies[0][0] < 1 < folding_model.out_energies[1][-1]:
318
- xticks = [np.floor(folding_model.out_energies[0][0] * 10) / 10, 1.0, np.floor(folding_model.out_energies[1][-1])]
319
- else:
320
- xticks = [np.floor(folding_model.out_energies[0][0] * 10) / 10, np.floor(folding_model.out_energies[1][-1])]
321
-
322
- denominator = (
323
- (folding_model.out_energies[1] - folding_model.out_energies[0]) * folding_model.exposure.data * integrated_arf
385
+ fig, axs = plt.subplots(
386
+ 2, len(obsconf_container), figsize=(6 * len(obsconf_container), 6), sharex=True, height_ratios=[0.7, 0.3]
324
387
  )
325
388
 
326
- # Use the helper function to plot the data and posterior predictive
327
- legend_plots += _plot_binned_samples_with_error(
328
- axs[0],
329
- folding_model.out_energies,
330
- y_samples=count,
331
- y_observed=folding_model.folded_counts.data,
332
- denominator=denominator,
333
- color=color,
334
- percentile=percentile,
335
- )
389
+ plot_ylabels_once = True
390
+
391
+ for name, obsconf, ax in zip(
392
+ obsconf_container.keys(), obsconf_container.values(), axs.T if len(obsconf_container) > 1 else [axs]
393
+ ):
394
+ legend_plots = []
395
+ legend_labels = []
396
+ count = az.extract(self.inference_data, var_names=f"obs_{name}", group="posterior_predictive").values.T
397
+ bkg_count = (
398
+ None
399
+ if self.background_model is None
400
+ else az.extract(self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive").values.T
401
+ )
336
402
 
337
- legend_labels.append("Source + Background")
403
+ xbins = obsconf.out_energies * u.keV
404
+ xbins = xbins.to(x_unit, u.spectral())
405
+
406
+ # This compute the total effective area within all bins
407
+ # This is a bit weird since the following computation is equivalent to ignoring the RMF
408
+ exposure = obsconf.exposure.data * u.s
409
+ mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
410
+ mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
411
+ e_grid = np.linspace(*xbins, 10)
412
+ interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
413
+ integrated_arf = (
414
+ trapezoid(interpolated_arf, x=e_grid, axis=0)
415
+ / (
416
+ np.abs(xbins[1] - xbins[0]) # Must fold in abs because some units reverse the ordering of the bins
417
+ )
418
+ * u.cm**2
419
+ )
338
420
 
339
- if self.background_model is not None:
340
- # We plot the background only if it is included in the fit, i.e. by subtracting
421
+ """
422
+ if xbins[0][0] < 1 < xbins[1][-1]:
423
+ xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
424
+ else:
425
+ xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
426
+ """
427
+
428
+ match y_type:
429
+ case "counts":
430
+ denominator = 1
431
+ case "countrate":
432
+ denominator = exposure
433
+ case "photon_flux":
434
+ denominator = integrated_arf * exposure
435
+ case "photon_flux_density":
436
+ denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
437
+
438
+ y_samples = (count * u.photon / denominator).to(y_units)
439
+ y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
440
+
441
+ # Use the helper function to plot the data and posterior predictive
341
442
  legend_plots += _plot_binned_samples_with_error(
342
- axs[0],
343
- folding_model.out_energies,
344
- y_observed=folding_model.folded_background.data,
345
- y_samples=bkg_count,
346
- denominator=denominator * folding_model.folded_backratio.data,
347
- color=(0.26787604, 0.60085972, 0.63302651),
443
+ ax[0],
444
+ xbins.value,
445
+ y_samples=y_samples.value,
446
+ y_observed=y_observed.value,
447
+ denominator=np.ones_like(y_observed).value,
448
+ color=color,
348
449
  percentile=percentile,
349
450
  )
350
451
 
351
- legend_labels.append("Background")
352
-
353
- residuals = np.percentile(
354
- (folding_model.folded_counts.data - count) / np.diff(np.percentile(count, percentile, axis=0), axis=0),
355
- percentile,
356
- axis=0,
357
- )
358
-
359
- axs[1].fill_between(
360
- list(folding_model.out_energies[0]) + [folding_model.out_energies[1][-1]],
361
- list(residuals[0]) + [residuals[0][-1]],
362
- list(residuals[1]) + [residuals[1][-1]],
363
- alpha=0.3,
364
- step="post",
365
- facecolor=color,
366
- )
367
-
368
- max_residuals = np.max(np.abs(residuals))
369
-
370
- axs[0].loglog()
371
- axs[0].set_ylabel("Folded spectrum\n" + r"[Counts s$^{-1}$ keV$^{-1}$ cm$^{-2}$]")
372
-
373
- axs[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
374
- axs[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
375
- axs[1].set_xlabel("Energy \n[keV]")
376
-
377
- axs[1].axhline(0, color=color, ls="--")
378
- axs[1].axhline(-3, color=color, ls=":")
379
- axs[1].axhline(3, color=color, ls=":")
380
-
381
- axs[1].set_xticks(xticks, labels=xticks)
382
- axs[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
383
- axs[1].set_yticks(range(-3, 4), minor=True)
384
-
385
- axs[0].set_xlim(folding_model.out_energies.min(), folding_model.out_energies.max())
452
+ legend_labels.append("Source + Background")
453
+
454
+ if self.background_model is not None:
455
+ # We plot the background only if it is included in the fit, i.e. by subtracting
456
+ ratio = obsconf.folded_backratio.data
457
+ y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
458
+ y_observed_bkg = (obsconf.folded_background.data * u.photon / (denominator * ratio)).to(y_units)
459
+ legend_plots += _plot_binned_samples_with_error(
460
+ ax[0],
461
+ xbins.value,
462
+ y_samples=y_samples_bkg.value,
463
+ y_observed=y_observed_bkg.value,
464
+ denominator=np.ones_like(y_observed).value,
465
+ color=(0.26787604, 0.60085972, 0.63302651),
466
+ percentile=percentile,
467
+ )
468
+
469
+ legend_labels.append("Background")
470
+
471
+ residuals = np.percentile(
472
+ (obsconf.folded_counts.data - count) / np.diff(np.percentile(count, percentile, axis=0), axis=0),
473
+ percentile,
474
+ axis=0,
475
+ )
386
476
 
387
- axs[0].legend(legend_plots, legend_labels)
388
- fig.suptitle(self.model.to_string())
477
+ ax[1].fill_between(
478
+ list(xbins.value[0]) + [xbins.value[1][-1]],
479
+ list(residuals[0]) + [residuals[0][-1]],
480
+ list(residuals[1]) + [residuals[1][-1]],
481
+ alpha=0.3,
482
+ step="post",
483
+ facecolor=color,
484
+ )
389
485
 
390
- fig.align_ylabels()
391
- plt.subplots_adjust(hspace=0.0)
392
- fig.tight_layout()
486
+ max_residuals = np.max(np.abs(residuals))
487
+
488
+ ax[0].loglog()
489
+ ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
490
+
491
+ if plot_ylabels_once:
492
+ ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
493
+ ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
494
+ plot_ylabels_once = False
495
+
496
+ match getattr(x_unit, "physical_type"):
497
+ case "length":
498
+ ax[1].set_xlabel(f"Wavelength \n[{x_unit:latex_inline}]")
499
+ case "energy":
500
+ ax[1].set_xlabel(f"Energy \n[{x_unit:latex_inline}]")
501
+ case "frequency":
502
+ ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
503
+ case _:
504
+ RuntimeError(
505
+ f"Unknown physical type for x_units: {x_unit}. " f"Must be 'length', 'energy' or 'frequency'"
506
+ )
507
+
508
+ ax[1].axhline(0, color=color, ls="--")
509
+ ax[1].axhline(-3, color=color, ls=":")
510
+ ax[1].axhline(3, color=color, ls=":")
511
+
512
+ # ax[1].set_xticks(xticks, labels=xticks)
513
+ # ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
514
+ ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
515
+ ax[1].set_yticks(range(-3, 4), minor=True)
516
+
517
+ ax[0].set_xlim(xbins.value.min(), xbins.value.max())
518
+
519
+ ax[0].legend(legend_plots, legend_labels)
520
+ fig.suptitle(self.model.to_string())
521
+ fig.align_ylabels()
522
+ plt.subplots_adjust(hspace=0.0)
523
+ fig.tight_layout()
393
524
 
394
525
  return fig
395
526
 
@@ -399,29 +530,27 @@ class ChainResult:
399
530
  """
400
531
 
401
532
  consumer = ChainConsumer()
402
- consumer.add_chain(self.chain(self.model.to_string()))
533
+ consumer.add_chain(self.to_chain(self.model.to_string()))
403
534
 
404
535
  return consumer.analysis.get_latex_table(caption="Results of the fit", label="tab:results")
405
536
 
406
537
  def plot_corner(
407
538
  self,
408
539
  config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=6),
409
- style="default",
410
- **kwargs,
411
- ):
540
+ **kwargs: Any,
541
+ ) -> plt.Figure:
412
542
  """
413
543
  Plot the corner plot of the posterior distribution of the parameters. This method uses the ChainConsumer.
414
544
 
415
545
  Parameters:
416
546
  config: The configuration of the plot.
417
- style: The matplotlib style of the plot.
418
547
  **kwargs: Additional arguments passed to ChainConsumer.plotter.plot.
419
548
  """
420
549
 
421
550
  consumer = ChainConsumer()
422
- consumer.add_chain(self.chain(self.model.to_string()))
551
+ consumer.add_chain(self.to_chain(self.model.to_string()))
423
552
  consumer.set_plot_config(config)
424
553
 
425
554
  # Context for default mpl style
426
- with plt.style.context(style):
555
+ with plt.style.context("default"):
427
556
  return consumer.plotter.plot(**kwargs)
@@ -72,8 +72,6 @@ class Instrument(xr.Dataset):
72
72
  Parameters:
73
73
  rmf_path: The RMF file path.
74
74
  arf_path: The ARF file path.
75
- exposure: The exposure time in second.
76
- grouping: The grouping matrix.
77
75
  """
78
76
 
79
77
  rmf = DataRMF.from_file(rmf_path)
@@ -75,6 +75,24 @@ class DataPHA:
75
75
  data = QTable.read(pha_file, "SPECTRUM")
76
76
  header = fits.getheader(pha_file, "SPECTRUM")
77
77
 
78
+ if header.get("HDUCLAS2") == "NET":
79
+ raise ValueError(
80
+ f"The HDUCLAS2={header.get('HDUCLAS2')} keyword in the PHA file is not supported."
81
+ f"Please open an issue if this is required."
82
+ )
83
+
84
+ if header.get("HDUCLAS3") == "RATE":
85
+ raise ValueError(
86
+ f"The HDUCLAS3={header.get('HDUCLAS3')} keyword in the PHA file is not supported."
87
+ f"Please open an issue if this is required."
88
+ )
89
+
90
+ if header.get("HDUCLAS4") == "TYPE:II":
91
+ raise ValueError(
92
+ f"The HDUCLAS4={header.get('HDUCLAS4')} keyword in the PHA file is not supported."
93
+ f"Please open an issue if this is required."
94
+ )
95
+
78
96
  if header.get("GROUPING") == 0:
79
97
  grouping = None
80
98
  elif "GROUPING" in data.colnames: