jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,63 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ import numpyro
6
+
7
+ from jax.experimental.sparse import BCOO
8
+ from jax.typing import ArrayLike
9
+ from numpyro.distributions import Distribution
10
+
11
+ if TYPE_CHECKING:
12
+ from ..data import ObsConfiguration
13
+ from ..model.abc import SpectralModel
14
+ from ..util.typing import PriorDictType
15
+
16
+
17
+ def forward_model(
18
+ model: "SpectralModel",
19
+ parameters,
20
+ obs_configuration: "ObsConfiguration",
21
+ sparse=False,
22
+ ):
23
+ energies = np.asarray(obs_configuration.in_energies)
24
+
25
+ if sparse:
26
+ # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
27
+ transfer_matrix = BCOO.from_scipy_sparse(
28
+ obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
29
+ )
30
+
31
+ else:
32
+ transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())
33
+
34
+ expected_counts = transfer_matrix @ model.photon_flux(parameters, *energies)
35
+
36
+ # The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
37
+ return jnp.clip(expected_counts, a_min=1e-6)
38
+
39
+
40
+ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
41
+ """
42
+ Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
43
+ Must be used within a numpyro model.
44
+ """
45
+ parameters = {}
46
+
47
+ for key, value in prior.items():
48
+ # Split the key to extract the module name and parameter name
49
+ module_name, param_name = key.rsplit("_", 1)
50
+ if isinstance(value, Distribution):
51
+ parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
52
+ f"{prefix}{module_name}_{param_name}", value
53
+ )
54
+
55
+ elif isinstance(value, ArrayLike):
56
+ parameters[key] = jnp.ones(expand_shape) * value
57
+
58
+ else:
59
+ raise ValueError(
60
+ f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
61
+ )
62
+
63
+ return parameters
jaxspec/analysis/_plot.py CHANGED
@@ -1,22 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import catppuccin
1
4
  import matplotlib.pyplot as plt
2
5
  import numpy as np
3
6
 
7
+ from astropy import units as u
8
+ from catppuccin.extras.matplotlib import load_color
9
+ from cycler import cycler
4
10
  from jax.typing import ArrayLike
5
- from scipy.stats import nbinom
11
+ from scipy.integrate import trapezoid
12
+ from scipy.stats import nbinom, norm
13
+
14
+ from jaxspec.data import ObsConfiguration
15
+
16
+ PALETTE = catppuccin.PALETTE.latte
17
+
18
+ COLOR_CYCLE = [
19
+ load_color(PALETTE.identifier, color)
20
+ for color in ["sky", "teal", "green", "yellow", "peach", "maroon", "red", "pink", "mauve"][::-1]
21
+ ]
22
+
23
+ LINESTYLE_CYCLE = ["dashed", "dotted", "dashdot", "solid"]
24
+
25
+ SPECS_CYCLE = cycler(linestyle=LINESTYLE_CYCLE) * cycler(color=COLOR_CYCLE)
26
+
27
+ SPECTRUM_COLOR = load_color(PALETTE.identifier, "blue")
28
+ SPECTRUM_DATA_COLOR = load_color(PALETTE.identifier, "overlay2")
29
+ BACKGROUND_COLOR = load_color(PALETTE.identifier, "sapphire")
30
+ BACKGROUND_DATA_COLOR = load_color(PALETTE.identifier, "overlay0")
31
+
32
+
33
+ def sigma_to_percentile_intervals(sigmas):
34
+ intervals = []
35
+ for sigma in sigmas:
36
+ lower_bound = 100 * norm.cdf(-sigma)
37
+ upper_bound = 100 * norm.cdf(sigma)
38
+ intervals.append((lower_bound, upper_bound))
39
+ return intervals
6
40
 
7
41
 
8
42
  def _plot_poisson_data_with_error(
9
43
  ax: plt.Axes,
10
44
  x_bins: ArrayLike,
11
45
  y: ArrayLike,
12
- percentiles: tuple = (16, 84),
46
+ y_low: ArrayLike,
47
+ y_high: ArrayLike,
48
+ color=SPECTRUM_DATA_COLOR,
49
+ linestyle="none",
50
+ alpha=0.3,
13
51
  ):
14
52
  """
15
53
  Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
16
54
  distributed according to a Gamma RV.
17
55
  """
18
- y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
19
- y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
20
56
 
21
57
  ax_to_plot = ax.errorbar(
22
58
  np.sqrt(x_bins[0] * x_bins[1]),
@@ -26,10 +62,133 @@ def _plot_poisson_data_with_error(
26
62
  y - y_low,
27
63
  y_high - y,
28
64
  ],
29
- color="black",
30
- linestyle="none",
31
- alpha=0.3,
65
+ color=color,
66
+ linestyle=linestyle,
67
+ alpha=alpha,
32
68
  capsize=2,
33
69
  )
34
70
 
35
71
  return ax_to_plot
72
+
73
+
74
+ def _plot_binned_samples_with_error(
75
+ ax: plt.Axes,
76
+ x_bins: ArrayLike,
77
+ y_samples: ArrayLike,
78
+ color=SPECTRUM_COLOR,
79
+ alpha_median: float = 0.7,
80
+ alpha_envelope: (float, float) = (0.15, 0.25),
81
+ linestyle="solid",
82
+ n_sigmas=3,
83
+ ):
84
+ """
85
+ Helper function to plot the posterior predictive distribution of the model. The function
86
+ computes the percentiles of the posterior predictive distribution and plot them as a shaded
87
+ area. If the observed data is provided, it is also plotted as a step function.
88
+
89
+ Parameters:
90
+ x_bins: The bin edges of the data (2 x N).
91
+ y_samples: The samples of the posterior predictive distribution (Samples X N).
92
+ ax: The matplotlib axes object.
93
+ color: The color of the posterior predictive distribution.
94
+ """
95
+
96
+ median = ax.stairs(
97
+ list(np.median(y_samples, axis=0)),
98
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
99
+ color=color,
100
+ alpha=alpha_median,
101
+ linestyle=linestyle,
102
+ )
103
+
104
+ # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
105
+ (envelope,) = ax.fill(np.nan, np.nan, alpha=alpha_envelope[-1], facecolor=color)
106
+
107
+ if n_sigmas == 1:
108
+ alpha_envelope = (alpha_envelope[1], alpha_envelope[0])
109
+
110
+ for percentile, alpha in zip(
111
+ sigma_to_percentile_intervals(list(range(n_sigmas, 0, -1))),
112
+ np.linspace(*alpha_envelope, n_sigmas),
113
+ ):
114
+ percentiles = np.percentile(y_samples, percentile, axis=0)
115
+ ax.stairs(
116
+ percentiles[1],
117
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
118
+ baseline=percentiles[0],
119
+ alpha=alpha,
120
+ fill=True,
121
+ color=color,
122
+ )
123
+
124
+ return [(median, envelope)]
125
+
126
+
127
+ def _compute_effective_area(
128
+ obsconf: ObsConfiguration,
129
+ x_unit: str | u.Unit = "keV",
130
+ ):
131
+ """
132
+ Helper function to compute the bins and effective area of an observational configuration
133
+
134
+ Parameters:
135
+ obsconf: The observational configuration.
136
+ x_unit: The unit of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
137
+ """
138
+
139
+ # Note to Simon : do not change xbins[1] - xbins[0] to
140
+ # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
141
+ # and enable weird broadcasting that makes the plot fail
142
+
143
+ xbins = obsconf.out_energies * u.keV
144
+ xbins = xbins.to(x_unit, u.spectral())
145
+
146
+ # This computes the total effective area within all bins
147
+ # This is a bit weird since the following computation is equivalent to ignoring the RMF
148
+ exposure = obsconf.exposure.data * u.s
149
+ mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
150
+ mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
151
+ e_grid = np.linspace(*xbins, 10)
152
+ interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
153
+ integrated_arf = (
154
+ trapezoid(interpolated_arf, x=e_grid, axis=0)
155
+ / (
156
+ np.abs(
157
+ xbins[1] - xbins[0]
158
+ ) # Must fold in abs because some units reverse the ordering of the bins
159
+ )
160
+ * u.cm**2
161
+ )
162
+
163
+ return xbins, exposure, integrated_arf
164
+
165
+
166
+ def _error_bars_for_observed_data(observed_counts, denominator, units, sigma=1):
167
+ r"""
168
+ Compute the error bars for the observed data assuming a prior Gamma distribution
169
+
170
+ Parameters:
171
+ observed_counts: array of integer counts
172
+ denominator: normalization factor (e.g. effective area)
173
+ units: unit to convert to
174
+ sigma: dispersion to use for quantiles computation
175
+
176
+ Returns:
177
+ y_observed: observed counts in the desired units
178
+ y_observed_low: lower bound of the error bars
179
+ y_observed_high: upper bound of the error bars
180
+ """
181
+
182
+ percentile = sigma_to_percentile_intervals([sigma])[0]
183
+
184
+ y_observed = (observed_counts * u.ct / denominator).to(units)
185
+
186
+ y_observed_low = (
187
+ nbinom.ppf(percentile[0] / 100, observed_counts, 0.5) * u.ct / denominator
188
+ ).to(units)
189
+
190
+ y_observed_high = (
191
+ nbinom.ppf(percentile[1] / 100, observed_counts, 0.5) * u.ct / denominator
192
+ ).to(units)
193
+
194
+ return y_observed, y_observed_low, y_observed_high