jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +231 -332
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +60 -70
- jaxspec/fit.py +76 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +102 -86
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/METADATA +13 -14
- jaxspec-0.2.1.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.1.dist-info}/entry_points.txt +0 -0
jaxspec/analysis/results.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from functools import cached_property
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
5
5
|
|
|
6
6
|
import arviz as az
|
|
@@ -10,17 +10,28 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import matplotlib.pyplot as plt
|
|
12
12
|
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
13
14
|
import xarray as xr
|
|
14
15
|
|
|
15
16
|
from astropy.cosmology import Cosmology, Planck18
|
|
16
17
|
from astropy.units import Unit
|
|
17
18
|
from chainconsumer import Chain, ChainConsumer, PlotConfig
|
|
18
|
-
from
|
|
19
|
+
from jax.experimental.sparse import BCOO
|
|
19
20
|
from jax.typing import ArrayLike
|
|
20
21
|
from numpyro.handlers import seed
|
|
21
|
-
from scipy.integrate import trapezoid
|
|
22
22
|
from scipy.special import gammaln
|
|
23
|
-
|
|
23
|
+
|
|
24
|
+
from ._plot import (
|
|
25
|
+
BACKGROUND_COLOR,
|
|
26
|
+
BACKGROUND_DATA_COLOR,
|
|
27
|
+
COLOR_CYCLE,
|
|
28
|
+
SPECTRUM_COLOR,
|
|
29
|
+
SPECTRUM_DATA_COLOR,
|
|
30
|
+
_compute_effective_area,
|
|
31
|
+
_error_bars_for_observed_data,
|
|
32
|
+
_plot_binned_samples_with_error,
|
|
33
|
+
_plot_poisson_data_with_error,
|
|
34
|
+
)
|
|
24
35
|
|
|
25
36
|
if TYPE_CHECKING:
|
|
26
37
|
from ..fit import BayesianModel
|
|
@@ -31,67 +42,6 @@ V = TypeVar("V")
|
|
|
31
42
|
T = TypeVar("T")
|
|
32
43
|
|
|
33
44
|
|
|
34
|
-
class HaikuDict(dict[str, dict[str, T]]): ...
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _plot_binned_samples_with_error(
|
|
38
|
-
ax: plt.Axes,
|
|
39
|
-
x_bins: ArrayLike,
|
|
40
|
-
denominator: ArrayLike | None = None,
|
|
41
|
-
y_samples: ArrayLike | None = None,
|
|
42
|
-
color=(0.15, 0.25, 0.45),
|
|
43
|
-
percentile: tuple = (16, 84),
|
|
44
|
-
):
|
|
45
|
-
"""
|
|
46
|
-
Helper function to plot the posterior predictive distribution of the model. The function
|
|
47
|
-
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
48
|
-
area. If the observed data is provided, it is also plotted as a step function.
|
|
49
|
-
|
|
50
|
-
Parameters
|
|
51
|
-
----------
|
|
52
|
-
x_bins: The bin edges of the data (2 x N).
|
|
53
|
-
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
54
|
-
denominator: Values used to divided the samples, i.e. to get energy flux (N).
|
|
55
|
-
ax: The matplotlib axes object.
|
|
56
|
-
color: The color of the posterior predictive distribution.
|
|
57
|
-
y_observed: The observed data (N).
|
|
58
|
-
label: The label of the observed data.
|
|
59
|
-
percentile: The percentile of the posterior predictive distribution to plot.
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
mean, envelope = None, None
|
|
63
|
-
|
|
64
|
-
if denominator is None:
|
|
65
|
-
denominator = np.ones_like(x_bins[0])
|
|
66
|
-
|
|
67
|
-
mean = ax.stairs(
|
|
68
|
-
list(np.median(y_samples, axis=0) / denominator),
|
|
69
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
70
|
-
color=color,
|
|
71
|
-
alpha=0.7,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
if y_samples is not None:
|
|
75
|
-
if denominator is None:
|
|
76
|
-
denominator = np.ones_like(x_bins[0])
|
|
77
|
-
|
|
78
|
-
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
79
|
-
|
|
80
|
-
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
81
|
-
(envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
|
|
82
|
-
|
|
83
|
-
ax.stairs(
|
|
84
|
-
percentiles[1] / denominator,
|
|
85
|
-
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
86
|
-
baseline=percentiles[0] / denominator,
|
|
87
|
-
alpha=0.3,
|
|
88
|
-
fill=True,
|
|
89
|
-
color=color,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
return [(mean, envelope)]
|
|
93
|
-
|
|
94
|
-
|
|
95
45
|
class FitResult:
|
|
96
46
|
"""
|
|
97
47
|
Container for the result of a fit using any ModelFitter class.
|
|
@@ -102,7 +52,6 @@ class FitResult:
|
|
|
102
52
|
self,
|
|
103
53
|
bayesian_fitter: BayesianModel,
|
|
104
54
|
inference_data: az.InferenceData,
|
|
105
|
-
structure: Mapping[K, V],
|
|
106
55
|
background_model: BackgroundModel = None,
|
|
107
56
|
):
|
|
108
57
|
self.model = bayesian_fitter.model
|
|
@@ -110,7 +59,6 @@ class FitResult:
|
|
|
110
59
|
self.inference_data = inference_data
|
|
111
60
|
self.obsconfs = bayesian_fitter.observation_container
|
|
112
61
|
self.background_model = background_model
|
|
113
|
-
self._structure = structure
|
|
114
62
|
|
|
115
63
|
# Add the model used in fit to the metadata
|
|
116
64
|
for group in self.inference_data.groups():
|
|
@@ -127,37 +75,38 @@ class FitResult:
|
|
|
127
75
|
|
|
128
76
|
return all(az.rhat(self.inference_data) < 1.01)
|
|
129
77
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
"""
|
|
133
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
134
|
-
"""
|
|
135
|
-
|
|
136
|
-
samples_flat = self._structured_samples_flat
|
|
78
|
+
def _ppc_folded_branches(self, obs_id):
|
|
79
|
+
obs = self.obsconfs[obs_id]
|
|
137
80
|
|
|
138
|
-
|
|
81
|
+
if len(next(iter(self.input_parameters.values())).shape) > 2:
|
|
82
|
+
idx = list(self.obsconfs.keys()).index(obs_id)
|
|
83
|
+
obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
|
|
139
84
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
samples_haiku[module] = {}
|
|
143
|
-
samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
|
|
85
|
+
else:
|
|
86
|
+
obs_parameters = self.input_parameters
|
|
144
87
|
|
|
145
|
-
|
|
88
|
+
if self.bayesian_fitter.sparse:
|
|
89
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
90
|
+
obs.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
91
|
+
)
|
|
146
92
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
"""
|
|
150
|
-
Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
|
|
151
|
-
"""
|
|
93
|
+
else:
|
|
94
|
+
transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
|
|
152
95
|
|
|
153
|
-
|
|
154
|
-
posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
|
|
155
|
-
samples_flat = {key: posterior[key].data for key in var_names}
|
|
96
|
+
energies = obs.in_energies
|
|
156
97
|
|
|
157
|
-
|
|
98
|
+
flux_func = jax.jit(
|
|
99
|
+
jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
|
|
100
|
+
)
|
|
101
|
+
convolve_func = jax.jit(
|
|
102
|
+
jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
|
|
103
|
+
)
|
|
104
|
+
return jax.tree.map(
|
|
105
|
+
lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
|
|
106
|
+
)
|
|
158
107
|
|
|
159
|
-
@
|
|
160
|
-
def input_parameters(self) ->
|
|
108
|
+
@cached_property
|
|
109
|
+
def input_parameters(self) -> dict[str, ArrayLike]:
|
|
161
110
|
"""
|
|
162
111
|
The input parameters of the model.
|
|
163
112
|
"""
|
|
@@ -173,7 +122,9 @@ class FitResult:
|
|
|
173
122
|
with seed(rng_seed=0):
|
|
174
123
|
input_parameters = self.bayesian_fitter.prior_distributions_func()
|
|
175
124
|
|
|
176
|
-
for
|
|
125
|
+
for key, value in input_parameters.items():
|
|
126
|
+
module, parameter = key.rsplit("_", 1)
|
|
127
|
+
|
|
177
128
|
if f"{module}_{parameter}" in posterior.keys():
|
|
178
129
|
# We add as extra dimension as there might be different values per observation
|
|
179
130
|
if posterior[f"{module}_{parameter}"].shape == samples_shape:
|
|
@@ -181,19 +132,21 @@ class FitResult:
|
|
|
181
132
|
else:
|
|
182
133
|
to_set = posterior[f"{module}_{parameter}"]
|
|
183
134
|
|
|
184
|
-
input_parameters[module
|
|
135
|
+
input_parameters[f"{module}_{parameter}"] = to_set
|
|
185
136
|
|
|
186
137
|
else:
|
|
187
138
|
# The parameter is fixed in this case, so we just broadcast is over chain and draws
|
|
188
|
-
input_parameters[module
|
|
139
|
+
input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
|
|
189
140
|
|
|
190
|
-
if len(total_shape) < len(input_parameters[module
|
|
141
|
+
if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
|
|
191
142
|
# If there are only chains and draws, we reduce
|
|
192
|
-
input_parameters[module
|
|
143
|
+
input_parameters[f"{module}_{parameter}"] = input_parameters[
|
|
144
|
+
f"{module}_{parameter}"
|
|
145
|
+
][..., 0]
|
|
193
146
|
|
|
194
147
|
else:
|
|
195
|
-
input_parameters[module
|
|
196
|
-
input_parameters[module
|
|
148
|
+
input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
|
|
149
|
+
input_parameters[f"{module}_{parameter}"], total_shape
|
|
197
150
|
)
|
|
198
151
|
|
|
199
152
|
return input_parameters
|
|
@@ -346,94 +299,46 @@ class FitResult:
|
|
|
346
299
|
|
|
347
300
|
return value
|
|
348
301
|
|
|
349
|
-
def to_chain(self, name: str
|
|
302
|
+
def to_chain(self, name: str) -> Chain:
|
|
350
303
|
"""
|
|
351
304
|
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
352
305
|
|
|
353
306
|
Parameters:
|
|
354
307
|
name: The name of the chain.
|
|
355
|
-
parameters_type: The parameters_type to include in the chain.
|
|
356
308
|
"""
|
|
357
309
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
key
|
|
363
|
-
for key in obs_id.posterior.keys()
|
|
364
|
-
if (key.startswith("_") or key.startswith("bkg"))
|
|
365
|
-
]
|
|
366
|
-
elif parameters_type == "bkg":
|
|
367
|
-
keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
|
|
368
|
-
else:
|
|
369
|
-
raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
|
|
370
|
-
|
|
371
|
-
obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
|
|
372
|
-
chain = Chain.from_arviz(obs_id, name)
|
|
373
|
-
|
|
374
|
-
"""
|
|
375
|
-
chain.samples.columns = [
|
|
376
|
-
format_parameters(parameter) for parameter in chain.samples.columns
|
|
310
|
+
keys_to_drop = [
|
|
311
|
+
key
|
|
312
|
+
for key in self.inference_data.posterior.keys()
|
|
313
|
+
if (key.startswith("_") or key.startswith("bkg"))
|
|
377
314
|
]
|
|
378
|
-
"""
|
|
379
315
|
|
|
380
|
-
|
|
316
|
+
reduced_id = az.extract(
|
|
317
|
+
self.inference_data,
|
|
318
|
+
var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
|
|
319
|
+
group="posterior",
|
|
320
|
+
)
|
|
381
321
|
|
|
382
|
-
|
|
383
|
-
def samples_haiku(self) -> HaikuDict[ArrayLike]:
|
|
384
|
-
"""
|
|
385
|
-
Haiku-like structure for the samples e.g.
|
|
386
|
-
|
|
387
|
-
```
|
|
388
|
-
{
|
|
389
|
-
'powerlaw_1' :
|
|
390
|
-
{
|
|
391
|
-
'alpha': ...,
|
|
392
|
-
'amplitude': ...
|
|
393
|
-
},
|
|
394
|
-
|
|
395
|
-
'blackbody_1':
|
|
396
|
-
{
|
|
397
|
-
'kT': ...,
|
|
398
|
-
'norm': ...
|
|
399
|
-
},
|
|
400
|
-
|
|
401
|
-
'tbabs_1':
|
|
402
|
-
{
|
|
403
|
-
'nH': ...
|
|
404
|
-
}
|
|
405
|
-
}
|
|
406
|
-
```
|
|
322
|
+
df_list = []
|
|
407
323
|
|
|
408
|
-
|
|
324
|
+
for var, array in reduced_id.data_vars.items():
|
|
325
|
+
extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
|
|
409
326
|
|
|
410
|
-
|
|
327
|
+
if extra_dims:
|
|
328
|
+
dim = extra_dims[
|
|
329
|
+
0
|
|
330
|
+
] # We only support the case where the extra dimension comes from the observations
|
|
411
331
|
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
332
|
+
for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
|
|
333
|
+
df = array.loc[{dim: coord}].to_pandas()
|
|
334
|
+
df.name += f"\n[{obs_id}]"
|
|
335
|
+
df_list.append(df)
|
|
336
|
+
else:
|
|
337
|
+
df_list.append(array.to_pandas())
|
|
416
338
|
|
|
417
|
-
|
|
339
|
+
df = pd.concat(df_list, axis=1)
|
|
418
340
|
|
|
419
|
-
|
|
420
|
-
def samples_flat(self) -> dict[str, ArrayLike]:
|
|
421
|
-
"""
|
|
422
|
-
Flat structure for the samples e.g.
|
|
423
|
-
|
|
424
|
-
```
|
|
425
|
-
{
|
|
426
|
-
'powerlaw_1_alpha': ...,
|
|
427
|
-
'powerlaw_1_amplitude': ...,
|
|
428
|
-
'blackbody_1_kT': ...,
|
|
429
|
-
'blackbody_1_norm': ...,
|
|
430
|
-
'tbabs_1_nH': ...,
|
|
431
|
-
}
|
|
432
|
-
```
|
|
433
|
-
"""
|
|
434
|
-
var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
|
|
435
|
-
posterior = az.extract(self.inference_data, var_names=var_names)
|
|
436
|
-
return {key: posterior[key].data for key in var_names}
|
|
341
|
+
return Chain(samples=df, name=name)
|
|
437
342
|
|
|
438
343
|
@property
|
|
439
344
|
def log_likelihood(self) -> xr.Dataset:
|
|
@@ -475,12 +380,20 @@ class FitResult:
|
|
|
475
380
|
|
|
476
381
|
def plot_ppc(
|
|
477
382
|
self,
|
|
478
|
-
|
|
383
|
+
n_sigmas: int = 1,
|
|
479
384
|
x_unit: str | u.Unit = "keV",
|
|
480
385
|
y_type: Literal[
|
|
481
386
|
"counts", "countrate", "photon_flux", "photon_flux_density"
|
|
482
387
|
] = "photon_flux_density",
|
|
483
|
-
|
|
388
|
+
plot_background: bool = True,
|
|
389
|
+
plot_components: bool = False,
|
|
390
|
+
scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
|
|
391
|
+
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
392
|
+
style: str | Any = "default",
|
|
393
|
+
title: str | None = None,
|
|
394
|
+
figsize: tuple[float, float] = (6, 6),
|
|
395
|
+
x_lims: tuple[float, float] | None = None,
|
|
396
|
+
) -> list[plt.Figure]:
|
|
484
397
|
r"""
|
|
485
398
|
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
|
|
486
399
|
following formula:
|
|
@@ -489,15 +402,24 @@ class FitResult:
|
|
|
489
402
|
{(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
|
|
490
403
|
|
|
491
404
|
Parameters:
|
|
492
|
-
|
|
405
|
+
n_sigmas: The number of sigmas to plot the envelops.
|
|
493
406
|
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
494
407
|
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
|
|
408
|
+
plot_background: Whether to plot the background model if it is included in the fit.
|
|
409
|
+
plot_components: Whether to plot the components of the model separately.
|
|
410
|
+
scale: The axes scaling
|
|
411
|
+
alpha_envelope: The transparency range for envelops
|
|
412
|
+
style: The style of the plot. It can be either a string or a matplotlib style context.
|
|
413
|
+
title: The title of the plot.
|
|
414
|
+
figsize: The size of the figure.
|
|
415
|
+
x_lims: The limits of the x-axis.
|
|
495
416
|
|
|
496
417
|
Returns:
|
|
497
|
-
|
|
418
|
+
A list of matplotlib figures for each observation in the model.
|
|
498
419
|
"""
|
|
499
420
|
|
|
500
421
|
obsconf_container = self.obsconfs
|
|
422
|
+
figure_list = []
|
|
501
423
|
x_unit = u.Unit(x_unit)
|
|
502
424
|
|
|
503
425
|
match y_type:
|
|
@@ -514,60 +436,24 @@ class FitResult:
|
|
|
514
436
|
f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
|
|
515
437
|
)
|
|
516
438
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
len(obsconf_container),
|
|
527
|
-
figsize=(6 * len(obsconf_container), 6),
|
|
528
|
-
sharex=True,
|
|
529
|
-
height_ratios=[0.7, 0.3],
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
plot_ylabels_once = True
|
|
439
|
+
with plt.style.context(style):
|
|
440
|
+
for obs_id, obsconf in obsconf_container.items():
|
|
441
|
+
fig, ax = plt.subplots(
|
|
442
|
+
2,
|
|
443
|
+
1,
|
|
444
|
+
figsize=figsize,
|
|
445
|
+
sharex="col",
|
|
446
|
+
height_ratios=[0.7, 0.3],
|
|
447
|
+
)
|
|
533
448
|
|
|
534
|
-
for name, obsconf, ax in zip(
|
|
535
|
-
obsconf_container.keys(),
|
|
536
|
-
obsconf_container.values(),
|
|
537
|
-
axs.T if len(obsconf_container) > 1 else [axs],
|
|
538
|
-
):
|
|
539
449
|
legend_plots = []
|
|
540
450
|
legend_labels = []
|
|
451
|
+
|
|
541
452
|
count = az.extract(
|
|
542
|
-
self.inference_data, var_names=f"obs_{
|
|
453
|
+
self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
|
|
543
454
|
).values.T
|
|
544
|
-
bkg_count = (
|
|
545
|
-
None
|
|
546
|
-
if self.background_model is None
|
|
547
|
-
else az.extract(
|
|
548
|
-
self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
|
|
549
|
-
).values.T
|
|
550
|
-
)
|
|
551
455
|
|
|
552
|
-
xbins = obsconf
|
|
553
|
-
xbins = xbins.to(x_unit, u.spectral())
|
|
554
|
-
|
|
555
|
-
# This compute the total effective area within all bins
|
|
556
|
-
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
557
|
-
exposure = obsconf.exposure.data * u.s
|
|
558
|
-
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
559
|
-
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
560
|
-
e_grid = np.linspace(*xbins, 10)
|
|
561
|
-
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
562
|
-
integrated_arf = (
|
|
563
|
-
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
564
|
-
/ (
|
|
565
|
-
np.abs(
|
|
566
|
-
xbins[1] - xbins[0]
|
|
567
|
-
) # Must fold in abs because some units reverse the ordering of the bins
|
|
568
|
-
)
|
|
569
|
-
* u.cm**2
|
|
570
|
-
)
|
|
456
|
+
xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
|
|
571
457
|
|
|
572
458
|
match y_type:
|
|
573
459
|
case "counts":
|
|
@@ -580,133 +466,125 @@ class FitResult:
|
|
|
580
466
|
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
581
467
|
|
|
582
468
|
y_samples = (count * u.ct / denominator).to(y_units)
|
|
583
|
-
|
|
584
|
-
y_observed_low = (
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
/ denominator
|
|
588
|
-
).to(y_units)
|
|
589
|
-
y_observed_high = (
|
|
590
|
-
nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
|
|
591
|
-
* u.ct
|
|
592
|
-
/ denominator
|
|
593
|
-
).to(y_units)
|
|
469
|
+
|
|
470
|
+
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
471
|
+
obsconf.folded_counts.data, denominator, y_units
|
|
472
|
+
)
|
|
594
473
|
|
|
595
474
|
# Use the helper function to plot the data and posterior predictive
|
|
596
|
-
|
|
475
|
+
model_plot = _plot_binned_samples_with_error(
|
|
597
476
|
ax[0],
|
|
598
477
|
xbins.value,
|
|
599
|
-
y_samples
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
478
|
+
y_samples.value,
|
|
479
|
+
color=SPECTRUM_COLOR,
|
|
480
|
+
n_sigmas=n_sigmas,
|
|
481
|
+
alpha_envelope=alpha_envelope,
|
|
603
482
|
)
|
|
604
483
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
484
|
+
true_data_plot = _plot_poisson_data_with_error(
|
|
485
|
+
ax[0],
|
|
486
|
+
xbins.value,
|
|
609
487
|
y_observed.value,
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
],
|
|
615
|
-
color="black",
|
|
616
|
-
linestyle="none",
|
|
617
|
-
alpha=0.3,
|
|
618
|
-
capsize=2,
|
|
488
|
+
y_observed_low.value,
|
|
489
|
+
y_observed_high.value,
|
|
490
|
+
color=SPECTRUM_DATA_COLOR,
|
|
491
|
+
alpha=0.7,
|
|
619
492
|
)
|
|
620
493
|
|
|
621
494
|
legend_plots.append((true_data_plot,))
|
|
622
495
|
legend_labels.append("Observed")
|
|
496
|
+
legend_plots += model_plot
|
|
497
|
+
legend_labels.append("Model")
|
|
498
|
+
|
|
499
|
+
# Plot the residuals
|
|
500
|
+
residual_samples = (obsconf.folded_counts.data - count) / np.diff(
|
|
501
|
+
np.percentile(count, [16, 84], axis=0), axis=0
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
_plot_binned_samples_with_error(
|
|
505
|
+
ax[1],
|
|
506
|
+
xbins.value,
|
|
507
|
+
residual_samples,
|
|
508
|
+
color=SPECTRUM_COLOR,
|
|
509
|
+
n_sigmas=n_sigmas,
|
|
510
|
+
alpha_envelope=alpha_envelope,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
if plot_components:
|
|
514
|
+
for (component_name, count), color in zip(
|
|
515
|
+
self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
|
|
516
|
+
):
|
|
517
|
+
# _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
|
|
518
|
+
y_samples = (
|
|
519
|
+
count.reshape((count.shape[0] * count.shape[1], -1))
|
|
520
|
+
* u.ct
|
|
521
|
+
/ denominator
|
|
522
|
+
).to(y_units)
|
|
523
|
+
component_plot = _plot_binned_samples_with_error(
|
|
524
|
+
ax[0],
|
|
525
|
+
xbins.value,
|
|
526
|
+
y_samples.value,
|
|
527
|
+
color=color,
|
|
528
|
+
linestyle="dashdot",
|
|
529
|
+
n_sigmas=n_sigmas,
|
|
530
|
+
alpha_envelope=alpha_envelope,
|
|
531
|
+
)
|
|
623
532
|
|
|
624
|
-
|
|
533
|
+
name = component_name.split("*")[-1]
|
|
534
|
+
|
|
535
|
+
legend_plots += component_plot
|
|
536
|
+
legend_labels.append(name)
|
|
537
|
+
|
|
538
|
+
if self.background_model is not None and plot_background:
|
|
625
539
|
# We plot the background only if it is included in the fit, i.e. by subtracting
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
540
|
+
bkg_count = (
|
|
541
|
+
None
|
|
542
|
+
if self.background_model is None
|
|
543
|
+
else az.extract(
|
|
544
|
+
self.inference_data,
|
|
545
|
+
var_names=f"bkg_{obs_id}",
|
|
546
|
+
group="posterior_predictive",
|
|
547
|
+
).values.T
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
|
|
551
|
+
|
|
552
|
+
y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
|
|
553
|
+
_error_bars_for_observed_data(
|
|
554
|
+
obsconf.folded_background.data, denominator, y_units
|
|
555
|
+
)
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
model_bkg_plot = _plot_binned_samples_with_error(
|
|
643
559
|
ax[0],
|
|
644
560
|
xbins.value,
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
561
|
+
y_samples_bkg.value,
|
|
562
|
+
color=BACKGROUND_COLOR,
|
|
563
|
+
alpha_envelope=alpha_envelope,
|
|
564
|
+
n_sigmas=n_sigmas,
|
|
649
565
|
)
|
|
650
566
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
np.sqrt(xbins.value[0] * xbins.value[1]),
|
|
567
|
+
true_bkg_plot = _plot_poisson_data_with_error(
|
|
568
|
+
ax[0],
|
|
569
|
+
xbins.value,
|
|
655
570
|
y_observed_bkg.value,
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
],
|
|
661
|
-
color="black",
|
|
662
|
-
linestyle="none",
|
|
663
|
-
alpha=0.3,
|
|
664
|
-
capsize=2,
|
|
571
|
+
y_observed_bkg_low.value,
|
|
572
|
+
y_observed_bkg_high.value,
|
|
573
|
+
color=BACKGROUND_DATA_COLOR,
|
|
574
|
+
alpha=0.7,
|
|
665
575
|
)
|
|
666
576
|
|
|
667
577
|
legend_plots.append((true_bkg_plot,))
|
|
668
578
|
legend_labels.append("Observed (bkg)")
|
|
579
|
+
legend_plots += model_bkg_plot
|
|
580
|
+
legend_labels.append("Model (bkg)")
|
|
669
581
|
|
|
670
|
-
|
|
671
|
-
np.percentile(count, percentile, axis=0), axis=0
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
residuals = np.percentile(
|
|
675
|
-
residual_samples,
|
|
676
|
-
percentile,
|
|
677
|
-
axis=0,
|
|
678
|
-
)
|
|
679
|
-
|
|
680
|
-
median_residuals = np.median(
|
|
681
|
-
residual_samples,
|
|
682
|
-
axis=0,
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
ax[1].stairs(
|
|
686
|
-
residuals[1],
|
|
687
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
688
|
-
baseline=list(residuals[0]),
|
|
689
|
-
alpha=0.3,
|
|
690
|
-
facecolor=color,
|
|
691
|
-
fill=True,
|
|
692
|
-
)
|
|
693
|
-
|
|
694
|
-
ax[1].stairs(
|
|
695
|
-
median_residuals,
|
|
696
|
-
edges=[*list(xbins.value[0]), xbins.value[1][-1]],
|
|
697
|
-
color=color,
|
|
698
|
-
alpha=0.7,
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
max_residuals = np.max(np.abs(residuals))
|
|
582
|
+
max_residuals = np.max(np.abs(residual_samples))
|
|
702
583
|
|
|
703
584
|
ax[0].loglog()
|
|
704
585
|
ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
708
|
-
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
709
|
-
plot_ylabels_once = False
|
|
586
|
+
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
587
|
+
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
710
588
|
|
|
711
589
|
match getattr(x_unit, "physical_type"):
|
|
712
590
|
case "length":
|
|
@@ -721,24 +599,45 @@ class FitResult:
|
|
|
721
599
|
f"Must be 'length', 'energy' or 'frequency'"
|
|
722
600
|
)
|
|
723
601
|
|
|
724
|
-
ax[1].axhline(0, color=
|
|
725
|
-
ax[1].axhline(-3, color=
|
|
726
|
-
ax[1].axhline(3, color=
|
|
602
|
+
ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
|
|
603
|
+
ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
604
|
+
ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
|
|
727
605
|
|
|
728
|
-
# ax[1].set_xticks(xticks, labels=xticks)
|
|
729
|
-
# ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
|
|
730
606
|
ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
|
|
731
607
|
ax[1].set_yticks(range(-3, 4), minor=True)
|
|
732
608
|
|
|
733
609
|
ax[0].set_xlim(xbins.value.min(), xbins.value.max())
|
|
734
610
|
|
|
735
611
|
ax[0].legend(legend_plots, legend_labels)
|
|
736
|
-
|
|
612
|
+
|
|
613
|
+
match scale:
|
|
614
|
+
case "linear":
|
|
615
|
+
ax[0].set_xscale("linear")
|
|
616
|
+
ax[0].set_yscale("linear")
|
|
617
|
+
case "semilogx":
|
|
618
|
+
ax[0].set_xscale("log")
|
|
619
|
+
ax[0].set_yscale("linear")
|
|
620
|
+
case "semilogy":
|
|
621
|
+
ax[0].set_xscale("linear")
|
|
622
|
+
ax[0].set_yscale("log")
|
|
623
|
+
case "loglog":
|
|
624
|
+
ax[0].set_xscale("log")
|
|
625
|
+
ax[0].set_yscale("log")
|
|
626
|
+
|
|
627
|
+
if x_lims is not None:
|
|
628
|
+
ax[0].set_xlim(*x_lims)
|
|
629
|
+
|
|
737
630
|
fig.align_ylabels()
|
|
738
631
|
plt.subplots_adjust(hspace=0.0)
|
|
739
632
|
fig.tight_layout()
|
|
633
|
+
figure_list.append(fig)
|
|
634
|
+
fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
|
|
635
|
+
# fig.show()
|
|
636
|
+
|
|
637
|
+
plt.tight_layout()
|
|
638
|
+
plt.show()
|
|
740
639
|
|
|
741
|
-
|
|
640
|
+
return figure_list
|
|
742
641
|
|
|
743
642
|
def table(self) -> str:
|
|
744
643
|
r"""
|
|
@@ -765,7 +664,7 @@ class FitResult:
|
|
|
765
664
|
"""
|
|
766
665
|
|
|
767
666
|
consumer = ChainConsumer()
|
|
768
|
-
consumer.add_chain(self.to_chain(
|
|
667
|
+
consumer.add_chain(self.to_chain("Results"))
|
|
769
668
|
consumer.set_plot_config(config)
|
|
770
669
|
|
|
771
670
|
# Context for default mpl style
|