astro-otter 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astro_otter-0.6.0.dist-info/METADATA +161 -0
- astro_otter-0.6.0.dist-info/RECORD +18 -0
- astro_otter-0.6.0.dist-info/WHEEL +5 -0
- astro_otter-0.6.0.dist-info/licenses/LICENSE +21 -0
- astro_otter-0.6.0.dist-info/top_level.txt +1 -0
- otter/__init__.py +19 -0
- otter/_version.py +5 -0
- otter/exceptions.py +74 -0
- otter/io/__init__.py +0 -0
- otter/io/data_finder.py +1045 -0
- otter/io/host.py +186 -0
- otter/io/otter.py +1594 -0
- otter/io/transient.py +1453 -0
- otter/plotter/__init__.py +0 -0
- otter/plotter/otter_plotter.py +76 -0
- otter/plotter/plotter.py +266 -0
- otter/schema.py +312 -0
- otter/util.py +850 -0
otter/io/transient.py
ADDED
|
@@ -0,0 +1,1453 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Class for a transient,
|
|
3
|
+
basically just inherits the dict properties with some overwriting
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import warnings
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
import re
|
|
10
|
+
from collections.abc import MutableMapping
|
|
11
|
+
from typing import Callable
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
import astropy.units as u
|
|
19
|
+
from astropy.time import Time
|
|
20
|
+
from astropy.coordinates import SkyCoord
|
|
21
|
+
|
|
22
|
+
from ..exceptions import (
|
|
23
|
+
FailedQueryError,
|
|
24
|
+
IOError,
|
|
25
|
+
OtterLimitationError,
|
|
26
|
+
TransientMergeError,
|
|
27
|
+
)
|
|
28
|
+
from ..util import XRAY_AREAS, _KNOWN_CLASS_ROOTS, _DuplicateFilter
|
|
29
|
+
from .host import Host
|
|
30
|
+
|
|
31
|
+
np.seterr(divide="ignore")
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Transient(MutableMapping):
|
|
36
|
+
def __init__(self, d={}, name=None):
|
|
37
|
+
"""
|
|
38
|
+
Overwrite the dictionary init
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
d (dict): A transient dictionary
|
|
42
|
+
name (str): The default name of the transient, default is None and it will
|
|
43
|
+
be inferred from the input dictionary.
|
|
44
|
+
"""
|
|
45
|
+
self.data = d
|
|
46
|
+
|
|
47
|
+
if "reference_alias" in self:
|
|
48
|
+
self.srcmap = {
|
|
49
|
+
ref["name"]: ref["human_readable_name"]
|
|
50
|
+
for ref in self["reference_alias"]
|
|
51
|
+
}
|
|
52
|
+
self.srcmap["TNS"] = "TNS"
|
|
53
|
+
else:
|
|
54
|
+
self.srcmap = {}
|
|
55
|
+
|
|
56
|
+
if "name" in self:
|
|
57
|
+
if "default_name" in self["name"]:
|
|
58
|
+
self.default_name = self["name"]["default_name"]
|
|
59
|
+
else:
|
|
60
|
+
raise AttributeError("Missing the default name!!")
|
|
61
|
+
elif name is not None:
|
|
62
|
+
self.default_name = name
|
|
63
|
+
else:
|
|
64
|
+
self.default_name = "Missing Default Name"
|
|
65
|
+
|
|
66
|
+
# Make it so all coordinates are astropy skycoords
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, keys):
|
|
69
|
+
"""
|
|
70
|
+
Override getitem to recursively access Transient elements
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
if isinstance(keys, (list, tuple)):
|
|
74
|
+
return Transient({key: self[key] for key in keys if key in self})
|
|
75
|
+
elif isinstance(keys, str) and "/" in keys: # this is for a path
|
|
76
|
+
s = "']['".join(keys.split("/"))
|
|
77
|
+
s = "['" + s
|
|
78
|
+
s += "']"
|
|
79
|
+
return eval(f"self{s}")
|
|
80
|
+
elif (
|
|
81
|
+
isinstance(keys, int)
|
|
82
|
+
or keys.isdigit()
|
|
83
|
+
or (keys[0] == "-" and keys[1:].isdigit())
|
|
84
|
+
):
|
|
85
|
+
# this is for indexing a sublist
|
|
86
|
+
return self[int(keys)]
|
|
87
|
+
else:
|
|
88
|
+
return self.data[keys]
|
|
89
|
+
|
|
90
|
+
def __setitem__(self, key, value):
|
|
91
|
+
"""
|
|
92
|
+
Override set item to work with the '/' syntax
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
if isinstance(key, str) and "/" in key: # this is for a path
|
|
96
|
+
s = "']['".join(key.split("/"))
|
|
97
|
+
s = "['" + s
|
|
98
|
+
s += "']"
|
|
99
|
+
exec(f"self{s} = value")
|
|
100
|
+
else:
|
|
101
|
+
self.data[key] = value
|
|
102
|
+
|
|
103
|
+
def __delitem__(self, keys):
|
|
104
|
+
if "/" in keys:
|
|
105
|
+
raise OtterLimitationError(
|
|
106
|
+
"For security, we can not delete with the / syntax!"
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
del self.data[keys]
|
|
110
|
+
|
|
111
|
+
def __iter__(self):
|
|
112
|
+
return iter(self.data)
|
|
113
|
+
|
|
114
|
+
def __len__(self):
|
|
115
|
+
return len(self.data)
|
|
116
|
+
|
|
117
|
+
def __repr__(self):
|
|
118
|
+
return f"Transient(\n\tName: {self.default_name},\n\tKeys: {self.keys()}\n)"
|
|
119
|
+
|
|
120
|
+
def keys(self):
|
|
121
|
+
return self.data.keys()
|
|
122
|
+
|
|
123
|
+
def __add__(self, other, strict_merge=True):
|
|
124
|
+
"""
|
|
125
|
+
Merge this transient object with another transient object
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
other [Transient]: A Transient object to merge with
|
|
129
|
+
strict_merge [bool]: If True it won't let you merge objects that
|
|
130
|
+
intuitively shouldn't be merged (ie. different
|
|
131
|
+
transient events).
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
# first check that this object is within a good distance of the other object
|
|
135
|
+
if (
|
|
136
|
+
strict_merge
|
|
137
|
+
and self.get_skycoord().separation(other.get_skycoord()) > 10 * u.arcsec
|
|
138
|
+
):
|
|
139
|
+
raise TransientMergeError(
|
|
140
|
+
"These two transients are not within 10 arcseconds!"
|
|
141
|
+
+ " They probably do not belong together! If they do"
|
|
142
|
+
+ " You can set strict_merge=False to override the check"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# create set of the allowed keywords
|
|
146
|
+
allowed_keywords = {
|
|
147
|
+
"name",
|
|
148
|
+
"date_reference",
|
|
149
|
+
"coordinate",
|
|
150
|
+
"distance",
|
|
151
|
+
"filter_alias",
|
|
152
|
+
"schema_version",
|
|
153
|
+
"photometry",
|
|
154
|
+
"classification",
|
|
155
|
+
"host",
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
merge_subkeys_map = {
|
|
159
|
+
"name": None,
|
|
160
|
+
"date_reference": ["value", "date_format", "date_type"],
|
|
161
|
+
"coordinate": None, # may need to update this if we run into problems
|
|
162
|
+
"distance": ["value", "distance_type", "unit"],
|
|
163
|
+
"filter_alias": None,
|
|
164
|
+
"schema_version": None,
|
|
165
|
+
"photometry": None,
|
|
166
|
+
"classification": None,
|
|
167
|
+
"host": [
|
|
168
|
+
"host_ra",
|
|
169
|
+
"host_dec",
|
|
170
|
+
"host_ra_units",
|
|
171
|
+
"host_dec_units",
|
|
172
|
+
"host_name",
|
|
173
|
+
],
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
groupby_key_for_default_map = {
|
|
177
|
+
"name": None,
|
|
178
|
+
"date_reference": "date_type",
|
|
179
|
+
"coordinate": "coordinate_type",
|
|
180
|
+
"distance": "distance_type",
|
|
181
|
+
"filter_alias": None,
|
|
182
|
+
"schema_version": None,
|
|
183
|
+
"photometry": None,
|
|
184
|
+
"classification": None,
|
|
185
|
+
"host": None,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
# create a blank dictionary since we don't want to overwrite this object
|
|
189
|
+
out = {}
|
|
190
|
+
|
|
191
|
+
# find the keys that are
|
|
192
|
+
merge_keys = list(
|
|
193
|
+
self.keys() & other.keys()
|
|
194
|
+
) # in both t1 and t2 so we need to merge these keys
|
|
195
|
+
only_in_t1 = list(self.keys() - other.keys()) # only in t1
|
|
196
|
+
only_in_t2 = list(other.keys() - self.keys()) # only in t2
|
|
197
|
+
|
|
198
|
+
# now let's handle the merge keys
|
|
199
|
+
for key in merge_keys:
|
|
200
|
+
# reference_alias is special
|
|
201
|
+
# we ALWAYS should combine these two
|
|
202
|
+
if key == "reference_alias":
|
|
203
|
+
out[key] = self[key]
|
|
204
|
+
if self[key] != other[key]:
|
|
205
|
+
# only add t2 values if they aren't already in it
|
|
206
|
+
bibcodes = {ref["name"] for ref in self[key]}
|
|
207
|
+
for val in other[key]:
|
|
208
|
+
if val["name"] not in bibcodes:
|
|
209
|
+
out[key].append(val)
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
# we can skip this merge process and just add the values from t1
|
|
213
|
+
# if they are equal. We should still add the new reference though!
|
|
214
|
+
if self[key] == other[key]:
|
|
215
|
+
# set the value
|
|
216
|
+
# we don't need to worry about references because this will
|
|
217
|
+
# only be true if the reference is also equal!
|
|
218
|
+
out[key] = deepcopy(self[key])
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
# There are some special keys that we are expecting
|
|
222
|
+
if key in allowed_keywords:
|
|
223
|
+
Transient._merge_arbitrary(
|
|
224
|
+
key,
|
|
225
|
+
self,
|
|
226
|
+
other,
|
|
227
|
+
out,
|
|
228
|
+
merge_subkeys=merge_subkeys_map[key],
|
|
229
|
+
groupby_key=groupby_key_for_default_map[key],
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
# this is an unexpected key!
|
|
233
|
+
if strict_merge:
|
|
234
|
+
# since this is a strict merge we don't want unexpected data!
|
|
235
|
+
raise TransientMergeError(f"{key} was not expected! Can not merge!")
|
|
236
|
+
else:
|
|
237
|
+
# Throw a warning and only keep the old stuff
|
|
238
|
+
logger.warning(
|
|
239
|
+
f"{key} was not expected! Only keeping the old information!"
|
|
240
|
+
)
|
|
241
|
+
out[key] = deepcopy(self[key])
|
|
242
|
+
|
|
243
|
+
# and now combining out with the stuff only in t1 and t2
|
|
244
|
+
out = out | dict(self[only_in_t1]) | dict(other[only_in_t2])
|
|
245
|
+
|
|
246
|
+
# now return out as a Transient Object
|
|
247
|
+
return Transient(out)
|
|
248
|
+
|
|
249
|
+
def get_meta(self, keys=None) -> Self:
|
|
250
|
+
"""
|
|
251
|
+
Get the metadata (no photometry or spectra)
|
|
252
|
+
|
|
253
|
+
This essentially just wraps on __getitem__ but with some checks
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
keys (list[str]) : list of keys to get the metadata for from the transient
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
A Transient object of just the meta data
|
|
260
|
+
"""
|
|
261
|
+
if keys is None:
|
|
262
|
+
keys = list(self.keys())
|
|
263
|
+
|
|
264
|
+
# note: using the remove method is safe here because dict keys are unique
|
|
265
|
+
if "photometry" in keys:
|
|
266
|
+
keys.remove("photometry")
|
|
267
|
+
if "spectra" in keys:
|
|
268
|
+
keys.remove("spectra")
|
|
269
|
+
else:
|
|
270
|
+
# run some checks
|
|
271
|
+
if "photometry" in keys:
|
|
272
|
+
logger.warning("Not returing the photometry!")
|
|
273
|
+
_ = keys.pop("photometry")
|
|
274
|
+
if "spectra" in keys:
|
|
275
|
+
logger.warning("Not returning the spectra!")
|
|
276
|
+
_ = keys.pop("spectra")
|
|
277
|
+
|
|
278
|
+
curr_keys = self.keys()
|
|
279
|
+
for key in keys:
|
|
280
|
+
if key not in curr_keys:
|
|
281
|
+
keys.remove(key)
|
|
282
|
+
logger.warning(
|
|
283
|
+
f"Not returning {key} because it is not in this transient!"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return self[keys]
|
|
287
|
+
|
|
288
|
+
def get_skycoord(self, coord_format="icrs") -> SkyCoord:
|
|
289
|
+
"""
|
|
290
|
+
Convert the coordinates to an astropy SkyCoord
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
coord_format (str): Astropy coordinate format to convert the SkyCoord to
|
|
294
|
+
defaults to icrs.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Astropy.coordinates.SkyCoord of the default coordinate for the transient
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
# now we can generate the SkyCoord
|
|
301
|
+
f = "df['coordinate_type'] == 'equatorial'"
|
|
302
|
+
coord_dict = self._get_default("coordinate", filt=f)
|
|
303
|
+
coordin = self._reformat_coordinate(coord_dict)
|
|
304
|
+
coord = SkyCoord(**coordin).transform_to(coord_format)
|
|
305
|
+
|
|
306
|
+
return coord
|
|
307
|
+
|
|
308
|
+
def get_discovery_date(self) -> Time:
|
|
309
|
+
"""
|
|
310
|
+
Get the default discovery date for this Transient
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
astropy.time.Time of the default discovery date
|
|
314
|
+
"""
|
|
315
|
+
key = "date_reference"
|
|
316
|
+
try:
|
|
317
|
+
date = self._get_default(key, filt='df["date_type"] == "discovery"')
|
|
318
|
+
except KeyError:
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
if date is None:
|
|
322
|
+
return date
|
|
323
|
+
|
|
324
|
+
if "date_format" in date:
|
|
325
|
+
f = date["date_format"]
|
|
326
|
+
else:
|
|
327
|
+
f = "mjd"
|
|
328
|
+
|
|
329
|
+
return Time(str(date["value"]).strip(), format=f)
|
|
330
|
+
|
|
331
|
+
def get_redshift(self) -> float:
|
|
332
|
+
"""
|
|
333
|
+
Get the default redshift of this Transient
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Float value of the default redshift
|
|
337
|
+
"""
|
|
338
|
+
f = "df['distance_type']=='redshift'"
|
|
339
|
+
default = self._get_default("distance", filt=f)
|
|
340
|
+
if default is None:
|
|
341
|
+
return default
|
|
342
|
+
else:
|
|
343
|
+
return default["value"]
|
|
344
|
+
|
|
345
|
+
def get_classification(self) -> tuple(str, float, list):
|
|
346
|
+
"""
|
|
347
|
+
Get the default classification of this Transient.
|
|
348
|
+
This normally corresponds to the highest confidence classification that we have
|
|
349
|
+
stored for the transient.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
The default object class as a string, the confidence level in that class,
|
|
353
|
+
and a list of the bibcodes corresponding to that classification. Or, None
|
|
354
|
+
if there is no classification.
|
|
355
|
+
"""
|
|
356
|
+
default = self._get_default("classification/value")
|
|
357
|
+
if default is None:
|
|
358
|
+
return default
|
|
359
|
+
return default.object_class, default.confidence, default.reference
|
|
360
|
+
|
|
361
|
+
def get_host(self, max_hosts=3, search=False, **kwargs) -> list[Host]:
|
|
362
|
+
"""
|
|
363
|
+
Gets the default host information of this Transient. This returns an otter.Host
|
|
364
|
+
object. If search=True, it will also check the BLAST host association database
|
|
365
|
+
for the best match and return it as well. Note that if search is True then
|
|
366
|
+
this has the potential to return max_hosts + 1, if BLAST also returns a result.
|
|
367
|
+
The BLAST result will always be the last value in the returned list.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
max_hosts [int] : The maximum number of hosts to return, default is 3
|
|
371
|
+
**kwargs : keyword arguments to be passed to getGHOST
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
A list of otter.Host objects. This is useful becuase the Host objects have
|
|
375
|
+
useful methods for querying public catalogs for data of the host.
|
|
376
|
+
"""
|
|
377
|
+
# first try to get the host information from our local database
|
|
378
|
+
host = []
|
|
379
|
+
if "host" in self:
|
|
380
|
+
max_hosts = min([max_hosts, len(self["host"])])
|
|
381
|
+
for h in self["host"][:max_hosts]:
|
|
382
|
+
# only return hosts with their ra and dec stored
|
|
383
|
+
if (
|
|
384
|
+
"host_ra" not in h
|
|
385
|
+
or "host_dec" not in h
|
|
386
|
+
or "host_ra_units" not in h
|
|
387
|
+
or "host_dec_units" not in h
|
|
388
|
+
):
|
|
389
|
+
continue
|
|
390
|
+
|
|
391
|
+
# now we can construct a host object from this
|
|
392
|
+
host.append(Host(transient_name=self.default_name, **dict(h)))
|
|
393
|
+
|
|
394
|
+
# then try BLAST
|
|
395
|
+
if search:
|
|
396
|
+
logger.warning(
|
|
397
|
+
"Trying to find a host with BLAST/astro-ghost. Note\
|
|
398
|
+
that this won't work for older targets! See https://blast.scimma.org"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# default_name should always be the TNS name if we have one
|
|
402
|
+
print(self.default_name)
|
|
403
|
+
blast_host = Host.query_blast(self.default_name.replace(" ", ""))
|
|
404
|
+
print(blast_host)
|
|
405
|
+
if blast_host is not None:
|
|
406
|
+
host.append(blast_host)
|
|
407
|
+
|
|
408
|
+
return host
|
|
409
|
+
|
|
410
|
+
def _get_default(self, key, filt=None):
|
|
411
|
+
"""
|
|
412
|
+
Get the default of key
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
key [str]: key in self to look for the default of
|
|
416
|
+
filt [str]: a valid pandas dataframe filter to index a pandas dataframe
|
|
417
|
+
called df.
|
|
418
|
+
"""
|
|
419
|
+
if key not in self:
|
|
420
|
+
raise KeyError(f"This transient does not have {key} associated with it!")
|
|
421
|
+
|
|
422
|
+
df = pd.DataFrame(self[key])
|
|
423
|
+
if len(df) == 0:
|
|
424
|
+
raise KeyError(f"This transient does not have {key} associated with it!")
|
|
425
|
+
|
|
426
|
+
if filt is not None:
|
|
427
|
+
df = df[eval(filt)] # apply the filters
|
|
428
|
+
|
|
429
|
+
if "default" in df:
|
|
430
|
+
# first try to get the default
|
|
431
|
+
df_filtered = df[df.default == True]
|
|
432
|
+
if len(df_filtered) == 0:
|
|
433
|
+
df_filtered = df
|
|
434
|
+
else:
|
|
435
|
+
df_filtered = df
|
|
436
|
+
|
|
437
|
+
if len(df_filtered) == 0:
|
|
438
|
+
return None
|
|
439
|
+
|
|
440
|
+
return df_filtered.iloc[0]
|
|
441
|
+
|
|
442
|
+
def _reformat_coordinate(self, item):
|
|
443
|
+
"""
|
|
444
|
+
Reformat the coordinate information in item
|
|
445
|
+
"""
|
|
446
|
+
coordin = None
|
|
447
|
+
if "ra" in item and "dec" in item:
|
|
448
|
+
# this is an equatorial coordinate
|
|
449
|
+
coordin = {
|
|
450
|
+
"ra": item["ra"],
|
|
451
|
+
"dec": item["dec"],
|
|
452
|
+
"unit": (item["ra_units"], item["dec_units"]),
|
|
453
|
+
}
|
|
454
|
+
elif "l" in item and "b" in item:
|
|
455
|
+
coordin = {
|
|
456
|
+
"l": item["l"],
|
|
457
|
+
"b": item["b"],
|
|
458
|
+
"unit": (item["l_units"], item["b_units"]),
|
|
459
|
+
"frame": "galactic",
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
return coordin
|
|
463
|
+
|
|
464
|
+
def clean_photometry(
|
|
465
|
+
self,
|
|
466
|
+
flux_unit: u.Unit = "mag(AB)",
|
|
467
|
+
date_unit: u.Unit = "MJD",
|
|
468
|
+
freq_unit: u.Unit = "GHz",
|
|
469
|
+
wave_unit: u.Unit = "nm",
|
|
470
|
+
obs_type: str = None,
|
|
471
|
+
deduplicate: Callable | None = None,
|
|
472
|
+
) -> pd.DataFrame:
|
|
473
|
+
"""
|
|
474
|
+
Ensure the photometry associated with this transient is all in the same
|
|
475
|
+
units/system/etc
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
flux_unit (astropy.unit.Unit): Either a valid string to convert
|
|
479
|
+
or an astropy.unit.Unit, this can be either
|
|
480
|
+
flux, flux density, or magnitude unit. This
|
|
481
|
+
supports any base units supported by
|
|
482
|
+
synphot
|
|
483
|
+
(https://synphot.readthedocs.io/en/latest/synphot/units.html#flux-units).
|
|
484
|
+
date_unit (str): Valid astropy date format string. See
|
|
485
|
+
https://docs.astropy.org/en/stable/time/index.html#time-format
|
|
486
|
+
freq_unit (astropy.unit.Unit): The astropy unit or string representation of
|
|
487
|
+
an astropy unit to convert and return the
|
|
488
|
+
frequency as. Must have a base unit of
|
|
489
|
+
1/time (Hz).
|
|
490
|
+
wave_unit (astropy.unit.Unit): The astropy unit or string representation of
|
|
491
|
+
an astropy unit to convert and return the
|
|
492
|
+
wavelength as. Must have a base unit of
|
|
493
|
+
length.
|
|
494
|
+
obs_type (str): "radio", "xray", or "uvoir". If provided, it only returns
|
|
495
|
+
data taken within that range of wavelengths/frequencies.
|
|
496
|
+
Default is None which will return all of the data.
|
|
497
|
+
deduplicate (Callable|None): A function to be used to remove duplicate
|
|
498
|
+
reductions of the same data that produces
|
|
499
|
+
different flux values. The default is the
|
|
500
|
+
otter.deduplicate_photometry method,
|
|
501
|
+
but you can pass
|
|
502
|
+
any callable that takes the output pandas
|
|
503
|
+
dataframe as input. Set this to False if you
|
|
504
|
+
don't want deduplication to occur.
|
|
505
|
+
Returns:
|
|
506
|
+
A pandas DataFrame of the cleaned up photometry in the requested units
|
|
507
|
+
"""
|
|
508
|
+
if deduplicate is None:
|
|
509
|
+
deduplicate = self.deduplicate_photometry
|
|
510
|
+
|
|
511
|
+
warn_filt = _DuplicateFilter()
|
|
512
|
+
logger.addFilter(warn_filt)
|
|
513
|
+
|
|
514
|
+
# these imports need to be here for some reason
|
|
515
|
+
# otherwise the code breaks
|
|
516
|
+
from synphot.units import VEGAMAG, convert_flux
|
|
517
|
+
from synphot.spectrum import SourceSpectrum
|
|
518
|
+
|
|
519
|
+
# variable so this warning only displays a single time each time this
|
|
520
|
+
# function is called
|
|
521
|
+
source_map_warning = True
|
|
522
|
+
|
|
523
|
+
# turn the photometry key into a pandas dataframe
|
|
524
|
+
if "photometry" not in self:
|
|
525
|
+
raise FailedQueryError("No photometry for this object!")
|
|
526
|
+
|
|
527
|
+
dfs = []
|
|
528
|
+
for item in self["photometry"]:
|
|
529
|
+
max_len = 0
|
|
530
|
+
for key, val in item.items():
|
|
531
|
+
if isinstance(val, list) and key != "reference":
|
|
532
|
+
max_len = max(max_len, len(val))
|
|
533
|
+
|
|
534
|
+
for key, val in item.items():
|
|
535
|
+
if not isinstance(val, list) or (
|
|
536
|
+
isinstance(val, list) and len(val) != max_len
|
|
537
|
+
):
|
|
538
|
+
item[key] = [val] * max_len
|
|
539
|
+
|
|
540
|
+
df = pd.DataFrame(item)
|
|
541
|
+
dfs.append(df)
|
|
542
|
+
|
|
543
|
+
if len(dfs) == 0:
|
|
544
|
+
raise FailedQueryError("No photometry for this object!")
|
|
545
|
+
c = pd.concat(dfs)
|
|
546
|
+
|
|
547
|
+
# extract the filter information and substitute in any missing columns
|
|
548
|
+
# because of how we handle this later, we just need to make sure the effective
|
|
549
|
+
# wavelengths are never nan
|
|
550
|
+
def fill_wave(row):
|
|
551
|
+
if "wave_eff" not in row or (
|
|
552
|
+
pd.isna(row.wave_eff) and not pd.isna(row.freq_eff)
|
|
553
|
+
):
|
|
554
|
+
freq_eff = row.freq_eff * u.Unit(row.freq_units)
|
|
555
|
+
wave_eff = freq_eff.to(u.Unit(wave_unit), equivalencies=u.spectral())
|
|
556
|
+
return wave_eff.value, wave_unit
|
|
557
|
+
elif not pd.isna(row.wave_eff):
|
|
558
|
+
return row.wave_eff, row.wave_units
|
|
559
|
+
else:
|
|
560
|
+
raise ValueError("Missing frequency or wavelength information!")
|
|
561
|
+
|
|
562
|
+
filters = pd.DataFrame(self["filter_alias"])
|
|
563
|
+
res = filters.apply(fill_wave, axis=1)
|
|
564
|
+
filters["wave_eff"], filters["wave_units"] = zip(*res)
|
|
565
|
+
# merge the photometry with the filter information
|
|
566
|
+
df = c.merge(filters, on="filter_key")
|
|
567
|
+
|
|
568
|
+
# drop irrelevant obs_types before continuing
|
|
569
|
+
if obs_type is not None:
|
|
570
|
+
valid_obs_types = {"radio", "uvoir", "xray"}
|
|
571
|
+
if obs_type not in valid_obs_types:
|
|
572
|
+
raise IOError("Please provide a valid obs_type")
|
|
573
|
+
df = df[df.obs_type == obs_type]
|
|
574
|
+
|
|
575
|
+
# add some mockup columns if they don't exist
|
|
576
|
+
if "value" not in df:
|
|
577
|
+
df["value"] = np.nan
|
|
578
|
+
df["value_err"] = np.nan
|
|
579
|
+
df["value_units"] = "NaN"
|
|
580
|
+
|
|
581
|
+
# fix some bad units that are old and no longer recognized by astropy
|
|
582
|
+
with warnings.catch_warnings():
|
|
583
|
+
warnings.filterwarnings("ignore")
|
|
584
|
+
df.raw_units = df.raw_units.str.replace("ergs", "erg")
|
|
585
|
+
df.raw_units = ["mag(AB)" if uu == "AB" else uu for uu in df.raw_units]
|
|
586
|
+
df.value_units = df.value_units.str.replace("ergs", "erg")
|
|
587
|
+
df.value_units = ["mag(AB)" if uu == "AB" else uu for uu in df.value_units]
|
|
588
|
+
|
|
589
|
+
# merge the raw and value keywords based on the requested flux_units
|
|
590
|
+
# first take everything that just has `raw` and not `value`
|
|
591
|
+
df_raw_only = df[df.value.isna()]
|
|
592
|
+
remaining = df[df.value.notna()]
|
|
593
|
+
if len(remaining) == 0:
|
|
594
|
+
df_raw = df_raw_only
|
|
595
|
+
df_value = [] # this tricks the code later
|
|
596
|
+
else:
|
|
597
|
+
# then take the remaining rows and figure out if we want the raw or value
|
|
598
|
+
with warnings.catch_warnings():
|
|
599
|
+
warnings.filterwarnings("ignore")
|
|
600
|
+
flux_unit_astropy = u.Unit(flux_unit)
|
|
601
|
+
|
|
602
|
+
val_unit_filt = np.array(
|
|
603
|
+
[
|
|
604
|
+
u.Unit(uu).is_equivalent(flux_unit_astropy)
|
|
605
|
+
for uu in remaining.value_units
|
|
606
|
+
]
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
df_value = remaining[val_unit_filt]
|
|
610
|
+
df_raw_and_value = remaining[~val_unit_filt]
|
|
611
|
+
|
|
612
|
+
# then merge the raw dataframes
|
|
613
|
+
df_raw = pd.concat([df_raw_only, df_raw_and_value], axis=0)
|
|
614
|
+
|
|
615
|
+
# then add columns to these dataframes to convert stuff later
|
|
616
|
+
df_raw = df_raw.assign(
|
|
617
|
+
_flux=df_raw["raw"].values,
|
|
618
|
+
_flux_units=df_raw["raw_units"].values,
|
|
619
|
+
_flux_err=(
|
|
620
|
+
df_raw["raw_err"].values
|
|
621
|
+
if "raw_err" in df_raw
|
|
622
|
+
else [np.nan] * len(df_raw)
|
|
623
|
+
),
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
if len(df_value) == 0:
|
|
627
|
+
df = df_raw
|
|
628
|
+
else:
|
|
629
|
+
df_value = df_value.assign(
|
|
630
|
+
_flux=df_value["value"].values,
|
|
631
|
+
_flux_units=df_value["value_units"].values,
|
|
632
|
+
_flux_err=(
|
|
633
|
+
df_value["value_err"].values
|
|
634
|
+
if "value_err" in df_value
|
|
635
|
+
else [np.nan] * len(df_value)
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# then merge df_value and df_raw back into one df
|
|
640
|
+
df = pd.concat([df_raw, df_value], axis=0)
|
|
641
|
+
|
|
642
|
+
# then, for the rest of the code to work, set the "by" variables to _flux
|
|
643
|
+
by = "_flux"
|
|
644
|
+
|
|
645
|
+
# skip rows where 'by' is nan
|
|
646
|
+
df = df[df[by].notna()]
|
|
647
|
+
|
|
648
|
+
# filter out anything that has _flux_units == "ct" because we can't convert that
|
|
649
|
+
try:
|
|
650
|
+
# this is a test case to see if we can convert ct -> flux_unit
|
|
651
|
+
convert_flux(
|
|
652
|
+
[1 * u.nm, 2 * u.nm], 1 * u.ct, u.Unit(flux_unit), area=1 * u.m**2
|
|
653
|
+
)
|
|
654
|
+
except u.UnitsError:
|
|
655
|
+
bad_units = df[df._flux_units == "ct"]
|
|
656
|
+
if len(bad_units) > 0:
|
|
657
|
+
logger.warning(
|
|
658
|
+
f"""Removing {len(bad_units)} photometry points from
|
|
659
|
+
{self.default_name} because we can't convert them from ct ->
|
|
660
|
+
{flux_unit}"""
|
|
661
|
+
)
|
|
662
|
+
df = df[df._flux_units != "ct"]
|
|
663
|
+
|
|
664
|
+
# convert the ads bibcodes to a string of human readable sources here
|
|
665
|
+
def mappedrefs(row):
|
|
666
|
+
if isinstance(row.reference, list):
|
|
667
|
+
return "<br>".join([self.srcmap[bibcode] for bibcode in row.reference])
|
|
668
|
+
else:
|
|
669
|
+
return self.srcmap[row.reference]
|
|
670
|
+
|
|
671
|
+
try:
|
|
672
|
+
df["human_readable_refs"] = df.apply(mappedrefs, axis=1)
|
|
673
|
+
except Exception as exc:
|
|
674
|
+
if source_map_warning:
|
|
675
|
+
source_map_warning = False
|
|
676
|
+
logger.warning(f"Unable to apply the source mapping because {exc}")
|
|
677
|
+
|
|
678
|
+
df["human_readable_refs"] = df.reference
|
|
679
|
+
|
|
680
|
+
# Figure out what columns are good to groupby in the photometry
|
|
681
|
+
outdata = []
|
|
682
|
+
|
|
683
|
+
if "telescope" in df:
|
|
684
|
+
tele = True
|
|
685
|
+
to_grp_by = ["obs_type", by + "_units", "telescope"]
|
|
686
|
+
else:
|
|
687
|
+
tele = False
|
|
688
|
+
to_grp_by = ["obs_type", by + "_units"]
|
|
689
|
+
|
|
690
|
+
# Do the conversion based on what we decided to group by
|
|
691
|
+
for groupedby, data in df.groupby(to_grp_by, dropna=False):
|
|
692
|
+
if tele:
|
|
693
|
+
obstype, unit, telescope = groupedby
|
|
694
|
+
else:
|
|
695
|
+
obstype, unit = groupedby
|
|
696
|
+
telescope = None
|
|
697
|
+
|
|
698
|
+
# get the photometry in the right type
|
|
699
|
+
unit = data[by + "_units"].unique()
|
|
700
|
+
if len(unit) > 1:
|
|
701
|
+
raise OtterLimitationError(
|
|
702
|
+
"Can not apply multiple units for different obs_types"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
unit = unit[0]
|
|
706
|
+
isvegamag = "vega" in unit.lower()
|
|
707
|
+
try:
|
|
708
|
+
if isvegamag:
|
|
709
|
+
astropy_units = VEGAMAG
|
|
710
|
+
elif unit == "AB":
|
|
711
|
+
# In astropy "AB" is a magnitude SYSTEM not unit and while
|
|
712
|
+
# u.Unit("AB") will succeed without error, it will not produce
|
|
713
|
+
# the expected result!
|
|
714
|
+
# We can assume here that this unit really means astropy's "mag(AB)"
|
|
715
|
+
astropy_units = u.Unit("mag(AB)")
|
|
716
|
+
else:
|
|
717
|
+
with warnings.catch_warnings():
|
|
718
|
+
warnings.simplefilter("ignore")
|
|
719
|
+
astropy_units = u.Unit(unit)
|
|
720
|
+
|
|
721
|
+
except ValueError:
|
|
722
|
+
# this means there is something likely slightly off in the input unit
|
|
723
|
+
# string. Let's try to fix it!
|
|
724
|
+
# here are some common mistakes
|
|
725
|
+
unit = unit.replace("ergs", "erg")
|
|
726
|
+
unit = unit.replace("AB", "mag(AB)")
|
|
727
|
+
|
|
728
|
+
astropy_units = u.Unit(unit)
|
|
729
|
+
|
|
730
|
+
except ValueError:
|
|
731
|
+
raise ValueError(
|
|
732
|
+
"Could not coerce your string into astropy unit format!"
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# get the flux data and find the type
|
|
736
|
+
indata = np.array(data[by].astype(float))
|
|
737
|
+
err_key = by + "_err"
|
|
738
|
+
if err_key in data:
|
|
739
|
+
indata_err = np.array(data[by + "_err"].astype(float))
|
|
740
|
+
else:
|
|
741
|
+
indata_err = np.zeros(len(data))
|
|
742
|
+
|
|
743
|
+
# convert to an astropy quantity
|
|
744
|
+
with warnings.catch_warnings():
|
|
745
|
+
warnings.filterwarnings("ignore")
|
|
746
|
+
q = indata * u.Unit(astropy_units)
|
|
747
|
+
q_err = indata_err * u.Unit(
|
|
748
|
+
astropy_units
|
|
749
|
+
) # assume error and values have the same unit
|
|
750
|
+
|
|
751
|
+
# get and save the effective wavelength
|
|
752
|
+
# because of cleaning we did to the filter dataframe above wave_eff
|
|
753
|
+
# should NEVER be nan!
|
|
754
|
+
if np.any(pd.isna(data["wave_eff"])):
|
|
755
|
+
raise ValueError("Flushing out the effective wavelength array failed!")
|
|
756
|
+
|
|
757
|
+
zz = zip(data["wave_eff"], data["wave_units"])
|
|
758
|
+
with warnings.catch_warnings():
|
|
759
|
+
warnings.filterwarnings("ignore")
|
|
760
|
+
wave_eff = u.Quantity([vv * u.Unit(uu) for vv, uu in zz], wave_unit)
|
|
761
|
+
freq_eff = wave_eff.to(freq_unit, equivalencies=u.spectral())
|
|
762
|
+
|
|
763
|
+
data["converted_wave"] = wave_eff.value
|
|
764
|
+
data["converted_wave_unit"] = wave_unit
|
|
765
|
+
data["converted_freq"] = freq_eff.value
|
|
766
|
+
data["converted_freq_unit"] = freq_unit
|
|
767
|
+
|
|
768
|
+
# convert using synphot
|
|
769
|
+
# stuff has to be done slightly differently for xray than for the others
|
|
770
|
+
if obstype == "xray":
|
|
771
|
+
if telescope is not None:
|
|
772
|
+
try:
|
|
773
|
+
area = XRAY_AREAS[telescope.lower()]
|
|
774
|
+
except KeyError:
|
|
775
|
+
raise OtterLimitationError(
|
|
776
|
+
"Did not find an area corresponding to "
|
|
777
|
+
+ "this telescope, please add to util!"
|
|
778
|
+
)
|
|
779
|
+
else:
|
|
780
|
+
raise OtterLimitationError(
|
|
781
|
+
"Can not convert x-ray data without a telescope"
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
# we also need to make this wave_min and wave_max
|
|
785
|
+
# instead of just the effective wavelength like for radio and uvoir
|
|
786
|
+
zz = zip(data["wave_min"], data["wave_max"], data["wave_units"])
|
|
787
|
+
with warnings.catch_warnings():
|
|
788
|
+
warnings.filterwarnings("ignore")
|
|
789
|
+
wave_eff = u.Quantity(
|
|
790
|
+
[np.array([m, M]) * u.Unit(uu) for m, M, uu in zz],
|
|
791
|
+
u.Unit(wave_unit),
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
else:
|
|
795
|
+
area = None
|
|
796
|
+
|
|
797
|
+
if obstype == "xray" or isvegamag:
|
|
798
|
+
# we unfortunately have to loop over the points here because
|
|
799
|
+
# syncphot does not work with a 2D array of min max wavelengths
|
|
800
|
+
# for converting counts to other flux units. It also can't convert
|
|
801
|
+
# vega mags with a wavelength array because it interprets that as the
|
|
802
|
+
# wavelengths corresponding to the SourceSpectrum.from_vega()
|
|
803
|
+
|
|
804
|
+
flux, flux_err = [], []
|
|
805
|
+
for wave, xray_point, xray_point_err in zip(wave_eff, q, q_err):
|
|
806
|
+
with warnings.catch_warnings():
|
|
807
|
+
warnings.filterwarnings("ignore")
|
|
808
|
+
f_val = convert_flux(
|
|
809
|
+
wave,
|
|
810
|
+
xray_point,
|
|
811
|
+
u.Unit(flux_unit),
|
|
812
|
+
vegaspec=SourceSpectrum.from_vega(),
|
|
813
|
+
area=area,
|
|
814
|
+
).value
|
|
815
|
+
|
|
816
|
+
# approximate the uncertainty as dX = dY/Y * X
|
|
817
|
+
f_err = np.multiply(
|
|
818
|
+
f_val, np.divide(xray_point_err.value, xray_point.value)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# then we take the average of the minimum and maximum values
|
|
822
|
+
# computed by syncphot
|
|
823
|
+
flux.append(np.mean(f_val))
|
|
824
|
+
flux_err.append(np.mean(f_err))
|
|
825
|
+
|
|
826
|
+
else:
|
|
827
|
+
# this will be faster and cover most cases
|
|
828
|
+
with warnings.catch_warnings():
|
|
829
|
+
warnings.filterwarnings("ignore")
|
|
830
|
+
flux = convert_flux(wave_eff, q, u.Unit(flux_unit)).value
|
|
831
|
+
|
|
832
|
+
# since the error propagation is different between logarithmic units
|
|
833
|
+
# and linear units, unfortunately
|
|
834
|
+
if isinstance(u.Unit(flux_unit), u.LogUnit):
|
|
835
|
+
# approximate the uncertainty as dX = dY/Y * |ln(10)/2.5|
|
|
836
|
+
prefactor = np.abs(np.log(10) / 2.5) # this is basically 1
|
|
837
|
+
else:
|
|
838
|
+
# approximate the uncertainty as dX = dY/Y * X
|
|
839
|
+
prefactor = flux
|
|
840
|
+
|
|
841
|
+
flux_err = np.multiply(prefactor, np.divide(q_err.value, q.value))
|
|
842
|
+
|
|
843
|
+
flux = np.array(flux) * u.Unit(flux_unit)
|
|
844
|
+
flux_err = np.array(flux_err) * u.Unit(flux_unit)
|
|
845
|
+
|
|
846
|
+
data["converted_flux"] = flux.value
|
|
847
|
+
data["converted_flux_err"] = flux_err.value
|
|
848
|
+
outdata.append(data)
|
|
849
|
+
|
|
850
|
+
if len(outdata) == 0:
|
|
851
|
+
raise FailedQueryError()
|
|
852
|
+
outdata = pd.concat(outdata)
|
|
853
|
+
|
|
854
|
+
# copy over the flux units
|
|
855
|
+
outdata["converted_flux_unit"] = flux_unit
|
|
856
|
+
|
|
857
|
+
# make sure all the datetimes are in the same format here too!!
|
|
858
|
+
times = [
|
|
859
|
+
Time(d, format=f).to_value(date_unit.lower())
|
|
860
|
+
for d, f in zip(outdata.date, outdata.date_format.str.lower())
|
|
861
|
+
]
|
|
862
|
+
outdata["converted_date"] = times
|
|
863
|
+
outdata["converted_date_unit"] = date_unit
|
|
864
|
+
|
|
865
|
+
# compute the upperlimit value based on a 3 sigma detection
|
|
866
|
+
# this is just for rows where we don't already know if it is an upperlimit
|
|
867
|
+
if isinstance(u.Unit(flux_unit), u.LogUnit):
|
|
868
|
+
# this uses the following formula (which is surprising because it means
|
|
869
|
+
# magnitude upperlimits are independent of the actual measurement!)
|
|
870
|
+
# sigma_m > (1/3) * (ln(10)/2.5)
|
|
871
|
+
def is_upperlimit(row):
|
|
872
|
+
if "upperlimit" in row and pd.isna(row.upperlimit):
|
|
873
|
+
return row.converted_flux_err > np.log(10) / (3 * 2.5)
|
|
874
|
+
else:
|
|
875
|
+
return row.upperlimit
|
|
876
|
+
else:
|
|
877
|
+
|
|
878
|
+
def is_upperlimit(row):
|
|
879
|
+
if "upperlimit" in row and pd.isna(row.upperlimit):
|
|
880
|
+
return row.converted_flux < 3 * row.converted_flux_err
|
|
881
|
+
else:
|
|
882
|
+
return row.upperlimit
|
|
883
|
+
|
|
884
|
+
outdata["upperlimit"] = outdata.apply(is_upperlimit, axis=1)
|
|
885
|
+
|
|
886
|
+
# perform some more complex deduplication of the dataset
|
|
887
|
+
if deduplicate:
|
|
888
|
+
outdata = deduplicate(outdata)
|
|
889
|
+
|
|
890
|
+
# throw a warning if the output dataframe has UV/Optical/IR or Radio data
|
|
891
|
+
# where we don't know if the dataset has been host corrected or not
|
|
892
|
+
if ("corr_host" not in outdata) or (
|
|
893
|
+
len(outdata[pd.isna(outdata.corr_host) * (outdata.obs_type != "xray")]) >= 0
|
|
894
|
+
):
|
|
895
|
+
logger.warning(
|
|
896
|
+
f"{self.default_name} has at least one photometry point where it is "
|
|
897
|
+
+ "unclear if a host subtraction was performed. This can be especially "
|
|
898
|
+
+ "detrimental for UV data. Please consider filtering out UV/Optical/IR"
|
|
899
|
+
+ " or radio rows where the corr_host column is null/None/NaN."
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
logger.removeFilter(warn_filt)
|
|
903
|
+
return outdata
|
|
904
|
+
|
|
905
|
+
@classmethod
|
|
906
|
+
def deduplicate_photometry(cls, phot: pd.DataFrame, date_tol: int | float = 1):
|
|
907
|
+
"""
|
|
908
|
+
This deduplicates a pandas dataframe of photometry that could potentially
|
|
909
|
+
have rows/datasets that are the result of different reductions of the same
|
|
910
|
+
data. This is especially relevant for X-ray and UV observations where different
|
|
911
|
+
reductions can produce different flux values from the same observation.
|
|
912
|
+
|
|
913
|
+
The algorithm used here first finds duplicates by normalizing the telescope
|
|
914
|
+
names, then grouping the dataframe by transient name, norm telescope name,
|
|
915
|
+
filter_key, and the obs_type. It then assumes that data from the same
|
|
916
|
+
reference will not produce duplicated data. Finally, it finds the overlapping
|
|
917
|
+
regions within date +/- date_tol (or between date_min and date_max for binned
|
|
918
|
+
data), and uses any data within that region as duplicated. From there, it
|
|
919
|
+
first tries to choose the reduction that is host subtracted (if only one is
|
|
920
|
+
host subtracted), then if neither or more than one of the datasets are host
|
|
921
|
+
subtracted then it just takes the most recent reduction.
|
|
922
|
+
|
|
923
|
+
Args:
|
|
924
|
+
phot (pd.DataFrame): A pandas dataframe of the photometry with keys defined
|
|
925
|
+
by the OTTER schema
|
|
926
|
+
date_tol (int|float): The default tolerance (or "uncertainty") to use on the
|
|
927
|
+
dates in the "date" column of phot. In days. Defaults
|
|
928
|
+
to 1 day.
|
|
929
|
+
"""
|
|
930
|
+
# we need to reset the index to keep track of things appropriately
|
|
931
|
+
phot = phot.reset_index(drop=True)
|
|
932
|
+
|
|
933
|
+
if "telescope" not in phot:
|
|
934
|
+
phot["telescope"] = np.nan
|
|
935
|
+
|
|
936
|
+
# we first have to standardize some columns given some basic assumptions
|
|
937
|
+
phot["_ref_str"] = phot.reference.astype(str)
|
|
938
|
+
|
|
939
|
+
# normalize the telescope name so we can group by it
|
|
940
|
+
phot["_norm_tele_name"] = phot.telescope.apply(cls._normalize_tele_name)
|
|
941
|
+
|
|
942
|
+
# now find the duplicated data
|
|
943
|
+
dups = []
|
|
944
|
+
phot_grpby = phot.groupby(
|
|
945
|
+
["_norm_tele_name", "filter_key", "obs_type"], dropna=False
|
|
946
|
+
)
|
|
947
|
+
for (tele, filter_key, obs_type), grp in phot_grpby:
|
|
948
|
+
# by definition, there can only be dups if the name, telescope, and filter
|
|
949
|
+
# are the same
|
|
950
|
+
|
|
951
|
+
# if there is only one reference in this group of data, there's no way
|
|
952
|
+
# there are duplicate reductions of the same dataset
|
|
953
|
+
if len(grp._ref_str.unique()) <= 1:
|
|
954
|
+
continue
|
|
955
|
+
|
|
956
|
+
# the next trick is that the dates don't need to be the same, but need to
|
|
957
|
+
# fall inside the same range
|
|
958
|
+
grp["_mean_dates"] = grp.apply(cls._convert_dates, axis=1)
|
|
959
|
+
|
|
960
|
+
if "date_min" in grp and not np.all(pd.isna(grp.date_min)):
|
|
961
|
+
grp["min_dates"] = grp.apply(
|
|
962
|
+
lambda row: cls._convert_dates(row, date_key="date_min"), axis=1
|
|
963
|
+
).astype(float)
|
|
964
|
+
grp["max_dates"] = grp.apply(
|
|
965
|
+
lambda row: cls._convert_dates(row, date_key="date_max"), axis=1
|
|
966
|
+
).astype(float)
|
|
967
|
+
|
|
968
|
+
# in case any of the min_date and max_date in the grp are nan
|
|
969
|
+
grp.fillna(
|
|
970
|
+
{
|
|
971
|
+
"min_dates": grp._mean_dates - date_tol,
|
|
972
|
+
"max_dates": grp._mean_dates + date_tol,
|
|
973
|
+
},
|
|
974
|
+
inplace=True,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
elif "date_err" in grp and not np.any(pd.isna(grp.date_err)):
|
|
978
|
+
grp["min_dates"] = (grp._mean_dates - grp.date_err).astype(float)
|
|
979
|
+
grp["max_dates"] = (grp._mean_dates + grp.date_err).astype(float)
|
|
980
|
+
else:
|
|
981
|
+
# then assume some uncertainty on the date
|
|
982
|
+
grp["min_dates"] = (grp._mean_dates - date_tol).astype(float)
|
|
983
|
+
grp["max_dates"] = (grp._mean_dates + date_tol).astype(float)
|
|
984
|
+
|
|
985
|
+
ref_ranges = [
|
|
986
|
+
(subgrp.min_dates.min(), subgrp.max_dates.max())
|
|
987
|
+
for _, subgrp in grp.groupby("_ref_str")
|
|
988
|
+
]
|
|
989
|
+
|
|
990
|
+
overlaps = cls._find_overlapping_regions(ref_ranges)
|
|
991
|
+
|
|
992
|
+
if len(overlaps) == 0:
|
|
993
|
+
continue # then there are no dups
|
|
994
|
+
|
|
995
|
+
for min_overlap, max_overlap in overlaps:
|
|
996
|
+
dup_data = grp[
|
|
997
|
+
(grp.min_dates >= min_overlap) * (grp.max_dates <= max_overlap)
|
|
998
|
+
]
|
|
999
|
+
|
|
1000
|
+
if len(dup_data) == 0:
|
|
1001
|
+
continue # no data falls in this range!
|
|
1002
|
+
|
|
1003
|
+
dups.append(dup_data)
|
|
1004
|
+
|
|
1005
|
+
# now that we've found the duplicated datasets, we can iterate through them
|
|
1006
|
+
# and choose the "default"
|
|
1007
|
+
phot_res = deepcopy(phot)
|
|
1008
|
+
undupd = []
|
|
1009
|
+
for dup in dups:
|
|
1010
|
+
try:
|
|
1011
|
+
phot_res = phot_res.drop(dup.index) # we'll append back in the non dup
|
|
1012
|
+
except KeyError:
|
|
1013
|
+
continue # we already deleted these ones
|
|
1014
|
+
|
|
1015
|
+
# first, check if only one of the dup reductions host subtracted
|
|
1016
|
+
if "corr_host" in dup:
|
|
1017
|
+
dup_host_corr = dup[dup.corr_host.astype(bool)]
|
|
1018
|
+
host_corr_refs = dup_host_corr.human_readable_refs.unique()
|
|
1019
|
+
if len(host_corr_refs) == 1:
|
|
1020
|
+
# then one of the reductions is host corrected and the other isn't!
|
|
1021
|
+
undupd.append(dup[dup.human_readable_refs == host_corr_refs[0]])
|
|
1022
|
+
continue
|
|
1023
|
+
|
|
1024
|
+
bibcodes_sorted_by_year = sorted(dup._ref_str.unique(), key=cls._find_year)
|
|
1025
|
+
dataset_to_use = dup[dup._ref_str == bibcodes_sorted_by_year[0]]
|
|
1026
|
+
undupd.append(dataset_to_use)
|
|
1027
|
+
|
|
1028
|
+
# then return the full photometry dataset but with the dups removed!
|
|
1029
|
+
return pd.concat([phot_res] + undupd).reset_index()
|
|
1030
|
+
|
|
1031
|
+
@staticmethod
|
|
1032
|
+
def _normalize_tele_name(tele_name):
|
|
1033
|
+
if pd.isna(tele_name):
|
|
1034
|
+
return tele_name
|
|
1035
|
+
|
|
1036
|
+
common_delims = ["-", "/", " ", "."]
|
|
1037
|
+
for delim in common_delims:
|
|
1038
|
+
tele_name = tele_name.replace(delim, ":*:")
|
|
1039
|
+
|
|
1040
|
+
# this assumes that the telescope name will almost always be first,
|
|
1041
|
+
# before other delimiters
|
|
1042
|
+
return tele_name.split(":*:")[0].lower()
|
|
1043
|
+
|
|
1044
|
+
@staticmethod
|
|
1045
|
+
def _convert_dates(row, date_key="date"):
|
|
1046
|
+
"""Make sure the dates are in MJD"""
|
|
1047
|
+
if pd.isna(row[date_key]):
|
|
1048
|
+
return row[date_key]
|
|
1049
|
+
|
|
1050
|
+
return Time(row[date_key], format=row.date_format.lower()).mjd
|
|
1051
|
+
|
|
1052
|
+
@staticmethod
|
|
1053
|
+
def _find_overlapping_regions(intervals):
|
|
1054
|
+
"""Find the overlaps in a list of tuples of mins and maxs. This is relatively
|
|
1055
|
+
inefficient but the len(intervals) should be < 10 so it should be fine"""
|
|
1056
|
+
overlap_ranges = []
|
|
1057
|
+
for ii, (start_ii, end_ii) in enumerate(intervals):
|
|
1058
|
+
for jj, (start_jj, end_jj) in enumerate(intervals):
|
|
1059
|
+
if ii <= jj:
|
|
1060
|
+
continue
|
|
1061
|
+
|
|
1062
|
+
if start_ii > start_jj:
|
|
1063
|
+
start = start_ii
|
|
1064
|
+
else:
|
|
1065
|
+
start = start_jj
|
|
1066
|
+
|
|
1067
|
+
if end_ii > end_jj:
|
|
1068
|
+
end = end_jj
|
|
1069
|
+
else:
|
|
1070
|
+
end = end_ii
|
|
1071
|
+
|
|
1072
|
+
if start < end:
|
|
1073
|
+
# then there is an overlap!
|
|
1074
|
+
overlap_ranges.append((start, end))
|
|
1075
|
+
|
|
1076
|
+
return overlap_ranges
|
|
1077
|
+
|
|
1078
|
+
@staticmethod
|
|
1079
|
+
def _find_year(s):
|
|
1080
|
+
match = re.search(r"\d{4}", s)
|
|
1081
|
+
return int(match.group()) if match else float("inf")
|
|
1082
|
+
|
|
1083
|
+
def _merge_names(t1, t2, out): # noqa: N805
|
|
1084
|
+
"""
|
|
1085
|
+
Private method to merge the name data in t1 and t2 and put it in out
|
|
1086
|
+
"""
|
|
1087
|
+
key = "name"
|
|
1088
|
+
out[key] = {}
|
|
1089
|
+
|
|
1090
|
+
# first deal with the default_name key
|
|
1091
|
+
# we are gonna need to use some regex magic to choose a preferred default_name
|
|
1092
|
+
if t1[key]["default_name"] == t2[key]["default_name"]:
|
|
1093
|
+
out[key]["default_name"] = t1[key]["default_name"]
|
|
1094
|
+
else:
|
|
1095
|
+
# we need to decide which default_name is better
|
|
1096
|
+
# it should be the one that matches the TNS style
|
|
1097
|
+
# let's use regex
|
|
1098
|
+
n1 = t1[key]["default_name"]
|
|
1099
|
+
n2 = t2[key]["default_name"]
|
|
1100
|
+
|
|
1101
|
+
# write some discriminating regex expressions
|
|
1102
|
+
# exp1: starts with a number, this is preferred because it is TNS style
|
|
1103
|
+
exp1 = "^[0-9]"
|
|
1104
|
+
# exp2: starts with any character, also preferred because it is TNS style
|
|
1105
|
+
exp2 = ".$"
|
|
1106
|
+
# exp3: checks if first four characters are a number, like a year :),
|
|
1107
|
+
# this is pretty strict though
|
|
1108
|
+
exp3 = "^[0-9]{3}"
|
|
1109
|
+
# exp4: # checks if it starts with AT like TNS names
|
|
1110
|
+
exp4 = "^AT"
|
|
1111
|
+
|
|
1112
|
+
# combine all the regex expressions, this makes it easier to add more later
|
|
1113
|
+
exps = [exp1, exp2, exp3, exp4]
|
|
1114
|
+
|
|
1115
|
+
# score each default_name based on this
|
|
1116
|
+
score1 = 0
|
|
1117
|
+
score2 = 0
|
|
1118
|
+
for e in exps:
|
|
1119
|
+
re1 = re.findall(e, n1)
|
|
1120
|
+
re2 = re.findall(e, n2)
|
|
1121
|
+
if re1:
|
|
1122
|
+
score1 += 1
|
|
1123
|
+
if re2:
|
|
1124
|
+
score2 += 1
|
|
1125
|
+
|
|
1126
|
+
# assign a default_name based on the score
|
|
1127
|
+
if score1 > score2:
|
|
1128
|
+
out[key]["default_name"] = t1[key]["default_name"]
|
|
1129
|
+
elif score2 > score1:
|
|
1130
|
+
out[key]["default_name"] = t2[key]["default_name"]
|
|
1131
|
+
else:
|
|
1132
|
+
logger.warning(
|
|
1133
|
+
"Names have the same score! Just using the existing default_name"
|
|
1134
|
+
)
|
|
1135
|
+
out[key]["default_name"] = t1[key]["default_name"]
|
|
1136
|
+
|
|
1137
|
+
# now deal with aliases
|
|
1138
|
+
# create a reference mapping for each
|
|
1139
|
+
t1map = {}
|
|
1140
|
+
for val in t1[key]["alias"]:
|
|
1141
|
+
ref = val["reference"]
|
|
1142
|
+
if isinstance(ref, str):
|
|
1143
|
+
t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
|
|
1144
|
+
else:
|
|
1145
|
+
t1map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
|
|
1146
|
+
|
|
1147
|
+
t2map = {}
|
|
1148
|
+
for val in t2[key]["alias"]:
|
|
1149
|
+
ref = val["reference"]
|
|
1150
|
+
if isinstance(ref, str):
|
|
1151
|
+
t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
|
|
1152
|
+
else:
|
|
1153
|
+
t2map[val["value"]] = [ref] if isinstance(ref, str) else list(ref)
|
|
1154
|
+
|
|
1155
|
+
# figure out which ones we need to be careful with references in
|
|
1156
|
+
inboth = list(
|
|
1157
|
+
t1map.keys() & t2map.keys()
|
|
1158
|
+
) # in both so we'll have to merge the reference key
|
|
1159
|
+
int1 = list(t1map.keys() - t2map.keys()) # only in t1
|
|
1160
|
+
int2 = list(t2map.keys() - t1map.keys()) # only in t2
|
|
1161
|
+
|
|
1162
|
+
# add ones that are not in both first, these are easy
|
|
1163
|
+
line1 = [{"value": k, "reference": t1map[k]} for k in int1]
|
|
1164
|
+
line2 = [{"value": k, "reference": t2map[k]} for k in int2]
|
|
1165
|
+
bothlines = [{"value": k, "reference": t1map[k] + t2map[k]} for k in inboth]
|
|
1166
|
+
out[key]["alias"] = line2 + line1 + bothlines
|
|
1167
|
+
|
|
1168
|
+
def _merge_filter_alias(t1, t2, out): # noqa: N805
|
|
1169
|
+
"""
|
|
1170
|
+
Combine the filter alias lists across the transient objects
|
|
1171
|
+
"""
|
|
1172
|
+
|
|
1173
|
+
key = "filter_alias"
|
|
1174
|
+
|
|
1175
|
+
out[key] = deepcopy(t1[key])
|
|
1176
|
+
keys1 = {filt["filter_key"] for filt in t1[key]}
|
|
1177
|
+
for filt in t2[key]:
|
|
1178
|
+
if filt["filter_key"] not in keys1:
|
|
1179
|
+
out[key].append(filt)
|
|
1180
|
+
|
|
1181
|
+
def _merge_schema_version(t1, t2, out): # noqa: N805
|
|
1182
|
+
"""
|
|
1183
|
+
Just keep whichever schema version is greater
|
|
1184
|
+
"""
|
|
1185
|
+
key = "schema_version/value"
|
|
1186
|
+
if "comment" not in t1["schema_version"]:
|
|
1187
|
+
t1["schema_version/comment"] = ""
|
|
1188
|
+
|
|
1189
|
+
if "comment" not in t2["schema_version"]:
|
|
1190
|
+
t2["schema_version/comment"] = ""
|
|
1191
|
+
|
|
1192
|
+
if key in t1 and key in t2 and int(t1[key]) > int(t2[key]):
|
|
1193
|
+
out["schema_version"] = deepcopy(t1["schema_version"])
|
|
1194
|
+
else:
|
|
1195
|
+
out["schema_version"] = deepcopy(t2["schema_version"])
|
|
1196
|
+
|
|
1197
|
+
out["schema_version"]["comment"] = (
|
|
1198
|
+
t1["schema_version/comment"] + ";" + t2["schema_version/comment"]
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
def _merge_photometry(t1, t2, out): # noqa: N805
|
|
1202
|
+
"""
|
|
1203
|
+
Combine photometry sources
|
|
1204
|
+
"""
|
|
1205
|
+
|
|
1206
|
+
key = "photometry"
|
|
1207
|
+
|
|
1208
|
+
out[key] = deepcopy(t1[key])
|
|
1209
|
+
refs = [] # np.array([d["reference"] for d in out[key]])
|
|
1210
|
+
# merge_dups = lambda val: np.sum(val) if np.any(val.isna()) else val.iloc[0]
|
|
1211
|
+
for val in out[key]:
|
|
1212
|
+
if isinstance(val, list):
|
|
1213
|
+
refs += val
|
|
1214
|
+
elif isinstance(val, np.ndarray):
|
|
1215
|
+
refs += list(val)
|
|
1216
|
+
else:
|
|
1217
|
+
refs.append(val)
|
|
1218
|
+
|
|
1219
|
+
for val in t2[key]:
|
|
1220
|
+
# first check if t2's reference is in out
|
|
1221
|
+
if val["reference"] not in refs:
|
|
1222
|
+
# it's not here so we can just append the new photometry!
|
|
1223
|
+
out[key].append(val)
|
|
1224
|
+
else:
|
|
1225
|
+
# we need to merge it with other photometry
|
|
1226
|
+
i1 = np.where(val["reference"] == refs)[0][0]
|
|
1227
|
+
df1 = pd.DataFrame(out[key][i1])
|
|
1228
|
+
df2 = pd.DataFrame(val)
|
|
1229
|
+
|
|
1230
|
+
# only substitute in values that are nan in df1 or new
|
|
1231
|
+
# the combined keys of the two
|
|
1232
|
+
mergeon = list(set(df1.keys()) & set(df2.keys()))
|
|
1233
|
+
df = df1.merge(df2, on=mergeon, how="outer")
|
|
1234
|
+
# convert to a dictionary
|
|
1235
|
+
newdict = df.reset_index().to_dict(orient="list")
|
|
1236
|
+
del newdict["index"]
|
|
1237
|
+
|
|
1238
|
+
newdict["reference"] = newdict["reference"][0]
|
|
1239
|
+
|
|
1240
|
+
out[key][i1] = newdict # replace the dictionary at i1 with the new dict
|
|
1241
|
+
|
|
1242
|
+
def _merge_class(t1, t2, out): # noqa: N805
|
|
1243
|
+
"""
|
|
1244
|
+
Combine the classification attribute
|
|
1245
|
+
"""
|
|
1246
|
+
key = "classification"
|
|
1247
|
+
subkey = "value"
|
|
1248
|
+
out[key] = deepcopy(t1[key])
|
|
1249
|
+
classes = np.array([item["object_class"] for item in out[key][subkey]])
|
|
1250
|
+
|
|
1251
|
+
for item in t2[key][subkey]:
|
|
1252
|
+
if item["object_class"] in classes:
|
|
1253
|
+
i = np.where(item["object_class"] == classes)[0][0]
|
|
1254
|
+
if int(item["confidence"]) > int(out[key][subkey][i]["confidence"]):
|
|
1255
|
+
out[key][subkey][i]["confidence"] = item[
|
|
1256
|
+
"confidence"
|
|
1257
|
+
] # we are now more confident
|
|
1258
|
+
|
|
1259
|
+
if not isinstance(out[key][subkey][i]["reference"], list):
|
|
1260
|
+
out[key][subkey][i]["reference"] = [
|
|
1261
|
+
out[key][subkey][i]["reference"]
|
|
1262
|
+
]
|
|
1263
|
+
|
|
1264
|
+
if not isinstance(item["reference"], list):
|
|
1265
|
+
item["reference"] = [item["reference"]]
|
|
1266
|
+
|
|
1267
|
+
newdata = list(
|
|
1268
|
+
np.unique(out[key][subkey][i]["reference"] + item["reference"])
|
|
1269
|
+
)
|
|
1270
|
+
out[key][subkey][i]["reference"] = newdata
|
|
1271
|
+
|
|
1272
|
+
else:
|
|
1273
|
+
out[key][subkey].append(item)
|
|
1274
|
+
|
|
1275
|
+
# now that we have all of them we need to figure out which one is the default
|
|
1276
|
+
maxconf = max(out[key][subkey], key=lambda d: d["confidence"])
|
|
1277
|
+
for item in out[key][subkey]:
|
|
1278
|
+
if item == maxconf:
|
|
1279
|
+
item["default"] = True
|
|
1280
|
+
else:
|
|
1281
|
+
item["default"] = False
|
|
1282
|
+
|
|
1283
|
+
# then rederive the classification flags
|
|
1284
|
+
out = Transient._derive_classification_flags(out)
|
|
1285
|
+
|
|
1286
|
+
@classmethod
|
|
1287
|
+
def _derive_classification_flags(cls, out):
|
|
1288
|
+
"""
|
|
1289
|
+
Derive the classification flags based on the confidence flags. This will find
|
|
1290
|
+
- spec_classed
|
|
1291
|
+
- unambiguous
|
|
1292
|
+
|
|
1293
|
+
See the paper for a detailed description of how this algorithm makes its
|
|
1294
|
+
choices
|
|
1295
|
+
"""
|
|
1296
|
+
|
|
1297
|
+
if "classification" not in out or "value" not in out["classification"]:
|
|
1298
|
+
# this means that the transient doesn't have any classifications
|
|
1299
|
+
# just return itself without any changes
|
|
1300
|
+
return out
|
|
1301
|
+
|
|
1302
|
+
# get the confidences of all of the classifications of this transient
|
|
1303
|
+
confs = np.array(
|
|
1304
|
+
[item["confidence"] for item in out["classification"]["value"]]
|
|
1305
|
+
).astype(float)
|
|
1306
|
+
|
|
1307
|
+
all_class_roots = np.array(
|
|
1308
|
+
[
|
|
1309
|
+
_fuzzy_class_root(item["object_class"])
|
|
1310
|
+
for item in out["classification"]["value"]
|
|
1311
|
+
]
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
if np.any(confs >= 3):
|
|
1315
|
+
unambiguous = len(np.unique(all_class_roots)) == 1
|
|
1316
|
+
if np.any(confs == 3) or np.any(confs == 3.3):
|
|
1317
|
+
# this is a "gold spectrum"
|
|
1318
|
+
spec_classed = 3
|
|
1319
|
+
elif np.any(confs == 3.2):
|
|
1320
|
+
# this is a silver spectrum
|
|
1321
|
+
spec_classed = 2
|
|
1322
|
+
elif np.any(confs == 3.1):
|
|
1323
|
+
# this is a bronze spectrum
|
|
1324
|
+
spec_classed = 1
|
|
1325
|
+
else:
|
|
1326
|
+
raise ValueError("Not prepared for this confidence flag!")
|
|
1327
|
+
|
|
1328
|
+
elif np.any(confs == 2):
|
|
1329
|
+
# these always have spec_classed = True, by definition
|
|
1330
|
+
# They also have unambiguous = False by definition because they don't
|
|
1331
|
+
# have a peer reviewed citation for their classification
|
|
1332
|
+
spec_classed = 1
|
|
1333
|
+
unambiguous = False
|
|
1334
|
+
|
|
1335
|
+
elif np.any(confs == 1):
|
|
1336
|
+
spec_classed = 0 # by definition
|
|
1337
|
+
unambiguous = len(np.unique(all_class_roots)) == 1
|
|
1338
|
+
|
|
1339
|
+
else:
|
|
1340
|
+
spec_classed = 0
|
|
1341
|
+
unambiguous = False
|
|
1342
|
+
|
|
1343
|
+
# finally, set these keys in the classification dict
|
|
1344
|
+
out["classification"]["spec_classed"] = spec_classed
|
|
1345
|
+
out["classification"]["unambiguous"] = unambiguous
|
|
1346
|
+
|
|
1347
|
+
return out
|
|
1348
|
+
|
|
1349
|
+
@staticmethod
|
|
1350
|
+
def _merge_arbitrary(key, t1, t2, out, merge_subkeys=None, groupby_key=None):
|
|
1351
|
+
"""
|
|
1352
|
+
Merge two arbitrary datasets inside the json file using pandas
|
|
1353
|
+
|
|
1354
|
+
The datasets in t1 and t2 in "key" must be able to be forced into
|
|
1355
|
+
a NxM pandas dataframe!
|
|
1356
|
+
"""
|
|
1357
|
+
|
|
1358
|
+
if key == "name":
|
|
1359
|
+
t1._merge_names(t2, out)
|
|
1360
|
+
elif key == "filter_alias":
|
|
1361
|
+
t1._merge_filter_alias(t2, out)
|
|
1362
|
+
elif key == "schema_version":
|
|
1363
|
+
t1._merge_schema_version(t2, out)
|
|
1364
|
+
elif key == "photometry":
|
|
1365
|
+
t1._merge_photometry(t2, out)
|
|
1366
|
+
elif key == "classification":
|
|
1367
|
+
t1._merge_class(t2, out)
|
|
1368
|
+
else:
|
|
1369
|
+
# this is where we can standardize some of the merging
|
|
1370
|
+
df1 = pd.DataFrame(t1[key])
|
|
1371
|
+
df2 = pd.DataFrame(t2[key])
|
|
1372
|
+
|
|
1373
|
+
merged_with_dups = pd.concat([df1, df2]).reset_index(drop=True)
|
|
1374
|
+
|
|
1375
|
+
# have to get the indexes to drop using a string rep of the df
|
|
1376
|
+
# this is cause we have lists in some cells
|
|
1377
|
+
# We also need to deal with merging the lists of references across rows
|
|
1378
|
+
# that we deem to be duplicates. This solution to do this quickly is from
|
|
1379
|
+
# https://stackoverflow.com/questions/36271413/ \
|
|
1380
|
+
# pandas-merge-nearly-duplicate-rows-based-on-column-value
|
|
1381
|
+
if merge_subkeys is None:
|
|
1382
|
+
merge_subkeys = merged_with_dups.columns.tolist()
|
|
1383
|
+
merge_subkeys.remove("reference")
|
|
1384
|
+
else:
|
|
1385
|
+
for k in merge_subkeys:
|
|
1386
|
+
if k not in merged_with_dups:
|
|
1387
|
+
merge_subkeys.remove(k)
|
|
1388
|
+
|
|
1389
|
+
merged = (
|
|
1390
|
+
merged_with_dups.astype(str)
|
|
1391
|
+
.groupby(merge_subkeys)["reference"]
|
|
1392
|
+
.apply(lambda x: x.sum())
|
|
1393
|
+
.reset_index()
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
# then we have to turn the merged reference strings into a string list
|
|
1397
|
+
merged["reference"] = merged.reference.str.replace("][", ",")
|
|
1398
|
+
|
|
1399
|
+
# then eval the string of a list to get back an actual list of sources
|
|
1400
|
+
merged["reference"] = merged.reference.apply(
|
|
1401
|
+
lambda v: np.unique(eval(v)).tolist()
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
# decide on default values
|
|
1405
|
+
if groupby_key is None:
|
|
1406
|
+
iterate_through = [(0, merged)]
|
|
1407
|
+
else:
|
|
1408
|
+
iterate_through = merged.groupby(groupby_key)
|
|
1409
|
+
|
|
1410
|
+
# we will make whichever value has more references the default
|
|
1411
|
+
outdict = []
|
|
1412
|
+
for data_type, df in iterate_through:
|
|
1413
|
+
lengths = df.reference.map(len)
|
|
1414
|
+
max_idx_arr = np.argmax(lengths)
|
|
1415
|
+
|
|
1416
|
+
if isinstance(max_idx_arr, np.int64):
|
|
1417
|
+
max_idx = max_idx_arr
|
|
1418
|
+
elif len(max_idx_arr) == 0:
|
|
1419
|
+
raise ValueError("Something went wrong with deciding the default")
|
|
1420
|
+
else:
|
|
1421
|
+
max_idx = max_idx_arr[0] # arbitrarily choose the first
|
|
1422
|
+
|
|
1423
|
+
defaults = np.full(len(df), False, dtype=bool)
|
|
1424
|
+
defaults[max_idx] = True
|
|
1425
|
+
|
|
1426
|
+
df["default"] = defaults
|
|
1427
|
+
outdict.append(df)
|
|
1428
|
+
outdict = pd.concat(outdict)
|
|
1429
|
+
|
|
1430
|
+
# from https://stackoverflow.com/questions/52504972/ \
|
|
1431
|
+
# converting-a-pandas-df-to-json-without-nan
|
|
1432
|
+
outdict = outdict.replace("nan", np.nan)
|
|
1433
|
+
outdict_cleaned = [{**x[i]} for i, x in outdict.stack().groupby(level=0)]
|
|
1434
|
+
|
|
1435
|
+
out[key] = outdict_cleaned
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def _fuzzy_class_root(s):
|
|
1439
|
+
"""
|
|
1440
|
+
Extract the fuzzy classification root name from the string s
|
|
1441
|
+
"""
|
|
1442
|
+
s = s.upper()
|
|
1443
|
+
# first split the class s using regex
|
|
1444
|
+
for root in _KNOWN_CLASS_ROOTS:
|
|
1445
|
+
if s.startswith(root):
|
|
1446
|
+
remaining = s[len(root) :]
|
|
1447
|
+
if remaining and root == "SN":
|
|
1448
|
+
# we want to be able to distinguish between SN Ia and SN II
|
|
1449
|
+
# we will use SN Ia to indicate thoes and SN to indicate CCSN
|
|
1450
|
+
if "IA" in remaining or "1A" in remaining:
|
|
1451
|
+
return "SN Ia"
|
|
1452
|
+
return root
|
|
1453
|
+
return s
|