redback 1.0.31__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. redback/__init__.py +3 -2
  2. redback/analysis.py +321 -4
  3. redback/filters.py +57 -23
  4. redback/get_data/directory.py +18 -0
  5. redback/likelihoods.py +260 -0
  6. redback/model_library.py +12 -2
  7. redback/plotting.py +335 -4
  8. redback/priors/blackbody_spectrum_with_absorption_and_emission_lines.prior +9 -0
  9. redback/priors/csm_shock_and_arnett_two_rphots.prior +11 -0
  10. redback/priors/exp_rise_powerlaw_decline.prior +6 -0
  11. redback/priors/powerlaw_spectrum_with_absorption_and_emission_lines.prior +8 -0
  12. redback/priors/salt2.prior +6 -0
  13. redback/priors/shock_cooling_and_arnett_bolometric.prior +11 -0
  14. redback/priors/shockcooling_morag.prior +6 -0
  15. redback/priors/shockcooling_morag_and_arnett.prior +10 -0
  16. redback/priors/shockcooling_morag_and_arnett_bolometric.prior +9 -0
  17. redback/priors/shockcooling_morag_bolometric.prior +5 -0
  18. redback/priors/shockcooling_sapirandwaxman.prior +6 -0
  19. redback/priors/shockcooling_sapirandwaxman_bolometric.prior +5 -0
  20. redback/priors/shockcooling_sapirwaxman_and_arnett.prior +10 -0
  21. redback/priors/shockcooling_sapirwaxman_and_arnett_bolometric.prior +9 -0
  22. redback/priors/shocked_cocoon_and_arnett.prior +13 -0
  23. redback/priors/synchrotron_ism.prior +6 -0
  24. redback/priors/synchrotron_massloss.prior +6 -0
  25. redback/priors/synchrotron_pldensity.prior +7 -0
  26. redback/priors/thermal_synchrotron_v2_fluxdensity.prior +8 -0
  27. redback/priors/thermal_synchrotron_v2_lnu.prior +7 -0
  28. redback/priors.py +10 -3
  29. redback/result.py +9 -1
  30. redback/sampler.py +46 -4
  31. redback/sed.py +48 -1
  32. redback/simulate_transients.py +5 -1
  33. redback/tables/filters.csv +265 -254
  34. redback/transient/__init__.py +2 -3
  35. redback/transient/transient.py +648 -10
  36. redback/transient_models/__init__.py +3 -2
  37. redback/transient_models/extinction_models.py +3 -2
  38. redback/transient_models/gaussianprocess_models.py +45 -0
  39. redback/transient_models/general_synchrotron_models.py +296 -6
  40. redback/transient_models/phenomenological_models.py +154 -7
  41. redback/transient_models/shock_powered_models.py +503 -40
  42. redback/transient_models/spectral_models.py +82 -0
  43. redback/transient_models/supernova_models.py +405 -31
  44. redback/transient_models/tde_models.py +57 -41
  45. redback/utils.py +302 -51
  46. {redback-1.0.31.dist-info → redback-1.12.0.dist-info}/METADATA +8 -6
  47. {redback-1.0.31.dist-info → redback-1.12.0.dist-info}/RECORD +50 -29
  48. {redback-1.0.31.dist-info → redback-1.12.0.dist-info}/WHEEL +1 -1
  49. {redback-1.0.31.dist-info → redback-1.12.0.dist-info/licenses}/LICENCE.md +0 -0
  50. {redback-1.0.31.dist-info → redback-1.12.0.dist-info}/top_level.txt +0 -0
redback/likelihoods.py CHANGED
@@ -204,6 +204,266 @@ class GaussianLikelihood(_RedbackLikelihood):
204
204
  def _gaussian_log_likelihood(res: np.ndarray, sigma: Union[float, np.ndarray]) -> Any:
205
205
  return np.sum(- (res / sigma) ** 2 / 2 - np.log(2 * np.pi * sigma ** 2) / 2)
206
206
 
207
+ class MixtureGaussianLikelihood(GaussianLikelihood):
208
+ def __init__(self, x: np.ndarray, y: np.ndarray,
209
+ sigma: Union[float, None, np.ndarray],
210
+ function: callable, kwargs: dict = None,
211
+ priors=None, fiducial_parameters=None) -> None:
212
+ """
213
+ Mixture Gaussian likelihood that handles outliers by modeling each data point’s likelihood
214
+ as a weighted sum of two Gaussians. The likelihood for each datum is given by
215
+
216
+ L_i = α * N(y_i | f(x_i), σ²) + (1 - α) * N(y_i | f(x_i), σ_out²)
217
+
218
+ where:
219
+ - N(y_i | f(x_i), σ²) is the Gaussian probability density evaluated at y_i with mean f(x_i)
220
+ and variance σ².
221
+ - α is the inlier fraction (between 0 and 1).
222
+ - σ_out is the standard deviation for the outlier component.
223
+
224
+ In addition, the posterior probability that a data point is an outlier is computed via
225
+
226
+ P(outlier | r) = [(1 - α) * p_out(r)] / [α * p_in(r) + (1 - α) * p_out(r)]
227
+
228
+ where r is the residual (y - f(x)).
229
+
230
+ Parameters
231
+ ----------
232
+ x : np.ndarray
233
+ Independent variable data.
234
+ y : np.ndarray
235
+ Observed dependent variable data.
236
+ sigma : float, None, or np.ndarray
237
+ Standard deviation for the inlier component.
238
+ function : callable
239
+ Model function. It should accept x as the first argument.
240
+ kwargs : dict, optional
241
+ Additional keyword arguments for the model function.
242
+ sigma_out: Standard deviation of outlier data, is set to 10 times the inlier sigma if not provided.
243
+ alpha: Inlier fraction, i.e., fraction of data points from the underlying model. Default is 0.9.
244
+ priors : dict, optional
245
+ Priors for the parameters (not used in this implementation).
246
+ fiducial_parameters : dict, optional
247
+ Starting guesses for the model parameters.
248
+ """
249
+ super().__init__(x=x, y=y, sigma=sigma, function=function, kwargs=kwargs, priors=priors,
250
+ fiducial_parameters=fiducial_parameters)
251
+
252
+ # Set default mixture parameters if not provided.
253
+ if 'alpha' not in self.parameters:
254
+ self.parameters['alpha'] = 0.9
255
+ if 'sigma_out' not in self.parameters:
256
+ if sigma is not None and isinstance(sigma, (int, float)):
257
+ self.parameters['sigma_out'] = sigma * 10
258
+ else:
259
+ self.parameters['sigma_out'] = 10.0
260
+
261
+ self._noise_log_likelihood = None
262
+ def _mixture_gaussian_log_likelihood(self, res: np.ndarray,
263
+ sigma: Union[float, np.ndarray],
264
+ sigma_out: Union[float, np.ndarray],
265
+ alpha: float) -> np.ndarray:
266
+ """
267
+ Compute the log-likelihood of the residuals under a mixture of two Gaussians in a stable
268
+ manner using the log-sum-exp trick.
269
+
270
+ Parameters
271
+ ----------
272
+ res : np.ndarray
273
+ Array of residuals.
274
+ sigma : float or np.ndarray
275
+ Standard deviation for the inlier Gaussian.
276
+ sigma_out : float or np.ndarray
277
+ Standard deviation for the outlier Gaussian.
278
+ alpha : float
279
+ Inlier fraction (between 0 and 1).
280
+
281
+ Returns
282
+ -------
283
+ np.ndarray
284
+ Log-likelihood for each residual under the mixture model.
285
+ """
286
+ # Compute log densities directly for inlier and outlier components:
287
+ logp_in = -0.5 * np.log(2 * np.pi) - np.log(sigma) - 0.5 * (res / sigma) ** 2
288
+ logp_out = -0.5 * np.log(2 * np.pi) - np.log(sigma_out) - 0.5 * (res / sigma_out) ** 2
289
+
290
+ # Combine contributions using log-sum-exp:
291
+ # log(sum_i exp(log_a_i)) can be computed as np.logaddexp(log_a, log_b) for two terms.
292
+ term_in = np.log(alpha) + logp_in
293
+ term_out = np.log(1 - alpha) + logp_out
294
+
295
+ # np.logaddexp is applied element-wise:
296
+ log_likelihood = np.logaddexp(term_in, term_out)
297
+ return log_likelihood
298
+
299
+ def p_in(self, r: np.ndarray) -> np.ndarray:
300
+ """
301
+ Compute the inlier probability density for residuals.
302
+
303
+ Parameters
304
+ ----------
305
+ r : np.ndarray
306
+ Residuals.
307
+
308
+ Returns
309
+ -------
310
+ np.ndarray
311
+ Inlier probability density evaluated at each residual.
312
+ """
313
+ sigma = self.sigma
314
+ return (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-0.5 * (r / sigma) ** 2)
315
+
316
+ def p_out(self, r: np.ndarray) -> np.ndarray:
317
+ """
318
+ Compute the outlier probability density for residuals.
319
+
320
+ Parameters
321
+ ----------
322
+ r : np.ndarray
323
+ Residuals.
324
+
325
+ Returns
326
+ -------
327
+ np.ndarray
328
+ Outlier probability density evaluated at each residual.
329
+ """
330
+ sigma_out = self.parameters.get('sigma_out')
331
+ return (1 / (np.sqrt(2 * np.pi) * sigma_out)) * np.exp(-0.5 * (r / sigma_out) ** 2)
332
+
333
+ def log_likelihood(self) -> float:
334
+ """
335
+ Compute the total log-likelihood for the mixture model.
336
+
337
+ For each data point, the log-likelihood is given by
338
+
339
+ log(L_i) = log(α * N(0, σ²) + (1 - α) * N(0, σ_out²)).
340
+
341
+ Returns
342
+ -------
343
+ float
344
+ The overall log-likelihood (summed over all data points).
345
+ """
346
+ res = self.residual
347
+ alpha = self.parameters.get('alpha')
348
+ sigma_out = self.parameters.get('sigma_out')
349
+ ll = np.sum(self._mixture_gaussian_log_likelihood(res, self.sigma, sigma_out, alpha))
350
+ return ll
351
+
352
+ def calculate_outlier_posteriors(self, model_prediction: np.ndarray) -> np.ndarray:
353
+ """
354
+ Calculate the posterior probability that each data point is an outlier.
355
+
356
+ Given a model prediction, the residual for each point is computed as:
357
+ r = y - model_prediction.
358
+ Then the posterior is given by
359
+
360
+ P(outlier | r) = [(1 - α) * p_out(r)] / [α * p_in(r) + (1 - α) * p_out(r)].
361
+
362
+ Parameters
363
+ ----------
364
+ model_prediction : np.ndarray
365
+ Model predictions for each data point.
366
+
367
+ Returns
368
+ -------
369
+ np.ndarray
370
+ An array of posterior probabilities (between 0 and 1) for each data point being an outlier.
371
+ """
372
+ r = self.y - model_prediction
373
+ pin = self.p_in(r)
374
+ pout = self.p_out(r)
375
+ alpha = self.parameters.get('alpha')
376
+ numerator = (1 - alpha) * pout
377
+ denominator = alpha * pin + (1 - alpha) * pout
378
+ posteriors = np.where(denominator > 0, numerator / denominator, 0.0)
379
+ return posteriors
380
+
381
+ class StudentTLikelihood(GaussianLikelihood):
382
+ def __init__(
383
+ self, x: np.ndarray, y: np.ndarray, sigma: Union[float, None, np.ndarray],
384
+ function: callable, kwargs: dict = None, priors=None, fiducial_parameters=None) -> None:
385
+ """
386
+ A Student-t likelihood that handles outliers by assuming that the data are distributed
387
+ according to a Student-t distribution. The probability density function for each residual is
388
+ given by:
389
+
390
+ p(r | ν, σ) = Γ((ν+1)/2) / [√(νπ) σ Γ(ν/2)] · [1 + (r/σ)²/ν]^(–(ν+1)/2)
391
+
392
+ where ν (nu) is the degrees of freedom and σ is the scale (standard deviation).
393
+
394
+ :param x: The x values.
395
+ :type x: np.ndarray
396
+ :param y: The y values.
397
+ :type y: np.ndarray
398
+ :param sigma: The scale (standard deviation) for the model residuals.
399
+ :type sigma: Union[float, None, np.ndarray]
400
+ :param function:
401
+ The python function to fit to the data. Note, this must take the
402
+ dependent variable as its first argument. The other arguments will require a prior
403
+ and will be sampled over (unless a fixed value is given).
404
+ :type function: callable
405
+ :param kwargs: Any additional keywords for 'function'.
406
+ :type kwargs: dict
407
+ :param nu: The degrees of freedom for the Student-t distribution. Default to 3.0.
408
+ :param priors: The priors for the parameters. Default to None if not provided.
409
+ :type priors: Union[dict, None]
410
+ :param fiducial_parameters: The starting guesses for model parameters to use in the optimization.
411
+ :type fiducial_parameters: Union[dict, None]
412
+ """
413
+ self._noise_log_likelihood = None
414
+ super().__init__(x=x, y=y, sigma=sigma, function=function, kwargs=kwargs, priors=priors,
415
+ fiducial_parameters=fiducial_parameters)
416
+
417
+ # Set default degrees of freedom for the Student-t distribution if not provided.
418
+ if 'nu' not in self.parameters:
419
+ self.parameters['nu'] = 3.0 # You can change this default as needed.
420
+
421
+ def noise_log_likelihood(self) -> float:
422
+ """
423
+ Compute the log-likelihood assuming the signal is pure noise (i.e. the residuals are just y).
424
+ """
425
+ if self._noise_log_likelihood is None:
426
+ nu = self.parameters.get('nu')
427
+ self._noise_log_likelihood = self._student_t_log_likelihood(res=self.y, sigma=self.sigma, nu=nu)
428
+ return self._noise_log_likelihood
429
+
430
+ def log_likelihood(self) -> float:
431
+ """
432
+ Compute the total log-likelihood for the Student-t likelihood model.
433
+
434
+ :return: The log-likelihood.
435
+ :rtype: float
436
+ """
437
+ nu = self.parameters.get('nu')
438
+ return np.nan_to_num(self._student_t_log_likelihood(res=self.residual, sigma=self.sigma, nu=nu))
439
+
440
+ @staticmethod
441
+ def _student_t_log_likelihood(res: np.ndarray, sigma: Union[float, np.ndarray], nu: float) -> Any:
442
+ """
443
+ Computes the log likelihood of the Student-t distribution for the residuals.
444
+
445
+ For each data point, the log probability is given by:
446
+
447
+ log[p(r|ν,σ)] = gammaln((ν+1)/2) - gammaln(ν/2)
448
+ - 0.5*log(νπ) - log(σ)
449
+ - ((ν+1)/2)*log[1 + (r/σ)²/ν]
450
+
451
+ :param res: The residuals.
452
+ :type res: np.ndarray
453
+ :param sigma: The scale parameter.
454
+ :type sigma: Union[float, np.ndarray]
455
+ :param nu: Degrees of freedom.
456
+ :type nu: float
457
+ :return: The total log-likelihood.
458
+ :rtype: float
459
+ """
460
+ term1 = gammaln((nu + 1) / 2) - gammaln(nu / 2)
461
+ term2 = - 0.5 * np.log(nu * np.pi)
462
+ term3 = - np.log(sigma)
463
+ term4 = - ((nu + 1) / 2) * np.log(1 + (res / sigma) ** 2 / nu)
464
+ log_pdf = term1 + term2 + term3 + term4
465
+ return np.sum(log_pdf)
466
+
207
467
 
208
468
  class GaussianLikelihoodUniformXErrors(GaussianLikelihood):
209
469
  def __init__(
redback/model_library.py CHANGED
@@ -1,18 +1,28 @@
1
1
  from redback.transient_models import afterglow_models, \
2
2
  extinction_models, kilonova_models, fireball_models, \
3
3
  gaussianprocess_models, magnetar_models, magnetar_driven_ejecta_models, phase_models, phenomenological_models, \
4
- prompt_models, shock_powered_models, supernova_models, tde_models, integrated_flux_afterglow_models, combined_models, general_synchrotron_models
4
+ prompt_models, shock_powered_models, supernova_models, tde_models, integrated_flux_afterglow_models, combined_models, \
5
+ general_synchrotron_models, spectral_models
6
+
5
7
  from redback.utils import get_functions_dict
6
8
 
7
9
  modules = [afterglow_models, extinction_models, fireball_models,
8
10
  gaussianprocess_models, integrated_flux_afterglow_models, kilonova_models,
9
11
  magnetar_models, magnetar_driven_ejecta_models,
10
- phase_models, phenomenological_models, prompt_models, shock_powered_models, supernova_models, tde_models, combined_models, general_synchrotron_models]
12
+ phase_models, phenomenological_models, prompt_models, shock_powered_models, supernova_models,
13
+ tde_models, combined_models, general_synchrotron_models, spectral_models]
14
+
15
+ base_modules = [extinction_models, phase_models]
11
16
 
12
17
  all_models_dict = dict()
18
+ base_models_dict = dict()
13
19
  modules_dict = dict()
14
20
  for module in modules:
15
21
  models_dict = get_functions_dict(module)
16
22
  modules_dict.update(models_dict)
17
23
  for k, v in models_dict[module.__name__.split('.')[-1]].items():
18
24
  all_models_dict[k] = v
25
+ for mod in base_modules:
26
+ models_dict = get_functions_dict(mod)
27
+ for k, v in models_dict[mod.__name__.split('.')[-1]].items():
28
+ base_models_dict[k] = v
redback/plotting.py CHANGED
@@ -33,6 +33,9 @@ class _FilePathGetter(object):
33
33
 
34
34
 
35
35
  class Plotter(object):
36
+ """
37
+ Base class for all lightcurve plotting classes in redback.
38
+ """
36
39
 
37
40
  capsize = KwargsAccessorWithDefault("capsize", 0.)
38
41
  legend_location = KwargsAccessorWithDefault("legend_location", "best")
@@ -43,9 +46,9 @@ class Plotter(object):
43
46
  band_scaling = KwargsAccessorWithDefault("band_scaling", {})
44
47
  dpi = KwargsAccessorWithDefault("dpi", 300)
45
48
  elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
46
- errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
49
+ errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "o")
47
50
  model = KwargsAccessorWithDefault("model", None)
48
- ms = KwargsAccessorWithDefault("ms", 1)
51
+ ms = KwargsAccessorWithDefault("ms", 5)
49
52
  axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
50
53
 
51
54
  max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
@@ -262,6 +265,208 @@ class Plotter(object):
262
265
  if show:
263
266
  plt.show()
264
267
 
268
+ class SpecPlotter(object):
269
+ """
270
+ Base class for all lightcurve plotting classes in redback.
271
+ """
272
+
273
+ capsize = KwargsAccessorWithDefault("capsize", 0.)
274
+ elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
275
+ errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
276
+ legend_location = KwargsAccessorWithDefault("legend_location", "best")
277
+ legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
278
+ color = KwargsAccessorWithDefault("color", "k")
279
+ dpi = KwargsAccessorWithDefault("dpi", 300)
280
+ model = KwargsAccessorWithDefault("model", None)
281
+ ms = KwargsAccessorWithDefault("ms", 1)
282
+ axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
283
+
284
+ max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
285
+ random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
286
+ uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
287
+ max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
288
+ random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
289
+
290
+ bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
291
+ linewidth = KwargsAccessorWithDefault("linewidth", 2)
292
+ zorder = KwargsAccessorWithDefault("zorder", -1)
293
+ yscale = KwargsAccessorWithDefault("yscale", "linear")
294
+
295
+ xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
296
+ xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
297
+ horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
298
+ annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
299
+
300
+ fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
301
+ fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
302
+ fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
303
+ fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
304
+ hspace = KwargsAccessorWithDefault("hspace", 0.04)
305
+ wspace = KwargsAccessorWithDefault("wspace", 0.15)
306
+
307
+ random_models = KwargsAccessorWithDefault("random_models", 100)
308
+ uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
309
+ credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
310
+ plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
311
+ set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
312
+
313
+ xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
314
+ xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
315
+ ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
316
+ ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
317
+
318
+ def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
319
+ """
320
+ :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
321
+ :param kwargs: Additional kwargs the plotter uses. -------
322
+ :keyword capsize: Same as matplotlib capsize.
323
+ :keyword elinewidth: same as matplotlib elinewidth
324
+ :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
325
+ :keyword ms: Same as matplotlib markersize.
326
+ :keyword legend_location: Same as matplotlib legend location.
327
+ :keyword legend_cols: Same as matplotlib legend columns.
328
+ :keyword color: Color of the data points.
329
+ :keyword dpi: Same as matplotlib dpi.
330
+ :keyword model: str or callable, the model to plot.
331
+ :keyword ms: Same as matplotlib markersize.
332
+ :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
333
+ :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
334
+ :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
335
+ :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
336
+ :keyword max_likelihood_color: Color of the maximum likelihood curve.
337
+ :keyword random_sample_color: Color of the random sample curves.
338
+ :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
339
+ :keyword linewidth: Same as matplotlib linewidth
340
+ :keyword zorder: Same as matplotlib zorder
341
+ :keyword yscale: Same as matplotlib yscale, default is linear
342
+ :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
343
+ :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
344
+ :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
345
+ :keyword annotation_size: `size` argument of of `ax.annotate`.
346
+ :keyword fontsize_axes: Font size of the x and y labels.
347
+ :keyword fontsize_legend: Font size of the legend.
348
+ :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
349
+ Used on `supxlabel` and `supylabel`.
350
+ :keyword fontsize_ticks: Font size of the axis ticks.
351
+ :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
352
+ :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
353
+ :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
354
+ :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
355
+ :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
356
+ 'credible_intervals': Plot a credible interval that is calculated based
357
+ on the available parameter sets.
358
+ :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
359
+ :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
360
+ :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
361
+ :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
362
+ :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
363
+ :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
364
+ :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
365
+ """
366
+ self.transient = spectrum
367
+ self.kwargs = kwargs or dict()
368
+ self._posterior_sorted = False
369
+
370
+ keyword_docstring = __init__.__doc__.split("-------")[1]
371
+
372
+ def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
373
+ """
374
+ :param axes: The axes used in the plotting procedure.
375
+ :type axes: matplotlib.axes.Axes
376
+
377
+ :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
378
+ :rtype: np.ndarray
379
+ """
380
+ if isinstance(axes, np.ndarray):
381
+ ax = axes[0]
382
+ else:
383
+ ax = axes
384
+
385
+ if ax.get_yscale() == 'linear':
386
+ angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
387
+ else:
388
+ angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
389
+
390
+ return angstroms
391
+
392
+ @property
393
+ def _xlim_low(self) -> float:
394
+ default = self.xlim_low_multiplier * self.transient.angstroms[0]
395
+ if default == 0:
396
+ default += 1e-3
397
+ return self.kwargs.get("xlim_low", default)
398
+
399
+ @property
400
+ def _xlim_high(self) -> float:
401
+ default = self.xlim_high_multiplier * self.transient.angstroms[-1]
402
+ return self.kwargs.get("xlim_high", default)
403
+
404
+ @property
405
+ def _ylim_low(self) -> float:
406
+ default = self.ylim_low_multiplier * min(self.transient.flux_density)
407
+ return self.kwargs.get("ylim_low", default/1e-17)
408
+
409
+ @property
410
+ def _ylim_high(self) -> float:
411
+ default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
412
+ return self.kwargs.get("ylim_high", default/1e-17)
413
+
414
+ @property
415
+ def _y_err(self) -> np.ndarray:
416
+ return np.array([np.abs(self.transient.flux_density_err)])
417
+
418
+ @property
419
+ def _data_plot_outdir(self) -> str:
420
+ return self._get_outdir(self.transient.directory_structure.directory_path)
421
+
422
+ def _get_outdir(self, default: str) -> str:
423
+ return self._get_kwarg_with_default(kwarg="outdir", default=default)
424
+
425
+ def get_filename(self, default: str) -> str:
426
+ return self._get_kwarg_with_default(kwarg="filename", default=default)
427
+
428
+ def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
429
+ return self.kwargs.get(kwarg, default) or default
430
+
431
+ @property
432
+ def _model_kwargs(self) -> dict:
433
+ return self._get_kwarg_with_default("model_kwargs", dict())
434
+
435
+ @property
436
+ def _posterior(self) -> pd.DataFrame:
437
+ posterior = self.kwargs.get("posterior", pd.DataFrame())
438
+ if not self._posterior_sorted and posterior is not None:
439
+ posterior.sort_values(by='log_likelihood', inplace=True)
440
+ self._posterior_sorted = True
441
+ return posterior
442
+
443
+ @property
444
+ def _max_like_params(self) -> pd.core.series.Series:
445
+ return self._posterior.iloc[-1]
446
+
447
+ def _get_random_parameters(self) -> list[pd.core.series.Series]:
448
+ integers = np.arange(len(self._posterior))
449
+ indices = np.random.choice(integers, size=self.random_models)
450
+ return [self._posterior.iloc[idx] for idx in indices]
451
+
452
+ _data_plot_filename = _FilenameGetter(suffix="data")
453
+ _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
454
+ _residual_plot_filename = _FilenameGetter(suffix="residual")
455
+
456
+ _data_plot_filepath = _FilePathGetter(
457
+ directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
458
+ _spectrum_ppd_plot_filepath = _FilePathGetter(
459
+ directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
460
+ _residual_plot_filepath = _FilePathGetter(
461
+ directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
462
+
463
+ def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
464
+ plt.tight_layout()
465
+ if save:
466
+ plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
467
+ if show:
468
+ plt.show()
469
+
265
470
 
266
471
  class IntegratedFluxPlotter(Plotter):
267
472
 
@@ -385,6 +590,16 @@ class IntegratedFluxPlotter(Plotter):
385
590
  return axes
386
591
 
387
592
 
593
+ class LuminosityOpticalPlotter(IntegratedFluxPlotter):
594
+
595
+ @property
596
+ def _xlabel(self) -> str:
597
+ return r"Time since explosion [days]"
598
+
599
+ @property
600
+ def _ylabel(self) -> str:
601
+ return r"L$_{\rm bol}$ [$10^{50}$ erg s$^{-1}$]"
602
+
388
603
  class LuminosityPlotter(IntegratedFluxPlotter):
389
604
  pass
390
605
 
@@ -805,6 +1020,122 @@ class FluxDensityPlotter(MagnitudePlotter):
805
1020
  class IntegratedFluxOpticalPlotter(MagnitudePlotter):
806
1021
  pass
807
1022
 
808
- class SpectraPlotter(Plotter):
809
- pass
1023
+ class SpectrumPlotter(SpecPlotter):
1024
+ @property
1025
+ def _xlabel(self) -> str:
1026
+ return self.transient.xlabel
1027
+
1028
+ @property
1029
+ def _ylabel(self) -> str:
1030
+ return self.transient.ylabel
810
1031
 
1032
+ def plot_data(
1033
+ self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1034
+ """Plots the spectrum data and returns Axes.
1035
+
1036
+ :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1037
+ :type axes: Union[matplotlib.axes.Axes, None], optional
1038
+ :param save: Whether to save the plot. (Default value = True)
1039
+ :type save: bool
1040
+ :param show: Whether to show the plot. (Default value = True)
1041
+ :type show: bool
1042
+
1043
+ :return: The axes with the plot.
1044
+ :rtype: matplotlib.axes.Axes
1045
+ """
1046
+ ax = axes or plt.gca()
1047
+
1048
+ if self.transient.plot_with_time_label:
1049
+ label = self.transient.time
1050
+ else:
1051
+ label = self.transient.name
1052
+ ax.plot(self.transient.angstroms, self.transient.flux_density/1e-17, color=self.color,
1053
+ lw=self.linewidth)
1054
+ ax.set_xscale('linear')
1055
+ ax.set_yscale(self.yscale)
1056
+
1057
+ ax.set_xlim(self._xlim_low, self._xlim_high)
1058
+ ax.set_ylim(self._ylim_low, self._ylim_high)
1059
+ ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1060
+ ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1061
+
1062
+ ax.annotate(
1063
+ label, xy=self.xy, xycoords=self.xycoords,
1064
+ horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1065
+
1066
+ ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1067
+
1068
+ self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1069
+ return ax
1070
+
1071
+ def plot_spectrum(
1072
+ self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1073
+ """Plots the spectrum data and the fit and returns Axes.
1074
+
1075
+ :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1076
+ :type axes: Union[matplotlib.axes.Axes, None], optional
1077
+ :param save: Whether to save the plot. (Default value = True)
1078
+ :type save: bool
1079
+ :param show: Whether to show the plot. (Default value = True)
1080
+ :type show: bool
1081
+
1082
+ :return: The axes with the plot.
1083
+ :rtype: matplotlib.axes.Axes
1084
+ """
1085
+
1086
+ axes = axes or plt.gca()
1087
+
1088
+ axes = self.plot_data(axes=axes, save=False, show=False)
1089
+ angstroms = self._get_angstroms(axes)
1090
+
1091
+ self._plot_spectrums(axes, angstroms)
1092
+
1093
+ self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
1094
+ return axes
1095
+
1096
+ def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1097
+ if self.plot_max_likelihood:
1098
+ ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
1099
+ axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1100
+ lw=self.linewidth)
1101
+
1102
+ random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
1103
+ for random_params in self._get_random_parameters()]
1104
+ if self.uncertainty_mode == "random_models":
1105
+ for ys in random_ys_list:
1106
+ axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
1107
+ lw=self.linewidth, zorder=self.zorder)
1108
+ elif self.uncertainty_mode == "credible_intervals":
1109
+ lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
1110
+ interval=self.credible_interval_level)
1111
+ axes.fill_between(
1112
+ angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1113
+
1114
+ def plot_residuals(
1115
+ self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1116
+ """Plots the residual of the Integrated flux data returns Axes.
1117
+
1118
+ :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1119
+ :param save: Whether to save the plot. (Default value = True)
1120
+ :param show: Whether to show the plot. (Default value = True)
1121
+
1122
+ :return: The axes with the plot.
1123
+ :rtype: matplotlib.axes.Axes
1124
+ """
1125
+ if axes is None:
1126
+ fig, axes = plt.subplots(
1127
+ nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1128
+
1129
+ axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
1130
+ axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1131
+ axes[0].set_xlabel("")
1132
+ ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
1133
+ axes[1].errorbar(
1134
+ self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1135
+ fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
1136
+ axes[1].set_yscale('linear')
1137
+ axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1138
+ axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1139
+
1140
+ self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1141
+ return axes
@@ -0,0 +1,9 @@
1
+ redshift = Uniform(minimum=0.01, maximum=2.0, name='redshift', latex_label=r'$z$')
2
+ rph = LogUniform(minimum=1e13, maximum=1e16, name='rph', latex_label=r'$R_{\mathrm{ph}}~(\mathrm{cm})$')
3
+ temp = Uniform(minimum=1e3, maximum=1e5, name='temp', latex_label=r'$T~(\mathrm{K})$')
4
+ lc1 = Uniform(2000, 7000, name='lc1', latex_label=r'$\Lambda_{\mathrm{emission}}$')
5
+ lc2 = Uniform(2000,7000, name='lc2', latex_label=r'$\Lambda_{\mathrm{absorption}}$')
6
+ ls1 = LogUniform(1e-34, 1e-31, name='ls1', latex_label=r'line_stength$_{\mathrm{emission}}$')
7
+ ls2 = LogUniform(1e-34, 1e-31, name='ls2', latex_label=r'line_stength$_{\mathrm{absorption}}$')
8
+ v1 = Uniform(1e3, 1e4, name='v1', latex_label=r'$v_{\mathrm{emission}}$')
9
+ v2 = Uniform(1e3, 1e4, name='v2', latex_label=r'$v_{\mathrm{absorption}}$')