jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +219 -330
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +56 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/METADATA +12 -13
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from functools import cached_property
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
5
5
|
|
|
6
6
|
import arviz as az
|
|
@@ -10,17 +10,28 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import matplotlib.pyplot as plt
|
|
12
12
|
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
13
14
|
import xarray as xr
|
|
14
15
|
|
|
15
16
|
from astropy.cosmology import Cosmology, Planck18
|
|
16
17
|
from astropy.units import Unit
|
|
17
18
|
from chainconsumer import Chain, ChainConsumer, PlotConfig
|
|
18
|
-
from
|
|
19
|
+
from jax.experimental.sparse import BCOO
|
|
19
20
|
from jax.typing import ArrayLike
|
|
20
21
|
from numpyro.handlers import seed
|
|
21
|
-
from scipy.integrate import trapezoid
|
|
22
22
|
from scipy.special import gammaln
|
|
23
|
-
|
|
23
|
+
|
|
24
|
+
from ._plot import (
|
|
25
|
+
BACKGROUND_COLOR,
|
|
26
|
+
BACKGROUND_DATA_COLOR,
|
|
27
|
+
COLOR_CYCLE,
|
|
28
|
+
SPECTRUM_COLOR,
|
|
29
|
+
SPECTRUM_DATA_COLOR,
|
|
30
|
+
_compute_effective_area,
|
|
31
|
+
_error_bars_for_observed_data,
|
|
32
|
+
_plot_binned_samples_with_error,
|
|
33
|
+
_plot_poisson_data_with_error,
|
|
34
|
+
)
|
|
24
35
|
|
|
25
36
|
if TYPE_CHECKING:
|
|
26
37
|
from ..fit import BayesianModel
|
|
@@ -31,67 +42,6 @@ V = TypeVar("V")
|
|
|
31
42
|
T = TypeVar("T")
|
|
32
43
|
|
|
33
44
|
|
|
34
|
-
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _plot_binned_samples_with_error(
|
|
38
|
-
ax: plt.Axes,
|
|
39
|
-
x_bins: ArrayLike,
|
|
40
|
-
denominator: ArrayLike | None = None,
|
|
41
|
-
y_samples: ArrayLike | None = None,
|
|
42
|
-
color=(0.15, 0.25, 0.45),
|
|
43
|
-
percentile: tuple = (16, 84),
|
|
44
|
-
):
|
|
45
|
-
"""
|
|
46
|
-
Helper function to plot the posterior predictive distribution of the model. The function
|
|
47
|
-
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
48
|
-
area. If the observed data is provided, it is also plotted as a step function.
|
|
49
|
-
|
|
50
|
-
Parameters
|
|
51
|
-
----------
|
|
52
|
-
x_bins: The bin edges of the data (2 x N).
|
|
53
|
-
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
54
|
-
denominator: Values used to divided the samples, i.e. to get energy flux (N).
|
|
55
|
-
ax: The matplotlib axes object.
|
|
56
|
-
color: The color of the posterior predictive distribution.
|
|
57
|
-
y_observed: The observed data (N).
|
|
58
|
-
label: The label of the observed data.
|
|
59
|
-
percentile: The percentile of the posterior predictive distribution to plot.
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
mean, envelope = None, None
|
|
63
|
-
|
|
64
|
-
if denominator is None:
|
|
65
|
-
denominator = np.ones_like(x_bins[0])
|
|
66
|
-
|
|
67
|
-
mean = ax.stairs(
|
|
68
|
-
list(np.median(y_samples, axis=0) / denominator),
|
|
69
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
70
|
-
color=color,
|
|
71
|
-
alpha=0.7,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
if y_samples is not None:
|
|
75
|
-
if denominator is None:
|
|
76
|
-
denominator = np.ones_like(x_bins[0])
|
|
77
|
-
|
|
78
|
-
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
79
|
-
|
|
80
|
-
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
81
|
-
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
82
|
-
|
|
83
|
-
ax.stairs(
|
|
84
|
-
percentiles[1] / denominator,
|
|
85
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
86
|
-
baseline=percentiles[0] / denominator,
|
|
87
|
-
alpha=0.3,
|
|
88
|
-
fill=True,
|
|
89
|
-
color=color,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
return [(mean, envelope)]
|
|
93
|
-
|
|
94
|
-
|
|
95
45
|
class FitResult:
|
|
96
46
|
"""
|
|
97
47
|
Container for the result of a fit using any ModelFitter class.
|
|
@@ -102,7 +52,6 @@ class FitResult:
|
|
|
102
52
|
self,
|
|
103
53
|
bayesian_fitter: BayesianModel,
|
|
104
54
|
inference_data: az.InferenceData,
|
|
105
|
-
structure: Mapping[K, V],
|
|
106
55
|
background_model: BackgroundModel = None,
|
|
107
56
|
):
|
|
108
57
|
self.model = bayesian_fitter.model
|
|
@@ -110,7 +59,6 @@ class FitResult:
|
|
|
110
59
|
self.inference_data = inference_data
|
|
111
60
|
self.obsconfs = bayesian_fitter.observation_container
|
|
112
61
|
self.background_model = background_model
|
|
113
|
-
self._structure = structure
|
|
114
62
|
|
|
115
63
|
# Add the model used in fit to the metadata
|
|
116
64
|
for group in self.inference_data.groups():
|
|
@@ -127,37 +75,38 @@ class FitResult:
|
|
|
127
75
|
|
|
128
76
|
return all(az.rhat(self.inference_data) < 1.01)
|
|
129
77
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
134
|
-
"""
|
|
135
|
-
|
|
136
|
-
samples_flat = self._structured_samples_flat
|
|
78
|
+
def _ppc_folded_branches(self, obs_id):
|
|
79
|
+
obs = self.obsconfs[obs_id]
|
|
137
80
|
|
|
138
|
-
|
|
81
|
+
if len(next(iter(self.input_parameters.values())).shape) > 2:
|
|
82
|
+
idx = list(self.obsconfs.keys()).index(obs_id)
|
|
83
|
+
obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
|
|
139
84
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
samples_haiku[module] = {}
|
|
143
|
-
samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
|
|
85
|
+
else:
|
|
86
|
+
obs_parameters = self.input_parameters
|
|
144
87
|
|
|
145
|
-
|
|
88
|
+
if self.bayesian_fitter.sparse:
|
|
89
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
90
|
+
obs.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
91
|
+
)
|
|
146
92
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
"""
|
|
150
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
151
|
-
"""
|
|
93
|
+
else:
|
|
94
|
+
transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
|
|
152
95
|
|
|
153
|
-
|
|
154
|
-
posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
|
|
155
|
-
samples_flat = {key: posterior[key].data for key in var_names}
|
|
96
|
+
energies = obs.in_energies
|
|
156
97
|
|
|
157
|
-
|
|
98
|
+
flux_func = jax.jit(
|
|
99
|
+
jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
|
|
100
|
+
)
|
|
101
|
+
convolve_func = jax.jit(
|
|
102
|
+
jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
|
|
103
|
+
)
|
|
104
|
+
return jax.tree.map(
|
|
105
|
+
lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
|
|
106
|
+
)
|
|
158
107
|
|
|
159
|
-
@
|
|
160
|
-
def input_parameters(self) ->
|
|
108
|
+
@cached_property
|
|
109
|
+
def input_parameters(self) -> dict[str, ArrayLike]:
|
|
161
110
|
"""
|
|
162
111
|
The input parameters of the model.
|
|
163
112
|
"""
|
|
@@ -173,7 +122,9 @@ class FitResult:
|
|
|
173
122
|
with seed(rng_seed=0):
|
|
174
123
|
input_parameters = self.bayesian_fitter.prior_distributions_func()
|
|
175
124
|
|
|
176
|
-
for
|
|
125
|
+
for key, value in input_parameters.items():
|
|
126
|
+
module, parameter = key.rsplit("_", 1)
|
|
127
|
+
|
|
177
128
|
if f"{module}_{parameter}" in posterior.keys():
|
|
178
129
|
# We add as extra dimension as there might be different values per observation
|
|
179
130
|
if posterior[f"{module}_{parameter}"].shape == samples_shape:
|
|
@@ -181,19 +132,21 @@ class FitResult:
|
|
|
181
132
|
else:
|
|
182
133
|
to_set = posterior[f"{module}_{parameter}"]
|
|
183
134
|
|
|
184
|
-
input_parameters[module
|
|
135
|
+
input_parameters[f"{module}_{parameter}"] = to_set
|
|
185
136
|
|
|
186
137
|
else:
|
|
187
138
|
# The parameter is fixed in this case, so we just broadcast is over chain and draws
|
|
188
|
-
input_parameters[module
|
|
139
|
+
input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
|
|
189
140
|
|
|
190
|
-
if len(total_shape) < len(input_parameters[module
|
|
141
|
+
if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
|
|
191
142
|
# If there are only chains and draws, we reduce
|
|
192
|
-
input_parameters[module
|
|
143
|
+
input_parameters[f"{module}_{parameter}"] = input_parameters[
|
|
144
|
+
f"{module}_{parameter}"
|
|
145
|
+
][..., 0]
|
|
193
146
|
|
|
194
147
|
else:
|
|
195
|
-
input_parameters[module
|
|
196
|
-
input_parameters[module
|
|
148
|
+
input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
|
|
149
|
+
input_parameters[f"{module}_{parameter}"], total_shape
|
|
197
150
|
)
|
|
198
151
|
|
|
199
152
|
return input_parameters
|
|
@@ -346,94 +299,46 @@ class FitResult:
|
|
|
346
299
|
|
|
347
300
|
return value
|
|
348
301
|
|
|
349
|
-
def to_chain(self, name: str
|
|
302
|
+
def to_chain(self, name: str) -> Chain:
|
|
350
303
|
"""
|
|
351
304
|
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
352
305
|
|
|
353
306
|
Parameters:
|
|
354
307
|
name: The name of the chain.
|
|
355
|
-
parameters_type: The parameters_type to include in the chain.
|
|
356
308
|
"""
|
|
357
309
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
key
|
|
363
|
-
for key in obs_id.posterior.keys()
|
|
364
|
-
if (key.startswith("_") or key.startswith("bkg"))
|
|
365
|
-
]
|
|
366
|
-
elif parameters_type == "bkg":
|
|
367
|
-
keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
|
|
368
|
-
else:
|
|
369
|
-
raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
|
|
370
|
-
|
|
371
|
-
obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
|
|
372
|
-
chain = Chain.from_arviz(obs_id, name)
|
|
373
|
-
|
|
374
|
-
"""
|
|
375
|
-
chain.samples.columns = [
|
|
376
|
-
format_parameters(parameter) for parameter in chain.samples.columns
|
|
310
|
+
keys_to_drop = [
|
|
311
|
+
key
|
|
312
|
+
for key in self.inference_data.posterior.keys()
|
|
313
|
+
if (key.startswith("_") or key.startswith("bkg"))
|
|
377
314
|
]
|
|
378
|
-
"""
|
|
379
315
|
|
|
380
|
-
|
|
316
|
+
reduced_id = az.extract(
|
|
317
|
+
self.inference_data,
|
|
318
|
+
var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
|
|
319
|
+
group="posterior",
|
|
320
|
+
)
|
|
381
321
|
|
|
382
|
-
|
|
383
|
-
def samples_haiku(self) -> HaikuDict[ArrayLike]:
|
|
384
|
-
"""
|
|
385
|
-
Haiku-like structure for the samples e.g.
|
|
386
|
-
|
|
387
|
-
```
|
|
388
|
-
{
|
|
389
|
-
'powerlaw_1' :
|
|
390
|
-
{
|
|
391
|
-
'alpha': ...,
|
|
392
|
-
'amplitude': ...
|
|
393
|
-
},
|
|
394
|
-
|
|
395
|
-
'blackbody_1':
|
|
396
|
-
{
|
|
397
|
-
'kT': ...,
|
|
398
|
-
'norm': ...
|
|
399
|
-
},
|
|
400
|
-
|
|
401
|
-
'tbabs_1':
|
|
402
|
-
{
|
|
403
|
-
'nH': ...
|
|
404
|
-
}
|
|
405
|
-
}
|
|
406
|
-
```
|
|
322
|
+
df_list = []
|
|
407
323
|
|
|
408
|
-
|
|
324
|
+
for var, array in reduced_id.data_vars.items():
|
|
325
|
+
extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
|
|
409
326
|
|
|
410
|
-
|
|
327
|
+
if extra_dims:
|
|
328
|
+
dim = extra_dims[
|
|
329
|
+
0
|
|
330
|
+
] # We only support the case where the extra dimension comes from the observations
|
|
411
331
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
332
|
+
for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
|
|
333
|
+
df = array.loc[{dim: coord}].to_pandas()
|
|
334
|
+
df.name += f"\n[{obs_id}]"
|
|
335
|
+
df_list.append(df)
|
|
336
|
+
else:
|
|
337
|
+
df_list.append(array.to_pandas())
|
|
416
338
|
|
|
417
|
-
|
|
339
|
+
df = pd.concat(df_list, axis=1)
|
|
418
340
|
|
|
419
|
-
|
|
420
|
-
def samples_flat(self) -> dict[str, ArrayLike]:
|
|
421
|
-
"""
|
|
422
|
-
Flat structure for the samples e.g.
|
|
423
|
-
|
|
424
|
-
```
|
|
425
|
-
{
|
|
426
|
-
'powerlaw_1_alpha': ...,
|
|
427
|
-
'powerlaw_1_amplitude': ...,
|
|
428
|
-
'blackbody_1_kT': ...,
|
|
429
|
-
'blackbody_1_norm': ...,
|
|
430
|
-
'tbabs_1_nH': ...,
|
|
431
|
-
}
|
|
432
|
-
```
|
|
433
|
-
"""
|
|
434
|
-
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
435
|
-
posterior = az.extract(self.inference_data, var_names=var_names)
|
|
436
|
-
return {key: posterior[key].data for key in var_names}
|
|
341
|
+
return Chain(samples=df, name=name)
|
|
437
342
|
|
|
438
343
|
@property
|
|
439
344
|
def log_likelihood(self) -> xr.Dataset:
|
|
@@ -475,12 +380,18 @@ class FitResult:
|
|
|
475
380
|
|
|
476
381
|
def plot_ppc(
|
|
477
382
|
self,
|
|
478
|
-
|
|
383
|
+
n_sigmas: int = 1,
|
|
479
384
|
x_unit: str | u.Unit = "keV",
|
|
480
385
|
y_type: Literal[
|
|
481
386
|
"counts", "countrate", "photon_flux", "photon_flux_density"
|
|
482
387
|
] = "photon_flux_density",
|
|
483
|
-
|
|
388
|
+
plot_background: bool = True,
|
|
389
|
+
plot_components: bool = False,
|
|
390
|
+
scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
|
|
391
|
+
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
392
|
+
style: str | Any = "default",
|
|
393
|
+
title: str | None = None,
|
|
394
|
+
) -> list[plt.Figure]:
|
|
484
395
|
r"""
|
|
485
396
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
486
397
|
following formula:
|
|
@@ -492,12 +403,18 @@ class FitResult:
|
|
|
492
403
|
percentile: The percentile of the posterior predictive distribution to plot.
|
|
493
404
|
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
494
405
|
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
406
|
+
plot_background: Whether to plot the background model if it is included in the fit.
|
|
407
|
+
plot_components: Whether to plot the components of the model separately.
|
|
408
|
+
scale: The axes scaling
|
|
409
|
+
alpha_envelope: The transparency range for envelops
|
|
410
|
+
style: The style of the plot. It can be either a string or a matplotlib style context.
|
|
495
411
|
|
|
496
412
|
Returns:
|
|
497
|
-
|
|
413
|
+
A list of matplotlib figures for each observation in the model.
|
|
498
414
|
"""
|
|
499
415
|
|
|
500
416
|
obsconf_container = self.obsconfs
|
|
417
|
+
figure_list = []
|
|
501
418
|
x_unit = u.Unit(x_unit)
|
|
502
419
|
|
|
503
420
|
match y_type:
|
|
@@ -514,60 +431,24 @@ class FitResult:
|
|
|
514
431
|
f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
|
|
515
432
|
)
|
|
516
433
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
len(obsconf_container),
|
|
527
|
-
figsize=(6 * len(obsconf_container), 6),
|
|
528
|
-
sharex=True,
|
|
529
|
-
height_ratios=[0.7, 0.3],
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
plot_ylabels_once = True
|
|
434
|
+
with plt.style.context(style):
|
|
435
|
+
for obs_id, obsconf in obsconf_container.items():
|
|
436
|
+
fig, ax = plt.subplots(
|
|
437
|
+
2,
|
|
438
|
+
1,
|
|
439
|
+
figsize=(6, 6),
|
|
440
|
+
sharex="col",
|
|
441
|
+
height_ratios=[0.7, 0.3],
|
|
442
|
+
)
|
|
533
443
|
|
|
534
|
-
for name, obsconf, ax in zip(
|
|
535
|
-
obsconf_container.keys(),
|
|
536
|
-
obsconf_container.values(),
|
|
537
|
-
axs.T if len(obsconf_container) > 1 else [axs],
|
|
538
|
-
):
|
|
539
444
|
legend_plots = []
|
|
540
445
|
legend_labels = []
|
|
446
|
+
|
|
541
447
|
count = az.extract(
|
|
542
|
-
self.inference_data, var_names=f"obs_{
|
|
448
|
+
self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
|
|
543
449
|
).values.T
|
|
544
|
-
bkg_count = (
|
|
545
|
-
None
|
|
546
|
-
if self.background_model is None
|
|
547
|
-
else az.extract(
|
|
548
|
-
self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
|
|
549
|
-
).values.T
|
|
550
|
-
)
|
|
551
450
|
|
|
552
|
-
xbins = obsconf
|
|
553
|
-
xbins = xbins.to(x_unit, u.spectral())
|
|
554
|
-
|
|
555
|
-
# This compute the total effective area within all bins
|
|
556
|
-
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
557
|
-
exposure = obsconf.exposure.data * u.s
|
|
558
|
-
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
559
|
-
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
560
|
-
e_grid = np.linspace(*xbins, 10)
|
|
561
|
-
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
562
|
-
integrated_arf = (
|
|
563
|
-
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
564
|
-
/ (
|
|
565
|
-
np.abs(
|
|
566
|
-
xbins[1] - xbins[0]
|
|
567
|
-
) # Must fold in abs because some units reverse the ordering of the bins
|
|
568
|
-
)
|
|
569
|
-
* u.cm**2
|
|
570
|
-
)
|
|
451
|
+
xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
|
|
571
452
|
|
|
572
453
|
match y_type:
|
|
573
454
|
case "counts":
|
|
@@ -580,133 +461,123 @@ class FitResult:
|
|
|
580
461
|
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
581
462
|
|
|
582
463
|
y_samples = (count * u.ct / denominator).to(y_units)
|
|
583
|
-
|
|
584
|
-
y_observed_low = (
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
/ denominator
|
|
588
|
-
).to(y_units)
|
|
589
|
-
y_observed_high = (
|
|
590
|
-
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
591
|
-
* u.ct
|
|
592
|
-
/ denominator
|
|
593
|
-
).to(y_units)
|
|
464
|
+
|
|
465
|
+
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
466
|
+
obsconf.folded_counts.data, denominator, y_units
|
|
467
|
+
)
|
|
594
468
|
|
|
595
469
|
# Use the helper function to plot the data and posterior predictive
|
|
596
|
-
|
|
470
|
+
model_plot = _plot_binned_samples_with_error(
|
|
597
471
|
ax[0],
|
|
598
472
|
xbins.value,
|
|
599
|
-
y_samples
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
473
|
+
y_samples.value,
|
|
474
|
+
color=SPECTRUM_COLOR,
|
|
475
|
+
n_sigmas=n_sigmas,
|
|
476
|
+
alpha_envelope=alpha_envelope,
|
|
603
477
|
)
|
|
604
478
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
479
|
+
true_data_plot = _plot_poisson_data_with_error(
|
|
480
|
+
ax[0],
|
|
481
|
+
xbins.value,
|
|
609
482
|
y_observed.value,
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
],
|
|
615
|
-
color="black",
|
|
616
|
-
linestyle="none",
|
|
617
|
-
alpha=0.3,
|
|
618
|
-
capsize=2,
|
|
483
|
+
y_observed_low.value,
|
|
484
|
+
y_observed_high.value,
|
|
485
|
+
color=SPECTRUM_DATA_COLOR,
|
|
486
|
+
alpha=0.7,
|
|
619
487
|
)
|
|
620
488
|
|
|
621
489
|
legend_plots.append((true_data_plot,))
|
|
622
490
|
legend_labels.append("Observed")
|
|
491
|
+
legend_plots += model_plot
|
|
492
|
+
legend_labels.append("Model")
|
|
493
|
+
|
|
494
|
+
# Plot the residuals
|
|
495
|
+
residual_samples = (obsconf.folded_counts.data - count) / np.diff(
|
|
496
|
+
np.percentile(count, [16, 84], axis=0), axis=0
|
|
497
|
+
)
|
|
623
498
|
|
|
624
|
-
|
|
499
|
+
_plot_binned_samples_with_error(
|
|
500
|
+
ax[1],
|
|
501
|
+
xbins.value,
|
|
502
|
+
residual_samples,
|
|
503
|
+
color=SPECTRUM_COLOR,
|
|
504
|
+
n_sigmas=n_sigmas,
|
|
505
|
+
alpha_envelope=alpha_envelope,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if plot_components:
|
|
509
|
+
for (component_name, count), color in zip(
|
|
510
|
+
self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
|
|
511
|
+
):
|
|
512
|
+
# _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
|
|
513
|
+
y_samples = (
|
|
514
|
+
count.reshape((count.shape[0] * count.shape[1], -1))
|
|
515
|
+
* u.ct
|
|
516
|
+
/ denominator
|
|
517
|
+
).to(y_units)
|
|
518
|
+
component_plot = _plot_binned_samples_with_error(
|
|
519
|
+
ax[0],
|
|
520
|
+
xbins.value,
|
|
521
|
+
y_samples.value,
|
|
522
|
+
color=color,
|
|
523
|
+
linestyle="dashdot",
|
|
524
|
+
n_sigmas=n_sigmas,
|
|
525
|
+
alpha_envelope=alpha_envelope,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
legend_plots += component_plot
|
|
529
|
+
legend_labels.append(component_name)
|
|
530
|
+
|
|
531
|
+
if self.background_model is not None and plot_background:
|
|
625
532
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
533
|
+
bkg_count = (
|
|
534
|
+
None
|
|
535
|
+
if self.background_model is None
|
|
536
|
+
else az.extract(
|
|
537
|
+
self.inference_data,
|
|
538
|
+
var_names=f"bkg_{obs_id}",
|
|
539
|
+
group="posterior_predictive",
|
|
540
|
+
).values.T
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
|
|
544
|
+
|
|
545
|
+
y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
|
|
546
|
+
_error_bars_for_observed_data(
|
|
547
|
+
obsconf.folded_background.data, denominator, y_units
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
model_bkg_plot = _plot_binned_samples_with_error(
|
|
643
552
|
ax[0],
|
|
644
553
|
xbins.value,
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
554
|
+
y_samples_bkg.value,
|
|
555
|
+
color=BACKGROUND_COLOR,
|
|
556
|
+
alpha_envelope=alpha_envelope,
|
|
557
|
+
n_sigmas=n_sigmas,
|
|
649
558
|
)
|
|
650
559
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
560
|
+
true_bkg_plot = _plot_poisson_data_with_error(
|
|
561
|
+
ax[0],
|
|
562
|
+
xbins.value,
|
|
655
563
|
y_observed_bkg.value,
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
],
|
|
661
|
-
color="black",
|
|
662
|
-
linestyle="none",
|
|
663
|
-
alpha=0.3,
|
|
664
|
-
capsize=2,
|
|
564
|
+
y_observed_bkg_low.value,
|
|
565
|
+
y_observed_bkg_high.value,
|
|
566
|
+
color=BACKGROUND_DATA_COLOR,
|
|
567
|
+
alpha=0.7,
|
|
665
568
|
)
|
|
666
569
|
|
|
667
570
|
legend_plots.append((true_bkg_plot,))
|
|
668
571
|
legend_labels.append("Observed (bkg)")
|
|
572
|
+
legend_plots += model_bkg_plot
|
|
573
|
+
legend_labels.append("Model (bkg)")
|
|
669
574
|
|
|
670
|
-
|
|
671
|
-
np.percentile(count, percentile, axis=0), axis=0
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
residuals = np.percentile(
|
|
675
|
-
residual_samples,
|
|
676
|
-
percentile,
|
|
677
|
-
axis=0,
|
|
678
|
-
)
|
|
679
|
-
|
|
680
|
-
median_residuals = np.median(
|
|
681
|
-
residual_samples,
|
|
682
|
-
axis=0,
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
ax[1].stairs(
|
|
686
|
-
residuals[1],
|
|
687
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
688
|
-
baseline=list(residuals[0]),
|
|
689
|
-
alpha=0.3,
|
|
690
|
-
facecolor=color,
|
|
691
|
-
fill=True,
|
|
692
|
-
)
|
|
693
|
-
|
|
694
|
-
ax[1].stairs(
|
|
695
|
-
median_residuals,
|
|
696
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
697
|
-
color=color,
|
|
698
|
-
alpha=0.7,
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
max_residuals = np.max(np.abs(residuals))
|
|
575
|
+
max_residuals = np.max(np.abs(residual_samples))
|
|
702
576
|
|
|
703
577
|
ax[0].loglog()
|
|
704
578
|
ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
708
|
-
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
709
|
-
plot_ylabels_once = False
|
|
579
|
+
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
580
|
+
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
710
581
|
|
|
711
582
|
match getattr(x_unit, "physical_type"):
|
|
712
583
|
case "length":
|
|
@@ -721,24 +592,42 @@ class FitResult:
|
|
|
721
592
|
f"Must be 'length', 'energy' or 'frequency'"
|
|
722
593
|
)
|
|
723
594
|
|
|
724
|
-
ax[1].axhline(0, color=
|
|
725
|
-
ax[1].axhline(-3, color=
|
|
726
|
-
ax[1].axhline(3, color=
|
|
595
|
+
ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
|
|
596
|
+
ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
597
|
+
ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
727
598
|
|
|
728
|
-
# ax[1].set_xticks(xticks, labels=xticks)
|
|
729
|
-
# ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
|
|
730
599
|
ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
|
|
731
600
|
ax[1].set_yticks(range(-3, 4), minor=True)
|
|
732
601
|
|
|
733
602
|
ax[0].set_xlim(xbins.value.min(), xbins.value.max())
|
|
734
603
|
|
|
735
604
|
ax[0].legend(legend_plots, legend_labels)
|
|
736
|
-
|
|
605
|
+
|
|
606
|
+
match scale:
|
|
607
|
+
case "linear":
|
|
608
|
+
ax[0].set_xscale("linear")
|
|
609
|
+
ax[0].set_yscale("linear")
|
|
610
|
+
case "semilogx":
|
|
611
|
+
ax[0].set_xscale("log")
|
|
612
|
+
ax[0].set_yscale("linear")
|
|
613
|
+
case "semilogy":
|
|
614
|
+
ax[0].set_xscale("linear")
|
|
615
|
+
ax[0].set_yscale("log")
|
|
616
|
+
case "loglog":
|
|
617
|
+
ax[0].set_xscale("log")
|
|
618
|
+
ax[0].set_yscale("log")
|
|
619
|
+
|
|
737
620
|
fig.align_ylabels()
|
|
738
621
|
plt.subplots_adjust(hspace=0.0)
|
|
739
622
|
fig.tight_layout()
|
|
623
|
+
figure_list.append(fig)
|
|
624
|
+
fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
|
|
625
|
+
# fig.show()
|
|
626
|
+
|
|
627
|
+
plt.tight_layout()
|
|
628
|
+
plt.show()
|
|
740
629
|
|
|
741
|
-
|
|
630
|
+
return figure_list
|
|
742
631
|
|
|
743
632
|
def table(self) -> str:
|
|
744
633
|
r"""
|