jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +219 -330
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +56 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/METADATA +12 -13
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/_fit/_build_model.py
CHANGED
|
@@ -1,62 +1,25 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import haiku as hk
|
|
4
|
-
import numpy as np
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
5
3
|
import jax.numpy as jnp
|
|
6
|
-
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpyro
|
|
6
|
+
|
|
7
7
|
from jax.experimental.sparse import BCOO
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
9
|
-
from numpyro.distributions import Poisson
|
|
10
8
|
from jax.typing import ArrayLike
|
|
11
9
|
from numpyro.distributions import Distribution
|
|
12
10
|
|
|
13
|
-
|
|
14
11
|
if TYPE_CHECKING:
|
|
15
|
-
from ..model.abc import SpectralModel
|
|
16
12
|
from ..data import ObsConfiguration
|
|
17
|
-
from ..
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class CountForwardModel(hk.Module):
|
|
22
|
-
"""
|
|
23
|
-
A haiku module which allows to build the function that simulates the measured counts
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
# TODO: It has no point of being a haiku module, it should be a simple function
|
|
27
|
-
|
|
28
|
-
def __init__(self, model: 'SpectralModel', folding: 'ObsConfiguration', sparse=False):
|
|
29
|
-
super().__init__()
|
|
30
|
-
self.model = model
|
|
31
|
-
self.energies = jnp.asarray(folding.in_energies)
|
|
32
|
-
|
|
33
|
-
if (
|
|
34
|
-
sparse
|
|
35
|
-
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
36
|
-
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
37
|
-
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
else:
|
|
41
|
-
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
42
|
-
|
|
43
|
-
def __call__(self, parameters):
|
|
44
|
-
"""
|
|
45
|
-
Compute the count functions for a given observation.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
49
|
-
|
|
50
|
-
return jnp.clip(expected_counts, a_min=1e-6)
|
|
13
|
+
from ..model.abc import SpectralModel
|
|
14
|
+
from ..util.typing import PriorDictType
|
|
51
15
|
|
|
52
16
|
|
|
53
17
|
def forward_model(
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
18
|
+
model: "SpectralModel",
|
|
19
|
+
parameters,
|
|
20
|
+
obs_configuration: "ObsConfiguration",
|
|
21
|
+
sparse=False,
|
|
22
|
+
):
|
|
60
23
|
energies = np.asarray(obs_configuration.in_energies)
|
|
61
24
|
|
|
62
25
|
if sparse:
|
|
@@ -74,67 +37,27 @@ def forward_model(
|
|
|
74
37
|
return jnp.clip(expected_counts, a_min=1e-6)
|
|
75
38
|
|
|
76
39
|
|
|
77
|
-
def
|
|
78
|
-
obs,
|
|
79
|
-
model,
|
|
80
|
-
background_model,
|
|
81
|
-
name: str = "",
|
|
82
|
-
sparse: bool = False,
|
|
83
|
-
) -> Callable:
|
|
84
|
-
"""
|
|
85
|
-
Build a numpyro model for a given observation and spectral model.
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
def numpyro_model(prior_params, observed=True):
|
|
89
|
-
|
|
90
|
-
# Return the expected countrate for a set of parameters
|
|
91
|
-
obs_model = jax.jit(lambda par: forward_model(model, par, obs, sparse=sparse))
|
|
92
|
-
countrate = obs_model(prior_params)
|
|
93
|
-
|
|
94
|
-
# Handle the background model
|
|
95
|
-
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
96
|
-
bkg_countrate = background_model.numpyro_model(
|
|
97
|
-
obs, model, name="bkg_" + name, observed=observed
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
101
|
-
raise ValueError(
|
|
102
|
-
"Trying to fit a background model but no background is linked to this observation"
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
else:
|
|
106
|
-
bkg_countrate = 0.0
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
# Register the observed value
|
|
110
|
-
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
111
|
-
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
112
|
-
numpyro.sample(
|
|
113
|
-
"obs_" + name,
|
|
114
|
-
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
115
|
-
obs=obs.folded_counts.data if observed else None,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
return numpyro_model
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def build_prior(prior: 'PriorDictType', expand_shape: tuple = (), prefix=""):
|
|
40
|
+
def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
|
|
122
41
|
"""
|
|
123
42
|
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
124
43
|
Must be used within a numpyro model.
|
|
125
44
|
"""
|
|
126
|
-
parameters =
|
|
127
|
-
|
|
128
|
-
for
|
|
129
|
-
|
|
130
|
-
|
|
45
|
+
parameters = {}
|
|
46
|
+
|
|
47
|
+
for key, value in prior.items():
|
|
48
|
+
# Split the key to extract the module name and parameter name
|
|
49
|
+
module_name, param_name = key.rsplit("_", 1)
|
|
50
|
+
if isinstance(value, Distribution):
|
|
51
|
+
parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
|
|
52
|
+
f"{prefix}{module_name}_{param_name}", value
|
|
53
|
+
)
|
|
131
54
|
|
|
132
|
-
elif isinstance(
|
|
133
|
-
parameters[
|
|
55
|
+
elif isinstance(value, ArrayLike):
|
|
56
|
+
parameters[key] = jnp.ones(expand_shape) * value
|
|
134
57
|
|
|
135
58
|
else:
|
|
136
59
|
raise ValueError(
|
|
137
|
-
f"Invalid prior type {type(
|
|
60
|
+
f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
|
|
138
61
|
)
|
|
139
62
|
|
|
140
|
-
return parameters
|
|
63
|
+
return parameters
|
jaxspec/analysis/_plot.py
CHANGED
|
@@ -1,22 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import catppuccin
|
|
1
4
|
import matplotlib.pyplot as plt
|
|
2
5
|
import numpy as np
|
|
3
6
|
|
|
7
|
+
from astropy import units as u
|
|
8
|
+
from catppuccin.extras.matplotlib import load_color
|
|
9
|
+
from cycler import cycler
|
|
4
10
|
from jax.typing import ArrayLike
|
|
5
|
-
from scipy.
|
|
11
|
+
from scipy.integrate import trapezoid
|
|
12
|
+
from scipy.stats import nbinom, norm
|
|
13
|
+
|
|
14
|
+
from jaxspec.data import ObsConfiguration
|
|
15
|
+
|
|
16
|
+
PALETTE = catppuccin.PALETTE.latte
|
|
17
|
+
|
|
18
|
+
COLOR_CYCLE = [
|
|
19
|
+
load_color(PALETTE.identifier, color)
|
|
20
|
+
for color in ["sky", "teal", "green", "yellow", "peach", "maroon", "red", "pink", "mauve"][::-1]
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
LINESTYLE_CYCLE = ["dashed", "dotted", "dashdot", "solid"]
|
|
24
|
+
|
|
25
|
+
SPECS_CYCLE = cycler(linestyle=LINESTYLE_CYCLE) * cycler(color=COLOR_CYCLE)
|
|
26
|
+
|
|
27
|
+
SPECTRUM_COLOR = load_color(PALETTE.identifier, "blue")
|
|
28
|
+
SPECTRUM_DATA_COLOR = load_color(PALETTE.identifier, "overlay2")
|
|
29
|
+
BACKGROUND_COLOR = load_color(PALETTE.identifier, "sapphire")
|
|
30
|
+
BACKGROUND_DATA_COLOR = load_color(PALETTE.identifier, "overlay0")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def sigma_to_percentile_intervals(sigmas):
|
|
34
|
+
intervals = []
|
|
35
|
+
for sigma in sigmas:
|
|
36
|
+
lower_bound = 100 * norm.cdf(-sigma)
|
|
37
|
+
upper_bound = 100 * norm.cdf(sigma)
|
|
38
|
+
intervals.append((lower_bound, upper_bound))
|
|
39
|
+
return intervals
|
|
6
40
|
|
|
7
41
|
|
|
8
42
|
def _plot_poisson_data_with_error(
|
|
9
43
|
ax: plt.Axes,
|
|
10
44
|
x_bins: ArrayLike,
|
|
11
45
|
y: ArrayLike,
|
|
12
|
-
|
|
46
|
+
y_low: ArrayLike,
|
|
47
|
+
y_high: ArrayLike,
|
|
48
|
+
color=SPECTRUM_DATA_COLOR,
|
|
49
|
+
linestyle="none",
|
|
50
|
+
alpha=0.3,
|
|
13
51
|
):
|
|
14
52
|
"""
|
|
15
53
|
Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
|
|
16
54
|
distributed according to a Gamma RV.
|
|
17
55
|
"""
|
|
18
|
-
y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
|
|
19
|
-
y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
|
|
20
56
|
|
|
21
57
|
ax_to_plot = ax.errorbar(
|
|
22
58
|
np.sqrt(x_bins[0] * x_bins[1]),
|
|
@@ -26,10 +62,133 @@ def _plot_poisson_data_with_error(
|
|
|
26
62
|
y - y_low,
|
|
27
63
|
y_high - y,
|
|
28
64
|
],
|
|
29
|
-
color=
|
|
30
|
-
linestyle=
|
|
31
|
-
alpha=
|
|
65
|
+
color=color,
|
|
66
|
+
linestyle=linestyle,
|
|
67
|
+
alpha=alpha,
|
|
32
68
|
capsize=2,
|
|
33
69
|
)
|
|
34
70
|
|
|
35
71
|
return ax_to_plot
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _plot_binned_samples_with_error(
|
|
75
|
+
ax: plt.Axes,
|
|
76
|
+
x_bins: ArrayLike,
|
|
77
|
+
y_samples: ArrayLike,
|
|
78
|
+
color=SPECTRUM_COLOR,
|
|
79
|
+
alpha_median: float = 0.7,
|
|
80
|
+
alpha_envelope: (float, float) = (0.15, 0.25),
|
|
81
|
+
linestyle="solid",
|
|
82
|
+
n_sigmas=3,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Helper function to plot the posterior predictive distribution of the model. The function
|
|
86
|
+
computes the percentiles of the posterior predictive distribution and plot them as a shaded
|
|
87
|
+
area. If the observed data is provided, it is also plotted as a step function.
|
|
88
|
+
|
|
89
|
+
Parameters:
|
|
90
|
+
x_bins: The bin edges of the data (2 x N).
|
|
91
|
+
y_samples: The samples of the posterior predictive distribution (Samples X N).
|
|
92
|
+
ax: The matplotlib axes object.
|
|
93
|
+
color: The color of the posterior predictive distribution.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
median = ax.stairs(
|
|
97
|
+
list(np.median(y_samples, axis=0)),
|
|
98
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
99
|
+
color=color,
|
|
100
|
+
alpha=alpha_median,
|
|
101
|
+
linestyle=linestyle,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# The legend cannot handle fill_between, so we pass a fill to get a fancy icon
|
|
105
|
+
(envelope,) = ax.fill(np.nan, np.nan, alpha=alpha_envelope[-1], facecolor=color)
|
|
106
|
+
|
|
107
|
+
if n_sigmas == 1:
|
|
108
|
+
alpha_envelope = (alpha_envelope[1], alpha_envelope[0])
|
|
109
|
+
|
|
110
|
+
for percentile, alpha in zip(
|
|
111
|
+
sigma_to_percentile_intervals(list(range(n_sigmas, 0, -1))),
|
|
112
|
+
np.linspace(*alpha_envelope, n_sigmas),
|
|
113
|
+
):
|
|
114
|
+
percentiles = np.percentile(y_samples, percentile, axis=0)
|
|
115
|
+
ax.stairs(
|
|
116
|
+
percentiles[1],
|
|
117
|
+
edges=[*list(x_bins[0]), x_bins[1][-1]],
|
|
118
|
+
baseline=percentiles[0],
|
|
119
|
+
alpha=alpha,
|
|
120
|
+
fill=True,
|
|
121
|
+
color=color,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return [(median, envelope)]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _compute_effective_area(
|
|
128
|
+
obsconf: ObsConfiguration,
|
|
129
|
+
x_unit: str | u.Unit = "keV",
|
|
130
|
+
):
|
|
131
|
+
"""
|
|
132
|
+
Helper function to compute the bins and effective area of an observational configuration
|
|
133
|
+
|
|
134
|
+
Parameters:
|
|
135
|
+
obsconf: The observational configuration.
|
|
136
|
+
x_unit: The unit of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
# Note to Simon : do not change xbins[1] - xbins[0] to
|
|
140
|
+
# np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
|
|
141
|
+
# and enable weird broadcasting that makes the plot fail
|
|
142
|
+
|
|
143
|
+
xbins = obsconf.out_energies * u.keV
|
|
144
|
+
xbins = xbins.to(x_unit, u.spectral())
|
|
145
|
+
|
|
146
|
+
# This computes the total effective area within all bins
|
|
147
|
+
# This is a bit weird since the following computation is equivalent to ignoring the RMF
|
|
148
|
+
exposure = obsconf.exposure.data * u.s
|
|
149
|
+
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
150
|
+
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
151
|
+
e_grid = np.linspace(*xbins, 10)
|
|
152
|
+
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
153
|
+
integrated_arf = (
|
|
154
|
+
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
155
|
+
/ (
|
|
156
|
+
np.abs(
|
|
157
|
+
xbins[1] - xbins[0]
|
|
158
|
+
) # Must fold in abs because some units reverse the ordering of the bins
|
|
159
|
+
)
|
|
160
|
+
* u.cm**2
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return xbins, exposure, integrated_arf
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _error_bars_for_observed_data(observed_counts, denominator, units, sigma=1):
|
|
167
|
+
r"""
|
|
168
|
+
Compute the error bars for the observed data assuming a prior Gamma distribution
|
|
169
|
+
|
|
170
|
+
Parameters:
|
|
171
|
+
observed_counts: array of integer counts
|
|
172
|
+
denominator: normalization factor (e.g. effective area)
|
|
173
|
+
units: unit to convert to
|
|
174
|
+
sigma: dispersion to use for quantiles computation
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
y_observed: observed counts in the desired units
|
|
178
|
+
y_observed_low: lower bound of the error bars
|
|
179
|
+
y_observed_high: upper bound of the error bars
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
percentile = sigma_to_percentile_intervals([sigma])[0]
|
|
183
|
+
|
|
184
|
+
y_observed = (observed_counts * u.ct / denominator).to(units)
|
|
185
|
+
|
|
186
|
+
y_observed_low = (
|
|
187
|
+
nbinom.ppf(percentile[0] / 100, observed_counts, 0.5) * u.ct / denominator
|
|
188
|
+
).to(units)
|
|
189
|
+
|
|
190
|
+
y_observed_high = (
|
|
191
|
+
nbinom.ppf(percentile[1] / 100, observed_counts, 0.5) * u.ct / denominator
|
|
192
|
+
).to(units)
|
|
193
|
+
|
|
194
|
+
return y_observed, y_observed_low, y_observed_high
|