timewise 0.5.3__py3-none-any.whl → 1.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. timewise/__init__.py +1 -5
  2. timewise/backend/__init__.py +6 -0
  3. timewise/backend/base.py +36 -0
  4. timewise/backend/filesystem.py +80 -0
  5. timewise/chunking.py +50 -0
  6. timewise/cli.py +117 -11
  7. timewise/config.py +34 -0
  8. timewise/io/__init__.py +1 -0
  9. timewise/io/config.py +64 -0
  10. timewise/io/download.py +302 -0
  11. timewise/io/stable_tap.py +121 -0
  12. timewise/plot/__init__.py +3 -0
  13. timewise/plot/diagnostic.py +242 -0
  14. timewise/plot/lightcurve.py +112 -0
  15. timewise/plot/panstarrs.py +260 -0
  16. timewise/plot/sdss.py +109 -0
  17. timewise/process/__init__.py +2 -0
  18. timewise/process/config.py +30 -0
  19. timewise/process/interface.py +143 -0
  20. timewise/process/keys.py +10 -0
  21. timewise/process/stacking.py +310 -0
  22. timewise/process/template.yml +49 -0
  23. timewise/query/__init__.py +6 -0
  24. timewise/query/base.py +45 -0
  25. timewise/query/positional.py +40 -0
  26. timewise/tables/__init__.py +10 -0
  27. timewise/tables/allwise_p3as_mep.py +22 -0
  28. timewise/tables/base.py +9 -0
  29. timewise/tables/neowiser_p1bs_psd.py +22 -0
  30. timewise/types.py +30 -0
  31. timewise/util/backoff.py +12 -0
  32. timewise/util/csv_utils.py +12 -0
  33. timewise/util/error_threading.py +70 -0
  34. timewise/util/visits.py +33 -0
  35. timewise-1.0.0a1.dist-info/METADATA +205 -0
  36. timewise-1.0.0a1.dist-info/RECORD +39 -0
  37. {timewise-0.5.3.dist-info → timewise-1.0.0a1.dist-info}/WHEEL +1 -1
  38. timewise-1.0.0a1.dist-info/entry_points.txt +3 -0
  39. timewise/big_parent_sample.py +0 -106
  40. timewise/config_loader.py +0 -157
  41. timewise/general.py +0 -52
  42. timewise/parent_sample_base.py +0 -89
  43. timewise/point_source_utils.py +0 -68
  44. timewise/utils.py +0 -558
  45. timewise/wise_bigdata_desy_cluster.py +0 -1407
  46. timewise/wise_data_base.py +0 -2027
  47. timewise/wise_data_by_visit.py +0 -672
  48. timewise/wise_flux_conversion_correction.dat +0 -19
  49. timewise-0.5.3.dist-info/METADATA +0 -55
  50. timewise-0.5.3.dist-info/RECORD +0 -17
  51. timewise-0.5.3.dist-info/entry_points.txt +0 -3
  52. {timewise-0.5.3.dist-info → timewise-1.0.0a1.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,302 @@
1
+ import time
2
+ import threading
3
+ import logging
4
+ from queue import Empty
5
+ from typing import Dict, Iterator, cast, Sequence
6
+ from itertools import product
7
+ from pathlib import Path
8
+ from datetime import datetime, timedelta
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ from astropy.table import Table
13
+ from pyvo.utils.http import create_session
14
+
15
+ from .stable_tap import StableTAPService
16
+ from ..backend import BackendType
17
+ from ..types import TAPJobMeta, TaskID, TYPE_MAP
18
+ from ..query import QueryType
19
+ from ..query.base import Query
20
+ from ..util.error_threading import ErrorQueue, ExceptionSafeThread
21
+ from ..chunking import Chunker, Chunk
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class Downloader:
28
+ def __init__(
29
+ self,
30
+ service_url: str,
31
+ input_csv: Path,
32
+ chunk_size: int,
33
+ backend: BackendType,
34
+ queries: list[QueryType],
35
+ max_concurrent_jobs: int,
36
+ poll_interval: float,
37
+ ):
38
+ self.backend = backend
39
+ self.queries = queries
40
+ self.input_csv = input_csv
41
+ self.max_concurrent_jobs = max_concurrent_jobs
42
+ self.poll_interval = poll_interval
43
+
44
+ # ----------------------------
45
+ # concurrency setup
46
+ # ----------------------------
47
+ # Shared state
48
+ self.job_lock = threading.Lock()
49
+ # (chunk_id, query_hash) -> job meta
50
+ self.jobs: Dict[TaskID, TAPJobMeta] = {}
51
+
52
+ self.stop_event = threading.Event()
53
+ self.submit_queue: ErrorQueue = ErrorQueue(stop_event=self.stop_event)
54
+ self.submit_thread = ExceptionSafeThread(
55
+ error_queue=self.submit_queue, target=self._submission_worker, daemon=True
56
+ )
57
+ self.poll_thread = ExceptionSafeThread(
58
+ error_queue=self.submit_queue, target=self._polling_worker, daemon=True
59
+ )
60
+ self.all_chunks_queued = False
61
+ self.all_chunks_submitted = False
62
+
63
+ # ----------------------------
64
+ # TAP setup
65
+ # ----------------------------
66
+ self.session = create_session()
67
+ self.service: StableTAPService = StableTAPService(
68
+ service_url, session=self.session
69
+ )
70
+
71
+ self.chunker = Chunker(input_csv=input_csv, chunk_size=chunk_size)
72
+
73
+ # ----------------------------
74
+ # helpers
75
+ # ----------------------------
76
+ @staticmethod
77
+ def get_task_id(chunk: Chunk, query: Query) -> TaskID:
78
+ return TaskID(
79
+ namespace="download", key=f"chunk{chunk.chunk_id:04d}_{query.hash}"
80
+ )
81
+
82
+ def iter_tasks(self) -> Iterator[TaskID]:
83
+ for chunk in self.chunker:
84
+ for q in self.queries:
85
+ yield self.get_task_id(chunk, q)
86
+
87
+ def iter_tasks_per_chunk(self) -> Iterator[list[TaskID]]:
88
+ for chunk in self.chunker:
89
+ yield [self.get_task_id(chunk, q) for q in self.queries]
90
+
91
+ def load_job_meta(self):
92
+ backend = self.backend
93
+ for task in self.iter_tasks():
94
+ if backend.meta_exists(task):
95
+ logger.debug(f"found job metadata {task}")
96
+ if task not in self.jobs:
97
+ try:
98
+ jm = TAPJobMeta(**backend.load_meta(task))
99
+ logger.debug(f"loaded {jm}")
100
+ logger.debug(f"setting {task}")
101
+ with self.job_lock:
102
+ self.jobs[task] = jm
103
+ except Exception:
104
+ continue
105
+
106
+ # ----------------------------
107
+ # TAP submission and download
108
+ # ----------------------------
109
+ def get_chunk_data(self, chunk: Chunk) -> pd.DataFrame:
110
+ start = (
111
+ min(cast(Sequence[int], chunk.row_numbers)) + 1
112
+ ) # plus one to always skip header line
113
+ nrows = (
114
+ max(cast(Sequence[int], chunk.row_numbers)) - start + 2
115
+ ) # plus one: skip header, plus one:
116
+
117
+ columns = list(pd.read_csv(self.input_csv, nrows=0).columns)
118
+ return pd.read_csv(
119
+ filepath_or_buffer=self.input_csv,
120
+ skiprows=start,
121
+ nrows=nrows,
122
+ names=columns,
123
+ )
124
+
125
+ def submit_tap_job(self, query: Query, chunk: Chunk) -> TAPJobMeta:
126
+ adql = query.adql
127
+ chunk_df = self.get_chunk_data(chunk)
128
+
129
+ assert all(chunk_df.index.isin(chunk.indices)), (
130
+ "Some inputs loaded from wrong chunk!"
131
+ )
132
+ assert all(np.isin(chunk.indices, chunk_df.index)), (
133
+ f"Some indices are missing in chunk {chunk.chunk_id}!"
134
+ )
135
+ logger.debug(f"loaded {len(chunk_df)} objects")
136
+
137
+ try:
138
+ upload = Table(
139
+ {
140
+ key: np.array(chunk_df[key]).astype(TYPE_MAP[dtype])
141
+ for key, dtype in query.input_columns.items()
142
+ }
143
+ )
144
+ except KeyError as e:
145
+ print(chunk_df)
146
+ raise KeyError(e)
147
+
148
+ logger.debug(f"uploading {len(upload)} objects.")
149
+ job = self.service.submit_job(adql, uploads={query.upload_name: upload})
150
+ job.run()
151
+ logger.debug(job.url)
152
+
153
+ return TAPJobMeta(
154
+ url=job.url,
155
+ query=adql,
156
+ query_config=query.model_dump(),
157
+ input_length=len(chunk_df),
158
+ submitted=str(datetime.now()),
159
+ last_checked=str(datetime.now()),
160
+ status=job.phase,
161
+ completed_at="",
162
+ )
163
+
164
+ def check_job_status(self, job_meta: TAPJobMeta) -> str:
165
+ status = self.service.get_job_from_url(url=job_meta["url"]).phase
166
+ job_meta["last_checked"] = str(datetime.now())
167
+ return status
168
+
169
+ def download_job_result(self, job_meta: TAPJobMeta) -> Table:
170
+ logger.info(f"downloading {job_meta['url']}")
171
+ job = self.service.get_job_from_url(url=job_meta["url"])
172
+ job.wait()
173
+ return job.fetch_result().to_table()
174
+
175
+ # ----------------------------
176
+ # Submission thread
177
+ # ----------------------------
178
+ def _submission_worker(self):
179
+ while not self.stop_event.is_set():
180
+ try:
181
+ chunk, query = self.submit_queue.get(timeout=1.0) # type: Chunk, Query
182
+ except Empty:
183
+ if self.all_chunks_queued:
184
+ self.all_chunks_submitted = True
185
+ break
186
+ continue
187
+
188
+ # Wait until we have capacity
189
+ while not self.stop_event.is_set():
190
+ with self.job_lock:
191
+ running = sum(
192
+ 1
193
+ for j in self.jobs.values()
194
+ if j.get("status") in ("QUEUED", "EXECUTING", "RUNNING")
195
+ )
196
+ if running < self.max_concurrent_jobs:
197
+ break
198
+ time.sleep(1.0)
199
+
200
+ task = self.get_task_id(chunk, query)
201
+ logger.info(f"submitting {task}")
202
+ job_meta = self.submit_tap_job(query, chunk)
203
+ self.backend.save_meta(task, job_meta)
204
+ with self.job_lock:
205
+ self.jobs[task] = job_meta
206
+
207
+ self.submit_queue.task_done()
208
+
209
+ # ----------------------------
210
+ # Polling thread
211
+ # ----------------------------
212
+ def _polling_worker(self):
213
+ logger.debug("starting polling worker")
214
+ backend = self.backend
215
+ while not self.stop_event.is_set():
216
+ # reload job infos
217
+ self.load_job_meta()
218
+
219
+ with self.job_lock:
220
+ items = list(self.jobs.items())
221
+
222
+ for task, meta in items: # type: TaskID, TAPJobMeta
223
+ if meta.get("status") in ("COMPLETED", "ERROR", "ABORTED"):
224
+ logger.debug(f"{task} was already {meta['status']}")
225
+ continue
226
+
227
+ status = self.check_job_status(meta)
228
+ if status == "COMPLETED":
229
+ logger.info(f"completed {task}")
230
+ payload_table = self.download_job_result(meta)
231
+ logger.debug(payload_table.columns)
232
+ backend.save_data(task, payload_table)
233
+ meta["status"] = "COMPLETED"
234
+ meta["completed_at"] = str(datetime.now())
235
+ backend.save_meta(task, meta)
236
+ backend.mark_done(task)
237
+ with self.job_lock:
238
+ self.jobs[task] = meta
239
+ elif status in ("ERROR", "ABORTED"):
240
+ logger.warning(f"failed {task}: {status}")
241
+ meta["status"] = status
242
+ with self.job_lock:
243
+ self.jobs[task] = meta
244
+ backend.save_meta(task, meta)
245
+ else:
246
+ with self.job_lock:
247
+ self.jobs[task]["status"] = status
248
+ snapshot = self.jobs[task]
249
+ backend.save_meta(task, snapshot)
250
+
251
+ if self.all_chunks_submitted:
252
+ with self.job_lock:
253
+ all_done = (
254
+ all(
255
+ j.get("status") in ("COMPLETED", "ERROR", "ABORTED")
256
+ for j in self.jobs.values()
257
+ )
258
+ if len(self.jobs) > 0
259
+ else False
260
+ )
261
+ if all_done:
262
+ logger.info("All tasks done! Exiting polling thread")
263
+ break
264
+
265
+ logger.info(
266
+ f"Next poll at {datetime.now() + timedelta(seconds=self.poll_interval)}s"
267
+ )
268
+ time.sleep(self.poll_interval)
269
+
270
+ # ----------------------------
271
+ # Main run loop
272
+ # ----------------------------
273
+ def run(self):
274
+ # load existing job metadata
275
+ self.load_job_meta()
276
+
277
+ # start threads
278
+ self.submit_thread.start()
279
+ self.poll_thread.start()
280
+
281
+ # enqueue all chunks & queries
282
+ backend = self.backend
283
+ for chunk, q in product(self.chunker, self.queries):
284
+ task = self.get_task_id(chunk, q)
285
+
286
+ # skip if the download is done, or the job is queued
287
+ if backend.is_done(task) or (task in self.jobs):
288
+ continue
289
+
290
+ self.submit_queue.put((chunk, q))
291
+ self.all_chunks_queued = True
292
+ # wait until all jobs are submitted
293
+ self.submit_queue.join()
294
+ # wait for the submit thread
295
+ self.submit_thread.join()
296
+ # the polling thread will exit ones all results are downloaded
297
+ self.poll_thread.join()
298
+ # the stop event will stop also the submit thread
299
+ self.stop_event.set()
300
+ # if any thread exited with an error report it
301
+ self.submit_queue.raise_errors()
302
+ logger.info("Done running downloader!")
@@ -0,0 +1,121 @@
1
+ import logging
2
+ import backoff
3
+ import pyvo as vo
4
+ from xml.etree import ElementTree
5
+
6
+ from timewise.util.backoff import backoff_hndlr
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class StableAsyncTAPJob(vo.dal.AsyncTAPJob):
13
+ """
14
+ Implements backoff for call of phase which otherwise breaks the code if there are connection issues.
15
+ Also stores the response of TapQuery.submit() under self.submit_response for debugging
16
+ """
17
+
18
+ def __init__(self, url, *, session=None, delete=True):
19
+ super(StableAsyncTAPJob, self).__init__(url, session=session, delete=delete)
20
+ self.submit_response = None
21
+
22
+ @classmethod
23
+ def create(
24
+ cls,
25
+ baseurl,
26
+ query,
27
+ *,
28
+ language="ADQL",
29
+ maxrec=None,
30
+ uploads=None,
31
+ session=None,
32
+ **keywords,
33
+ ):
34
+ """
35
+ creates a async tap job on the server under ``baseurl``
36
+ Raises requests.HTTPError if TAPQuery.submit() failes.
37
+
38
+ Parameters
39
+ ----------
40
+ baseurl : str
41
+ the TAP baseurl
42
+ query : str
43
+ the query string
44
+ language : str
45
+ specifies the query language, default ADQL.
46
+ useful for services which allow to use the backend query language.
47
+ maxrec : int
48
+ the maximum records to return. defaults to the service default
49
+ uploads : dict
50
+ a mapping from table names to objects containing a votable
51
+ session : object
52
+ optional session to use for network requests
53
+ """
54
+ tapquery = vo.dal.TAPQuery(
55
+ baseurl,
56
+ query,
57
+ mode="async",
58
+ language=language,
59
+ maxrec=maxrec,
60
+ uploads=uploads,
61
+ session=session,
62
+ **keywords,
63
+ )
64
+ response = tapquery.submit()
65
+
66
+ # check if the response is valid
67
+ response.raise_for_status()
68
+
69
+ # check if the response contains an error from the ADQL engine
70
+ root = ElementTree.fromstring(response.content)
71
+ info = root.find(".//v:INFO", {"v": "http://www.ivoa.net/xml/VOTable/v1.3"})
72
+ if info and (info.attrib.get("value") == "ERROR"):
73
+ raise vo.dal.DALQueryError(info.text.strip())
74
+
75
+ # create the job instance
76
+ job = cls(response.url, session=session)
77
+ job._client_set_maxrec = maxrec
78
+ job.submit_response = response
79
+
80
+ return job
81
+
82
+ @property
83
+ @backoff.on_exception(
84
+ backoff.expo,
85
+ (vo.dal.DALServiceError, AttributeError),
86
+ max_tries=50,
87
+ on_backoff=backoff_hndlr,
88
+ )
89
+ def phase(self):
90
+ return super(StableAsyncTAPJob, self).phase
91
+
92
+
93
+ class StableTAPService(vo.dal.TAPService):
94
+ """
95
+ Implements the StableAsyncTAPJob for job submission
96
+ """
97
+
98
+ @backoff.on_exception(
99
+ backoff.expo,
100
+ (vo.dal.DALServiceError, AttributeError, AssertionError),
101
+ max_tries=5,
102
+ on_backoff=backoff_hndlr,
103
+ )
104
+ def submit_job(
105
+ self, query, *, language="ADQL", maxrec=None, uploads=None, **keywords
106
+ ):
107
+ job = StableAsyncTAPJob.create(
108
+ self.baseurl,
109
+ query,
110
+ language=language,
111
+ maxrec=maxrec,
112
+ uploads=uploads,
113
+ session=self._session,
114
+ **keywords,
115
+ )
116
+ logger.debug(job.url)
117
+ assert job.phase
118
+ return job
119
+
120
+ def get_job_from_url(self, url):
121
+ return StableAsyncTAPJob(url, session=self._session)
@@ -0,0 +1,3 @@
1
+ from .lightcurve import plot_lightcurve
2
+ from .sdss import plot_sdss_cutout
3
+ from .panstarrs import plot_panstarrs_cutout
@@ -0,0 +1,242 @@
1
+ from typing import Literal, Dict, Any, Sequence, List, cast
2
+ from functools import partial
3
+ from pathlib import Path
4
+ import logging
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ from numpy import typing as npt
9
+ from pydantic import BaseModel
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.lines import Line2D
12
+ from matplotlib.markers import MarkerStyle
13
+ from matplotlib.transforms import Affine2D
14
+
15
+ from timewise.plot import plot_lightcurve, plot_panstarrs_cutout, plot_sdss_cutout
16
+ from timewise.plot.lightcurve import BAND_PLOT_COLORS
17
+ from timewise.process import keys
18
+ from timewise.util.visits import get_visit_map
19
+ from timewise.config import TimewiseConfig
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class DiagnosticPlotter(BaseModel):
26
+ cutout: Literal["sdss", "panstarrs"] = "panstarrs"
27
+ band_colors: Dict[str, str] = BAND_PLOT_COLORS
28
+ lum_key: str = keys.FLUX_EXT
29
+
30
+ def plot_lightcurve(
31
+ self,
32
+ stacked_lightcurve: pd.DataFrame | None = None,
33
+ raw_lightcurve: pd.DataFrame | None = None,
34
+ ax: plt.Axes | None = None,
35
+ **kwargs,
36
+ ):
37
+ return plot_lightcurve(
38
+ lum_key=self.lum_key,
39
+ stacked_lightcurve=stacked_lightcurve,
40
+ raw_lightcurve=raw_lightcurve,
41
+ ax=ax,
42
+ colors=self.band_colors,
43
+ **kwargs,
44
+ )
45
+
46
+ def plot_cutout(self, ra: float, dec: float, radius_arcsec: float, ax: plt.Axes):
47
+ if self.cutout == "sdss":
48
+ plot_cutout = plot_sdss_cutout
49
+ elif self.cutout == "panstarrs":
50
+ plot_cutout = partial(plot_panstarrs_cutout, plot_color_image=True)
51
+ else:
52
+ raise NotImplementedError # should never happen
53
+ return plot_cutout(ra=ra, dec=dec, arcsec=radius_arcsec, ax=ax)
54
+
55
+ def make_plot(
56
+ self,
57
+ stacked_lightcurve: pd.DataFrame | None,
58
+ raw_lightcurve: pd.DataFrame,
59
+ labels: npt.ArrayLike,
60
+ source_ra: float,
61
+ source_dec: float,
62
+ selected_indices: list[Any],
63
+ highlight_radius: float | None = None,
64
+ ) -> tuple[plt.Figure, Sequence[plt.Axes]]:
65
+ fig, axs = plt.subplots(
66
+ nrows=2, gridspec_kw={"height_ratios": [3, 2]}, figsize=(5, 8)
67
+ )
68
+
69
+ self.plot_cutout(ra=source_ra, dec=source_dec, ax=axs[0], radius_arcsec=20)
70
+
71
+ selected_mask = raw_lightcurve.index.isin(selected_indices)
72
+ plot_lightcurve(
73
+ raw_lightcurve=raw_lightcurve[~selected_mask],
74
+ lum_key=self.lum_key,
75
+ ax=axs[-1],
76
+ save=False,
77
+ colors={"w1": "gray", "w2": "lightgray"},
78
+ add_to_label=" ignored",
79
+ )
80
+ self.plot_lightcurve(
81
+ stacked_lightcurve=stacked_lightcurve,
82
+ raw_lightcurve=raw_lightcurve[selected_mask],
83
+ ax=axs[-1],
84
+ save=False,
85
+ )
86
+
87
+ # set markers for clusters
88
+ markers_strings = list(Line2D.filled_markers) + [
89
+ "$1$",
90
+ "$2$",
91
+ "$3$",
92
+ "$4$",
93
+ "$5$",
94
+ "$6$",
95
+ "$7$",
96
+ "$8$",
97
+ "$9$",
98
+ ]
99
+ markers_straight = [MarkerStyle(im) for im in markers_strings]
100
+ rot = Affine2D().rotate_deg(180)
101
+ markers_rotated = [MarkerStyle(im, transform=rot) for im in markers_strings]
102
+ markers = markers_straight + markers_rotated
103
+
104
+ # calculate ra and dec relative to center of cutout
105
+ ra = (raw_lightcurve.ra - source_ra) * 3600
106
+ dec = (raw_lightcurve.dec - source_dec) * 3600
107
+
108
+ # get visit map
109
+ visit_map = get_visit_map(raw_lightcurve)
110
+
111
+ # for each visit plot the datapoints on the cutout
112
+ # for each visit plot the datapoints on the cutout
113
+ for visit in np.unique(visit_map):
114
+ m = visit_map == visit
115
+ label = str(visit)
116
+ axs[0].plot(
117
+ [],
118
+ [],
119
+ marker=markers[visit],
120
+ label=label,
121
+ mec="k",
122
+ mew=1,
123
+ mfc="none",
124
+ ls="",
125
+ )
126
+
127
+ for im, mec, zorder in zip(
128
+ [selected_mask, ~selected_mask], ["k", "none"], [1, 0]
129
+ ):
130
+ mask = m & im
131
+
132
+ for i_label in np.unique(labels):
133
+ label_mask = labels == i_label
134
+ final_mask = mask & label_mask
135
+ datapoints_label = raw_lightcurve[final_mask]
136
+ color = f"C{i_label}" if i_label != -1 else "grey"
137
+
138
+ if ("sigra" in datapoints_label.columns) and (
139
+ "sigdec" in datapoints_label.columns
140
+ ):
141
+ has_sig = (
142
+ ~datapoints_label.sigra.isna()
143
+ & ~datapoints_label.sigdec.isna()
144
+ )
145
+ _ra = ra[final_mask]
146
+ _dec = dec[final_mask]
147
+
148
+ axs[0].errorbar(
149
+ _ra[has_sig],
150
+ _dec[has_sig],
151
+ xerr=datapoints_label.sigra[has_sig] / 3600,
152
+ yerr=datapoints_label.sigdec[has_sig] / 3600,
153
+ marker=markers[visit],
154
+ ls="",
155
+ color=color,
156
+ zorder=zorder,
157
+ ms=10,
158
+ mec=mec,
159
+ mew=0.1,
160
+ )
161
+ axs[0].scatter(
162
+ _ra[~has_sig],
163
+ _dec[~has_sig],
164
+ marker=markers[visit],
165
+ color=color,
166
+ zorder=zorder,
167
+ edgecolors=mec,
168
+ linewidths=0.1,
169
+ )
170
+ else:
171
+ axs[0].scatter(
172
+ ra[final_mask],
173
+ dec[final_mask],
174
+ marker=markers[visit],
175
+ color=color,
176
+ zorder=zorder,
177
+ edgecolors=mec,
178
+ linewidths=0.1,
179
+ )
180
+
181
+ if highlight_radius:
182
+ circle = plt.Circle(
183
+ (0, 0),
184
+ highlight_radius,
185
+ color="g",
186
+ fill=False,
187
+ ls="-",
188
+ lw=3,
189
+ zorder=0,
190
+ )
191
+ axs[0].add_artist(circle)
192
+
193
+ # formatting
194
+ title = axs[0].get_title()
195
+ axs[-1].set_ylabel("Apparent Vega Magnitude")
196
+ axs[-1].grid(ls=":", alpha=0.5)
197
+ axs[0].set_title("")
198
+ axs[0].legend(
199
+ ncol=5,
200
+ bbox_to_anchor=(0, 1, 1, 0),
201
+ loc="lower left",
202
+ mode="expand",
203
+ title=title,
204
+ )
205
+ axs[0].set_aspect(1, adjustable="box")
206
+
207
+ return fig, axs
208
+
209
+
210
+ def make_plot(
211
+ config_path: Path,
212
+ cutout: Literal["sdss", "panstarrs"],
213
+ indices: List[int],
214
+ output_directory: Path,
215
+ ):
216
+ cfg = TimewiseConfig.from_yaml(config_path)
217
+ ampel_interface = cfg.build_ampel_interface()
218
+ input_data = pd.read_csv(cfg.download.input_csv).set_index(
219
+ ampel_interface.orig_id_key
220
+ )
221
+ plotter = DiagnosticPlotter(cutout=cutout)
222
+ for index in indices:
223
+ stacked_lightcurve = ampel_interface.extract_stacked_lightcurve(stock_id=index)
224
+ raw_lightcurve = ampel_interface.extract_datapoints(stock_id=index)
225
+ selected_dp_ids = ampel_interface.extract_selected_datapoint_ids(stock_id=index)
226
+ labels = [0] * len(raw_lightcurve)
227
+ source = input_data.loc[index]
228
+ ra: float = cast(float, source.ra)
229
+ dec: float = cast(float, source.dec)
230
+
231
+ fig, axs = plotter.make_plot(
232
+ stacked_lightcurve=stacked_lightcurve,
233
+ raw_lightcurve=raw_lightcurve,
234
+ labels=labels,
235
+ source_ra=ra,
236
+ source_dec=dec,
237
+ selected_indices=selected_dp_ids,
238
+ )
239
+ fn = output_directory / f"{index}.pdf"
240
+ logger.info(f"Saving plot to {fn}")
241
+ fig.savefig(fn)
242
+ plt.close()