astro-otter 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
otter/io/transient.py ADDED
@@ -0,0 +1,1453 @@
1
+ """
2
+ Class for a transient,
3
+ basically just inherits the dict properties with some overwriting
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import warnings
8
+ from copy import deepcopy
9
+ import re
10
+ from collections.abc import MutableMapping
11
+ from typing import Callable
12
+ from typing_extensions import Self
13
+ import logging
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ import astropy.units as u
19
+ from astropy.time import Time
20
+ from astropy.coordinates import SkyCoord
21
+
22
+ from ..exceptions import (
23
+ FailedQueryError,
24
+ IOError,
25
+ OtterLimitationError,
26
+ TransientMergeError,
27
+ )
28
+ from ..util import XRAY_AREAS, _KNOWN_CLASS_ROOTS, _DuplicateFilter
29
+ from .host import Host
30
+
31
+ np.seterr(divide="ignore")
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Transient(MutableMapping):
36
+ def __init__(self, d={}, name=None):
37
+ """
38
+ Overwrite the dictionary init
39
+
40
+ Args:
41
+ d (dict): A transient dictionary
42
+ name (str): The default name of the transient, default is None and it will
43
+ be inferred from the input dictionary.
44
+ """
45
+ self.data = d
46
+
47
+ if "reference_alias" in self:
48
+ self.srcmap = {
49
+ ref["name"]: ref["human_readable_name"]
50
+ for ref in self["reference_alias"]
51
+ }
52
+ self.srcmap["TNS"] = "TNS"
53
+ else:
54
+ self.srcmap = {}
55
+
56
+ if "name" in self:
57
+ if "default_name" in self["name"]:
58
+ self.default_name = self["name"]["default_name"]
59
+ else:
60
+ raise AttributeError("Missing the default name!!")
61
+ elif name is not None:
62
+ self.default_name = name
63
+ else:
64
+ self.default_name = "Missing Default Name"
65
+
66
+ # Make it so all coordinates are astropy skycoords
67
+
68
+ def __getitem__(self, keys):
69
+ """
70
+ Override getitem to recursively access Transient elements
71
+ """
72
+
73
+ if isinstance(keys, (list, tuple)):
74
+ return Transient({key: self[key] for key in keys if key in self})
75
+ elif isinstance(keys, str) and "/" in keys: # this is for a path
76
+ s = "']['".join(keys.split("/"))
77
+ s = "['" + s
78
+ s += "']"
79
+ return eval(f"self{s}")
80
+ elif (
81
+ isinstance(keys, int)
82
+ or keys.isdigit()
83
+ or (keys[0] == "-" and keys[1:].isdigit())
84
+ ):
85
+ # this is for indexing a sublist
86
+ return self[int(keys)]
87
+ else:
88
+ return self.data[keys]
89
+
90
+ def __setitem__(self, key, value):
91
+ """
92
+ Override set item to work with the '/' syntax
93
+ """
94
+
95
+ if isinstance(key, str) and "/" in key: # this is for a path
96
+ s = "']['".join(key.split("/"))
97
+ s = "['" + s
98
+ s += "']"
99
+ exec(f"self{s} = value")
100
+ else:
101
+ self.data[key] = value
102
+
103
+ def __delitem__(self, keys):
104
+ if "/" in keys:
105
+ raise OtterLimitationError(
106
+ "For security, we can not delete with the / syntax!"
107
+ )
108
+ else:
109
+ del self.data[keys]
110
+
111
+ def __iter__(self):
112
+ return iter(self.data)
113
+
114
+ def __len__(self):
115
+ return len(self.data)
116
+
117
+ def __repr__(self):
118
+ return f"Transient(\n\tName: {self.default_name},\n\tKeys: {self.keys()}\n)"
119
+
120
+ def keys(self):
121
+ return self.data.keys()
122
+
123
+ def __add__(self, other, strict_merge=True):
124
+ """
125
+ Merge this transient object with another transient object
126
+
127
+ Args:
128
+ other [Transient]: A Transient object to merge with
129
+ strict_merge [bool]: If True it won't let you merge objects that
130
+ intuitively shouldn't be merged (ie. different
131
+ transient events).
132
+ """
133
+
134
+ # first check that this object is within a good distance of the other object
135
+ if (
136
+ strict_merge
137
+ and self.get_skycoord().separation(other.get_skycoord()) > 10 * u.arcsec
138
+ ):
139
+ raise TransientMergeError(
140
+ "These two transients are not within 10 arcseconds!"
141
+ + " They probably do not belong together! If they do"
142
+ + " You can set strict_merge=False to override the check"
143
+ )
144
+
145
+ # create set of the allowed keywords
146
+ allowed_keywords = {
147
+ "name",
148
+ "date_reference",
149
+ "coordinate",
150
+ "distance",
151
+ "filter_alias",
152
+ "schema_version",
153
+ "photometry",
154
+ "classification",
155
+ "host",
156
+ }
157
+
158
+ merge_subkeys_map = {
159
+ "name": None,
160
+ "date_reference": ["value", "date_format", "date_type"],
161
+ "coordinate": None, # may need to update this if we run into problems
162
+ "distance": ["value", "distance_type", "unit"],
163
+ "filter_alias": None,
164
+ "schema_version": None,
165
+ "photometry": None,
166
+ "classification": None,
167
+ "host": [
168
+ "host_ra",
169
+ "host_dec",
170
+ "host_ra_units",
171
+ "host_dec_units",
172
+ "host_name",
173
+ ],
174
+ }
175
+
176
+ groupby_key_for_default_map = {
177
+ "name": None,
178
+ "date_reference": "date_type",
179
+ "coordinate": "coordinate_type",
180
+ "distance": "distance_type",
181
+ "filter_alias": None,
182
+ "schema_version": None,
183
+ "photometry": None,
184
+ "classification": None,
185
+ "host": None,
186
+ }
187
+
188
+ # create a blank dictionary since we don't want to overwrite this object
189
+ out = {}
190
+
191
+ # find the keys that are
192
+ merge_keys = list(
193
+ self.keys() & other.keys()
194
+ ) # in both t1 and t2 so we need to merge these keys
195
+ only_in_t1 = list(self.keys() - other.keys()) # only in t1
196
+ only_in_t2 = list(other.keys() - self.keys()) # only in t2
197
+
198
+ # now let's handle the merge keys
199
+ for key in merge_keys:
200
+ # reference_alias is special
201
+ # we ALWAYS should combine these two
202
+ if key == "reference_alias":
203
+ out[key] = self[key]
204
+ if self[key] != other[key]:
205
+ # only add t2 values if they aren't already in it
206
+ bibcodes = {ref["name"] for ref in self[key]}
207
+ for val in other[key]:
208
+ if val["name"] not in bibcodes:
209
+ out[key].append(val)
210
+ continue
211
+
212
+ # we can skip this merge process and just add the values from t1
213
+ # if they are equal. We should still add the new reference though!
214
+ if self[key] == other[key]:
215
+ # set the value
216
+ # we don't need to worry about references because this will
217
+ # only be true if the reference is also equal!
218
+ out[key] = deepcopy(self[key])
219
+ continue
220
+
221
+ # There are some special keys that we are expecting
222
+ if key in allowed_keywords:
223
+ Transient._merge_arbitrary(
224
+ key,
225
+ self,
226
+ other,
227
+ out,
228
+ merge_subkeys=merge_subkeys_map[key],
229
+ groupby_key=groupby_key_for_default_map[key],
230
+ )
231
+ else:
232
+ # this is an unexpected key!
233
+ if strict_merge:
234
+ # since this is a strict merge we don't want unexpected data!
235
+ raise TransientMergeError(f"{key} was not expected! Can not merge!")
236
+ else:
237
+ # Throw a warning and only keep the old stuff
238
+ logger.warning(
239
+ f"{key} was not expected! Only keeping the old information!"
240
+ )
241
+ out[key] = deepcopy(self[key])
242
+
243
+ # and now combining out with the stuff only in t1 and t2
244
+ out = out | dict(self[only_in_t1]) | dict(other[only_in_t2])
245
+
246
+ # now return out as a Transient Object
247
+ return Transient(out)
248
+
249
+ def get_meta(self, keys=None) -> Self:
250
+ """
251
+ Get the metadata (no photometry or spectra)
252
+
253
+ This essentially just wraps on __getitem__ but with some checks
254
+
255
+ Args:
256
+ keys (list[str]) : list of keys to get the metadata for from the transient
257
+
258
+ Returns:
259
+ A Transient object of just the meta data
260
+ """
261
+ if keys is None:
262
+ keys = list(self.keys())
263
+
264
+ # note: using the remove method is safe here because dict keys are unique
265
+ if "photometry" in keys:
266
+ keys.remove("photometry")
267
+ if "spectra" in keys:
268
+ keys.remove("spectra")
269
+ else:
270
+ # run some checks
271
+ if "photometry" in keys:
272
+ logger.warning("Not returing the photometry!")
273
+ _ = keys.pop("photometry")
274
+ if "spectra" in keys:
275
+ logger.warning("Not returning the spectra!")
276
+ _ = keys.pop("spectra")
277
+
278
+ curr_keys = self.keys()
279
+ for key in keys:
280
+ if key not in curr_keys:
281
+ keys.remove(key)
282
+ logger.warning(
283
+ f"Not returning {key} because it is not in this transient!"
284
+ )
285
+
286
+ return self[keys]
287
+
288
+ def get_skycoord(self, coord_format="icrs") -> SkyCoord:
289
+ """
290
+ Convert the coordinates to an astropy SkyCoord
291
+
292
+ Args:
293
+ coord_format (str): Astropy coordinate format to convert the SkyCoord to
294
+ defaults to icrs.
295
+
296
+ Returns:
297
+ Astropy.coordinates.SkyCoord of the default coordinate for the transient
298
+ """
299
+
300
+ # now we can generate the SkyCoord
301
+ f = "df['coordinate_type'] == 'equatorial'"
302
+ coord_dict = self._get_default("coordinate", filt=f)
303
+ coordin = self._reformat_coordinate(coord_dict)
304
+ coord = SkyCoord(**coordin).transform_to(coord_format)
305
+
306
+ return coord
307
+
308
+ def get_discovery_date(self) -> Time:
309
+ """
310
+ Get the default discovery date for this Transient
311
+
312
+ Returns:
313
+ astropy.time.Time of the default discovery date
314
+ """
315
+ key = "date_reference"
316
+ try:
317
+ date = self._get_default(key, filt='df["date_type"] == "discovery"')
318
+ except KeyError:
319
+ return None
320
+
321
+ if date is None:
322
+ return date
323
+
324
+ if "date_format" in date:
325
+ f = date["date_format"]
326
+ else:
327
+ f = "mjd"
328
+
329
+ return Time(str(date["value"]).strip(), format=f)
330
+
331
+ def get_redshift(self) -> float:
332
+ """
333
+ Get the default redshift of this Transient
334
+
335
+ Returns:
336
+ Float value of the default redshift
337
+ """
338
+ f = "df['distance_type']=='redshift'"
339
+ default = self._get_default("distance", filt=f)
340
+ if default is None:
341
+ return default
342
+ else:
343
+ return default["value"]
344
+
345
+ def get_classification(self) -> tuple(str, float, list):
346
+ """
347
+ Get the default classification of this Transient.
348
+ This normally corresponds to the highest confidence classification that we have
349
+ stored for the transient.
350
+
351
+ Returns:
352
+ The default object class as a string, the confidence level in that class,
353
+ and a list of the bibcodes corresponding to that classification. Or, None
354
+ if there is no classification.
355
+ """
356
+ default = self._get_default("classification/value")
357
+ if default is None:
358
+ return default
359
+ return default.object_class, default.confidence, default.reference
360
+
361
+ def get_host(self, max_hosts=3, search=False, **kwargs) -> list[Host]:
362
+ """
363
+ Gets the default host information of this Transient. This returns an otter.Host
364
+ object. If search=True, it will also check the BLAST host association database
365
+ for the best match and return it as well. Note that if search is True then
366
+ this has the potential to return max_hosts + 1, if BLAST also returns a result.
367
+ The BLAST result will always be the last value in the returned list.
368
+
369
+ Args:
370
+ max_hosts [int] : The maximum number of hosts to return, default is 3
371
+ **kwargs : keyword arguments to be passed to getGHOST
372
+
373
+ Returns:
374
+ A list of otter.Host objects. This is useful becuase the Host objects have
375
+ useful methods for querying public catalogs for data of the host.
376
+ """
377
+ # first try to get the host information from our local database
378
+ host = []
379
+ if "host" in self:
380
+ max_hosts = min([max_hosts, len(self["host"])])
381
+ for h in self["host"][:max_hosts]:
382
+ # only return hosts with their ra and dec stored
383
+ if (
384
+ "host_ra" not in h
385
+ or "host_dec" not in h
386
+ or "host_ra_units" not in h
387
+ or "host_dec_units" not in h
388
+ ):
389
+ continue
390
+
391
+ # now we can construct a host object from this
392
+ host.append(Host(transient_name=self.default_name, **dict(h)))
393
+
394
+ # then try BLAST
395
+ if search:
396
+ logger.warning(
397
+ "Trying to find a host with BLAST/astro-ghost. Note\
398
+ that this won't work for older targets! See https://blast.scimma.org"
399
+ )
400
+
401
+ # default_name should always be the TNS name if we have one
402
+ print(self.default_name)
403
+ blast_host = Host.query_blast(self.default_name.replace(" ", ""))
404
+ print(blast_host)
405
+ if blast_host is not None:
406
+ host.append(blast_host)
407
+
408
+ return host
409
+
410
+ def _get_default(self, key, filt=None):
411
+ """
412
+ Get the default of key
413
+
414
+ Args:
415
+ key [str]: key in self to look for the default of
416
+ filt [str]: a valid pandas dataframe filter to index a pandas dataframe
417
+ called df.
418
+ """
419
+ if key not in self:
420
+ raise KeyError(f"This transient does not have {key} associated with it!")
421
+
422
+ df = pd.DataFrame(self[key])
423
+ if len(df) == 0:
424
+ raise KeyError(f"This transient does not have {key} associated with it!")
425
+
426
+ if filt is not None:
427
+ df = df[eval(filt)] # apply the filters
428
+
429
+ if "default" in df:
430
+ # first try to get the default
431
+ df_filtered = df[df.default == True]
432
+ if len(df_filtered) == 0:
433
+ df_filtered = df
434
+ else:
435
+ df_filtered = df
436
+
437
+ if len(df_filtered) == 0:
438
+ return None
439
+
440
+ return df_filtered.iloc[0]
441
+
442
+ def _reformat_coordinate(self, item):
443
+ """
444
+ Reformat the coordinate information in item
445
+ """
446
+ coordin = None
447
+ if "ra" in item and "dec" in item:
448
+ # this is an equatorial coordinate
449
+ coordin = {
450
+ "ra": item["ra"],
451
+ "dec": item["dec"],
452
+ "unit": (item["ra_units"], item["dec_units"]),
453
+ }
454
+ elif "l" in item and "b" in item:
455
+ coordin = {
456
+ "l": item["l"],
457
+ "b": item["b"],
458
+ "unit": (item["l_units"], item["b_units"]),
459
+ "frame": "galactic",
460
+ }
461
+
462
+ return coordin
463
+
464
+ def clean_photometry(
465
+ self,
466
+ flux_unit: u.Unit = "mag(AB)",
467
+ date_unit: u.Unit = "MJD",
468
+ freq_unit: u.Unit = "GHz",
469
+ wave_unit: u.Unit = "nm",
470
+ obs_type: str = None,
471
+ deduplicate: Callable | None = None,
472
+ ) -> pd.DataFrame:
473
+ """
474
+ Ensure the photometry associated with this transient is all in the same
475
+ units/system/etc
476
+
477
+ Args:
478
+ flux_unit (astropy.unit.Unit): Either a valid string to convert
479
+ or an astropy.unit.Unit, this can be either
480
+ flux, flux density, or magnitude unit. This
481
+ supports any base units supported by
482
+ synphot
483
+ (https://synphot.readthedocs.io/en/latest/synphot/units.html#flux-units).
484
+ date_unit (str): Valid astropy date format string. See
485
+ https://docs.astropy.org/en/stable/time/index.html#time-format
486
+ freq_unit (astropy.unit.Unit): The astropy unit or string representation of
487
+ an astropy unit to convert and return the
488
+ frequency as. Must have a base unit of
489
+ 1/time (Hz).
490
+ wave_unit (astropy.unit.Unit): The astropy unit or string representation of
491
+ an astropy unit to convert and return the
492
+ wavelength as. Must have a base unit of
493
+ length.
494
+ obs_type (str): "radio", "xray", or "uvoir". If provided, it only returns
495
+ data taken within that range of wavelengths/frequencies.
496
+ Default is None which will return all of the data.
497
+ deduplicate (Callable|None): A function to be used to remove duplicate
498
+ reductions of the same data that produces
499
+ different flux values. The default is the
500
+ otter.deduplicate_photometry method,
501
+ but you can pass
502
+ any callable that takes the output pandas
503
+ dataframe as input. Set this to False if you
504
+ don't want deduplication to occur.
505
+ Returns:
506
+ A pandas DataFrame of the cleaned up photometry in the requested units
507
+ """
508
+ if deduplicate is None:
509
+ deduplicate = self.deduplicate_photometry
510
+
511
+ warn_filt = _DuplicateFilter()
512
+ logger.addFilter(warn_filt)
513
+
514
+ # these imports need to be here for some reason
515
+ # otherwise the code breaks
516
+ from synphot.units import VEGAMAG, convert_flux
517
+ from synphot.spectrum import SourceSpectrum
518
+
519
+ # variable so this warning only displays a single time each time this
520
+ # function is called
521
+ source_map_warning = True
522
+
523
+ # turn the photometry key into a pandas dataframe
524
+ if "photometry" not in self:
525
+ raise FailedQueryError("No photometry for this object!")
526
+
527
+ dfs = []
528
+ for item in self["photometry"]:
529
+ max_len = 0
530
+ for key, val in item.items():
531
+ if isinstance(val, list) and key != "reference":
532
+ max_len = max(max_len, len(val))
533
+
534
+ for key, val in item.items():
535
+ if not isinstance(val, list) or (
536
+ isinstance(val, list) and len(val) != max_len
537
+ ):
538
+ item[key] = [val] * max_len
539
+
540
+ df = pd.DataFrame(item)
541
+ dfs.append(df)
542
+
543
+ if len(dfs) == 0:
544
+ raise FailedQueryError("No photometry for this object!")
545
+ c = pd.concat(dfs)
546
+
547
+ # extract the filter information and substitute in any missing columns
548
+ # because of how we handle this later, we just need to make sure the effective
549
+ # wavelengths are never nan
550
+ def fill_wave(row):
551
+ if "wave_eff" not in row or (
552
+ pd.isna(row.wave_eff) and not pd.isna(row.freq_eff)
553
+ ):
554
+ freq_eff = row.freq_eff * u.Unit(row.freq_units)
555
+ wave_eff = freq_eff.to(u.Unit(wave_unit), equivalencies=u.spectral())
556
+ return wave_eff.value, wave_unit
557
+ elif not pd.isna(row.wave_eff):
558
+ return row.wave_eff, row.wave_units
559
+ else:
560
+ raise ValueError("Missing frequency or wavelength information!")
561
+
562
+ filters = pd.DataFrame(self["filter_alias"])
563
+ res = filters.apply(fill_wave, axis=1)
564
+ filters["wave_eff"], filters["wave_units"] = zip(*res)
565
+ # merge the photometry with the filter information
566
+ df = c.merge(filters, on="filter_key")
567
+
568
+ # drop irrelevant obs_types before continuing
569
+ if obs_type is not None:
570
+ valid_obs_types = {"radio", "uvoir", "xray"}
571
+ if obs_type not in valid_obs_types:
572
+ raise IOError("Please provide a valid obs_type")
573
+ df = df[df.obs_type == obs_type]
574
+
575
+ # add some mockup columns if they don't exist
576
+ if "value" not in df:
577
+ df["value"] = np.nan
578
+ df["value_err"] = np.nan
579
+ df["value_units"] = "NaN"
580
+
581
+ # fix some bad units that are old and no longer recognized by astropy
582
+ with warnings.catch_warnings():
583
+ warnings.filterwarnings("ignore")
584
+ df.raw_units = df.raw_units.str.replace("ergs", "erg")
585
+ df.raw_units = ["mag(AB)" if uu == "AB" else uu for uu in df.raw_units]
586
+ df.value_units = df.value_units.str.replace("ergs", "erg")
587
+ df.value_units = ["mag(AB)" if uu == "AB" else uu for uu in df.value_units]
588
+
589
+ # merge the raw and value keywords based on the requested flux_units
590
+ # first take everything that just has `raw` and not `value`
591
+ df_raw_only = df[df.value.isna()]
592
+ remaining = df[df.value.notna()]
593
+ if len(remaining) == 0:
594
+ df_raw = df_raw_only
595
+ df_value = [] # this tricks the code later
596
+ else:
597
+ # then take the remaining rows and figure out if we want the raw or value
598
+ with warnings.catch_warnings():
599
+ warnings.filterwarnings("ignore")
600
+ flux_unit_astropy = u.Unit(flux_unit)
601
+
602
+ val_unit_filt = np.array(
603
+ [
604
+ u.Unit(uu).is_equivalent(flux_unit_astropy)
605
+ for uu in remaining.value_units
606
+ ]
607
+ )
608
+
609
+ df_value = remaining[val_unit_filt]
610
+ df_raw_and_value = remaining[~val_unit_filt]
611
+
612
+ # then merge the raw dataframes
613
+ df_raw = pd.concat([df_raw_only, df_raw_and_value], axis=0)
614
+
615
+ # then add columns to these dataframes to convert stuff later
616
+ df_raw = df_raw.assign(
617
+ _flux=df_raw["raw"].values,
618
+ _flux_units=df_raw["raw_units"].values,
619
+ _flux_err=(
620
+ df_raw["raw_err"].values
621
+ if "raw_err" in df_raw
622
+ else [np.nan] * len(df_raw)
623
+ ),
624
+ )
625
+
626
+ if len(df_value) == 0:
627
+ df = df_raw
628
+ else:
629
+ df_value = df_value.assign(
630
+ _flux=df_value["value"].values,
631
+ _flux_units=df_value["value_units"].values,
632
+ _flux_err=(
633
+ df_value["value_err"].values
634
+ if "value_err" in df_value
635
+ else [np.nan] * len(df_value)
636
+ ),
637
+ )
638
+
639
+ # then merge df_value and df_raw back into one df
640
+ df = pd.concat([df_raw, df_value], axis=0)
641
+
642
+ # then, for the rest of the code to work, set the "by" variables to _flux
643
+ by = "_flux"
644
+
645
+ # skip rows where 'by' is nan
646
+ df = df[df[by].notna()]
647
+
648
+ # filter out anything that has _flux_units == "ct" because we can't convert that
649
+ try:
650
+ # this is a test case to see if we can convert ct -> flux_unit
651
+ convert_flux(
652
+ [1 * u.nm, 2 * u.nm], 1 * u.ct, u.Unit(flux_unit), area=1 * u.m**2
653
+ )
654
+ except u.UnitsError:
655
+ bad_units = df[df._flux_units == "ct"]
656
+ if len(bad_units) > 0:
657
+ logger.warning(
658
+ f"""Removing {len(bad_units)} photometry points from
659
+ {self.default_name} because we can't convert them from ct ->
660
+ {flux_unit}"""
661
+ )
662
+ df = df[df._flux_units != "ct"]
663
+
664
+ # convert the ads bibcodes to a string of human readable sources here
665
+ def mappedrefs(row):
666
+ if isinstance(row.reference, list):
667
+ return "<br>".join([self.srcmap[bibcode] for bibcode in row.reference])
668
+ else:
669
+ return self.srcmap[row.reference]
670
+
671
+ try:
672
+ df["human_readable_refs"] = df.apply(mappedrefs, axis=1)
673
+ except Exception as exc:
674
+ if source_map_warning:
675
+ source_map_warning = False
676
+ logger.warning(f"Unable to apply the source mapping because {exc}")
677
+
678
+ df["human_readable_refs"] = df.reference
679
+
680
+ # Figure out what columns are good to groupby in the photometry
681
+ outdata = []
682
+
683
+ if "telescope" in df:
684
+ tele = True
685
+ to_grp_by = ["obs_type", by + "_units", "telescope"]
686
+ else:
687
+ tele = False
688
+ to_grp_by = ["obs_type", by + "_units"]
689
+
690
+ # Do the conversion based on what we decided to group by
691
+ for groupedby, data in df.groupby(to_grp_by, dropna=False):
692
+ if tele:
693
+ obstype, unit, telescope = groupedby
694
+ else:
695
+ obstype, unit = groupedby
696
+ telescope = None
697
+
698
+ # get the photometry in the right type
699
+ unit = data[by + "_units"].unique()
700
+ if len(unit) > 1:
701
+ raise OtterLimitationError(
702
+ "Can not apply multiple units for different obs_types"
703
+ )
704
+
705
+ unit = unit[0]
706
+ isvegamag = "vega" in unit.lower()
707
+ try:
708
+ if isvegamag:
709
+ astropy_units = VEGAMAG
710
+ elif unit == "AB":
711
+ # In astropy "AB" is a magnitude SYSTEM not unit and while
712
+ # u.Unit("AB") will succeed without error, it will not produce
713
+ # the expected result!
714
+ # We can assume here that this unit really means astropy's "mag(AB)"
715
+ astropy_units = u.Unit("mag(AB)")
716
+ else:
717
+ with warnings.catch_warnings():
718
+ warnings.simplefilter("ignore")
719
+ astropy_units = u.Unit(unit)
720
+
721
+ except ValueError:
722
+ # this means there is something likely slightly off in the input unit
723
+ # string. Let's try to fix it!
724
+ # here are some common mistakes
725
+ unit = unit.replace("ergs", "erg")
726
+ unit = unit.replace("AB", "mag(AB)")
727
+
728
+ astropy_units = u.Unit(unit)
729
+
730
+ except ValueError:
731
+ raise ValueError(
732
+ "Could not coerce your string into astropy unit format!"
733
+ )
734
+
735
+ # get the flux data and find the type
736
+ indata = np.array(data[by].astype(float))
737
+ err_key = by + "_err"
738
+ if err_key in data:
739
+ indata_err = np.array(data[by + "_err"].astype(float))
740
+ else:
741
+ indata_err = np.zeros(len(data))
742
+
743
+ # convert to an astropy quantity
744
+ with warnings.catch_warnings():
745
+ warnings.filterwarnings("ignore")
746
+ q = indata * u.Unit(astropy_units)
747
+ q_err = indata_err * u.Unit(
748
+ astropy_units
749
+ ) # assume error and values have the same unit
750
+
751
+ # get and save the effective wavelength
752
+ # because of cleaning we did to the filter dataframe above wave_eff
753
+ # should NEVER be nan!
754
+ if np.any(pd.isna(data["wave_eff"])):
755
+ raise ValueError("Flushing out the effective wavelength array failed!")
756
+
757
+ zz = zip(data["wave_eff"], data["wave_units"])
758
+ with warnings.catch_warnings():
759
+ warnings.filterwarnings("ignore")
760
+ wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
761
+ freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
762
+
763
+ data["converted_wave"] = wave_eff.value
764
+ data["converted_wave_unit"] = wave_unit
765
+ data["converted_freq"] = freq_eff.value
766
+ data["converted_freq_unit"] = freq_unit
767
+
768
+ # convert using synphot
769
+ # stuff has to be done slightly differently for xray than for the others
770
+ if obstype == "xray":
771
+ if telescope is not None:
772
+ try:
773
+ area = XRAY_AREAS[telescope.lower()]
774
+ except KeyError:
775
+ raise OtterLimitationError(
776
+ "Did not find an area corresponding to "
777
+ + "this telescope, please add to util!"
778
+ )
779
+ else:
780
+ raise OtterLimitationError(
781
+ "Can not convert x-ray data without a telescope"
782
+ )
783
+
784
+ # we also need to make this wave_min and wave_max
785
+ # instead of just the effective wavelength like for radio and uvoir
786
+ zz = zip(data["wave_min"], data["wave_max"], data["wave_units"])
787
+ with warnings.catch_warnings():
788
+ warnings.filterwarnings("ignore")
789
+ wave_eff = u.Quantity(
790
+ [np.array([m, M]) * u.Unit(uu) for m, M, uu in zz],
791
+ u.Unit(wave_unit),
792
+ )
793
+
794
+ else:
795
+ area = None
796
+
797
+ if obstype == "xray" or isvegamag:
798
+ # we unfortunately have to loop over the points here because
799
+ # syncphot does not work with a 2D array of min max wavelengths
800
+ # for converting counts to other flux units. It also can't convert
801
+ # vega mags with a wavelength array because it interprets that as the
802
+ # wavelengths corresponding to the SourceSpectrum.from_vega()
803
+
804
+ flux, flux_err = [], []
805
+ for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
806
+ with warnings.catch_warnings():
807
+ warnings.filterwarnings("ignore")
808
+ f_val = convert_flux(
809
+ wave,
810
+ xray_point,
811
+ u.Unit(flux_unit),
812
+ vegaspec=SourceSpectrum.from_vega(),
813
+ area=area,
814
+ ).value
815
+
816
+ # approximate the uncertainty as dX = dY/Y * X
817
+ f_err = np.multiply(
818
+ f_val, np.divide(xray_point_err.value, xray_point.value)
819
+ )
820
+
821
+ # then we take the average of the minimum and maximum values
822
+ # computed by syncphot
823
+ flux.append(np.mean(f_val))
824
+ flux_err.append(np.mean(f_err))
825
+
826
+ else:
827
+ # this will be faster and cover most cases
828
+ with warnings.catch_warnings():
829
+ warnings.filterwarnings("ignore")
830
+ flux = convert_flux(wave_eff, q, u.Unit(flux_unit)).value
831
+
832
+ # since the error propagation is different between logarithmic units
833
+ # and linear units, unfortunately
834
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
835
+ # approximate the uncertainty as dX = dY/Y * |ln(10)/2.5|
836
+ prefactor = np.abs(np.log(10) / 2.5) # this is basically 1
837
+ else:
838
+ # approximate the uncertainty as dX = dY/Y * X
839
+ prefactor = flux
840
+
841
+ flux_err = np.multiply(prefactor, np.divide(q_err.value, q.value))
842
+
843
+ flux = np.array(flux) * u.Unit(flux_unit)
844
+ flux_err = np.array(flux_err) * u.Unit(flux_unit)
845
+
846
+ data["converted_flux"] = flux.value
847
+ data["converted_flux_err"] = flux_err.value
848
+ outdata.append(data)
849
+
850
+ if len(outdata) == 0:
851
+ raise FailedQueryError()
852
+ outdata = pd.concat(outdata)
853
+
854
+ # copy over the flux units
855
+ outdata["converted_flux_unit"] = flux_unit
856
+
857
+ # make sure all the datetimes are in the same format here too!!
858
+ times = [
859
+ Time(d, format=f).to_value(date_unit.lower())
860
+ for d, f in zip(outdata.date, outdata.date_format.str.lower())
861
+ ]
862
+ outdata["converted_date"] = times
863
+ outdata["converted_date_unit"] = date_unit
864
+
865
+ # compute the upperlimit value based on a 3 sigma detection
866
+ # this is just for rows where we don't already know if it is an upperlimit
867
+ if isinstance(u.Unit(flux_unit), u.LogUnit):
868
+ # this uses the following formula (which is surprising because it means
869
+ # magnitude upperlimits are independent of the actual measurement!)
870
+ # sigma_m > (1/3) * (ln(10)/2.5)
871
+ def is_upperlimit(row):
872
+ if "upperlimit" in row and pd.isna(row.upperlimit):
873
+ return row.converted_flux_err > np.log(10) / (3 * 2.5)
874
+ else:
875
+ return row.upperlimit
876
+ else:
877
+
878
+ def is_upperlimit(row):
879
+ if "upperlimit" in row and pd.isna(row.upperlimit):
880
+ return row.converted_flux < 3 * row.converted_flux_err
881
+ else:
882
+ return row.upperlimit
883
+
884
+ outdata["upperlimit"] = outdata.apply(is_upperlimit, axis=1)
885
+
886
+ # perform some more complex deduplication of the dataset
887
+ if deduplicate:
888
+ outdata = deduplicate(outdata)
889
+
890
+ # throw a warning if the output dataframe has UV/Optical/IR or Radio data
891
+ # where we don't know if the dataset has been host corrected or not
892
+ if ("corr_host" not in outdata) or (
893
+ len(outdata[pd.isna(outdata.corr_host) * (outdata.obs_type != "xray")]) >= 0
894
+ ):
895
+ logger.warning(
896
+ f"{self.default_name} has at least one photometry point where it is "
897
+ + "unclear if a host subtraction was performed. This can be especially "
898
+ + "detrimental for UV data. Please consider filtering out UV/Optical/IR"
899
+ + " or radio rows where the corr_host column is null/None/NaN."
900
+ )
901
+
902
+ logger.removeFilter(warn_filt)
903
+ return outdata
904
+
905
+ @classmethod
906
+ def deduplicate_photometry(cls, phot: pd.DataFrame, date_tol: int | float = 1):
907
+ """
908
+ This deduplicates a pandas dataframe of photometry that could potentially
909
+ have rows/datasets that are the result of different reductions of the same
910
+ data. This is especially relevant for X-ray and UV observations where different
911
+ reductions can produce different flux values from the same observation.
912
+
913
+ The algorithm used here first finds duplicates by normalizing the telescope
914
+ names, then grouping the dataframe by transient name, norm telescope name,
915
+ filter_key, and the obs_type. It then assumes that data from the same
916
+ reference will not produce duplicated data. Finally, it finds the overlapping
917
+ regions within date +/- date_tol (or between date_min and date_max for binned
918
+ data), and uses any data within that region as duplicated. From there, it
919
+ first tries to choose the reduction that is host subtracted (if only one is
920
+ host subtracted), then if neither or more than one of the datasets are host
921
+ subtracted then it just takes the most recent reduction.
922
+
923
+ Args:
924
+ phot (pd.DataFrame): A pandas dataframe of the photometry with keys defined
925
+ by the OTTER schema
926
+ date_tol (int|float): The default tolerance (or "uncertainty") to use on the
927
+ dates in the "date" column of phot. In days. Defaults
928
+ to 1 day.
929
+ """
930
+ # we need to reset the index to keep track of things appropriately
931
+ phot = phot.reset_index(drop=True)
932
+
933
+ if "telescope" not in phot:
934
+ phot["telescope"] = np.nan
935
+
936
+ # we first have to standardize some columns given some basic assumptions
937
+ phot["_ref_str"] = phot.reference.astype(str)
938
+
939
+ # normalize the telescope name so we can group by it
940
+ phot["_norm_tele_name"] = phot.telescope.apply(cls._normalize_tele_name)
941
+
942
+ # now find the duplicated data
943
+ dups = []
944
+ phot_grpby = phot.groupby(
945
+ ["_norm_tele_name", "filter_key", "obs_type"], dropna=False
946
+ )
947
+ for (tele, filter_key, obs_type), grp in phot_grpby:
948
+ # by definition, there can only be dups if the name, telescope, and filter
949
+ # are the same
950
+
951
+ # if there is only one reference in this group of data, there's no way
952
+ # there are duplicate reductions of the same dataset
953
+ if len(grp._ref_str.unique()) <= 1:
954
+ continue
955
+
956
+ # the next trick is that the dates don't need to be the same, but need to
957
+ # fall inside the same range
958
+ grp["_mean_dates"] = grp.apply(cls._convert_dates, axis=1)
959
+
960
+ if "date_min" in grp and not np.all(pd.isna(grp.date_min)):
961
+ grp["min_dates"] = grp.apply(
962
+ lambda row: cls._convert_dates(row, date_key="date_min"), axis=1
963
+ ).astype(float)
964
+ grp["max_dates"] = grp.apply(
965
+ lambda row: cls._convert_dates(row, date_key="date_max"), axis=1
966
+ ).astype(float)
967
+
968
+ # in case any of the min_date and max_date in the grp are nan
969
+ grp.fillna(
970
+ {
971
+ "min_dates": grp._mean_dates - date_tol,
972
+ "max_dates": grp._mean_dates + date_tol,
973
+ },
974
+ inplace=True,
975
+ )
976
+
977
+ elif "date_err" in grp and not np.any(pd.isna(grp.date_err)):
978
+ grp["min_dates"] = (grp._mean_dates - grp.date_err).astype(float)
979
+ grp["max_dates"] = (grp._mean_dates + grp.date_err).astype(float)
980
+ else:
981
+ # then assume some uncertainty on the date
982
+ grp["min_dates"] = (grp._mean_dates - date_tol).astype(float)
983
+ grp["max_dates"] = (grp._mean_dates + date_tol).astype(float)
984
+
985
+ ref_ranges = [
986
+ (subgrp.min_dates.min(), subgrp.max_dates.max())
987
+ for _, subgrp in grp.groupby("_ref_str")
988
+ ]
989
+
990
+ overlaps = cls._find_overlapping_regions(ref_ranges)
991
+
992
+ if len(overlaps) == 0:
993
+ continue # then there are no dups
994
+
995
+ for min_overlap, max_overlap in overlaps:
996
+ dup_data = grp[
997
+ (grp.min_dates >= min_overlap) * (grp.max_dates <= max_overlap)
998
+ ]
999
+
1000
+ if len(dup_data) == 0:
1001
+ continue # no data falls in this range!
1002
+
1003
+ dups.append(dup_data)
1004
+
1005
+ # now that we've found the duplicated datasets, we can iterate through them
1006
+ # and choose the "default"
1007
+ phot_res = deepcopy(phot)
1008
+ undupd = []
1009
+ for dup in dups:
1010
+ try:
1011
+ phot_res = phot_res.drop(dup.index) # we'll append back in the non dup
1012
+ except KeyError:
1013
+ continue # we already deleted these ones
1014
+
1015
+ # first, check if only one of the dup reductions host subtracted
1016
+ if "corr_host" in dup:
1017
+ dup_host_corr = dup[dup.corr_host.astype(bool)]
1018
+ host_corr_refs = dup_host_corr.human_readable_refs.unique()
1019
+ if len(host_corr_refs) == 1:
1020
+ # then one of the reductions is host corrected and the other isn't!
1021
+ undupd.append(dup[dup.human_readable_refs == host_corr_refs[0]])
1022
+ continue
1023
+
1024
+ bibcodes_sorted_by_year = sorted(dup._ref_str.unique(), key=cls._find_year)
1025
+ dataset_to_use = dup[dup._ref_str == bibcodes_sorted_by_year[0]]
1026
+ undupd.append(dataset_to_use)
1027
+
1028
+ # then return the full photometry dataset but with the dups removed!
1029
+ return pd.concat([phot_res] + undupd).reset_index()
1030
+
1031
+ @staticmethod
1032
+ def _normalize_tele_name(tele_name):
1033
+ if pd.isna(tele_name):
1034
+ return tele_name
1035
+
1036
+ common_delims = ["-", "/", " ", "."]
1037
+ for delim in common_delims:
1038
+ tele_name = tele_name.replace(delim, ":*:")
1039
+
1040
+ # this assumes that the telescope name will almost always be first,
1041
+ # before other delimiters
1042
+ return tele_name.split(":*:")[0].lower()
1043
+
1044
+ @staticmethod
1045
+ def _convert_dates(row, date_key="date"):
1046
+ """Make sure the dates are in MJD"""
1047
+ if pd.isna(row[date_key]):
1048
+ return row[date_key]
1049
+
1050
+ return Time(row[date_key], format=row.date_format.lower()).mjd
1051
+
1052
+ @staticmethod
1053
+ def _find_overlapping_regions(intervals):
1054
+ """Find the overlaps in a list of tuples of mins and maxs. This is relatively
1055
+ inefficient but the len(intervals) should be < 10 so it should be fine"""
1056
+ overlap_ranges = []
1057
+ for ii, (start_ii, end_ii) in enumerate(intervals):
1058
+ for jj, (start_jj, end_jj) in enumerate(intervals):
1059
+ if ii <= jj:
1060
+ continue
1061
+
1062
+ if start_ii > start_jj:
1063
+ start = start_ii
1064
+ else:
1065
+ start = start_jj
1066
+
1067
+ if end_ii > end_jj:
1068
+ end = end_jj
1069
+ else:
1070
+ end = end_ii
1071
+
1072
+ if start < end:
1073
+ # then there is an overlap!
1074
+ overlap_ranges.append((start, end))
1075
+
1076
+ return overlap_ranges
1077
+
1078
+ @staticmethod
1079
+ def _find_year(s):
1080
+ match = re.search(r"\d{4}", s)
1081
+ return int(match.group()) if match else float("inf")
1082
+
1083
+ def _merge_names(t1, t2, out): # noqa: N805
1084
+ """
1085
+ Private method to merge the name data in t1 and t2 and put it in out
1086
+ """
1087
+ key = "name"
1088
+ out[key] = {}
1089
+
1090
+ # first deal with the default_name key
1091
+ # we are gonna need to use some regex magic to choose a preferred default_name
1092
+ if t1[key]["default_name"] == t2[key]["default_name"]:
1093
+ out[key]["default_name"] = t1[key]["default_name"]
1094
+ else:
1095
+ # we need to decide which default_name is better
1096
+ # it should be the one that matches the TNS style
1097
+ # let's use regex
1098
+ n1 = t1[key]["default_name"]
1099
+ n2 = t2[key]["default_name"]
1100
+
1101
+ # write some discriminating regex expressions
1102
+ # exp1: starts with a number, this is preferred because it is TNS style
1103
+ exp1 = "^[0-9]"
1104
+ # exp2: starts with any character, also preferred because it is TNS style
1105
+ exp2 = ".$"
1106
+ # exp3: checks if first four characters are a number, like a year :),
1107
+ # this is pretty strict though
1108
+ exp3 = "^[0-9]{3}"
1109
+ # exp4: # checks if it starts with AT like TNS names
1110
+ exp4 = "^AT"
1111
+
1112
+ # combine all the regex expressions, this makes it easier to add more later
1113
+ exps = [exp1, exp2, exp3, exp4]
1114
+
1115
+ # score each default_name based on this
1116
+ score1 = 0
1117
+ score2 = 0
1118
+ for e in exps:
1119
+ re1 = re.findall(e, n1)
1120
+ re2 = re.findall(e, n2)
1121
+ if re1:
1122
+ score1 += 1
1123
+ if re2:
1124
+ score2 += 1
1125
+
1126
+ # assign a default_name based on the score
1127
+ if score1 > score2:
1128
+ out[key]["default_name"] = t1[key]["default_name"]
1129
+ elif score2 > score1:
1130
+ out[key]["default_name"] = t2[key]["default_name"]
1131
+ else:
1132
+ logger.warning(
1133
+ "Names have the same score! Just using the existing default_name"
1134
+ )
1135
+ out[key]["default_name"] = t1[key]["default_name"]
1136
+
1137
+ # now deal with aliases
1138
+ # create a reference mapping for each
1139
+ t1map = {}
1140
+ for val in t1[key]["alias"]:
1141
+ ref = val["reference"]
1142
+ if isinstance(ref, str):
1143
+ t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
1144
+ else:
1145
+ t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
1146
+
1147
+ t2map = {}
1148
+ for val in t2[key]["alias"]:
1149
+ ref = val["reference"]
1150
+ if isinstance(ref, str):
1151
+ t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
1152
+ else:
1153
+ t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
1154
+
1155
+ # figure out which ones we need to be careful with references in
1156
+ inboth = list(
1157
+ t1map.keys() & t2map.keys()
1158
+ ) # in both so we'll have to merge the reference key
1159
+ int1 = list(t1map.keys() - t2map.keys()) # only in t1
1160
+ int2 = list(t2map.keys() - t1map.keys()) # only in t2
1161
+
1162
+ # add ones that are not in both first, these are easy
1163
+ line1 = [{"value": k, "reference": t1map[k]} for k in int1]
1164
+ line2 = [{"value": k, "reference": t2map[k]} for k in int2]
1165
+ bothlines = [{"value": k, "reference": t1map[k] + t2map[k]} for k in inboth]
1166
+ out[key]["alias"] = line2 + line1 + bothlines
1167
+
1168
+ def _merge_filter_alias(t1, t2, out): # noqa: N805
1169
+ """
1170
+ Combine the filter alias lists across the transient objects
1171
+ """
1172
+
1173
+ key = "filter_alias"
1174
+
1175
+ out[key] = deepcopy(t1[key])
1176
+ keys1 = {filt["filter_key"] for filt in t1[key]}
1177
+ for filt in t2[key]:
1178
+ if filt["filter_key"] not in keys1:
1179
+ out[key].append(filt)
1180
+
1181
+ def _merge_schema_version(t1, t2, out): # noqa: N805
1182
+ """
1183
+ Just keep whichever schema version is greater
1184
+ """
1185
+ key = "schema_version/value"
1186
+ if "comment" not in t1["schema_version"]:
1187
+ t1["schema_version/comment"] = ""
1188
+
1189
+ if "comment" not in t2["schema_version"]:
1190
+ t2["schema_version/comment"] = ""
1191
+
1192
+ if key in t1 and key in t2 and int(t1[key]) > int(t2[key]):
1193
+ out["schema_version"] = deepcopy(t1["schema_version"])
1194
+ else:
1195
+ out["schema_version"] = deepcopy(t2["schema_version"])
1196
+
1197
+ out["schema_version"]["comment"] = (
1198
+ t1["schema_version/comment"] + ";" + t2["schema_version/comment"]
1199
+ )
1200
+
1201
+ def _merge_photometry(t1, t2, out): # noqa: N805
1202
+ """
1203
+ Combine photometry sources
1204
+ """
1205
+
1206
+ key = "photometry"
1207
+
1208
+ out[key] = deepcopy(t1[key])
1209
+ refs = [] # np.array([d["reference"] for d in out[key]])
1210
+ # merge_dups = lambda val: np.sum(val) if np.any(val.isna()) else val.iloc[0]
1211
+ for val in out[key]:
1212
+ if isinstance(val, list):
1213
+ refs += val
1214
+ elif isinstance(val, np.ndarray):
1215
+ refs += list(val)
1216
+ else:
1217
+ refs.append(val)
1218
+
1219
+ for val in t2[key]:
1220
+ # first check if t2's reference is in out
1221
+ if val["reference"] not in refs:
1222
+ # it's not here so we can just append the new photometry!
1223
+ out[key].append(val)
1224
+ else:
1225
+ # we need to merge it with other photometry
1226
+ i1 = np.where(val["reference"] == refs)[0][0]
1227
+ df1 = pd.DataFrame(out[key][i1])
1228
+ df2 = pd.DataFrame(val)
1229
+
1230
+ # only substitute in values that are nan in df1 or new
1231
+ # the combined keys of the two
1232
+ mergeon = list(set(df1.keys()) & set(df2.keys()))
1233
+ df = df1.merge(df2, on=mergeon, how="outer")
1234
+ # convert to a dictionary
1235
+ newdict = df.reset_index().to_dict(orient="list")
1236
+ del newdict["index"]
1237
+
1238
+ newdict["reference"] = newdict["reference"][0]
1239
+
1240
+ out[key][i1] = newdict # replace the dictionary at i1 with the new dict
1241
+
1242
+ def _merge_class(t1, t2, out): # noqa: N805
1243
+ """
1244
+ Combine the classification attribute
1245
+ """
1246
+ key = "classification"
1247
+ subkey = "value"
1248
+ out[key] = deepcopy(t1[key])
1249
+ classes = np.array([item["object_class"] for item in out[key][subkey]])
1250
+
1251
+ for item in t2[key][subkey]:
1252
+ if item["object_class"] in classes:
1253
+ i = np.where(item["object_class"] == classes)[0][0]
1254
+ if int(item["confidence"]) > int(out[key][subkey][i]["confidence"]):
1255
+ out[key][subkey][i]["confidence"] = item[
1256
+ "confidence"
1257
+ ] # we are now more confident
1258
+
1259
+ if not isinstance(out[key][subkey][i]["reference"], list):
1260
+ out[key][subkey][i]["reference"] = [
1261
+ out[key][subkey][i]["reference"]
1262
+ ]
1263
+
1264
+ if not isinstance(item["reference"], list):
1265
+ item["reference"] = [item["reference"]]
1266
+
1267
+ newdata = list(
1268
+ np.unique(out[key][subkey][i]["reference"] + item["reference"])
1269
+ )
1270
+ out[key][subkey][i]["reference"] = newdata
1271
+
1272
+ else:
1273
+ out[key][subkey].append(item)
1274
+
1275
+ # now that we have all of them we need to figure out which one is the default
1276
+ maxconf = max(out[key][subkey], key=lambda d: d["confidence"])
1277
+ for item in out[key][subkey]:
1278
+ if item == maxconf:
1279
+ item["default"] = True
1280
+ else:
1281
+ item["default"] = False
1282
+
1283
+ # then rederive the classification flags
1284
+ out = Transient._derive_classification_flags(out)
1285
+
1286
+ @classmethod
1287
+ def _derive_classification_flags(cls, out):
1288
+ """
1289
+ Derive the classification flags based on the confidence flags. This will find
1290
+ - spec_classed
1291
+ - unambiguous
1292
+
1293
+ See the paper for a detailed description of how this algorithm makes its
1294
+ choices
1295
+ """
1296
+
1297
+ if "classification" not in out or "value" not in out["classification"]:
1298
+ # this means that the transient doesn't have any classifications
1299
+ # just return itself without any changes
1300
+ return out
1301
+
1302
+ # get the confidences of all of the classifications of this transient
1303
+ confs = np.array(
1304
+ [item["confidence"] for item in out["classification"]["value"]]
1305
+ ).astype(float)
1306
+
1307
+ all_class_roots = np.array(
1308
+ [
1309
+ _fuzzy_class_root(item["object_class"])
1310
+ for item in out["classification"]["value"]
1311
+ ]
1312
+ )
1313
+
1314
+ if np.any(confs >= 3):
1315
+ unambiguous = len(np.unique(all_class_roots)) == 1
1316
+ if np.any(confs == 3) or np.any(confs == 3.3):
1317
+ # this is a "gold spectrum"
1318
+ spec_classed = 3
1319
+ elif np.any(confs == 3.2):
1320
+ # this is a silver spectrum
1321
+ spec_classed = 2
1322
+ elif np.any(confs == 3.1):
1323
+ # this is a bronze spectrum
1324
+ spec_classed = 1
1325
+ else:
1326
+ raise ValueError("Not prepared for this confidence flag!")
1327
+
1328
+ elif np.any(confs == 2):
1329
+ # these always have spec_classed = True, by definition
1330
+ # They also have unambiguous = False by definition because they don't
1331
+ # have a peer reviewed citation for their classification
1332
+ spec_classed = 1
1333
+ unambiguous = False
1334
+
1335
+ elif np.any(confs == 1):
1336
+ spec_classed = 0 # by definition
1337
+ unambiguous = len(np.unique(all_class_roots)) == 1
1338
+
1339
+ else:
1340
+ spec_classed = 0
1341
+ unambiguous = False
1342
+
1343
+ # finally, set these keys in the classification dict
1344
+ out["classification"]["spec_classed"] = spec_classed
1345
+ out["classification"]["unambiguous"] = unambiguous
1346
+
1347
+ return out
1348
+
1349
+ @staticmethod
1350
+ def _merge_arbitrary(key, t1, t2, out, merge_subkeys=None, groupby_key=None):
1351
+ """
1352
+ Merge two arbitrary datasets inside the json file using pandas
1353
+
1354
+ The datasets in t1 and t2 in "key" must be able to be forced into
1355
+ a NxM pandas dataframe!
1356
+ """
1357
+
1358
+ if key == "name":
1359
+ t1._merge_names(t2, out)
1360
+ elif key == "filter_alias":
1361
+ t1._merge_filter_alias(t2, out)
1362
+ elif key == "schema_version":
1363
+ t1._merge_schema_version(t2, out)
1364
+ elif key == "photometry":
1365
+ t1._merge_photometry(t2, out)
1366
+ elif key == "classification":
1367
+ t1._merge_class(t2, out)
1368
+ else:
1369
+ # this is where we can standardize some of the merging
1370
+ df1 = pd.DataFrame(t1[key])
1371
+ df2 = pd.DataFrame(t2[key])
1372
+
1373
+ merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
1374
+
1375
+ # have to get the indexes to drop using a string rep of the df
1376
+ # this is cause we have lists in some cells
1377
+ # We also need to deal with merging the lists of references across rows
1378
+ # that we deem to be duplicates. This solution to do this quickly is from
1379
+ # https://stackoverflow.com/questions/36271413/ \
1380
+ # pandas-merge-nearly-duplicate-rows-based-on-column-value
1381
+ if merge_subkeys is None:
1382
+ merge_subkeys = merged_with_dups.columns.tolist()
1383
+ merge_subkeys.remove("reference")
1384
+ else:
1385
+ for k in merge_subkeys:
1386
+ if k not in merged_with_dups:
1387
+ merge_subkeys.remove(k)
1388
+
1389
+ merged = (
1390
+ merged_with_dups.astype(str)
1391
+ .groupby(merge_subkeys)["reference"]
1392
+ .apply(lambda x: x.sum())
1393
+ .reset_index()
1394
+ )
1395
+
1396
+ # then we have to turn the merged reference strings into a string list
1397
+ merged["reference"] = merged.reference.str.replace("][", ",")
1398
+
1399
+ # then eval the string of a list to get back an actual list of sources
1400
+ merged["reference"] = merged.reference.apply(
1401
+ lambda v: np.unique(eval(v)).tolist()
1402
+ )
1403
+
1404
+ # decide on default values
1405
+ if groupby_key is None:
1406
+ iterate_through = [(0, merged)]
1407
+ else:
1408
+ iterate_through = merged.groupby(groupby_key)
1409
+
1410
+ # we will make whichever value has more references the default
1411
+ outdict = []
1412
+ for data_type, df in iterate_through:
1413
+ lengths = df.reference.map(len)
1414
+ max_idx_arr = np.argmax(lengths)
1415
+
1416
+ if isinstance(max_idx_arr, np.int64):
1417
+ max_idx = max_idx_arr
1418
+ elif len(max_idx_arr) == 0:
1419
+ raise ValueError("Something went wrong with deciding the default")
1420
+ else:
1421
+ max_idx = max_idx_arr[0] # arbitrarily choose the first
1422
+
1423
+ defaults = np.full(len(df), False, dtype=bool)
1424
+ defaults[max_idx] = True
1425
+
1426
+ df["default"] = defaults
1427
+ outdict.append(df)
1428
+ outdict = pd.concat(outdict)
1429
+
1430
+ # from https://stackoverflow.com/questions/52504972/ \
1431
+ # converting-a-pandas-df-to-json-without-nan
1432
+ outdict = outdict.replace("nan", np.nan)
1433
+ outdict_cleaned = [{**x[i]} for i, x in outdict.stack().groupby(level=0)]
1434
+
1435
+ out[key] = outdict_cleaned
1436
+
1437
+
1438
+ def _fuzzy_class_root(s):
1439
+ """
1440
+ Extract the fuzzy classification root name from the string s
1441
+ """
1442
+ s = s.upper()
1443
+ # first split the class s using regex
1444
+ for root in _KNOWN_CLASS_ROOTS:
1445
+ if s.startswith(root):
1446
+ remaining = s[len(root) :]
1447
+ if remaining and root == "SN":
1448
+ # we want to be able to distinguish between SN Ia and SN II
1449
+ # we will use SN Ia to indicate thoes and SN to indicate CCSN
1450
+ if "IA" in remaining or "1A" in remaining:
1451
+ return "SN Ia"
1452
+ return root
1453
+ return s