astro-otter 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

otter/io/transient.py CHANGED
@@ -209,6 +209,36 @@ class Transient(MutableMapping):
209
209
  "host",
210
210
  }
211
211
 
212
+ merge_subkeys_map = {
213
+ "name": None,
214
+ "date_reference": ["value", "date_format", "date_type"],
215
+ "coordinate": None, # may need to update this if we run into problems
216
+ "distance": ["value", "distance_type", "unit"],
217
+ "filter_alias": None,
218
+ "schema_version": None,
219
+ "photometry": None,
220
+ "classification": None,
221
+ "host": [
222
+ "host_ra",
223
+ "host_dec",
224
+ "host_ra_units",
225
+ "host_dec_units",
226
+ "host_name",
227
+ ],
228
+ }
229
+
230
+ groupby_key_for_default_map = {
231
+ "name": None,
232
+ "date_reference": "date_type",
233
+ "coordinate": "coordinate_type",
234
+ "distance": "distance_type",
235
+ "filter_alias": None,
236
+ "schema_version": None,
237
+ "photometry": None,
238
+ "classification": None,
239
+ "host": None,
240
+ }
241
+
212
242
  # create a blank dictionary since we don't want to overwrite this object
213
243
  out = {}
214
244
 
@@ -244,14 +274,19 @@ class Transient(MutableMapping):
244
274
 
245
275
  # There are some special keys that we are expecting
246
276
  if key in allowed_keywords:
247
- Transient._merge_arbitrary(key, self, other, out)
277
+ Transient._merge_arbitrary(
278
+ key,
279
+ self,
280
+ other,
281
+ out,
282
+ merge_subkeys=merge_subkeys_map[key],
283
+ groupby_key=groupby_key_for_default_map[key],
284
+ )
248
285
  else:
249
286
  # this is an unexpected key!
250
287
  if strict_merge:
251
288
  # since this is a strict merge we don't want unexpected data!
252
- raise TransientMergeError(
253
- f"{key} was not expected! Only keeping the old information!"
254
- )
289
+ raise TransientMergeError(f"{key} was not expected! Can not merge!")
255
290
  else:
256
291
  # Throw a warning and only keep the old stuff
257
292
  warnings.warn(
@@ -332,7 +367,14 @@ class Transient(MutableMapping):
332
367
  astropy.time.Time of the default discovery date
333
368
  """
334
369
  key = "date_reference"
335
- date = self._get_default(key, filt='df["date_type"] == "discovery"')
370
+ try:
371
+ date = self._get_default(key, filt='df["date_type"] == "discovery"')
372
+ except KeyError:
373
+ return None
374
+
375
+ if date is None:
376
+ return date
377
+
336
378
  if "date_format" in date:
337
379
  f = date["date_format"]
338
380
  else:
@@ -370,11 +412,13 @@ class Transient(MutableMapping):
370
412
  return default
371
413
  return default.object_class, default.confidence, default.reference
372
414
 
373
- def get_host(self, max_hosts=3, **kwargs) -> list[Host]:
415
+ def get_host(self, max_hosts=3, search=False, **kwargs) -> list[Host]:
374
416
  """
375
417
  Gets the default host information of this Transient. This returns an otter.Host
376
- object. If no host is known in OTTER, it uses astro-ghost to find the best
377
- match.
418
+ object. If search=True, it will also check the BLAST host association database
419
+ for the best match and return it as well. Note that if search is True then
420
+ this has the potential to return max_hosts + 1, if BLAST also returns a result.
421
+ The BLAST result will always be the last value in the returned list.
378
422
 
379
423
  Args:
380
424
  max_hosts [int] : The maximum number of hosts to return
@@ -385,38 +429,25 @@ class Transient(MutableMapping):
385
429
  useful methods for querying public catalogs for data of the host.
386
430
  """
387
431
  # first try to get the host information from our local database
432
+ host = []
388
433
  if "host" in self:
389
- host = [
390
- Host(transient_name=self.default_name, **dict(h)) for h in self["host"]
391
- ]
434
+ max_hosts = min([max_hosts, len(self["host"])])
435
+ for h in self["host"][:max_hosts]:
436
+ host.append(Host(transient_name=self.default_name, **dict(h)))
392
437
 
393
- # then try astro-ghost
394
- else:
438
+ # then try BLAST
439
+ if search:
395
440
  logger.warn(
396
- "No host known, trying to find it with astro-ghost. \
397
- See https://uiucsnastro-ghost.readthedocs.io/en/latest/index.html"
441
+ "Trying to find a host with BLAST/astro-ghost. Note\
442
+ that this won't work for older targets! See https://blast.scimma.org"
398
443
  )
399
444
 
400
- # this import has to be here otherwise the code breaks
401
- from astro_ghost.ghostHelperFunctions import getTransientHosts, getGHOST
402
-
403
- getGHOST(real=False, verbose=1)
404
- res = getTransientHosts(
405
- [self.default_name], [self.get_skycoord()], verbose=False
406
- )
407
-
408
- host = [
409
- Host(
410
- host_ra=row["raStack"],
411
- host_dec=row["decStack"],
412
- host_ra_units="deg",
413
- host_dec_units="deg",
414
- host_name=row["objName"],
415
- transient_name=self.default_name,
416
- reference=["astro-ghost"],
417
- )
418
- for i, row in res.iterrows()
419
- ]
445
+ # default_name should always be the TNS name if we have one
446
+ print(self.default_name)
447
+ blast_host = Host.query_blast(self.default_name)
448
+ print(blast_host)
449
+ if blast_host is not None:
450
+ host.append(blast_host)
420
451
 
421
452
  return host
422
453
 
@@ -433,6 +464,9 @@ class Transient(MutableMapping):
433
464
  raise KeyError(f"This transient does not have {key} associated with it!")
434
465
 
435
466
  df = pd.DataFrame(self[key])
467
+ if len(df) == 0:
468
+ raise KeyError(f"This transient does not have {key} associated with it!")
469
+
436
470
  if filt is not None:
437
471
  df = df[eval(filt)] # apply the filters
438
472
 
@@ -446,6 +480,7 @@ class Transient(MutableMapping):
446
480
 
447
481
  if len(df_filtered) == 0:
448
482
  return None
483
+
449
484
  return df_filtered.iloc[0]
450
485
 
451
486
  def _reformat_coordinate(self, item):
@@ -515,6 +550,9 @@ class Transient(MutableMapping):
515
550
  raise IOError("Please choose either value or raw!")
516
551
 
517
552
  # turn the photometry key into a pandas dataframe
553
+ if "photometry" not in self:
554
+ raise FailedQueryError("No photometry for this object!")
555
+
518
556
  dfs = []
519
557
  for item in self["photometry"]:
520
558
  max_len = 0
@@ -531,9 +569,29 @@ class Transient(MutableMapping):
531
569
  df = pd.DataFrame(item)
532
570
  dfs.append(df)
533
571
 
572
+ if len(dfs) == 0:
573
+ raise FailedQueryError("No photometry for this object!")
534
574
  c = pd.concat(dfs)
535
575
 
576
+ # extract the filter information and substitute in any missing columns
577
+ # because of how we handle this later, we just need to make sure the effective
578
+ # wavelengths are never nan
579
+ def fill_wave(row):
580
+ if "wave_eff" not in row or (
581
+ pd.isna(row.wave_eff) and not pd.isna(row.freq_eff)
582
+ ):
583
+ freq_eff = row.freq_eff * u.Unit(row.freq_units)
584
+ wave_eff = freq_eff.to(u.Unit(wave_unit), equivalencies=u.spectral())
585
+ return wave_eff.value, wave_unit
586
+ elif not pd.isna(row.wave_eff):
587
+ return row.wave_eff, row.wave_units
588
+ else:
589
+ raise ValueError("Missing frequency or wavelength information!")
590
+
536
591
  filters = pd.DataFrame(self["filter_alias"])
592
+ res = filters.apply(fill_wave, axis=1)
593
+ filters["wave_eff"], filters["wave_units"] = zip(*res)
594
+ # merge the photometry with the filter information
537
595
  df = c.merge(filters, on="filter_key")
538
596
 
539
597
  # make sure 'by' is in df
@@ -546,6 +604,14 @@ class Transient(MutableMapping):
546
604
  # skip rows where 'by' is nan
547
605
  df = df[df[by].notna()]
548
606
 
607
+ # remove rows where the flux is less than zero since this is nonphysical
608
+ # See Mummery et al. (2023) Section 5.2 for why we need to do this when using
609
+ # ZTF data:
610
+ # "Because the origin of the negative late-time flux is currently un-
611
+ # known (and under investigation), we have not attempted to correct
612
+ # the TDE lightcurves for this systematic effect. "
613
+ df = df[df[by].astype(float) > 0]
614
+
549
615
  # drop irrelevant obs_types before continuing
550
616
  if obs_type is not None:
551
617
  valid_obs_types = {"radio", "uvoir", "xray"}
@@ -568,6 +634,7 @@ class Transient(MutableMapping):
568
634
 
569
635
  # Figure out what columns are good to groupby in the photometry
570
636
  outdata = []
637
+
571
638
  if "telescope" in df:
572
639
  tele = True
573
640
  to_grp_by = ["obs_type", by + "_units", "telescope"]
@@ -595,6 +662,12 @@ class Transient(MutableMapping):
595
662
  try:
596
663
  if isvegamag:
597
664
  astropy_units = VEGAMAG
665
+ elif unit == "AB":
666
+ # In astropy "AB" is a magnitude SYSTEM not unit and while
667
+ # u.Unit("AB") will succeed without error, it will not produce
668
+ # the expected result!
669
+ # We can assume here that this unit really means astropy's "mag(AB)"
670
+ astropy_units = u.Unit("mag(AB)")
598
671
  else:
599
672
  astropy_units = u.Unit(unit)
600
673
 
@@ -619,24 +692,22 @@ class Transient(MutableMapping):
619
692
  indata_err = np.array(data[by + "_err"].astype(float))
620
693
  else:
621
694
  indata_err = np.zeros(len(data))
695
+
696
+ # convert to an astropy quantity
622
697
  q = indata * u.Unit(astropy_units)
623
698
  q_err = indata_err * u.Unit(
624
699
  astropy_units
625
700
  ) # assume error and values have the same unit
626
701
 
627
702
  # get and save the effective wavelength
628
- if "freq_eff" in data and not np.isnan(data["freq_eff"].iloc[0]):
629
- zz = zip(data["freq_eff"], data["freq_units"])
630
- freq_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], freq_unit)
631
- wave_eff = freq_eff.to(wave_unit, equivalencies=u.spectral())
703
+ # because of cleaning we did to the filter dataframe above wave_eff
704
+ # should NEVER be nan!
705
+ if np.any(pd.isna(data["wave_eff"])):
706
+ raise ValueError("Flushing out the effective wavelength array failed!")
632
707
 
633
- elif "wave_eff" in data and not np.isnan(data["wave_eff"].iloc[0]):
634
- zz = zip(data["wave_eff"], data["wave_units"])
635
- wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
636
- freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
637
-
638
- else:
639
- raise ValueError("No known frequency or wavelength, please fix!")
708
+ zz = zip(data["wave_eff"], data["wave_units"])
709
+ wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
710
+ freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
640
711
 
641
712
  data["converted_wave"] = wave_eff.value
642
713
  data["converted_wave_unit"] = wave_unit
@@ -656,7 +727,7 @@ class Transient(MutableMapping):
656
727
  )
657
728
  else:
658
729
  raise OtterLimitationError(
659
- "Can not convert x-ray data without a " + "telescope"
730
+ "Can not convert x-ray data without a telescope"
660
731
  )
661
732
 
662
733
  # we also need to make this wave_min and wave_max
@@ -685,24 +756,32 @@ class Transient(MutableMapping):
685
756
  u.Unit(flux_unit),
686
757
  vegaspec=SourceSpectrum.from_vega(),
687
758
  area=area,
688
- )
689
- f_err = convert_flux(
690
- wave,
691
- xray_point_err,
692
- u.Unit(flux_unit),
693
- vegaspec=SourceSpectrum.from_vega(),
694
- area=area,
759
+ ).value
760
+
761
+ # approximate the uncertainty as dX = dY/Y * X
762
+ f_err = np.multiply(
763
+ f_val, np.divide(xray_point_err.value, xray_point.value)
695
764
  )
696
765
 
697
766
  # then we take the average of the minimum and maximum values
698
767
  # computed by syncphot
699
- flux.append(np.mean(f_val).value)
700
- flux_err.append(np.mean(f_err).value)
768
+ flux.append(np.mean(f_val))
769
+ flux_err.append(np.mean(f_err))
701
770
 
702
771
  else:
703
772
  # this will be faster and cover most cases
704
- flux = convert_flux(wave_eff, q, u.Unit(flux_unit))
705
- flux_err = convert_flux(wave_eff, q_err, u.Unit(flux_unit))
773
+ flux = convert_flux(wave_eff, q, u.Unit(flux_unit)).value
774
+
775
+ # since the error propagation is different between logarithmic units
776
+ # and linear units, unfortunately
777
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
778
+ # approximate the uncertainty as dX = dY/Y * |ln(10)/2.5|
779
+ prefactor = np.abs(np.log(10) / 2.5) # this is basically 1
780
+ else:
781
+ # approximate the uncertainty as dX = dY/Y * X
782
+ prefactor = flux
783
+
784
+ flux_err = np.multiply(prefactor, np.divide(q_err.value, q.value))
706
785
 
707
786
  flux = np.array(flux) * u.Unit(flux_unit)
708
787
  flux_err = np.array(flux_err) * u.Unit(flux_unit)
@@ -726,6 +805,26 @@ class Transient(MutableMapping):
726
805
  outdata["converted_date"] = times
727
806
  outdata["converted_date_unit"] = date_unit
728
807
 
808
+ # compute the upperlimit value based on a 3 sigma detection
809
+ # this is just for rows where we don't already know if it is an upperlimit
810
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
811
+ # this uses the following formula (which is surprising because it means
812
+ # magnitude upperlimits are independent of the actual measurement!)
813
+ # sigma_m > (1/3) * (ln(10)/2.5)
814
+ def is_upperlimit(row):
815
+ if "upperlimit" in row and pd.isna(row.upperlimit):
816
+ return row.converted_flux_err > np.log(10) / (3 * 2.5)
817
+ else:
818
+ return row.upperlimit
819
+ else:
820
+
821
+ def is_upperlimit(row):
822
+ if "upperlimit" in row and pd.isna(row.upperlimit):
823
+ return row.converted_flux < 3 * row.converted_flux_err
824
+ else:
825
+ return row.upperlimit
826
+
827
+ outdata["upperlimit"] = outdata.apply(is_upperlimit, axis=1)
729
828
  return outdata
730
829
 
731
830
  def _merge_names(t1, t2, out): # noqa: N805
@@ -831,11 +930,21 @@ class Transient(MutableMapping):
831
930
  Just keep whichever schema version is greater
832
931
  """
833
932
  key = "schema_version/value"
834
- if int(t1[key]) > int(t2[key]):
933
+ if "comment" not in t1["schema_version"]:
934
+ t1["schema_version/comment"] = ""
935
+
936
+ if "comment" not in t2["schema_version"]:
937
+ t2["schema_version/comment"] = ""
938
+
939
+ if key in t1 and key in t2 and int(t1[key]) > int(t2[key]):
835
940
  out["schema_version"] = deepcopy(t1["schema_version"])
836
941
  else:
837
942
  out["schema_version"] = deepcopy(t2["schema_version"])
838
943
 
944
+ out["schema_version"]["comment"] = (
945
+ t1["schema_version/comment"] + ";" + t2["schema_version/comment"]
946
+ )
947
+
839
948
  def _merge_photometry(t1, t2, out): # noqa: N805
840
949
  """
841
950
  Combine photometry sources
@@ -913,7 +1022,7 @@ class Transient(MutableMapping):
913
1022
  item["default"] = False
914
1023
 
915
1024
  @staticmethod
916
- def _merge_arbitrary(key, t1, t2, out):
1025
+ def _merge_arbitrary(key, t1, t2, out, merge_subkeys=None, groupby_key=None):
917
1026
  """
918
1027
  Merge two arbitrary datasets inside the json file using pandas
919
1028
 
@@ -940,37 +1049,62 @@ class Transient(MutableMapping):
940
1049
 
941
1050
  # have to get the indexes to drop using a string rep of the df
942
1051
  # this is cause we have lists in some cells
943
- to_drop = merged_with_dups.astype(str).drop_duplicates().index
944
-
945
- merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
946
-
947
- outdict = merged.to_dict(orient="records")
1052
+ # We also need to deal with merging the lists of references across rows
1053
+ # that we deem to be duplicates. This solution to do this quickly is from
1054
+ # https://stackoverflow.com/questions/36271413/ \
1055
+ # pandas-merge-nearly-duplicate-rows-based-on-column-value
1056
+ if merge_subkeys is None:
1057
+ merge_subkeys = merged_with_dups.columns.tolist()
1058
+ merge_subkeys.remove("reference")
1059
+ else:
1060
+ for k in merge_subkeys:
1061
+ if k not in merged_with_dups:
1062
+ merge_subkeys.remove(k)
1063
+
1064
+ merged = (
1065
+ merged_with_dups.astype(str)
1066
+ .groupby(merge_subkeys)["reference"]
1067
+ .apply(lambda x: x.sum())
1068
+ .reset_index()
1069
+ )
948
1070
 
949
- outdict_cleaned = Transient._remove_nans(
950
- outdict
951
- ) # clear out the nans from pandas conversion
1071
+ # then we have to turn the merged reference strings into a string list
1072
+ merged["reference"] = merged.reference.str.replace("][", ",")
952
1073
 
953
- out[key] = outdict_cleaned
1074
+ # then eval the string of a list to get back an actual list of sources
1075
+ merged["reference"] = merged.reference.apply(
1076
+ lambda v: np.unique(eval(v)).tolist()
1077
+ )
954
1078
 
955
- @staticmethod
956
- def _remove_nans(d):
957
- """
958
- Remove nans from a record dictionary
1079
+ # decide on default values
1080
+ if groupby_key is None:
1081
+ iterate_through = [(0, merged)]
1082
+ else:
1083
+ iterate_through = merged.groupby(groupby_key)
1084
+
1085
+ # we will make whichever value has more references the default
1086
+ outdict = []
1087
+ for data_type, df in iterate_through:
1088
+ lengths = df.reference.map(len)
1089
+ max_idx_arr = np.argmax(lengths)
1090
+
1091
+ if isinstance(max_idx_arr, np.int64):
1092
+ max_idx = max_idx_arr
1093
+ elif len(max_idx_arr) == 0:
1094
+ raise ValueError("Something went wrong with deciding the default")
1095
+ else:
1096
+ max_idx = max_idx_arr[0] # arbitrarily choose the first
959
1097
 
960
- THIS IS SLOW: O(n^2)!!! WILL NEED TO BE SPED UP LATER
961
- """
1098
+ defaults = np.full(len(df), False, dtype=bool)
1099
+ defaults[max_idx] = True
962
1100
 
963
- outd = []
964
- for item in d:
965
- outsubd = {}
966
- for key, val in item.items():
967
- if not isinstance(val, float):
968
- # this definitely is not NaN
969
- outsubd[key] = val
1101
+ df["default"] = defaults
1102
+ outdict.append(df)
1103
+ outdict = pd.concat(outdict)
970
1104
 
971
- else:
972
- if not np.isnan(val):
973
- outsubd[key] = val
974
- outd.append(outsubd)
1105
+ # from https://stackoverflow.com/questions/52504972/ \
1106
+ # converting-a-pandas-df-to-json-without-nan
1107
+ outdict = outdict.replace("nan", np.nan)
1108
+ outdict_cleaned = [{**x[i]} for i, x in outdict.stack().groupby(level=0)]
975
1109
 
976
- return outd
1110
+ out[key] = outdict_cleaned
@@ -39,7 +39,7 @@ class OtterPlotter:
39
39
  elif self.backend == "plotly.graph_objects":
40
40
  self.plot = self._plot_plotly
41
41
  else:
42
- raise ValueError("Unknown backend!")
42
+ raise ValueError("Unknown plotting backend!")
43
43
 
44
44
  def _plot_matplotlib(self, x, y, xerr=None, yerr=None, ax=None, **kwargs):
45
45
  """
@@ -53,17 +53,19 @@ class OtterPlotter:
53
53
  ax.errorbar(x, y, xerr=xerr, yerr=yerr, **kwargs)
54
54
  return ax
55
55
 
56
- def _plot_plotly(self, x, y, xerr=None, yerr=None, go=None, *args, **kwargs):
56
+ def _plot_plotly(self, x, y, xerr=None, yerr=None, ax=None, *args, **kwargs):
57
57
  """
58
58
  General plotting method using plotly, is called by _plotly_light_curve and
59
59
  _plotly_sed
60
60
  """
61
61
 
62
- if go is None:
62
+ if ax is None:
63
63
  go = self.plotter.Figure()
64
+ else:
65
+ go = ax
64
66
 
65
67
  fig = go.add_scatter(
66
- x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr)
68
+ x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr), **kwargs
67
69
  )
68
70
 
69
71
  return fig