astro-otter 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

otter/io/transient.py CHANGED
@@ -3,10 +3,13 @@ Class for a transient,
3
3
  basically just inherits the dict properties with some overwriting
4
4
  """
5
5
 
6
+ from __future__ import annotations
6
7
  import warnings
7
8
  from copy import deepcopy
8
9
  import re
9
10
  from collections.abc import MutableMapping
11
+ from typing_extensions import Self
12
+ import logging
10
13
 
11
14
  import numpy as np
12
15
  import pandas as pd
@@ -15,9 +18,6 @@ import astropy.units as u
15
18
  from astropy.time import Time
16
19
  from astropy.coordinates import SkyCoord
17
20
 
18
- from synphot.units import VEGAMAG, convert_flux
19
- from synphot.spectrum import SourceSpectrum
20
-
21
21
  from ..exceptions import (
22
22
  FailedQueryError,
23
23
  IOError,
@@ -25,10 +25,12 @@ from ..exceptions import (
25
25
  TransientMergeError,
26
26
  )
27
27
  from ..util import XRAY_AREAS
28
+ from .host import Host
28
29
 
29
30
  warnings.simplefilter("once", RuntimeWarning)
30
31
  warnings.simplefilter("once", UserWarning)
31
32
  np.seterr(divide="ignore")
33
+ logger = logging.getLogger(__name__)
32
34
 
33
35
 
34
36
  class Transient(MutableMapping):
@@ -37,7 +39,9 @@ class Transient(MutableMapping):
37
39
  Overwrite the dictionary init
38
40
 
39
41
  Args:
40
- d [dict]: A transient dictionary
42
+ d (dict): A transient dictionary
43
+ name (str): The default name of the transient, default is None and it will
44
+ be inferred from the input dictionary.
41
45
  """
42
46
  self.data = d
43
47
 
@@ -68,7 +72,7 @@ class Transient(MutableMapping):
68
72
  """
69
73
 
70
74
  if isinstance(keys, (list, tuple)):
71
- return Transient({key: self[key] for key in keys})
75
+ return Transient({key: (self[key] if key in self else []) for key in keys})
72
76
  elif isinstance(keys, str) and "/" in keys: # this is for a path
73
77
  s = "']['".join(keys.split("/"))
74
78
  s = "['" + s
@@ -85,6 +89,10 @@ class Transient(MutableMapping):
85
89
  return self.data[keys]
86
90
 
87
91
  def __setitem__(self, key, value):
92
+ """
93
+ Override set item to work with the '/' syntax
94
+ """
95
+
88
96
  if isinstance(key, str) and "/" in key: # this is for a path
89
97
  s = "']['".join(key.split("/"))
90
98
  s = "['" + s
@@ -109,7 +117,7 @@ class Transient(MutableMapping):
109
117
 
110
118
  def __repr__(self, html=False):
111
119
  if not html:
112
- return str(self.data)
120
+ return f"Transient(\n\tName: {self.default_name},\n\tKeys: {self.keys()}\n)"
113
121
  else:
114
122
  html = ""
115
123
 
@@ -188,6 +196,19 @@ class Transient(MutableMapping):
188
196
  + " You can set strict_merge=False to override the check"
189
197
  )
190
198
 
199
+ # create set of the allowed keywords
200
+ allowed_keywords = {
201
+ "name",
202
+ "date_reference",
203
+ "coordinate",
204
+ "distance",
205
+ "filter_alias",
206
+ "schema_version",
207
+ "photometry",
208
+ "classification",
209
+ "host",
210
+ }
211
+
191
212
  # create a blank dictionary since we don't want to overwrite this object
192
213
  out = {}
193
214
 
@@ -222,24 +243,8 @@ class Transient(MutableMapping):
222
243
  continue
223
244
 
224
245
  # There are some special keys that we are expecting
225
- if key == "name":
226
- self._merge_names(other, out)
227
- elif key == "coordinate":
228
- self._merge_coords(other, out)
229
- elif key == "date_reference":
230
- self._merge_date(other, out)
231
- elif key == "distance":
232
- self._merge_distance(other, out)
233
- elif key == "filter_alias":
234
- self._merge_filter_alias(other, out)
235
- elif key == "schema_version":
236
- self._merge_schema_version(other, out)
237
- elif key == "photometry":
238
- self._merge_photometry(other, out)
239
- elif key == "spectra":
240
- self._merge_spectra(other, out)
241
- elif key == "classification":
242
- self._merge_class(other, out)
246
+ if key in allowed_keywords:
247
+ Transient._merge_arbitrary(key, self, other, out)
243
248
  else:
244
249
  # this is an unexpected key!
245
250
  if strict_merge:
@@ -260,14 +265,17 @@ class Transient(MutableMapping):
260
265
  # now return out as a Transient Object
261
266
  return Transient(out)
262
267
 
263
- def get_meta(self, keys=None):
268
+ def get_meta(self, keys=None) -> Self:
264
269
  """
265
270
  Get the metadata (no photometry or spectra)
266
271
 
267
272
  This essentially just wraps on __getitem__ but with some checks
268
273
 
269
274
  Args:
270
- keys [list[str]] : list of keys
275
+ keys (list[str]) : list of keys to get the metadata for from the transient
276
+
277
+ Returns:
278
+ A Transient object of just the meta data
271
279
  """
272
280
  if keys is None:
273
281
  keys = list(self.keys())
@@ -296,9 +304,16 @@ class Transient(MutableMapping):
296
304
 
297
305
  return self[keys]
298
306
 
299
- def get_skycoord(self, coord_format="icrs"):
307
+ def get_skycoord(self, coord_format="icrs") -> SkyCoord:
300
308
  """
301
309
  Convert the coordinates to an astropy SkyCoord
310
+
311
+ Args:
312
+ coord_format (str): Astropy coordinate format to convert the SkyCoord to
313
+ defaults to icrs.
314
+
315
+ Returns:
316
+ Astropy.coordinates.SkyCoord of the default coordinate for the transient
302
317
  """
303
318
 
304
319
  # now we can generate the SkyCoord
@@ -309,9 +324,12 @@ class Transient(MutableMapping):
309
324
 
310
325
  return coord
311
326
 
312
- def get_discovery_date(self):
327
+ def get_discovery_date(self) -> Time:
313
328
  """
314
- Get the default discovery date
329
+ Get the default discovery date for this Transient
330
+
331
+ Returns:
332
+ astropy.time.Time of the default discovery date
315
333
  """
316
334
  key = "date_reference"
317
335
  date = self._get_default(key, filt='df["date_type"] == "discovery"')
@@ -320,11 +338,14 @@ class Transient(MutableMapping):
320
338
  else:
321
339
  f = "mjd"
322
340
 
323
- return Time(date["value"], format=f)
341
+ return Time(str(date["value"]).strip(), format=f)
324
342
 
325
- def get_redshift(self):
343
+ def get_redshift(self) -> float:
326
344
  """
327
- Get the default redshift
345
+ Get the default redshift of this Transient
346
+
347
+ Returns:
348
+ Float value of the default redshift
328
349
  """
329
350
  f = "df['distance_type']=='redshift'"
330
351
  default = self._get_default("distance", filt=f)
@@ -333,7 +354,73 @@ class Transient(MutableMapping):
333
354
  else:
334
355
  return default["value"]
335
356
 
336
- def _get_default(self, key, filt=""):
357
+ def get_classification(self) -> tuple(str, float, list):
358
+ """
359
+ Get the default classification of this Transient.
360
+ This normally corresponds to the highest confidence classification that we have
361
+ stored for the transient.
362
+
363
+ Returns:
364
+ The default object class as a string, the confidence level in that class,
365
+ and a list of the bibcodes corresponding to that classification. Or, None
366
+ if there is no classification.
367
+ """
368
+ default = self._get_default("classification")
369
+ if default is None:
370
+ return default
371
+ return default.object_class, default.confidence, default.reference
372
+
373
+ def get_host(self, max_hosts=3, **kwargs) -> list[Host]:
374
+ """
375
+ Gets the default host information of this Transient. This returns an otter.Host
376
+ object. If no host is known in OTTER, it uses astro-ghost to find the best
377
+ match.
378
+
379
+ Args:
380
+ max_hosts [int] : The maximum number of hosts to return
381
+ **kwargs : keyword arguments to be passed to getGHOST
382
+
383
+ Returns:
384
+ A list of otter.Host objects. This is useful becuase the Host objects have
385
+ useful methods for querying public catalogs for data of the host.
386
+ """
387
+ # first try to get the host information from our local database
388
+ if "host" in self:
389
+ host = [
390
+ Host(transient_name=self.default_name, **dict(h)) for h in self["host"]
391
+ ]
392
+
393
+ # then try astro-ghost
394
+ else:
395
+ logger.warn(
396
+ "No host known, trying to find it with astro-ghost. \
397
+ See https://uiucsnastro-ghost.readthedocs.io/en/latest/index.html"
398
+ )
399
+
400
+ # this import has to be here otherwise the code breaks
401
+ from astro_ghost.ghostHelperFunctions import getTransientHosts, getGHOST
402
+
403
+ getGHOST(real=False, verbose=1)
404
+ res = getTransientHosts(
405
+ [self.default_name], [self.get_skycoord()], verbose=False
406
+ )
407
+
408
+ host = [
409
+ Host(
410
+ host_ra=row["raStack"],
411
+ host_dec=row["decStack"],
412
+ host_ra_units="deg",
413
+ host_dec_units="deg",
414
+ host_name=row["objName"],
415
+ transient_name=self.default_name,
416
+ reference=["astro-ghost"],
417
+ )
418
+ for i, row in res.iterrows()
419
+ ]
420
+
421
+ return host
422
+
423
+ def _get_default(self, key, filt=None):
337
424
  """
338
425
  Get the default of key
339
426
 
@@ -346,7 +433,8 @@ class Transient(MutableMapping):
346
433
  raise KeyError(f"This transient does not have {key} associated with it!")
347
434
 
348
435
  df = pd.DataFrame(self[key])
349
- df = df[eval(filt)] # apply the filters
436
+ if filt is not None:
437
+ df = df[eval(filt)] # apply the filters
350
438
 
351
439
  if "default" in df:
352
440
  # first try to get the default
@@ -390,11 +478,37 @@ class Transient(MutableMapping):
390
478
  wave_unit: u.Unit = "nm",
391
479
  by: str = "raw",
392
480
  obs_type: str = None,
393
- ):
481
+ ) -> pd.DataFrame:
394
482
  """
395
483
  Ensure the photometry associated with this transient is all in the same
396
484
  units/system/etc
397
- """
485
+
486
+ Args:
487
+ flux_unit (astropy.unit.Unit): The astropy unit or string representation of
488
+ an astropy unit to convert and return the
489
+ flux as.
490
+ date_unit (str): Valid astropy date format string.
491
+ freq_unit (astropy.unit.Unit): The astropy unit or string representation of
492
+ an astropy unit to convert and return the
493
+ frequency as.
494
+ wave_unit (astropy.unit.Unit): The astropy unit or string representation of
495
+ an astropy unit to convert and return the
496
+ wavelength as.
497
+ by (str): Either 'raw' or 'value'. 'raw' is the default and is highly
498
+ recommended! If 'value' is used it may skip some photometry.
499
+ See the schema definition to understand this keyword completely
500
+ before using it.
501
+ obs_type (str): "radio", "xray", or "uvoir". If provided, it only returns
502
+ data taken within that range of wavelengths/frequencies.
503
+ Default is None which will return all of the data.
504
+
505
+ Returns:
506
+ A pandas DataFrame of the cleaned up photometry in the requested units
507
+ """
508
+ # these imports need to be here for some reason
509
+ # otherwise the code breaks
510
+ from synphot.units import VEGAMAG, convert_flux
511
+ from synphot.spectrum import SourceSpectrum
398
512
 
399
513
  # check inputs
400
514
  if by not in {"value", "raw"}:
@@ -429,6 +543,9 @@ class Transient(MutableMapping):
429
543
  else:
430
544
  by = "value"
431
545
 
546
+ # skip rows where 'by' is nan
547
+ df = df[df[by].notna()]
548
+
432
549
  # drop irrelevant obs_types before continuing
433
550
  if obs_type is not None:
434
551
  valid_obs_types = {"radio", "uvoir", "xray"}
@@ -474,8 +591,9 @@ class Transient(MutableMapping):
474
591
  )
475
592
 
476
593
  unit = unit[0]
594
+ isvegamag = "vega" in unit.lower()
477
595
  try:
478
- if "vega" in unit.lower():
596
+ if isvegamag:
479
597
  astropy_units = VEGAMAG
480
598
  else:
481
599
  astropy_units = u.Unit(unit)
@@ -506,25 +624,24 @@ class Transient(MutableMapping):
506
624
  astropy_units
507
625
  ) # assume error and values have the same unit
508
626
 
509
- # get the effective wavelength
627
+ # get and save the effective wavelength
510
628
  if "freq_eff" in data and not np.isnan(data["freq_eff"].iloc[0]):
511
- freq_units = data["freq_units"]
512
- if len(np.unique(freq_units)) > 1:
513
- raise OtterLimitationError(
514
- "Can not convert different units to the same unit!"
515
- )
516
-
517
- freq_eff = np.array(data["freq_eff"]) * u.Unit(freq_units.iloc[0])
518
- wave_eff = freq_eff.to(u.AA, equivalencies=u.spectral())
629
+ zz = zip(data["freq_eff"], data["freq_units"])
630
+ freq_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], freq_unit)
631
+ wave_eff = freq_eff.to(wave_unit, equivalencies=u.spectral())
519
632
 
520
633
  elif "wave_eff" in data and not np.isnan(data["wave_eff"].iloc[0]):
521
- wave_units = data["wave_units"]
522
- if len(np.unique(wave_units)) > 1:
523
- raise OtterLimitationError(
524
- "Can not convert different units to the same unit!"
525
- )
634
+ zz = zip(data["wave_eff"], data["wave_units"])
635
+ wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
636
+ freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
637
+
638
+ else:
639
+ raise ValueError("No known frequency or wavelength, please fix!")
526
640
 
527
- wave_eff = np.array(data["wave_eff"]) * u.Unit(wave_units.iloc[0])
641
+ data["converted_wave"] = wave_eff.value
642
+ data["converted_wave_unit"] = wave_unit
643
+ data["converted_freq"] = freq_eff.value
644
+ data["converted_freq_unit"] = freq_unit
528
645
 
529
646
  # convert using synphot
530
647
  # stuff has to be done slightly differently for xray than for the others
@@ -544,39 +661,48 @@ class Transient(MutableMapping):
544
661
 
545
662
  # we also need to make this wave_min and wave_max
546
663
  # instead of just the effective wavelength like for radio and uvoir
547
- wave_eff = np.array(
548
- list(zip(data["wave_min"], data["wave_max"]))
549
- ) * u.Unit(wave_units.iloc[0])
664
+ zz = zip(data["wave_min"], data["wave_max"], data["wave_units"])
665
+ wave_eff = u.Quantity(
666
+ [np.array([m, M]) * u.Unit(uu) for m, M, uu in zz],
667
+ u.Unit(wave_unit),
668
+ )
550
669
 
551
670
  else:
552
671
  area = None
553
672
 
554
- # we unfortunately have to loop over the points here because
555
- # syncphot does not work with a 2D array of min max wavelengths
556
- # for converting counts to other flux units. It also can't convert
557
- # vega mags with a wavelength array because it then interprets that as the
558
- # wavelengths corresponding to the SourceSpectrum.from_vega()
559
- flux, flux_err = [], []
560
- for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
561
- f_val = convert_flux(
562
- wave,
563
- xray_point,
564
- u.Unit(flux_unit),
565
- vegaspec=SourceSpectrum.from_vega(),
566
- area=area,
567
- )
568
- f_err = convert_flux(
569
- wave,
570
- xray_point_err,
571
- u.Unit(flux_unit),
572
- vegaspec=SourceSpectrum.from_vega(),
573
- area=area,
574
- )
673
+ if obstype == "xray" or isvegamag:
674
+ # we unfortunately have to loop over the points here because
675
+ # syncphot does not work with a 2D array of min max wavelengths
676
+ # for converting counts to other flux units. It also can't convert
677
+ # vega mags with a wavelength array because it interprets that as the
678
+ # wavelengths corresponding to the SourceSpectrum.from_vega()
679
+
680
+ flux, flux_err = [], []
681
+ for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
682
+ f_val = convert_flux(
683
+ wave,
684
+ xray_point,
685
+ u.Unit(flux_unit),
686
+ vegaspec=SourceSpectrum.from_vega(),
687
+ area=area,
688
+ )
689
+ f_err = convert_flux(
690
+ wave,
691
+ xray_point_err,
692
+ u.Unit(flux_unit),
693
+ vegaspec=SourceSpectrum.from_vega(),
694
+ area=area,
695
+ )
696
+
697
+ # then we take the average of the minimum and maximum values
698
+ # computed by syncphot
699
+ flux.append(np.mean(f_val).value)
700
+ flux_err.append(np.mean(f_err).value)
575
701
 
576
- # then we take the average of the minimum and maximum values
577
- # computed by syncphot
578
- flux.append(np.mean(f_val).value)
579
- flux_err.append(np.mean(f_err).value)
702
+ else:
703
+ # this will be faster and cover most cases
704
+ flux = convert_flux(wave_eff, q, u.Unit(flux_unit))
705
+ flux_err = convert_flux(wave_eff, q_err, u.Unit(flux_unit))
580
706
 
581
707
  flux = np.array(flux) * u.Unit(flux_unit)
582
708
  flux_err = np.array(flux_err) * u.Unit(flux_unit)
@@ -590,7 +716,7 @@ class Transient(MutableMapping):
590
716
  outdata = pd.concat(outdata)
591
717
 
592
718
  # copy over the flux units
593
- outdata["converted_flux_unit"] = [flux_unit] * len(outdata)
719
+ outdata["converted_flux_unit"] = flux_unit
594
720
 
595
721
  # make sure all the datetimes are in the same format here too!!
596
722
  times = [
@@ -598,27 +724,7 @@ class Transient(MutableMapping):
598
724
  for d, f in zip(outdata.date, outdata.date_format.str.lower())
599
725
  ]
600
726
  outdata["converted_date"] = times
601
- outdata["converted_date_unit"] = [date_unit] * len(outdata)
602
-
603
- # same with frequencies and wavelengths
604
- freqs = []
605
- waves = []
606
-
607
- for _, row in df.iterrows():
608
- if "freq_eff" in row and not np.isnan(row["freq_eff"]):
609
- val = row["freq_eff"] * u.Unit(row["freq_units"])
610
- elif "wave_eff" in df and not np.isnan(row["wave_eff"]):
611
- val = row["wave_eff"] * u.Unit(row["wave_units"])
612
- else:
613
- raise ValueError("No known frequency or wavelength, please fix!")
614
-
615
- freqs.append(val.to(freq_unit, equivalencies=u.spectral()).value)
616
- waves.append(val.to(wave_unit, equivalencies=u.spectral()).value)
617
-
618
- outdata["converted_freq"] = freqs
619
- outdata["converted_wave"] = waves
620
- outdata["converted_wave_unit"] = [wave_unit] * len(outdata)
621
- outdata["converted_freq_unit"] = [freq_unit] * len(outdata)
727
+ outdata["converted_date_unit"] = date_unit
622
728
 
623
729
  return outdata
624
730
 
@@ -707,16 +813,6 @@ class Transient(MutableMapping):
707
813
  bothlines = [{"value": k, "reference": t1map[k] + t2map[k]} for k in inboth]
708
814
  out[key]["alias"] = line2 + line1 + bothlines
709
815
 
710
- def _merge_coords(t1, t2, out): # noqa: N805
711
- """
712
- Merge the coordinates subdictionaries for t1 and t2 and put it in out
713
-
714
- Use pandas to drop any duplicates
715
- """
716
- key = "coordinate"
717
-
718
- Transient._merge_arbitrary(key, t1, t2, out)
719
-
720
816
  def _merge_filter_alias(t1, t2, out): # noqa: N805
721
817
  """
722
818
  Combine the filter alias lists across the transient objects
@@ -748,8 +844,15 @@ class Transient(MutableMapping):
748
844
  key = "photometry"
749
845
 
750
846
  out[key] = deepcopy(t1[key])
751
- refs = np.array([d["reference"] for d in out[key]])
847
+ refs = [] # np.array([d["reference"] for d in out[key]])
752
848
  # merge_dups = lambda val: np.sum(val) if np.any(val.isna()) else val.iloc[0]
849
+ for val in out[key]:
850
+ if isinstance(val, list):
851
+ refs += val
852
+ elif isinstance(val, np.ndarray):
853
+ refs += list(val)
854
+ else:
855
+ refs.append(val)
753
856
 
754
857
  for val in t2[key]:
755
858
  # first check if t2's reference is in out
@@ -774,12 +877,6 @@ class Transient(MutableMapping):
774
877
 
775
878
  out[key][i1] = newdict # replace the dictionary at i1 with the new dict
776
879
 
777
- def _merge_spectra(t1, t2, out): # noqa: N805
778
- """
779
- Combine spectra sources
780
- """
781
- pass
782
-
783
880
  def _merge_class(t1, t2, out): # noqa: N805
784
881
  """
785
882
  Combine the classification attribute
@@ -815,22 +912,6 @@ class Transient(MutableMapping):
815
912
  else:
816
913
  item["default"] = False
817
914
 
818
- def _merge_date(t1, t2, out): # noqa: N805
819
- """
820
- Combine epoch data across two transients and write it to "out"
821
- """
822
- key = "date_reference"
823
-
824
- Transient._merge_arbitrary(key, t1, t2, out)
825
-
826
- def _merge_distance(t1, t2, out): # noqa: N805
827
- """
828
- Combine distance information for these two transients
829
- """
830
- key = "distance"
831
-
832
- Transient._merge_arbitrary(key, t1, t2, out)
833
-
834
915
  @staticmethod
835
916
  def _merge_arbitrary(key, t1, t2, out):
836
917
  """
@@ -840,24 +921,36 @@ class Transient(MutableMapping):
840
921
  a NxM pandas dataframe!
841
922
  """
842
923
 
843
- df1 = pd.DataFrame(t1[key])
844
- df2 = pd.DataFrame(t2[key])
924
+ if key == "name":
925
+ t1._merge_names(t2, out)
926
+ elif key == "filter_alias":
927
+ t1._merge_filter_alias(t2, out)
928
+ elif key == "schema_version":
929
+ t1._merge_schema_version(t2, out)
930
+ elif key == "photometry":
931
+ t1._merge_photometry(t2, out)
932
+ elif key == "classification":
933
+ t1._merge_class(t2, out)
934
+ else:
935
+ # this is where we can standardize some of the merging
936
+ df1 = pd.DataFrame(t1[key])
937
+ df2 = pd.DataFrame(t2[key])
845
938
 
846
- merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
939
+ merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
847
940
 
848
- # have to get the indexes to drop using a string rep of the df
849
- # this is cause we have lists in some cells
850
- to_drop = merged_with_dups.astype(str).drop_duplicates().index
941
+ # have to get the indexes to drop using a string rep of the df
942
+ # this is cause we have lists in some cells
943
+ to_drop = merged_with_dups.astype(str).drop_duplicates().index
851
944
 
852
- merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
945
+ merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
853
946
 
854
- outdict = merged.to_dict(orient="records")
947
+ outdict = merged.to_dict(orient="records")
855
948
 
856
- outdict_cleaned = Transient._remove_nans(
857
- outdict
858
- ) # clear out the nans from pandas conversion
949
+ outdict_cleaned = Transient._remove_nans(
950
+ outdict
951
+ ) # clear out the nans from pandas conversion
859
952
 
860
- out[key] = outdict_cleaned
953
+ out[key] = outdict_cleaned
861
954
 
862
955
  @staticmethod
863
956
  def _remove_nans(d):
@@ -6,19 +6,21 @@ Currently supported backends are:
6
6
  - plotly
7
7
  """
8
8
 
9
+ from __future__ import annotations
9
10
  import importlib
10
11
 
11
12
 
12
13
  class OtterPlotter:
13
- def __init__(self, backend):
14
- """
15
- Handles the backend for the "plotter" module
14
+ """
15
+ Handles the backend for the "plotter" module
16
16
 
17
- Args:
18
- backend [string]: a string of the module name to import and use
19
- as the backend. Currently supported are "matplotlib",
20
- "matplotlib.pyplot", "plotly", and "plotly.graph_objects"
21
- """
17
+ Args:
18
+ backend (string): a string of the module name to import and use
19
+ as the backend. Currently supported are "matplotlib",
20
+ "matplotlib.pyplot", "plotly", and "plotly.graph_objects"
21
+ """
22
+
23
+ def __init__(self, backend):
22
24
  if backend == "matplotlib.pyplot":
23
25
  self.backend = backend
24
26
  elif backend == "pyplot.graph_objects":