jaxspec 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,20 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar
5
+
1
6
  import arviz as az
7
+ import astropy.units as u
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import matplotlib.pyplot as plt
2
11
  import numpy as np
3
12
  import xarray as xr
4
- import matplotlib.pyplot as plt
5
- from ..data import ObsConfiguration
6
- from ..model.abc import SpectralModel
7
- from ..model.background import BackgroundModel
8
- from collections.abc import Mapping
9
- from typing import TypeVar, Tuple, Literal, Any
13
+
10
14
  from astropy.cosmology import Cosmology, Planck18
11
- import astropy.units as u
12
15
  from astropy.units import Unit
16
+ from chainconsumer import Chain, ChainConsumer, PlotConfig
13
17
  from haiku.data_structures import traverse
14
- from chainconsumer import Chain, PlotConfig, ChainConsumer
15
- import jax
16
18
  from jax.typing import ArrayLike
19
+ from numpyro.handlers import seed
17
20
  from scipy.integrate import trapezoid
21
+ from scipy.special import gammaln
22
+ from scipy.stats import nbinom
23
+
24
+ if TYPE_CHECKING:
25
+ from ..fit import BayesianModel
26
+ from ..model.background import BackgroundModel
18
27
 
19
28
  K = TypeVar("K")
20
29
  V = TypeVar("V")
@@ -29,7 +38,6 @@ def _plot_binned_samples_with_error(
29
38
  x_bins: ArrayLike,
30
39
  denominator: ArrayLike | None = None,
31
40
  y_samples: ArrayLike | None = None,
32
- y_observed: ArrayLike | None = None,
33
41
  color=(0.15, 0.25, 0.45),
34
42
  percentile: tuple = (16, 84),
35
43
  ):
@@ -38,7 +46,8 @@ def _plot_binned_samples_with_error(
38
46
  computes the percentiles of the posterior predictive distribution and plot them as a shaded
39
47
  area. If the observed data is provided, it is also plotted as a step function.
40
48
 
41
- Parameters:
49
+ Parameters
50
+ ----------
42
51
  x_bins: The bin edges of the data (2 x N).
43
52
  y_samples: The samples of the posterior predictive distribution (Samples X N).
44
53
  denominator: Values used to divided the samples, i.e. to get energy flux (N).
@@ -51,22 +60,15 @@ def _plot_binned_samples_with_error(
51
60
 
52
61
  mean, envelope = None, None
53
62
 
54
- if x_bins is None:
55
- raise ValueError("x_bins cannot be None.")
63
+ if denominator is None:
64
+ denominator = np.ones_like(x_bins[0])
56
65
 
57
- if (y_samples is None) and (y_observed is None):
58
- raise ValueError("Either a y_samples or y_observed must be provided.")
59
-
60
- if y_observed is not None:
61
- if denominator is None:
62
- denominator = np.ones_like(x_bins[0])
63
-
64
- (mean,) = ax.step(
65
- list(x_bins[0]) + [x_bins[1][-1]], # x_bins[1][-1]+1],
66
- list(y_observed / denominator) + [np.nan], # + [np.nan, np.nan],
67
- where="pre",
68
- c=color,
69
- )
66
+ mean = ax.stairs(
67
+ list(np.median(y_samples, axis=0) / denominator),
68
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
69
+ color=color,
70
+ alpha=0.7,
71
+ )
70
72
 
71
73
  if y_samples is not None:
72
74
  if denominator is None:
@@ -77,63 +79,35 @@ def _plot_binned_samples_with_error(
77
79
  # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
78
80
  (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
79
81
 
80
- ax.fill_between(
81
- list(x_bins[0]) + [x_bins[1][-1]], # + [x_bins[1][-1], x_bins[1][-1] + 1],
82
- list(percentiles[0] / denominator) + [np.nan], # + [np.nan, np.nan],
83
- list(percentiles[1] / denominator) + [np.nan], # + [np.nan, np.nan],
82
+ ax.stairs(
83
+ percentiles[1] / denominator,
84
+ edges=[*list(x_bins[0]), x_bins[1][-1]],
85
+ baseline=percentiles[0] / denominator,
84
86
  alpha=0.3,
85
- step="pre",
86
- facecolor=color,
87
+ fill=True,
88
+ color=color,
87
89
  )
88
90
 
89
91
  return [(mean, envelope)]
90
92
 
91
93
 
92
- def format_parameters(parameter_name):
93
- computed_parameters = ["Photon flux", "Energy flux", "Luminosity"]
94
-
95
- if parameter_name == "weight":
96
- # ChainConsumer add a weight column to the samples
97
- return parameter_name
98
-
99
- for parameter in computed_parameters:
100
- if parameter in parameter_name:
101
- return parameter_name
102
-
103
- # Find second occurrence of the character '_'
104
- first_occurrence = parameter_name.find("_")
105
- second_occurrence = parameter_name.find("_", first_occurrence + 1)
106
- module = parameter_name[:second_occurrence]
107
- parameter = parameter_name[second_occurrence + 1 :]
108
-
109
- name, number = module.split("_")
110
- module = rf"[{name.capitalize()} ({number})]"
111
-
112
- if parameter == "norm":
113
- return r"Norm " + module
114
-
115
- else:
116
- return rf"${parameter}$" + module
117
-
118
-
119
94
  class FitResult:
120
95
  """
121
- This class is the container for the result of a fit using any ModelFitter class.
96
+ Container for the result of a fit using any ModelFitter class.
122
97
  """
123
98
 
124
99
  # TODO : Add type hints
125
100
  def __init__(
126
101
  self,
127
- model: SpectralModel,
128
- obsconf: ObsConfiguration | dict[str, ObsConfiguration],
102
+ bayesian_fitter: BayesianModel,
129
103
  inference_data: az.InferenceData,
130
104
  structure: Mapping[K, V],
131
105
  background_model: BackgroundModel = None,
132
106
  ):
133
- self.model = model
134
- self._structure = structure
107
+ self.model = bayesian_fitter.model
108
+ self.bayesian_fitter = bayesian_fitter
135
109
  self.inference_data = inference_data
136
- self.obsconfs = {"Observation": obsconf} if isinstance(obsconf, ObsConfiguration) else obsconf
110
+ self.obsconfs = bayesian_fitter.observation_container
137
111
  self.background_model = background_model
138
112
  self._structure = structure
139
113
 
@@ -141,22 +115,94 @@ class FitResult:
141
115
  for group in self.inference_data.groups():
142
116
  group_name = group.split("/")[-1]
143
117
  metadata = getattr(self.inference_data, group_name).attrs
144
- metadata["model"] = str(model)
118
+ metadata["model"] = str(self.model)
145
119
  # TODO : Store metadata about observations used in the fitting process
146
120
 
147
121
  @property
148
122
  def converged(self) -> bool:
149
- """
123
+ r"""
150
124
  Convergence of the chain as computed by the $\hat{R}$ statistic.
151
125
  """
152
126
 
153
127
  return all(az.rhat(self.inference_data) < 1.01)
154
128
 
129
+ @property
130
+ def _structured_samples(self):
131
+ """
132
+ Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
133
+ """
134
+
135
+ samples_flat = self._structured_samples_flat
136
+
137
+ samples_haiku = {}
138
+
139
+ for module, parameter, value in traverse(self._structure):
140
+ if samples_haiku.get(module, None) is None:
141
+ samples_haiku[module] = {}
142
+ samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
143
+
144
+ return samples_haiku
145
+
146
+ @property
147
+ def _structured_samples_flat(self):
148
+ """
149
+ Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
150
+ """
151
+
152
+ var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
153
+ posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
154
+ samples_flat = {key: posterior[key].data for key in var_names}
155
+
156
+ return samples_flat
157
+
158
+ @property
159
+ def input_parameters(self) -> HaikuDict[ArrayLike]:
160
+ """
161
+ The input parameters of the model.
162
+ """
163
+
164
+ posterior = az.extract(self.inference_data, combined=False)
165
+
166
+ samples_shape = (len(posterior.coords["chain"]), len(posterior.coords["draw"]))
167
+
168
+ total_shape = tuple(posterior.sizes[d] for d in posterior.coords)
169
+
170
+ posterior = {key: posterior[key].data for key in posterior.data_vars}
171
+
172
+ with seed(rng_seed=0):
173
+ input_parameters = self.bayesian_fitter.prior_distributions_func()
174
+
175
+ for module, parameter, value in traverse(input_parameters):
176
+ if f"{module}_{parameter}" in posterior.keys():
177
+ # We add as extra dimension as there might be different values per observation
178
+ if posterior[f"{module}_{parameter}"].shape == samples_shape:
179
+ to_set = posterior[f"{module}_{parameter}"][..., None]
180
+ else:
181
+ to_set = posterior[f"{module}_{parameter}"]
182
+
183
+ input_parameters[module][parameter] = to_set
184
+
185
+ else:
186
+ # The parameter is fixed in this case, so we just broadcast is over chain and draws
187
+ input_parameters[module][parameter] = value[None, None, ...]
188
+
189
+ if len(total_shape) < len(input_parameters[module][parameter].shape):
190
+ # If there are only chains and draws, we reduce
191
+ input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
192
+
193
+ else:
194
+ input_parameters[module][parameter] = jnp.broadcast_to(
195
+ input_parameters[module][parameter], total_shape
196
+ )
197
+
198
+ return input_parameters
199
+
155
200
  def photon_flux(
156
201
  self,
157
202
  e_min: float,
158
203
  e_max: float,
159
204
  unit: Unit = u.photon / u.cm**2 / u.s,
205
+ register: bool = False,
160
206
  ) -> ArrayLike:
161
207
  """
162
208
  Compute the unfolded photon flux in a given energy band. The flux is then added to
@@ -166,19 +212,31 @@ class FitResult:
166
212
  e_min: The lower bound of the energy band in observer frame.
167
213
  e_max: The upper bound of the energy band in observer frame.
168
214
  unit: The unit of the photon flux.
215
+ register: Whether to register the flux with the other posterior parameters.
169
216
 
170
217
  !!! warning
171
218
  Computation of the folded flux is not implemented yet. Feel free to open an
172
219
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
173
220
  """
174
221
 
175
- flux = jax.vmap(lambda p: self.model.photon_flux(p, np.asarray([e_min]), np.asarray([e_max])))(self.params)
222
+ @jax.jit
223
+ @jnp.vectorize
224
+ def vectorized_flux(*pars):
225
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
226
+ return self.model.photon_flux(
227
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
228
+ )[0]
176
229
 
230
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
231
+ flux = vectorized_flux(*flat_tree)
177
232
  conversion_factor = (u.photon / u.cm**2 / u.s).to(unit)
178
-
179
233
  value = flux * conversion_factor
180
- # TODO : fix this since sample doesn't exist anymore
181
- self.samples[rf"Photon flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
234
+
235
+ if register:
236
+ self.inference_data.posterior[f"photon_flux_{e_min:.1f}_{e_max:.1f}"] = (
237
+ list(self.inference_data.posterior.coords),
238
+ value,
239
+ )
182
240
 
183
241
  return value
184
242
 
@@ -187,6 +245,7 @@ class FitResult:
187
245
  e_min: float,
188
246
  e_max: float,
189
247
  unit: Unit = u.erg / u.cm**2 / u.s,
248
+ register: bool = False,
190
249
  ) -> ArrayLike:
191
250
  """
192
251
  Compute the unfolded energy flux in a given energy band. The flux is then added to
@@ -196,20 +255,31 @@ class FitResult:
196
255
  e_min: The lower bound of the energy band in observer frame.
197
256
  e_max: The upper bound of the energy band in observer frame.
198
257
  unit: The unit of the energy flux.
258
+ register: Whether to register the flux with the other posterior parameters.
199
259
 
200
260
  !!! warning
201
261
  Computation of the folded flux is not implemented yet. Feel free to open an
202
262
  [issue](https://github.com/renecotyfanboy/jaxspec/issues) in the GitHub repository.
203
263
  """
204
264
 
205
- flux = jax.vmap(lambda p: self.model.energy_flux(p, np.asarray([e_min]), np.asarray([e_max])))(self.params)
265
+ @jax.jit
266
+ @jnp.vectorize
267
+ def vectorized_flux(*pars):
268
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
269
+ return self.model.energy_flux(
270
+ parameters_pytree, jnp.asarray([e_min]), jnp.asarray([e_max]), n_points=100
271
+ )[0]
206
272
 
273
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
274
+ flux = vectorized_flux(*flat_tree)
207
275
  conversion_factor = (u.keV / u.cm**2 / u.s).to(unit)
208
-
209
276
  value = flux * conversion_factor
210
277
 
211
- # TODO : fix this since sample doesn't exist anymore
212
- self.samples[rf"Energy flux ({e_min:.1f}-{e_max:.1f} keV)"] = value
278
+ if register:
279
+ self.inference_data.posterior[f"energy_flux_{e_min:.1f}_{e_max:.1f}"] = (
280
+ list(self.inference_data.posterior.coords),
281
+ value,
282
+ )
213
283
 
214
284
  return value
215
285
 
@@ -217,10 +287,11 @@ class FitResult:
217
287
  self,
218
288
  e_min: float,
219
289
  e_max: float,
220
- redshift: float | ArrayLike = 0,
290
+ redshift: float | ArrayLike = 0.1,
221
291
  observer_frame: bool = True,
222
292
  cosmology: Cosmology = Planck18,
223
293
  unit: Unit = u.erg / u.s,
294
+ register: bool = False,
224
295
  ) -> ArrayLike:
225
296
  """
226
297
  Compute the luminosity of the source specifying its redshift. The luminosity is then added to
@@ -233,41 +304,65 @@ class FitResult:
233
304
  observer_frame: Whether the input bands are defined in observer frame or not.
234
305
  cosmology: Chosen cosmology.
235
306
  unit: The unit of the luminosity.
307
+ register: Whether to register the flux with the other posterior parameters.
236
308
  """
237
309
 
238
310
  if not observer_frame:
239
311
  raise NotImplementedError()
240
312
 
241
- flux = self.energy_flux(e_min * (1 + redshift), e_max * (1 + redshift)) * (u.erg / u.cm**2 / u.s)
242
-
313
+ @jax.jit
314
+ @jnp.vectorize
315
+ def vectorized_flux(*pars):
316
+ parameters_pytree = jax.tree.unflatten(pytree_def, pars)
317
+ return self.model.energy_flux(
318
+ parameters_pytree,
319
+ jnp.asarray([e_min]) * (1 + redshift),
320
+ jnp.asarray([e_max]) * (1 + redshift),
321
+ n_points=100,
322
+ )[0]
323
+
324
+ flat_tree, pytree_def = jax.tree.flatten(self.input_parameters)
325
+ flux = vectorized_flux(*flat_tree) * (u.keV / u.cm**2 / u.s)
243
326
  value = (flux * (4 * np.pi * cosmology.luminosity_distance(redshift) ** 2)).to(unit)
244
327
 
245
- # TODO : fix this since sample doesn't exist anymore
246
- self.samples[rf"Luminosity ({e_min:.1f}-{e_max:.1f} keV)"] = value
328
+ if register:
329
+ self.inference_data.posterior[f"luminosity_{e_min:.1f}_{e_max:.1f}"] = (
330
+ list(self.inference_data.posterior.coords),
331
+ value,
332
+ )
247
333
 
248
334
  return value
249
335
 
250
- def to_chain(self, name: str, parameters: Literal["model", "bkg"] = "model") -> Chain:
336
+ def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
251
337
  """
252
- Return a ChainConsumer Chain object from the posterior distribution of the parameters.
338
+ Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
253
339
 
254
340
  Parameters:
255
341
  name: The name of the chain.
256
- parameters: The parameters to include in the chain.
342
+ parameters_type: The parameters_type to include in the chain.
257
343
  """
258
344
 
259
345
  obs_id = self.inference_data.copy()
260
346
 
261
- if parameters == "model":
262
- keys_to_drop = [key for key in obs_id.posterior.keys() if (key.startswith("_") or key.startswith("bkg"))]
263
- elif parameters == "bkg":
347
+ if parameters_type == "model":
348
+ keys_to_drop = [
349
+ key
350
+ for key in obs_id.posterior.keys()
351
+ if (key.startswith("_") or key.startswith("bkg"))
352
+ ]
353
+ elif parameters_type == "bkg":
264
354
  keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
265
355
  else:
266
- raise ValueError(f"Unknown value for parameters: {parameters}")
356
+ raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
267
357
 
268
358
  obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
269
359
  chain = Chain.from_arviz(obs_id, name)
270
- chain.samples.columns = [format_parameters(parameter) for parameter in chain.samples.columns]
360
+
361
+ """
362
+ chain.samples.columns = [
363
+ format_parameters(parameter) for parameter in chain.samples.columns
364
+ ]
365
+ """
271
366
 
272
367
  return chain
273
368
 
@@ -304,7 +399,7 @@ class FitResult:
304
399
  for module, parameter, value in traverse(self._structure):
305
400
  if params.get(module, None) is None:
306
401
  params[module] = {}
307
- params[module][parameter] = self.samples[f"{module}_{parameter}"]
402
+ params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
308
403
 
309
404
  return params
310
405
 
@@ -328,19 +423,50 @@ class FitResult:
328
423
  return {key: posterior[key].data for key in var_names}
329
424
 
330
425
  @property
331
- def likelihood(self) -> xr.Dataset:
426
+ def log_likelihood(self) -> xr.Dataset:
332
427
  """
333
- Return the likelihood of each observation
428
+ Return the log_likelihood of each observation
334
429
  """
335
430
  log_likelihood = az.extract(self.inference_data, group="log_likelihood")
336
- dimensions_to_reduce = [coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]]
431
+ dimensions_to_reduce = [
432
+ coord for coord in log_likelihood.coords if coord not in ["sample", "draw", "chain"]
433
+ ]
337
434
  return log_likelihood.sum(dimensions_to_reduce)
338
435
 
436
+ @property
437
+ def c_stat(self):
438
+ r"""
439
+ Return the C-statistic of the model
440
+
441
+ The C-statistic is defined as:
442
+
443
+ $$ C = 2 \sum_{i} M - D*log(M) + D*log(D) - D $$
444
+ or
445
+ $$ C = 2 \sum_{i} M - D*log(M)$$
446
+ for bins with no counts
447
+
448
+ """
449
+
450
+ exclude_dims = ["chain", "draw", "sample"]
451
+ all_dims = list(self.inference_data.log_likelihood.dims)
452
+ reduce_dims = [dim for dim in all_dims if dim not in exclude_dims]
453
+ data = self.inference_data.observed_data
454
+ c_stat = -2 * (
455
+ self.log_likelihood
456
+ + (gammaln(data + 1) - (xr.where(data > 0, data * (np.log(data) - 1), 0))).sum(
457
+ dim=reduce_dims
458
+ )
459
+ )
460
+
461
+ return c_stat
462
+
339
463
  def plot_ppc(
340
464
  self,
341
- percentile: Tuple[int, int] = (14, 86),
465
+ percentile: tuple[int, int] = (16, 84),
342
466
  x_unit: str | u.Unit = "keV",
343
- y_type: Literal["counts", "countrate", "photon_flux", "photon_flux_density"] = "photon_flux_density",
467
+ y_type: Literal[
468
+ "counts", "countrate", "photon_flux", "photon_flux_density"
469
+ ] = "photon_flux_density",
344
470
  ) -> plt.Figure:
345
471
  r"""
346
472
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
@@ -383,21 +509,31 @@ class FitResult:
383
509
  # and enable weird broadcasting that makes the plot fail
384
510
 
385
511
  fig, axs = plt.subplots(
386
- 2, len(obsconf_container), figsize=(6 * len(obsconf_container), 6), sharex=True, height_ratios=[0.7, 0.3]
512
+ 2,
513
+ len(obsconf_container),
514
+ figsize=(6 * len(obsconf_container), 6),
515
+ sharex=True,
516
+ height_ratios=[0.7, 0.3],
387
517
  )
388
518
 
389
519
  plot_ylabels_once = True
390
520
 
391
521
  for name, obsconf, ax in zip(
392
- obsconf_container.keys(), obsconf_container.values(), axs.T if len(obsconf_container) > 1 else [axs]
522
+ obsconf_container.keys(),
523
+ obsconf_container.values(),
524
+ axs.T if len(obsconf_container) > 1 else [axs],
393
525
  ):
394
526
  legend_plots = []
395
527
  legend_labels = []
396
- count = az.extract(self.inference_data, var_names=f"obs_{name}", group="posterior_predictive").values.T
528
+ count = az.extract(
529
+ self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
530
+ ).values.T
397
531
  bkg_count = (
398
532
  None
399
533
  if self.background_model is None
400
- else az.extract(self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive").values.T
534
+ else az.extract(
535
+ self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
536
+ ).values.T
401
537
  )
402
538
 
403
539
  xbins = obsconf.out_energies * u.keV
@@ -413,18 +549,13 @@ class FitResult:
413
549
  integrated_arf = (
414
550
  trapezoid(interpolated_arf, x=e_grid, axis=0)
415
551
  / (
416
- np.abs(xbins[1] - xbins[0]) # Must fold in abs because some units reverse the ordering of the bins
552
+ np.abs(
553
+ xbins[1] - xbins[0]
554
+ ) # Must fold in abs because some units reverse the ordering of the bins
417
555
  )
418
556
  * u.cm**2
419
557
  )
420
558
 
421
- """
422
- if xbins[0][0] < 1 < xbins[1][-1]:
423
- xticks = [np.floor(xbins[0][0] * 10) / 10, 1.0, np.floor(xbins[1][-1])]
424
- else:
425
- xticks = [np.floor(xbins[0][0] * 10) / 10, np.floor(xbins[1][-1])]
426
- """
427
-
428
559
  match y_type:
429
560
  case "counts":
430
561
  denominator = 1
@@ -437,50 +568,93 @@ class FitResult:
437
568
 
438
569
  y_samples = (count * u.photon / denominator).to(y_units)
439
570
  y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
571
+ y_observed_low = (
572
+ nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
573
+ * u.photon
574
+ / denominator
575
+ ).to(y_units)
576
+ y_observed_high = (
577
+ nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
578
+ * u.photon
579
+ / denominator
580
+ ).to(y_units)
440
581
 
441
582
  # Use the helper function to plot the data and posterior predictive
442
583
  legend_plots += _plot_binned_samples_with_error(
443
584
  ax[0],
444
585
  xbins.value,
445
586
  y_samples=y_samples.value,
446
- y_observed=y_observed.value,
447
587
  denominator=np.ones_like(y_observed).value,
448
588
  color=color,
449
589
  percentile=percentile,
450
590
  )
451
591
 
452
- legend_labels.append("Source + Background")
592
+ legend_labels.append("Model")
593
+
594
+ true_data_plot = ax[0].errorbar(
595
+ np.sqrt(xbins.value[0] * xbins.value[1]),
596
+ y_observed.value,
597
+ xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
598
+ yerr=[
599
+ y_observed.value - y_observed_low.value,
600
+ y_observed_high.value - y_observed.value,
601
+ ],
602
+ color="black",
603
+ linestyle="none",
604
+ alpha=0.3,
605
+ capsize=2,
606
+ )
607
+
608
+ legend_plots.append((true_data_plot,))
609
+ legend_labels.append("Observed")
453
610
 
454
611
  if self.background_model is not None:
455
612
  # We plot the background only if it is included in the fit, i.e. by subtracting
456
613
  ratio = obsconf.folded_backratio.data
457
614
  y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
458
- y_observed_bkg = (obsconf.folded_background.data * u.photon / (denominator * ratio)).to(y_units)
615
+ y_observed_bkg = (
616
+ obsconf.folded_background.data * u.photon / (denominator * ratio)
617
+ ).to(y_units)
459
618
  legend_plots += _plot_binned_samples_with_error(
460
619
  ax[0],
461
620
  xbins.value,
462
621
  y_samples=y_samples_bkg.value,
463
- y_observed=y_observed_bkg.value,
464
622
  denominator=np.ones_like(y_observed).value,
465
623
  color=(0.26787604, 0.60085972, 0.63302651),
466
624
  percentile=percentile,
467
625
  )
468
626
 
469
- legend_labels.append("Background")
627
+ legend_labels.append("Model (bkg)")
628
+
629
+ residual_samples = (obsconf.folded_counts.data - count) / np.diff(
630
+ np.percentile(count, percentile, axis=0), axis=0
631
+ )
470
632
 
471
633
  residuals = np.percentile(
472
- (obsconf.folded_counts.data - count) / np.diff(np.percentile(count, percentile, axis=0), axis=0),
634
+ residual_samples,
473
635
  percentile,
474
636
  axis=0,
475
637
  )
476
638
 
477
- ax[1].fill_between(
478
- list(xbins.value[0]) + [xbins.value[1][-1]],
479
- list(residuals[0]) + [residuals[0][-1]],
480
- list(residuals[1]) + [residuals[1][-1]],
639
+ median_residuals = np.median(
640
+ residual_samples,
641
+ axis=0,
642
+ )
643
+
644
+ ax[1].stairs(
645
+ residuals[1],
646
+ edges=[*list(xbins.value[0]), xbins.value[1][-1]],
647
+ baseline=list(residuals[0]),
481
648
  alpha=0.3,
482
- step="post",
483
649
  facecolor=color,
650
+ fill=True,
651
+ )
652
+
653
+ ax[1].stairs(
654
+ median_residuals,
655
+ edges=[*list(xbins.value[0]), xbins.value[1][-1]],
656
+ color=color,
657
+ alpha=0.7,
484
658
  )
485
659
 
486
660
  max_residuals = np.max(np.abs(residuals))
@@ -502,7 +676,8 @@ class FitResult:
502
676
  ax[1].set_xlabel(f"Frequency \n[{x_unit:latex_inline}]")
503
677
  case _:
504
678
  RuntimeError(
505
- f"Unknown physical type for x_units: {x_unit}. " f"Must be 'length', 'energy' or 'frequency'"
679
+ f"Unknown physical type for x_units: {x_unit}. "
680
+ f"Must be 'length', 'energy' or 'frequency'"
506
681
  )
507
682
 
508
683
  ax[1].axhline(0, color=color, ls="--")
@@ -536,15 +711,16 @@ class FitResult:
536
711
 
537
712
  def plot_corner(
538
713
  self,
539
- config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=6),
714
+ config: PlotConfig = PlotConfig(usetex=False, summarise=False, label_font_size=12),
540
715
  **kwargs: Any,
541
716
  ) -> plt.Figure:
542
717
  """
543
- Plot the corner plot of the posterior distribution of the parameters. This method uses the ChainConsumer.
718
+ Plot the corner plot of the posterior distribution of the parameters_type. This method uses the ChainConsumer.
544
719
 
545
720
  Parameters:
546
721
  config: The configuration of the plot.
547
- **kwargs: Additional arguments passed to ChainConsumer.plotter.plot.
722
+ **kwargs: Additional arguments passed to ChainConsumer.plotter.plot. Some useful parameters are :
723
+ - columns : list of parameters to plot.
548
724
  """
549
725
 
550
726
  consumer = ChainConsumer()