astro-otter 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

otter/io/transient.py CHANGED
@@ -209,6 +209,36 @@ class Transient(MutableMapping):
209
209
  "host",
210
210
  }
211
211
 
212
+ merge_subkeys_map = {
213
+ "name": None,
214
+ "date_reference": ["value", "date_format", "date_type"],
215
+ "coordinate": None, # may need to update this if we run into problems
216
+ "distance": ["value", "distance_type", "unit"],
217
+ "filter_alias": None,
218
+ "schema_version": None,
219
+ "photometry": None,
220
+ "classification": None,
221
+ "host": [
222
+ "host_ra",
223
+ "host_dec",
224
+ "host_ra_units",
225
+ "host_dec_units",
226
+ "host_name",
227
+ ],
228
+ }
229
+
230
+ groupby_key_for_default_map = {
231
+ "name": None,
232
+ "date_reference": "date_type",
233
+ "coordinate": "coordinate_type",
234
+ "distance": "distance_type",
235
+ "filter_alias": None,
236
+ "schema_version": None,
237
+ "photometry": None,
238
+ "classification": None,
239
+ "host": None,
240
+ }
241
+
212
242
  # create a blank dictionary since we don't want to overwrite this object
213
243
  out = {}
214
244
 
@@ -244,14 +274,19 @@ class Transient(MutableMapping):
244
274
 
245
275
  # There are some special keys that we are expecting
246
276
  if key in allowed_keywords:
247
- Transient._merge_arbitrary(key, self, other, out)
277
+ Transient._merge_arbitrary(
278
+ key,
279
+ self,
280
+ other,
281
+ out,
282
+ merge_subkeys=merge_subkeys_map[key],
283
+ groupby_key=groupby_key_for_default_map[key],
284
+ )
248
285
  else:
249
286
  # this is an unexpected key!
250
287
  if strict_merge:
251
288
  # since this is a strict merge we don't want unexpected data!
252
- raise TransientMergeError(
253
- f"{key} was not expected! Only keeping the old information!"
254
- )
289
+ raise TransientMergeError(f"{key} was not expected! Can not merge!")
255
290
  else:
256
291
  # Throw a warning and only keep the old stuff
257
292
  warnings.warn(
@@ -332,7 +367,14 @@ class Transient(MutableMapping):
332
367
  astropy.time.Time of the default discovery date
333
368
  """
334
369
  key = "date_reference"
335
- date = self._get_default(key, filt='df["date_type"] == "discovery"')
370
+ try:
371
+ date = self._get_default(key, filt='df["date_type"] == "discovery"')
372
+ except KeyError:
373
+ return None
374
+
375
+ if date is None:
376
+ return date
377
+
336
378
  if "date_format" in date:
337
379
  f = date["date_format"]
338
380
  else:
@@ -370,11 +412,13 @@ class Transient(MutableMapping):
370
412
  return default
371
413
  return default.object_class, default.confidence, default.reference
372
414
 
373
- def get_host(self, max_hosts=3, **kwargs) -> list[Host]:
415
+ def get_host(self, max_hosts=3, search=False, **kwargs) -> list[Host]:
374
416
  """
375
417
  Gets the default host information of this Transient. This returns an otter.Host
376
- object. If no host is known in OTTER, it uses astro-ghost to find the best
377
- match.
418
+ object. If search=True, it will also check the BLAST host association database
419
+ for the best match and return it as well. Note that if search is True then
420
+ this has the potential to return max_hosts + 1, if BLAST also returns a result.
421
+ The BLAST result will always be the last value in the returned list.
378
422
 
379
423
  Args:
380
424
  max_hosts [int] : The maximum number of hosts to return
@@ -385,38 +429,25 @@ class Transient(MutableMapping):
385
429
  useful methods for querying public catalogs for data of the host.
386
430
  """
387
431
  # first try to get the host information from our local database
432
+ host = []
388
433
  if "host" in self:
389
- host = [
390
- Host(transient_name=self.default_name, **dict(h)) for h in self["host"]
391
- ]
434
+ max_hosts = min([max_hosts, len(self["host"])])
435
+ for h in self["host"][:max_hosts]:
436
+ host.append(Host(transient_name=self.default_name, **dict(h)))
392
437
 
393
- # then try astro-ghost
394
- else:
438
+ # then try BLAST
439
+ if search:
395
440
  logger.warn(
396
- "No host known, trying to find it with astro-ghost. \
397
- See https://uiucsnastro-ghost.readthedocs.io/en/latest/index.html"
398
- )
399
-
400
- # this import has to be here otherwise the code breaks
401
- from astro_ghost.ghostHelperFunctions import getTransientHosts, getGHOST
402
-
403
- getGHOST(real=False, verbose=1)
404
- res = getTransientHosts(
405
- [self.default_name], [self.get_skycoord()], verbose=False
441
+ "Trying to find a host with BLAST/astro-ghost. Note\
442
+ that this won't work for older targets! See https://blast.scimma.org"
406
443
  )
407
444
 
408
- host = [
409
- Host(
410
- host_ra=row["raStack"],
411
- host_dec=row["decStack"],
412
- host_ra_units="deg",
413
- host_dec_units="deg",
414
- host_name=row["objName"],
415
- transient_name=self.default_name,
416
- reference=["astro-ghost"],
417
- )
418
- for i, row in res.iterrows()
419
- ]
445
+ # default_name should always be the TNS name if we have one
446
+ print(self.default_name)
447
+ blast_host = Host.query_blast(self.default_name)
448
+ print(blast_host)
449
+ if blast_host is not None:
450
+ host.append(blast_host)
420
451
 
421
452
  return host
422
453
 
@@ -433,6 +464,9 @@ class Transient(MutableMapping):
433
464
  raise KeyError(f"This transient does not have {key} associated with it!")
434
465
 
435
466
  df = pd.DataFrame(self[key])
467
+ if len(df) == 0:
468
+ raise KeyError(f"This transient does not have {key} associated with it!")
469
+
436
470
  if filt is not None:
437
471
  df = df[eval(filt)] # apply the filters
438
472
 
@@ -446,6 +480,7 @@ class Transient(MutableMapping):
446
480
 
447
481
  if len(df_filtered) == 0:
448
482
  return None
483
+
449
484
  return df_filtered.iloc[0]
450
485
 
451
486
  def _reformat_coordinate(self, item):
@@ -515,6 +550,9 @@ class Transient(MutableMapping):
515
550
  raise IOError("Please choose either value or raw!")
516
551
 
517
552
  # turn the photometry key into a pandas dataframe
553
+ if "photometry" not in self:
554
+ raise FailedQueryError("No photometry for this object!")
555
+
518
556
  dfs = []
519
557
  for item in self["photometry"]:
520
558
  max_len = 0
@@ -531,9 +569,29 @@ class Transient(MutableMapping):
531
569
  df = pd.DataFrame(item)
532
570
  dfs.append(df)
533
571
 
572
+ if len(dfs) == 0:
573
+ raise FailedQueryError("No photometry for this object!")
534
574
  c = pd.concat(dfs)
535
575
 
576
+ # extract the filter information and substitute in any missing columns
577
+ # because of how we handle this later, we just need to make sure the effective
578
+ # wavelengths are never nan
579
+ def fill_wave(row):
580
+ if "wave_eff" not in row or (
581
+ pd.isna(row.wave_eff) and not pd.isna(row.freq_eff)
582
+ ):
583
+ freq_eff = row.freq_eff * u.Unit(row.freq_units)
584
+ wave_eff = freq_eff.to(u.Unit(wave_unit), equivalencies=u.spectral())
585
+ return wave_eff.value, wave_unit
586
+ elif not pd.isna(row.wave_eff):
587
+ return row.wave_eff, row.wave_units
588
+ else:
589
+ raise ValueError("Missing frequency or wavelength information!")
590
+
536
591
  filters = pd.DataFrame(self["filter_alias"])
592
+ res = filters.apply(fill_wave, axis=1)
593
+ filters["wave_eff"], filters["wave_units"] = zip(*res)
594
+ # merge the photometry with the filter information
537
595
  df = c.merge(filters, on="filter_key")
538
596
 
539
597
  # make sure 'by' is in df
@@ -546,6 +604,14 @@ class Transient(MutableMapping):
546
604
  # skip rows where 'by' is nan
547
605
  df = df[df[by].notna()]
548
606
 
607
+ # remove rows where the flux is less than zero since this is nonphysical
608
+ # See Mummery et al. (2023) Section 5.2 for why we need to do this when using
609
+ # ZTF data:
610
+ # "Because the origin of the negative late-time flux is currently un-
611
+ # known (and under investigation), we have not attempted to correct
612
+ # the TDE lightcurves for this systematic effect. "
613
+ df = df[df[by].astype(float) > 0]
614
+
549
615
  # drop irrelevant obs_types before continuing
550
616
  if obs_type is not None:
551
617
  valid_obs_types = {"radio", "uvoir", "xray"}
@@ -568,6 +634,7 @@ class Transient(MutableMapping):
568
634
 
569
635
  # Figure out what columns are good to groupby in the photometry
570
636
  outdata = []
637
+
571
638
  if "telescope" in df:
572
639
  tele = True
573
640
  to_grp_by = ["obs_type", by + "_units", "telescope"]
@@ -619,24 +686,22 @@ class Transient(MutableMapping):
619
686
  indata_err = np.array(data[by + "_err"].astype(float))
620
687
  else:
621
688
  indata_err = np.zeros(len(data))
689
+
690
+ # convert to an astropy quantity
622
691
  q = indata * u.Unit(astropy_units)
623
692
  q_err = indata_err * u.Unit(
624
693
  astropy_units
625
694
  ) # assume error and values have the same unit
626
695
 
627
696
  # get and save the effective wavelength
628
- if "freq_eff" in data and not np.isnan(data["freq_eff"].iloc[0]):
629
- zz = zip(data["freq_eff"], data["freq_units"])
630
- freq_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], freq_unit)
631
- wave_eff = freq_eff.to(wave_unit, equivalencies=u.spectral())
697
+ # because of cleaning we did to the filter dataframe above wave_eff
698
+ # should NEVER be nan!
699
+ if np.any(pd.isna(data["wave_eff"])):
700
+ raise ValueError("Flushing out the effective wavelength array failed!")
632
701
 
633
- elif "wave_eff" in data and not np.isnan(data["wave_eff"].iloc[0]):
634
- zz = zip(data["wave_eff"], data["wave_units"])
635
- wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
636
- freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
637
-
638
- else:
639
- raise ValueError("No known frequency or wavelength, please fix!")
702
+ zz = zip(data["wave_eff"], data["wave_units"])
703
+ wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
704
+ freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
640
705
 
641
706
  data["converted_wave"] = wave_eff.value
642
707
  data["converted_wave_unit"] = wave_unit
@@ -656,7 +721,7 @@ class Transient(MutableMapping):
656
721
  )
657
722
  else:
658
723
  raise OtterLimitationError(
659
- "Can not convert x-ray data without a " + "telescope"
724
+ "Can not convert x-ray data without a telescope"
660
725
  )
661
726
 
662
727
  # we also need to make this wave_min and wave_max
@@ -685,24 +750,32 @@ class Transient(MutableMapping):
685
750
  u.Unit(flux_unit),
686
751
  vegaspec=SourceSpectrum.from_vega(),
687
752
  area=area,
688
- )
689
- f_err = convert_flux(
690
- wave,
691
- xray_point_err,
692
- u.Unit(flux_unit),
693
- vegaspec=SourceSpectrum.from_vega(),
694
- area=area,
753
+ ).value
754
+
755
+ # approximate the uncertainty as dX = dY/Y * X
756
+ f_err = np.multiply(
757
+ f_val, np.divide(xray_point_err.value, xray_point.value)
695
758
  )
696
759
 
697
760
  # then we take the average of the minimum and maximum values
698
761
  # computed by syncphot
699
- flux.append(np.mean(f_val).value)
700
- flux_err.append(np.mean(f_err).value)
762
+ flux.append(np.mean(f_val))
763
+ flux_err.append(np.mean(f_err))
701
764
 
702
765
  else:
703
766
  # this will be faster and cover most cases
704
- flux = convert_flux(wave_eff, q, u.Unit(flux_unit))
705
- flux_err = convert_flux(wave_eff, q_err, u.Unit(flux_unit))
767
+ flux = convert_flux(wave_eff, q, u.Unit(flux_unit)).value
768
+
769
+ # since the error propagation is different between logarithmic units
770
+ # and linear units, unfortunately
771
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
772
+ # approximate the uncertainty as dX = dY/Y * |ln(10)/2.5|
773
+ prefactor = np.abs(np.log(10) / 2.5) # this is basically 1
774
+ else:
775
+ # approximate the uncertainty as dX = dY/Y * X
776
+ prefactor = flux
777
+
778
+ flux_err = np.multiply(prefactor, np.divide(q_err.value, q.value))
706
779
 
707
780
  flux = np.array(flux) * u.Unit(flux_unit)
708
781
  flux_err = np.array(flux_err) * u.Unit(flux_unit)
@@ -726,6 +799,27 @@ class Transient(MutableMapping):
726
799
  outdata["converted_date"] = times
727
800
  outdata["converted_date_unit"] = date_unit
728
801
 
802
+ # compute the upperlimit value based on a 3 sigma detection
803
+ # this is just for rows where we don't already know if it is an upperlimit
804
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
805
+ # this uses the following formula (which is surprising because it means
806
+ # magnitude upperlimits are independent of the actual measurement!)
807
+ # sigma_m > (1/3) * (ln(10)/2.5)
808
+ def is_upperlimit(row):
809
+ if pd.isna(row.upperlimit):
810
+ return row.converted_flux_err > np.log(10) / (3 * 2.5)
811
+ else:
812
+ return row.upperlimit
813
+ else:
814
+
815
+ def is_upperlimit(row):
816
+ if pd.isna(row.upperlimit):
817
+ return row.converted_flux < 3 * row.converted_flux_err
818
+ else:
819
+ return row.upperlimit
820
+
821
+ outdata["upperlimit"] = outdata.apply(is_upperlimit, axis=1)
822
+
729
823
  return outdata
730
824
 
731
825
  def _merge_names(t1, t2, out): # noqa: N805
@@ -831,11 +925,21 @@ class Transient(MutableMapping):
831
925
  Just keep whichever schema version is greater
832
926
  """
833
927
  key = "schema_version/value"
834
- if int(t1[key]) > int(t2[key]):
928
+ if "comment" not in t1["schema_version"]:
929
+ t1["schema_version/comment"] = ""
930
+
931
+ if "comment" not in t2["schema_version"]:
932
+ t2["schema_version/comment"] = ""
933
+
934
+ if key in t1 and key in t2 and int(t1[key]) > int(t2[key]):
835
935
  out["schema_version"] = deepcopy(t1["schema_version"])
836
936
  else:
837
937
  out["schema_version"] = deepcopy(t2["schema_version"])
838
938
 
939
+ out["schema_version"]["comment"] = (
940
+ t1["schema_version/comment"] + ";" + t2["schema_version/comment"]
941
+ )
942
+
839
943
  def _merge_photometry(t1, t2, out): # noqa: N805
840
944
  """
841
945
  Combine photometry sources
@@ -913,7 +1017,7 @@ class Transient(MutableMapping):
913
1017
  item["default"] = False
914
1018
 
915
1019
  @staticmethod
916
- def _merge_arbitrary(key, t1, t2, out):
1020
+ def _merge_arbitrary(key, t1, t2, out, merge_subkeys=None, groupby_key=None):
917
1021
  """
918
1022
  Merge two arbitrary datasets inside the json file using pandas
919
1023
 
@@ -940,37 +1044,62 @@ class Transient(MutableMapping):
940
1044
 
941
1045
  # have to get the indexes to drop using a string rep of the df
942
1046
  # this is cause we have lists in some cells
943
- to_drop = merged_with_dups.astype(str).drop_duplicates().index
944
-
945
- merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
946
-
947
- outdict = merged.to_dict(orient="records")
1047
+ # We also need to deal with merging the lists of references across rows
1048
+ # that we deem to be duplicates. This solution to do this quickly is from
1049
+ # https://stackoverflow.com/questions/36271413/ \
1050
+ # pandas-merge-nearly-duplicate-rows-based-on-column-value
1051
+ if merge_subkeys is None:
1052
+ merge_subkeys = merged_with_dups.columns.tolist()
1053
+ merge_subkeys.remove("reference")
1054
+ else:
1055
+ for k in merge_subkeys:
1056
+ if k not in merged_with_dups:
1057
+ merge_subkeys.remove(k)
1058
+
1059
+ merged = (
1060
+ merged_with_dups.astype(str)
1061
+ .groupby(merge_subkeys)["reference"]
1062
+ .apply(lambda x: x.sum())
1063
+ .reset_index()
1064
+ )
948
1065
 
949
- outdict_cleaned = Transient._remove_nans(
950
- outdict
951
- ) # clear out the nans from pandas conversion
1066
+ # then we have to turn the merged reference strings into a string list
1067
+ merged["reference"] = merged.reference.str.replace("][", ",")
952
1068
 
953
- out[key] = outdict_cleaned
1069
+ # then eval the string of a list to get back an actual list of sources
1070
+ merged["reference"] = merged.reference.apply(
1071
+ lambda v: np.unique(eval(v)).tolist()
1072
+ )
954
1073
 
955
- @staticmethod
956
- def _remove_nans(d):
957
- """
958
- Remove nans from a record dictionary
1074
+ # decide on default values
1075
+ if groupby_key is None:
1076
+ iterate_through = [(0, merged)]
1077
+ else:
1078
+ iterate_through = merged.groupby(groupby_key)
1079
+
1080
+ # we will make whichever value has more references the default
1081
+ outdict = []
1082
+ for data_type, df in iterate_through:
1083
+ lengths = df.reference.map(len)
1084
+ max_idx_arr = np.argmax(lengths)
1085
+
1086
+ if isinstance(max_idx_arr, np.int64):
1087
+ max_idx = max_idx_arr
1088
+ elif len(max_idx_arr) == 0:
1089
+ raise ValueError("Something went wrong with deciding the default")
1090
+ else:
1091
+ max_idx = max_idx_arr[0] # arbitrarily choose the first
959
1092
 
960
- THIS IS SLOW: O(n^2)!!! WILL NEED TO BE SPED UP LATER
961
- """
1093
+ defaults = np.full(len(df), False, dtype=bool)
1094
+ defaults[max_idx] = True
962
1095
 
963
- outd = []
964
- for item in d:
965
- outsubd = {}
966
- for key, val in item.items():
967
- if not isinstance(val, float):
968
- # this definitely is not NaN
969
- outsubd[key] = val
1096
+ df["default"] = defaults
1097
+ outdict.append(df)
1098
+ outdict = pd.concat(outdict)
970
1099
 
971
- else:
972
- if not np.isnan(val):
973
- outsubd[key] = val
974
- outd.append(outsubd)
1100
+ # from https://stackoverflow.com/questions/52504972/ \
1101
+ # converting-a-pandas-df-to-json-without-nan
1102
+ outdict = outdict.replace("nan", np.nan)
1103
+ outdict_cleaned = [{**x[i]} for i, x in outdict.stack().groupby(level=0)]
975
1104
 
976
- return outd
1105
+ out[key] = outdict_cleaned
@@ -39,7 +39,7 @@ class OtterPlotter:
39
39
  elif self.backend == "plotly.graph_objects":
40
40
  self.plot = self._plot_plotly
41
41
  else:
42
- raise ValueError("Unknown backend!")
42
+ raise ValueError("Unknown plotting backend!")
43
43
 
44
44
  def _plot_matplotlib(self, x, y, xerr=None, yerr=None, ax=None, **kwargs):
45
45
  """
@@ -53,17 +53,19 @@ class OtterPlotter:
53
53
  ax.errorbar(x, y, xerr=xerr, yerr=yerr, **kwargs)
54
54
  return ax
55
55
 
56
- def _plot_plotly(self, x, y, xerr=None, yerr=None, go=None, *args, **kwargs):
56
+ def _plot_plotly(self, x, y, xerr=None, yerr=None, ax=None, *args, **kwargs):
57
57
  """
58
58
  General plotting method using plotly, is called by _plotly_light_curve and
59
59
  _plotly_sed
60
60
  """
61
61
 
62
- if go is None:
62
+ if ax is None:
63
63
  go = self.plotter.Figure()
64
+ else:
65
+ go = ax
64
66
 
65
67
  fig = go.add_scatter(
66
- x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr)
68
+ x=x, y=y, error_x=dict(array=xerr), error_y=dict(array=yerr), **kwargs
67
69
  )
68
70
 
69
71
  return fig