astro-otter 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astro-otter might be problematic. Click here for more details.

otter/io/transient.py ADDED
@@ -0,0 +1,883 @@
1
+ """
2
+ Class for a transient,
3
+ basically just inherits the dict properties with some overwriting
4
+ """
5
+
6
+ import warnings
7
+ from copy import deepcopy
8
+ import re
9
+ from collections.abc import MutableMapping
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ import astropy.units as u
15
+ from astropy.time import Time
16
+ from astropy.coordinates import SkyCoord
17
+
18
+ from synphot.units import VEGAMAG, convert_flux
19
+ from synphot.spectrum import SourceSpectrum
20
+
21
+ from ..exceptions import (
22
+ FailedQueryError,
23
+ IOError,
24
+ OtterLimitationError,
25
+ TransientMergeError,
26
+ )
27
+ from ..util import XRAY_AREAS
28
+
29
+ warnings.simplefilter("once", RuntimeWarning)
30
+ warnings.simplefilter("once", UserWarning)
31
+ np.seterr(divide="ignore")
32
+
33
+
34
+ class Transient(MutableMapping):
35
+ def __init__(self, d={}, name=None):
36
+ """
37
+ Overwrite the dictionary init
38
+
39
+ Args:
40
+ d [dict]: A transient dictionary
41
+ """
42
+ self.data = d
43
+
44
+ if "reference_alias" in self:
45
+ self.srcmap = {
46
+ ref["name"]: ref["human_readable_name"]
47
+ for ref in self["reference_alias"]
48
+ }
49
+ self.srcmap["TNS"] = "TNS"
50
+ else:
51
+ self.srcmap = {}
52
+
53
+ if "name" in self:
54
+ if "default_name" in self["name"]:
55
+ self.default_name = self["name"]["default_name"]
56
+ else:
57
+ raise AttributeError("Missing the default name!!")
58
+ elif name is not None:
59
+ self.default_name = name
60
+ else:
61
+ self.default_name = "Missing Default Name"
62
+
63
+ # Make it so all coordinates are astropy skycoords
64
+
65
+ def __getitem__(self, keys):
66
+ """
67
+ Override getitem to recursively access Transient elements
68
+ """
69
+
70
+ if isinstance(keys, (list, tuple)):
71
+ return Transient({key: self[key] for key in keys})
72
+ elif isinstance(keys, str) and "/" in keys: # this is for a path
73
+ s = "']['".join(keys.split("/"))
74
+ s = "['" + s
75
+ s += "']"
76
+ return eval(f"self{s}")
77
+ elif (
78
+ isinstance(keys, int)
79
+ or keys.isdigit()
80
+ or (keys[0] == "-" and keys[1:].isdigit())
81
+ ):
82
+ # this is for indexing a sublist
83
+ return self[int(keys)]
84
+ else:
85
+ return self.data[keys]
86
+
87
+ def __setitem__(self, key, value):
88
+ if isinstance(key, str) and "/" in key: # this is for a path
89
+ s = "']['".join(key.split("/"))
90
+ s = "['" + s
91
+ s += "']"
92
+ exec(f"self{s} = value")
93
+ else:
94
+ self.data[key] = value
95
+
96
+ def __delitem__(self, keys):
97
+ if "/" in keys:
98
+ raise OtterLimitationError(
99
+ "For security, we can not delete with the / syntax!"
100
+ )
101
+ else:
102
+ del self.data[keys]
103
+
104
+ def __iter__(self):
105
+ return iter(self.data)
106
+
107
+ def __len__(self):
108
+ return len(self.data)
109
+
110
+ def __repr__(self, html=False):
111
+ if not html:
112
+ return str(self.data)
113
+ else:
114
+ html = ""
115
+
116
+ coord = self.get_skycoord()
117
+
118
+ # add the ra and dec
119
+ # These are required so no need to check if they are there
120
+ html += f"""
121
+ <tr>
122
+ <td style="text-align:left">RA [hrs]:</td>
123
+ <td style="text-align:left">{coord.ra}
124
+ </tr>
125
+ <tr>
126
+ <td style="text-align:left">DEC [deg]:</td>
127
+ <td style="text-align:left">{coord.dec}
128
+ </tr>
129
+ """
130
+
131
+ if "date_reference" in self:
132
+ discovery = self.getDiscoveryDate().to_value("datetime")
133
+ if discovery is not None:
134
+ # add the discovery date
135
+ html += f"""
136
+ <tr>
137
+ <td style="text-align:left">Discovery Date [MJD]:</td>
138
+ <td style="text-align:left">{discovery}
139
+ </tr>
140
+ """
141
+
142
+ if "distance" in self:
143
+ # add the redshift
144
+ html += f"""
145
+ <tr>
146
+ <td style="text-align:left">Redshift:</td>
147
+ <td style="text-align:left">{self['distance']['redshift'][0]['value']}
148
+ </tr>
149
+ """
150
+
151
+ if "reference_alias" in self:
152
+ srcs = ""
153
+ for bibcode, src in self.srcmap.items():
154
+ srcs += f"<a href='https://ui.adsabs.harvard.edu/abs/{bibcode}'"
155
+ srcs += f"target='_blank'>{src}</a><br>"
156
+
157
+ html += f"""
158
+ <tr>
159
+ <td style="text-align:left">Sources:</td>
160
+ <td style="text-align:left">{srcs}
161
+ </tr>
162
+ """
163
+
164
+ return html
165
+
166
+ def keys(self):
167
+ return self.data.keys()
168
+
169
+ def __add__(self, other, strict_merge=True):
170
+ """
171
+ Merge this transient object with another transient object
172
+
173
+ Args:
174
+ other [Transient]: A Transient object to merge with
175
+ strict_merge [bool]: If True it won't let you merge objects that
176
+ intuitively shouldn't be merged (ie. different
177
+ transient events).
178
+ """
179
+
180
+ # first check that this object is within a good distance of the other object
181
+ if (
182
+ strict_merge
183
+ and self.get_skycoord().separation(other.get_skycoord()) > 10 * u.arcsec
184
+ ):
185
+ raise TransientMergeError(
186
+ "These two transients are not within 10 arcseconds!"
187
+ + " They probably do not belong together! If they do"
188
+ + " You can set strict_merge=False to override the check"
189
+ )
190
+
191
+ # create a blank dictionary since we don't want to overwrite this object
192
+ out = {}
193
+
194
+ # find the keys that are
195
+ merge_keys = list(
196
+ self.keys() & other.keys()
197
+ ) # in both t1 and t2 so we need to merge these keys
198
+ only_in_t1 = list(self.keys() - other.keys()) # only in t1
199
+ only_in_t2 = list(other.keys() - self.keys()) # only in t2
200
+
201
+ # now let's handle the merge keys
202
+ for key in merge_keys:
203
+ # reference_alias is special
204
+ # we ALWAYS should combine these two
205
+ if key == "reference_alias":
206
+ out[key] = self[key]
207
+ if self[key] != other[key]:
208
+ # only add t2 values if they aren't already in it
209
+ bibcodes = {ref["name"] for ref in self[key]}
210
+ for val in other[key]:
211
+ if val["name"] not in bibcodes:
212
+ out[key].append(val)
213
+ continue
214
+
215
+ # we can skip this merge process and just add the values from t1
216
+ # if they are equal. We should still add the new reference though!
217
+ if self[key] == other[key]:
218
+ # set the value
219
+ # we don't need to worry about references because this will
220
+ # only be true if the reference is also equal!
221
+ out[key] = deepcopy(self[key])
222
+ continue
223
+
224
+ # There are some special keys that we are expecting
225
+ if key == "name":
226
+ self._merge_names(other, out)
227
+ elif key == "coordinate":
228
+ self._merge_coords(other, out)
229
+ elif key == "date_reference":
230
+ self._merge_date(other, out)
231
+ elif key == "distance":
232
+ self._merge_distance(other, out)
233
+ elif key == "filter_alias":
234
+ self._merge_filter_alias(other, out)
235
+ elif key == "schema_version":
236
+ self._merge_schema_version(other, out)
237
+ elif key == "photometry":
238
+ self._merge_photometry(other, out)
239
+ elif key == "spectra":
240
+ self._merge_spectra(other, out)
241
+ elif key == "classification":
242
+ self._merge_class(other, out)
243
+ else:
244
+ # this is an unexpected key!
245
+ if strict_merge:
246
+ # since this is a strict merge we don't want unexpected data!
247
+ raise TransientMergeError(
248
+ f"{key} was not expected! Only keeping the old information!"
249
+ )
250
+ else:
251
+ # Throw a warning and only keep the old stuff
252
+ warnings.warn(
253
+ f"{key} was not expected! Only keeping the old information!"
254
+ )
255
+ out[key] = deepcopy(self[key])
256
+
257
+ # and now combining out with the stuff only in t1 and t2
258
+ out = out | dict(self[only_in_t1]) | dict(other[only_in_t2])
259
+
260
+ # now return out as a Transient Object
261
+ return Transient(out)
262
+
263
+ def get_meta(self, keys=None):
264
+ """
265
+ Get the metadata (no photometry or spectra)
266
+
267
+ This essentially just wraps on __getitem__ but with some checks
268
+
269
+ Args:
270
+ keys [list[str]] : list of keys
271
+ """
272
+ if keys is None:
273
+ keys = list(self.keys())
274
+
275
+ # note: using the remove method is safe here because dict keys are unique
276
+ if "photometry" in keys:
277
+ keys.remove("photometry")
278
+ if "spectra" in keys:
279
+ keys.remove("spectra")
280
+ else:
281
+ # run some checks
282
+ if "photometry" in keys:
283
+ warnings.warn("Not returing the photometry!")
284
+ _ = keys.pop("photometry")
285
+ if "spectra" in keys:
286
+ warnings.warn("Not returning the spectra!")
287
+ _ = keys.pop("spectra")
288
+
289
+ curr_keys = self.keys()
290
+ for key in keys:
291
+ if key not in curr_keys:
292
+ keys.remove(key)
293
+ warnings.warn(
294
+ f"Not returning {key} because it is not in this transient!"
295
+ )
296
+
297
+ return self[keys]
298
+
299
+ def get_skycoord(self, coord_format="icrs"):
300
+ """
301
+ Convert the coordinates to an astropy SkyCoord
302
+ """
303
+
304
+ # now we can generate the SkyCoord
305
+ f = "df['coordinate_type'] == 'equitorial'"
306
+ coord_dict = self._get_default("coordinate", filt=f)
307
+ coordin = self._reformat_coordinate(coord_dict)
308
+ coord = SkyCoord(**coordin).transform_to(coord_format)
309
+
310
+ return coord
311
+
312
+ def get_discovery_date(self):
313
+ """
314
+ Get the default discovery date
315
+ """
316
+ key = "date_reference"
317
+ date = self._get_default(key, filt='df["date_type"] == "discovery"')
318
+ if "date_format" in date:
319
+ f = date["date_format"]
320
+ else:
321
+ f = "mjd"
322
+
323
+ return Time(date["value"], format=f)
324
+
325
+ def get_redshift(self):
326
+ """
327
+ Get the default redshift
328
+ """
329
+ f = "df['distance_type']=='redshift'"
330
+ default = self._get_default("distance", filt=f)
331
+ if default is None:
332
+ return default
333
+ else:
334
+ return default["value"]
335
+
336
+ def _get_default(self, key, filt=""):
337
+ """
338
+ Get the default of key
339
+
340
+ Args:
341
+ key [str]: key in self to look for the default of
342
+ filt [str]: a valid pandas dataframe filter to index a pandas dataframe
343
+ called df.
344
+ """
345
+ if key not in self:
346
+ raise KeyError(f"This transient does not have {key} associated with it!")
347
+
348
+ df = pd.DataFrame(self[key])
349
+ df = df[eval(filt)] # apply the filters
350
+
351
+ if "default" in df:
352
+ # first try to get the default
353
+ df_filtered = df[df.default == True]
354
+ if len(df_filtered) == 0:
355
+ df_filtered = df
356
+ else:
357
+ df_filtered = df
358
+
359
+ if len(df_filtered) == 0:
360
+ return None
361
+ return df_filtered.iloc[0]
362
+
363
+ def _reformat_coordinate(self, item):
364
+ """
365
+ Reformat the coordinate information in item
366
+ """
367
+ coordin = None
368
+ if "ra" in item and "dec" in item:
369
+ # this is an equitorial coordinate
370
+ coordin = {
371
+ "ra": item["ra"],
372
+ "dec": item["dec"],
373
+ "unit": (item["ra_units"], item["dec_units"]),
374
+ }
375
+ elif "l" in item and "b" in item:
376
+ coordin = {
377
+ "l": item["l"],
378
+ "b": item["b"],
379
+ "unit": (item["l_units"], item["b_units"]),
380
+ "frame": "galactic",
381
+ }
382
+
383
+ return coordin
384
+
385
+ def clean_photometry(
386
+ self,
387
+ flux_unit: u.Unit = "mag(AB)",
388
+ date_unit: u.Unit = "MJD",
389
+ freq_unit: u.Unit = "GHz",
390
+ wave_unit: u.Unit = "nm",
391
+ by: str = "raw",
392
+ obs_type: str = None,
393
+ ):
394
+ """
395
+ Ensure the photometry associated with this transient is all in the same
396
+ units/system/etc
397
+ """
398
+
399
+ # check inputs
400
+ if by not in {"value", "raw"}:
401
+ raise IOError("Please choose either value or raw!")
402
+
403
+ # turn the photometry key into a pandas dataframe
404
+ dfs = []
405
+ for item in self["photometry"]:
406
+ max_len = 0
407
+ for key, val in item.items():
408
+ if isinstance(val, list) and key != "reference":
409
+ max_len = max(max_len, len(val))
410
+
411
+ for key, val in item.items():
412
+ if not isinstance(val, list) or (
413
+ isinstance(val, list) and len(val) != max_len
414
+ ):
415
+ item[key] = [val] * max_len
416
+
417
+ df = pd.DataFrame(item)
418
+ dfs.append(df)
419
+
420
+ c = pd.concat(dfs)
421
+
422
+ filters = pd.DataFrame(self["filter_alias"])
423
+ df = c.merge(filters, on="filter_key")
424
+
425
+ # make sure 'by' is in df
426
+ if by not in df:
427
+ if by == "value":
428
+ by = "raw"
429
+ else:
430
+ by = "value"
431
+
432
+ # drop irrelevant obs_types before continuing
433
+ if obs_type is not None:
434
+ valid_obs_types = {"radio", "uvoir", "xray"}
435
+ if obs_type not in valid_obs_types:
436
+ raise IOError("Please provide a valid obs_type")
437
+ df = df[df.obs_type == obs_type]
438
+
439
+ # convert the ads bibcodes to a string of human readable sources here
440
+ def mappedrefs(row):
441
+ if isinstance(row.reference, list):
442
+ return "<br>".join([self.srcmap[bibcode] for bibcode in row.reference])
443
+ else:
444
+ return self.srcmap[row.reference]
445
+
446
+ try:
447
+ df["human_readable_refs"] = df.apply(mappedrefs, axis=1)
448
+ except Exception as exc:
449
+ warnings.warn(f"Unable to apply the source mapping because {exc}")
450
+ df["human_readable_refs"] = df.reference
451
+
452
+ # Figure out what columns are good to groupby in the photometry
453
+ outdata = []
454
+ if "telescope" in df:
455
+ tele = True
456
+ to_grp_by = ["obs_type", by + "_units", "telescope"]
457
+ else:
458
+ tele = False
459
+ to_grp_by = ["obs_type", by + "_units"]
460
+
461
+ # Do the conversion based on what we decided to group by
462
+ for groupedby, data in df.groupby(to_grp_by, dropna=False):
463
+ if tele:
464
+ obstype, unit, telescope = groupedby
465
+ else:
466
+ obstype, unit = groupedby
467
+ telescope = None
468
+
469
+ # get the photometry in the right type
470
+ unit = data[by + "_units"].unique()
471
+ if len(unit) > 1:
472
+ raise OtterLimitationError(
473
+ "Can not apply multiple units for different obs_types"
474
+ )
475
+
476
+ unit = unit[0]
477
+ try:
478
+ if "vega" in unit.lower():
479
+ astropy_units = VEGAMAG
480
+ else:
481
+ astropy_units = u.Unit(unit)
482
+
483
+ except ValueError:
484
+ # this means there is something likely slightly off in the input unit
485
+ # string. Let's try to fix it!
486
+ # here are some common mistakes
487
+ unit = unit.replace("ergs", "erg")
488
+ unit = unit.replace("AB", "mag(AB)")
489
+
490
+ astropy_units = u.Unit(unit)
491
+
492
+ except ValueError:
493
+ raise ValueError(
494
+ "Could not coerce your string into astropy unit format!"
495
+ )
496
+
497
+ # get the flux data and find the type
498
+ indata = np.array(data[by].astype(float))
499
+ err_key = by + "_err"
500
+ if err_key in data:
501
+ indata_err = np.array(data[by + "_err"].astype(float))
502
+ else:
503
+ indata_err = np.zeros(len(data))
504
+ q = indata * u.Unit(astropy_units)
505
+ q_err = indata_err * u.Unit(
506
+ astropy_units
507
+ ) # assume error and values have the same unit
508
+
509
+ # get the effective wavelength
510
+ if "freq_eff" in data and not np.isnan(data["freq_eff"].iloc[0]):
511
+ freq_units = data["freq_units"]
512
+ if len(np.unique(freq_units)) > 1:
513
+ raise OtterLimitationError(
514
+ "Can not convert different units to the same unit!"
515
+ )
516
+
517
+ freq_eff = np.array(data["freq_eff"]) * u.Unit(freq_units.iloc[0])
518
+ wave_eff = freq_eff.to(u.AA, equivalencies=u.spectral())
519
+
520
+ elif "wave_eff" in data and not np.isnan(data["wave_eff"].iloc[0]):
521
+ wave_units = data["wave_units"]
522
+ if len(np.unique(wave_units)) > 1:
523
+ raise OtterLimitationError(
524
+ "Can not convert different units to the same unit!"
525
+ )
526
+
527
+ wave_eff = np.array(data["wave_eff"]) * u.Unit(wave_units.iloc[0])
528
+
529
+ # convert using synphot
530
+ # stuff has to be done slightly differently for xray than for the others
531
+ if obstype == "xray":
532
+ if telescope is not None:
533
+ try:
534
+ area = XRAY_AREAS[telescope.lower()]
535
+ except KeyError:
536
+ raise OtterLimitationError(
537
+ "Did not find an area corresponding to "
538
+ + "this telescope, please add to util!"
539
+ )
540
+ else:
541
+ raise OtterLimitationError(
542
+ "Can not convert x-ray data without a " + "telescope"
543
+ )
544
+
545
+ # we also need to make this wave_min and wave_max
546
+ # instead of just the effective wavelength like for radio and uvoir
547
+ wave_eff = np.array(
548
+ list(zip(data["wave_min"], data["wave_max"]))
549
+ ) * u.Unit(wave_units.iloc[0])
550
+
551
+ else:
552
+ area = None
553
+
554
+ # we unfortunately have to loop over the points here because
555
+ # syncphot does not work with a 2D array of min max wavelengths
556
+ # for converting counts to other flux units. It also can't convert
557
+ # vega mags with a wavelength array because it then interprets that as the
558
+ # wavelengths corresponding to the SourceSpectrum.from_vega()
559
+ flux, flux_err = [], []
560
+ for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
561
+ f_val = convert_flux(
562
+ wave,
563
+ xray_point,
564
+ u.Unit(flux_unit),
565
+ vegaspec=SourceSpectrum.from_vega(),
566
+ area=area,
567
+ )
568
+ f_err = convert_flux(
569
+ wave,
570
+ xray_point_err,
571
+ u.Unit(flux_unit),
572
+ vegaspec=SourceSpectrum.from_vega(),
573
+ area=area,
574
+ )
575
+
576
+ # then we take the average of the minimum and maximum values
577
+ # computed by syncphot
578
+ flux.append(np.mean(f_val).value)
579
+ flux_err.append(np.mean(f_err).value)
580
+
581
+ flux = np.array(flux) * u.Unit(flux_unit)
582
+ flux_err = np.array(flux_err) * u.Unit(flux_unit)
583
+
584
+ data["converted_flux"] = flux.value
585
+ data["converted_flux_err"] = flux_err.value
586
+ outdata.append(data)
587
+
588
+ if len(outdata) == 0:
589
+ raise FailedQueryError()
590
+ outdata = pd.concat(outdata)
591
+
592
+ # copy over the flux units
593
+ outdata["converted_flux_unit"] = [flux_unit] * len(outdata)
594
+
595
+ # make sure all the datetimes are in the same format here too!!
596
+ times = [
597
+ Time(d, format=f).to_value(date_unit.lower())
598
+ for d, f in zip(outdata.date, outdata.date_format.str.lower())
599
+ ]
600
+ outdata["converted_date"] = times
601
+ outdata["converted_date_unit"] = [date_unit] * len(outdata)
602
+
603
+ # same with frequencies and wavelengths
604
+ freqs = []
605
+ waves = []
606
+
607
+ for _, row in df.iterrows():
608
+ if "freq_eff" in row and not np.isnan(row["freq_eff"]):
609
+ val = row["freq_eff"] * u.Unit(row["freq_units"])
610
+ elif "wave_eff" in df and not np.isnan(row["wave_eff"]):
611
+ val = row["wave_eff"] * u.Unit(row["wave_units"])
612
+ else:
613
+ raise ValueError("No known frequency or wavelength, please fix!")
614
+
615
+ freqs.append(val.to(freq_unit, equivalencies=u.spectral()).value)
616
+ waves.append(val.to(wave_unit, equivalencies=u.spectral()).value)
617
+
618
+ outdata["converted_freq"] = freqs
619
+ outdata["converted_wave"] = waves
620
+ outdata["converted_wave_unit"] = [wave_unit] * len(outdata)
621
+ outdata["converted_freq_unit"] = [freq_unit] * len(outdata)
622
+
623
+ return outdata
624
+
625
+ def _merge_names(t1, t2, out): # noqa: N805
626
+ """
627
+ Private method to merge the name data in t1 and t2 and put it in out
628
+ """
629
+ key = "name"
630
+ out[key] = {}
631
+
632
+ # first deal with the default_name key
633
+ # we are gonna need to use some regex magic to choose a preferred default_name
634
+ if t1[key]["default_name"] == t2[key]["default_name"]:
635
+ out[key]["default_name"] = t1[key]["default_name"]
636
+ else:
637
+ # we need to decide which default_name is better
638
+ # it should be the one that matches the TNS style
639
+ # let's use regex
640
+ n1 = t1[key]["default_name"]
641
+ n2 = t2[key]["default_name"]
642
+
643
+ # write some discriminating regex expressions
644
+ # exp1: starts with a number, this is preferred because it is TNS style
645
+ exp1 = "^[0-9]"
646
+ # exp2: starts with any character, also preferred because it is TNS style
647
+ exp2 = ".$"
648
+ # exp3: checks if first four characters are a number, like a year :),
649
+ # this is pretty strict though
650
+ exp3 = "^[0-9]{3}"
651
+ # exp4: # checks if it starts with AT like TNS names
652
+ exp4 = "^AT"
653
+
654
+ # combine all the regex expressions, this makes it easier to add more later
655
+ exps = [exp1, exp2, exp3, exp4]
656
+
657
+ # score each default_name based on this
658
+ score1 = 0
659
+ score2 = 0
660
+ for e in exps:
661
+ re1 = re.findall(e, n1)
662
+ re2 = re.findall(e, n2)
663
+ if re1:
664
+ score1 += 1
665
+ if re2:
666
+ score2 += 1
667
+
668
+ # assign a default_name based on the score
669
+ if score1 > score2:
670
+ out[key]["default_name"] = t1[key]["default_name"]
671
+ elif score2 > score1:
672
+ out[key]["default_name"] = t2[key]["default_name"]
673
+ else:
674
+ warnings.warn(
675
+ "Names have the same score! Just using the existing default_name"
676
+ )
677
+ out[key]["default_name"] = t1[key]["default_name"]
678
+
679
+ # now deal with aliases
680
+ # create a reference mapping for each
681
+ t1map = {}
682
+ for val in t1[key]["alias"]:
683
+ ref = val["reference"]
684
+ if isinstance(ref, str):
685
+ t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
686
+ else:
687
+ t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
688
+
689
+ t2map = {}
690
+ for val in t2[key]["alias"]:
691
+ ref = val["reference"]
692
+ if isinstance(ref, str):
693
+ t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
694
+ else:
695
+ t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
696
+
697
+ # figure out which ones we need to be careful with references in
698
+ inboth = list(
699
+ t1map.keys() & t2map.keys()
700
+ ) # in both so we'll have to merge the reference key
701
+ int1 = list(t1map.keys() - t2map.keys()) # only in t1
702
+ int2 = list(t2map.keys() - t1map.keys()) # only in t2
703
+
704
+ # add ones that are not in both first, these are easy
705
+ line1 = [{"value": k, "reference": t1map[k]} for k in int1]
706
+ line2 = [{"value": k, "reference": t2map[k]} for k in int2]
707
+ bothlines = [{"value": k, "reference": t1map[k] + t2map[k]} for k in inboth]
708
+ out[key]["alias"] = line2 + line1 + bothlines
709
+
710
+ def _merge_coords(t1, t2, out): # noqa: N805
711
+ """
712
+ Merge the coordinates subdictionaries for t1 and t2 and put it in out
713
+
714
+ Use pandas to drop any duplicates
715
+ """
716
+ key = "coordinate"
717
+
718
+ Transient._merge_arbitrary(key, t1, t2, out)
719
+
720
+ def _merge_filter_alias(t1, t2, out): # noqa: N805
721
+ """
722
+ Combine the filter alias lists across the transient objects
723
+ """
724
+
725
+ key = "filter_alias"
726
+
727
+ out[key] = deepcopy(t1[key])
728
+ keys1 = {filt["filter_key"] for filt in t1[key]}
729
+ for filt in t2[key]:
730
+ if filt["filter_key"] not in keys1:
731
+ out[key].append(filt)
732
+
733
+ def _merge_schema_version(t1, t2, out): # noqa: N805
734
+ """
735
+ Just keep whichever schema version is greater
736
+ """
737
+ key = "schema_version/value"
738
+ if int(t1[key]) > int(t2[key]):
739
+ out["schema_version"] = deepcopy(t1["schema_version"])
740
+ else:
741
+ out["schema_version"] = deepcopy(t2["schema_version"])
742
+
743
+ def _merge_photometry(t1, t2, out): # noqa: N805
744
+ """
745
+ Combine photometry sources
746
+ """
747
+
748
+ key = "photometry"
749
+
750
+ out[key] = deepcopy(t1[key])
751
+ refs = np.array([d["reference"] for d in out[key]])
752
+ # merge_dups = lambda val: np.sum(val) if np.any(val.isna()) else val.iloc[0]
753
+
754
+ for val in t2[key]:
755
+ # first check if t2's reference is in out
756
+ if val["reference"] not in refs:
757
+ # it's not here so we can just append the new photometry!
758
+ out[key].append(val)
759
+ else:
760
+ # we need to merge it with other photometry
761
+ i1 = np.where(val["reference"] == refs)[0][0]
762
+ df1 = pd.DataFrame(out[key][i1])
763
+ df2 = pd.DataFrame(val)
764
+
765
+ # only substitute in values that are nan in df1 or new
766
+ # the combined keys of the two
767
+ mergeon = list(set(df1.keys()) & set(df2.keys()))
768
+ df = df1.merge(df2, on=mergeon, how="outer")
769
+ # convert to a dictionary
770
+ newdict = df.reset_index().to_dict(orient="list")
771
+ del newdict["index"]
772
+
773
+ newdict["reference"] = newdict["reference"][0]
774
+
775
+ out[key][i1] = newdict # replace the dictionary at i1 with the new dict
776
+
777
+ def _merge_spectra(t1, t2, out): # noqa: N805
778
+ """
779
+ Combine spectra sources
780
+ """
781
+ pass
782
+
783
+ def _merge_class(t1, t2, out): # noqa: N805
784
+ """
785
+ Combine the classification attribute
786
+ """
787
+ key = "classification"
788
+ out[key] = deepcopy(t1[key])
789
+ classes = np.array([item["object_class"] for item in out[key]])
790
+ for item in t2[key]:
791
+ if item["object_class"] in classes:
792
+ i = np.where(item["object_class"] == classes)[0][0]
793
+ if int(item["confidence"]) > int(out[key][i]["confidence"]):
794
+ out[key][i]["confidence"] = item[
795
+ "confidence"
796
+ ] # we are now more confident
797
+
798
+ if not isinstance(out[key][i]["reference"], list):
799
+ out[key][i]["reference"] = [out[key][i]["reference"]]
800
+
801
+ if not isinstance(item["reference"], list):
802
+ item["reference"] = [item["reference"]]
803
+
804
+ newdata = list(np.unique(out[key][i]["reference"] + item["reference"]))
805
+ out[key][i]["reference"] = newdata
806
+
807
+ else:
808
+ out[key].append(item)
809
+
810
+ # now that we have all of them we need to figure out which one is the default
811
+ maxconf = max(out[key], key=lambda d: d["confidence"])
812
+ for item in out[key]:
813
+ if item == maxconf:
814
+ item["default"] = True
815
+ else:
816
+ item["default"] = False
817
+
818
+ def _merge_date(t1, t2, out): # noqa: N805
819
+ """
820
+ Combine epoch data across two transients and write it to "out"
821
+ """
822
+ key = "date_reference"
823
+
824
+ Transient._merge_arbitrary(key, t1, t2, out)
825
+
826
+ def _merge_distance(t1, t2, out): # noqa: N805
827
+ """
828
+ Combine distance information for these two transients
829
+ """
830
+ key = "distance"
831
+
832
+ Transient._merge_arbitrary(key, t1, t2, out)
833
+
834
+ @staticmethod
835
+ def _merge_arbitrary(key, t1, t2, out):
836
+ """
837
+ Merge two arbitrary datasets inside the json file using pandas
838
+
839
+ The datasets in t1 and t2 in "key" must be able to be forced into
840
+ a NxM pandas dataframe!
841
+ """
842
+
843
+ df1 = pd.DataFrame(t1[key])
844
+ df2 = pd.DataFrame(t2[key])
845
+
846
+ merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
847
+
848
+ # have to get the indexes to drop using a string rep of the df
849
+ # this is cause we have lists in some cells
850
+ to_drop = merged_with_dups.astype(str).drop_duplicates().index
851
+
852
+ merged = merged_with_dups.iloc[to_drop].reset_index(drop=True)
853
+
854
+ outdict = merged.to_dict(orient="records")
855
+
856
+ outdict_cleaned = Transient._remove_nans(
857
+ outdict
858
+ ) # clear out the nans from pandas conversion
859
+
860
+ out[key] = outdict_cleaned
861
+
862
+ @staticmethod
863
+ def _remove_nans(d):
864
+ """
865
+ Remove nans from a record dictionary
866
+
867
+ THIS IS SLOW: O(n^2)!!! WILL NEED TO BE SPED UP LATER
868
+ """
869
+
870
+ outd = []
871
+ for item in d:
872
+ outsubd = {}
873
+ for key, val in item.items():
874
+ if not isinstance(val, float):
875
+ # this definitely is not NaN
876
+ outsubd[key] = val
877
+
878
+ else:
879
+ if not np.isnan(val):
880
+ outsubd[key] = val
881
+ outd.append(outsubd)
882
+
883
+ return outd