jaxspec 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,26 @@
1
+ from collections.abc import Mapping
2
+ from typing import Any, Literal, TypeVar
3
+
1
4
  import arviz as az
5
+ import astropy.units as u
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import matplotlib.pyplot as plt
2
9
  import numpy as np
3
10
  import xarray as xr
4
- import matplotlib.pyplot as plt
5
- from ..data import ObsConfiguration
6
- from ..model.abc import SpectralModel
7
- from ..model.background import BackgroundModel
8
- from collections.abc import Mapping
9
- from typing import TypeVar, Tuple, Literal, Any
11
+
10
12
  from astropy.cosmology import Cosmology, Planck18
11
- import astropy.units as u
12
13
  from astropy.units import Unit
14
+ from chainconsumer import Chain, ChainConsumer, PlotConfig
13
15
  from haiku.data_structures import traverse
14
- from chainconsumer import Chain, PlotConfig, ChainConsumer
15
- import jax
16
16
  from jax.typing import ArrayLike
17
17
  from scipy.integrate import trapezoid
18
+ from scipy.special import gammaln
19
+ from scipy.stats import nbinom
20
+
21
+ from ..data import ObsConfiguration
22
+ from ..model.abc import SpectralModel
23
+ from ..model.background import BackgroundModel
18
24
 
19
25
  K = TypeVar("K")
20
26
  V = TypeVar("V")
@@ -29,7 +35,6 @@ def _plot_binned_samples_with_error(
29
35
  x_bins: ArrayLike,
30
36
  denominator: ArrayLike | None = None,
31
37
  y_samples: ArrayLike | None = None,
32
- y_observed: ArrayLike | None = None,
33
38
  color=(0.15, 0.25, 0.45),
34
39
  percentile: tuple = (16, 84),
35
40
  ):
@@ -38,7 +43,8 @@ def _plot_binned_samples_with_error(
38
43
  computes the percentiles of the posterior predictive distribution and plot them as a shaded
39
44
  area. If the observed data is provided, it is also plotted as a step function.
40
45
 
41
- Parameters:
46
+ Parameters
47
+ ----------
42
48
  x_bins: The bin edges of the data (2 x N).
43
49
  y_samples: The samples of the posterior predictive distribution (Samples X N).
44
50
  denominator: Values used to divided the samples, i.e. to get energy flux (N).
@@ -51,22 +57,15 @@ def _plot_binned_samples_with_error(
51
57
 
52
58
  mean, envelope = None, None
53
59
 
54
- if x_bins is None:
55
- raise ValueError("x_bins cannot be None.")
60
+ if denominator is None:
61
+ denominator = np.ones_like(x_bins[0])
56
62
 
57
- if (y_samples is None) and (y_observed is None):
58
- raise ValueError("Either a y_samples or y_observed must be provided.")
59
-
60
- if y_observed is not None:
61
- if denominator is None:
62
- denominator = np.ones_like(x_bins[0])
63
-
64
- (mean,) = ax.step(
65
- list(x_bins[0]) + [x_bins[1][-1]], # x_bins[1][-1]+1],
66
- list(y_observed / denominator) + [np.nan], # + [np.nan, np.nan],
67
- where="pre",
68
- c=color,
69
- )
63
+ mean = ax.stairs(
64
+ list(np.median(y_samples, axis=0) / denominator),
65
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
66
+ color=color,
67
+ alpha=0.7,
68
+ )
70
69
 
71
70
  if y_samples is not None:
72
71
  if denominator is None:
@@ -77,48 +76,21 @@ def _plot_binned_samples_with_error(
77
76
  # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
78
77
  (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
79
78
 
80
- ax.fill_between(
81
- list(x_bins[0]) + [x_bins[1][-1]], # + [x_bins[1][-1], x_bins[1][-1] + 1],
82
- list(percentiles[0] / denominator) + [np.nan], # + [np.nan, np.nan],
83
- list(percentiles[1] / denominator) + [np.nan], # + [np.nan, np.nan],
79
+ ax.stairs(
80
+ percentiles[1] / denominator,
81
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
82
+ baseline=percentiles[0] / denominator,
84
83
  alpha=0.3,
85
- step="pre",
86
- facecolor=color,
84
+ fill=True,
85
+ color=color,
87
86
  )
88
87
 
89
88
  return [(mean, envelope)]
90
89
 
91
90
 
92
- def format_parameters(parameter_name):
93
- computed_parameters = ["Photon flux", "Energy flux", "Luminosity"]
94
-
95
- if parameter_name == "weight":
96
- # ChainConsumer add a weight column to the samples
97
- return parameter_name
98
-
99
- for parameter in computed_parameters:
100
- if parameter in parameter_name:
101
- return parameter_name
102
-
103
- # Find second occurrence of the character '_'
104
- first_occurrence = parameter_name.find("_")
105
- second_occurrence = parameter_name.find("_", first_occurrence + 1)
106
- module = parameter_name[:second_occurrence]
107
- parameter = parameter_name[second_occurrence + 1 :]
108
-
109
- name, number = module.split("_")
110
- module = rf"[{name.capitalize()} ({number})]"
111
-
112
- if parameter == "norm":
113
- return r"Norm " + module
114
-
115
- else:
116
- return rf"${parameter}$" + module
117
-
118
-
119
91
  class FitResult:
120
92
  """
121
- This class is the container for the result of a fit using any ModelFitter class.
93
+ Container for the result of a fit using any ModelFitter class.
122
94
  """
123
95
 
124
96
  # TODO : Add type hints
@@ -133,7 +105,9 @@ class FitResult:
133
105
  self.model = model
134
106
  self._structure = structure
135
107
  self.inference_data = inference_data
136
- self.obsconfs = {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
108
+ self.obsconfs = (
109
+ {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
110
+ )
137
111
  self.background_model = background_model
138
112
  self._structure = structure
139
113
 
@@ -146,39 +120,70 @@ class FitResult:
146
120
 
147
121
  @property
148
122
  def converged(self) -> bool:
149
- """
123
+ r"""
150
124
  Convergence of the chain as computed by the $\hat{R}$ statistic.
151
125
  """
152
126
 
153
127
  return all(az.rhat(self.inference_data) < 1.01)
154
128
 
129
+ @property
130
+ def _structured_samples(self):
131
+ """
132
+ Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
133
+ """
134
+
135
+ var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
136
+ posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
137
+ samples_flat = {key: posterior[key].data for key in var_names}
138
+
139
+ samples_haiku = {}
140
+
141
+ for module, parameter, value in traverse(self._structure):
142
+ if samples_haiku.get(module, None) is None:
143
+ samples_haiku[module] = {}
144
+ samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
145
+
146
+ return samples_haiku
147
+
155
148
  def photon_flux(
156
149
  self,
157
150
  e_min: float,
158
151
  e_max: float,
159
152
  unit: Unit = u.photon / u.cm**2 / u.s,
153
+ register: bool = False,
160
154
  ) -> ArrayLike:
161
155
  """
162
156
  Compute the unfolded photon flux in a given energy band. The flux is then added to
163
157
  the result parameters so covariance can be plotted.
164
158
 
165
- Parameters:
159
+ Parameters
160
+ ----------
166
161
  e_min: The lower bound of the energy band in observer frame.
167
162
  e_max: The upper bound of the energy band in observer frame.
168
163
  unit: The unit of the photon flux.
164
+ register: Whether to register the flux with the other posterior parameters.
169
165
 
170
166
  !!! warning
171
167
  Computation of the folded flux is not implemented yet. Feel free to open an
172
168
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
173
169
  """
174
170
 
175
- flux = jax.vmap(lambda p: self.model.photon_flux(p, np.asarray([e_min]), np.asarray([e_max])))(self.params)
171
+ samples = self._structured_samples
172
+ init_shape = jax.tree.leaves(samples)[0].shape
176
173
 
177
- conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
174
+ flux = jax.vmap(
175
+ lambda p: self.model.photon_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
176
+ )(jax.tree.map(lambda x: x.ravel(), samples))
178
177
 
178
+ flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
179
+ conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
179
180
  value = flux * conversion_factor
180
- # TODO : fix this since sample doesn't exist anymore
181
- self.samples[rf"Photon flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
181
+
182
+ if register:
183
+ self.inference_data.posterior[f"flux_{e_min:.1f}_{e_max:.1f}"] = (
184
+ ["chain", "draw"],
185
+ value,
186
+ )
182
187
 
183
188
  return value
184
189
 
@@ -187,29 +192,41 @@ class FitResult:
187
192
  e_min: float,
188
193
  e_max: float,
189
194
  unit: Unit = u.erg / u.cm**2 / u.s,
195
+ register: bool = False,
190
196
  ) -> ArrayLike:
191
197
  """
192
198
  Compute the unfolded energy flux in a given energy band. The flux is then added to
193
199
  the result parameters so covariance can be plotted.
194
200
 
195
- Parameters:
201
+ Parameters
202
+ ----------
196
203
  e_min: The lower bound of the energy band in observer frame.
197
204
  e_max: The upper bound of the energy band in observer frame.
198
205
  unit: The unit of the energy flux.
206
+ register: Whether to register the flux with the other posterior parameters.
199
207
 
200
208
  !!! warning
201
209
  Computation of the folded flux is not implemented yet. Feel free to open an
202
210
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
203
211
  """
204
212
 
205
- flux = jax.vmap(lambda p: self.model.energy_flux(p, np.asarray([e_min]), np.asarray([e_max])))(self.params)
213
+ samples = self._structured_samples
214
+ init_shape = jax.tree.leaves(samples)[0].shape
206
215
 
207
- conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
216
+ flux = jax.vmap(
217
+ lambda p: self.model.energy_flux(p, jnp.asarray([e_min]), jnp.asarray([e_max]))
218
+ )(jax.tree.map(lambda x: x.ravel(), samples))
219
+
220
+ flux = jax.tree.map(lambda x: x.reshape(init_shape), flux)
208
221
 
222
+ conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
209
223
  value = flux * conversion_factor
210
224
 
211
- # TODO : fix this since sample doesn't exist anymore
212
- self.samples[rf"Energy flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
225
+ if register:
226
+ self.inference_data.posterior[f"eflux_{e_min:.1f}_{e_max:.1f}"] = (
227
+ ["chain", "draw"],
228
+ value,
229
+ )
213
230
 
214
231
  return value
215
232
 
@@ -217,57 +234,84 @@ class FitResult:
217
234
  self,
218
235
  e_min: float,
219
236
  e_max: float,
220
- redshift: float | ArrayLike = 0,
237
+ redshift: float | ArrayLike = 0.1,
221
238
  observer_frame: bool = True,
222
239
  cosmology: Cosmology = Planck18,
223
240
  unit: Unit = u.erg / u.s,
241
+ register: bool = False,
224
242
  ) -> ArrayLike:
225
243
  """
226
244
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
227
245
  the result parameters so covariance can be plotted.
228
246
 
229
- Parameters:
247
+ Parameters
248
+ ----------
230
249
  e_min: The lower bound of the energy band.
231
250
  e_max: The upper bound of the energy band.
232
251
  redshift: The redshift of the source. It can be a distribution of redshifts.
233
252
  observer_frame: Whether the input bands are defined in observer frame or not.
234
253
  cosmology: Chosen cosmology.
235
254
  unit: The unit of the luminosity.
255
+ register: Whether to register the flux with the other posterior parameters.
236
256
  """
237
257
 
238
258
  if not observer_frame:
239
259
  raise NotImplementedError()
240
260
 
241
- flux = self.energy_flux(e_min * (1 + redshift), e_max * (1 + redshift)) * (u.erg / u.cm**2 / u.s)
261
+ samples = self._structured_samples
262
+ init_shape = jax.tree.leaves(samples)[0].shape
242
263
 
264
+ flux = jax.vmap(
265
+ lambda p: self.model.energy_flux(
266
+ p, jnp.asarray([e_min]) * (1 + redshift), jnp.asarray([e_max])
267
+ )
268
+ * (1 + redshift)
269
+ )(jax.tree.map(lambda x: x.ravel(), samples))
270
+
271
+ flux = jax.tree.map(
272
+ lambda x: np.asarray(x.reshape(init_shape)) * (u.keV / u.cm**2 / u.s), flux
273
+ )
243
274
  value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
244
275
 
245
- # TODO : fix this since sample doesn't exist anymore
246
- self.samples[rf"Luminosity ({e_min:.1f}-{e_max:.1f} keV)"] = value
276
+ if register:
277
+ self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
278
+ ["chain", "draw"],
279
+ value,
280
+ )
247
281
 
248
282
  return value
249
283
 
250
- def to_chain(self, name: str, parameters: Literal["model", "bkg"] = "model") -> Chain:
284
+ def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
251
285
  """
252
- Return a ChainConsumer Chain object from the posterior distribution of the parameters.
286
+ Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
253
287
 
254
- Parameters:
288
+ Parameters
289
+ ----------
255
290
  name: The name of the chain.
256
- parameters: The parameters to include in the chain.
291
+ parameters_type: The parameters_type to include in the chain.
257
292
  """
258
293
 
259
294
  obs_id = self.inference_data.copy()
260
295
 
261
- if parameters == "model":
262
- keys_to_drop = [key for key in obs_id.posterior.keys() if (key.startswith("_") or key.startswith("bkg"))]
263
- elif parameters == "bkg":
296
+ if parameters_type == "model":
297
+ keys_to_drop = [
298
+ key
299
+ for key in obs_id.posterior.keys()
300
+ if (key.startswith("_") or key.startswith("bkg"))
301
+ ]
302
+ elif parameters_type == "bkg":
264
303
  keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
265
304
  else:
266
- raise ValueError(f"Unknown value for parameters: {parameters}")
305
+ raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
267
306
 
268
307
  obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
269
308
  chain = Chain.from_arviz(obs_id, name)
270
- chain.samples.columns = [format_parameters(parameter) for parameter in chain.samples.columns]
309
+
310
+ """
311
+ chain.samples.columns = [
312
+ format_parameters(parameter) for parameter in chain.samples.columns
313
+ ]
314
+ """
271
315
 
272
316
  return chain
273
317
 
@@ -304,7 +348,7 @@ class FitResult:
304
348
  for module, parameter, value in traverse(self._structure):
305
349
  if params.get(module, None) is None:
306
350
  params[module] = {}
307
- params[module][parameter] = self.samples[f"{module}_{parameter}"]
351
+ params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
308
352
 
309
353
  return params
310
354
 
@@ -328,19 +372,50 @@ class FitResult:
328
372
  return {key: posterior[key].data for key in var_names}
329
373
 
330
374
  @property
331
- def likelihood(self) -> xr.Dataset:
375
+ def log_likelihood(self) -> xr.Dataset:
332
376
  """
333
- Return the likelihood of each observation
377
+ Return the log_likelihood of each observation
334
378
  """
335
379
  log_likelihood = az.extract(self.inference_data, group="log_likelihood")
336
- dimensions_to_reduce = [coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]]
380
+ dimensions_to_reduce = [
381
+ coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]
382
+ ]
337
383
  return log_likelihood.sum(dimensions_to_reduce)
338
384
 
385
+ @property
386
+ def c_stat(self):
387
+ r"""
388
+ Return the C-statistic of the model
389
+
390
+ The C-statistic is defined as:
391
+
392
+ $$ C = 2 \sum_{i} M - D*log(M) + D*log(D) - D $$
393
+ or
394
+ $$ C = 2 \sum_{i} M - D*log(M)$$
395
+ for bins with no counts
396
+
397
+ """
398
+
399
+ exclude_dims = ["chain", "draw", "sample"]
400
+ all_dims = list(self.inference_data.log_likelihood.dims)
401
+ reduce_dims = [dim for dim in all_dims if dim not in exclude_dims]
402
+ data = self.inference_data.observed_data
403
+ c_stat = -2 * (
404
+ self.log_likelihood
405
+ + (gammaln(data + 1) - (xr.where(data > 0, data * (np.log(data) - 1), 0))).sum(
406
+ dim=reduce_dims
407
+ )
408
+ )
409
+
410
+ return c_stat
411
+
339
412
  def plot_ppc(
340
413
  self,
341
- percentile: Tuple[int, int] = (14, 86),
414
+ percentile: tuple[int, int] = (16, 84),
342
415
  x_unit: str | u.Unit = "keV",
343
- y_type: Literal["counts", "countrate", "photon_flux", "photon_flux_density"] = "photon_flux_density",
416
+ y_type: Literal[
417
+ "counts", "countrate", "photon_flux", "photon_flux_density"
418
+ ] = "photon_flux_density",
344
419
  ) -> plt.Figure:
345
420
  r"""
346
421
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
@@ -349,12 +424,14 @@ class FitResult:
349
424
  $$ \text{Residual} = \frac{\text{Observed counts} - \text{Posterior counts}}
350
425
  {(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
351
426
 
352
- Parameters:
427
+ Parameters
428
+ ----------
353
429
  percentile: The percentile of the posterior predictive distribution to plot.
354
430
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
355
431
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
356
432
 
357
433
  Returns:
434
+ -------
358
435
  The matplotlib figure.
359
436
  """
360
437
 
@@ -383,21 +460,31 @@ class FitResult:
383
460
  # and enable weird broadcasting that makes the plot fail
384
461
 
385
462
  fig, axs = plt.subplots(
386
- 2, len(obsconf_container), figsize=(6 * len(obsconf_container), 6), sharex=True, height_ratios=[0.7, 0.3]
463
+ 2,
464
+ len(obsconf_container),
465
+ figsize=(6 * len(obsconf_container), 6),
466
+ sharex=True,
467
+ height_ratios=[0.7, 0.3],
387
468
  )
388
469
 
389
470
  plot_ylabels_once = True
390
471
 
391
472
  for name, obsconf, ax in zip(
392
- obsconf_container.keys(), obsconf_container.values(), axs.T if len(obsconf_container) > 1 else [axs]
473
+ obsconf_container.keys(),
474
+ obsconf_container.values(),
475
+ axs.T if len(obsconf_container) > 1 else [axs],
393
476
  ):
394
477
  legend_plots = []
395
478
  legend_labels = []
396
- count = az.extract(self.inference_data, var_names=f"obs_{name}", group="posterior_predictive").values.T
479
+ count = az.extract(
480
+ self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
481
+ ).values.T
397
482
  bkg_count = (
398
483
  None
399
484
  if self.background_model is None
400
- else az.extract(self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive").values.T
485
+ else az.extract(
486
+ self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
487
+ ).values.T
401
488
  )
402
489
 
403
490
  xbins = obsconf.out_energies * u.keV
@@ -413,18 +500,13 @@ class FitResult:
413
500
  integrated_arf = (
414
501
  trapezoid(interpolated_arf, x=e_grid, axis=0)
415
502
  / (
416
- np.abs(xbins[1] - xbins[0]) # Must fold in abs because some units reverse the ordering of the bins
503
+ np.abs(
504
+ xbins[1] - xbins[0]
505
+ ) # Must fold in abs because some units reverse the ordering of the bins
417
506
  )
418
507
  * u.cm**2
419
508
  )
420
509
 
421
- """
422
- if xbins[0][0] < 1 < xbins[1][-1]:
423
- xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
424
- else:
425
- xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
426
- """
427
-
428
510
  match y_type:
429
511
  case "counts":
430
512
  denominator = 1
@@ -437,50 +519,93 @@ class FitResult:
437
519
 
438
520
  y_samples = (count * u.photon / denominator).to(y_units)
439
521
  y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
522
+ y_observed_low = (
523
+ nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
524
+ * u.photon
525
+ / denominator
526
+ ).to(y_units)
527
+ y_observed_high = (
528
+ nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
529
+ * u.photon
530
+ / denominator
531
+ ).to(y_units)
440
532
 
441
533
  # Use the helper function to plot the data and posterior predictive
442
534
  legend_plots += _plot_binned_samples_with_error(
443
535
  ax[0],
444
536
  xbins.value,
445
537
  y_samples=y_samples.value,
446
- y_observed=y_observed.value,
447
538
  denominator=np.ones_like(y_observed).value,
448
539
  color=color,
449
540
  percentile=percentile,
450
541
  )
451
542
 
452
- legend_labels.append("Source + Background")
543
+ legend_labels.append("Model")
544
+
545
+ true_data_plot = ax[0].errorbar(
546
+ np.sqrt(xbins.value[0] * xbins.value[1]),
547
+ y_observed.value,
548
+ xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
549
+ yerr=[
550
+ y_observed.value - y_observed_low.value,
551
+ y_observed_high.value - y_observed.value,
552
+ ],
553
+ color="black",
554
+ linestyle="none",
555
+ alpha=0.3,
556
+ capsize=2,
557
+ )
558
+
559
+ legend_plots.append((true_data_plot,))
560
+ legend_labels.append("Observed")
453
561
 
454
562
  if self.background_model is not None:
455
563
  # We plot the background only if it is included in the fit, i.e. by subtracting
456
564
  ratio = obsconf.folded_backratio.data
457
565
  y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
458
- y_observed_bkg = (obsconf.folded_background.data * u.photon / (denominator * ratio)).to(y_units)
566
+ y_observed_bkg = (
567
+ obsconf.folded_background.data * u.photon / (denominator * ratio)
568
+ ).to(y_units)
459
569
  legend_plots += _plot_binned_samples_with_error(
460
570
  ax[0],
461
571
  xbins.value,
462
572
  y_samples=y_samples_bkg.value,
463
- y_observed=y_observed_bkg.value,
464
573
  denominator=np.ones_like(y_observed).value,
465
574
  color=(0.26787604, 0.60085972, 0.63302651),
466
575
  percentile=percentile,
467
576
  )
468
577
 
469
- legend_labels.append("Background")
578
+ legend_labels.append("Model (bkg)")
579
+
580
+ residual_samples = (obsconf.folded_counts.data - count) / np.diff(
581
+ np.percentile(count, percentile, axis=0), axis=0
582
+ )
470
583
 
471
584
  residuals = np.percentile(
472
- (obsconf.folded_counts.data - count) / np.diff(np.percentile(count, percentile, axis=0), axis=0),
585
+ residual_samples,
473
586
  percentile,
474
587
  axis=0,
475
588
  )
476
589
 
477
- ax[1].fill_between(
478
- list(xbins.value[0]) + [xbins.value[1][-1]],
479
- list(residuals[0]) + [residuals[0][-1]],
480
- list(residuals[1]) + [residuals[1][-1]],
590
+ median_residuals = np.median(
591
+ residual_samples,
592
+ axis=0,
593
+ )
594
+
595
+ ax[1].stairs(
596
+ residuals[1],
597
+ edges=[*list(xbins.value[0]), xbins.value[1][-1]],
598
+ baseline=list(residuals[0]),
481
599
  alpha=0.3,
482
- step="post",
483
600
  facecolor=color,
601
+ fill=True,
602
+ )
603
+
604
+ ax[1].stairs(
605
+ median_residuals,
606
+ edges=[*list(xbins.value[0]), xbins.value[1][-1]],
607
+ color=color,
608
+ alpha=0.7,
484
609
  )
485
610
 
486
611
  max_residuals = np.max(np.abs(residuals))
@@ -502,7 +627,8 @@ class FitResult:
502
627
  ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
503
628
  case _:
504
629
  RuntimeError(
505
- f"Unknown physical type for x_units: {x_unit}. " f"Must be 'length', 'energy' or 'frequency'"
630
+ f"Unknown physical type for x_units: {x_unit}. "
631
+ f"Must be 'length', 'energy' or 'frequency'"
506
632
  )
507
633
 
508
634
  ax[1].axhline(0, color=color, ls="--")
@@ -536,15 +662,18 @@ class FitResult:
536
662
 
537
663
  def plot_corner(
538
664
  self,
539
- config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=6),
665
+ config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=12),
540
666
  **kwargs: Any,
541
667
  ) -> plt.Figure:
542
668
  """
543
- Plot the corner plot of the posterior distribution of the parameters. This method uses the ChainConsumer.
669
+ Plot the corner plot of the posterior distribution of the parameters_type. This method uses the ChainConsumer.
544
670
 
545
- Parameters:
671
+ Parameters
672
+ ----------
546
673
  config: The configuration of the plot.
547
- **kwargs: Additional arguments passed to ChainConsumer.plotter.plot.
674
+ parameters: The parameters to include in the plot using the following format: `blackbody_1_kT`.
675
+ **kwargs: Additional arguments passed to ChainConsumer.plotter.plot. Some useful parameters are :
676
+ - columns : list of parameters to plot.
548
677
  """
549
678
 
550
679
  consumer = ChainConsumer()
jaxspec/data/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
- # precommit is suppressing these imports
2
- from .obsconf import ObsConfiguration # noqa: F401
3
- from .instrument import Instrument # noqa: F401
4
- from .observation import Observation # noqa: F401
5
1
  import astropy.units as u
6
2
 
3
+ from .instrument import Instrument
4
+ from .obsconf import ObsConfiguration
5
+ from .observation import Observation
6
+
7
7
  u.add_enabled_aliases({"counts": u.count})
8
8
  u.add_enabled_aliases({"channel": u.dimensionless_unscaled})
9
9
  # Arbitrary units are found in .rsp files , let's hope it is compatible with what we would expect as the rmf x arf