jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/__init__.py +0 -0
- jaxspec/_fit/_build_model.py +63 -0
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +238 -336
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +68 -11
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +101 -140
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +94 -87
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/METADATA +36 -16
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.3.dist-info/RECORD +0 -31
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/_fit/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpyro
|
|
6
|
+
|
|
7
|
+
from jax.experimental.sparse import BCOO
|
|
8
|
+
from jax.typing import ArrayLike
|
|
9
|
+
from numpyro.distributions import Distribution
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..data import ObsConfiguration
|
|
13
|
+
from ..model.abc import SpectralModel
|
|
14
|
+
from ..util.typing import PriorDictType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def forward_model(
|
|
18
|
+
model: "SpectralModel",
|
|
19
|
+
parameters,
|
|
20
|
+
obs_configuration: "ObsConfiguration",
|
|
21
|
+
sparse=False,
|
|
22
|
+
):
|
|
23
|
+
energies = np.asarray(obs_configuration.in_energies)
|
|
24
|
+
|
|
25
|
+
if sparse:
|
|
26
|
+
# folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
27
|
+
transfer_matrix = BCOO.from_scipy_sparse(
|
|
28
|
+
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
else:
|
|
32
|
+
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
|
|
33
|
+
|
|
34
|
+
expected_counts = transfer_matrix @ model.photon_flux(parameters, *energies)
|
|
35
|
+
|
|
36
|
+
# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
|
|
37
|
+
return jnp.clip(expected_counts, a_min=1e-6)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
41
|
+
"""
|
|
42
|
+
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
43
|
+
Must be used within a numpyro model.
|
|
44
|
+
"""
|
|
45
|
+
parameters = {}
|
|
46
|
+
|
|
47
|
+
for key, value in prior.items():
|
|
48
|
+
# Split the key to extract the module name and parameter name
|
|
49
|
+
module_name, param_name = key.rsplit("_", 1)
|
|
50
|
+
if isinstance(value, Distribution):
|
|
51
|
+
parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
|
|
52
|
+
f"{prefix}{module_name}_{param_name}", value
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
elif isinstance(value, ArrayLike):
|
|
56
|
+
parameters[key] = jnp.ones(expand_shape) * value
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return parameters
|
jaxspec/analysis/_plot.py
CHANGED
|
@@ -1,22 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import catppuccin
|
|
1
4
|
import matplotlib.pyplot as plt
|
|
2
5
|
import numpy as np
|
|
3
6
|
|
|
7
|
+
from astropy import units as u
|
|
8
|
+
from catppuccin.extras.matplotlib import load_color
|
|
9
|
+
from cycler import cycler
|
|
4
10
|
from jax.typing import ArrayLike
|
|
5
|
-
from scipy.
|
|
11
|
+
from scipy.integrate import trapezoid
|
|
12
|
+
from scipy.stats import nbinom, norm
|
|
13
|
+
|
|
14
|
+
from jaxspec.data import ObsConfiguration
|
|
15
|
+
|
|
16
|
+
PALETTE = catppuccin.PALETTE.latte
|
|
17
|
+
|
|
18
|
+
COLOR_CYCLE = [
|
|
19
|
+
load_color(PALETTE.identifier, color)
|
|
20
|
+
for color in ["sky", "teal", "green", "yellow", "peach", "maroon", "red", "pink", "mauve"][::-1]
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
LINESTYLE_CYCLE = ["dashed", "dotted", "dashdot", "solid"]
|
|
24
|
+
|
|
25
|
+
SPECS_CYCLE = cycler(linestyle=LINESTYLE_CYCLE) * cycler(color=COLOR_CYCLE)
|
|
26
|
+
|
|
27
|
+
SPECTRUM_COLOR = load_color(PALETTE.identifier, "blue")
|
|
28
|
+
SPECTRUM_DATA_COLOR = load_color(PALETTE.identifier, "overlay2")
|
|
29
|
+
BACKGROUND_COLOR = load_color(PALETTE.identifier, "sapphire")
|
|
30
|
+
BACKGROUND_DATA_COLOR = load_color(PALETTE.identifier, "overlay0")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def sigma_to_percentile_intervals(sigmas):
|
|
34
|
+
intervals = []
|
|
35
|
+
for sigma in sigmas:
|
|
36
|
+
lower_bound = 100 * norm.cdf(-sigma)
|
|
37
|
+
upper_bound = 100 * norm.cdf(sigma)
|
|
38
|
+
intervals.append((lower_bound, upper_bound))
|
|
39
|
+
return intervals
|
|
6
40
|
|
|
7
41
|
|
|
8
42
|
def _plot_poisson_data_with_error(
|
|
9
43
|
ax: plt.Axes,
|
|
10
44
|
x_bins: ArrayLike,
|
|
11
45
|
y: ArrayLike,
|
|
12
|
-
|
|
46
|
+
y_low: ArrayLike,
|
|
47
|
+
y_high: ArrayLike,
|
|
48
|
+
color=SPECTRUM_DATA_COLOR,
|
|
49
|
+
linestyle="none",
|
|
50
|
+
alpha=0.3,
|
|
13
51
|
):
|
|
14
52
|
"""
|
|
15
53
|
Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
|
|
16
54
|
distributed according to a Gamma RV.
|
|
17
55
|
"""
|
|
18
|
-
y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
|
|
19
|
-
y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
|
|
20
56
|
|
|
21
57
|
ax_to_plot = ax.errorbar(
|
|
22
58
|
np.sqrt(x_bins[0] * x_bins[1]),
|
|
@@ -26,10 +62,133 @@ def _plot_poisson_data_with_error(
|
|
|
26
62
|
y - y_low,
|
|
27
63
|
y_high - y,
|
|
28
64
|
],
|
|
29
|
-
color=
|
|
30
|
-
linestyle=
|
|
31
|
-
alpha=
|
|
65
|
+
color=color,
|
|
66
|
+
linestyle=linestyle,
|
|
67
|
+
alpha=alpha,
|
|
32
68
|
capsize=2,
|
|
33
69
|
)
|
|
34
70
|
|
|
35
71
|
return ax_to_plot
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _plot_binned_samples_with_error(
|
|
75
|
+
ax: plt.Axes,
|
|
76
|
+
x_bins: ArrayLike,
|
|
77
|
+
y_samples: ArrayLike,
|
|
78
|
+
color=SPECTRUM_COLOR,
|
|
79
|
+
alpha_median: float = 0.7,
|
|
80
|
+
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
81
|
+
linestyle="solid",
|
|
82
|
+
n_sigmas=3,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Helper function to plot the posterior predictive distribution of the model. The function
|
|
86
|
+
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
87
|
+
area. If the observed data is provided, it is also plotted as a step function.
|
|
88
|
+
|
|
89
|
+
Parameters:
|
|
90
|
+
x_bins: The bin edges of the data (2 x N).
|
|
91
|
+
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
92
|
+
ax: The matplotlib axes object.
|
|
93
|
+
color: The color of the posterior predictive distribution.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
median = ax.stairs(
|
|
97
|
+
list(np.median(y_samples, axis=0)),
|
|
98
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
99
|
+
color=color,
|
|
100
|
+
alpha=alpha_median,
|
|
101
|
+
linestyle=linestyle,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
105
|
+
(envelope,) = ax.fill(np.nan, np.nan, alpha=alpha_envelope[-1], facecolor=color)
|
|
106
|
+
|
|
107
|
+
if n_sigmas == 1:
|
|
108
|
+
alpha_envelope = (alpha_envelope[1], alpha_envelope[0])
|
|
109
|
+
|
|
110
|
+
for percentile, alpha in zip(
|
|
111
|
+
sigma_to_percentile_intervals(list(range(n_sigmas, 0, -1))),
|
|
112
|
+
np.linspace(*alpha_envelope, n_sigmas),
|
|
113
|
+
):
|
|
114
|
+
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
115
|
+
ax.stairs(
|
|
116
|
+
percentiles[1],
|
|
117
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
118
|
+
baseline=percentiles[0],
|
|
119
|
+
alpha=alpha,
|
|
120
|
+
fill=True,
|
|
121
|
+
color=color,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return [(median, envelope)]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _compute_effective_area(
|
|
128
|
+
obsconf: ObsConfiguration,
|
|
129
|
+
x_unit: str | u.Unit = "keV",
|
|
130
|
+
):
|
|
131
|
+
"""
|
|
132
|
+
Helper function to compute the bins and effective area of an observational configuration
|
|
133
|
+
|
|
134
|
+
Parameters:
|
|
135
|
+
obsconf: The observational configuration.
|
|
136
|
+
x_unit: The unit of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
# Note to Simon : do not change xbins[1] - xbins[0] to
|
|
140
|
+
# np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
|
|
141
|
+
# and enable weird broadcasting that makes the plot fail
|
|
142
|
+
|
|
143
|
+
xbins = obsconf.out_energies * u.keV
|
|
144
|
+
xbins = xbins.to(x_unit, u.spectral())
|
|
145
|
+
|
|
146
|
+
# This computes the total effective area within all bins
|
|
147
|
+
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
148
|
+
exposure = obsconf.exposure.data * u.s
|
|
149
|
+
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
150
|
+
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
151
|
+
e_grid = np.linspace(*xbins, 10)
|
|
152
|
+
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
153
|
+
integrated_arf = (
|
|
154
|
+
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
155
|
+
/ (
|
|
156
|
+
np.abs(
|
|
157
|
+
xbins[1] - xbins[0]
|
|
158
|
+
) # Must fold in abs because some units reverse the ordering of the bins
|
|
159
|
+
)
|
|
160
|
+
* u.cm**2
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return xbins, exposure, integrated_arf
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _error_bars_for_observed_data(observed_counts, denominator, units, sigma=1):
|
|
167
|
+
r"""
|
|
168
|
+
Compute the error bars for the observed data assuming a prior Gamma distribution
|
|
169
|
+
|
|
170
|
+
Parameters:
|
|
171
|
+
observed_counts: array of integer counts
|
|
172
|
+
denominator: normalization factor (e.g. effective area)
|
|
173
|
+
units: unit to convert to
|
|
174
|
+
sigma: dispersion to use for quantiles computation
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
y_observed: observed counts in the desired units
|
|
178
|
+
y_observed_low: lower bound of the error bars
|
|
179
|
+
y_observed_high: upper bound of the error bars
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
percentile = sigma_to_percentile_intervals([sigma])[0]
|
|
183
|
+
|
|
184
|
+
y_observed = (observed_counts * u.ct / denominator).to(units)
|
|
185
|
+
|
|
186
|
+
y_observed_low = (
|
|
187
|
+
nbinom.ppf(percentile[0] / 100, observed_counts, 0.5) * u.ct / denominator
|
|
188
|
+
).to(units)
|
|
189
|
+
|
|
190
|
+
y_observed_high = (
|
|
191
|
+
nbinom.ppf(percentile[1] / 100, observed_counts, 0.5) * u.ct / denominator
|
|
192
|
+
).to(units)
|
|
193
|
+
|
|
194
|
+
return y_observed, y_observed_low, y_observed_high
|