redback 1.1__py3-none-any.whl → 1.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. redback/__init__.py +1 -1
  2. redback/filters.py +57 -45
  3. redback/likelihoods.py +274 -6
  4. redback/model_library.py +2 -2
  5. redback/plotting.py +5 -3
  6. redback/priors/blackbody_spectrum_at_z.prior +3 -0
  7. redback/priors/bpl_cooling_envelope.prior +9 -0
  8. redback/priors/gaussianrise_cooling_envelope.prior +1 -5
  9. redback/priors/gaussianrise_cooling_envelope_bolometric.prior +1 -5
  10. redback/priors/powerlaw_plus_blackbody.prior +12 -0
  11. redback/priors/powerlaw_plus_blackbody_spectrum_at_z.prior +13 -0
  12. redback/priors/salt2.prior +6 -0
  13. redback/priors/shock_cooling_and_arnett_bolometric.prior +11 -0
  14. redback/priors/smooth_exponential_powerlaw_cooling_envelope_bolometric.prior +9 -0
  15. redback/priors/sn_nickel_fallback.prior +9 -0
  16. redback/priors/wr_bh_merger.prior +10 -0
  17. redback/priors/wr_bh_merger_bolometric.prior +8 -0
  18. redback/priors.py +14 -3
  19. redback/sed.py +185 -41
  20. redback/simulate_transients.py +13 -3
  21. redback/tables/filters.csv +260 -258
  22. redback/tables/qdot_rosswogkorobkin24.npz +0 -0
  23. redback/transient_models/afterglow_models.py +32 -16
  24. redback/transient_models/combined_models.py +16 -11
  25. redback/transient_models/extinction_models.py +310 -84
  26. redback/transient_models/gaussianprocess_models.py +1 -12
  27. redback/transient_models/kilonova_models.py +3 -3
  28. redback/transient_models/phase_models.py +97 -43
  29. redback/transient_models/phenomenological_models.py +172 -0
  30. redback/transient_models/spectral_models.py +101 -0
  31. redback/transient_models/stellar_interaction_models.py +254 -0
  32. redback/transient_models/supernova_models.py +349 -62
  33. redback/transient_models/tde_models.py +193 -54
  34. redback/utils.py +34 -7
  35. {redback-1.1.dist-info → redback-1.12.1.dist-info}/METADATA +7 -4
  36. {redback-1.1.dist-info → redback-1.12.1.dist-info}/RECORD +39 -28
  37. {redback-1.1.dist-info → redback-1.12.1.dist-info}/WHEEL +1 -1
  38. redback/tables/qdot_rosswogkorobkin24.pck +0 -0
  39. {redback-1.1.dist-info → redback-1.12.1.dist-info}/licenses/LICENCE.md +0 -0
  40. {redback-1.1.dist-info → redback-1.12.1.dist-info}/top_level.txt +0 -0
redback/__init__.py CHANGED
@@ -5,5 +5,5 @@ from redback.transient import afterglow, kilonova, prompt, supernova, tde
5
5
  from redback.sampler import fit_model
6
6
  from redback.utils import setup_logger
7
7
 
8
- __version__ = "1.1.0"
8
+ __version__ = "1.12.1"
9
9
  setup_logger(log_level='info')
redback/filters.py CHANGED
@@ -2,63 +2,73 @@ from astropy.io import ascii
2
2
  from astropy import units as u
3
3
  from astroquery.svo_fps import SvoFps
4
4
  import numpy as np
5
+ from redback.utils import calc_effective_width_hz_from_angstrom
5
6
  import redback
6
7
  import sncosmo
7
8
 
8
- def add_to_database(LABEL, WAVELENGTH, ZEROFLUX, DATABASE, PLOT_LABEL, EFFECTIVE_WIDTH):
9
-
9
+ def add_to_database(label, wavelength, zeroflux, database, plot_label, effective_width):
10
+
10
11
  """
11
12
  Add a filter to the Redback filter database.
12
- :param LABEL: name of the filter in the Redback filter database
13
- :param WAVELENGTH: central wavelength of the filter as defined on SVO
13
+
14
+ :param label: name of the filter in the Redback filter database
15
+ :param wavelength: central wavelength of the filter as defined on SVO in m
16
+ :param zeroflux: zero flux of the filter in erg/cm^2/s/Hz
17
+ :param database: filter database
18
+ :param plot_label: plot label. If none is provided, it will use LABEL (default: None).
19
+ :param effective_width: effective width of the filter in Angstrom
14
20
  :return: None
15
21
  """
16
22
 
17
- frequency = 3.0e8 / WAVELENGTH
18
- effective_width = 3.0e8 / EFFECTIVE_WIDTH
19
- print(effective_width)
20
- DATABASE.add_row([LABEL, frequency, WAVELENGTH*1e10, 'black', ZEROFLUX, LABEL, PLOT_LABEL, effective_width])
23
+ frequency = 3.0e8 / wavelength
24
+ effective_width = calc_effective_width_hz_from_angstrom(effective_width=effective_width,
25
+ effective_wavelength=wavelength * 1e10)
26
+ database.add_row([label, frequency, wavelength * 1e10, 'black', zeroflux, label, plot_label, effective_width])
21
27
 
22
- def add_to_sncosmo(LABEL, TRANSMISSION):
28
+ def add_to_sncosmo(label, transmission):
23
29
 
24
30
  """
25
31
  Add a filter to the Redback filter database.
26
- :param LABEL: name of the filter in the Redback filter database
32
+
33
+ :param label: name of the filter in the Redback filter database
27
34
  :param WAVELENGTH: central wavelength of the filter as defined on SVO
28
35
  :return: None
29
36
  """
30
37
 
31
- band = sncosmo.Bandpass(TRANSMISSION['Wavelength'], TRANSMISSION['Transmission'], name=LABEL, wave_unit=u.angstrom)
32
- sncosmo.register(band, LABEL, force=True)
38
+ band = sncosmo.Bandpass(transmission['Wavelength'], transmission['Transmission'], name=label, wave_unit=u.angstrom)
39
+ sncosmo.register(band, label, force=True)
33
40
 
34
- def add_filter_svo(FILTER, LABEL, PLOT_LABEL=None, OVERWRITE=False):
41
+ def add_filter_svo(filter, label, plot_label=None, overwrite=False):
35
42
 
36
43
  """
37
44
  Wrapper to add a filter from SVO to SNCosmo and the Redback filter database
38
- :param FILTER: record from the SVO query
39
- :param LABEL: name of the filter in SNCosmo
45
+
46
+ :param filter: record from the SVO query
47
+ :param label: name of the filter in SNCosmo
48
+ :param plot_label: plot label. If none is provided, it will use LABEL (default: None).
49
+ :param overwrite: overwrite any existing entry? (default: False)
40
50
  :return: None
41
51
  """
42
52
 
43
- redback_db_fname = path = redback.__path__[0] + '/tables/filters.csv'
53
+ redback_db_fname = redback.__path__[0] + '/tables/filters.csv'
44
54
  database_filters = ascii.read(redback_db_fname)
45
55
 
46
- mask = np.where((database_filters['bands'] == LABEL) & (database_filters['sncosmo_name'] == LABEL))[0]
56
+ mask = np.where((database_filters['bands'] == label) & (database_filters['sncosmo_name'] == label))[0]
47
57
 
48
58
  # Only add filter to filter database if entry does not exist in the Redback database by default
49
59
 
50
60
  # If no entry exists or you choose to overwrite an entry
51
- if (len(mask) == 0) or ( (len(mask) != 0) & OVERWRITE ):
61
+ if (len(mask) == 0) or ((len(mask) != 0) & overwrite):
52
62
 
53
63
  if len(mask) > 0:
54
64
  database_filters.remove_rows(mask)
55
65
 
56
66
  # Reference (=pivot) wavelength, unit: AA
57
- wavelength_pivot = FILTER['WavelengthRef']
67
+ wavelength_pivot = filter['WavelengthRef']
58
68
 
59
69
  # Effective width
60
70
  # defined as int( T(lambda), lambda ) / max( T(lambda) ), unit: AA
61
- effective_width = FILTER['WidthEff']
71
+ effective_width = filter['WidthEff']
62
72
 
63
73
  # Zero flux
64
74
 
@@ -74,14 +84,14 @@ def add_filter_svo(FILTER, LABEL, PLOT_LABEL=None, OVERWRITE=False):
74
84
 
75
85
  # Add to Redback
76
86
 
77
- plot_label = PLOT_LABEL if PLOT_LABEL != None else LABEL
87
+ plot_label = plot_label if plot_label != None else label
78
88
 
79
- add_to_database(LABEL, wavelength_pivot * 1.0e-10, zeroflux, database_filters, plot_label, effective_width)
89
+ add_to_database(label, wavelength_pivot * 1.0e-10, zeroflux, database_filters, plot_label, effective_width)
80
90
 
81
91
  # Non-standard filters always needs to be re-added to SN Cosmo even if an entry exists in filter.csv
82
92
 
83
- filter_transmission = SvoFps.get_transmission_data(FILTER['filterID'])
84
- add_to_sncosmo(LABEL, filter_transmission)
93
+ filter_transmission = SvoFps.get_transmission_data(filter['filterID'])
94
+ add_to_sncosmo(label, filter_transmission)
85
95
 
86
96
  # Prettify output
87
97
 
@@ -92,16 +102,16 @@ def add_filter_svo(FILTER, LABEL, PLOT_LABEL=None, OVERWRITE=False):
92
102
 
93
103
  database_filters.write(redback_db_fname, overwrite=True, format='csv')
94
104
 
95
- def add_filter_user(FILE, LABEL, PLOT_LABEL=None, OVERWRITE=False):
105
+ def add_filter_user(file, label, plot_label=None, overwrite=False):
96
106
 
97
107
  """
98
108
  Wrapper to add a user filter from SVO to SNCosmo and the Redback filter database
99
- :param FILE: file name that contains the transmission function
109
+ :param file: file name that contains the transmission function
100
110
  (Must have two columns, wavelength must be in AA)
101
- :param LABEL: name of the filter
111
+ :param label: name of the filter
102
112
  :param DATABASE: location of the Redback filter database
103
- :param PLOT_LABEL: plot label. If none is provided, it will use LABEL (default: None).
104
- :param OVERWRITE: overwrite any existing entry? (default: False)
113
+ :param plot_label: plot label. If none is provided, it will use LABEL (default: None).
114
+ :param overwrite: overwrite any existing entry? (default: False)
105
115
  :return: None
106
116
  """
107
117
 
@@ -111,21 +121,21 @@ def add_filter_user(FILE, LABEL, PLOT_LABEL=None, OVERWRITE=False):
111
121
  database_filters = ascii.read(redback_db_fname)
112
122
 
113
123
  # Check whether such an entry already exists
114
- mask = np.where((database_filters['bands'] == LABEL) & (database_filters['sncosmo_name'] == LABEL))[0]
124
+ mask = np.where((database_filters['bands'] == label) & (database_filters['sncosmo_name'] == label))[0]
115
125
 
116
126
  # Add to SNCosmo
117
127
  # Needs to be done even if an entry exists in filters.csv
118
128
 
119
- filter_transmission = ascii.read(FILE)
129
+ filter_transmission = ascii.read(file)
120
130
  filter_transmission.rename_columns(list(filter_transmission.keys()), ['Wavelength', 'Transmission'])
121
131
 
122
- add_to_sncosmo(LABEL, filter_transmission)
132
+ add_to_sncosmo(label, filter_transmission)
123
133
 
124
134
  # Add to filter.csv
125
135
 
126
136
  # If no entry exists or you choose to overwrite an entry
127
137
 
128
- if (len(mask) == 0) or ( (len(mask) != 0) & OVERWRITE ):
138
+ if (len(mask) == 0) or ((len(mask) != 0) & overwrite):
129
139
 
130
140
  if len(mask) > 0:
131
141
  database_filters.remove_rows(mask)
@@ -156,11 +166,11 @@ def add_filter_user(FILE, LABEL, PLOT_LABEL=None, OVERWRITE=False):
156
166
 
157
167
  # Add to Redback
158
168
 
159
- plot_label = PLOT_LABEL if PLOT_LABEL != None else LABEL
169
+ plot_label = plot_label if plot_label != None else label
160
170
 
161
- print(LABEL, wavelength_pivot * 1.0e-10, zeroflux, plot_label)
171
+ print(label, wavelength_pivot * 1.0e-10, zeroflux, plot_label)
162
172
 
163
- add_to_database(LABEL, wavelength_pivot * 1.0e-10, zeroflux, database_filters, plot_label, effective_width)
173
+ add_to_database(label, wavelength_pivot * 1.0e-10, zeroflux, database_filters, plot_label, effective_width)
164
174
 
165
175
  # Prettify output
166
176
 
@@ -173,7 +183,7 @@ def add_filter_user(FILE, LABEL, PLOT_LABEL=None, OVERWRITE=False):
173
183
 
174
184
  else:
175
185
 
176
- print('Filter {} already exists. Set OVERWRITE to True if you want to overwrite the existing entry'.format(LABEL))
186
+ print('Filter {} already exists. Set OVERWRITE to True if you want to overwrite the existing entry'.format(label))
177
187
 
178
188
  def add_common_filters(overwrite=False):
179
189
 
@@ -191,7 +201,7 @@ def add_common_filters(overwrite=False):
191
201
  filter_label = ['grond::' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
192
202
  plot_label = ['GROND/' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
193
203
 
194
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
204
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
195
205
 
196
206
  print('done.\n')
197
207
 
@@ -206,7 +216,7 @@ def add_common_filters(overwrite=False):
206
216
  filter_label = ['efosc2::' + x for x in filter_list['Band']]
207
217
  plot_label = ['EFOSC/' + x for x in filter_list['Band']]
208
218
 
209
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
219
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
210
220
 
211
221
  print('done.\n')
212
222
 
@@ -218,7 +228,7 @@ def add_common_filters(overwrite=False):
218
228
  filter_label = ['euclid::' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
219
229
  plot_label = ['EUCLID/' + x.split('/')[1].split('.')[1].upper() for x in filter_list['filterID']]
220
230
 
221
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
231
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
222
232
 
223
233
 
224
234
  filter_list = SvoFps.get_filter_list(facility='Euclid', instrument='NISP')
@@ -227,7 +237,7 @@ def add_common_filters(overwrite=False):
227
237
  filter_label = ['euclid::' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
228
238
  plot_label = ['EUCLID/' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
229
239
 
230
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
240
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
231
241
 
232
242
  print('done.\n')
233
243
 
@@ -239,7 +249,7 @@ def add_common_filters(overwrite=False):
239
249
  filter_label = ['irac::' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
240
250
  plot_label = ['IRAC/' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
241
251
 
242
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
252
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
243
253
 
244
254
  print('done.\n')
245
255
 
@@ -251,7 +261,7 @@ def add_common_filters(overwrite=False):
251
261
  filter_label = ['wise::' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
252
262
  plot_label = ['WISE/' + x.split('/')[1].split('.')[1] for x in filter_list['filterID']]
253
263
 
254
- [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], OVERWRITE=overwrite) for ii in range(len(filter_list))]
264
+ [add_filter_svo(filter_list[ii], filter_label[ii], plot_label[ii], overwrite=overwrite) for ii in range(len(filter_list))]
255
265
 
256
266
  print('done.\n')
257
267
 
@@ -265,6 +275,7 @@ def show_all_filters():
265
275
  def add_effective_widths():
266
276
  """
267
277
  Adds effective widths to the Redback filter database
278
+
268
279
  :return: None
269
280
  """
270
281
  import pandas as pd
@@ -280,8 +291,9 @@ def add_effective_widths():
280
291
  # Calculate the effective width:
281
292
  # effective_width = ∫T(λ) dλ / max(T(λ))
282
293
  effective_width = np.trapz(trans, waves) / np.max(trans)
283
- effective_width = effective_width * u.Angstrom
284
- eff_width[ii] = effective_width.to(u.Hz, equivalencies=u.spectral()).value
294
+ effective_width = calc_effective_width_hz_from_angstrom(effective_width=effective_width,
295
+ effective_wavelength=band.wave_eff)
296
+ eff_width[ii] = effective_width
285
297
  except Exception:
286
298
  redback.utils.logger.warning("Failed for band={} at index={}".format(bb, ii))
287
299
  eff_width[ii] = db['wavelength [Hz]'].iloc[ii]
redback/likelihoods.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
2
  from typing import Any, Union
3
3
 
4
4
  import bilby
5
- from scipy.special import gammaln
5
+ from scipy.special import gammaln, erf
6
6
  from redback.utils import logger
7
7
  from bilby.core.prior import DeltaFunction, Constraint
8
8
 
@@ -180,9 +180,17 @@ class GaussianLikelihood(_RedbackLikelihood):
180
180
  else:
181
181
  raise ValueError('Sigma must be either float or array-like x.')
182
182
 
183
+ @property
184
+ def model_output(self) -> np.ndarray:
185
+ """
186
+ :return: The model output for the given x values.
187
+ :rtype: np.ndarray
188
+ """
189
+ return self.function(self.x, **self.parameters, **self.kwargs)
190
+
183
191
  @property
184
192
  def residual(self) -> np.ndarray:
185
- return self.y - self.function(self.x, **self.parameters, **self.kwargs)
193
+ return self.y - self.model_output
186
194
 
187
195
  def noise_log_likelihood(self) -> float:
188
196
  """
@@ -204,6 +212,266 @@ class GaussianLikelihood(_RedbackLikelihood):
204
212
  def _gaussian_log_likelihood(res: np.ndarray, sigma: Union[float, np.ndarray]) -> Any:
205
213
  return np.sum(- (res / sigma) ** 2 / 2 - np.log(2 * np.pi * sigma ** 2) / 2)
206
214
 
215
+
216
+ class GaussianLikelihoodWithUpperLimits(GaussianLikelihood):
217
+ def __init__(
218
+ self, x: np.ndarray, y: np.ndarray, sigma: Union[float, None, np.ndarray],
219
+ function: callable, kwargs: dict = None, priors=None,
220
+ fiducial_parameters=None, detections: Union[np.ndarray, None] = None,
221
+ upper_limit_sigma: Union[float, np.ndarray] = 3.0,
222
+ data_mode: str = 'flux') -> None:
223
+ """A Gaussian likelihood that handles upper limits - extends the base GaussianLikelihood.
224
+
225
+ :param x: The x values.
226
+ :type x: np.ndarray
227
+ :param y: The y values. For upper limits, these are the reported limit values.
228
+ :type y: np.ndarray
229
+ :param sigma: The standard deviation of the noise for detections.
230
+ :type sigma: Union[float, None, np.ndarray]
231
+ :param function:
232
+ The python function to fit to the data. Note, this must take the
233
+ dependent variable as its first argument. The other arguments
234
+ will require a prior and will be sampled over (unless a fixed
235
+ value is given).
236
+ :type function: callable
237
+ :param kwargs: Any additional keywords for 'function'.
238
+ :type kwargs: dict
239
+ :param priors: The priors for the parameters. Default to None if not provided.
240
+ Only necessary if using maximum likelihood estimation functionality.
241
+ :type priors: Union[dict, None]
242
+ :param fiducial_parameters: The starting guesses for model parameters to
243
+ use in the optimization for maximum likelihood estimation. Default to None if not provided.
244
+ :type fiducial_parameters: Union[dict, None]
245
+ :param detections: Array indicating which data points are detections.
246
+ Can be boolean (True/False) or integer (1/0). 1 = detection, 0 = upper limit.
247
+ If None, all data points are treated as detections.
248
+ :type detections: Union[np.ndarray, None]
249
+ :param upper_limit_sigma: The sigma level for upper limits. Can be a single value
250
+ (e.g., 3.0 for all 3-sigma limits) or an array with different sigma levels for each
251
+ upper limit. Default is 3.0.
252
+ :type upper_limit_sigma: Union[float, np.ndarray]
253
+ :param data_mode: Whether data is in 'flux' or 'magnitude' units. This affects
254
+ how upper limits are interpreted. For flux: upper limit means "true value < limit".
255
+ For magnitude: upper limit means "true value > limit" (fainter than limit).
256
+ :type data_mode: str
257
+ """
258
+
259
+ # Initialize the parent class first
260
+ super().__init__(x=x, y=y, sigma=sigma, function=function, kwargs=kwargs,
261
+ priors=priors, fiducial_parameters=fiducial_parameters)
262
+
263
+ # Add upper limit functionality
264
+ self.detections = detections
265
+ self.upper_limit_sigma = upper_limit_sigma
266
+ self.data_mode = data_mode
267
+
268
+ @property
269
+ def detections(self) -> np.ndarray:
270
+ return self._detections
271
+
272
+ @detections.setter
273
+ def detections(self, detections: Union[np.ndarray, None]) -> None:
274
+ if detections is None:
275
+ self._detections = np.ones(len(self.x), dtype=bool) # All detections by default
276
+ elif len(detections) == len(self.x):
277
+ # Convert to boolean array, handles both 0/1 and True/False
278
+ self._detections = np.array(detections, dtype=bool)
279
+ else:
280
+ raise ValueError('detections must have the same length as x.')
281
+
282
+ @property
283
+ def upper_limits(self) -> np.ndarray:
284
+ """Derived property: upper_limits is the inverse of detections"""
285
+ return ~self._detections
286
+
287
+ @property
288
+ def upper_limit_sigma(self) -> Union[float, np.ndarray]:
289
+ return self._upper_limit_sigma
290
+
291
+ @upper_limit_sigma.setter
292
+ def upper_limit_sigma(self, upper_limit_sigma: Union[float, np.ndarray]) -> None:
293
+ if isinstance(upper_limit_sigma, (float, int)):
294
+ self._upper_limit_sigma = float(upper_limit_sigma)
295
+ elif isinstance(upper_limit_sigma, np.ndarray):
296
+ if len(upper_limit_sigma) == len(self.x):
297
+ self._upper_limit_sigma = upper_limit_sigma
298
+ elif hasattr(self, '_detections') and len(upper_limit_sigma) == np.sum(~self._detections):
299
+ # Array length matches number of upper limits
300
+ self._upper_limit_sigma = upper_limit_sigma
301
+ else:
302
+ raise ValueError('upper_limit_sigma array must have length equal to x or to number of upper limits.')
303
+ else:
304
+ raise ValueError('upper_limit_sigma must be a float or array.')
305
+
306
+ @property
307
+ def data_mode(self) -> str:
308
+ return self._data_mode
309
+
310
+ @data_mode.setter
311
+ def data_mode(self, data_mode: str) -> None:
312
+ if data_mode.lower() not in ['flux', 'magnitude', 'mag']:
313
+ raise ValueError("data_mode must be 'flux' or 'magnitude' (or 'mag')")
314
+ self._data_mode = data_mode.lower()
315
+ if self._data_mode == 'mag':
316
+ self._data_mode = 'magnitude'
317
+
318
+ def get_upper_limit_sigma_values(self) -> np.ndarray:
319
+ """
320
+ Get the sigma values for upper limits only.
321
+
322
+ :return: Array of sigma values for upper limit data points.
323
+ :rtype: np.ndarray
324
+ """
325
+ if not np.any(self.upper_limits):
326
+ return np.array([])
327
+
328
+ if isinstance(self.upper_limit_sigma, (float, int)):
329
+ # Same sigma level for all upper limits
330
+ n_upper_limits = np.sum(self.upper_limits)
331
+ return np.full(n_upper_limits, self.upper_limit_sigma)
332
+ elif len(self.upper_limit_sigma) == len(self.x):
333
+ # Sigma level for each data point, extract upper limits only
334
+ return self.upper_limit_sigma[self.upper_limits]
335
+ else:
336
+ # Array already has length equal to number of upper limits
337
+ return self.upper_limit_sigma
338
+
339
+ @staticmethod
340
+ def _normal_cdf(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
341
+ """
342
+ Fast computation of normal CDF using erf.
343
+ CDF(x) = 0.5 * (1 + erf(x / sqrt(2)))
344
+
345
+ :param x: Standardized values (z-scores)
346
+ :return: CDF values
347
+ """
348
+ return 0.5 * (1.0 + erf(x / np.sqrt(2)))
349
+
350
+ def _upper_limit_log_likelihood(self, observed: np.ndarray, model: np.ndarray) -> float:
351
+ """
352
+ Calculate log-likelihood contribution from upper limits only.
353
+
354
+ :param observed: Upper limit values
355
+ :param model: Model predictions at upper limit points
356
+ :return: Log-likelihood contribution from upper limits
357
+ """
358
+ if not np.any(self.upper_limits):
359
+ return 0.0
360
+
361
+ model_ul = model[self.upper_limits]
362
+ observed_ul = observed[self.upper_limits]
363
+
364
+ # Get the sigma levels for each upper limit
365
+ ul_sigma_levels = self.get_upper_limit_sigma_values()
366
+
367
+ # The measurement uncertainty - this calculation depends on data mode
368
+ if self.data_mode == 'magnitude':
369
+ # For magnitudes, the uncertainty is typically symmetric in mag space
370
+ # If we don't have explicit uncertainties for upper limits, we need to estimate them
371
+ # This is a common issue - often we only have the sigma level, not the actual uncertainty
372
+ # We'll use a reasonable default or derive from the limit
373
+
374
+ # Option 1: Use a typical photometric uncertainty (you may want to adjust this)
375
+ sigma_measurement = np.full_like(observed_ul, 0.1) # Assume 0.1 mag uncertainty
376
+
377
+ # Option 2: Derive from the sigma level (uncomment if preferred)
378
+ # sigma_measurement = observed_ul / ul_sigma_levels # This assumes the limit is sigma_level * uncertainty
379
+
380
+ else: # flux mode
381
+ sigma_measurement = observed_ul / ul_sigma_levels
382
+
383
+ # Calculate the probability based on data mode
384
+ if self.data_mode == 'magnitude':
385
+ # For magnitudes: upper limit means "true magnitude > observed_limit" (fainter than limit)
386
+ # We want: P(true_mag > observed_upper_limit | model_prediction)
387
+ # This is 1 - CDF(observed_limit) = CDF(-standardized) due to symmetry
388
+ standardized = (observed_ul - model_ul) / sigma_measurement
389
+ # P(X > observed) = 1 - P(X <= observed) = 1 - CDF(standardized)
390
+ survival_prob = 1.0 - self._normal_cdf(standardized)
391
+ cdf_values = survival_prob
392
+
393
+ else: # flux mode
394
+ # For flux: upper limit means "true flux < observed_limit"
395
+ # We want: P(true_flux < observed_upper_limit | model_prediction)
396
+ standardized = (observed_ul - model_ul) / sigma_measurement
397
+ cdf_values = self._normal_cdf(standardized)
398
+
399
+ # Add small epsilon to avoid log(0) and clip to valid range
400
+ epsilon = 1e-30
401
+ cdf_values = np.clip(cdf_values, epsilon, 1.0 - epsilon)
402
+
403
+ return np.sum(np.log(cdf_values))
404
+
405
+ def noise_log_likelihood(self) -> float:
406
+ """
407
+ Override parent method to include upper limits in noise likelihood.
408
+
409
+ :return: The noise log-likelihood, i.e. the log-likelihood assuming the signal is just noise.
410
+ :rtype: float
411
+ """
412
+ if self._noise_log_likelihood is None:
413
+ # Detections part (use parent class method for detected points only)
414
+ if np.any(self.detections):
415
+ y_det = self.y[self.detections]
416
+ sigma_det = self.sigma if np.isscalar(self.sigma) else self.sigma[self.detections]
417
+ detection_noise_ll = self._gaussian_log_likelihood(res=y_det, sigma=sigma_det)
418
+ else:
419
+ detection_noise_ll = 0.0
420
+
421
+ # Upper limits part (assume model = 0 for noise in flux, or some reference mag for magnitudes)
422
+ if self.data_mode == 'magnitude':
423
+ # For magnitudes, "no signal" might mean very faint (large magnitude)
424
+ # You might want to adjust this based on your specific case
425
+ noise_model = np.full_like(self.y, 60.0) # Assume 30 mag as "no signal"
426
+ else:
427
+ noise_model = np.zeros_like(self.y)
428
+
429
+ ul_noise_ll = self._upper_limit_log_likelihood(observed=self.y, model=noise_model)
430
+
431
+ self._noise_log_likelihood = detection_noise_ll + ul_noise_ll
432
+
433
+ return self._noise_log_likelihood
434
+
435
+ def log_likelihood(self) -> float:
436
+ """
437
+ Override parent method to include upper limits.
438
+
439
+ :return: The log-likelihood including upper limits.
440
+ :rtype: float
441
+ """
442
+ # Detections part (use parent class method for detected points only)
443
+ if np.any(self.detections):
444
+ residual_det = self.residual[self.detections]
445
+ sigma_det = self.sigma if np.isscalar(self.sigma) else self.sigma[self.detections]
446
+ detection_ll = self._gaussian_log_likelihood(res=residual_det, sigma=sigma_det)
447
+ else:
448
+ detection_ll = 0.0
449
+
450
+ # Upper limits part
451
+ ul_ll = self._upper_limit_log_likelihood(observed=self.y, model=self.model_output)
452
+
453
+ return np.nan_to_num(detection_ll + ul_ll)
454
+
455
+ def summary(self) -> dict:
456
+ """
457
+ Provide a summary of the likelihood setup.
458
+
459
+ :return: Dictionary with summary information
460
+ """
461
+ n_detections = np.sum(self.detections)
462
+ n_upper_limits = np.sum(self.upper_limits)
463
+
464
+ summary_dict = {
465
+ 'total_data_points': len(self.x),
466
+ 'detections': n_detections,
467
+ 'upper_limits': n_upper_limits,
468
+ 'data_mode': self.data_mode,
469
+ 'upper_limit_sigma_levels': self.get_upper_limit_sigma_values() if n_upper_limits > 0 else None
470
+ }
471
+
472
+ return summary_dict
473
+
474
+
207
475
  class MixtureGaussianLikelihood(GaussianLikelihood):
208
476
  def __init__(self, x: np.ndarray, y: np.ndarray,
209
477
  sigma: Union[float, None, np.ndarray],
@@ -534,7 +802,6 @@ class GaussianLikelihoodUniformXErrors(GaussianLikelihood):
534
802
  """
535
803
  return np.nan_to_num(self.log_likelihood_x() + self.log_likelihood_y())
536
804
 
537
-
538
805
  class GaussianLikelihoodQuadratureNoise(GaussianLikelihood):
539
806
  def __init__(
540
807
  self, x: np.ndarray, y: np.ndarray, sigma_i: Union[float, None, np.ndarray],
@@ -628,7 +895,7 @@ class GaussianLikelihoodWithFractionalNoise(GaussianLikelihood):
628
895
  :return: The standard deviation of the full noise
629
896
  :rtype: Union[float, np.ndarray]
630
897
  """
631
- model_y = self.function(self.x, **self.parameters, **self.kwargs)
898
+ model_y = self.model_output
632
899
  return np.sqrt(self.sigma_i**2.*model_y**2)
633
900
 
634
901
  def noise_log_likelihood(self) -> float:
@@ -652,7 +919,8 @@ class GaussianLikelihoodWithSystematicNoise(GaussianLikelihood):
652
919
  self, x: np.ndarray, y: np.ndarray, sigma_i: Union[float, None, np.ndarray],
653
920
  function: callable, kwargs: dict = None, priors=None, fiducial_parameters=None) -> None:
654
921
  """
655
- A Gaussian likelihood with a systematic noise term that is proportional to the model + some additive noise.
922
+ A Gaussian likelihood with a systematic noise term that is proportional to the model +
923
+ the original data noise added in quadrature.
656
924
  The parameters are inferred from the arguments of function
657
925
 
658
926
  :param x: The x values.
@@ -688,7 +956,7 @@ class GaussianLikelihoodWithSystematicNoise(GaussianLikelihood):
688
956
  :return: The standard deviation of the full noise
689
957
  :rtype: Union[float, np.ndarray]
690
958
  """
691
- model_y = self.function(self.x, **self.parameters, **self.kwargs)
959
+ model_y = self.model_output
692
960
  return np.sqrt(self.sigma_i**2. + model_y**2*self.sigma**2.)
693
961
 
694
962
  def noise_log_likelihood(self) -> float:
redback/model_library.py CHANGED
@@ -2,7 +2,7 @@ from redback.transient_models import afterglow_models, \
2
2
  extinction_models, kilonova_models, fireball_models, \
3
3
  gaussianprocess_models, magnetar_models, magnetar_driven_ejecta_models, phase_models, phenomenological_models, \
4
4
  prompt_models, shock_powered_models, supernova_models, tde_models, integrated_flux_afterglow_models, combined_models, \
5
- general_synchrotron_models, spectral_models
5
+ general_synchrotron_models, spectral_models, stellar_interaction_models
6
6
 
7
7
  from redback.utils import get_functions_dict
8
8
 
@@ -10,7 +10,7 @@ modules = [afterglow_models, extinction_models, fireball_models,
10
10
  gaussianprocess_models, integrated_flux_afterglow_models, kilonova_models,
11
11
  magnetar_models, magnetar_driven_ejecta_models,
12
12
  phase_models, phenomenological_models, prompt_models, shock_powered_models, supernova_models,
13
- tde_models, combined_models, general_synchrotron_models, spectral_models]
13
+ tde_models, combined_models, general_synchrotron_models, spectral_models, stellar_interaction_models]
14
14
 
15
15
  base_modules = [extinction_models, phase_models]
16
16
 
redback/plotting.py CHANGED
@@ -672,6 +672,7 @@ class MagnitudePlotter(Plotter):
672
672
  if self.transient.magnitude_data:
673
673
  ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
674
674
  ax.invert_yaxis()
675
+ ax.set_yscale('linear')
675
676
  else:
676
677
  ax.set_ylim(self._ylim_low, self._ylim_high)
677
678
  ax.set_yscale("log")
@@ -680,6 +681,7 @@ class MagnitudePlotter(Plotter):
680
681
  if self.transient.magnitude_data:
681
682
  ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
682
683
  ax.invert_yaxis()
684
+ ax.set_yscale('linear')
683
685
  else:
684
686
  ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
685
687
  self._get_ylim_high_with_indices(indices=indices))
@@ -742,7 +744,7 @@ class MagnitudePlotter(Plotter):
742
744
  color = self._colors[list(self._filters).index(band)]
743
745
  if band_label_generator is None:
744
746
  if band in self.band_scaling:
745
- label = str(self.band_scaling.get(band)) + ' ' + self.band_scaling.get("type") + ' ' + band
747
+ label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band))
746
748
  else:
747
749
  label = band
748
750
  else:
@@ -806,7 +808,6 @@ class MagnitudePlotter(Plotter):
806
808
  axes = axes or plt.gca()
807
809
 
808
810
  axes = self.plot_data(axes=axes, save=False, show=False)
809
- axes.set_yscale('log')
810
811
 
811
812
  times = self._get_times(axes)
812
813
  bands_to_plot = self._get_bands_to_plot
@@ -983,7 +984,8 @@ class MagnitudePlotter(Plotter):
983
984
  continue
984
985
  new_model_kwargs = self._model_kwargs.copy()
985
986
  new_model_kwargs['frequency'] = freq
986
- new_model_kwargs['bands'] = band
987
+ new_model_kwargs['bands'] = redback.utils.sncosmo_bandname_from_band([band])
988
+ new_model_kwargs['bands'] = [new_model_kwargs['bands'][0] for _ in range(len(times))]
987
989
 
988
990
  if self.set_same_color_per_subplot is True:
989
991
  color = self._colors[list(self._filters).index(band)]
@@ -0,0 +1,3 @@
1
+ redshift = Uniform(0.01, 1.0, 'redshift', latex_label = r'$z$')
2
+ rph = LogUniform(1e13, 1e20, 'rph', latex_label = r'$R_{\mathrm{ph}}~(\mathrm{cm})$')
3
+ temp = Uniform(2000, 15000, 'temp', latex_label = r'$T~(\mathrm{K})$')
@@ -0,0 +1,9 @@
1
+ redshift = Uniform(1e-6, 3, 'redshift', latex_label = r'$z$')
2
+ peak_time = LogUniform(0.1,60, name='peak_time', latex_label = r'$t_{\mathrm{peak}}~(\mathrm{day})$')
3
+ alpha_1 = Uniform(0.5, 3, name='alpha_1', latex_label=r'$\\alpha_{1}$')
4
+ alpha_2 = Uniform(-3, -0.5, name='alpha_2', latex_label=r'$\\alpha_{2}$')
5
+ mbh_6 = LogUniform(0.01, 20, name='mbh_6', latex_label = r'$M_{\mathrm{BH}}~(10^{6}~M_\odot)$')
6
+ stellar_mass = LogUniform(0.1, 10, name='stellar_mass', latex_label = r'$M_{\mathrm{star}}~(M_\odot)$')
7
+ eta = LogUniform(1e-4, 0.1, name='eta', latex_label=r'$\\eta$')
8
+ alpha = LogUniform(0.1, 1, name='alpha', latex_label=r'$\\alpha$')
9
+ beta = Uniform(1, 5, name='beta', latex_label=r'$\\beta$')