astro-otter 0.0.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

otter/io/transient.py CHANGED
@@ -9,6 +9,7 @@ from copy import deepcopy
9
9
  import re
10
10
  from collections.abc import MutableMapping
11
11
  from typing_extensions import Self
12
+ import logging
12
13
 
13
14
  import numpy as np
14
15
  import pandas as pd
@@ -17,9 +18,6 @@ import astropy.units as u
17
18
  from astropy.time import Time
18
19
  from astropy.coordinates import SkyCoord
19
20
 
20
- from synphot.units import VEGAMAG, convert_flux
21
- from synphot.spectrum import SourceSpectrum
22
-
23
21
  from ..exceptions import (
24
22
  FailedQueryError,
25
23
  IOError,
@@ -27,10 +25,12 @@ from ..exceptions import (
27
25
  TransientMergeError,
28
26
  )
29
27
  from ..util import XRAY_AREAS
28
+ from .host import Host
30
29
 
31
30
  warnings.simplefilter("once", RuntimeWarning)
32
31
  warnings.simplefilter("once", UserWarning)
33
32
  np.seterr(divide="ignore")
33
+ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  class Transient(MutableMapping):
@@ -196,6 +196,49 @@ class Transient(MutableMapping):
196
196
  + " You can set strict_merge=False to override the check"
197
197
  )
198
198
 
199
+ # create set of the allowed keywords
200
+ allowed_keywords = {
201
+ "name",
202
+ "date_reference",
203
+ "coordinate",
204
+ "distance",
205
+ "filter_alias",
206
+ "schema_version",
207
+ "photometry",
208
+ "classification",
209
+ "host",
210
+ }
211
+
212
+ merge_subkeys_map = {
213
+ "name": None,
214
+ "date_reference": ["value", "date_format", "date_type"],
215
+ "coordinate": None, # may need to update this if we run into problems
216
+ "distance": ["value", "distance_type", "unit"],
217
+ "filter_alias": None,
218
+ "schema_version": None,
219
+ "photometry": None,
220
+ "classification": None,
221
+ "host": [
222
+ "host_ra",
223
+ "host_dec",
224
+ "host_ra_units",
225
+ "host_dec_units",
226
+ "host_name",
227
+ ],
228
+ }
229
+
230
+ groupby_key_for_default_map = {
231
+ "name": None,
232
+ "date_reference": "date_type",
233
+ "coordinate": "coordinate_type",
234
+ "distance": "distance_type",
235
+ "filter_alias": None,
236
+ "schema_version": None,
237
+ "photometry": None,
238
+ "classification": None,
239
+ "host": None,
240
+ }
241
+
199
242
  # create a blank dictionary since we don't want to overwrite this object
200
243
  out = {}
201
244
 
@@ -230,31 +273,20 @@ class Transient(MutableMapping):
230
273
  continue
231
274
 
232
275
  # There are some special keys that we are expecting
233
- if key == "name":
234
- self._merge_names(other, out)
235
- elif key == "coordinate":
236
- self._merge_coords(other, out)
237
- elif key == "date_reference":
238
- self._merge_date(other, out)
239
- elif key == "distance":
240
- self._merge_distance(other, out)
241
- elif key == "filter_alias":
242
- self._merge_filter_alias(other, out)
243
- elif key == "schema_version":
244
- self._merge_schema_version(other, out)
245
- elif key == "photometry":
246
- self._merge_photometry(other, out)
247
- elif key == "spectra":
248
- self._merge_spectra(other, out)
249
- elif key == "classification":
250
- self._merge_class(other, out)
276
+ if key in allowed_keywords:
277
+ Transient._merge_arbitrary(
278
+ key,
279
+ self,
280
+ other,
281
+ out,
282
+ merge_subkeys=merge_subkeys_map[key],
283
+ groupby_key=groupby_key_for_default_map[key],
284
+ )
251
285
  else:
252
286
  # this is an unexpected key!
253
287
  if strict_merge:
254
288
  # since this is a strict merge we don't want unexpected data!
255
- raise TransientMergeError(
256
- f"{key} was not expected! Only keeping the old information!"
257
- )
289
+ raise TransientMergeError(f"{key} was not expected! Can not merge!")
258
290
  else:
259
291
  # Throw a warning and only keep the old stuff
260
292
  warnings.warn(
@@ -335,13 +367,20 @@ class Transient(MutableMapping):
335
367
  astropy.time.Time of the default discovery date
336
368
  """
337
369
  key = "date_reference"
338
- date = self._get_default(key, filt='df["date_type"] == "discovery"')
370
+ try:
371
+ date = self._get_default(key, filt='df["date_type"] == "discovery"')
372
+ except KeyError:
373
+ return None
374
+
375
+ if date is None:
376
+ return date
377
+
339
378
  if "date_format" in date:
340
379
  f = date["date_format"]
341
380
  else:
342
381
  f = "mjd"
343
382
 
344
- return Time(date["value"], format=f)
383
+ return Time(str(date["value"]).strip(), format=f)
345
384
 
346
385
  def get_redshift(self) -> float:
347
386
  """
@@ -357,7 +396,62 @@ class Transient(MutableMapping):
357
396
  else:
358
397
  return default["value"]
359
398
 
360
- def _get_default(self, key, filt=""):
399
+ def get_classification(self) -> tuple(str, float, list):
400
+ """
401
+ Get the default classification of this Transient.
402
+ This normally corresponds to the highest confidence classification that we have
403
+ stored for the transient.
404
+
405
+ Returns:
406
+ The default object class as a string, the confidence level in that class,
407
+ and a list of the bibcodes corresponding to that classification. Or, None
408
+ if there is no classification.
409
+ """
410
+ default = self._get_default("classification")
411
+ if default is None:
412
+ return default
413
+ return default.object_class, default.confidence, default.reference
414
+
415
+ def get_host(self, max_hosts=3, search=False, **kwargs) -> list[Host]:
416
+ """
417
+ Gets the default host information of this Transient. This returns an otter.Host
418
+ object. If search=True, it will also check the BLAST host association database
419
+ for the best match and return it as well. Note that if search is True then
420
+ this has the potential to return max_hosts + 1, if BLAST also returns a result.
421
+ The BLAST result will always be the last value in the returned list.
422
+
423
+ Args:
424
+ max_hosts [int] : The maximum number of hosts to return
425
+ **kwargs : keyword arguments to be passed to getGHOST
426
+
427
+ Returns:
428
+ A list of otter.Host objects. This is useful becuase the Host objects have
429
+ useful methods for querying public catalogs for data of the host.
430
+ """
431
+ # first try to get the host information from our local database
432
+ host = []
433
+ if "host" in self:
434
+ max_hosts = min([max_hosts, len(self["host"])])
435
+ for h in self["host"][:max_hosts]:
436
+ host.append(Host(transient_name=self.default_name, **dict(h)))
437
+
438
+ # then try BLAST
439
+ if search:
440
+ logger.warn(
441
+ "Trying to find a host with BLAST/astro-ghost. Note\
442
+ that this won't work for older targets! See https://blast.scimma.org"
443
+ )
444
+
445
+ # default_name should always be the TNS name if we have one
446
+ print(self.default_name)
447
+ blast_host = Host.query_blast(self.default_name)
448
+ print(blast_host)
449
+ if blast_host is not None:
450
+ host.append(blast_host)
451
+
452
+ return host
453
+
454
+ def _get_default(self, key, filt=None):
361
455
  """
362
456
  Get the default of key
363
457
 
@@ -370,7 +464,11 @@ class Transient(MutableMapping):
370
464
  raise KeyError(f"This transient does not have {key} associated with it!")
371
465
 
372
466
  df = pd.DataFrame(self[key])
373
- df = df[eval(filt)] # apply the filters
467
+ if len(df) == 0:
468
+ raise KeyError(f"This transient does not have {key} associated with it!")
469
+
470
+ if filt is not None:
471
+ df = df[eval(filt)] # apply the filters
374
472
 
375
473
  if "default" in df:
376
474
  # first try to get the default
@@ -382,6 +480,7 @@ class Transient(MutableMapping):
382
480
 
383
481
  if len(df_filtered) == 0:
384
482
  return None
483
+
385
484
  return df_filtered.iloc[0]
386
485
 
387
486
  def _reformat_coordinate(self, item):
@@ -441,12 +540,19 @@ class Transient(MutableMapping):
441
540
  Returns:
442
541
  A pandas DataFrame of the cleaned up photometry in the requested units
443
542
  """
543
+ # these imports need to be here for some reason
544
+ # otherwise the code breaks
545
+ from synphot.units import VEGAMAG, convert_flux
546
+ from synphot.spectrum import SourceSpectrum
444
547
 
445
548
  # check inputs
446
549
  if by not in {"value", "raw"}:
447
550
  raise IOError("Please choose either value or raw!")
448
551
 
449
552
  # turn the photometry key into a pandas dataframe
553
+ if "photometry" not in self:
554
+ raise FailedQueryError("No photometry for this object!")
555
+
450
556
  dfs = []
451
557
  for item in self["photometry"]:
452
558
  max_len = 0
@@ -463,9 +569,29 @@ class Transient(MutableMapping):
463
569
  df = pd.DataFrame(item)
464
570
  dfs.append(df)
465
571
 
572
+ if len(dfs) == 0:
573
+ raise FailedQueryError("No photometry for this object!")
466
574
  c = pd.concat(dfs)
467
575
 
576
+ # extract the filter information and substitute in any missing columns
577
+ # because of how we handle this later, we just need to make sure the effective
578
+ # wavelengths are never nan
579
+ def fill_wave(row):
580
+ if "wave_eff" not in row or (
581
+ pd.isna(row.wave_eff) and not pd.isna(row.freq_eff)
582
+ ):
583
+ freq_eff = row.freq_eff * u.Unit(row.freq_units)
584
+ wave_eff = freq_eff.to(u.Unit(wave_unit), equivalencies=u.spectral())
585
+ return wave_eff.value, wave_unit
586
+ elif not pd.isna(row.wave_eff):
587
+ return row.wave_eff, row.wave_units
588
+ else:
589
+ raise ValueError("Missing frequency or wavelength information!")
590
+
468
591
  filters = pd.DataFrame(self["filter_alias"])
592
+ res = filters.apply(fill_wave, axis=1)
593
+ filters["wave_eff"], filters["wave_units"] = zip(*res)
594
+ # merge the photometry with the filter information
469
595
  df = c.merge(filters, on="filter_key")
470
596
 
471
597
  # make sure 'by' is in df
@@ -478,6 +604,14 @@ class Transient(MutableMapping):
478
604
  # skip rows where 'by' is nan
479
605
  df = df[df[by].notna()]
480
606
 
607
+ # remove rows where the flux is less than zero since this is nonphysical
608
+ # See Mummery et al. (2023) Section 5.2 for why we need to do this when using
609
+ # ZTF data:
610
+ # "Because the origin of the negative late-time flux is currently un-
611
+ # known (and under investigation), we have not attempted to correct
612
+ # the TDE lightcurves for this systematic effect. "
613
+ df = df[df[by].astype(float) > 0]
614
+
481
615
  # drop irrelevant obs_types before continuing
482
616
  if obs_type is not None:
483
617
  valid_obs_types = {"radio", "uvoir", "xray"}
@@ -500,6 +634,7 @@ class Transient(MutableMapping):
500
634
 
501
635
  # Figure out what columns are good to groupby in the photometry
502
636
  outdata = []
637
+
503
638
  if "telescope" in df:
504
639
  tele = True
505
640
  to_grp_by = ["obs_type", by + "_units", "telescope"]
@@ -523,8 +658,9 @@ class Transient(MutableMapping):
523
658
  )
524
659
 
525
660
  unit = unit[0]
661
+ isvegamag = "vega" in unit.lower()
526
662
  try:
527
- if "vega" in unit.lower():
663
+ if isvegamag:
528
664
  astropy_units = VEGAMAG
529
665
  else:
530
666
  astropy_units = u.Unit(unit)
@@ -550,30 +686,27 @@ class Transient(MutableMapping):
550
686
  indata_err = np.array(data[by + "_err"].astype(float))
551
687
  else:
552
688
  indata_err = np.zeros(len(data))
689
+
690
+ # convert to an astropy quantity
553
691
  q = indata * u.Unit(astropy_units)
554
692
  q_err = indata_err * u.Unit(
555
693
  astropy_units
556
694
  ) # assume error and values have the same unit
557
695
 
558
- # get the effective wavelength
559
- if "freq_eff" in data and not np.isnan(data["freq_eff"].iloc[0]):
560
- freq_units = data["freq_units"]
561
- if len(np.unique(freq_units)) > 1:
562
- raise OtterLimitationError(
563
- "Can not convert different units to the same unit!"
564
- )
565
-
566
- freq_eff = np.array(data["freq_eff"]) * u.Unit(freq_units.iloc[0])
567
- wave_eff = freq_eff.to(u.AA, equivalencies=u.spectral())
696
+ # get and save the effective wavelength
697
+ # because of cleaning we did to the filter dataframe above wave_eff
698
+ # should NEVER be nan!
699
+ if np.any(pd.isna(data["wave_eff"])):
700
+ raise ValueError("Flushing out the effective wavelength array failed!")
568
701
 
569
- elif "wave_eff" in data and not np.isnan(data["wave_eff"].iloc[0]):
570
- wave_units = data["wave_units"]
571
- if len(np.unique(wave_units)) > 1:
572
- raise OtterLimitationError(
573
- "Can not convert different units to the same unit!"
574
- )
702
+ zz = zip(data["wave_eff"], data["wave_units"])
703
+ wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
704
+ freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
575
705
 
576
- wave_eff = np.array(data["wave_eff"]) * u.Unit(wave_units.iloc[0])
706
+ data["converted_wave"] = wave_eff.value
707
+ data["converted_wave_unit"] = wave_unit
708
+ data["converted_freq"] = freq_eff.value
709
+ data["converted_freq_unit"] = freq_unit
577
710
 
578
711
  # convert using synphot
579
712
  # stuff has to be done slightly differently for xray than for the others
@@ -588,44 +721,61 @@ class Transient(MutableMapping):
588
721
  )
589
722
  else:
590
723
  raise OtterLimitationError(
591
- "Can not convert x-ray data without a " + "telescope"
724
+ "Can not convert x-ray data without a telescope"
592
725
  )
593
726
 
594
727
  # we also need to make this wave_min and wave_max
595
728
  # instead of just the effective wavelength like for radio and uvoir
596
- wave_eff = np.array(
597
- list(zip(data["wave_min"], data["wave_max"]))
598
- ) * u.Unit(wave_units.iloc[0])
729
+ zz = zip(data["wave_min"], data["wave_max"], data["wave_units"])
730
+ wave_eff = u.Quantity(
731
+ [np.array([m, M]) * u.Unit(uu) for m, M, uu in zz],
732
+ u.Unit(wave_unit),
733
+ )
599
734
 
600
735
  else:
601
736
  area = None
602
737
 
603
- # we unfortunately have to loop over the points here because
604
- # syncphot does not work with a 2D array of min max wavelengths
605
- # for converting counts to other flux units. It also can't convert
606
- # vega mags with a wavelength array because it then interprets that as the
607
- # wavelengths corresponding to the SourceSpectrum.from_vega()
608
- flux, flux_err = [], []
609
- for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
610
- f_val = convert_flux(
611
- wave,
612
- xray_point,
613
- u.Unit(flux_unit),
614
- vegaspec=SourceSpectrum.from_vega(),
615
- area=area,
616
- )
617
- f_err = convert_flux(
618
- wave,
619
- xray_point_err,
620
- u.Unit(flux_unit),
621
- vegaspec=SourceSpectrum.from_vega(),
622
- area=area,
623
- )
738
+ if obstype == "xray" or isvegamag:
739
+ # we unfortunately have to loop over the points here because
740
+ # syncphot does not work with a 2D array of min max wavelengths
741
+ # for converting counts to other flux units. It also can't convert
742
+ # vega mags with a wavelength array because it interprets that as the
743
+ # wavelengths corresponding to the SourceSpectrum.from_vega()
744
+
745
+ flux, flux_err = [], []
746
+ for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
747
+ f_val = convert_flux(
748
+ wave,
749
+ xray_point,
750
+ u.Unit(flux_unit),
751
+ vegaspec=SourceSpectrum.from_vega(),
752
+ area=area,
753
+ ).value
754
+
755
+ # approximate the uncertainty as dX = dY/Y * X
756
+ f_err = np.multiply(
757
+ f_val, np.divide(xray_point_err.value, xray_point.value)
758
+ )
759
+
760
+ # then we take the average of the minimum and maximum values
761
+ # computed by syncphot
762
+ flux.append(np.mean(f_val))
763
+ flux_err.append(np.mean(f_err))
624
764
 
625
- # then we take the average of the minimum and maximum values
626
- # computed by syncphot
627
- flux.append(np.mean(f_val).value)
628
- flux_err.append(np.mean(f_err).value)
765
+ else:
766
+ # this will be faster and cover most cases
767
+ flux = convert_flux(wave_eff, q, u.Unit(flux_unit)).value
768
+
769
+ # since the error propagation is different between logarithmic units
770
+ # and linear units, unfortunately
771
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
772
+ # approximate the uncertainty as dX = dY/Y * |ln(10)/2.5|
773
+ prefactor = np.abs(np.log(10) / 2.5) # this is basically 1
774
+ else:
775
+ # approximate the uncertainty as dX = dY/Y * X
776
+ prefactor = flux
777
+
778
+ flux_err = np.multiply(prefactor, np.divide(q_err.value, q.value))
629
779
 
630
780
  flux = np.array(flux) * u.Unit(flux_unit)
631
781
  flux_err = np.array(flux_err) * u.Unit(flux_unit)
@@ -639,7 +789,7 @@ class Transient(MutableMapping):
639
789
  outdata = pd.concat(outdata)
640
790
 
641
791
  # copy over the flux units
642
- outdata["converted_flux_unit"] = [flux_unit] * len(outdata)
792
+ outdata["converted_flux_unit"] = flux_unit
643
793
 
644
794
  # make sure all the datetimes are in the same format here too!!
645
795
  times = [
@@ -647,27 +797,28 @@ class Transient(MutableMapping):
647
797
  for d, f in zip(outdata.date, outdata.date_format.str.lower())
648
798
  ]
649
799
  outdata["converted_date"] = times
650
- outdata["converted_date_unit"] = [date_unit] * len(outdata)
651
-
652
- # same with frequencies and wavelengths
653
- freqs = []
654
- waves = []
655
-
656
- for _, row in df.iterrows():
657
- if "freq_eff" in row and not np.isnan(row["freq_eff"]):
658
- val = row["freq_eff"] * u.Unit(row["freq_units"])
659
- elif "wave_eff" in df and not np.isnan(row["wave_eff"]):
660
- val = row["wave_eff"] * u.Unit(row["wave_units"])
661
- else:
662
- raise ValueError("No known frequency or wavelength, please fix!")
800
+ outdata["converted_date_unit"] = date_unit
801
+
802
+ # compute the upperlimit value based on a 3 sigma detection
803
+ # this is just for rows where we don't already know if it is an upperlimit
804
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
805
+ # this uses the following formula (which is surprising because it means
806
+ # magnitude upperlimits are independent of the actual measurement!)
807
+ # sigma_m > (1/3) * (ln(10)/2.5)
808
+ def is_upperlimit(row):
809
+ if pd.isna(row.upperlimit):
810
+ return row.converted_flux_err > np.log(10) / (3 * 2.5)
811
+ else:
812
+ return row.upperlimit
813
+ else:
663
814
 
664
- freqs.append(val.to(freq_unit, equivalencies=u.spectral()).value)
665
- waves.append(val.to(wave_unit, equivalencies=u.spectral()).value)
815
+ def is_upperlimit(row):
816
+ if pd.isna(row.upperlimit):
817
+ return row.converted_flux < 3 * row.converted_flux_err
818
+ else:
819
+ return row.upperlimit
666
820
 
667
- outdata["converted_freq"] = freqs
668
- outdata["converted_wave"] = waves
669
- outdata["converted_wave_unit"] = [wave_unit] * len(outdata)
670
- outdata["converted_freq_unit"] = [freq_unit] * len(outdata)
821
+ outdata["upperlimit"] = outdata.apply(is_upperlimit, axis=1)
671
822
 
672
823
  return outdata
673
824
 
@@ -756,16 +907,6 @@ class Transient(MutableMapping):
756
907
  bothlines = [{"value": k, "reference": t1map[k] + t2map[k]} for k in inboth]
757
908
  out[key]["alias"] = line2 + line1 + bothlines
758
909
 
759
- def _merge_coords(t1, t2, out): # noqa: N805
760
- """
761
- Merge the coordinates subdictionaries for t1 and t2 and put it in out
762
-
763
- Use pandas to drop any duplicates
764
- """
765
- key = "coordinate"
766
-
767
- Transient._merge_arbitrary(key, t1, t2, out)
768
-
769
910
  def _merge_filter_alias(t1, t2, out): # noqa: N805
770
911
  """
771
912
  Combine the filter alias lists across the transient objects
@@ -784,11 +925,21 @@ class Transient(MutableMapping):
784
925
  Just keep whichever schema version is greater
785
926
  """
786
927
  key = "schema_version/value"
787
- if int(t1[key]) > int(t2[key]):
928
+ if "comment" not in t1["schema_version"]:
929
+ t1["schema_version/comment"] = ""
930
+
931
+ if "comment" not in t2["schema_version"]:
932
+ t2["schema_version/comment"] = ""
933
+
934
+ if key in t1 and key in t2 and int(t1[key]) > int(t2[key]):
788
935
  out["schema_version"] = deepcopy(t1["schema_version"])
789
936
  else:
790
937
  out["schema_version"] = deepcopy(t2["schema_version"])
791
938
 
939
+ out["schema_version"]["comment"] = (
940
+ t1["schema_version/comment"] + ";" + t2["schema_version/comment"]
941
+ )
942
+
792
943
  def _merge_photometry(t1, t2, out): # noqa: N805
793
944
  """
794
945
  Combine photometry sources
@@ -797,8 +948,15 @@ class Transient(MutableMapping):
797
948
  key = "photometry"
798
949
 
799
950
  out[key] = deepcopy(t1[key])
800
- refs = np.array([d["reference"] for d in out[key]])
951
+ refs = [] # np.array([d["reference"] for d in out[key]])
801
952
  # merge_dups = lambda val: np.sum(val) if np.any(val.isna()) else val.iloc[0]
953
+ for val in out[key]:
954
+ if isinstance(val, list):
955
+ refs += val
956
+ elif isinstance(val, np.ndarray):
957
+ refs += list(val)
958
+ else:
959
+ refs.append(val)
802
960
 
803
961
  for val in t2[key]:
804
962
  # first check if t2's reference is in out
@@ -823,12 +981,6 @@ class Transient(MutableMapping):
823
981
 
824
982
  out[key][i1] = newdict # replace the dictionary at i1 with the new dict
825
983
 
826
- def _merge_spectra(t1, t2, out): # noqa: N805
827
- """
828
- Combine spectra sources
829
- """
830
- pass
831
-
832
984
  def _merge_class(t1, t2, out): # noqa: N805
833
985
  """
834
986
  Combine the classification attribute
@@ -864,24 +1016,8 @@ class Transient(MutableMapping):
864
1016
  else:
865
1017
  item["default"] = False
866
1018
 
867
- def _merge_date(t1, t2, out): # noqa: N805
868
- """
869
- Combine epoch data across two transients and write it to "out"
870
- """
871
- key = "date_reference"
872
-
873
- Transient._merge_arbitrary(key, t1, t2, out)
874
-
875
- def _merge_distance(t1, t2, out): # noqa: N805
876
- """
877
- Combine distance information for these two transients
878
- """
879
- key = "distance"
880
-
881
- Transient._merge_arbitrary(key, t1, t2, out)
882
-
883
1019
  @staticmethod
884
- def _merge_arbitrary(key, t1, t2, out):
1020
+ def _merge_arbitrary(key, t1, t2, out, merge_subkeys=None, groupby_key=None):
885
1021
  """
886
1022
  Merge two arbitrary datasets inside the json file using pandas
887
1023
 
@@ -889,44 +1025,81 @@ class Transient(MutableMapping):
889
1025
  a NxM pandas dataframe!
890
1026
  """
891
1027
 
892
- df1 = pd.DataFrame(t1[key])
893
- df2 = pd.DataFrame(t2[key])
894
-
895
- merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
896
-
897
- # have to get the indexes to drop using a string rep of the df
898
- # this is cause we have lists in some cells
899
- to_drop = merged_with_dups.astype(str).drop_duplicates().index
900
-
901
- merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
902
-
903
- outdict = merged.to_dict(orient="records")
1028
+ if key == "name":
1029
+ t1._merge_names(t2, out)
1030
+ elif key == "filter_alias":
1031
+ t1._merge_filter_alias(t2, out)
1032
+ elif key == "schema_version":
1033
+ t1._merge_schema_version(t2, out)
1034
+ elif key == "photometry":
1035
+ t1._merge_photometry(t2, out)
1036
+ elif key == "classification":
1037
+ t1._merge_class(t2, out)
1038
+ else:
1039
+ # this is where we can standardize some of the merging
1040
+ df1 = pd.DataFrame(t1[key])
1041
+ df2 = pd.DataFrame(t2[key])
1042
+
1043
+ merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
1044
+
1045
+ # have to get the indexes to drop using a string rep of the df
1046
+ # this is cause we have lists in some cells
1047
+ # We also need to deal with merging the lists of references across rows
1048
+ # that we deem to be duplicates. This solution to do this quickly is from
1049
+ # https://stackoverflow.com/questions/36271413/ \
1050
+ # pandas-merge-nearly-duplicate-rows-based-on-column-value
1051
+ if merge_subkeys is None:
1052
+ merge_subkeys = merged_with_dups.columns.tolist()
1053
+ merge_subkeys.remove("reference")
1054
+ else:
1055
+ for k in merge_subkeys:
1056
+ if k not in merged_with_dups:
1057
+ merge_subkeys.remove(k)
1058
+
1059
+ merged = (
1060
+ merged_with_dups.astype(str)
1061
+ .groupby(merge_subkeys)["reference"]
1062
+ .apply(lambda x: x.sum())
1063
+ .reset_index()
1064
+ )
904
1065
 
905
- outdict_cleaned = Transient._remove_nans(
906
- outdict
907
- ) # clear out the nans from pandas conversion
1066
+ # then we have to turn the merged reference strings into a string list
1067
+ merged["reference"] = merged.reference.str.replace("][", ",")
908
1068
 
909
- out[key] = outdict_cleaned
1069
+ # then eval the string of a list to get back an actual list of sources
1070
+ merged["reference"] = merged.reference.apply(
1071
+ lambda v: np.unique(eval(v)).tolist()
1072
+ )
910
1073
 
911
- @staticmethod
912
- def _remove_nans(d):
913
- """
914
- Remove nans from a record dictionary
1074
+ # decide on default values
1075
+ if groupby_key is None:
1076
+ iterate_through = [(0, merged)]
1077
+ else:
1078
+ iterate_through = merged.groupby(groupby_key)
1079
+
1080
+ # we will make whichever value has more references the default
1081
+ outdict = []
1082
+ for data_type, df in iterate_through:
1083
+ lengths = df.reference.map(len)
1084
+ max_idx_arr = np.argmax(lengths)
1085
+
1086
+ if isinstance(max_idx_arr, np.int64):
1087
+ max_idx = max_idx_arr
1088
+ elif len(max_idx_arr) == 0:
1089
+ raise ValueError("Something went wrong with deciding the default")
1090
+ else:
1091
+ max_idx = max_idx_arr[0] # arbitrarily choose the first
915
1092
 
916
- THIS IS SLOW: O(n^2)!!! WILL NEED TO BE SPED UP LATER
917
- """
1093
+ defaults = np.full(len(df), False, dtype=bool)
1094
+ defaults[max_idx] = True
918
1095
 
919
- outd = []
920
- for item in d:
921
- outsubd = {}
922
- for key, val in item.items():
923
- if not isinstance(val, float):
924
- # this definitely is not NaN
925
- outsubd[key] = val
1096
+ df["default"] = defaults
1097
+ outdict.append(df)
1098
+ outdict = pd.concat(outdict)
926
1099
 
927
- else:
928
- if not np.isnan(val):
929
- outsubd[key] = val
930
- outd.append(outsubd)
1100
+ # from https://stackoverflow.com/questions/52504972/ \
1101
+ # converting-a-pandas-df-to-json-without-nan
1102
+ outdict = outdict.replace("nan", np.nan)
1103
+ outdict_cleaned = [{**x[i]} for i, x in outdict.stack().groupby(level=0)]
931
1104
 
932
- return outd
1105
+ out[key] = outdict_cleaned