jaxspec 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/results.py +250 -121
- jaxspec/data/__init__.py +4 -4
- jaxspec/data/obsconf.py +53 -8
- jaxspec/data/util.py +29 -20
- jaxspec/fit.py +329 -81
- jaxspec/model/__init__.py +0 -1
- jaxspec/model/_additive/apec.py +56 -117
- jaxspec/model/_additive/apec_loaders.py +42 -59
- jaxspec/model/additive.py +27 -13
- jaxspec/model/background.py +50 -16
- jaxspec/model/multiplicative.py +20 -25
- jaxspec/util/__init__.py +45 -0
- jaxspec/util/abundance.py +5 -3
- jaxspec/util/online_storage.py +15 -0
- jaxspec/util/typing.py +43 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/METADATA +12 -9
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/RECORD +19 -22
- jaxspec/tables/abundances.dat +0 -31
- jaxspec/tables/new_apec.nc +0 -0
- jaxspec/tables/xsect_phabs_aspl.fits +0 -0
- jaxspec/tables/xsect_tbabs_wilm.fits +0 -0
- jaxspec/tables/xsect_wabs_angr.fits +0 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.0.5.dist-info → jaxspec-0.0.7.dist-info}/WHEEL +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, Literal, TypeVar
|
|
3
|
+
|
|
1
4
|
import arviz as az
|
|
5
|
+
import astropy.units as u
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
2
9
|
import numpy as np
|
|
3
10
|
import xarray as xr
|
|
4
|
-
|
|
5
|
-
from ..data import ObsConfiguration
|
|
6
|
-
from ..model.abc import SpectralModel
|
|
7
|
-
from ..model.background import BackgroundModel
|
|
8
|
-
from collections.abc import Mapping
|
|
9
|
-
from typing import TypeVar, Tuple, Literal, Any
|
|
11
|
+
|
|
10
12
|
from astropy.cosmology import Cosmology, Planck18
|
|
11
|
-
import astropy.units as u
|
|
12
13
|
from astropy.units import Unit
|
|
14
|
+
from chainconsumer import Chain, ChainConsumer, PlotConfig
|
|
13
15
|
from haiku.data_structures import traverse
|
|
14
|
-
from chainconsumer import Chain, PlotConfig, ChainConsumer
|
|
15
|
-
import jax
|
|
16
16
|
from jax.typing import ArrayLike
|
|
17
17
|
from scipy.integrate import trapezoid
|
|
18
|
+
from scipy.special import gammaln
|
|
19
|
+
from scipy.stats import nbinom
|
|
20
|
+
|
|
21
|
+
from ..data import ObsConfiguration
|
|
22
|
+
from ..model.abc import SpectralModel
|
|
23
|
+
from ..model.background import BackgroundModel
|
|
18
24
|
|
|
19
25
|
K = TypeVar("K")
|
|
20
26
|
V = TypeVar("V")
|
|
@@ -29,7 +35,6 @@ def _plot_binned_samples_with_error(
|
|
|
29
35
|
x_bins: ArrayLike,
|
|
30
36
|
denominator: ArrayLike | None = None,
|
|
31
37
|
y_samples: ArrayLike | None = None,
|
|
32
|
-
y_observed: ArrayLike | None = None,
|
|
33
38
|
color=(0.15, 0.25, 0.45),
|
|
34
39
|
percentile: tuple = (16, 84),
|
|
35
40
|
):
|
|
@@ -38,7 +43,8 @@ def _plot_binned_samples_with_error(
|
|
|
38
43
|
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
39
44
|
area. If the observed data is provided, it is also plotted as a step function.
|
|
40
45
|
|
|
41
|
-
Parameters
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
42
48
|
x_bins: The bin edges of the data (2 x N).
|
|
43
49
|
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
44
50
|
denominator: Values used to divided the samples, i.e. to get energy flux (N).
|
|
@@ -51,22 +57,15 @@ def _plot_binned_samples_with_error(
|
|
|
51
57
|
|
|
52
58
|
mean, envelope = None, None
|
|
53
59
|
|
|
54
|
-
if
|
|
55
|
-
|
|
60
|
+
if denominator is None:
|
|
61
|
+
denominator = np.ones_like(x_bins[0])
|
|
56
62
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
(mean,) = ax.step(
|
|
65
|
-
list(x_bins[0]) + [x_bins[1][-1]], # x_bins[1][-1]+1],
|
|
66
|
-
list(y_observed / denominator) + [np.nan], # + [np.nan, np.nan],
|
|
67
|
-
where="pre",
|
|
68
|
-
c=color,
|
|
69
|
-
)
|
|
63
|
+
mean = ax.stairs(
|
|
64
|
+
list(np.median(y_samples, axis=0) / denominator),
|
|
65
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
66
|
+
color=color,
|
|
67
|
+
alpha=0.7,
|
|
68
|
+
)
|
|
70
69
|
|
|
71
70
|
if y_samples is not None:
|
|
72
71
|
if denominator is None:
|
|
@@ -77,48 +76,21 @@ def _plot_binned_samples_with_error(
|
|
|
77
76
|
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
78
77
|
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
79
78
|
|
|
80
|
-
ax.
|
|
81
|
-
|
|
82
|
-
list(
|
|
83
|
-
|
|
79
|
+
ax.stairs(
|
|
80
|
+
percentiles[1] / denominator,
|
|
81
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
82
|
+
baseline=percentiles[0] / denominator,
|
|
84
83
|
alpha=0.3,
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
fill=True,
|
|
85
|
+
color=color,
|
|
87
86
|
)
|
|
88
87
|
|
|
89
88
|
return [(mean, envelope)]
|
|
90
89
|
|
|
91
90
|
|
|
92
|
-
def format_parameters(parameter_name):
|
|
93
|
-
computed_parameters = ["Photon flux", "Energy flux", "Luminosity"]
|
|
94
|
-
|
|
95
|
-
if parameter_name == "weight":
|
|
96
|
-
# ChainConsumer add a weight column to the samples
|
|
97
|
-
return parameter_name
|
|
98
|
-
|
|
99
|
-
for parameter in computed_parameters:
|
|
100
|
-
if parameter in parameter_name:
|
|
101
|
-
return parameter_name
|
|
102
|
-
|
|
103
|
-
# Find second occurrence of the character '_'
|
|
104
|
-
first_occurrence = parameter_name.find("_")
|
|
105
|
-
second_occurrence = parameter_name.find("_", first_occurrence + 1)
|
|
106
|
-
module = parameter_name[:second_occurrence]
|
|
107
|
-
parameter = parameter_name[second_occurrence + 1 :]
|
|
108
|
-
|
|
109
|
-
name, number = module.split("_")
|
|
110
|
-
module = rf"[{name.capitalize()} ({number})]"
|
|
111
|
-
|
|
112
|
-
if parameter == "norm":
|
|
113
|
-
return r"Norm " + module
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
return rf"${parameter}$" + module
|
|
117
|
-
|
|
118
|
-
|
|
119
91
|
class FitResult:
|
|
120
92
|
"""
|
|
121
|
-
|
|
93
|
+
Container for the result of a fit using any ModelFitter class.
|
|
122
94
|
"""
|
|
123
95
|
|
|
124
96
|
# TODO : Add type hints
|
|
@@ -133,7 +105,9 @@ class FitResult:
|
|
|
133
105
|
self.model = model
|
|
134
106
|
self._structure = structure
|
|
135
107
|
self.inference_data = inference_data
|
|
136
|
-
self.obsconfs =
|
|
108
|
+
self.obsconfs = (
|
|
109
|
+
{"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
|
|
110
|
+
)
|
|
137
111
|
self.background_model = background_model
|
|
138
112
|
self._structure = structure
|
|
139
113
|
|
|
@@ -146,39 +120,70 @@ class FitResult:
|
|
|
146
120
|
|
|
147
121
|
@property
|
|
148
122
|
def converged(self) -> bool:
|
|
149
|
-
"""
|
|
123
|
+
r"""
|
|
150
124
|
Convergence of the chain as computed by the $\hat{R}$ statistic.
|
|
151
125
|
"""
|
|
152
126
|
|
|
153
127
|
return all(az.rhat(self.inference_data) < 1.01)
|
|
154
128
|
|
|
129
|
+
@property
|
|
130
|
+
def _structured_samples(self):
|
|
131
|
+
"""
|
|
132
|
+
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
136
|
+
posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
|
|
137
|
+
samples_flat = {key: posterior[key].data for key in var_names}
|
|
138
|
+
|
|
139
|
+
samples_haiku = {}
|
|
140
|
+
|
|
141
|
+
for module, parameter, value in traverse(self._structure):
|
|
142
|
+
if samples_haiku.get(module, None) is None:
|
|
143
|
+
samples_haiku[module] = {}
|
|
144
|
+
samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
|
|
145
|
+
|
|
146
|
+
return samples_haiku
|
|
147
|
+
|
|
155
148
|
def photon_flux(
|
|
156
149
|
self,
|
|
157
150
|
e_min: float,
|
|
158
151
|
e_max: float,
|
|
159
152
|
unit: Unit = u.photon / u.cm**2 / u.s,
|
|
153
|
+
register: bool = False,
|
|
160
154
|
) -> ArrayLike:
|
|
161
155
|
"""
|
|
162
156
|
Compute the unfolded photon flux in a given energy band. The flux is then added to
|
|
163
157
|
the result parameters so covariance can be plotted.
|
|
164
158
|
|
|
165
|
-
Parameters
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
166
161
|
e_min: The lower bound of the energy band in observer frame.
|
|
167
162
|
e_max: The upper bound of the energy band in observer frame.
|
|
168
163
|
unit: The unit of the photon flux.
|
|
164
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
169
165
|
|
|
170
166
|
!!! warning
|
|
171
167
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
172
168
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
173
169
|
"""
|
|
174
170
|
|
|
175
|
-
|
|
171
|
+
samples = self._structured_samples
|
|
172
|
+
init_shape = jax.tree.leaves(samples)[0].shape
|
|
176
173
|
|
|
177
|
-
|
|
174
|
+
flux = jax.vmap(
|
|
175
|
+
lambda p: self.model.photon_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
|
|
176
|
+
)(jax.tree.map(lambda x: x.ravel(), samples))
|
|
178
177
|
|
|
178
|
+
flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
|
|
179
|
+
conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
|
|
179
180
|
value = flux * conversion_factor
|
|
180
|
-
|
|
181
|
-
|
|
181
|
+
|
|
182
|
+
if register:
|
|
183
|
+
self.inference_data.posterior[f"flux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
184
|
+
["chain", "draw"],
|
|
185
|
+
value,
|
|
186
|
+
)
|
|
182
187
|
|
|
183
188
|
return value
|
|
184
189
|
|
|
@@ -187,29 +192,41 @@ class FitResult:
|
|
|
187
192
|
e_min: float,
|
|
188
193
|
e_max: float,
|
|
189
194
|
unit: Unit = u.erg / u.cm**2 / u.s,
|
|
195
|
+
register: bool = False,
|
|
190
196
|
) -> ArrayLike:
|
|
191
197
|
"""
|
|
192
198
|
Compute the unfolded energy flux in a given energy band. The flux is then added to
|
|
193
199
|
the result parameters so covariance can be plotted.
|
|
194
200
|
|
|
195
|
-
Parameters
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
196
203
|
e_min: The lower bound of the energy band in observer frame.
|
|
197
204
|
e_max: The upper bound of the energy band in observer frame.
|
|
198
205
|
unit: The unit of the energy flux.
|
|
206
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
199
207
|
|
|
200
208
|
!!! warning
|
|
201
209
|
Computation of the folded flux is not implemented yet. Feel free to open an
|
|
202
210
|
[issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
|
|
203
211
|
"""
|
|
204
212
|
|
|
205
|
-
|
|
213
|
+
samples = self._structured_samples
|
|
214
|
+
init_shape = jax.tree.leaves(samples)[0].shape
|
|
206
215
|
|
|
207
|
-
|
|
216
|
+
flux = jax.vmap(
|
|
217
|
+
lambda p: self.model.energy_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
|
|
218
|
+
)(jax.tree.map(lambda x: x.ravel(), samples))
|
|
219
|
+
|
|
220
|
+
flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
|
|
208
221
|
|
|
222
|
+
conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
|
|
209
223
|
value = flux * conversion_factor
|
|
210
224
|
|
|
211
|
-
|
|
212
|
-
|
|
225
|
+
if register:
|
|
226
|
+
self.inference_data.posterior[f"eflux_{e_min:.1f}_{e_max:.1f}"] = (
|
|
227
|
+
["chain", "draw"],
|
|
228
|
+
value,
|
|
229
|
+
)
|
|
213
230
|
|
|
214
231
|
return value
|
|
215
232
|
|
|
@@ -217,57 +234,84 @@ class FitResult:
|
|
|
217
234
|
self,
|
|
218
235
|
e_min: float,
|
|
219
236
|
e_max: float,
|
|
220
|
-
redshift: float | ArrayLike = 0,
|
|
237
|
+
redshift: float | ArrayLike = 0.1,
|
|
221
238
|
observer_frame: bool = True,
|
|
222
239
|
cosmology: Cosmology = Planck18,
|
|
223
240
|
unit: Unit = u.erg / u.s,
|
|
241
|
+
register: bool = False,
|
|
224
242
|
) -> ArrayLike:
|
|
225
243
|
"""
|
|
226
244
|
Compute the luminosity of the source specifying its redshift. The luminosity is then added to
|
|
227
245
|
the result parameters so covariance can be plotted.
|
|
228
246
|
|
|
229
|
-
Parameters
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
230
249
|
e_min: The lower bound of the energy band.
|
|
231
250
|
e_max: The upper bound of the energy band.
|
|
232
251
|
redshift: The redshift of the source. It can be a distribution of redshifts.
|
|
233
252
|
observer_frame: Whether the input bands are defined in observer frame or not.
|
|
234
253
|
cosmology: Chosen cosmology.
|
|
235
254
|
unit: The unit of the luminosity.
|
|
255
|
+
register: Whether to register the flux with the other posterior parameters.
|
|
236
256
|
"""
|
|
237
257
|
|
|
238
258
|
if not observer_frame:
|
|
239
259
|
raise NotImplementedError()
|
|
240
260
|
|
|
241
|
-
|
|
261
|
+
samples = self._structured_samples
|
|
262
|
+
init_shape = jax.tree.leaves(samples)[0].shape
|
|
242
263
|
|
|
264
|
+
flux = jax.vmap(
|
|
265
|
+
lambda p: self.model.energy_flux(
|
|
266
|
+
p, jnp.asarray([e_min]) * (1 + redshift), jnp.asarray([e_max])
|
|
267
|
+
)
|
|
268
|
+
* (1 + redshift)
|
|
269
|
+
)(jax.tree.map(lambda x: x.ravel(), samples))
|
|
270
|
+
|
|
271
|
+
flux = jax.tree.map(
|
|
272
|
+
lambda x: np.asarray(x.reshape(init_shape)) * (u.keV / u.cm**2 / u.s), flux
|
|
273
|
+
)
|
|
243
274
|
value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
|
|
244
275
|
|
|
245
|
-
|
|
246
|
-
|
|
276
|
+
if register:
|
|
277
|
+
self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
|
|
278
|
+
["chain", "draw"],
|
|
279
|
+
value,
|
|
280
|
+
)
|
|
247
281
|
|
|
248
282
|
return value
|
|
249
283
|
|
|
250
|
-
def to_chain(self, name: str,
|
|
284
|
+
def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
|
|
251
285
|
"""
|
|
252
|
-
Return a ChainConsumer Chain object from the posterior distribution of the
|
|
286
|
+
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
253
287
|
|
|
254
|
-
Parameters
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
255
290
|
name: The name of the chain.
|
|
256
|
-
|
|
291
|
+
parameters_type: The parameters_type to include in the chain.
|
|
257
292
|
"""
|
|
258
293
|
|
|
259
294
|
obs_id = self.inference_data.copy()
|
|
260
295
|
|
|
261
|
-
if
|
|
262
|
-
keys_to_drop = [
|
|
263
|
-
|
|
296
|
+
if parameters_type == "model":
|
|
297
|
+
keys_to_drop = [
|
|
298
|
+
key
|
|
299
|
+
for key in obs_id.posterior.keys()
|
|
300
|
+
if (key.startswith("_") or key.startswith("bkg"))
|
|
301
|
+
]
|
|
302
|
+
elif parameters_type == "bkg":
|
|
264
303
|
keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
|
|
265
304
|
else:
|
|
266
|
-
raise ValueError(f"Unknown value for
|
|
305
|
+
raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
|
|
267
306
|
|
|
268
307
|
obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
|
|
269
308
|
chain = Chain.from_arviz(obs_id, name)
|
|
270
|
-
|
|
309
|
+
|
|
310
|
+
"""
|
|
311
|
+
chain.samples.columns = [
|
|
312
|
+
format_parameters(parameter) for parameter in chain.samples.columns
|
|
313
|
+
]
|
|
314
|
+
"""
|
|
271
315
|
|
|
272
316
|
return chain
|
|
273
317
|
|
|
@@ -304,7 +348,7 @@ class FitResult:
|
|
|
304
348
|
for module, parameter, value in traverse(self._structure):
|
|
305
349
|
if params.get(module, None) is None:
|
|
306
350
|
params[module] = {}
|
|
307
|
-
params[module][parameter] = self.
|
|
351
|
+
params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
|
|
308
352
|
|
|
309
353
|
return params
|
|
310
354
|
|
|
@@ -328,19 +372,50 @@ class FitResult:
|
|
|
328
372
|
return {key: posterior[key].data for key in var_names}
|
|
329
373
|
|
|
330
374
|
@property
|
|
331
|
-
def
|
|
375
|
+
def log_likelihood(self) -> xr.Dataset:
|
|
332
376
|
"""
|
|
333
|
-
Return the
|
|
377
|
+
Return the log_likelihood of each observation
|
|
334
378
|
"""
|
|
335
379
|
log_likelihood = az.extract(self.inference_data, group="log_likelihood")
|
|
336
|
-
dimensions_to_reduce = [
|
|
380
|
+
dimensions_to_reduce = [
|
|
381
|
+
coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]
|
|
382
|
+
]
|
|
337
383
|
return log_likelihood.sum(dimensions_to_reduce)
|
|
338
384
|
|
|
385
|
+
@property
|
|
386
|
+
def c_stat(self):
|
|
387
|
+
r"""
|
|
388
|
+
Return the C-statistic of the model
|
|
389
|
+
|
|
390
|
+
The C-statistic is defined as:
|
|
391
|
+
|
|
392
|
+
$$ C = 2 \sum_{i} M - D*log(M) + D*log(D) - D $$
|
|
393
|
+
or
|
|
394
|
+
$$ C = 2 \sum_{i} M - D*log(M)$$
|
|
395
|
+
for bins with no counts
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
exclude_dims = ["chain", "draw", "sample"]
|
|
400
|
+
all_dims = list(self.inference_data.log_likelihood.dims)
|
|
401
|
+
reduce_dims = [dim for dim in all_dims if dim not in exclude_dims]
|
|
402
|
+
data = self.inference_data.observed_data
|
|
403
|
+
c_stat = -2 * (
|
|
404
|
+
self.log_likelihood
|
|
405
|
+
+ (gammaln(data + 1) - (xr.where(data > 0, data * (np.log(data) - 1), 0))).sum(
|
|
406
|
+
dim=reduce_dims
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return c_stat
|
|
411
|
+
|
|
339
412
|
def plot_ppc(
|
|
340
413
|
self,
|
|
341
|
-
percentile:
|
|
414
|
+
percentile: tuple[int, int] = (16, 84),
|
|
342
415
|
x_unit: str | u.Unit = "keV",
|
|
343
|
-
y_type: Literal[
|
|
416
|
+
y_type: Literal[
|
|
417
|
+
"counts", "countrate", "photon_flux", "photon_flux_density"
|
|
418
|
+
] = "photon_flux_density",
|
|
344
419
|
) -> plt.Figure:
|
|
345
420
|
r"""
|
|
346
421
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
@@ -349,12 +424,14 @@ class FitResult:
|
|
|
349
424
|
$$ \text{Residual} = \frac{\text{Observed counts} - \text{Posterior counts}}
|
|
350
425
|
{(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
|
|
351
426
|
|
|
352
|
-
Parameters
|
|
427
|
+
Parameters
|
|
428
|
+
----------
|
|
353
429
|
percentile: The percentile of the posterior predictive distribution to plot.
|
|
354
430
|
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
355
431
|
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
356
432
|
|
|
357
433
|
Returns:
|
|
434
|
+
-------
|
|
358
435
|
The matplotlib figure.
|
|
359
436
|
"""
|
|
360
437
|
|
|
@@ -383,21 +460,31 @@ class FitResult:
|
|
|
383
460
|
# and enable weird broadcasting that makes the plot fail
|
|
384
461
|
|
|
385
462
|
fig, axs = plt.subplots(
|
|
386
|
-
2,
|
|
463
|
+
2,
|
|
464
|
+
len(obsconf_container),
|
|
465
|
+
figsize=(6 * len(obsconf_container), 6),
|
|
466
|
+
sharex=True,
|
|
467
|
+
height_ratios=[0.7, 0.3],
|
|
387
468
|
)
|
|
388
469
|
|
|
389
470
|
plot_ylabels_once = True
|
|
390
471
|
|
|
391
472
|
for name, obsconf, ax in zip(
|
|
392
|
-
obsconf_container.keys(),
|
|
473
|
+
obsconf_container.keys(),
|
|
474
|
+
obsconf_container.values(),
|
|
475
|
+
axs.T if len(obsconf_container) > 1 else [axs],
|
|
393
476
|
):
|
|
394
477
|
legend_plots = []
|
|
395
478
|
legend_labels = []
|
|
396
|
-
count = az.extract(
|
|
479
|
+
count = az.extract(
|
|
480
|
+
self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
|
|
481
|
+
).values.T
|
|
397
482
|
bkg_count = (
|
|
398
483
|
None
|
|
399
484
|
if self.background_model is None
|
|
400
|
-
else az.extract(
|
|
485
|
+
else az.extract(
|
|
486
|
+
self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
|
|
487
|
+
).values.T
|
|
401
488
|
)
|
|
402
489
|
|
|
403
490
|
xbins = obsconf.out_energies * u.keV
|
|
@@ -413,18 +500,13 @@ class FitResult:
|
|
|
413
500
|
integrated_arf = (
|
|
414
501
|
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
415
502
|
/ (
|
|
416
|
-
np.abs(
|
|
503
|
+
np.abs(
|
|
504
|
+
xbins[1] - xbins[0]
|
|
505
|
+
) # Must fold in abs because some units reverse the ordering of the bins
|
|
417
506
|
)
|
|
418
507
|
* u.cm**2
|
|
419
508
|
)
|
|
420
509
|
|
|
421
|
-
"""
|
|
422
|
-
if xbins[0][0] < 1 < xbins[1][-1]:
|
|
423
|
-
xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
|
|
424
|
-
else:
|
|
425
|
-
xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
|
|
426
|
-
"""
|
|
427
|
-
|
|
428
510
|
match y_type:
|
|
429
511
|
case "counts":
|
|
430
512
|
denominator = 1
|
|
@@ -437,50 +519,93 @@ class FitResult:
|
|
|
437
519
|
|
|
438
520
|
y_samples = (count * u.photon / denominator).to(y_units)
|
|
439
521
|
y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
|
|
522
|
+
y_observed_low = (
|
|
523
|
+
nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
|
|
524
|
+
* u.photon
|
|
525
|
+
/ denominator
|
|
526
|
+
).to(y_units)
|
|
527
|
+
y_observed_high = (
|
|
528
|
+
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
529
|
+
* u.photon
|
|
530
|
+
/ denominator
|
|
531
|
+
).to(y_units)
|
|
440
532
|
|
|
441
533
|
# Use the helper function to plot the data and posterior predictive
|
|
442
534
|
legend_plots += _plot_binned_samples_with_error(
|
|
443
535
|
ax[0],
|
|
444
536
|
xbins.value,
|
|
445
537
|
y_samples=y_samples.value,
|
|
446
|
-
y_observed=y_observed.value,
|
|
447
538
|
denominator=np.ones_like(y_observed).value,
|
|
448
539
|
color=color,
|
|
449
540
|
percentile=percentile,
|
|
450
541
|
)
|
|
451
542
|
|
|
452
|
-
legend_labels.append("
|
|
543
|
+
legend_labels.append("Model")
|
|
544
|
+
|
|
545
|
+
true_data_plot = ax[0].errorbar(
|
|
546
|
+
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
547
|
+
y_observed.value,
|
|
548
|
+
xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
|
|
549
|
+
yerr=[
|
|
550
|
+
y_observed.value - y_observed_low.value,
|
|
551
|
+
y_observed_high.value - y_observed.value,
|
|
552
|
+
],
|
|
553
|
+
color="black",
|
|
554
|
+
linestyle="none",
|
|
555
|
+
alpha=0.3,
|
|
556
|
+
capsize=2,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
legend_plots.append((true_data_plot,))
|
|
560
|
+
legend_labels.append("Observed")
|
|
453
561
|
|
|
454
562
|
if self.background_model is not None:
|
|
455
563
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
456
564
|
ratio = obsconf.folded_backratio.data
|
|
457
565
|
y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
|
|
458
|
-
y_observed_bkg = (
|
|
566
|
+
y_observed_bkg = (
|
|
567
|
+
obsconf.folded_background.data * u.photon / (denominator * ratio)
|
|
568
|
+
).to(y_units)
|
|
459
569
|
legend_plots += _plot_binned_samples_with_error(
|
|
460
570
|
ax[0],
|
|
461
571
|
xbins.value,
|
|
462
572
|
y_samples=y_samples_bkg.value,
|
|
463
|
-
y_observed=y_observed_bkg.value,
|
|
464
573
|
denominator=np.ones_like(y_observed).value,
|
|
465
574
|
color=(0.26787604, 0.60085972, 0.63302651),
|
|
466
575
|
percentile=percentile,
|
|
467
576
|
)
|
|
468
577
|
|
|
469
|
-
legend_labels.append("
|
|
578
|
+
legend_labels.append("Model (bkg)")
|
|
579
|
+
|
|
580
|
+
residual_samples = (obsconf.folded_counts.data - count) / np.diff(
|
|
581
|
+
np.percentile(count, percentile, axis=0), axis=0
|
|
582
|
+
)
|
|
470
583
|
|
|
471
584
|
residuals = np.percentile(
|
|
472
|
-
|
|
585
|
+
residual_samples,
|
|
473
586
|
percentile,
|
|
474
587
|
axis=0,
|
|
475
588
|
)
|
|
476
589
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
590
|
+
median_residuals = np.median(
|
|
591
|
+
residual_samples,
|
|
592
|
+
axis=0,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
ax[1].stairs(
|
|
596
|
+
residuals[1],
|
|
597
|
+
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
598
|
+
baseline=list(residuals[0]),
|
|
481
599
|
alpha=0.3,
|
|
482
|
-
step="post",
|
|
483
600
|
facecolor=color,
|
|
601
|
+
fill=True,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
ax[1].stairs(
|
|
605
|
+
median_residuals,
|
|
606
|
+
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
607
|
+
color=color,
|
|
608
|
+
alpha=0.7,
|
|
484
609
|
)
|
|
485
610
|
|
|
486
611
|
max_residuals = np.max(np.abs(residuals))
|
|
@@ -502,7 +627,8 @@ class FitResult:
|
|
|
502
627
|
ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
|
|
503
628
|
case _:
|
|
504
629
|
RuntimeError(
|
|
505
|
-
f"Unknown physical type for x_units: {x_unit}. "
|
|
630
|
+
f"Unknown physical type for x_units: {x_unit}. "
|
|
631
|
+
f"Must be 'length', 'energy' or 'frequency'"
|
|
506
632
|
)
|
|
507
633
|
|
|
508
634
|
ax[1].axhline(0, color=color, ls="--")
|
|
@@ -536,15 +662,18 @@ class FitResult:
|
|
|
536
662
|
|
|
537
663
|
def plot_corner(
|
|
538
664
|
self,
|
|
539
|
-
config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=
|
|
665
|
+
config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=12),
|
|
540
666
|
**kwargs: Any,
|
|
541
667
|
) -> plt.Figure:
|
|
542
668
|
"""
|
|
543
|
-
Plot the corner plot of the posterior distribution of the
|
|
669
|
+
Plot the corner plot of the posterior distribution of the parameters_type. This method uses the ChainConsumer.
|
|
544
670
|
|
|
545
|
-
Parameters
|
|
671
|
+
Parameters
|
|
672
|
+
----------
|
|
546
673
|
config: The configuration of the plot.
|
|
547
|
-
|
|
674
|
+
parameters: The parameters to include in the plot using the following format: `blackbody_1_kT`.
|
|
675
|
+
**kwargs: Additional arguments passed to ChainConsumer.plotter.plot. Some useful parameters are :
|
|
676
|
+
- columns : list of parameters to plot.
|
|
548
677
|
"""
|
|
549
678
|
|
|
550
679
|
consumer = ChainConsumer()
|
jaxspec/data/__init__.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# precommit is suppressing these imports
|
|
2
|
-
from .obsconf import ObsConfiguration # noqa: F401
|
|
3
|
-
from .instrument import Instrument # noqa: F401
|
|
4
|
-
from .observation import Observation # noqa: F401
|
|
5
1
|
import astropy.units as u
|
|
6
2
|
|
|
3
|
+
from .instrument import Instrument
|
|
4
|
+
from .obsconf import ObsConfiguration
|
|
5
|
+
from .observation import Observation
|
|
6
|
+
|
|
7
7
|
u.add_enabled_aliases({"counts": u.count})
|
|
8
8
|
u.add_enabled_aliases({"channel": u.dimensionless_unscaled})
|
|
9
9
|
# Arbitrary units are found in .rsp files , let's hope it is compatible with what we would expect as the rmf x arf
|