astro-otter 0.0.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

@@ -39,7 +39,7 @@ class OtterPlotter:
39
39
  elif self.backend == "plotly.graph_objects":
40
40
  self.plot = self._plot_plotly
41
41
  else:
42
- raise ValueError("Unknown backend!")
42
+ raise ValueError("Unknown plotting backend!")
43
43
 
44
44
  def _plot_matplotlib(self, x, y, xerr=None, yerr=None, ax=None, **kwargs):
45
45
  """
@@ -53,17 +53,19 @@ class OtterPlotter:
53
53
  ax.errorbar(x, y, xerr=xerr, yerr=yerr, **kwargs)
54
54
  return ax
55
55
 
56
- def _plot_plotly(self, x, y, xerr=None, yerr=None, go=None, *args, **kwargs):
56
+ def _plot_plotly(self, x, y, xerr=None, yerr=None, ax=None, *args, **kwargs):
57
57
  """
58
58
  General plotting method using plotly, is called by _plotly_light_curve and
59
59
  _plotly_sed
60
60
  """
61
61
 
62
- if go is None:
62
+ if ax is None:
63
63
  go = self.plotter.Figure()
64
+ else:
65
+ go = ax
64
66
 
65
67
  fig = go.add_scatter(
66
- x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr)
68
+ x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr), **kwargs
67
69
  )
68
70
 
69
71
  return fig
otter/plotter/plotter.py CHANGED
@@ -3,7 +3,185 @@ Some utilities to create common plots for transients that use the OtterPlotter
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+ from warnings import warn
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+
6
12
  from .otter_plotter import OtterPlotter
13
+ from ..exceptions import FailedQueryError
14
+ from ..io.otter import Transient, Otter
15
+
16
+
17
+ def query_quick_view(
18
+ db: Otter = None,
19
+ otter_path: str = None,
20
+ ptype: str = "both",
21
+ sed_dim: str = "freq",
22
+ dt_over_t: float = 0,
23
+ plotting_kwargs: dict = {},
24
+ phot_cleaning_kwargs: dict = {},
25
+ result_length_tol=10,
26
+ **kwargs,
27
+ ) -> list[plt.Figure]:
28
+ """
29
+ Queries otter and then plots all of the transients. It will either query the Otter
30
+ object provided in db or construct an Otter object from otter_path.
31
+
32
+ Args:
33
+ db (otter.Otter) : The otter object to query
34
+ otter_path (str) : The path to construct an otter path from
35
+ ptype (str) : The plot type to generate. Valid options are
36
+ - both -> Plot both light curve and sed (default)
37
+ - sed -> Plot just the sed
38
+ - lc -> Plot just the light curve
39
+ sed_dim (str) : The x dimension to plot in the SED. Options are "freq" or
40
+ "wave". Default is "freq".
41
+ time_tol (float) : The tolerance to split the days by. Default is 1 day. must be
42
+ in units of days.
43
+ plotting_kwargs (dict) : dictionary of key word arguments to pass to
44
+ otter.plotter.plot_light_curve or
45
+ otter.plotter.plot_sed.
46
+ phot_cleaning_kwargs (dict) : Keyword arguments passed to
47
+ otter.Transient.clean_photometry
48
+ result_length_tol (int) : If the query result is longer than this it will throw
49
+ an erorr to prevent 100s of plots from spitting out
50
+ (and likely crashing your computer). Default is 10 but
51
+ can be adjusted.
52
+ **kwargs : Arguments to pass to otter.Otter.query
53
+
54
+ Returns:
55
+ A list of matplotlib pyplot Figure objects that we plotted
56
+
57
+ """
58
+ if db is None:
59
+ if otter_path is not None:
60
+ db = Otter(otter_path)
61
+ else:
62
+ raise ValueError("Either the db or otter_path arguments must be provided!")
63
+
64
+ res = db.query(**kwargs)
65
+
66
+ if len(res) > result_length_tol:
67
+ raise RuntimeError(
68
+ f"This query returned {len(res)} results which is greater than the given "
69
+ + f"tolerance of {result_length_tol}! Either increase the result_length_tol"
70
+ + " keyword or pass in a stricter query!"
71
+ )
72
+
73
+ figs = []
74
+ for t in res:
75
+ try:
76
+ fig = quick_view(
77
+ t, ptype, sed_dim, dt_over_t, plotting_kwargs, **phot_cleaning_kwargs
78
+ )
79
+ except (KeyError, FailedQueryError):
80
+ warn(f"No photometry associated with {t.default_name}, skipping!")
81
+ continue
82
+
83
+ fig.suptitle(t.default_name)
84
+ figs.append(fig)
85
+
86
+ return figs
87
+
88
+
89
+ def quick_view(
90
+ t: Transient,
91
+ ptype: str = "both",
92
+ sed_dim: str = "freq",
93
+ dt_over_t: float = 0,
94
+ plotting_kwargs: dict = {},
95
+ **kwargs,
96
+ ) -> plt.Figure:
97
+ """
98
+ Generate a quick view (not necessarily publication ready) of the transients light
99
+ curve, SED, or both. Default is to do both.
100
+
101
+ Args:
102
+ t (otter.Transient) : An otter Transient object to grab photometry from
103
+ ptype (str) : The plot type to generate. Valid options are
104
+ - both -> Plot both light curve and sed (default)
105
+ - sed -> Plot just the sed
106
+ - lc -> Plot just the light curve
107
+ sed_dim (str) : The x dimension to plot in the SED. Options are "freq" or
108
+ "wave". Default is "freq".
109
+ dt_over_t (float) : The tolerance to split the days by. Default is 1 day. must
110
+ be unitless.
111
+ plotting_kwargs (dict) : dictionary of key word arguments to pass to
112
+ otter.plotter.plot_light_curve or
113
+ otter.plotter.plot_sed.
114
+ **kwargs : Any other arguments to pass to otter.Transient.clean_photometry
115
+
116
+ Returns:
117
+ The matplotlib figure used for plotting.
118
+ """
119
+ backend = plotting_kwargs.get("backend", "matplotlib.pyplot")
120
+ if backend not in {"matplotlib.pyplot", "matplotlib"}:
121
+ raise ValueError(
122
+ "Only matplotlib.pyplot backend is available for quick_view!"
123
+ + " To use plotly, use the plotting functionality individually!"
124
+ )
125
+
126
+ allphot = t.clean_photometry(**kwargs)
127
+ allphot = allphot.sort_values("converted_date")
128
+ allphot["time_tol"] = dt_over_t * allphot["converted_date"]
129
+ allphot["time_diff"] = allphot["converted_date"].diff().fillna(-np.inf)
130
+ allphot["time_grp"] = (allphot.time_diff > allphot.time_tol).cumsum()
131
+
132
+ plt_lc = (ptype == "both") or (ptype == "lc")
133
+ plt_sed = (ptype == "both") or (ptype == "sed")
134
+
135
+ if ptype == "both":
136
+ fig, (lc_ax, sed_ax) = plt.subplots(1, 2)
137
+ elif ptype == "sed":
138
+ fig, sed_ax = plt.subplots()
139
+ elif ptype == "lc":
140
+ fig, lc_ax = plt.subplots()
141
+
142
+ if np.all(pd.isna(allphot.converted_flux_err)):
143
+ flux_err = None
144
+ else:
145
+ flux_err = allphot.converted_flux_err
146
+
147
+ if plt_lc:
148
+ for filt, phot in allphot.groupby("filter_name"):
149
+ plot_light_curve(
150
+ date=phot.converted_date,
151
+ flux=phot.converted_flux,
152
+ flux_err=flux_err[allphot.filter_name == filt],
153
+ xlabel=f"Date [{phot.converted_date_unit.values[0]}]",
154
+ ylabel=f"Flux [{phot.converted_flux_unit.values[0]}]",
155
+ ax=lc_ax,
156
+ label=filt,
157
+ **plotting_kwargs,
158
+ )
159
+
160
+ if plt_sed:
161
+ for grp_name, phot in allphot.groupby("time_grp"):
162
+ if sed_dim == "wave":
163
+ wave_or_freq = phot.converted_wave
164
+ xlab = f"Wavelength [{phot.converted_wave_unit.values[0]}]"
165
+ elif sed_dim == "freq":
166
+ wave_or_freq = phot.converted_freq
167
+ xlab = f"Frequency [{phot.converted_freq_unit.values[0]}]"
168
+ else:
169
+ raise ValueError("sed_dim value is not recognized!")
170
+
171
+ plot_sed(
172
+ wave_or_freq=wave_or_freq,
173
+ flux=phot.converted_flux,
174
+ flux_err=flux_err[allphot.time_grp == grp_name],
175
+ ax=sed_ax,
176
+ xlabel=xlab,
177
+ ylabel=f"Flux [{phot.converted_flux_unit.values[0]}]",
178
+ label=phot.converted_date.mean(),
179
+ **plotting_kwargs,
180
+ )
181
+
182
+ sed_ax.set_xscale("log")
183
+
184
+ return fig
7
185
 
8
186
 
9
187
  def plot_light_curve(
@@ -46,7 +224,7 @@ def plot_light_curve(
46
224
  fig.set_xlabel(xlabel)
47
225
 
48
226
  elif backend == "plotly":
49
- fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, **kwargs)
227
+ fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel)
50
228
 
51
229
  return fig
52
230
 
@@ -91,6 +269,6 @@ def plot_sed(
91
269
  fig.set_xlabel(xlabel)
92
270
 
93
271
  elif backend == "plotly":
94
- fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, **kwargs)
272
+ fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel)
95
273
 
96
274
  return fig
otter/schema.py ADDED
@@ -0,0 +1,296 @@
1
+ """
2
+ Pydantic Schema Model of our JSON schema
3
+ """
4
+
5
+ from pydantic import BaseModel, model_validator, field_validator, ValidationError
6
+ from typing import Optional, Union, List
7
+
8
+
9
+ class VersionSchema(BaseModel):
10
+ value: Union[str, int] = None
11
+ comment: str = None
12
+
13
+
14
+ class _AliasSchema(BaseModel):
15
+ value: str
16
+ reference: Union[str, List[str]]
17
+
18
+
19
+ class _XrayModelSchema(BaseModel):
20
+ # the following two lines are needed to prevent annoying warnings
21
+ model_config: dict = {}
22
+ model_config["protected_namespaces"] = ()
23
+
24
+ # required keywords
25
+ model_name: str
26
+ param_names: List[str]
27
+ param_values: List[Union[float, int, str]]
28
+ param_units: List[Union[str, None]]
29
+ min_energy: Union[float, int, str]
30
+ max_energy: Union[float, int, str]
31
+ energy_units: str
32
+
33
+ # optional keywords
34
+ param_value_err_upper: Optional[List[Union[float, int, str]]] = None
35
+ param_value_err_lower: Optional[List[Union[float, int, str]]] = None
36
+ param_upperlimit: Optional[List[Union[float, int, str]]] = None
37
+ param_descriptions: Optional[List[str]] = None
38
+ model_reference: Optional[Union[str, List[str]]] = None
39
+
40
+
41
+ class _ErrDetailSchema(BaseModel):
42
+ # all optional keywords!
43
+ upper: Optional[List[Union[float, int, str]]] = None
44
+ lower: Optional[List[Union[float, int, str]]] = None
45
+ systematic: Optional[List[Union[float, int, str]]] = None
46
+ statistical: Optional[List[Union[float, int, str]]] = None
47
+ iss: Optional[List[Union[float, int, str]]] = None
48
+
49
+
50
+ class NameSchema(BaseModel):
51
+ default_name: str
52
+ alias: list[_AliasSchema]
53
+
54
+
55
+ class CoordinateSchema(BaseModel):
56
+ reference: Union[List[str], str]
57
+ ra: Union[str, float] = None
58
+ dec: Union[str, float] = None
59
+ l: Union[str, float] = None # noqa: E741
60
+ b: Union[str, float] = None
61
+ lon: Union[str, float] = None
62
+ lat: Union[str, float] = None
63
+ ra_units: str = None
64
+ dec_units: str = None
65
+ l_units: str = None
66
+ b_units: str = None
67
+ lon_units: str = None
68
+ lat_units: str = None
69
+ ra_error: Union[str, float] = None
70
+ dec_error: Union[str, float] = None
71
+ l_error: Union[str, float] = None
72
+ b_error: Union[str, float] = None
73
+ lon_error: Union[str, float] = None
74
+ lat_error: Union[str, float] = None
75
+ epoch: str = None
76
+ frame: str = "J2000"
77
+ coord_type: str = None
78
+ computed: bool = False
79
+ default: bool = False
80
+
81
+ @model_validator(mode="after")
82
+ def _has_coordinate(self):
83
+ uses_ra_dec = self.ra is not None and self.dec is not None
84
+ uses_galactic = self.l is not None and self.b is not None
85
+ uses_lon_lat = self.lon is not None and self.lat is not None
86
+
87
+ if uses_ra_dec:
88
+ if self.ra_units is None:
89
+ raise ValidationError("ra_units must be provided for RA!")
90
+ if self.dec_units is None:
91
+ raise ValidationError("dec_units must be provided for Dec!")
92
+
93
+ elif uses_galactic:
94
+ if self.l_units is None:
95
+ raise ValidationError("l_units must be provided for RA!")
96
+ if self.b_units is None:
97
+ raise ValidationError("b_units must be provided for Dec!")
98
+
99
+ elif uses_lon_lat:
100
+ if self.lon_units is None:
101
+ raise ValidationError("lon_units must be provided for RA!")
102
+ if self.lat_units is None:
103
+ raise ValidationError("lat_units must be provided for Dec!")
104
+
105
+ else:
106
+ ValidationError("Must have RA/Dec, l/b, and/or lon/lat!")
107
+
108
+ return self
109
+
110
+
111
+ class DistanceSchema(BaseModel):
112
+ value: Union[str, float, int]
113
+ unit: str = None
114
+ reference: Union[str, List[str]]
115
+ distance_type: str
116
+ error: Union[str, float, int] = None
117
+ cosmology: str = None
118
+ computed: bool = False
119
+ uuid: str = None
120
+ default: bool = False
121
+
122
+ @model_validator(mode="after")
123
+ def _has_units(self):
124
+ if self.distance_type != "redshift" and self.unit is None:
125
+ raise ValidationError("Need units if the distance_type is not redshift!")
126
+
127
+ return self
128
+
129
+
130
+ class ClassificationSchema(BaseModel):
131
+ object_class: str
132
+ confidence: float
133
+ reference: Union[str, List[str]]
134
+ default: bool = False
135
+ class_type: str = None
136
+
137
+
138
+ class ReferenceSchema(BaseModel):
139
+ name: str
140
+ human_readable_name: str
141
+
142
+
143
+ class DateSchema(BaseModel):
144
+ value: Union[str, int, float]
145
+ date_format: str
146
+ date_type: str
147
+ reference: Union[str, List[str]]
148
+ computed: bool = None
149
+
150
+
151
+ class PhotometrySchema(BaseModel):
152
+ reference: Union[List[str], str]
153
+ raw: list[Union[float, int]]
154
+ raw_err: Optional[List[float]] = []
155
+ raw_units: Union[str, List[str]]
156
+ value: Optional[list[Union[float, int]]] = None
157
+ value_err: Optional[list[Union[float, int]]] = None
158
+ value_units: Optional[Union[str, List[str]]] = None
159
+ epoch_zeropoint: Optional[Union[float, str, int]] = None
160
+ epoch_redshift: Optional[Union[float, int]] = None
161
+ filter: Optional[Union[str, List[str]]] = None
162
+ filter_key: Union[str, List[str]]
163
+ obs_type: Union[str, List[str]]
164
+ telescope_area: Optional[Union[float, List[float]]] = None
165
+ date: Union[str, float, List[Union[str, float]]]
166
+ date_format: Union[str, List[str]]
167
+ date_err: Optional[Union[str, float, List[Union[str, float]]]] = None
168
+ ignore: Optional[Union[bool, List[bool]]] = None
169
+ upperlimit: Optional[Union[bool, List[bool]]] = None
170
+ sigma: Optional[Union[str, float, List[Union[str, float]]]] = None
171
+ sky: Optional[Union[str, float, List[Union[str, float]]]] = None
172
+ telescope: Optional[Union[str, List[str]]] = None
173
+ instrument: Optional[Union[str, List[str]]] = None
174
+ phot_type: Optional[Union[str, List[str]]] = None
175
+ exptime: Optional[Union[str, int, float, List[Union[str, int, float]]]] = None
176
+ aperture: Optional[Union[str, int, float, List[Union[str, int, float]]]] = None
177
+ observer: Optional[Union[str, List[str]]] = None
178
+ reducer: Optional[Union[str, List[str]]] = None
179
+ pipeline: Optional[Union[str, List[str]]] = None
180
+ corr_k: Optional[Union[bool, str, List[Union[bool, str]]]] = None
181
+ corr_s: Optional[Union[bool, str, List[Union[bool, str]]]] = None
182
+ corr_av: Optional[Union[bool, str, List[Union[bool, str]]]] = None
183
+ corr_host: Optional[Union[bool, str, List[Union[bool, str]]]] = None
184
+ corr_hostav: Optional[Union[bool, str, List[Union[bool, str]]]] = None
185
+ val_k: Optional[Union[float, int, str, List[Union[float, int, str]]]] = None
186
+ val_s: Optional[Union[float, int, str, List[Union[float, int, str]]]] = None
187
+ val_av: Optional[Union[float, int, str, List[Union[float, int, str]]]] = None
188
+ val_host: Optional[Union[float, int, str, List[Union[float, int, str]]]] = None
189
+ val_hostav: Optional[Union[float, int, str, List[Union[float, int, str]]]] = None
190
+ xray_model: Optional[Union[List[_XrayModelSchema], List[None]]] = None
191
+ raw_err_detail: Optional[_ErrDetailSchema] = None
192
+ value_err_detail: Optional[_ErrDetailSchema] = None
193
+
194
+ @field_validator(
195
+ "raw_units",
196
+ "raw_err",
197
+ "filter_key",
198
+ "obs_type",
199
+ "date_format",
200
+ "upperlimit",
201
+ "date",
202
+ "telescope",
203
+ )
204
+ @classmethod
205
+ def ensure_list(cls, v):
206
+ if not isinstance(v, list):
207
+ return [v]
208
+ return v
209
+
210
+ @model_validator(mode="after")
211
+ def _ensure_xray_model(self):
212
+ """
213
+ This will eventually ensure the xray_model key is used if obs_type="xray"
214
+
215
+ It will be commented out until we get the data setup correctly
216
+ """
217
+ # if self.obs_type == "xray" and self.xray_model is None:
218
+ # raise ValidationError(
219
+ # "Need an xray_model for this xray data!"
220
+ # )
221
+
222
+ return self
223
+
224
+
225
+ class FilterSchema(BaseModel):
226
+ filter_key: str
227
+ filter_name: str
228
+ wave_eff: Union[str, float, int] = None
229
+ wave_min: Union[str, float, int] = None
230
+ wave_max: Union[str, float, int] = None
231
+ freq_eff: Union[str, float, int] = None
232
+ freq_min: Union[str, float, int] = None
233
+ freq_max: Union[str, float, int] = None
234
+ zp: Union[str, float, int] = None
235
+ wave_units: Union[str, float, int] = None
236
+ freq_units: Union[str, float, int] = None
237
+ zp_units: Union[str, float, int] = None
238
+ zp_system: Union[str, float, int] = None
239
+
240
+
241
+ class HostSchema(BaseModel):
242
+ reference: Union[str, List[str]]
243
+ host_ra: Optional[Union[str, float]] = None
244
+ host_dec: Optional[Union[str, float]] = None
245
+ host_ra_units: Optional[str] = None
246
+ host_dec_units: Optional[str] = None
247
+ host_z: Optional[Union[str, int, float]] = None
248
+ host_type: Optional[str] = None
249
+ host_name: Optional[str] = None
250
+
251
+ @model_validator(mode="after")
252
+ def _has_coordinate_or_name(self):
253
+ has_coordinate = self.host_ra is not None and self.host_dec is not None
254
+ has_name = self.host_name is not None
255
+
256
+ # if it has the RA/Dec keys, make sure it also has ra_unit, dec_unit keys
257
+ if has_coordinate:
258
+ if self.host_ra_units is None:
259
+ raise ValidationError("Need RA unit if coordinates are provided!")
260
+ if self.host_dec_units is None:
261
+ raise ValidationError("Need Dec unit if coordinates are provided!")
262
+
263
+ # we need either the coordinate or name to identify this object
264
+ # Both are okay too (more info is always better)
265
+ if not has_coordinate and not has_name:
266
+ raise ValidationError(
267
+ "Need to provide a Host name and/or host coordinates!"
268
+ )
269
+
270
+ # Make sure that if one of RA/Dec is given then both are given
271
+ if (self.host_ra is None and self.host_dec is not None) or (
272
+ self.host_ra is not None and self.host_dec is None
273
+ ):
274
+ raise ValidationError(
275
+ "Please provide RA AND Dec, not just one or the other!"
276
+ )
277
+
278
+ return self
279
+
280
+
281
+ class OtterSchema(BaseModel):
282
+ schema_version: Optional[VersionSchema] = None
283
+ name: NameSchema
284
+ coordinate: list[CoordinateSchema]
285
+ distance: Optional[list[DistanceSchema]] = None
286
+ classification: Optional[list[ClassificationSchema]] = None
287
+ reference_alias: list[ReferenceSchema]
288
+ date_reference: Optional[list[DateSchema]] = None
289
+ photometry: Optional[list[PhotometrySchema]] = None
290
+ filter_alias: Optional[list[FilterSchema]] = None
291
+ host: Optional[list[HostSchema]] = None
292
+
293
+ @model_validator(mode="after")
294
+ def _verify_filter_alias(self):
295
+ if self.photometry is not None and self.filter_alias is None:
296
+ raise ValidationError("filter_alias is needed if photometry is given!")