jaxspec 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Mapping
3
+ from functools import cached_property
4
4
  from typing import TYPE_CHECKING, Any, Literal, TypeVar
5
5
 
6
6
  import arviz as az
@@ -10,17 +10,28 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import matplotlib.pyplot as plt
12
12
  import numpy as np
13
+ import pandas as pd
13
14
  import xarray as xr
14
15
 
15
16
  from astropy.cosmology import Cosmology, Planck18
16
17
  from astropy.units import Unit
17
18
  from chainconsumer import Chain, ChainConsumer, PlotConfig
18
- from haiku.data_structures import traverse
19
+ from jax.experimental.sparse import BCOO
19
20
  from jax.typing import ArrayLike
20
21
  from numpyro.handlers import seed
21
- from scipy.integrate import trapezoid
22
22
  from scipy.special import gammaln
23
- from scipy.stats import nbinom
23
+
24
+ from ._plot import (
25
+ BACKGROUND_COLOR,
26
+ BACKGROUND_DATA_COLOR,
27
+ COLOR_CYCLE,
28
+ SPECTRUM_COLOR,
29
+ SPECTRUM_DATA_COLOR,
30
+ _compute_effective_area,
31
+ _error_bars_for_observed_data,
32
+ _plot_binned_samples_with_error,
33
+ _plot_poisson_data_with_error,
34
+ )
24
35
 
25
36
  if TYPE_CHECKING:
26
37
  from ..fit import BayesianModel
@@ -31,67 +42,6 @@ V = TypeVar("V")
31
42
  T = TypeVar("T")
32
43
 
33
44
 
34
- class HaikuDict(dict[str, dict[str, T]]): ...
35
-
36
-
37
- def _plot_binned_samples_with_error(
38
- ax: plt.Axes,
39
- x_bins: ArrayLike,
40
- denominator: ArrayLike | None = None,
41
- y_samples: ArrayLike | None = None,
42
- color=(0.15, 0.25, 0.45),
43
- percentile: tuple = (16, 84),
44
- ):
45
- """
46
- Helper function to plot the posterior predictive distribution of the model. The function
47
- computes the percentiles of the posterior predictive distribution and plot them as a shaded
48
- area. If the observed data is provided, it is also plotted as a step function.
49
-
50
- Parameters
51
- ----------
52
- x_bins: The bin edges of the data (2 x N).
53
- y_samples: The samples of the posterior predictive distribution (Samples X N).
54
- denominator: Values used to divided the samples, i.e. to get energy flux (N).
55
- ax: The matplotlib axes object.
56
- color: The color of the posterior predictive distribution.
57
- y_observed: The observed data (N).
58
- label: The label of the observed data.
59
- percentile: The percentile of the posterior predictive distribution to plot.
60
- """
61
-
62
- mean, envelope = None, None
63
-
64
- if denominator is None:
65
- denominator = np.ones_like(x_bins[0])
66
-
67
- mean = ax.stairs(
68
- list(np.median(y_samples, axis=0) / denominator),
69
- edges=[*list(x_bins[0]), x_bins[1][-1]],
70
- color=color,
71
- alpha=0.7,
72
- )
73
-
74
- if y_samples is not None:
75
- if denominator is None:
76
- denominator = np.ones_like(x_bins[0])
77
-
78
- percentiles = np.percentile(y_samples, percentile, axis=0)
79
-
80
- # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
81
- (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
82
-
83
- ax.stairs(
84
- percentiles[1] / denominator,
85
- edges=[*list(x_bins[0]), x_bins[1][-1]],
86
- baseline=percentiles[0] / denominator,
87
- alpha=0.3,
88
- fill=True,
89
- color=color,
90
- )
91
-
92
- return [(mean, envelope)]
93
-
94
-
95
45
  class FitResult:
96
46
  """
97
47
  Container for the result of a fit using any ModelFitter class.
@@ -102,7 +52,6 @@ class FitResult:
102
52
  self,
103
53
  bayesian_fitter: BayesianModel,
104
54
  inference_data: az.InferenceData,
105
- structure: Mapping[K, V],
106
55
  background_model: BackgroundModel = None,
107
56
  ):
108
57
  self.model = bayesian_fitter.model
@@ -110,7 +59,6 @@ class FitResult:
110
59
  self.inference_data = inference_data
111
60
  self.obsconfs = bayesian_fitter.observation_container
112
61
  self.background_model = background_model
113
- self._structure = structure
114
62
 
115
63
  # Add the model used in fit to the metadata
116
64
  for group in self.inference_data.groups():
@@ -127,37 +75,38 @@ class FitResult:
127
75
 
128
76
  return all(az.rhat(self.inference_data) < 1.01)
129
77
 
130
- @property
131
- def _structured_samples(self):
132
- """
133
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
134
- """
135
-
136
- samples_flat = self._structured_samples_flat
78
+ def _ppc_folded_branches(self, obs_id):
79
+ obs = self.obsconfs[obs_id]
137
80
 
138
- samples_haiku = {}
81
+ if len(next(iter(self.input_parameters.values())).shape) > 2:
82
+ idx = list(self.obsconfs.keys()).index(obs_id)
83
+ obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
139
84
 
140
- for module, parameter, value in traverse(self._structure):
141
- if samples_haiku.get(module, None) is None:
142
- samples_haiku[module] = {}
143
- samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
85
+ else:
86
+ obs_parameters = self.input_parameters
144
87
 
145
- return samples_haiku
88
+ if self.bayesian_fitter.sparse:
89
+ transfer_matrix = BCOO.from_scipy_sparse(
90
+ obs.transfer_matrix.data.to_scipy_sparse().tocsr()
91
+ )
146
92
 
147
- @property
148
- def _structured_samples_flat(self):
149
- """
150
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
151
- """
93
+ else:
94
+ transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
152
95
 
153
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
154
- posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
155
- samples_flat = {key: posterior[key].data for key in var_names}
96
+ energies = obs.in_energies
156
97
 
157
- return samples_flat
98
+ flux_func = jax.jit(
99
+ jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
100
+ )
101
+ convolve_func = jax.jit(
102
+ jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
103
+ )
104
+ return jax.tree.map(
105
+ lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
106
+ )
158
107
 
159
- @property
160
- def input_parameters(self) -> HaikuDict[ArrayLike]:
108
+ @cached_property
109
+ def input_parameters(self) -> dict[str, ArrayLike]:
161
110
  """
162
111
  The input parameters of the model.
163
112
  """
@@ -173,7 +122,9 @@ class FitResult:
173
122
  with seed(rng_seed=0):
174
123
  input_parameters = self.bayesian_fitter.prior_distributions_func()
175
124
 
176
- for module, parameter, value in traverse(input_parameters):
125
+ for key, value in input_parameters.items():
126
+ module, parameter = key.rsplit("_", 1)
127
+
177
128
  if f"{module}_{parameter}" in posterior.keys():
178
129
  # We add as extra dimension as there might be different values per observation
179
130
  if posterior[f"{module}_{parameter}"].shape == samples_shape:
@@ -181,19 +132,21 @@ class FitResult:
181
132
  else:
182
133
  to_set = posterior[f"{module}_{parameter}"]
183
134
 
184
- input_parameters[module][parameter] = to_set
135
+ input_parameters[f"{module}_{parameter}"] = to_set
185
136
 
186
137
  else:
187
138
  # The parameter is fixed in this case, so we just broadcast is over chain and draws
188
- input_parameters[module][parameter] = value[None, None, ...]
139
+ input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
189
140
 
190
- if len(total_shape) < len(input_parameters[module][parameter].shape):
141
+ if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
191
142
  # If there are only chains and draws, we reduce
192
- input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
143
+ input_parameters[f"{module}_{parameter}"] = input_parameters[
144
+ f"{module}_{parameter}"
145
+ ][..., 0]
193
146
 
194
147
  else:
195
- input_parameters[module][parameter] = jnp.broadcast_to(
196
- input_parameters[module][parameter], total_shape
148
+ input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
149
+ input_parameters[f"{module}_{parameter}"], total_shape
197
150
  )
198
151
 
199
152
  return input_parameters
@@ -346,94 +299,46 @@ class FitResult:
346
299
 
347
300
  return value
348
301
 
349
- def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
302
+ def to_chain(self, name: str) -> Chain:
350
303
  """
351
304
  Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
352
305
 
353
306
  Parameters:
354
307
  name: The name of the chain.
355
- parameters_type: The parameters_type to include in the chain.
356
308
  """
357
309
 
358
- obs_id = self.inference_data.copy()
359
-
360
- if parameters_type == "model":
361
- keys_to_drop = [
362
- key
363
- for key in obs_id.posterior.keys()
364
- if (key.startswith("_") or key.startswith("bkg"))
365
- ]
366
- elif parameters_type == "bkg":
367
- keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
368
- else:
369
- raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
370
-
371
- obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
372
- chain = Chain.from_arviz(obs_id, name)
373
-
374
- """
375
- chain.samples.columns = [
376
- format_parameters(parameter) for parameter in chain.samples.columns
310
+ keys_to_drop = [
311
+ key
312
+ for key in self.inference_data.posterior.keys()
313
+ if (key.startswith("_") or key.startswith("bkg"))
377
314
  ]
378
- """
379
315
 
380
- return chain
316
+ reduced_id = az.extract(
317
+ self.inference_data,
318
+ var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
319
+ group="posterior",
320
+ )
381
321
 
382
- @property
383
- def samples_haiku(self) -> HaikuDict[ArrayLike]:
384
- """
385
- Haiku-like structure for the samples e.g.
386
-
387
- ```
388
- {
389
- 'powerlaw_1' :
390
- {
391
- 'alpha': ...,
392
- 'amplitude': ...
393
- },
394
-
395
- 'blackbody_1':
396
- {
397
- 'kT': ...,
398
- 'norm': ...
399
- },
400
-
401
- 'tbabs_1':
402
- {
403
- 'nH': ...
404
- }
405
- }
406
- ```
322
+ df_list = []
407
323
 
408
- """
324
+ for var, array in reduced_id.data_vars.items():
325
+ extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
409
326
 
410
- params = {}
327
+ if extra_dims:
328
+ dim = extra_dims[
329
+ 0
330
+ ] # We only support the case where the extra dimension comes from the observations
411
331
 
412
- for module, parameter, value in traverse(self._structure):
413
- if params.get(module, None) is None:
414
- params[module] = {}
415
- params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
332
+ for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
333
+ df = array.loc[{dim: coord}].to_pandas()
334
+ df.name += f"\n[{obs_id}]"
335
+ df_list.append(df)
336
+ else:
337
+ df_list.append(array.to_pandas())
416
338
 
417
- return params
339
+ df = pd.concat(df_list, axis=1)
418
340
 
419
- @property
420
- def samples_flat(self) -> dict[str, ArrayLike]:
421
- """
422
- Flat structure for the samples e.g.
423
-
424
- ```
425
- {
426
- 'powerlaw_1_alpha': ...,
427
- 'powerlaw_1_amplitude': ...,
428
- 'blackbody_1_kT': ...,
429
- 'blackbody_1_norm': ...,
430
- 'tbabs_1_nH': ...,
431
- }
432
- ```
433
- """
434
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
435
- posterior = az.extract(self.inference_data, var_names=var_names)
436
- return {key: posterior[key].data for key in var_names}
341
+ return Chain(samples=df, name=name)
437
342
 
438
343
  @property
439
344
  def log_likelihood(self) -> xr.Dataset:
@@ -475,12 +380,20 @@ class FitResult:
475
380
 
476
381
  def plot_ppc(
477
382
  self,
478
- percentile: tuple[int, int] = (16, 84),
383
+ n_sigmas: int = 1,
479
384
  x_unit: str | u.Unit = "keV",
480
385
  y_type: Literal[
481
386
  "counts", "countrate", "photon_flux", "photon_flux_density"
482
387
  ] = "photon_flux_density",
483
- ) -> plt.Figure:
388
+ plot_background: bool = True,
389
+ plot_components: bool = False,
390
+ scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
391
+ alpha_envelope: (float, float) = (0.15, 0.25),
392
+ style: str | Any = "default",
393
+ title: str | None = None,
394
+ figsize: tuple[float, float] = (6, 6),
395
+ x_lims: tuple[float, float] | None = None,
396
+ ) -> list[plt.Figure]:
484
397
  r"""
485
398
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
486
399
  following formula:
@@ -489,15 +402,24 @@ class FitResult:
489
402
  {(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
490
403
 
491
404
  Parameters:
492
- percentile: The percentile of the posterior predictive distribution to plot.
405
+ n_sigmas: The number of sigmas to plot the envelops.
493
406
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
494
407
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
408
+ plot_background: Whether to plot the background model if it is included in the fit.
409
+ plot_components: Whether to plot the components of the model separately.
410
+ scale: The axes scaling
411
+ alpha_envelope: The transparency range for envelops
412
+ style: The style of the plot. It can be either a string or a matplotlib style context.
413
+ title: The title of the plot.
414
+ figsize: The size of the figure.
415
+ x_lims: The limits of the x-axis.
495
416
 
496
417
  Returns:
497
- The matplotlib figure.
418
+ A list of matplotlib figures for each observation in the model.
498
419
  """
499
420
 
500
421
  obsconf_container = self.obsconfs
422
+ figure_list = []
501
423
  x_unit = u.Unit(x_unit)
502
424
 
503
425
  match y_type:
@@ -514,60 +436,24 @@ class FitResult:
514
436
  f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
515
437
  )
516
438
 
517
- color = (0.15, 0.25, 0.45)
518
-
519
- with plt.style.context("default"):
520
- # Note to Simon : do not change xbins[1] - xbins[0] to
521
- # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
522
- # and enable weird broadcasting that makes the plot fail
523
-
524
- fig, axs = plt.subplots(
525
- 2,
526
- len(obsconf_container),
527
- figsize=(6 * len(obsconf_container), 6),
528
- sharex=True,
529
- height_ratios=[0.7, 0.3],
530
- )
531
-
532
- plot_ylabels_once = True
439
+ with plt.style.context(style):
440
+ for obs_id, obsconf in obsconf_container.items():
441
+ fig, ax = plt.subplots(
442
+ 2,
443
+ 1,
444
+ figsize=figsize,
445
+ sharex="col",
446
+ height_ratios=[0.7, 0.3],
447
+ )
533
448
 
534
- for name, obsconf, ax in zip(
535
- obsconf_container.keys(),
536
- obsconf_container.values(),
537
- axs.T if len(obsconf_container) > 1 else [axs],
538
- ):
539
449
  legend_plots = []
540
450
  legend_labels = []
451
+
541
452
  count = az.extract(
542
- self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
453
+ self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
543
454
  ).values.T
544
- bkg_count = (
545
- None
546
- if self.background_model is None
547
- else az.extract(
548
- self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
549
- ).values.T
550
- )
551
455
 
552
- xbins = obsconf.out_energies * u.keV
553
- xbins = xbins.to(x_unit, u.spectral())
554
-
555
- # This compute the total effective area within all bins
556
- # This is a bit weird since the following computation is equivalent to ignoring the RMF
557
- exposure = obsconf.exposure.data * u.s
558
- mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
559
- mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
560
- e_grid = np.linspace(*xbins, 10)
561
- interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
562
- integrated_arf = (
563
- trapezoid(interpolated_arf, x=e_grid, axis=0)
564
- / (
565
- np.abs(
566
- xbins[1] - xbins[0]
567
- ) # Must fold in abs because some units reverse the ordering of the bins
568
- )
569
- * u.cm**2
570
- )
456
+ xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
571
457
 
572
458
  match y_type:
573
459
  case "counts":
@@ -580,133 +466,125 @@ class FitResult:
580
466
  denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
581
467
 
582
468
  y_samples = (count * u.ct / denominator).to(y_units)
583
- y_observed = (obsconf.folded_counts.data * u.ct / denominator).to(y_units)
584
- y_observed_low = (
585
- nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
586
- * u.ct
587
- / denominator
588
- ).to(y_units)
589
- y_observed_high = (
590
- nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
591
- * u.ct
592
- / denominator
593
- ).to(y_units)
469
+
470
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
471
+ obsconf.folded_counts.data, denominator, y_units
472
+ )
594
473
 
595
474
  # Use the helper function to plot the data and posterior predictive
596
- legend_plots += _plot_binned_samples_with_error(
475
+ model_plot = _plot_binned_samples_with_error(
597
476
  ax[0],
598
477
  xbins.value,
599
- y_samples=y_samples.value,
600
- denominator=np.ones_like(y_observed).value,
601
- color=color,
602
- percentile=percentile,
478
+ y_samples.value,
479
+ color=SPECTRUM_COLOR,
480
+ n_sigmas=n_sigmas,
481
+ alpha_envelope=alpha_envelope,
603
482
  )
604
483
 
605
- legend_labels.append("Model")
606
-
607
- true_data_plot = ax[0].errorbar(
608
- np.sqrt(xbins.value[0] * xbins.value[1]),
484
+ true_data_plot = _plot_poisson_data_with_error(
485
+ ax[0],
486
+ xbins.value,
609
487
  y_observed.value,
610
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
611
- yerr=[
612
- y_observed.value - y_observed_low.value,
613
- y_observed_high.value - y_observed.value,
614
- ],
615
- color="black",
616
- linestyle="none",
617
- alpha=0.3,
618
- capsize=2,
488
+ y_observed_low.value,
489
+ y_observed_high.value,
490
+ color=SPECTRUM_DATA_COLOR,
491
+ alpha=0.7,
619
492
  )
620
493
 
621
494
  legend_plots.append((true_data_plot,))
622
495
  legend_labels.append("Observed")
496
+ legend_plots += model_plot
497
+ legend_labels.append("Model")
498
+
499
+ # Plot the residuals
500
+ residual_samples = (obsconf.folded_counts.data - count) / np.diff(
501
+ np.percentile(count, [16, 84], axis=0), axis=0
502
+ )
503
+
504
+ _plot_binned_samples_with_error(
505
+ ax[1],
506
+ xbins.value,
507
+ residual_samples,
508
+ color=SPECTRUM_COLOR,
509
+ n_sigmas=n_sigmas,
510
+ alpha_envelope=alpha_envelope,
511
+ )
512
+
513
+ if plot_components:
514
+ for (component_name, count), color in zip(
515
+ self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
516
+ ):
517
+ # _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
518
+ y_samples = (
519
+ count.reshape((count.shape[0] * count.shape[1], -1))
520
+ * u.ct
521
+ / denominator
522
+ ).to(y_units)
523
+ component_plot = _plot_binned_samples_with_error(
524
+ ax[0],
525
+ xbins.value,
526
+ y_samples.value,
527
+ color=color,
528
+ linestyle="dashdot",
529
+ n_sigmas=n_sigmas,
530
+ alpha_envelope=alpha_envelope,
531
+ )
623
532
 
624
- if self.background_model is not None:
533
+ name = component_name.split("*")[-1]
534
+
535
+ legend_plots += component_plot
536
+ legend_labels.append(name)
537
+
538
+ if self.background_model is not None and plot_background:
625
539
  # We plot the background only if it is included in the fit, i.e. by subtracting
626
- ratio = obsconf.folded_backratio.data
627
- y_samples_bkg = (bkg_count * u.ct / (denominator * ratio)).to(y_units)
628
- y_observed_bkg = (
629
- obsconf.folded_background.data * u.ct / (denominator * ratio)
630
- ).to(y_units)
631
- y_observed_bkg_low = (
632
- nbinom.ppf(percentile[0] / 100, obsconf.folded_background.data, 0.5)
633
- * u.ct
634
- / (denominator * ratio)
635
- ).to(y_units)
636
- y_observed_bkg_high = (
637
- nbinom.ppf(percentile[1] / 100, obsconf.folded_background.data, 0.5)
638
- * u.ct
639
- / (denominator * ratio)
640
- ).to(y_units)
641
-
642
- legend_plots += _plot_binned_samples_with_error(
540
+ bkg_count = (
541
+ None
542
+ if self.background_model is None
543
+ else az.extract(
544
+ self.inference_data,
545
+ var_names=f"bkg_{obs_id}",
546
+ group="posterior_predictive",
547
+ ).values.T
548
+ )
549
+
550
+ y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
551
+
552
+ y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
553
+ _error_bars_for_observed_data(
554
+ obsconf.folded_background.data, denominator, y_units
555
+ )
556
+ )
557
+
558
+ model_bkg_plot = _plot_binned_samples_with_error(
643
559
  ax[0],
644
560
  xbins.value,
645
- y_samples=y_samples_bkg.value,
646
- denominator=np.ones_like(y_observed).value,
647
- color=(0.26787604, 0.60085972, 0.63302651),
648
- percentile=percentile,
561
+ y_samples_bkg.value,
562
+ color=BACKGROUND_COLOR,
563
+ alpha_envelope=alpha_envelope,
564
+ n_sigmas=n_sigmas,
649
565
  )
650
566
 
651
- legend_labels.append("Model (bkg)")
652
-
653
- true_bkg_plot = ax[0].errorbar(
654
- np.sqrt(xbins.value[0] * xbins.value[1]),
567
+ true_bkg_plot = _plot_poisson_data_with_error(
568
+ ax[0],
569
+ xbins.value,
655
570
  y_observed_bkg.value,
656
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
657
- yerr=[
658
- y_observed_bkg.value - y_observed_bkg_low.value,
659
- y_observed_bkg_high.value - y_observed_bkg.value,
660
- ],
661
- color="black",
662
- linestyle="none",
663
- alpha=0.3,
664
- capsize=2,
571
+ y_observed_bkg_low.value,
572
+ y_observed_bkg_high.value,
573
+ color=BACKGROUND_DATA_COLOR,
574
+ alpha=0.7,
665
575
  )
666
576
 
667
577
  legend_plots.append((true_bkg_plot,))
668
578
  legend_labels.append("Observed (bkg)")
579
+ legend_plots += model_bkg_plot
580
+ legend_labels.append("Model (bkg)")
669
581
 
670
- residual_samples = (obsconf.folded_counts.data - count) / np.diff(
671
- np.percentile(count, percentile, axis=0), axis=0
672
- )
673
-
674
- residuals = np.percentile(
675
- residual_samples,
676
- percentile,
677
- axis=0,
678
- )
679
-
680
- median_residuals = np.median(
681
- residual_samples,
682
- axis=0,
683
- )
684
-
685
- ax[1].stairs(
686
- residuals[1],
687
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
688
- baseline=list(residuals[0]),
689
- alpha=0.3,
690
- facecolor=color,
691
- fill=True,
692
- )
693
-
694
- ax[1].stairs(
695
- median_residuals,
696
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
697
- color=color,
698
- alpha=0.7,
699
- )
700
-
701
- max_residuals = np.max(np.abs(residuals))
582
+ max_residuals = np.max(np.abs(residual_samples))
702
583
 
703
584
  ax[0].loglog()
704
585
  ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
705
-
706
- if plot_ylabels_once:
707
- ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
708
- ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
709
- plot_ylabels_once = False
586
+ ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
587
+ ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
710
588
 
711
589
  match getattr(x_unit, "physical_type"):
712
590
  case "length":
@@ -721,24 +599,45 @@ class FitResult:
721
599
  f"Must be 'length', 'energy' or 'frequency'"
722
600
  )
723
601
 
724
- ax[1].axhline(0, color=color, ls="--")
725
- ax[1].axhline(-3, color=color, ls=":")
726
- ax[1].axhline(3, color=color, ls=":")
602
+ ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
603
+ ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
604
+ ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
727
605
 
728
- # ax[1].set_xticks(xticks, labels=xticks)
729
- # ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
730
606
  ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
731
607
  ax[1].set_yticks(range(-3, 4), minor=True)
732
608
 
733
609
  ax[0].set_xlim(xbins.value.min(), xbins.value.max())
734
610
 
735
611
  ax[0].legend(legend_plots, legend_labels)
736
- fig.suptitle(self.model.to_string())
612
+
613
+ match scale:
614
+ case "linear":
615
+ ax[0].set_xscale("linear")
616
+ ax[0].set_yscale("linear")
617
+ case "semilogx":
618
+ ax[0].set_xscale("log")
619
+ ax[0].set_yscale("linear")
620
+ case "semilogy":
621
+ ax[0].set_xscale("linear")
622
+ ax[0].set_yscale("log")
623
+ case "loglog":
624
+ ax[0].set_xscale("log")
625
+ ax[0].set_yscale("log")
626
+
627
+ if x_lims is not None:
628
+ ax[0].set_xlim(*x_lims)
629
+
737
630
  fig.align_ylabels()
738
631
  plt.subplots_adjust(hspace=0.0)
739
632
  fig.tight_layout()
633
+ figure_list.append(fig)
634
+ fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
635
+ # fig.show()
636
+
637
+ plt.tight_layout()
638
+ plt.show()
740
639
 
741
- return fig
640
+ return figure_list
742
641
 
743
642
  def table(self) -> str:
744
643
  r"""
@@ -765,7 +664,7 @@ class FitResult:
765
664
  """
766
665
 
767
666
  consumer = ChainConsumer()
768
- consumer.add_chain(self.to_chain(self.model.to_string()))
667
+ consumer.add_chain(self.to_chain("Results"))
769
668
  consumer.set_plot_config(config)
770
669
 
771
670
  # Context for default mpl style