jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,37 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Mapping
3
+ from functools import cached_property
4
4
  from typing import TYPE_CHECKING, Any, Literal, TypeVar
5
5
 
6
6
  import arviz as az
7
+ import astropy.cosmology.units as cu
7
8
  import astropy.units as u
8
9
  import jax
9
10
  import jax.numpy as jnp
10
11
  import matplotlib.pyplot as plt
11
12
  import numpy as np
13
+ import pandas as pd
12
14
  import xarray as xr
13
15
 
14
16
  from astropy.cosmology import Cosmology, Planck18
15
17
  from astropy.units import Unit
16
18
  from chainconsumer import Chain, ChainConsumer, PlotConfig
17
- from haiku.data_structures import traverse
19
+ from jax.experimental.sparse import BCOO
18
20
  from jax.typing import ArrayLike
19
21
  from numpyro.handlers import seed
20
- from scipy.integrate import trapezoid
21
22
  from scipy.special import gammaln
22
- from scipy.stats import nbinom
23
+
24
+ from ._plot import (
25
+ BACKGROUND_COLOR,
26
+ BACKGROUND_DATA_COLOR,
27
+ COLOR_CYCLE,
28
+ SPECTRUM_COLOR,
29
+ SPECTRUM_DATA_COLOR,
30
+ _compute_effective_area,
31
+ _error_bars_for_observed_data,
32
+ _plot_binned_samples_with_error,
33
+ _plot_poisson_data_with_error,
34
+ )
23
35
 
24
36
  if TYPE_CHECKING:
25
37
  from ..fit import BayesianModel
@@ -30,67 +42,6 @@ V = TypeVar("V")
30
42
  T = TypeVar("T")
31
43
 
32
44
 
33
- class HaikuDict(dict[str, dict[str, T]]): ...
34
-
35
-
36
- def _plot_binned_samples_with_error(
37
- ax: plt.Axes,
38
- x_bins: ArrayLike,
39
- denominator: ArrayLike | None = None,
40
- y_samples: ArrayLike | None = None,
41
- color=(0.15, 0.25, 0.45),
42
- percentile: tuple = (16, 84),
43
- ):
44
- """
45
- Helper function to plot the posterior predictive distribution of the model. The function
46
- computes the percentiles of the posterior predictive distribution and plot them as a shaded
47
- area. If the observed data is provided, it is also plotted as a step function.
48
-
49
- Parameters
50
- ----------
51
- x_bins: The bin edges of the data (2 x N).
52
- y_samples: The samples of the posterior predictive distribution (Samples X N).
53
- denominator: Values used to divided the samples, i.e. to get energy flux (N).
54
- ax: The matplotlib axes object.
55
- color: The color of the posterior predictive distribution.
56
- y_observed: The observed data (N).
57
- label: The label of the observed data.
58
- percentile: The percentile of the posterior predictive distribution to plot.
59
- """
60
-
61
- mean, envelope = None, None
62
-
63
- if denominator is None:
64
- denominator = np.ones_like(x_bins[0])
65
-
66
- mean = ax.stairs(
67
- list(np.median(y_samples, axis=0) / denominator),
68
- edges=[*list(x_bins[0]), x_bins[1][-1]],
69
- color=color,
70
- alpha=0.7,
71
- )
72
-
73
- if y_samples is not None:
74
- if denominator is None:
75
- denominator = np.ones_like(x_bins[0])
76
-
77
- percentiles = np.percentile(y_samples, percentile, axis=0)
78
-
79
- # The legend cannot handle fill_between, so we pass a fill to get a fancy icon
80
- (envelope,) = ax.fill(np.nan, np.nan, alpha=0.3, facecolor=color)
81
-
82
- ax.stairs(
83
- percentiles[1] / denominator,
84
- edges=[*list(x_bins[0]), x_bins[1][-1]],
85
- baseline=percentiles[0] / denominator,
86
- alpha=0.3,
87
- fill=True,
88
- color=color,
89
- )
90
-
91
- return [(mean, envelope)]
92
-
93
-
94
45
  class FitResult:
95
46
  """
96
47
  Container for the result of a fit using any ModelFitter class.
@@ -101,7 +52,6 @@ class FitResult:
101
52
  self,
102
53
  bayesian_fitter: BayesianModel,
103
54
  inference_data: az.InferenceData,
104
- structure: Mapping[K, V],
105
55
  background_model: BackgroundModel = None,
106
56
  ):
107
57
  self.model = bayesian_fitter.model
@@ -109,7 +59,6 @@ class FitResult:
109
59
  self.inference_data = inference_data
110
60
  self.obsconfs = bayesian_fitter.observation_container
111
61
  self.background_model = background_model
112
- self._structure = structure
113
62
 
114
63
  # Add the model used in fit to the metadata
115
64
  for group in self.inference_data.groups():
@@ -126,37 +75,38 @@ class FitResult:
126
75
 
127
76
  return all(az.rhat(self.inference_data) < 1.01)
128
77
 
129
- @property
130
- def _structured_samples(self):
131
- """
132
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
133
- """
134
-
135
- samples_flat = self._structured_samples_flat
78
+ def _ppc_folded_branches(self, obs_id):
79
+ obs = self.obsconfs[obs_id]
136
80
 
137
- samples_haiku = {}
81
+ if len(next(iter(self.input_parameters.values())).shape) > 2:
82
+ idx = list(self.obsconfs.keys()).index(obs_id)
83
+ obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
138
84
 
139
- for module, parameter, value in traverse(self._structure):
140
- if samples_haiku.get(module, None) is None:
141
- samples_haiku[module] = {}
142
- samples_haiku[module][parameter] = samples_flat[f"{module}_{parameter}"]
85
+ else:
86
+ obs_parameters = self.input_parameters
143
87
 
144
- return samples_haiku
88
+ if self.bayesian_fitter.sparse:
89
+ transfer_matrix = BCOO.from_scipy_sparse(
90
+ obs.transfer_matrix.data.to_scipy_sparse().tocsr()
91
+ )
145
92
 
146
- @property
147
- def _structured_samples_flat(self):
148
- """
149
- Get samples from the parameter posterior distribution but keep their shape in terms of draw and chains.
150
- """
93
+ else:
94
+ transfer_matrix = np.asarray(obs.transfer_matrix.data.todense())
151
95
 
152
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
153
- posterior = az.extract(self.inference_data, var_names=var_names, combined=False)
154
- samples_flat = {key: posterior[key].data for key in var_names}
96
+ energies = obs.in_energies
155
97
 
156
- return samples_flat
98
+ flux_func = jax.jit(
99
+ jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
100
+ )
101
+ convolve_func = jax.jit(
102
+ jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
103
+ )
104
+ return jax.tree.map(
105
+ lambda flux: np.random.poisson(convolve_func(flux)), flux_func(obs_parameters)
106
+ )
157
107
 
158
- @property
159
- def input_parameters(self) -> HaikuDict[ArrayLike]:
108
+ @cached_property
109
+ def input_parameters(self) -> dict[str, ArrayLike]:
160
110
  """
161
111
  The input parameters of the model.
162
112
  """
@@ -172,7 +122,9 @@ class FitResult:
172
122
  with seed(rng_seed=0):
173
123
  input_parameters = self.bayesian_fitter.prior_distributions_func()
174
124
 
175
- for module, parameter, value in traverse(input_parameters):
125
+ for key, value in input_parameters.items():
126
+ module, parameter = key.rsplit("_", 1)
127
+
176
128
  if f"{module}_{parameter}" in posterior.keys():
177
129
  # We add as extra dimension as there might be different values per observation
178
130
  if posterior[f"{module}_{parameter}"].shape == samples_shape:
@@ -180,19 +132,21 @@ class FitResult:
180
132
  else:
181
133
  to_set = posterior[f"{module}_{parameter}"]
182
134
 
183
- input_parameters[module][parameter] = to_set
135
+ input_parameters[f"{module}_{parameter}"] = to_set
184
136
 
185
137
  else:
186
138
  # The parameter is fixed in this case, so we just broadcast is over chain and draws
187
- input_parameters[module][parameter] = value[None, None, ...]
139
+ input_parameters[f"{module}_{parameter}"] = value[None, None, ...]
188
140
 
189
- if len(total_shape) < len(input_parameters[module][parameter].shape):
141
+ if len(total_shape) < len(input_parameters[f"{module}_{parameter}"].shape):
190
142
  # If there are only chains and draws, we reduce
191
- input_parameters[module][parameter] = input_parameters[module][parameter][..., 0]
143
+ input_parameters[f"{module}_{parameter}"] = input_parameters[
144
+ f"{module}_{parameter}"
145
+ ][..., 0]
192
146
 
193
147
  else:
194
- input_parameters[module][parameter] = jnp.broadcast_to(
195
- input_parameters[module][parameter], total_shape
148
+ input_parameters[f"{module}_{parameter}"] = jnp.broadcast_to(
149
+ input_parameters[f"{module}_{parameter}"], total_shape
196
150
  )
197
151
 
198
152
  return input_parameters
@@ -287,7 +241,8 @@ class FitResult:
287
241
  self,
288
242
  e_min: float,
289
243
  e_max: float,
290
- redshift: float | ArrayLike = 0.1,
244
+ redshift: float | ArrayLike = None,
245
+ distance: float | ArrayLike = None,
291
246
  observer_frame: bool = True,
292
247
  cosmology: Cosmology = Planck18,
293
248
  unit: Unit = u.erg / u.s,
@@ -310,6 +265,17 @@ class FitResult:
310
265
  if not observer_frame:
311
266
  raise NotImplementedError()
312
267
 
268
+ if redshift is None and distance is None:
269
+ raise ValueError("Either redshift or distance must be specified.")
270
+
271
+ if distance is not None:
272
+ if redshift is not None:
273
+ raise ValueError("Redshift must be None as a distance is specified.")
274
+ else:
275
+ redshift = distance.to(
276
+ cu.redshift, cu.redshift_distance(cosmology, kind="luminosity")
277
+ ).value
278
+
313
279
  @jax.jit
314
280
  @jnp.vectorize
315
281
  def vectorized_flux(*pars):
@@ -333,94 +299,46 @@ class FitResult:
333
299
 
334
300
  return value
335
301
 
336
- def to_chain(self, name: str, parameters_type: Literal["model", "bkg"] = "model") -> Chain:
302
+ def to_chain(self, name: str) -> Chain:
337
303
  """
338
304
  Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
339
305
 
340
306
  Parameters:
341
307
  name: The name of the chain.
342
- parameters_type: The parameters_type to include in the chain.
343
308
  """
344
309
 
345
- obs_id = self.inference_data.copy()
346
-
347
- if parameters_type == "model":
348
- keys_to_drop = [
349
- key
350
- for key in obs_id.posterior.keys()
351
- if (key.startswith("_") or key.startswith("bkg"))
352
- ]
353
- elif parameters_type == "bkg":
354
- keys_to_drop = [key for key in obs_id.posterior.keys() if not key.startswith("bkg")]
355
- else:
356
- raise ValueError(f"Unknown value for parameters_type: {parameters_type}")
357
-
358
- obs_id.posterior = obs_id.posterior.drop_vars(keys_to_drop)
359
- chain = Chain.from_arviz(obs_id, name)
360
-
361
- """
362
- chain.samples.columns = [
363
- format_parameters(parameter) for parameter in chain.samples.columns
310
+ keys_to_drop = [
311
+ key
312
+ for key in self.inference_data.posterior.keys()
313
+ if (key.startswith("_") or key.startswith("bkg"))
364
314
  ]
365
- """
366
315
 
367
- return chain
316
+ reduced_id = az.extract(
317
+ self.inference_data,
318
+ var_names=[f"~{key}" for key in keys_to_drop] if keys_to_drop else None,
319
+ group="posterior",
320
+ )
368
321
 
369
- @property
370
- def samples_haiku(self) -> HaikuDict[ArrayLike]:
371
- """
372
- Haiku-like structure for the samples e.g.
373
-
374
- ```
375
- {
376
- 'powerlaw_1' :
377
- {
378
- 'alpha': ...,
379
- 'amplitude': ...
380
- },
381
-
382
- 'blackbody_1':
383
- {
384
- 'kT': ...,
385
- 'norm': ...
386
- },
387
-
388
- 'tbabs_1':
389
- {
390
- 'nH': ...
391
- }
392
- }
393
- ```
322
+ df_list = []
394
323
 
395
- """
324
+ for var, array in reduced_id.data_vars.items():
325
+ extra_dims = [dim for dim in array.dims if dim not in ["sample"]]
396
326
 
397
- params = {}
327
+ if extra_dims:
328
+ dim = extra_dims[
329
+ 0
330
+ ] # We only support the case where the extra dimension comes from the observations
398
331
 
399
- for module, parameter, value in traverse(self._structure):
400
- if params.get(module, None) is None:
401
- params[module] = {}
402
- params[module][parameter] = self.samples_flat[f"{module}_{parameter}"]
332
+ for coord, obs_id in zip(array.coords[dim], self.obsconfs.keys()):
333
+ df = array.loc[{dim: coord}].to_pandas()
334
+ df.name += f"\n[{obs_id}]"
335
+ df_list.append(df)
336
+ else:
337
+ df_list.append(array.to_pandas())
403
338
 
404
- return params
339
+ df = pd.concat(df_list, axis=1)
405
340
 
406
- @property
407
- def samples_flat(self) -> dict[str, ArrayLike]:
408
- """
409
- Flat structure for the samples e.g.
410
-
411
- ```
412
- {
413
- 'powerlaw_1_alpha': ...,
414
- 'powerlaw_1_amplitude': ...,
415
- 'blackbody_1_kT': ...,
416
- 'blackbody_1_norm': ...,
417
- 'tbabs_1_nH': ...,
418
- }
419
- ```
420
- """
421
- var_names = [f"{m}_{n}" for m, n, _ in traverse(self._structure)]
422
- posterior = az.extract(self.inference_data, var_names=var_names)
423
- return {key: posterior[key].data for key in var_names}
341
+ return Chain(samples=df, name=name)
424
342
 
425
343
  @property
426
344
  def log_likelihood(self) -> xr.Dataset:
@@ -462,12 +380,18 @@ class FitResult:
462
380
 
463
381
  def plot_ppc(
464
382
  self,
465
- percentile: tuple[int, int] = (16, 84),
383
+ n_sigmas: int = 1,
466
384
  x_unit: str | u.Unit = "keV",
467
385
  y_type: Literal[
468
386
  "counts", "countrate", "photon_flux", "photon_flux_density"
469
387
  ] = "photon_flux_density",
470
- ) -> plt.Figure:
388
+ plot_background: bool = True,
389
+ plot_components: bool = False,
390
+ scale: Literal["linear", "semilogx", "semilogy", "loglog"] = "loglog",
391
+ alpha_envelope: (float, float) = (0.15, 0.25),
392
+ style: str | Any = "default",
393
+ title: str | None = None,
394
+ ) -> list[plt.Figure]:
471
395
  r"""
472
396
  Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
473
397
  following formula:
@@ -479,82 +403,52 @@ class FitResult:
479
403
  percentile: The percentile of the posterior predictive distribution to plot.
480
404
  x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
481
405
  y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
406
+ plot_background: Whether to plot the background model if it is included in the fit.
407
+ plot_components: Whether to plot the components of the model separately.
408
+ scale: The axes scaling
409
+ alpha_envelope: The transparency range for envelops
410
+ style: The style of the plot. It can be either a string or a matplotlib style context.
482
411
 
483
412
  Returns:
484
- The matplotlib figure.
413
+ A list of matplotlib figures for each observation in the model.
485
414
  """
486
415
 
487
416
  obsconf_container = self.obsconfs
417
+ figure_list = []
488
418
  x_unit = u.Unit(x_unit)
489
419
 
490
420
  match y_type:
491
421
  case "counts":
492
- y_units = u.photon
422
+ y_units = u.ct
493
423
  case "countrate":
494
- y_units = u.photon / u.s
424
+ y_units = u.ct / u.s
495
425
  case "photon_flux":
496
- y_units = u.photon / u.cm**2 / u.s
426
+ y_units = u.ct / u.cm**2 / u.s
497
427
  case "photon_flux_density":
498
- y_units = u.photon / u.cm**2 / u.s / x_unit
428
+ y_units = u.ct / u.cm**2 / u.s / x_unit
499
429
  case _:
500
430
  raise ValueError(
501
431
  f"Unknown y_type: {y_type}. Must be 'counts', 'countrate', 'photon_flux' or 'photon_flux_density'"
502
432
  )
503
433
 
504
- color = (0.15, 0.25, 0.45)
505
-
506
- with plt.style.context("default"):
507
- # Note to Simon : do not change xbins[1] - xbins[0] to
508
- # np.diff, you already did this twice and forgot that it does not work since diff keeps the dimensions
509
- # and enable weird broadcasting that makes the plot fail
510
-
511
- fig, axs = plt.subplots(
512
- 2,
513
- len(obsconf_container),
514
- figsize=(6 * len(obsconf_container), 6),
515
- sharex=True,
516
- height_ratios=[0.7, 0.3],
517
- )
518
-
519
- plot_ylabels_once = True
434
+ with plt.style.context(style):
435
+ for obs_id, obsconf in obsconf_container.items():
436
+ fig, ax = plt.subplots(
437
+ 2,
438
+ 1,
439
+ figsize=(6, 6),
440
+ sharex="col",
441
+ height_ratios=[0.7, 0.3],
442
+ )
520
443
 
521
- for name, obsconf, ax in zip(
522
- obsconf_container.keys(),
523
- obsconf_container.values(),
524
- axs.T if len(obsconf_container) > 1 else [axs],
525
- ):
526
444
  legend_plots = []
527
445
  legend_labels = []
446
+
528
447
  count = az.extract(
529
- self.inference_data, var_names=f"obs_{name}", group="posterior_predictive"
448
+ self.inference_data, var_names=f"obs_{obs_id}", group="posterior_predictive"
530
449
  ).values.T
531
- bkg_count = (
532
- None
533
- if self.background_model is None
534
- else az.extract(
535
- self.inference_data, var_names=f"bkg_{name}", group="posterior_predictive"
536
- ).values.T
537
- )
538
450
 
539
- xbins = obsconf.out_energies * u.keV
540
- xbins = xbins.to(x_unit, u.spectral())
541
-
542
- # This compute the total effective area within all bins
543
- # This is a bit weird since the following computation is equivalent to ignoring the RMF
544
- exposure = obsconf.exposure.data * u.s
545
- mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
546
- mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
547
- e_grid = np.linspace(*xbins, 10)
548
- interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
549
- integrated_arf = (
550
- trapezoid(interpolated_arf, x=e_grid, axis=0)
551
- / (
552
- np.abs(
553
- xbins[1] - xbins[0]
554
- ) # Must fold in abs because some units reverse the ordering of the bins
555
- )
556
- * u.cm**2
557
- )
451
+ xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
558
452
 
559
453
  match y_type:
560
454
  case "counts":
@@ -566,134 +460,124 @@ class FitResult:
566
460
  case "photon_flux_density":
567
461
  denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
568
462
 
569
- y_samples = (count * u.photon / denominator).to(y_units)
570
- y_observed = (obsconf.folded_counts.data * u.photon / denominator).to(y_units)
571
- y_observed_low = (
572
- nbinom.ppf(percentile[0] / 100, obsconf.folded_counts.data, 0.5)
573
- * u.photon
574
- / denominator
575
- ).to(y_units)
576
- y_observed_high = (
577
- nbinom.ppf(percentile[1] / 100, obsconf.folded_counts.data, 0.5)
578
- * u.photon
579
- / denominator
580
- ).to(y_units)
463
+ y_samples = (count * u.ct / denominator).to(y_units)
464
+
465
+ y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
466
+ obsconf.folded_counts.data, denominator, y_units
467
+ )
581
468
 
582
469
  # Use the helper function to plot the data and posterior predictive
583
- legend_plots += _plot_binned_samples_with_error(
470
+ model_plot = _plot_binned_samples_with_error(
584
471
  ax[0],
585
472
  xbins.value,
586
- y_samples=y_samples.value,
587
- denominator=np.ones_like(y_observed).value,
588
- color=color,
589
- percentile=percentile,
473
+ y_samples.value,
474
+ color=SPECTRUM_COLOR,
475
+ n_sigmas=n_sigmas,
476
+ alpha_envelope=alpha_envelope,
590
477
  )
591
478
 
592
- legend_labels.append("Model")
593
-
594
- true_data_plot = ax[0].errorbar(
595
- np.sqrt(xbins.value[0] * xbins.value[1]),
479
+ true_data_plot = _plot_poisson_data_with_error(
480
+ ax[0],
481
+ xbins.value,
596
482
  y_observed.value,
597
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
598
- yerr=[
599
- y_observed.value - y_observed_low.value,
600
- y_observed_high.value - y_observed.value,
601
- ],
602
- color="black",
603
- linestyle="none",
604
- alpha=0.3,
605
- capsize=2,
483
+ y_observed_low.value,
484
+ y_observed_high.value,
485
+ color=SPECTRUM_DATA_COLOR,
486
+ alpha=0.7,
606
487
  )
607
488
 
608
489
  legend_plots.append((true_data_plot,))
609
490
  legend_labels.append("Observed")
491
+ legend_plots += model_plot
492
+ legend_labels.append("Model")
610
493
 
611
- if self.background_model is not None:
494
+ # Plot the residuals
495
+ residual_samples = (obsconf.folded_counts.data - count) / np.diff(
496
+ np.percentile(count, [16, 84], axis=0), axis=0
497
+ )
498
+
499
+ _plot_binned_samples_with_error(
500
+ ax[1],
501
+ xbins.value,
502
+ residual_samples,
503
+ color=SPECTRUM_COLOR,
504
+ n_sigmas=n_sigmas,
505
+ alpha_envelope=alpha_envelope,
506
+ )
507
+
508
+ if plot_components:
509
+ for (component_name, count), color in zip(
510
+ self._ppc_folded_branches(obs_id).items(), COLOR_CYCLE
511
+ ):
512
+ # _ppc_folded_branches returns (n_chains, n_draws, n_bins) shaped arrays so we must flatten it
513
+ y_samples = (
514
+ count.reshape((count.shape[0] * count.shape[1], -1))
515
+ * u.ct
516
+ / denominator
517
+ ).to(y_units)
518
+ component_plot = _plot_binned_samples_with_error(
519
+ ax[0],
520
+ xbins.value,
521
+ y_samples.value,
522
+ color=color,
523
+ linestyle="dashdot",
524
+ n_sigmas=n_sigmas,
525
+ alpha_envelope=alpha_envelope,
526
+ )
527
+
528
+ legend_plots += component_plot
529
+ legend_labels.append(component_name)
530
+
531
+ if self.background_model is not None and plot_background:
612
532
  # We plot the background only if it is included in the fit, i.e. by subtracting
613
- ratio = obsconf.folded_backratio.data
614
- y_samples_bkg = (bkg_count * u.photon / (denominator * ratio)).to(y_units)
615
- y_observed_bkg = (
616
- obsconf.folded_background.data * u.photon / (denominator * ratio)
617
- ).to(y_units)
618
- y_observed_bkg_low = (
619
- nbinom.ppf(percentile[0] / 100, obsconf.folded_background.data, 0.5)
620
- * u.photon
621
- / (denominator * ratio)
622
- ).to(y_units)
623
- y_observed_bkg_high = (
624
- nbinom.ppf(percentile[1] / 100, obsconf.folded_background.data, 0.5)
625
- * u.photon
626
- / (denominator * ratio)
627
- ).to(y_units)
628
-
629
- legend_plots += _plot_binned_samples_with_error(
533
+ bkg_count = (
534
+ None
535
+ if self.background_model is None
536
+ else az.extract(
537
+ self.inference_data,
538
+ var_names=f"bkg_{obs_id}",
539
+ group="posterior_predictive",
540
+ ).values.T
541
+ )
542
+
543
+ y_samples_bkg = (bkg_count * u.ct / denominator).to(y_units)
544
+
545
+ y_observed_bkg, y_observed_bkg_low, y_observed_bkg_high = (
546
+ _error_bars_for_observed_data(
547
+ obsconf.folded_background.data, denominator, y_units
548
+ )
549
+ )
550
+
551
+ model_bkg_plot = _plot_binned_samples_with_error(
630
552
  ax[0],
631
553
  xbins.value,
632
- y_samples=y_samples_bkg.value,
633
- denominator=np.ones_like(y_observed).value,
634
- color=(0.26787604, 0.60085972, 0.63302651),
635
- percentile=percentile,
554
+ y_samples_bkg.value,
555
+ color=BACKGROUND_COLOR,
556
+ alpha_envelope=alpha_envelope,
557
+ n_sigmas=n_sigmas,
636
558
  )
637
559
 
638
- legend_labels.append("Model (bkg)")
639
-
640
- true_bkg_plot = ax[0].errorbar(
641
- np.sqrt(xbins.value[0] * xbins.value[1]),
560
+ true_bkg_plot = _plot_poisson_data_with_error(
561
+ ax[0],
562
+ xbins.value,
642
563
  y_observed_bkg.value,
643
- xerr=np.abs(xbins.value - np.sqrt(xbins.value[0] * xbins.value[1])),
644
- yerr=[
645
- y_observed_bkg.value - y_observed_bkg_low.value,
646
- y_observed_bkg_high.value - y_observed_bkg.value,
647
- ],
648
- color="black",
649
- linestyle="none",
650
- alpha=0.3,
651
- capsize=2,
564
+ y_observed_bkg_low.value,
565
+ y_observed_bkg_high.value,
566
+ color=BACKGROUND_DATA_COLOR,
567
+ alpha=0.7,
652
568
  )
653
569
 
654
570
  legend_plots.append((true_bkg_plot,))
655
571
  legend_labels.append("Observed (bkg)")
572
+ legend_plots += model_bkg_plot
573
+ legend_labels.append("Model (bkg)")
656
574
 
657
- residual_samples = (obsconf.folded_counts.data - count) / np.diff(
658
- np.percentile(count, percentile, axis=0), axis=0
659
- )
660
-
661
- residuals = np.percentile(
662
- residual_samples,
663
- percentile,
664
- axis=0,
665
- )
666
-
667
- median_residuals = np.median(
668
- residual_samples,
669
- axis=0,
670
- )
671
-
672
- ax[1].stairs(
673
- residuals[1],
674
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
675
- baseline=list(residuals[0]),
676
- alpha=0.3,
677
- facecolor=color,
678
- fill=True,
679
- )
680
-
681
- ax[1].stairs(
682
- median_residuals,
683
- edges=[*list(xbins.value[0]), xbins.value[1][-1]],
684
- color=color,
685
- alpha=0.7,
686
- )
687
-
688
- max_residuals = np.max(np.abs(residuals))
575
+ max_residuals = np.max(np.abs(residual_samples))
689
576
 
690
577
  ax[0].loglog()
691
578
  ax[1].set_ylim(-max(3.5, max_residuals), +max(3.5, max_residuals))
692
-
693
- if plot_ylabels_once:
694
- ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
695
- ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
696
- plot_ylabels_once = False
579
+ ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
580
+ ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
697
581
 
698
582
  match getattr(x_unit, "physical_type"):
699
583
  case "length":
@@ -708,24 +592,42 @@ class FitResult:
708
592
  f"Must be 'length', 'energy' or 'frequency'"
709
593
  )
710
594
 
711
- ax[1].axhline(0, color=color, ls="--")
712
- ax[1].axhline(-3, color=color, ls=":")
713
- ax[1].axhline(3, color=color, ls=":")
595
+ ax[1].axhline(0, color=SPECTRUM_DATA_COLOR, ls="--")
596
+ ax[1].axhline(-3, color=SPECTRUM_DATA_COLOR, ls=":")
597
+ ax[1].axhline(3, color=SPECTRUM_DATA_COLOR, ls=":")
714
598
 
715
- # ax[1].set_xticks(xticks, labels=xticks)
716
- # ax[1].xaxis.set_minor_formatter(ticker.LogFormatter(minor_thresholds=(np.inf, np.inf)))
717
599
  ax[1].set_yticks([-3, 0, 3], labels=[-3, 0, 3])
718
600
  ax[1].set_yticks(range(-3, 4), minor=True)
719
601
 
720
602
  ax[0].set_xlim(xbins.value.min(), xbins.value.max())
721
603
 
722
604
  ax[0].legend(legend_plots, legend_labels)
723
- fig.suptitle(self.model.to_string())
605
+
606
+ match scale:
607
+ case "linear":
608
+ ax[0].set_xscale("linear")
609
+ ax[0].set_yscale("linear")
610
+ case "semilogx":
611
+ ax[0].set_xscale("log")
612
+ ax[0].set_yscale("linear")
613
+ case "semilogy":
614
+ ax[0].set_xscale("linear")
615
+ ax[0].set_yscale("log")
616
+ case "loglog":
617
+ ax[0].set_xscale("log")
618
+ ax[0].set_yscale("log")
619
+
724
620
  fig.align_ylabels()
725
621
  plt.subplots_adjust(hspace=0.0)
726
622
  fig.tight_layout()
623
+ figure_list.append(fig)
624
+ fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
625
+ # fig.show()
626
+
627
+ plt.tight_layout()
628
+ plt.show()
727
629
 
728
- return fig
630
+ return figure_list
729
631
 
730
632
  def table(self) -> str:
731
633
  r"""