jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/__init__.py +0 -0
- jaxspec/_fit/_build_model.py +63 -0
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +238 -336
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +68 -11
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +101 -140
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +94 -87
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/METADATA +36 -16
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.3.dist-info/RECORD +0 -31
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -1,25 +1,37 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from functools import cached_property
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
5
5
|
|
|
6
6
|
import arviz as az
|
|
7
|
+
import astropy.cosmology.units as cu
|
|
7
8
|
import astropy.units as u
|
|
8
9
|
import jax
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
import matplotlib.pyplot as plt
|
|
11
12
|
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
12
14
|
import xarray as xr
|
|
13
15
|
|
|
14
16
|
from astropy.cosmology import Cosmology, Planck18
|
|
15
17
|
from astropy.units import Unit
|
|
16
18
|
from chainconsumer import Chain, ChainConsumer, PlotConfig
|
|
17
|
-
from
|
|
19
|
+
from jax.experimental.sparse import BCOO
|
|
18
20
|
from jax.typing import ArrayLike
|
|
19
21
|
from numpyro.handlers import seed
|
|
20
|
-
from scipy.integrate import trapezoid
|
|
21
22
|
from scipy.special import gammaln
|
|
22
|
-
|
|
23
|
+
|
|
24
|
+
from ._plot import (
|
|
25
|
+
BACKGROUND_COLOR,
|
|
26
|
+
BACKGROUND_DATA_COLOR,
|
|
27
|
+
COLOR_CYCLE,
|
|
28
|
+
SPECTRUM_COLOR,
|
|
29
|
+
SPECTRUM_DATA_COLOR,
|
|
30
|
+
_compute_effective_area,
|
|
31
|
+
_error_bars_for_observed_data,
|
|
32
|
+
_plot_binned_samples_with_error,
|
|
33
|
+
_plot_poisson_data_with_error,
|
|
34
|
+
)
|
|
23
35
|
|
|
24
36
|
if TYPE_CHECKING:
|
|
25
37
|
from ..fit import BayesianModel
|
|
@@ -30,67 +42,6 @@ V = TypeVar("V")
|
|
|
30
42
|
T = TypeVar("T")
|
|
31
43
|
|
|
32
44
|
|
|
33
|
-
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _plot_binned_samples_with_error(
|
|
37
|
-
ax: plt.Axes,
|
|
38
|
-
x_bins: ArrayLike,
|
|
39
|
-
denominator: ArrayLike | None = None,
|
|
40
|
-
y_samples: ArrayLike | None = None,
|
|
41
|
-
color=(0.15, 0.25, 0.45),
|
|
42
|
-
percentile: tuple = (16, 84),
|
|
43
|
-
):
|
|
44
|
-
"""
|
|
45
|
-
Helper function to plot the posterior predictive distribution of the model. The function
|
|
46
|
-
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
47
|
-
area. If the observed data is provided, it is also plotted as a step function.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
x_bins: The bin edges of the data (2 x N).
|
|
52
|
-
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
53
|
-
denominator: Values used to divided the samples, i.e. to get energy flux (N).
|
|
54
|
-
ax: The matplotlib axes object.
|
|
55
|
-
color: The color of the posterior predictive distribution.
|
|
56
|
-
y_observed: The observed data (N).
|
|
57
|
-
label: The label of the observed data.
|
|
58
|
-
percentile: The percentile of the posterior predictive distribution to plot.
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
mean, envelope = None, None
|
|
62
|
-
|
|
63
|
-
if denominator is None:
|
|
64
|
-
denominator = np.ones_like(x_bins[0])
|
|
65
|
-
|
|
66
|
-
mean = ax.stairs(
|
|
67
|
-
list(np.median(y_samples, axis=0) / denominator),
|
|
68
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
69
|
-
color=color,
|
|
70
|
-
alpha=0.7,
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
if y_samples is not None:
|
|
74
|
-
if denominator is None:
|
|
75
|
-
denominator = np.ones_like(x_bins[0])
|
|
76
|
-
|
|
77
|
-
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
78
|
-
|
|
79
|
-
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
80
|
-
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
81
|
-
|
|
82
|
-
ax.stairs(
|
|
83
|
-
percentiles[1] / denominator,
|
|
84
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
85
|
-
baseline=percentiles[0] / denominator,
|
|
86
|
-
alpha=0.3,
|
|
87
|
-
fill=True,
|
|
88
|
-
color=color,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
return [(mean, envelope)]
|
|
92
|
-
|
|
93
|
-
|
|
94
45
|
class FitResult:
|
|
95
46
|
"""
|
|
96
47
|
Container for the result of a fit using any ModelFitter class.
|
|
@@ -101,7 +52,6 @@ class FitResult:
|
|
|
101
52
|
self,
|
|
102
53
|
bayesian_fitter: BayesianModel,
|
|
103
54
|
inference_data: az.InferenceData,
|
|
104
|
-
structure: Mapping[K, V],
|
|
105
55
|
background_model: BackgroundModel = None,
|
|
106
56
|
):
|
|
107
57
|
self.model = bayesian_fitter.model
|
|
@@ -109,7 +59,6 @@ class FitResult:
|
|
|
109
59
|
self.inference_data = inference_data
|
|
110
60
|
self.obsconfs = bayesian_fitter.observation_container
|
|
111
61
|
self.background_model = background_model
|
|
112
|
-
self._structure = structure
|
|
113
62
|
|
|
114
63
|
# Add the model used in fit to the metadata
|
|
115
64
|
for group in self.inference_data.groups():
|
|
@@ -126,37 +75,38 @@ class FitResult:
|
|
|
126
75
|
|
|
127
76
|
return all(az.rhat(self.inference_data) < 1.01)
|
|
128
77
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
"""
|
|
132
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
133
|
-
"""
|
|
134
|
-
|
|
135
|
-
samples_flat = self._structured_samples_flat
|
|
78
|
+
def _ppc_folded_branches(self, obs_id):
|
|
79
|
+
obs = self.obsconfs[obs_id]
|
|
136
80
|
|
|
137
|
-
|
|
81
|
+
if len(next(iter(self.input_parameters.values())).shape) > 2:
|
|
82
|
+
idx = list(self.obsconfs.keys()).index(obs_id)
|
|
83
|
+
obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
|
|
138
84
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
samples_haiku[module] = {}
|
|
142
|
-
samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
|
|
85
|
+
else:
|
|
86
|
+
obs_parameters = self.input_parameters
|
|
143
87
|
|
|
144
|
-
|
|
88
|
+
if self.bayesian_fitter.sparse:
|
|
89
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
90
|
+
obs.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
91
|
+
)
|
|
145
92
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
"""
|
|
149
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
150
|
-
"""
|
|
93
|
+
else:
|
|
94
|
+
transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
|
|
151
95
|
|
|
152
|
-
|
|
153
|
-
posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
|
|
154
|
-
samples_flat = {key: posterior[key].data for key in var_names}
|
|
96
|
+
energies = obs.in_energies
|
|
155
97
|
|
|
156
|
-
|
|
98
|
+
flux_func = jax.jit(
|
|
99
|
+
jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
|
|
100
|
+
)
|
|
101
|
+
convolve_func = jax.jit(
|
|
102
|
+
jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
|
|
103
|
+
)
|
|
104
|
+
return jax.tree.map(
|
|
105
|
+
lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
|
|
106
|
+
)
|
|
157
107
|
|
|
158
|
-
@
|
|
159
|
-
def input_parameters(self) ->
|
|
108
|
+
@cached_property
|
|
109
|
+
def input_parameters(self) -> dict[str, ArrayLike]:
|
|
160
110
|
"""
|
|
161
111
|
The input parameters of the model.
|
|
162
112
|
"""
|
|
@@ -172,7 +122,9 @@ class FitResult:
|
|
|
172
122
|
with seed(rng_seed=0):
|
|
173
123
|
input_parameters = self.bayesian_fitter.prior_distributions_func()
|
|
174
124
|
|
|
175
|
-
for
|
|
125
|
+
for key, value in input_parameters.items():
|
|
126
|
+
module, parameter = key.rsplit("_", 1)
|
|
127
|
+
|
|
176
128
|
if f"{module}_{parameter}" in posterior.keys():
|
|
177
129
|
# We add as extra dimension as there might be different values per observation
|
|
178
130
|
if posterior[f"{module}_{parameter}"].shape == samples_shape:
|
|
@@ -180,19 +132,21 @@ class FitResult:
|
|
|
180
132
|
else:
|
|
181
133
|
to_set = posterior[f"{module}_{parameter}"]
|
|
182
134
|
|
|
183
|
-
input_parameters[module
|
|
135
|
+
input_parameters[f"{module}_{parameter}"] = to_set
|
|
184
136
|
|
|
185
137
|
else:
|
|
186
138
|
# The parameter is fixed in this case, so we just broadcast is over chain and draws
|
|
187
|
-
input_parameters[module
|
|
139
|
+
input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
|
|
188
140
|
|
|
189
|
-
if len(total_shape) < len(input_parameters[module
|
|
141
|
+
if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
|
|
190
142
|
# If there are only chains and draws, we reduce
|
|
191
|
-
input_parameters[module
|
|
143
|
+
input_parameters[f"{module}_{parameter}"] = input_parameters[
|
|
144
|
+
f"{module}_{parameter}"
|
|
145
|
+
][..., 0]
|
|
192
146
|
|
|
193
147
|
else:
|
|
194
|
-
input_parameters[module
|
|
195
|
-
input_parameters[module
|
|
148
|
+
input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
|
|
149
|
+
input_parameters[f"{module}_{parameter}"], total_shape
|
|
196
150
|
)
|
|
197
151
|
|
|
198
152
|
return input_parameters
|
|
@@ -287,7 +241,8 @@ class FitResult:
|
|
|
287
241
|
self,
|
|
288
242
|
e_min: float,
|
|
289
243
|
e_max: float,
|
|
290
|
-
redshift: float | ArrayLike =
|
|
244
|
+
redshift: float | ArrayLike = None,
|
|
245
|
+
distance: float | ArrayLike = None,
|
|
291
246
|
observer_frame: bool = True,
|
|
292
247
|
cosmology: Cosmology = Planck18,
|
|
293
248
|
unit: Unit = u.erg / u.s,
|
|
@@ -310,6 +265,17 @@ class FitResult:
|
|
|
310
265
|
if not observer_frame:
|
|
311
266
|
raise NotImplementedError()
|
|
312
267
|
|
|
268
|
+
if redshift is None and distance is None:
|
|
269
|
+
raise ValueError("Either redshift or distance must be specified.")
|
|
270
|
+
|
|
271
|
+
if distance is not None:
|
|
272
|
+
if redshift is not None:
|
|
273
|
+
raise ValueError("Redshift must be None as a distance is specified.")
|
|
274
|
+
else:
|
|
275
|
+
redshift = distance.to(
|
|
276
|
+
cu.redshift, cu.redshift_distance(cosmology, kind="luminosity")
|
|
277
|
+
).value
|
|
278
|
+
|
|
313
279
|
@jax.jit
|
|
314
280
|
@jnp.vectorize
|
|
315
281
|
def vectorized_flux(*pars):
|
|
@@ -333,94 +299,46 @@ class FitResult:
|
|
|
333
299
|
|
|
334
300
|
return value
|
|
335
301
|
|
|
336
|
-
def to_chain(self, name: str
|
|
302
|
+
def to_chain(self, name: str) -> Chain:
|
|
337
303
|
"""
|
|
338
304
|
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
339
305
|
|
|
340
306
|
Parameters:
|
|
341
307
|
name: The name of the chain.
|
|
342
|
-
parameters_type: The parameters_type to include in the chain.
|
|
343
308
|
"""
|
|
344
309
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
key
|
|
350
|
-
for key in obs_id.posterior.keys()
|
|
351
|
-
if (key.startswith("_") or key.startswith("bkg"))
|
|
352
|
-
]
|
|
353
|
-
elif parameters_type == "bkg":
|
|
354
|
-
keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
|
|
355
|
-
else:
|
|
356
|
-
raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
|
|
357
|
-
|
|
358
|
-
obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
|
|
359
|
-
chain = Chain.from_arviz(obs_id, name)
|
|
360
|
-
|
|
361
|
-
"""
|
|
362
|
-
chain.samples.columns = [
|
|
363
|
-
format_parameters(parameter) for parameter in chain.samples.columns
|
|
310
|
+
keys_to_drop = [
|
|
311
|
+
key
|
|
312
|
+
for key in self.inference_data.posterior.keys()
|
|
313
|
+
if (key.startswith("_") or key.startswith("bkg"))
|
|
364
314
|
]
|
|
365
|
-
"""
|
|
366
315
|
|
|
367
|
-
|
|
316
|
+
reduced_id = az.extract(
|
|
317
|
+
self.inference_data,
|
|
318
|
+
var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
|
|
319
|
+
group="posterior",
|
|
320
|
+
)
|
|
368
321
|
|
|
369
|
-
|
|
370
|
-
def samples_haiku(self) -> HaikuDict[ArrayLike]:
|
|
371
|
-
"""
|
|
372
|
-
Haiku-like structure for the samples e.g.
|
|
373
|
-
|
|
374
|
-
```
|
|
375
|
-
{
|
|
376
|
-
'powerlaw_1' :
|
|
377
|
-
{
|
|
378
|
-
'alpha': ...,
|
|
379
|
-
'amplitude': ...
|
|
380
|
-
},
|
|
381
|
-
|
|
382
|
-
'blackbody_1':
|
|
383
|
-
{
|
|
384
|
-
'kT': ...,
|
|
385
|
-
'norm': ...
|
|
386
|
-
},
|
|
387
|
-
|
|
388
|
-
'tbabs_1':
|
|
389
|
-
{
|
|
390
|
-
'nH': ...
|
|
391
|
-
}
|
|
392
|
-
}
|
|
393
|
-
```
|
|
322
|
+
df_list = []
|
|
394
323
|
|
|
395
|
-
|
|
324
|
+
for var, array in reduced_id.data_vars.items():
|
|
325
|
+
extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
|
|
396
326
|
|
|
397
|
-
|
|
327
|
+
if extra_dims:
|
|
328
|
+
dim = extra_dims[
|
|
329
|
+
0
|
|
330
|
+
] # We only support the case where the extra dimension comes from the observations
|
|
398
331
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
332
|
+
for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
|
|
333
|
+
df = array.loc[{dim: coord}].to_pandas()
|
|
334
|
+
df.name += f"\n[{obs_id}]"
|
|
335
|
+
df_list.append(df)
|
|
336
|
+
else:
|
|
337
|
+
df_list.append(array.to_pandas())
|
|
403
338
|
|
|
404
|
-
|
|
339
|
+
df = pd.concat(df_list, axis=1)
|
|
405
340
|
|
|
406
|
-
|
|
407
|
-
def samples_flat(self) -> dict[str, ArrayLike]:
|
|
408
|
-
"""
|
|
409
|
-
Flat structure for the samples e.g.
|
|
410
|
-
|
|
411
|
-
```
|
|
412
|
-
{
|
|
413
|
-
'powerlaw_1_alpha': ...,
|
|
414
|
-
'powerlaw_1_amplitude': ...,
|
|
415
|
-
'blackbody_1_kT': ...,
|
|
416
|
-
'blackbody_1_norm': ...,
|
|
417
|
-
'tbabs_1_nH': ...,
|
|
418
|
-
}
|
|
419
|
-
```
|
|
420
|
-
"""
|
|
421
|
-
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
422
|
-
posterior = az.extract(self.inference_data, var_names=var_names)
|
|
423
|
-
return {key: posterior[key].data for key in var_names}
|
|
341
|
+
return Chain(samples=df, name=name)
|
|
424
342
|
|
|
425
343
|
@property
|
|
426
344
|
def log_likelihood(self) -> xr.Dataset:
|
|
@@ -462,12 +380,18 @@ class FitResult:
|
|
|
462
380
|
|
|
463
381
|
def plot_ppc(
|
|
464
382
|
self,
|
|
465
|
-
|
|
383
|
+
n_sigmas: int = 1,
|
|
466
384
|
x_unit: str | u.Unit = "keV",
|
|
467
385
|
y_type: Literal[
|
|
468
386
|
"counts", "countrate", "photon_flux", "photon_flux_density"
|
|
469
387
|
] = "photon_flux_density",
|
|
470
|
-
|
|
388
|
+
plot_background: bool = True,
|
|
389
|
+
plot_components: bool = False,
|
|
390
|
+
scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
|
|
391
|
+
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
392
|
+
style: str | Any = "default",
|
|
393
|
+
title: str | None = None,
|
|
394
|
+
) -> list[plt.Figure]:
|
|
471
395
|
r"""
|
|
472
396
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
473
397
|
following formula:
|
|
@@ -479,82 +403,52 @@ class FitResult:
|
|
|
479
403
|
percentile: The percentile of the posterior predictive distribution to plot.
|
|
480
404
|
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
481
405
|
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
406
|
+
plot_background: Whether to plot the background model if it is included in the fit.
|
|
407
|
+
plot_components: Whether to plot the components of the model separately.
|
|
408
|
+
scale: The axes scaling
|
|
409
|
+
alpha_envelope: The transparency range for envelops
|
|
410
|
+
style: The style of the plot. It can be either a string or a matplotlib style context.
|
|
482
411
|
|
|
483
412
|
Returns:
|
|
484
|
-
|
|
413
|
+
A list of matplotlib figures for each observation in the model.
|
|
485
414
|
"""
|
|
486
415
|
|
|
487
416
|
obsconf_container = self.obsconfs
|
|
417
|
+
figure_list = []
|
|
488
418
|
x_unit = u.Unit(x_unit)
|
|
489
419
|
|
|
490
420
|
match y_type:
|
|
491
421
|
case "counts":
|
|
492
|
-
y_units = u.
|
|
422
|
+
y_units = u.ct
|
|
493
423
|
case "countrate":
|
|
494
|
-
y_units = u.
|
|
424
|
+
y_units = u.ct / u.s
|
|
495
425
|
case "photon_flux":
|
|
496
|
-
y_units = u.
|
|
426
|
+
y_units = u.ct / u.cm**2 / u.s
|
|
497
427
|
case "photon_flux_density":
|
|
498
|
-
y_units = u.
|
|
428
|
+
y_units = u.ct / u.cm**2 / u.s / x_unit
|
|
499
429
|
case _:
|
|
500
430
|
raise ValueError(
|
|
501
431
|
f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
|
|
502
432
|
)
|
|
503
433
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
len(obsconf_container),
|
|
514
|
-
figsize=(6 * len(obsconf_container), 6),
|
|
515
|
-
sharex=True,
|
|
516
|
-
height_ratios=[0.7, 0.3],
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
plot_ylabels_once = True
|
|
434
|
+
with plt.style.context(style):
|
|
435
|
+
for obs_id, obsconf in obsconf_container.items():
|
|
436
|
+
fig, ax = plt.subplots(
|
|
437
|
+
2,
|
|
438
|
+
1,
|
|
439
|
+
figsize=(6, 6),
|
|
440
|
+
sharex="col",
|
|
441
|
+
height_ratios=[0.7, 0.3],
|
|
442
|
+
)
|
|
520
443
|
|
|
521
|
-
for name, obsconf, ax in zip(
|
|
522
|
-
obsconf_container.keys(),
|
|
523
|
-
obsconf_container.values(),
|
|
524
|
-
axs.T if len(obsconf_container) > 1 else [axs],
|
|
525
|
-
):
|
|
526
444
|
legend_plots = []
|
|
527
445
|
legend_labels = []
|
|
446
|
+
|
|
528
447
|
count = az.extract(
|
|
529
|
-
self.inference_data, var_names=f"obs_{
|
|
448
|
+
self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
|
|
530
449
|
).values.T
|
|
531
|
-
bkg_count = (
|
|
532
|
-
None
|
|
533
|
-
if self.background_model is None
|
|
534
|
-
else az.extract(
|
|
535
|
-
self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
|
|
536
|
-
).values.T
|
|
537
|
-
)
|
|
538
450
|
|
|
539
|
-
xbins = obsconf
|
|
540
|
-
xbins = xbins.to(x_unit, u.spectral())
|
|
541
|
-
|
|
542
|
-
# This compute the total effective area within all bins
|
|
543
|
-
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
544
|
-
exposure = obsconf.exposure.data * u.s
|
|
545
|
-
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
546
|
-
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
547
|
-
e_grid = np.linspace(*xbins, 10)
|
|
548
|
-
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
549
|
-
integrated_arf = (
|
|
550
|
-
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
551
|
-
/ (
|
|
552
|
-
np.abs(
|
|
553
|
-
xbins[1] - xbins[0]
|
|
554
|
-
) # Must fold in abs because some units reverse the ordering of the bins
|
|
555
|
-
)
|
|
556
|
-
* u.cm**2
|
|
557
|
-
)
|
|
451
|
+
xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
|
|
558
452
|
|
|
559
453
|
match y_type:
|
|
560
454
|
case "counts":
|
|
@@ -566,134 +460,124 @@ class FitResult:
|
|
|
566
460
|
case "photon_flux_density":
|
|
567
461
|
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
568
462
|
|
|
569
|
-
y_samples = (count * u.
|
|
570
|
-
|
|
571
|
-
y_observed_low = (
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
/ denominator
|
|
575
|
-
).to(y_units)
|
|
576
|
-
y_observed_high = (
|
|
577
|
-
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
578
|
-
* u.photon
|
|
579
|
-
/ denominator
|
|
580
|
-
).to(y_units)
|
|
463
|
+
y_samples = (count * u.ct / denominator).to(y_units)
|
|
464
|
+
|
|
465
|
+
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
466
|
+
obsconf.folded_counts.data, denominator, y_units
|
|
467
|
+
)
|
|
581
468
|
|
|
582
469
|
# Use the helper function to plot the data and posterior predictive
|
|
583
|
-
|
|
470
|
+
model_plot = _plot_binned_samples_with_error(
|
|
584
471
|
ax[0],
|
|
585
472
|
xbins.value,
|
|
586
|
-
y_samples
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
473
|
+
y_samples.value,
|
|
474
|
+
color=SPECTRUM_COLOR,
|
|
475
|
+
n_sigmas=n_sigmas,
|
|
476
|
+
alpha_envelope=alpha_envelope,
|
|
590
477
|
)
|
|
591
478
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
479
|
+
true_data_plot = _plot_poisson_data_with_error(
|
|
480
|
+
ax[0],
|
|
481
|
+
xbins.value,
|
|
596
482
|
y_observed.value,
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
],
|
|
602
|
-
color="black",
|
|
603
|
-
linestyle="none",
|
|
604
|
-
alpha=0.3,
|
|
605
|
-
capsize=2,
|
|
483
|
+
y_observed_low.value,
|
|
484
|
+
y_observed_high.value,
|
|
485
|
+
color=SPECTRUM_DATA_COLOR,
|
|
486
|
+
alpha=0.7,
|
|
606
487
|
)
|
|
607
488
|
|
|
608
489
|
legend_plots.append((true_data_plot,))
|
|
609
490
|
legend_labels.append("Observed")
|
|
491
|
+
legend_plots += model_plot
|
|
492
|
+
legend_labels.append("Model")
|
|
610
493
|
|
|
611
|
-
|
|
494
|
+
# Plot the residuals
|
|
495
|
+
residual_samples = (obsconf.folded_counts.data - count) / np.diff(
|
|
496
|
+
np.percentile(count, [16, 84], axis=0), axis=0
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
_plot_binned_samples_with_error(
|
|
500
|
+
ax[1],
|
|
501
|
+
xbins.value,
|
|
502
|
+
residual_samples,
|
|
503
|
+
color=SPECTRUM_COLOR,
|
|
504
|
+
n_sigmas=n_sigmas,
|
|
505
|
+
alpha_envelope=alpha_envelope,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if plot_components:
|
|
509
|
+
for (component_name, count), color in zip(
|
|
510
|
+
self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
|
|
511
|
+
):
|
|
512
|
+
# _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
|
|
513
|
+
y_samples = (
|
|
514
|
+
count.reshape((count.shape[0] * count.shape[1], -1))
|
|
515
|
+
* u.ct
|
|
516
|
+
/ denominator
|
|
517
|
+
).to(y_units)
|
|
518
|
+
component_plot = _plot_binned_samples_with_error(
|
|
519
|
+
ax[0],
|
|
520
|
+
xbins.value,
|
|
521
|
+
y_samples.value,
|
|
522
|
+
color=color,
|
|
523
|
+
linestyle="dashdot",
|
|
524
|
+
n_sigmas=n_sigmas,
|
|
525
|
+
alpha_envelope=alpha_envelope,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
legend_plots += component_plot
|
|
529
|
+
legend_labels.append(component_name)
|
|
530
|
+
|
|
531
|
+
if self.background_model is not None and plot_background:
|
|
612
532
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
533
|
+
bkg_count = (
|
|
534
|
+
None
|
|
535
|
+
if self.background_model is None
|
|
536
|
+
else az.extract(
|
|
537
|
+
self.inference_data,
|
|
538
|
+
var_names=f"bkg_{obs_id}",
|
|
539
|
+
group="posterior_predictive",
|
|
540
|
+
).values.T
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
|
|
544
|
+
|
|
545
|
+
y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
|
|
546
|
+
_error_bars_for_observed_data(
|
|
547
|
+
obsconf.folded_background.data, denominator, y_units
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
model_bkg_plot = _plot_binned_samples_with_error(
|
|
630
552
|
ax[0],
|
|
631
553
|
xbins.value,
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
554
|
+
y_samples_bkg.value,
|
|
555
|
+
color=BACKGROUND_COLOR,
|
|
556
|
+
alpha_envelope=alpha_envelope,
|
|
557
|
+
n_sigmas=n_sigmas,
|
|
636
558
|
)
|
|
637
559
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
560
|
+
true_bkg_plot = _plot_poisson_data_with_error(
|
|
561
|
+
ax[0],
|
|
562
|
+
xbins.value,
|
|
642
563
|
y_observed_bkg.value,
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
],
|
|
648
|
-
color="black",
|
|
649
|
-
linestyle="none",
|
|
650
|
-
alpha=0.3,
|
|
651
|
-
capsize=2,
|
|
564
|
+
y_observed_bkg_low.value,
|
|
565
|
+
y_observed_bkg_high.value,
|
|
566
|
+
color=BACKGROUND_DATA_COLOR,
|
|
567
|
+
alpha=0.7,
|
|
652
568
|
)
|
|
653
569
|
|
|
654
570
|
legend_plots.append((true_bkg_plot,))
|
|
655
571
|
legend_labels.append("Observed (bkg)")
|
|
572
|
+
legend_plots += model_bkg_plot
|
|
573
|
+
legend_labels.append("Model (bkg)")
|
|
656
574
|
|
|
657
|
-
|
|
658
|
-
np.percentile(count, percentile, axis=0), axis=0
|
|
659
|
-
)
|
|
660
|
-
|
|
661
|
-
residuals = np.percentile(
|
|
662
|
-
residual_samples,
|
|
663
|
-
percentile,
|
|
664
|
-
axis=0,
|
|
665
|
-
)
|
|
666
|
-
|
|
667
|
-
median_residuals = np.median(
|
|
668
|
-
residual_samples,
|
|
669
|
-
axis=0,
|
|
670
|
-
)
|
|
671
|
-
|
|
672
|
-
ax[1].stairs(
|
|
673
|
-
residuals[1],
|
|
674
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
675
|
-
baseline=list(residuals[0]),
|
|
676
|
-
alpha=0.3,
|
|
677
|
-
facecolor=color,
|
|
678
|
-
fill=True,
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
ax[1].stairs(
|
|
682
|
-
median_residuals,
|
|
683
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
684
|
-
color=color,
|
|
685
|
-
alpha=0.7,
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
max_residuals = np.max(np.abs(residuals))
|
|
575
|
+
max_residuals = np.max(np.abs(residual_samples))
|
|
689
576
|
|
|
690
577
|
ax[0].loglog()
|
|
691
578
|
ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
695
|
-
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
696
|
-
plot_ylabels_once = False
|
|
579
|
+
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
580
|
+
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
697
581
|
|
|
698
582
|
match getattr(x_unit, "physical_type"):
|
|
699
583
|
case "length":
|
|
@@ -708,24 +592,42 @@ class FitResult:
|
|
|
708
592
|
f"Must be 'length', 'energy' or 'frequency'"
|
|
709
593
|
)
|
|
710
594
|
|
|
711
|
-
ax[1].axhline(0, color=
|
|
712
|
-
ax[1].axhline(-3, color=
|
|
713
|
-
ax[1].axhline(3, color=
|
|
595
|
+
ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
|
|
596
|
+
ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
597
|
+
ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
714
598
|
|
|
715
|
-
# ax[1].set_xticks(xticks, labels=xticks)
|
|
716
|
-
# ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
|
|
717
599
|
ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
|
|
718
600
|
ax[1].set_yticks(range(-3, 4), minor=True)
|
|
719
601
|
|
|
720
602
|
ax[0].set_xlim(xbins.value.min(), xbins.value.max())
|
|
721
603
|
|
|
722
604
|
ax[0].legend(legend_plots, legend_labels)
|
|
723
|
-
|
|
605
|
+
|
|
606
|
+
match scale:
|
|
607
|
+
case "linear":
|
|
608
|
+
ax[0].set_xscale("linear")
|
|
609
|
+
ax[0].set_yscale("linear")
|
|
610
|
+
case "semilogx":
|
|
611
|
+
ax[0].set_xscale("log")
|
|
612
|
+
ax[0].set_yscale("linear")
|
|
613
|
+
case "semilogy":
|
|
614
|
+
ax[0].set_xscale("linear")
|
|
615
|
+
ax[0].set_yscale("log")
|
|
616
|
+
case "loglog":
|
|
617
|
+
ax[0].set_xscale("log")
|
|
618
|
+
ax[0].set_yscale("log")
|
|
619
|
+
|
|
724
620
|
fig.align_ylabels()
|
|
725
621
|
plt.subplots_adjust(hspace=0.0)
|
|
726
622
|
fig.tight_layout()
|
|
623
|
+
figure_list.append(fig)
|
|
624
|
+
fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
|
|
625
|
+
# fig.show()
|
|
626
|
+
|
|
627
|
+
plt.tight_layout()
|
|
628
|
+
plt.show()
|
|
727
629
|
|
|
728
|
-
|
|
630
|
+
return figure_list
|
|
729
631
|
|
|
730
632
|
def table(self) -> str:
|
|
731
633
|
r"""
|