timewise 0.5.3__py3-none-any.whl → 1.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- timewise/__init__.py +1 -5
- timewise/backend/__init__.py +6 -0
- timewise/backend/base.py +36 -0
- timewise/backend/filesystem.py +80 -0
- timewise/chunking.py +50 -0
- timewise/cli.py +117 -11
- timewise/config.py +34 -0
- timewise/io/__init__.py +1 -0
- timewise/io/config.py +64 -0
- timewise/io/download.py +302 -0
- timewise/io/stable_tap.py +121 -0
- timewise/plot/__init__.py +3 -0
- timewise/plot/diagnostic.py +242 -0
- timewise/plot/lightcurve.py +112 -0
- timewise/plot/panstarrs.py +260 -0
- timewise/plot/sdss.py +109 -0
- timewise/process/__init__.py +2 -0
- timewise/process/config.py +30 -0
- timewise/process/interface.py +143 -0
- timewise/process/keys.py +10 -0
- timewise/process/stacking.py +310 -0
- timewise/process/template.yml +49 -0
- timewise/query/__init__.py +6 -0
- timewise/query/base.py +45 -0
- timewise/query/positional.py +40 -0
- timewise/tables/__init__.py +10 -0
- timewise/tables/allwise_p3as_mep.py +22 -0
- timewise/tables/base.py +9 -0
- timewise/tables/neowiser_p1bs_psd.py +22 -0
- timewise/types.py +30 -0
- timewise/util/backoff.py +12 -0
- timewise/util/csv_utils.py +12 -0
- timewise/util/error_threading.py +70 -0
- timewise/util/visits.py +33 -0
- timewise-1.0.0a1.dist-info/METADATA +205 -0
- timewise-1.0.0a1.dist-info/RECORD +39 -0
- {timewise-0.5.3.dist-info → timewise-1.0.0a1.dist-info}/WHEEL +1 -1
- timewise-1.0.0a1.dist-info/entry_points.txt +3 -0
- timewise/big_parent_sample.py +0 -106
- timewise/config_loader.py +0 -157
- timewise/general.py +0 -52
- timewise/parent_sample_base.py +0 -89
- timewise/point_source_utils.py +0 -68
- timewise/utils.py +0 -558
- timewise/wise_bigdata_desy_cluster.py +0 -1407
- timewise/wise_data_base.py +0 -2027
- timewise/wise_data_by_visit.py +0 -672
- timewise/wise_flux_conversion_correction.dat +0 -19
- timewise-0.5.3.dist-info/METADATA +0 -55
- timewise-0.5.3.dist-info/RECORD +0 -17
- timewise-0.5.3.dist-info/entry_points.txt +0 -3
- {timewise-0.5.3.dist-info → timewise-1.0.0a1.dist-info/licenses}/LICENSE +0 -0
timewise/io/download.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import threading
|
|
3
|
+
import logging
|
|
4
|
+
from queue import Empty
|
|
5
|
+
from typing import Dict, Iterator, cast, Sequence
|
|
6
|
+
from itertools import product
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from datetime import datetime, timedelta
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import numpy as np
|
|
12
|
+
from astropy.table import Table
|
|
13
|
+
from pyvo.utils.http import create_session
|
|
14
|
+
|
|
15
|
+
from .stable_tap import StableTAPService
|
|
16
|
+
from ..backend import BackendType
|
|
17
|
+
from ..types import TAPJobMeta, TaskID, TYPE_MAP
|
|
18
|
+
from ..query import QueryType
|
|
19
|
+
from ..query.base import Query
|
|
20
|
+
from ..util.error_threading import ErrorQueue, ExceptionSafeThread
|
|
21
|
+
from ..chunking import Chunker, Chunk
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Downloader:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
service_url: str,
|
|
31
|
+
input_csv: Path,
|
|
32
|
+
chunk_size: int,
|
|
33
|
+
backend: BackendType,
|
|
34
|
+
queries: list[QueryType],
|
|
35
|
+
max_concurrent_jobs: int,
|
|
36
|
+
poll_interval: float,
|
|
37
|
+
):
|
|
38
|
+
self.backend = backend
|
|
39
|
+
self.queries = queries
|
|
40
|
+
self.input_csv = input_csv
|
|
41
|
+
self.max_concurrent_jobs = max_concurrent_jobs
|
|
42
|
+
self.poll_interval = poll_interval
|
|
43
|
+
|
|
44
|
+
# ----------------------------
|
|
45
|
+
# concurrency setup
|
|
46
|
+
# ----------------------------
|
|
47
|
+
# Shared state
|
|
48
|
+
self.job_lock = threading.Lock()
|
|
49
|
+
# (chunk_id, query_hash) -> job meta
|
|
50
|
+
self.jobs: Dict[TaskID, TAPJobMeta] = {}
|
|
51
|
+
|
|
52
|
+
self.stop_event = threading.Event()
|
|
53
|
+
self.submit_queue: ErrorQueue = ErrorQueue(stop_event=self.stop_event)
|
|
54
|
+
self.submit_thread = ExceptionSafeThread(
|
|
55
|
+
error_queue=self.submit_queue, target=self._submission_worker, daemon=True
|
|
56
|
+
)
|
|
57
|
+
self.poll_thread = ExceptionSafeThread(
|
|
58
|
+
error_queue=self.submit_queue, target=self._polling_worker, daemon=True
|
|
59
|
+
)
|
|
60
|
+
self.all_chunks_queued = False
|
|
61
|
+
self.all_chunks_submitted = False
|
|
62
|
+
|
|
63
|
+
# ----------------------------
|
|
64
|
+
# TAP setup
|
|
65
|
+
# ----------------------------
|
|
66
|
+
self.session = create_session()
|
|
67
|
+
self.service: StableTAPService = StableTAPService(
|
|
68
|
+
service_url, session=self.session
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.chunker = Chunker(input_csv=input_csv, chunk_size=chunk_size)
|
|
72
|
+
|
|
73
|
+
# ----------------------------
|
|
74
|
+
# helpers
|
|
75
|
+
# ----------------------------
|
|
76
|
+
@staticmethod
|
|
77
|
+
def get_task_id(chunk: Chunk, query: Query) -> TaskID:
|
|
78
|
+
return TaskID(
|
|
79
|
+
namespace="download", key=f"chunk{chunk.chunk_id:04d}_{query.hash}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def iter_tasks(self) -> Iterator[TaskID]:
|
|
83
|
+
for chunk in self.chunker:
|
|
84
|
+
for q in self.queries:
|
|
85
|
+
yield self.get_task_id(chunk, q)
|
|
86
|
+
|
|
87
|
+
def iter_tasks_per_chunk(self) -> Iterator[list[TaskID]]:
|
|
88
|
+
for chunk in self.chunker:
|
|
89
|
+
yield [self.get_task_id(chunk, q) for q in self.queries]
|
|
90
|
+
|
|
91
|
+
def load_job_meta(self):
|
|
92
|
+
backend = self.backend
|
|
93
|
+
for task in self.iter_tasks():
|
|
94
|
+
if backend.meta_exists(task):
|
|
95
|
+
logger.debug(f"found job metadata {task}")
|
|
96
|
+
if task not in self.jobs:
|
|
97
|
+
try:
|
|
98
|
+
jm = TAPJobMeta(**backend.load_meta(task))
|
|
99
|
+
logger.debug(f"loaded {jm}")
|
|
100
|
+
logger.debug(f"setting {task}")
|
|
101
|
+
with self.job_lock:
|
|
102
|
+
self.jobs[task] = jm
|
|
103
|
+
except Exception:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# ----------------------------
|
|
107
|
+
# TAP submission and download
|
|
108
|
+
# ----------------------------
|
|
109
|
+
def get_chunk_data(self, chunk: Chunk) -> pd.DataFrame:
|
|
110
|
+
start = (
|
|
111
|
+
min(cast(Sequence[int], chunk.row_numbers)) + 1
|
|
112
|
+
) # plus one to always skip header line
|
|
113
|
+
nrows = (
|
|
114
|
+
max(cast(Sequence[int], chunk.row_numbers)) - start + 2
|
|
115
|
+
) # plus one: skip header, plus one:
|
|
116
|
+
|
|
117
|
+
columns = list(pd.read_csv(self.input_csv, nrows=0).columns)
|
|
118
|
+
return pd.read_csv(
|
|
119
|
+
filepath_or_buffer=self.input_csv,
|
|
120
|
+
skiprows=start,
|
|
121
|
+
nrows=nrows,
|
|
122
|
+
names=columns,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def submit_tap_job(self, query: Query, chunk: Chunk) -> TAPJobMeta:
|
|
126
|
+
adql = query.adql
|
|
127
|
+
chunk_df = self.get_chunk_data(chunk)
|
|
128
|
+
|
|
129
|
+
assert all(chunk_df.index.isin(chunk.indices)), (
|
|
130
|
+
"Some inputs loaded from wrong chunk!"
|
|
131
|
+
)
|
|
132
|
+
assert all(np.isin(chunk.indices, chunk_df.index)), (
|
|
133
|
+
f"Some indices are missing in chunk {chunk.chunk_id}!"
|
|
134
|
+
)
|
|
135
|
+
logger.debug(f"loaded {len(chunk_df)} objects")
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
upload = Table(
|
|
139
|
+
{
|
|
140
|
+
key: np.array(chunk_df[key]).astype(TYPE_MAP[dtype])
|
|
141
|
+
for key, dtype in query.input_columns.items()
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
except KeyError as e:
|
|
145
|
+
print(chunk_df)
|
|
146
|
+
raise KeyError(e)
|
|
147
|
+
|
|
148
|
+
logger.debug(f"uploading {len(upload)} objects.")
|
|
149
|
+
job = self.service.submit_job(adql, uploads={query.upload_name: upload})
|
|
150
|
+
job.run()
|
|
151
|
+
logger.debug(job.url)
|
|
152
|
+
|
|
153
|
+
return TAPJobMeta(
|
|
154
|
+
url=job.url,
|
|
155
|
+
query=adql,
|
|
156
|
+
query_config=query.model_dump(),
|
|
157
|
+
input_length=len(chunk_df),
|
|
158
|
+
submitted=str(datetime.now()),
|
|
159
|
+
last_checked=str(datetime.now()),
|
|
160
|
+
status=job.phase,
|
|
161
|
+
completed_at="",
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def check_job_status(self, job_meta: TAPJobMeta) -> str:
|
|
165
|
+
status = self.service.get_job_from_url(url=job_meta["url"]).phase
|
|
166
|
+
job_meta["last_checked"] = str(datetime.now())
|
|
167
|
+
return status
|
|
168
|
+
|
|
169
|
+
def download_job_result(self, job_meta: TAPJobMeta) -> Table:
|
|
170
|
+
logger.info(f"downloading {job_meta['url']}")
|
|
171
|
+
job = self.service.get_job_from_url(url=job_meta["url"])
|
|
172
|
+
job.wait()
|
|
173
|
+
return job.fetch_result().to_table()
|
|
174
|
+
|
|
175
|
+
# ----------------------------
|
|
176
|
+
# Submission thread
|
|
177
|
+
# ----------------------------
|
|
178
|
+
def _submission_worker(self):
|
|
179
|
+
while not self.stop_event.is_set():
|
|
180
|
+
try:
|
|
181
|
+
chunk, query = self.submit_queue.get(timeout=1.0) # type: Chunk, Query
|
|
182
|
+
except Empty:
|
|
183
|
+
if self.all_chunks_queued:
|
|
184
|
+
self.all_chunks_submitted = True
|
|
185
|
+
break
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
# Wait until we have capacity
|
|
189
|
+
while not self.stop_event.is_set():
|
|
190
|
+
with self.job_lock:
|
|
191
|
+
running = sum(
|
|
192
|
+
1
|
|
193
|
+
for j in self.jobs.values()
|
|
194
|
+
if j.get("status") in ("QUEUED", "EXECUTING", "RUNNING")
|
|
195
|
+
)
|
|
196
|
+
if running < self.max_concurrent_jobs:
|
|
197
|
+
break
|
|
198
|
+
time.sleep(1.0)
|
|
199
|
+
|
|
200
|
+
task = self.get_task_id(chunk, query)
|
|
201
|
+
logger.info(f"submitting {task}")
|
|
202
|
+
job_meta = self.submit_tap_job(query, chunk)
|
|
203
|
+
self.backend.save_meta(task, job_meta)
|
|
204
|
+
with self.job_lock:
|
|
205
|
+
self.jobs[task] = job_meta
|
|
206
|
+
|
|
207
|
+
self.submit_queue.task_done()
|
|
208
|
+
|
|
209
|
+
# ----------------------------
|
|
210
|
+
# Polling thread
|
|
211
|
+
# ----------------------------
|
|
212
|
+
def _polling_worker(self):
|
|
213
|
+
logger.debug("starting polling worker")
|
|
214
|
+
backend = self.backend
|
|
215
|
+
while not self.stop_event.is_set():
|
|
216
|
+
# reload job infos
|
|
217
|
+
self.load_job_meta()
|
|
218
|
+
|
|
219
|
+
with self.job_lock:
|
|
220
|
+
items = list(self.jobs.items())
|
|
221
|
+
|
|
222
|
+
for task, meta in items: # type: TaskID, TAPJobMeta
|
|
223
|
+
if meta.get("status") in ("COMPLETED", "ERROR", "ABORTED"):
|
|
224
|
+
logger.debug(f"{task} was already {meta['status']}")
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
status = self.check_job_status(meta)
|
|
228
|
+
if status == "COMPLETED":
|
|
229
|
+
logger.info(f"completed {task}")
|
|
230
|
+
payload_table = self.download_job_result(meta)
|
|
231
|
+
logger.debug(payload_table.columns)
|
|
232
|
+
backend.save_data(task, payload_table)
|
|
233
|
+
meta["status"] = "COMPLETED"
|
|
234
|
+
meta["completed_at"] = str(datetime.now())
|
|
235
|
+
backend.save_meta(task, meta)
|
|
236
|
+
backend.mark_done(task)
|
|
237
|
+
with self.job_lock:
|
|
238
|
+
self.jobs[task] = meta
|
|
239
|
+
elif status in ("ERROR", "ABORTED"):
|
|
240
|
+
logger.warning(f"failed {task}: {status}")
|
|
241
|
+
meta["status"] = status
|
|
242
|
+
with self.job_lock:
|
|
243
|
+
self.jobs[task] = meta
|
|
244
|
+
backend.save_meta(task, meta)
|
|
245
|
+
else:
|
|
246
|
+
with self.job_lock:
|
|
247
|
+
self.jobs[task]["status"] = status
|
|
248
|
+
snapshot = self.jobs[task]
|
|
249
|
+
backend.save_meta(task, snapshot)
|
|
250
|
+
|
|
251
|
+
if self.all_chunks_submitted:
|
|
252
|
+
with self.job_lock:
|
|
253
|
+
all_done = (
|
|
254
|
+
all(
|
|
255
|
+
j.get("status") in ("COMPLETED", "ERROR", "ABORTED")
|
|
256
|
+
for j in self.jobs.values()
|
|
257
|
+
)
|
|
258
|
+
if len(self.jobs) > 0
|
|
259
|
+
else False
|
|
260
|
+
)
|
|
261
|
+
if all_done:
|
|
262
|
+
logger.info("All tasks done! Exiting polling thread")
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
logger.info(
|
|
266
|
+
f"Next poll at {datetime.now() + timedelta(seconds=self.poll_interval)}s"
|
|
267
|
+
)
|
|
268
|
+
time.sleep(self.poll_interval)
|
|
269
|
+
|
|
270
|
+
# ----------------------------
|
|
271
|
+
# Main run loop
|
|
272
|
+
# ----------------------------
|
|
273
|
+
def run(self):
|
|
274
|
+
# load existing job metadata
|
|
275
|
+
self.load_job_meta()
|
|
276
|
+
|
|
277
|
+
# start threads
|
|
278
|
+
self.submit_thread.start()
|
|
279
|
+
self.poll_thread.start()
|
|
280
|
+
|
|
281
|
+
# enqueue all chunks & queries
|
|
282
|
+
backend = self.backend
|
|
283
|
+
for chunk, q in product(self.chunker, self.queries):
|
|
284
|
+
task = self.get_task_id(chunk, q)
|
|
285
|
+
|
|
286
|
+
# skip if the download is done, or the job is queued
|
|
287
|
+
if backend.is_done(task) or (task in self.jobs):
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
self.submit_queue.put((chunk, q))
|
|
291
|
+
self.all_chunks_queued = True
|
|
292
|
+
# wait until all jobs are submitted
|
|
293
|
+
self.submit_queue.join()
|
|
294
|
+
# wait for the submit thread
|
|
295
|
+
self.submit_thread.join()
|
|
296
|
+
# the polling thread will exit ones all results are downloaded
|
|
297
|
+
self.poll_thread.join()
|
|
298
|
+
# the stop event will stop also the submit thread
|
|
299
|
+
self.stop_event.set()
|
|
300
|
+
# if any thread exited with an error report it
|
|
301
|
+
self.submit_queue.raise_errors()
|
|
302
|
+
logger.info("Done running downloader!")
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import backoff
|
|
3
|
+
import pyvo as vo
|
|
4
|
+
from xml.etree import ElementTree
|
|
5
|
+
|
|
6
|
+
from timewise.util.backoff import backoff_hndlr
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StableAsyncTAPJob(vo.dal.AsyncTAPJob):
|
|
13
|
+
"""
|
|
14
|
+
Implements backoff for call of phase which otherwise breaks the code if there are connection issues.
|
|
15
|
+
Also stores the response of TapQuery.submit() under self.submit_response for debugging
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, url, *, session=None, delete=True):
|
|
19
|
+
super(StableAsyncTAPJob, self).__init__(url, session=session, delete=delete)
|
|
20
|
+
self.submit_response = None
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def create(
|
|
24
|
+
cls,
|
|
25
|
+
baseurl,
|
|
26
|
+
query,
|
|
27
|
+
*,
|
|
28
|
+
language="ADQL",
|
|
29
|
+
maxrec=None,
|
|
30
|
+
uploads=None,
|
|
31
|
+
session=None,
|
|
32
|
+
**keywords,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
creates a async tap job on the server under ``baseurl``
|
|
36
|
+
Raises requests.HTTPError if TAPQuery.submit() failes.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
baseurl : str
|
|
41
|
+
the TAP baseurl
|
|
42
|
+
query : str
|
|
43
|
+
the query string
|
|
44
|
+
language : str
|
|
45
|
+
specifies the query language, default ADQL.
|
|
46
|
+
useful for services which allow to use the backend query language.
|
|
47
|
+
maxrec : int
|
|
48
|
+
the maximum records to return. defaults to the service default
|
|
49
|
+
uploads : dict
|
|
50
|
+
a mapping from table names to objects containing a votable
|
|
51
|
+
session : object
|
|
52
|
+
optional session to use for network requests
|
|
53
|
+
"""
|
|
54
|
+
tapquery = vo.dal.TAPQuery(
|
|
55
|
+
baseurl,
|
|
56
|
+
query,
|
|
57
|
+
mode="async",
|
|
58
|
+
language=language,
|
|
59
|
+
maxrec=maxrec,
|
|
60
|
+
uploads=uploads,
|
|
61
|
+
session=session,
|
|
62
|
+
**keywords,
|
|
63
|
+
)
|
|
64
|
+
response = tapquery.submit()
|
|
65
|
+
|
|
66
|
+
# check if the response is valid
|
|
67
|
+
response.raise_for_status()
|
|
68
|
+
|
|
69
|
+
# check if the response contains an error from the ADQL engine
|
|
70
|
+
root = ElementTree.fromstring(response.content)
|
|
71
|
+
info = root.find(".//v:INFO", {"v": "http://www.ivoa.net/xml/VOTable/v1.3"})
|
|
72
|
+
if info and (info.attrib.get("value") == "ERROR"):
|
|
73
|
+
raise vo.dal.DALQueryError(info.text.strip())
|
|
74
|
+
|
|
75
|
+
# create the job instance
|
|
76
|
+
job = cls(response.url, session=session)
|
|
77
|
+
job._client_set_maxrec = maxrec
|
|
78
|
+
job.submit_response = response
|
|
79
|
+
|
|
80
|
+
return job
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
@backoff.on_exception(
|
|
84
|
+
backoff.expo,
|
|
85
|
+
(vo.dal.DALServiceError, AttributeError),
|
|
86
|
+
max_tries=50,
|
|
87
|
+
on_backoff=backoff_hndlr,
|
|
88
|
+
)
|
|
89
|
+
def phase(self):
|
|
90
|
+
return super(StableAsyncTAPJob, self).phase
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class StableTAPService(vo.dal.TAPService):
|
|
94
|
+
"""
|
|
95
|
+
Implements the StableAsyncTAPJob for job submission
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@backoff.on_exception(
|
|
99
|
+
backoff.expo,
|
|
100
|
+
(vo.dal.DALServiceError, AttributeError, AssertionError),
|
|
101
|
+
max_tries=5,
|
|
102
|
+
on_backoff=backoff_hndlr,
|
|
103
|
+
)
|
|
104
|
+
def submit_job(
|
|
105
|
+
self, query, *, language="ADQL", maxrec=None, uploads=None, **keywords
|
|
106
|
+
):
|
|
107
|
+
job = StableAsyncTAPJob.create(
|
|
108
|
+
self.baseurl,
|
|
109
|
+
query,
|
|
110
|
+
language=language,
|
|
111
|
+
maxrec=maxrec,
|
|
112
|
+
uploads=uploads,
|
|
113
|
+
session=self._session,
|
|
114
|
+
**keywords,
|
|
115
|
+
)
|
|
116
|
+
logger.debug(job.url)
|
|
117
|
+
assert job.phase
|
|
118
|
+
return job
|
|
119
|
+
|
|
120
|
+
def get_job_from_url(self, url):
|
|
121
|
+
return StableAsyncTAPJob(url, session=self._session)
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from typing import Literal, Dict, Any, Sequence, List, cast
|
|
2
|
+
from functools import partial
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy import typing as npt
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
from matplotlib.lines import Line2D
|
|
12
|
+
from matplotlib.markers import MarkerStyle
|
|
13
|
+
from matplotlib.transforms import Affine2D
|
|
14
|
+
|
|
15
|
+
from timewise.plot import plot_lightcurve, plot_panstarrs_cutout, plot_sdss_cutout
|
|
16
|
+
from timewise.plot.lightcurve import BAND_PLOT_COLORS
|
|
17
|
+
from timewise.process import keys
|
|
18
|
+
from timewise.util.visits import get_visit_map
|
|
19
|
+
from timewise.config import TimewiseConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DiagnosticPlotter(BaseModel):
|
|
26
|
+
cutout: Literal["sdss", "panstarrs"] = "panstarrs"
|
|
27
|
+
band_colors: Dict[str, str] = BAND_PLOT_COLORS
|
|
28
|
+
lum_key: str = keys.FLUX_EXT
|
|
29
|
+
|
|
30
|
+
def plot_lightcurve(
|
|
31
|
+
self,
|
|
32
|
+
stacked_lightcurve: pd.DataFrame | None = None,
|
|
33
|
+
raw_lightcurve: pd.DataFrame | None = None,
|
|
34
|
+
ax: plt.Axes | None = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
return plot_lightcurve(
|
|
38
|
+
lum_key=self.lum_key,
|
|
39
|
+
stacked_lightcurve=stacked_lightcurve,
|
|
40
|
+
raw_lightcurve=raw_lightcurve,
|
|
41
|
+
ax=ax,
|
|
42
|
+
colors=self.band_colors,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def plot_cutout(self, ra: float, dec: float, radius_arcsec: float, ax: plt.Axes):
|
|
47
|
+
if self.cutout == "sdss":
|
|
48
|
+
plot_cutout = plot_sdss_cutout
|
|
49
|
+
elif self.cutout == "panstarrs":
|
|
50
|
+
plot_cutout = partial(plot_panstarrs_cutout, plot_color_image=True)
|
|
51
|
+
else:
|
|
52
|
+
raise NotImplementedError # should never happen
|
|
53
|
+
return plot_cutout(ra=ra, dec=dec, arcsec=radius_arcsec, ax=ax)
|
|
54
|
+
|
|
55
|
+
def make_plot(
|
|
56
|
+
self,
|
|
57
|
+
stacked_lightcurve: pd.DataFrame | None,
|
|
58
|
+
raw_lightcurve: pd.DataFrame,
|
|
59
|
+
labels: npt.ArrayLike,
|
|
60
|
+
source_ra: float,
|
|
61
|
+
source_dec: float,
|
|
62
|
+
selected_indices: list[Any],
|
|
63
|
+
highlight_radius: float | None = None,
|
|
64
|
+
) -> tuple[plt.Figure, Sequence[plt.Axes]]:
|
|
65
|
+
fig, axs = plt.subplots(
|
|
66
|
+
nrows=2, gridspec_kw={"height_ratios": [3, 2]}, figsize=(5, 8)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.plot_cutout(ra=source_ra, dec=source_dec, ax=axs[0], radius_arcsec=20)
|
|
70
|
+
|
|
71
|
+
selected_mask = raw_lightcurve.index.isin(selected_indices)
|
|
72
|
+
plot_lightcurve(
|
|
73
|
+
raw_lightcurve=raw_lightcurve[~selected_mask],
|
|
74
|
+
lum_key=self.lum_key,
|
|
75
|
+
ax=axs[-1],
|
|
76
|
+
save=False,
|
|
77
|
+
colors={"w1": "gray", "w2": "lightgray"},
|
|
78
|
+
add_to_label=" ignored",
|
|
79
|
+
)
|
|
80
|
+
self.plot_lightcurve(
|
|
81
|
+
stacked_lightcurve=stacked_lightcurve,
|
|
82
|
+
raw_lightcurve=raw_lightcurve[selected_mask],
|
|
83
|
+
ax=axs[-1],
|
|
84
|
+
save=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# set markers for clusters
|
|
88
|
+
markers_strings = list(Line2D.filled_markers) + [
|
|
89
|
+
"$1$",
|
|
90
|
+
"$2$",
|
|
91
|
+
"$3$",
|
|
92
|
+
"$4$",
|
|
93
|
+
"$5$",
|
|
94
|
+
"$6$",
|
|
95
|
+
"$7$",
|
|
96
|
+
"$8$",
|
|
97
|
+
"$9$",
|
|
98
|
+
]
|
|
99
|
+
markers_straight = [MarkerStyle(im) for im in markers_strings]
|
|
100
|
+
rot = Affine2D().rotate_deg(180)
|
|
101
|
+
markers_rotated = [MarkerStyle(im, transform=rot) for im in markers_strings]
|
|
102
|
+
markers = markers_straight + markers_rotated
|
|
103
|
+
|
|
104
|
+
# calculate ra and dec relative to center of cutout
|
|
105
|
+
ra = (raw_lightcurve.ra - source_ra) * 3600
|
|
106
|
+
dec = (raw_lightcurve.dec - source_dec) * 3600
|
|
107
|
+
|
|
108
|
+
# get visit map
|
|
109
|
+
visit_map = get_visit_map(raw_lightcurve)
|
|
110
|
+
|
|
111
|
+
# for each visit plot the datapoints on the cutout
|
|
112
|
+
# for each visit plot the datapoints on the cutout
|
|
113
|
+
for visit in np.unique(visit_map):
|
|
114
|
+
m = visit_map == visit
|
|
115
|
+
label = str(visit)
|
|
116
|
+
axs[0].plot(
|
|
117
|
+
[],
|
|
118
|
+
[],
|
|
119
|
+
marker=markers[visit],
|
|
120
|
+
label=label,
|
|
121
|
+
mec="k",
|
|
122
|
+
mew=1,
|
|
123
|
+
mfc="none",
|
|
124
|
+
ls="",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
for im, mec, zorder in zip(
|
|
128
|
+
[selected_mask, ~selected_mask], ["k", "none"], [1, 0]
|
|
129
|
+
):
|
|
130
|
+
mask = m & im
|
|
131
|
+
|
|
132
|
+
for i_label in np.unique(labels):
|
|
133
|
+
label_mask = labels == i_label
|
|
134
|
+
final_mask = mask & label_mask
|
|
135
|
+
datapoints_label = raw_lightcurve[final_mask]
|
|
136
|
+
color = f"C{i_label}" if i_label != -1 else "grey"
|
|
137
|
+
|
|
138
|
+
if ("sigra" in datapoints_label.columns) and (
|
|
139
|
+
"sigdec" in datapoints_label.columns
|
|
140
|
+
):
|
|
141
|
+
has_sig = (
|
|
142
|
+
~datapoints_label.sigra.isna()
|
|
143
|
+
& ~datapoints_label.sigdec.isna()
|
|
144
|
+
)
|
|
145
|
+
_ra = ra[final_mask]
|
|
146
|
+
_dec = dec[final_mask]
|
|
147
|
+
|
|
148
|
+
axs[0].errorbar(
|
|
149
|
+
_ra[has_sig],
|
|
150
|
+
_dec[has_sig],
|
|
151
|
+
xerr=datapoints_label.sigra[has_sig] / 3600,
|
|
152
|
+
yerr=datapoints_label.sigdec[has_sig] / 3600,
|
|
153
|
+
marker=markers[visit],
|
|
154
|
+
ls="",
|
|
155
|
+
color=color,
|
|
156
|
+
zorder=zorder,
|
|
157
|
+
ms=10,
|
|
158
|
+
mec=mec,
|
|
159
|
+
mew=0.1,
|
|
160
|
+
)
|
|
161
|
+
axs[0].scatter(
|
|
162
|
+
_ra[~has_sig],
|
|
163
|
+
_dec[~has_sig],
|
|
164
|
+
marker=markers[visit],
|
|
165
|
+
color=color,
|
|
166
|
+
zorder=zorder,
|
|
167
|
+
edgecolors=mec,
|
|
168
|
+
linewidths=0.1,
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
axs[0].scatter(
|
|
172
|
+
ra[final_mask],
|
|
173
|
+
dec[final_mask],
|
|
174
|
+
marker=markers[visit],
|
|
175
|
+
color=color,
|
|
176
|
+
zorder=zorder,
|
|
177
|
+
edgecolors=mec,
|
|
178
|
+
linewidths=0.1,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if highlight_radius:
|
|
182
|
+
circle = plt.Circle(
|
|
183
|
+
(0, 0),
|
|
184
|
+
highlight_radius,
|
|
185
|
+
color="g",
|
|
186
|
+
fill=False,
|
|
187
|
+
ls="-",
|
|
188
|
+
lw=3,
|
|
189
|
+
zorder=0,
|
|
190
|
+
)
|
|
191
|
+
axs[0].add_artist(circle)
|
|
192
|
+
|
|
193
|
+
# formatting
|
|
194
|
+
title = axs[0].get_title()
|
|
195
|
+
axs[-1].set_ylabel("Apparent Vega Magnitude")
|
|
196
|
+
axs[-1].grid(ls=":", alpha=0.5)
|
|
197
|
+
axs[0].set_title("")
|
|
198
|
+
axs[0].legend(
|
|
199
|
+
ncol=5,
|
|
200
|
+
bbox_to_anchor=(0, 1, 1, 0),
|
|
201
|
+
loc="lower left",
|
|
202
|
+
mode="expand",
|
|
203
|
+
title=title,
|
|
204
|
+
)
|
|
205
|
+
axs[0].set_aspect(1, adjustable="box")
|
|
206
|
+
|
|
207
|
+
return fig, axs
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def make_plot(
|
|
211
|
+
config_path: Path,
|
|
212
|
+
cutout: Literal["sdss", "panstarrs"],
|
|
213
|
+
indices: List[int],
|
|
214
|
+
output_directory: Path,
|
|
215
|
+
):
|
|
216
|
+
cfg = TimewiseConfig.from_yaml(config_path)
|
|
217
|
+
ampel_interface = cfg.build_ampel_interface()
|
|
218
|
+
input_data = pd.read_csv(cfg.download.input_csv).set_index(
|
|
219
|
+
ampel_interface.orig_id_key
|
|
220
|
+
)
|
|
221
|
+
plotter = DiagnosticPlotter(cutout=cutout)
|
|
222
|
+
for index in indices:
|
|
223
|
+
stacked_lightcurve = ampel_interface.extract_stacked_lightcurve(stock_id=index)
|
|
224
|
+
raw_lightcurve = ampel_interface.extract_datapoints(stock_id=index)
|
|
225
|
+
selected_dp_ids = ampel_interface.extract_selected_datapoint_ids(stock_id=index)
|
|
226
|
+
labels = [0] * len(raw_lightcurve)
|
|
227
|
+
source = input_data.loc[index]
|
|
228
|
+
ra: float = cast(float, source.ra)
|
|
229
|
+
dec: float = cast(float, source.dec)
|
|
230
|
+
|
|
231
|
+
fig, axs = plotter.make_plot(
|
|
232
|
+
stacked_lightcurve=stacked_lightcurve,
|
|
233
|
+
raw_lightcurve=raw_lightcurve,
|
|
234
|
+
labels=labels,
|
|
235
|
+
source_ra=ra,
|
|
236
|
+
source_dec=dec,
|
|
237
|
+
selected_indices=selected_dp_ids,
|
|
238
|
+
)
|
|
239
|
+
fn = output_directory / f"{index}.pdf"
|
|
240
|
+
logger.info(f"Saving plot to {fn}")
|
|
241
|
+
fig.savefig(fn)
|
|
242
|
+
plt.close()
|