jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Mapping
3
+ from functools import cached_property
4
4
  from typing import TYPE_CHECKING, Any, Literal, TypeVar
5
5
 
6
6
  import arviz as az
@@ -10,17 +10,28 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import matplotlib.pyplot as plt
12
12
  import numpy as np
13
+ import pandas as pd
13
14
  import xarray as xr
14
15
 
15
16
  from astropy.cosmology import Cosmology, Planck18
16
17
  from astropy.units import Unit
17
18
  from chainconsumer import Chain, ChainConsumer, PlotConfig
18
- from haiku.data_structures import traverse
19
+ from jax.experimental.sparse import BCOO
19
20
  from jax.typing import ArrayLike
20
21
  from numpyro.handlers import seed
21
- from scipy.integrate import trapezoid
22
22
  from scipy.special import gammaln
23
- from scipy.stats import nbinom
23
+
24
+ from ._plot import (
25
+ BACKGROUND_COLOR,
26
+ BACKGROUND_DATA_COLOR,
27
+ COLOR_CYCLE,
28
+ SPECTRUM_COLOR,
29
+ SPECTRUM_DATA_COLOR,
30
+ _compute_effective_area,
31
+ _error_bars_for_observed_data,
32
+ _plot_binned_samples_with_error,
33
+ _plot_poisson_data_with_error,
34
+ )
24
35
 
25
36
  if TYPE_CHECKING:
26
37
  from ..fit import BayesianModel
@@ -31,67 +42,6 @@ V = TypeVar("V")
31
42
  T = TypeVar("T")
32
43
 
33
44
 
34
- class HaikuDict(dict[str, dict[str, T]]): ...
35
-
36
-
37
- def _plot_binned_samples_with_error(
38
- ax: plt.Axes,
39
- x_bins: ArrayLike,
40
- denominator: ArrayLike | None = None,
41
- y_samples: ArrayLike | None = None,
42
- color=(0.15, 0.25, 0.45),
43
- percentile: tuple = (16, 84),
44
- ):
45
- """
46
- Helper function to plot the posterior predictive distribution of the model. The function
47
- computes the percentiles of the posterior predictive distribution and plot them as a shaded
48
- area. If the observed data is provided, it is also plotted as a step function.
49
-
50
- Parameters
51
- ----------
52
- x_bins: The bin edges of the data (2 x N).
53
- y_samples: The samples of the posterior predictive distribution (Samples X N).
54
- denominator: Values used to divided the samples, i.e. to get energy flux (N).
55
- ax: The matplotlib axes object.
56
- color: The color of the posterior predictive distribution.
57
- y_observed: The observed data (N).
58
- label: The label of the observed data.
59
- percentile: The percentile of the posterior predictive distribution to plot.
60
- """
61
-
62
- mean, envelope = None, None
63
-
64
- if denominator is None:
65
- denominator = np.ones_like(x_bins[0])
66
-
67
- mean = ax.stairs(
68
- list(np.median(y_samples, axis=0) / denominator),
69
- edges=[*list(x_bins[0]), x_bins[1][-1]],
70
- color=color,
71
- alpha=0.7,
72
- )
73
-
74
- if y_samples is not None:
75
- if denominator is None:
76
- denominator = np.ones_like(x_bins[0])
77
-
78
- percentiles = np.percentile(y_samples, percentile, axis=0)
79
-
80
- # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
81
- (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
82
-
83
- ax.stairs(
84
- percentiles[1] / denominator,
85
- edges=[*list(x_bins[0]), x_bins[1][-1]],
86
- baseline=percentiles[0] / denominator,
87
- alpha=0.3,
88
- fill=True,
89
- color=color,
90
- )
91
-
92
- return [(mean, envelope)]
93
-
94
-
95
45
  class FitResult:
96
46
  """
97
47
  Container for the result of a fit using any ModelFitter class.
@@ -102,7 +52,6 @@ class FitResult:
102
52
  self,
103
53
  bayesian_fitter: BayesianModel,
104
54
  inference_data: az.InferenceData,
105
- structure: Mapping[K, V],
106
55
  background_model: BackgroundModel = None,
107
56
  ):
108
57
  self.model = bayesian_fitter.model
@@ -110,7 +59,6 @@ class FitResult:
110
59
  self.inference_data = inference_data
111
60
  self.obsconfs = bayesian_fitter.observation_container
112
61
  self.background_model = background_model
113
- self._structure = structure
114
62
 
115
63
  # Add the model used in fit to the metadata
116
64
  for group in self.inference_data.groups():
@@ -127,37 +75,38 @@ class FitResult:
127
75
 
128
76
  return all(az.rhat(self.inference_data) < 1.01)
129
77
 
130
- @property
131
- def _structured_samples(self):
132
- """
133
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
134
- """
135
-
136
- samples_flat = self._structured_samples_flat
78
+ def _ppc_folded_branches(self, obs_id):
79
+ obs = self.obsconfs[obs_id]
137
80
 
138
- samples_haiku = {}
81
+ if len(next(iter(self.input_parameters.values())).shape) > 2:
82
+ idx = list(self.obsconfs.keys()).index(obs_id)
83
+ obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
139
84
 
140
- for module, parameter, value in traverse(self._structure):
141
- if samples_haiku.get(module, None) is None:
142
- samples_haiku[module] = {}
143
- samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
85
+ else:
86
+ obs_parameters = self.input_parameters
144
87
 
145
- return samples_haiku
88
+ if self.bayesian_fitter.sparse:
89
+ transfer_matrix = BCOO.from_scipy_sparse(
90
+ obs.transfer_matrix.data.to_scipy_sparse().tocsr()
91
+ )
146
92
 
147
- @property
148
- def _structured_samples_flat(self):
149
- """
150
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
151
- """
93
+ else:
94
+ transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
152
95
 
153
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
154
- posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
155
- samples_flat = {key: posterior[key].data for key in var_names}
96
+ energies = obs.in_energies
156
97
 
157
- return samples_flat
98
+ flux_func = jax.jit(
99
+ jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
100
+ )
101
+ convolve_func = jax.jit(
102
+ jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
103
+ )
104
+ return jax.tree.map(
105
+ lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
106
+ )
158
107
 
159
- @property
160
- def input_parameters(self) -> HaikuDict[ArrayLike]:
108
+ @cached_property
109
+ def input_parameters(self) -> dict[str, ArrayLike]:
161
110
  """
162
111
  The input parameters of the model.
163
112
  """
@@ -173,7 +122,9 @@ class FitResult:
173
122
  with seed(rng_seed=0):
174
123
  input_parameters = self.bayesian_fitter.prior_distributions_func()
175
124
 
176
- for module, parameter, value in traverse(input_parameters):
125
+ for key, value in input_parameters.items():
126
+ module, parameter = key.rsplit("_", 1)
127
+
177
128
  if f"{module}_{parameter}" in posterior.keys():
178
129
  # We add as extra dimension as there might be different values per observation
179
130
  if posterior[f"{module}_{parameter}"].shape == samples_shape:
@@ -181,19 +132,21 @@ class FitResult:
181
132
  else:
182
133
  to_set = posterior[f"{module}_{parameter}"]
183
134
 
184
- input_parameters[module][parameter] = to_set
135
+ input_parameters[f"{module}_{parameter}"] = to_set
185
136
 
186
137
  else:
187
138
  # The parameter is fixed in this case, so we just broadcast is over chain and draws
188
- input_parameters[module][parameter] = value[None, None, ...]
139
+ input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
189
140
 
190
- if len(total_shape) < len(input_parameters[module][parameter].shape):
141
+ if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
191
142
  # If there are only chains and draws, we reduce
192
- input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
143
+ input_parameters[f"{module}_{parameter}"] = input_parameters[
144
+ f"{module}_{parameter}"
145
+ ][..., 0]
193
146
 
194
147
  else:
195
- input_parameters[module][parameter] = jnp.broadcast_to(
196
- input_parameters[module][parameter], total_shape
148
+ input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
149
+ input_parameters[f"{module}_{parameter}"], total_shape
197
150
  )
198
151
 
199
152
  return input_parameters
@@ -346,94 +299,46 @@ class FitResult:
346
299
 
347
300
  return value
348
301
 
349
- def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
302
+ def to_chain(self, name: str) -> Chain:
350
303
  """
351
304
  Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
352
305
 
353
306
  Parameters:
354
307
  name: The name of the chain.
355
- parameters_type: The parameters_type to include in the chain.
356
308
  """
357
309
 
358
- obs_id = self.inference_data.copy()
359
-
360
- if parameters_type == "model":
361
- keys_to_drop = [
362
- key
363
- for key in obs_id.posterior.keys()
364
- if (key.startswith("_") or key.startswith("bkg"))
365
- ]
366
- elif parameters_type == "bkg":
367
- keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
368
- else:
369
- raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
370
-
371
- obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
372
- chain = Chain.from_arviz(obs_id, name)
373
-
374
- """
375
- chain.samples.columns = [
376
- format_parameters(parameter) for parameter in chain.samples.columns
310
+ keys_to_drop = [
311
+ key
312
+ for key in self.inference_data.posterior.keys()
313
+ if (key.startswith("_") or key.startswith("bkg"))
377
314
  ]
378
- """
379
315
 
380
- return chain
316
+ reduced_id = az.extract(
317
+ self.inference_data,
318
+ var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
319
+ group="posterior",
320
+ )
381
321
 
382
- @property
383
- def samples_haiku(self) -> HaikuDict[ArrayLike]:
384
- """
385
- Haiku-like structure for the samples e.g.
386
-
387
- ```
388
- {
389
- 'powerlaw_1' :
390
- {
391
- 'alpha': ...,
392
- 'amplitude': ...
393
- },
394
-
395
- 'blackbody_1':
396
- {
397
- 'kT': ...,
398
- 'norm': ...
399
- },
400
-
401
- 'tbabs_1':
402
- {
403
- 'nH': ...
404
- }
405
- }
406
- ```
322
+ df_list = []
407
323
 
408
- """
324
+ for var, array in reduced_id.data_vars.items():
325
+ extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
409
326
 
410
- params = {}
327
+ if extra_dims:
328
+ dim = extra_dims[
329
+ 0
330
+ ] # We only support the case where the extra dimension comes from the observations
411
331
 
412
- for module, parameter, value in traverse(self._structure):
413
- if params.get(module, None) is None:
414
- params[module] = {}
415
- params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
332
+ for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
333
+ df = array.loc[{dim: coord}].to_pandas()
334
+ df.name += f"\n[{obs_id}]"
335
+ df_list.append(df)
336
+ else:
337
+ df_list.append(array.to_pandas())
416
338
 
417
- return params
339
+ df = pd.concat(df_list, axis=1)
418
340
 
419
- @property
420
- def samples_flat(self) -> dict[str, ArrayLike]:
421
- """
422
- Flat structure for the samples e.g.
423
-
424
- ```
425
- {
426
- 'powerlaw_1_alpha': ...,
427
- 'powerlaw_1_amplitude': ...,
428
- 'blackbody_1_kT': ...,
429
- 'blackbody_1_norm': ...,
430
- 'tbabs_1_nH': ...,
431
- }
432
- ```
433
- """
434
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
435
- posterior = az.extract(self.inference_data, var_names=var_names)
436
- return {key: posterior[key].data for key in var_names}
341
+ return Chain(samples=df, name=name)
437
342
 
438
343
  @property
439
344
  def log_likelihood(self) -> xr.Dataset:
@@ -475,12 +380,18 @@ class FitResult:
475
380
 
476
381
  def plot_ppc(
477
382
  self,
478
- percentile: tuple[int, int] = (16, 84),
383
+ n_sigmas: int = 1,
479
384
  x_unit: str | u.Unit = "keV",
480
385
  y_type: Literal[
481
386
  "counts", "countrate", "photon_flux", "photon_flux_density"
482
387
  ] = "photon_flux_density",
483
- ) -> plt.Figure:
388
+ plot_background: bool = True,
389
+ plot_components: bool = False,
390
+ scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
391
+ alpha_envelope: (float, float) = (0.15, 0.25),
392
+ style: str | Any = "default",
393
+ title: str | None = None,
394
+ ) -> list[plt.Figure]:
484
395
  r"""
485
396
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
486
397
  following formula:
@@ -492,12 +403,18 @@ class FitResult:
492
403
  percentile: The percentile of the posterior predictive distribution to plot.
493
404
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
494
405
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
406
+ plot_background: Whether to plot the background model if it is included in the fit.
407
+ plot_components: Whether to plot the components of the model separately.
408
+ scale: The axes scaling
409
+ alpha_envelope: The transparency range for envelops
410
+ style: The style of the plot. It can be either a string or a matplotlib style context.
495
411
 
496
412
  Returns:
497
- The matplotlib figure.
413
+ A list of matplotlib figures for each observation in the model.
498
414
  """
499
415
 
500
416
  obsconf_container = self.obsconfs
417
+ figure_list = []
501
418
  x_unit = u.Unit(x_unit)
502
419
 
503
420
  match y_type:
@@ -514,60 +431,24 @@ class FitResult:
514
431
  f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
515
432
  )
516
433
 
517
- color = (0.15, 0.25, 0.45)
518
-
519
- with plt.style.context("default"):
520
- # Note to Simon : do not change xbins[1] - xbins[0] to
521
- # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
522
- # and enable weird broadcasting that makes the plot fail
523
-
524
- fig, axs = plt.subplots(
525
- 2,
526
- len(obsconf_container),
527
- figsize=(6 * len(obsconf_container), 6),
528
- sharex=True,
529
- height_ratios=[0.7, 0.3],
530
- )
531
-
532
- plot_ylabels_once = True
434
+ with plt.style.context(style):
435
+ for obs_id, obsconf in obsconf_container.items():
436
+ fig, ax = plt.subplots(
437
+ 2,
438
+ 1,
439
+ figsize=(6, 6),
440
+ sharex="col",
441
+ height_ratios=[0.7, 0.3],
442
+ )
533
443
 
534
- for name, obsconf, ax in zip(
535
- obsconf_container.keys(),
536
- obsconf_container.values(),
537
- axs.T if len(obsconf_container) > 1 else [axs],
538
- ):
539
444
  legend_plots = []
540
445
  legend_labels = []
446
+
541
447
  count = az.extract(
542
- self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
448
+ self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
543
449
  ).values.T
544
- bkg_count = (
545
- None
546
- if self.background_model is None
547
- else az.extract(
548
- self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
549
- ).values.T
550
- )
551
450
 
552
- xbins = obsconf.out_energies * u.keV
553
- xbins = xbins.to(x_unit, u.spectral())
554
-
555
- # This compute the total effective area within all bins
556
- # This is a bit weird since the following computation is equivalent to ignoring the RMF
557
- exposure = obsconf.exposure.data * u.s
558
- mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
559
- mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
560
- e_grid = np.linspace(*xbins, 10)
561
- interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
562
- integrated_arf = (
563
- trapezoid(interpolated_arf, x=e_grid, axis=0)
564
- / (
565
- np.abs(
566
- xbins[1] - xbins[0]
567
- ) # Must fold in abs because some units reverse the ordering of the bins
568
- )
569
- * u.cm**2
570
- )
451
+ xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
571
452
 
572
453
  match y_type:
573
454
  case "counts":
@@ -580,133 +461,123 @@ class FitResult:
580
461
  denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
581
462
 
582
463
  y_samples = (count * u.ct / denominator).to(y_units)
583
- y_observed = (obsconf.folded_counts.data * u.ct / denominator).to(y_units)
584
- y_observed_low = (
585
- nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
586
- * u.ct
587
- / denominator
588
- ).to(y_units)
589
- y_observed_high = (
590
- nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
591
- * u.ct
592
- / denominator
593
- ).to(y_units)
464
+
465
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
466
+ obsconf.folded_counts.data, denominator, y_units
467
+ )
594
468
 
595
469
  # Use the helper function to plot the data and posterior predictive
596
- legend_plots += _plot_binned_samples_with_error(
470
+ model_plot = _plot_binned_samples_with_error(
597
471
  ax[0],
598
472
  xbins.value,
599
- y_samples=y_samples.value,
600
- denominator=np.ones_like(y_observed).value,
601
- color=color,
602
- percentile=percentile,
473
+ y_samples.value,
474
+ color=SPECTRUM_COLOR,
475
+ n_sigmas=n_sigmas,
476
+ alpha_envelope=alpha_envelope,
603
477
  )
604
478
 
605
- legend_labels.append("Model")
606
-
607
- true_data_plot = ax[0].errorbar(
608
- np.sqrt(xbins.value[0] * xbins.value[1]),
479
+ true_data_plot = _plot_poisson_data_with_error(
480
+ ax[0],
481
+ xbins.value,
609
482
  y_observed.value,
610
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
611
- yerr=[
612
- y_observed.value - y_observed_low.value,
613
- y_observed_high.value - y_observed.value,
614
- ],
615
- color="black",
616
- linestyle="none",
617
- alpha=0.3,
618
- capsize=2,
483
+ y_observed_low.value,
484
+ y_observed_high.value,
485
+ color=SPECTRUM_DATA_COLOR,
486
+ alpha=0.7,
619
487
  )
620
488
 
621
489
  legend_plots.append((true_data_plot,))
622
490
  legend_labels.append("Observed")
491
+ legend_plots += model_plot
492
+ legend_labels.append("Model")
493
+
494
+ # Plot the residuals
495
+ residual_samples = (obsconf.folded_counts.data - count) / np.diff(
496
+ np.percentile(count, [16, 84], axis=0), axis=0
497
+ )
623
498
 
624
- if self.background_model is not None:
499
+ _plot_binned_samples_with_error(
500
+ ax[1],
501
+ xbins.value,
502
+ residual_samples,
503
+ color=SPECTRUM_COLOR,
504
+ n_sigmas=n_sigmas,
505
+ alpha_envelope=alpha_envelope,
506
+ )
507
+
508
+ if plot_components:
509
+ for (component_name, count), color in zip(
510
+ self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
511
+ ):
512
+ # _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
513
+ y_samples = (
514
+ count.reshape((count.shape[0] * count.shape[1], -1))
515
+ * u.ct
516
+ / denominator
517
+ ).to(y_units)
518
+ component_plot = _plot_binned_samples_with_error(
519
+ ax[0],
520
+ xbins.value,
521
+ y_samples.value,
522
+ color=color,
523
+ linestyle="dashdot",
524
+ n_sigmas=n_sigmas,
525
+ alpha_envelope=alpha_envelope,
526
+ )
527
+
528
+ legend_plots += component_plot
529
+ legend_labels.append(component_name)
530
+
531
+ if self.background_model is not None and plot_background:
625
532
  # We plot the background only if it is included in the fit, i.e. by subtracting
626
- ratio = obsconf.folded_backratio.data
627
- y_samples_bkg = (bkg_count * u.ct / (denominator * ratio)).to(y_units)
628
- y_observed_bkg = (
629
- obsconf.folded_background.data * u.ct / (denominator * ratio)
630
- ).to(y_units)
631
- y_observed_bkg_low = (
632
- nbinom.ppf(percentile[0] / 100, obsconf.folded_background.data, 0.5)
633
- * u.ct
634
- / (denominator * ratio)
635
- ).to(y_units)
636
- y_observed_bkg_high = (
637
- nbinom.ppf(percentile[1] / 100, obsconf.folded_background.data, 0.5)
638
- * u.ct
639
- / (denominator * ratio)
640
- ).to(y_units)
641
-
642
- legend_plots += _plot_binned_samples_with_error(
533
+ bkg_count = (
534
+ None
535
+ if self.background_model is None
536
+ else az.extract(
537
+ self.inference_data,
538
+ var_names=f"bkg_{obs_id}",
539
+ group="posterior_predictive",
540
+ ).values.T
541
+ )
542
+
543
+ y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
544
+
545
+ y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
546
+ _error_bars_for_observed_data(
547
+ obsconf.folded_background.data, denominator, y_units
548
+ )
549
+ )
550
+
551
+ model_bkg_plot = _plot_binned_samples_with_error(
643
552
  ax[0],
644
553
  xbins.value,
645
- y_samples=y_samples_bkg.value,
646
- denominator=np.ones_like(y_observed).value,
647
- color=(0.26787604, 0.60085972, 0.63302651),
648
- percentile=percentile,
554
+ y_samples_bkg.value,
555
+ color=BACKGROUND_COLOR,
556
+ alpha_envelope=alpha_envelope,
557
+ n_sigmas=n_sigmas,
649
558
  )
650
559
 
651
- legend_labels.append("Model (bkg)")
652
-
653
- true_bkg_plot = ax[0].errorbar(
654
- np.sqrt(xbins.value[0] * xbins.value[1]),
560
+ true_bkg_plot = _plot_poisson_data_with_error(
561
+ ax[0],
562
+ xbins.value,
655
563
  y_observed_bkg.value,
656
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
657
- yerr=[
658
- y_observed_bkg.value - y_observed_bkg_low.value,
659
- y_observed_bkg_high.value - y_observed_bkg.value,
660
- ],
661
- color="black",
662
- linestyle="none",
663
- alpha=0.3,
664
- capsize=2,
564
+ y_observed_bkg_low.value,
565
+ y_observed_bkg_high.value,
566
+ color=BACKGROUND_DATA_COLOR,
567
+ alpha=0.7,
665
568
  )
666
569
 
667
570
  legend_plots.append((true_bkg_plot,))
668
571
  legend_labels.append("Observed (bkg)")
572
+ legend_plots += model_bkg_plot
573
+ legend_labels.append("Model (bkg)")
669
574
 
670
- residual_samples = (obsconf.folded_counts.data - count) / np.diff(
671
- np.percentile(count, percentile, axis=0), axis=0
672
- )
673
-
674
- residuals = np.percentile(
675
- residual_samples,
676
- percentile,
677
- axis=0,
678
- )
679
-
680
- median_residuals = np.median(
681
- residual_samples,
682
- axis=0,
683
- )
684
-
685
- ax[1].stairs(
686
- residuals[1],
687
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
688
- baseline=list(residuals[0]),
689
- alpha=0.3,
690
- facecolor=color,
691
- fill=True,
692
- )
693
-
694
- ax[1].stairs(
695
- median_residuals,
696
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
697
- color=color,
698
- alpha=0.7,
699
- )
700
-
701
- max_residuals = np.max(np.abs(residuals))
575
+ max_residuals = np.max(np.abs(residual_samples))
702
576
 
703
577
  ax[0].loglog()
704
578
  ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
705
-
706
- if plot_ylabels_once:
707
- ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
708
- ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
709
- plot_ylabels_once = False
579
+ ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
580
+ ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
710
581
 
711
582
  match getattr(x_unit, "physical_type"):
712
583
  case "length":
@@ -721,24 +592,42 @@ class FitResult:
721
592
  f"Must be 'length', 'energy' or 'frequency'"
722
593
  )
723
594
 
724
- ax[1].axhline(0, color=color, ls="--")
725
- ax[1].axhline(-3, color=color, ls=":")
726
- ax[1].axhline(3, color=color, ls=":")
595
+ ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
596
+ ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
597
+ ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
727
598
 
728
- # ax[1].set_xticks(xticks, labels=xticks)
729
- # ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
730
599
  ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
731
600
  ax[1].set_yticks(range(-3, 4), minor=True)
732
601
 
733
602
  ax[0].set_xlim(xbins.value.min(), xbins.value.max())
734
603
 
735
604
  ax[0].legend(legend_plots, legend_labels)
736
- fig.suptitle(self.model.to_string())
605
+
606
+ match scale:
607
+ case "linear":
608
+ ax[0].set_xscale("linear")
609
+ ax[0].set_yscale("linear")
610
+ case "semilogx":
611
+ ax[0].set_xscale("log")
612
+ ax[0].set_yscale("linear")
613
+ case "semilogy":
614
+ ax[0].set_xscale("linear")
615
+ ax[0].set_yscale("log")
616
+ case "loglog":
617
+ ax[0].set_xscale("log")
618
+ ax[0].set_yscale("log")
619
+
737
620
  fig.align_ylabels()
738
621
  plt.subplots_adjust(hspace=0.0)
739
622
  fig.tight_layout()
623
+ figure_list.append(fig)
624
+ fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
625
+ # fig.show()
626
+
627
+ plt.tight_layout()
628
+ plt.show()
740
629
 
741
- return fig
630
+ return figure_list
742
631
 
743
632
  def table(self) -> str:
744
633
  r"""