jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,62 +1,25 @@
1
- import jax
2
- import numpyro
3
- import haiku as hk
4
- import numpy as np
1
+ from typing import TYPE_CHECKING
2
+
5
3
  import jax.numpy as jnp
6
- from typing import Callable
4
+ import numpy as np
5
+ import numpyro
6
+
7
7
  from jax.experimental.sparse import BCOO
8
- from typing import TYPE_CHECKING
9
- from numpyro.distributions import Poisson
10
8
  from jax.typing import ArrayLike
11
9
  from numpyro.distributions import Distribution
12
10
 
13
-
14
11
  if TYPE_CHECKING:
15
- from ..model.abc import SpectralModel
16
12
  from ..data import ObsConfiguration
17
- from ..util.typing import PriorDictModel, PriorDictType
18
-
19
-
20
-
21
- class CountForwardModel(hk.Module):
22
- """
23
- A haiku module which allows to build the function that simulates the measured counts
24
- """
25
-
26
- # TODO: It has no point of being a haiku module, it should be a simple function
27
-
28
- def __init__(self, model: 'SpectralModel', folding: 'ObsConfiguration', sparse=False):
29
- super().__init__()
30
- self.model = model
31
- self.energies = jnp.asarray(folding.in_energies)
32
-
33
- if (
34
- sparse
35
- ): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
36
- self.transfer_matrix = BCOO.from_scipy_sparse(
37
- folding.transfer_matrix.data.to_scipy_sparse().tocsr()
38
- )
39
-
40
- else:
41
- self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
42
-
43
- def __call__(self, parameters):
44
- """
45
- Compute the count functions for a given observation.
46
- """
47
-
48
- expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
49
-
50
- return jnp.clip(expected_counts, a_min=1e-6)
13
+ from ..model.abc import SpectralModel
14
+ from ..util.typing import PriorDictType
51
15
 
52
16
 
53
17
  def forward_model(
54
- model: 'SpectralModel',
55
- parameters,
56
- obs_configuration: 'ObsConfiguration',
57
- sparse=False,
58
- ):
59
-
18
+ model: "SpectralModel",
19
+ parameters,
20
+ obs_configuration: "ObsConfiguration",
21
+ sparse=False,
22
+ ):
60
23
  energies = np.asarray(obs_configuration.in_energies)
61
24
 
62
25
  if sparse:
@@ -74,67 +37,27 @@ def forward_model(
74
37
  return jnp.clip(expected_counts, a_min=1e-6)
75
38
 
76
39
 
77
- def build_numpyro_model_for_single_obs(
78
- obs,
79
- model,
80
- background_model,
81
- name: str = "",
82
- sparse: bool = False,
83
- ) -> Callable:
84
- """
85
- Build a numpyro model for a given observation and spectral model.
86
- """
87
-
88
- def numpyro_model(prior_params, observed=True):
89
-
90
- # Return the expected countrate for a set of parameters
91
- obs_model = jax.jit(lambda par: forward_model(model, par, obs, sparse=sparse))
92
- countrate = obs_model(prior_params)
93
-
94
- # Handle the background model
95
- if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
96
- bkg_countrate = background_model.numpyro_model(
97
- obs, model, name="bkg_" + name, observed=observed
98
- )
99
-
100
- elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
101
- raise ValueError(
102
- "Trying to fit a background model but no background is linked to this observation"
103
- )
104
-
105
- else:
106
- bkg_countrate = 0.0
107
-
108
-
109
- # Register the observed value
110
- # This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
111
- with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
112
- numpyro.sample(
113
- "obs_" + name,
114
- Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
115
- obs=obs.folded_counts.data if observed else None,
116
- )
117
-
118
- return numpyro_model
119
-
120
-
121
- def build_prior(prior: 'PriorDictType', expand_shape: tuple = (), prefix=""):
40
+ def build_prior(prior: "PriorDictType", expand_shape: tuple = (), prefix=""):
122
41
  """
123
42
  Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
124
43
  Must be used within a numpyro model.
125
44
  """
126
- parameters = dict(hk.data_structures.to_haiku_dict(prior))
127
-
128
- for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
129
- if isinstance(sample, Distribution):
130
- parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
45
+ parameters = {}
46
+
47
+ for key, value in prior.items():
48
+ # Split the key to extract the module name and parameter name
49
+ module_name, param_name = key.rsplit("_", 1)
50
+ if isinstance(value, Distribution):
51
+ parameters[key] = jnp.ones(expand_shape) * numpyro.sample(
52
+ f"{prefix}{module_name}_{param_name}", value
53
+ )
131
54
 
132
- elif isinstance(sample, ArrayLike):
133
- parameters[m][n] = jnp.ones(expand_shape) * sample
55
+ elif isinstance(value, ArrayLike):
56
+ parameters[key] = jnp.ones(expand_shape) * value
134
57
 
135
58
  else:
136
59
  raise ValueError(
137
- f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
60
+ f"Invalid prior type {type(value)} for parameter {prefix}{module_name}_{param_name} : {value}"
138
61
  )
139
62
 
140
- return parameters
63
+ return parameters
jaxspec/analysis/_plot.py CHANGED
@@ -1,22 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import catppuccin
1
4
  import matplotlib.pyplot as plt
2
5
  import numpy as np
3
6
 
7
+ from astropy import units as u
8
+ from catppuccin.extras.matplotlib import load_color
9
+ from cycler import cycler
4
10
  from jax.typing import ArrayLike
5
- from scipy.stats import nbinom
11
+ from scipy.integrate import trapezoid
12
+ from scipy.stats import nbinom, norm
13
+
14
+ from jaxspec.data import ObsConfiguration
15
+
16
+ PALETTE = catppuccin.PALETTE.latte
17
+
18
+ COLOR_CYCLE = [
19
+ load_color(PALETTE.identifier, color)
20
+ for color in ["sky", "teal", "green", "yellow", "peach", "maroon", "red", "pink", "mauve"][::-1]
21
+ ]
22
+
23
+ LINESTYLE_CYCLE = ["dashed", "dotted", "dashdot", "solid"]
24
+
25
+ SPECS_CYCLE = cycler(linestyle=LINESTYLE_CYCLE) * cycler(color=COLOR_CYCLE)
26
+
27
+ SPECTRUM_COLOR = load_color(PALETTE.identifier, "blue")
28
+ SPECTRUM_DATA_COLOR = load_color(PALETTE.identifier, "overlay2")
29
+ BACKGROUND_COLOR = load_color(PALETTE.identifier, "sapphire")
30
+ BACKGROUND_DATA_COLOR = load_color(PALETTE.identifier, "overlay0")
31
+
32
+
33
+ def sigma_to_percentile_intervals(sigmas):
34
+ intervals = []
35
+ for sigma in sigmas:
36
+ lower_bound = 100 * norm.cdf(-sigma)
37
+ upper_bound = 100 * norm.cdf(sigma)
38
+ intervals.append((lower_bound, upper_bound))
39
+ return intervals
6
40
 
7
41
 
8
42
  def _plot_poisson_data_with_error(
9
43
  ax: plt.Axes,
10
44
  x_bins: ArrayLike,
11
45
  y: ArrayLike,
12
- percentiles: tuple = (16, 84),
46
+ y_low: ArrayLike,
47
+ y_high: ArrayLike,
48
+ color=SPECTRUM_DATA_COLOR,
49
+ linestyle="none",
50
+ alpha=0.3,
13
51
  ):
14
52
  """
15
53
  Plot Poisson data with error bars. We extrapolate the intrinsic error of the observation assuming a prior rate
16
54
  distributed according to a Gamma RV.
17
55
  """
18
- y_low = nbinom.ppf(percentiles[0] / 100, y, 0.5)
19
- y_high = nbinom.ppf(percentiles[1] / 100, y, 0.5)
20
56
 
21
57
  ax_to_plot = ax.errorbar(
22
58
  np.sqrt(x_bins[0] * x_bins[1]),
@@ -26,10 +62,133 @@ def _plot_poisson_data_with_error(
26
62
  y - y_low,
27
63
  y_high - y,
28
64
  ],
29
- color="black",
30
- linestyle="none",
31
- alpha=0.3,
65
+ color=color,
66
+ linestyle=linestyle,
67
+ alpha=alpha,
32
68
  capsize=2,
33
69
  )
34
70
 
35
71
  return ax_to_plot
72
+
73
+
74
+ def _plot_binned_samples_with_error(
75
+ ax: plt.Axes,
76
+ x_bins: ArrayLike,
77
+ y_samples: ArrayLike,
78
+ color=SPECTRUM_COLOR,
79
+ alpha_median: float = 0.7,
80
+ alpha_envelope: (float, float) = (0.15, 0.25),
81
+ linestyle="solid",
82
+ n_sigmas=3,
83
+ ):
84
+ """
85
+ Helper function to plot the posterior predictive distribution of the model. The function
86
+ computes the percentiles of the posterior predictive distribution and plot them as a shaded
87
+ area. If the observed data is provided, it is also plotted as a step function.
88
+
89
+ Parameters:
90
+ x_bins: The bin edges of the data (2 x N).
91
+ y_samples: The samples of the posterior predictive distribution (Samples X N).
92
+ ax: The matplotlib axes object.
93
+ color: The color of the posterior predictive distribution.
94
+ """
95
+
96
+ median = ax.stairs(
97
+ list(np.median(y_samples, axis=0)),
98
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
99
+ color=color,
100
+ alpha=alpha_median,
101
+ linestyle=linestyle,
102
+ )
103
+
104
+ # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
105
+ (envelope,) = ax.fill(np.nan, np.nan, alpha=alpha_envelope[-1], facecolor=color)
106
+
107
+ if n_sigmas == 1:
108
+ alpha_envelope = (alpha_envelope[1], alpha_envelope[0])
109
+
110
+ for percentile, alpha in zip(
111
+ sigma_to_percentile_intervals(list(range(n_sigmas, 0, -1))),
112
+ np.linspace(*alpha_envelope, n_sigmas),
113
+ ):
114
+ percentiles = np.percentile(y_samples, percentile, axis=0)
115
+ ax.stairs(
116
+ percentiles[1],
117
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
118
+ baseline=percentiles[0],
119
+ alpha=alpha,
120
+ fill=True,
121
+ color=color,
122
+ )
123
+
124
+ return [(median, envelope)]
125
+
126
+
127
+ def _compute_effective_area(
128
+ obsconf: ObsConfiguration,
129
+ x_unit: str | u.Unit = "keV",
130
+ ):
131
+ """
132
+ Helper function to compute the bins and effective area of an observational configuration
133
+
134
+ Parameters:
135
+ obsconf: The observational configuration.
136
+ x_unit: The unit of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
137
+ """
138
+
139
+ # Note to Simon : do not change xbins[1] - xbins[0] to
140
+ # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
141
+ # and enable weird broadcasting that makes the plot fail
142
+
143
+ xbins = obsconf.out_energies * u.keV
144
+ xbins = xbins.to(x_unit, u.spectral())
145
+
146
+ # This computes the total effective area within all bins
147
+ # This is a bit weird since the following computation is equivalent to ignoring the RMF
148
+ exposure = obsconf.exposure.data * u.s
149
+ mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
150
+ mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
151
+ e_grid = np.linspace(*xbins, 10)
152
+ interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
153
+ integrated_arf = (
154
+ trapezoid(interpolated_arf, x=e_grid, axis=0)
155
+ / (
156
+ np.abs(
157
+ xbins[1] - xbins[0]
158
+ ) # Must fold in abs because some units reverse the ordering of the bins
159
+ )
160
+ * u.cm**2
161
+ )
162
+
163
+ return xbins, exposure, integrated_arf
164
+
165
+
166
+ def _error_bars_for_observed_data(observed_counts, denominator, units, sigma=1):
167
+ r"""
168
+ Compute the error bars for the observed data assuming a prior Gamma distribution
169
+
170
+ Parameters:
171
+ observed_counts: array of integer counts
172
+ denominator: normalization factor (e.g. effective area)
173
+ units: unit to convert to
174
+ sigma: dispersion to use for quantiles computation
175
+
176
+ Returns:
177
+ y_observed: observed counts in the desired units
178
+ y_observed_low: lower bound of the error bars
179
+ y_observed_high: upper bound of the error bars
180
+ """
181
+
182
+ percentile = sigma_to_percentile_intervals([sigma])[0]
183
+
184
+ y_observed = (observed_counts * u.ct / denominator).to(units)
185
+
186
+ y_observed_low = (
187
+ nbinom.ppf(percentile[0] / 100, observed_counts, 0.5) * u.ct / denominator
188
+ ).to(units)
189
+
190
+ y_observed_high = (
191
+ nbinom.ppf(percentile[1] / 100, observed_counts, 0.5) * u.ct / denominator
192
+ ).to(units)
193
+
194
+ return y_observed, y_observed_low, y_observed_high