setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,732 @@
|
|
|
1
|
+
# src/setiastro/saspro/cosmicclarity_engines/benchmark_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os, time, platform
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Optional, Literal, Dict, Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from astropy.io import fits
|
|
10
|
+
|
|
11
|
+
from numba import njit, prange
|
|
12
|
+
import psutil
|
|
13
|
+
|
|
14
|
+
import cpuinfo
|
|
15
|
+
|
|
16
|
+
ProgressCB = Callable[[str], None]
|
|
17
|
+
CancelCB = Callable[[], bool]
|
|
18
|
+
DLProgressCB = Callable[[int, int], None] # (done_bytes, total_bytes)
|
|
19
|
+
|
|
20
|
+
from setiastro.saspro.cosmicclarity_engines.sharpen_engine import load_sharpen_models
|
|
21
|
+
from setiastro.saspro.resources import get_resources
|
|
22
|
+
|
|
23
|
+
# -----------------------------
|
|
24
|
+
# Paths / cache
|
|
25
|
+
# -----------------------------
|
|
26
|
+
def benchmark_cache_dir() -> Path:
|
|
27
|
+
# Reuse your runtime dir (same one accel installer uses)
|
|
28
|
+
from setiastro.saspro.runtime_torch import _user_runtime_dir
|
|
29
|
+
rt = Path(_user_runtime_dir())
|
|
30
|
+
d = rt / "benchmarks"
|
|
31
|
+
d.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
return d
|
|
33
|
+
|
|
34
|
+
def benchmark_image_path() -> Path:
|
|
35
|
+
return benchmark_cache_dir() / "benchmarkimage.fit"
|
|
36
|
+
|
|
37
|
+
from typing import Sequence, Union
|
|
38
|
+
from urllib.parse import urlparse, parse_qs
|
|
39
|
+
import re
|
|
40
|
+
|
|
41
|
+
BENCHMARK_URLS = [
|
|
42
|
+
"https://drive.google.com/file/d/1wgp6Ydn8JgF1j9FVnF6PgjyN-6ptJTnK/view?usp=drive_link",
|
|
43
|
+
"https://drive.google.com/file/d/1QhsmuKjvksAMq45M3aHKHylZgZEd8Nh0/view?usp=drive_link",
|
|
44
|
+
"https://github.com/setiastro/setiastrosuitepro/releases/download/benchmarkFIT/benchmarkimage.fit",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
# Keep for backwards compat (some code imports BENCHMARK_FITS_URL)
|
|
48
|
+
BENCHMARK_FITS_URL = BENCHMARK_URLS[-1]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
from urllib.parse import urljoin
|
|
52
|
+
|
|
53
|
+
def _looks_like_html_prefix(b: bytes) -> bool:
|
|
54
|
+
head = (b or b"").lstrip()[:256].lower()
|
|
55
|
+
return head.startswith(b"<!doctype html") or head.startswith(b"<html") or b"<html" in head
|
|
56
|
+
|
|
57
|
+
def _parse_gdrive_download_form(html: str) -> tuple[str, dict] | tuple[None, None]:
|
|
58
|
+
"""
|
|
59
|
+
Parse the Google Drive virus-scan warning page.
|
|
60
|
+
Extracts:
|
|
61
|
+
- form action URL (often https://drive.usercontent.google.com/download)
|
|
62
|
+
- all hidden input fields needed for the download
|
|
63
|
+
"""
|
|
64
|
+
# action="..."
|
|
65
|
+
m = re.search(r'<form[^>]+id="download-form"[^>]+action="([^"]+)"', html)
|
|
66
|
+
if not m:
|
|
67
|
+
return None, None
|
|
68
|
+
action = m.group(1)
|
|
69
|
+
|
|
70
|
+
# hidden inputs: <input type="hidden" name="X" value="Y">
|
|
71
|
+
inputs = {}
|
|
72
|
+
for name, val in re.findall(r'<input[^>]+type="hidden"[^>]+name="([^"]+)"[^>]*value="([^"]*)"', html):
|
|
73
|
+
inputs[name] = val
|
|
74
|
+
|
|
75
|
+
# Some pages omit value="" explicitly; handle name-only hidden inputs too (rare)
|
|
76
|
+
for name in re.findall(r'<input[^>]+type="hidden"[^>]+name="([^"]+)"(?![^>]*value=)', html):
|
|
77
|
+
inputs.setdefault(name, "")
|
|
78
|
+
|
|
79
|
+
return action, inputs
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _gdrive_file_id(url: str) -> Optional[str]:
|
|
83
|
+
"""
|
|
84
|
+
Extract Google Drive file id from:
|
|
85
|
+
- https://drive.google.com/file/d/<ID>/view
|
|
86
|
+
- https://drive.google.com/open?id=<ID>
|
|
87
|
+
- https://drive.google.com/uc?id=<ID>&export=download
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
u = urlparse(url)
|
|
91
|
+
if "drive.google.com" not in (u.netloc or ""):
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
# /file/d/<id>/...
|
|
95
|
+
m = re.search(r"/file/d/([^/]+)", u.path or "")
|
|
96
|
+
if m:
|
|
97
|
+
return m.group(1)
|
|
98
|
+
|
|
99
|
+
# ?id=<id>
|
|
100
|
+
qs = parse_qs(u.query or "")
|
|
101
|
+
if "id" in qs and qs["id"]:
|
|
102
|
+
return qs["id"][0]
|
|
103
|
+
except Exception:
|
|
104
|
+
pass
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _gdrive_direct_url(file_id: str) -> str:
|
|
109
|
+
# export=download is essential; confirm token may be appended later if needed
|
|
110
|
+
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _gdrive_confirm_token(html: str) -> Optional[str]:
|
|
114
|
+
"""
|
|
115
|
+
When Drive shows the "can't scan for viruses" interstitial,
|
|
116
|
+
it includes a confirm token in a download link.
|
|
117
|
+
"""
|
|
118
|
+
# Typical patterns include confirm=<TOKEN>
|
|
119
|
+
m = re.search(r"confirm=([0-9A-Za-z_]+)", html)
|
|
120
|
+
if m:
|
|
121
|
+
return m.group(1)
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _normalize_download_url(url: str) -> tuple[str, Optional[str]]:
|
|
126
|
+
"""
|
|
127
|
+
Returns (normalized_url, label_for_logging).
|
|
128
|
+
If Google Drive view/open link, returns a direct uc?export=download&id= URL.
|
|
129
|
+
"""
|
|
130
|
+
fid = _gdrive_file_id(url)
|
|
131
|
+
if fid:
|
|
132
|
+
return _gdrive_direct_url(fid), f"Google Drive ({fid})"
|
|
133
|
+
return url, None
|
|
134
|
+
|
|
135
|
+
def _looks_like_html_prefix(b: bytes) -> bool:
|
|
136
|
+
if not b:
|
|
137
|
+
return False
|
|
138
|
+
head = b.lstrip()[:64].lower()
|
|
139
|
+
return head.startswith(b"<!doctype html") or head.startswith(b"<html") or b"<html" in head
|
|
140
|
+
|
|
141
|
+
def _is_probably_valid_fits(path: Path, *, min_bytes: int = 1_000_000) -> bool:
|
|
142
|
+
try:
|
|
143
|
+
if path.stat().st_size < min_bytes:
|
|
144
|
+
return False
|
|
145
|
+
with open(path, "rb") as f:
|
|
146
|
+
first = f.read(80)
|
|
147
|
+
# FITS primary header should start with SIMPLE =
|
|
148
|
+
if b"SIMPLE" not in first[:20]:
|
|
149
|
+
return False
|
|
150
|
+
return True
|
|
151
|
+
except Exception:
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# -----------------------------
|
|
156
|
+
# Download benchmark FITS
|
|
157
|
+
# -----------------------------
|
|
158
|
+
def download_benchmark_image(
|
|
159
|
+
url: Union[str, Sequence[str], None] = None,
|
|
160
|
+
dst: Optional[Path] = None,
|
|
161
|
+
*,
|
|
162
|
+
status_cb: Optional[ProgressCB] = None,
|
|
163
|
+
progress_cb: Optional[DLProgressCB] = None,
|
|
164
|
+
cancel_cb: Optional[CancelCB] = None,
|
|
165
|
+
timeout: int = 30,
|
|
166
|
+
) -> Path:
|
|
167
|
+
"""
|
|
168
|
+
Download benchmarkimage.fit into runtime cache.
|
|
169
|
+
|
|
170
|
+
url:
|
|
171
|
+
- None -> try BENCHMARK_URLS in order
|
|
172
|
+
- str -> try that one (but will still handle Drive confirms)
|
|
173
|
+
- list/tuple -> try each in order
|
|
174
|
+
|
|
175
|
+
Uses streaming download + atomic replace. Supports cancel.
|
|
176
|
+
"""
|
|
177
|
+
if dst is None:
|
|
178
|
+
dst = benchmark_image_path()
|
|
179
|
+
dst = Path(dst)
|
|
180
|
+
tmp = dst.with_suffix(dst.suffix + ".part")
|
|
181
|
+
|
|
182
|
+
# Build candidate list
|
|
183
|
+
if url is None:
|
|
184
|
+
candidates = list(BENCHMARK_URLS)
|
|
185
|
+
elif isinstance(url, (list, tuple)):
|
|
186
|
+
candidates = list(url)
|
|
187
|
+
else:
|
|
188
|
+
candidates = [url]
|
|
189
|
+
|
|
190
|
+
import requests # local import keeps startup lighter
|
|
191
|
+
|
|
192
|
+
last_err = None
|
|
193
|
+
|
|
194
|
+
for idx, raw in enumerate(candidates, start=1):
|
|
195
|
+
try:
|
|
196
|
+
dl_url, label = _normalize_download_url(raw)
|
|
197
|
+
src_label = label or raw
|
|
198
|
+
|
|
199
|
+
if status_cb:
|
|
200
|
+
status_cb(f"Downloading benchmark image… (source {idx}/{len(candidates)}: {src_label})")
|
|
201
|
+
|
|
202
|
+
# Use a session so Drive confirm/cookies work reliably
|
|
203
|
+
with requests.Session() as s:
|
|
204
|
+
r = s.get(dl_url, stream=True, timeout=timeout, allow_redirects=True)
|
|
205
|
+
|
|
206
|
+
ctype = (r.headers.get("Content-Type") or "").lower()
|
|
207
|
+
|
|
208
|
+
# If we got HTML, it’s probably the virus-scan warning page
|
|
209
|
+
if "text/html" in ctype:
|
|
210
|
+
html = r.text # reads page into memory (small)
|
|
211
|
+
r.close()
|
|
212
|
+
|
|
213
|
+
action, params = _parse_gdrive_download_form(html)
|
|
214
|
+
if action and params:
|
|
215
|
+
# Submit the "Download anyway" form with the SAME session/cookies
|
|
216
|
+
r = s.get(action, params=params, stream=True, timeout=timeout, allow_redirects=True)
|
|
217
|
+
ctype = (r.headers.get("Content-Type") or "").lower()
|
|
218
|
+
else:
|
|
219
|
+
raise RuntimeError("Google Drive returned an interstitial HTML page, but download form could not be parsed.")
|
|
220
|
+
|
|
221
|
+
r.raise_for_status()
|
|
222
|
+
total = int(r.headers.get("Content-Length") or 0)
|
|
223
|
+
done = 0
|
|
224
|
+
t_start = time.time()
|
|
225
|
+
t_last = t_start
|
|
226
|
+
done_last = 0
|
|
227
|
+
ema_bps = None
|
|
228
|
+
|
|
229
|
+
tmp.parent.mkdir(parents=True, exist_ok=True)
|
|
230
|
+
|
|
231
|
+
first_chunk = True
|
|
232
|
+
with open(tmp, "wb") as f:
|
|
233
|
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
|
234
|
+
if cancel_cb and cancel_cb():
|
|
235
|
+
try:
|
|
236
|
+
f.close()
|
|
237
|
+
tmp.unlink(missing_ok=True)
|
|
238
|
+
except Exception:
|
|
239
|
+
pass
|
|
240
|
+
raise RuntimeError("Download canceled.")
|
|
241
|
+
|
|
242
|
+
if not chunk:
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
# SNIFF FIRST BYTES: if it's HTML, abort this source immediately
|
|
246
|
+
if first_chunk:
|
|
247
|
+
first_chunk = False
|
|
248
|
+
if _looks_like_html_prefix(chunk[:256]):
|
|
249
|
+
raise RuntimeError("Google Drive returned HTML (not the FITS). Link likely requires confirm/permission.")
|
|
250
|
+
|
|
251
|
+
f.write(chunk)
|
|
252
|
+
done += len(chunk)
|
|
253
|
+
|
|
254
|
+
if progress_cb:
|
|
255
|
+
progress_cb(done, total)
|
|
256
|
+
|
|
257
|
+
now = time.time()
|
|
258
|
+
dt = now - t_last
|
|
259
|
+
if dt >= 0.5:
|
|
260
|
+
inst_bps = (done - done_last) / max(dt, 1e-9)
|
|
261
|
+
ema_bps = inst_bps if ema_bps is None else (0.75 * ema_bps + 0.25 * inst_bps)
|
|
262
|
+
|
|
263
|
+
eta = None
|
|
264
|
+
if total > 0 and ema_bps and ema_bps > 1:
|
|
265
|
+
eta = (total - done) / ema_bps
|
|
266
|
+
|
|
267
|
+
if status_cb:
|
|
268
|
+
pct = (done * 100.0 / total) if total > 0 else None
|
|
269
|
+
if pct is None:
|
|
270
|
+
status_cb(f"Downloading… {_fmt_bytes(done)} at {_fmt_bytes(ema_bps)}/s • ETA {_fmt_eta(None)}")
|
|
271
|
+
else:
|
|
272
|
+
status_cb(
|
|
273
|
+
f"Downloading… {pct:5.1f}% • {_fmt_bytes(done)}/{_fmt_bytes(total)} "
|
|
274
|
+
f"at {_fmt_bytes(ema_bps)}/s • ETA {_fmt_eta(eta)}"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
t_last = now
|
|
278
|
+
done_last = done
|
|
279
|
+
|
|
280
|
+
# atomic replace only after a full success
|
|
281
|
+
os.replace(str(tmp), str(dst))
|
|
282
|
+
|
|
283
|
+
# VALIDATE: size + FITS header sanity. If invalid, treat as failure and try next URL.
|
|
284
|
+
if not _is_probably_valid_fits(dst, min_bytes=10_000_000): # 10MB floor; tune as you like
|
|
285
|
+
raise RuntimeError("Downloaded file is not a valid FITS (too small or missing SIMPLE).")
|
|
286
|
+
|
|
287
|
+
if status_cb:
|
|
288
|
+
status_cb(f"Benchmark image ready: {dst}")
|
|
289
|
+
return dst
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
last_err = e
|
|
294
|
+
# clean partial
|
|
295
|
+
try:
|
|
296
|
+
tmp.unlink(missing_ok=True)
|
|
297
|
+
except Exception:
|
|
298
|
+
pass
|
|
299
|
+
if status_cb:
|
|
300
|
+
status_cb(f"Source {idx} failed: {e}")
|
|
301
|
+
|
|
302
|
+
raise RuntimeError(f"All benchmark download sources failed. Last error: {last_err}")
|
|
303
|
+
|
|
304
|
+
def _fmt_bytes(n: float) -> str:
|
|
305
|
+
units = ["B", "KB", "MB", "GB", "TB"]
|
|
306
|
+
n = float(max(0.0, n))
|
|
307
|
+
for u in units:
|
|
308
|
+
if n < 1024.0 or u == units[-1]:
|
|
309
|
+
return f"{n:.1f} {u}" if u != "B" else f"{n:.0f} {u}"
|
|
310
|
+
n /= 1024.0
|
|
311
|
+
return f"{n:.1f} TB"
|
|
312
|
+
|
|
313
|
+
def _fmt_eta(seconds: Optional[float]) -> str:
|
|
314
|
+
if seconds is None or seconds <= 0 or not np.isfinite(seconds):
|
|
315
|
+
return "—"
|
|
316
|
+
s = int(seconds + 0.5)
|
|
317
|
+
m, s = divmod(s, 60)
|
|
318
|
+
h, m = divmod(m, 60)
|
|
319
|
+
if h:
|
|
320
|
+
return f"{h:d}h {m:02d}m"
|
|
321
|
+
if m:
|
|
322
|
+
return f"{m:d}m {s:02d}s"
|
|
323
|
+
return f"{s:d}s"
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _get_stellar_model_for_benchmark(use_gpu: bool, status_cb=None):
|
|
327
|
+
"""
|
|
328
|
+
Returns (models, backend_tag)
|
|
329
|
+
- models: SharpenModels (torch or onnx)
|
|
330
|
+
- backend_tag: 'CUDA', 'CPU', 'DirectML', 'MPS', etc.
|
|
331
|
+
"""
|
|
332
|
+
models = load_sharpen_models(use_gpu=use_gpu, status_cb=status_cb or (lambda *_: None))
|
|
333
|
+
|
|
334
|
+
if models.is_onnx:
|
|
335
|
+
return models, "DirectML"
|
|
336
|
+
|
|
337
|
+
dev = models.device
|
|
338
|
+
dev_type = getattr(dev, "type", "")
|
|
339
|
+
if dev_type == "cuda":
|
|
340
|
+
return models, "CUDA"
|
|
341
|
+
if dev_type == "mps":
|
|
342
|
+
return models, "MPS"
|
|
343
|
+
|
|
344
|
+
# torch-directml devices don’t have .type == 'dml' typically; handle by string
|
|
345
|
+
if "dml" in str(dev).lower() or "directml" in str(dev).lower():
|
|
346
|
+
return models, "DirectML"
|
|
347
|
+
|
|
348
|
+
return models, "CPU"
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def torch_benchmark_stellar(
|
|
353
|
+
patches_nchw: np.ndarray,
|
|
354
|
+
*,
|
|
355
|
+
use_gpu: bool,
|
|
356
|
+
progress_cb=None, # (done,total)->bool
|
|
357
|
+
status_cb=None,
|
|
358
|
+
) -> tuple[float, float, str]:
|
|
359
|
+
"""
|
|
360
|
+
Torch benchmark using the SAME Stellar model + autocast policy as sharpen_engine.
|
|
361
|
+
"""
|
|
362
|
+
status_cb = status_cb or (lambda *_: None)
|
|
363
|
+
models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
|
|
364
|
+
|
|
365
|
+
if models.is_onnx:
|
|
366
|
+
raise RuntimeError("torch_benchmark_stellar called but models.is_onnx=True")
|
|
367
|
+
|
|
368
|
+
torch = models.torch
|
|
369
|
+
device = models.device
|
|
370
|
+
model = models.stellar
|
|
371
|
+
|
|
372
|
+
x = torch.from_numpy(patches_nchw).to(device=device, dtype=torch.float32, non_blocking=True)
|
|
373
|
+
|
|
374
|
+
# warmup a tiny bit to avoid first-kernel skew
|
|
375
|
+
with torch.no_grad():
|
|
376
|
+
_ = model(x[0:1])
|
|
377
|
+
if device.type == "cuda":
|
|
378
|
+
torch.cuda.synchronize()
|
|
379
|
+
|
|
380
|
+
total_ms = 0.0
|
|
381
|
+
n = int(x.shape[0])
|
|
382
|
+
status_cb(f"Benchmarking Stellar model via Torch ({tag})…")
|
|
383
|
+
|
|
384
|
+
# IMPORTANT: reuse sharpen_engine autocast policy, not unconditional AMP
|
|
385
|
+
from setiastro.saspro.cosmicclarity_engines.sharpen_engine import _autocast_context
|
|
386
|
+
|
|
387
|
+
with torch.no_grad(), _autocast_context(torch, device):
|
|
388
|
+
for i in range(n):
|
|
389
|
+
t0 = time.time()
|
|
390
|
+
_ = model(x[i:i+1])
|
|
391
|
+
if device.type == "cuda":
|
|
392
|
+
torch.cuda.synchronize()
|
|
393
|
+
total_ms += (time.time() - t0) * 1000.0
|
|
394
|
+
|
|
395
|
+
if progress_cb and (not progress_cb(i + 1, n)):
|
|
396
|
+
raise RuntimeError("Canceled.")
|
|
397
|
+
|
|
398
|
+
return (total_ms / n), total_ms, tag
|
|
399
|
+
|
|
400
|
+
def onnx_benchmark_stellar(
|
|
401
|
+
patches_nchw: np.ndarray,
|
|
402
|
+
*,
|
|
403
|
+
use_gpu: bool,
|
|
404
|
+
progress_cb=None,
|
|
405
|
+
status_cb=None,
|
|
406
|
+
) -> tuple[float, float, str]:
|
|
407
|
+
"""
|
|
408
|
+
ONNX benchmark:
|
|
409
|
+
- If sharpen_engine selected ONNX (DirectML), reuse that exact session.
|
|
410
|
+
- Otherwise, on Windows, prefer DirectML provider if available (when use_gpu=True),
|
|
411
|
+
then CUDA EP if present, else CPU.
|
|
412
|
+
"""
|
|
413
|
+
status_cb = status_cb or (lambda *_: None)
|
|
414
|
+
|
|
415
|
+
# Reuse sharpen_engine session if it already chose ONNX (typically DirectML on Windows)
|
|
416
|
+
models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
|
|
417
|
+
if models.is_onnx:
|
|
418
|
+
sess = models.stellar
|
|
419
|
+
# Use the provider that the session actually has (best-effort label)
|
|
420
|
+
try:
|
|
421
|
+
provs = sess.get_providers()
|
|
422
|
+
provider = provs[0] if provs else "ONNX"
|
|
423
|
+
except Exception:
|
|
424
|
+
provider = "DmlExecutionProvider"
|
|
425
|
+
else:
|
|
426
|
+
import onnxruntime as ort
|
|
427
|
+
from setiastro.saspro.model_manager import require_model
|
|
428
|
+
onnx_path = require_model("deep_sharp_stellar_cnn_AI3_5s.onnx")
|
|
429
|
+
|
|
430
|
+
providers_avail = ort.get_available_providers()
|
|
431
|
+
|
|
432
|
+
# Prefer DirectML if possible (Windows) when GPU requested
|
|
433
|
+
providers = []
|
|
434
|
+
if use_gpu and ("DmlExecutionProvider" in providers_avail):
|
|
435
|
+
providers.append("DmlExecutionProvider")
|
|
436
|
+
# If no DML (or user disabled GPU), try CUDA EP if available
|
|
437
|
+
if use_gpu and ("CUDAExecutionProvider" in providers_avail):
|
|
438
|
+
providers.append("CUDAExecutionProvider")
|
|
439
|
+
# Always end with CPU
|
|
440
|
+
providers.append("CPUExecutionProvider")
|
|
441
|
+
|
|
442
|
+
# Build session
|
|
443
|
+
sess = ort.InferenceSession(str(onnx_path), providers=providers)
|
|
444
|
+
|
|
445
|
+
# Label by what actually got picked
|
|
446
|
+
try:
|
|
447
|
+
provs = sess.get_providers()
|
|
448
|
+
provider = provs[0] if provs else providers[0]
|
|
449
|
+
except Exception:
|
|
450
|
+
provider = providers[0]
|
|
451
|
+
|
|
452
|
+
input_name = sess.get_inputs()[0].name
|
|
453
|
+
total_ms = 0.0
|
|
454
|
+
n = int(patches_nchw.shape[0])
|
|
455
|
+
|
|
456
|
+
status_cb(f"Benchmarking Stellar model via ONNX ({provider})…")
|
|
457
|
+
|
|
458
|
+
for i in range(n):
|
|
459
|
+
patch = patches_nchw[i:i+1].astype(np.float32, copy=False)
|
|
460
|
+
t0 = time.time()
|
|
461
|
+
sess.run(None, {input_name: patch})
|
|
462
|
+
total_ms += (time.time() - t0) * 1000.0
|
|
463
|
+
|
|
464
|
+
if progress_cb and (not progress_cb(i + 1, n)):
|
|
465
|
+
raise RuntimeError("Canceled.")
|
|
466
|
+
|
|
467
|
+
return (total_ms / n), total_ms, provider
|
|
468
|
+
|
|
469
|
+
# -----------------------------
|
|
470
|
+
# Load + tile image
|
|
471
|
+
# -----------------------------
|
|
472
|
+
def _load_benchmark_fits(path: Path) -> np.ndarray:
|
|
473
|
+
with fits.open(str(path), memmap=False) as hdul:
|
|
474
|
+
img = hdul[0].data
|
|
475
|
+
if img is None:
|
|
476
|
+
raise RuntimeError("FITS contains no data.")
|
|
477
|
+
img = np.asarray(img, dtype=np.float32)
|
|
478
|
+
|
|
479
|
+
# Expect mono 2D; convert to CHW(3,H,W)
|
|
480
|
+
if img.ndim == 2:
|
|
481
|
+
img = np.stack([img, img, img], axis=0)
|
|
482
|
+
elif img.ndim == 3:
|
|
483
|
+
# If HWC convert -> CHW
|
|
484
|
+
if img.shape[-1] in (3, 4) and img.shape[0] != 3:
|
|
485
|
+
img = np.transpose(img[..., :3], (2, 0, 1))
|
|
486
|
+
# If already CHW with 3 ok; if 1 channel expand
|
|
487
|
+
if img.shape[0] == 1:
|
|
488
|
+
img = np.repeat(img, 3, axis=0)
|
|
489
|
+
else:
|
|
490
|
+
raise RuntimeError(f"Unexpected FITS shape: {img.shape}")
|
|
491
|
+
return img
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def tile_chw_image(image_chw: np.ndarray, patch_size: int = 256) -> np.ndarray:
|
|
495
|
+
"""
|
|
496
|
+
image_chw: (3,H,W) -> patches: (N,3,patch,patch)
|
|
497
|
+
Only full patches (no padding) to match your old behavior.
|
|
498
|
+
"""
|
|
499
|
+
c, h, w = image_chw.shape
|
|
500
|
+
patches = []
|
|
501
|
+
for y in range(0, h, patch_size):
|
|
502
|
+
for x in range(0, w, patch_size):
|
|
503
|
+
p = image_chw[:, y:y+patch_size, x:x+patch_size]
|
|
504
|
+
if p.shape[1] == patch_size and p.shape[2] == patch_size:
|
|
505
|
+
patches.append(p)
|
|
506
|
+
if not patches:
|
|
507
|
+
raise RuntimeError("No full 256x256 patches found in benchmark image.")
|
|
508
|
+
return np.stack(patches, axis=0).astype(np.float32, copy=False)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
# -----------------------------
|
|
512
|
+
# CPU microbenchmarks (Numba)
|
|
513
|
+
# -----------------------------
|
|
514
|
+
@njit
|
|
515
|
+
def _mad_cpu_jit(image_array: np.ndarray, median_val: float) -> float:
|
|
516
|
+
return np.median(np.abs(image_array - median_val))
|
|
517
|
+
|
|
518
|
+
def mad_cpu(image_array: np.ndarray, runs: int = 3) -> list[float]:
|
|
519
|
+
"""
|
|
520
|
+
Return ms timings. First run includes JIT compile.
|
|
521
|
+
image_array can be CHW or HW; we flatten it for fairness.
|
|
522
|
+
"""
|
|
523
|
+
arr = np.asarray(image_array, dtype=np.float32).ravel()
|
|
524
|
+
times = []
|
|
525
|
+
for _ in range(runs):
|
|
526
|
+
t0 = time.time()
|
|
527
|
+
med = float(np.median(arr))
|
|
528
|
+
_ = _mad_cpu_jit(arr, med)
|
|
529
|
+
times.append((time.time() - t0) * 1000.0)
|
|
530
|
+
return times
|
|
531
|
+
|
|
532
|
+
@njit(parallel=True)
|
|
533
|
+
def _flat_field_jit(image_array: np.ndarray, flat_frame: np.ndarray, median_flat: float) -> np.ndarray:
|
|
534
|
+
out = np.empty_like(image_array)
|
|
535
|
+
n = image_array.size
|
|
536
|
+
for i in prange(n):
|
|
537
|
+
out[i] = image_array[i] / (flat_frame[i] / median_flat)
|
|
538
|
+
return out
|
|
539
|
+
|
|
540
|
+
def flat_field_correction(image_array: np.ndarray, flat_frame: np.ndarray, runs: int = 3) -> list[float]:
|
|
541
|
+
arr = np.asarray(image_array, dtype=np.float32).ravel()
|
|
542
|
+
flt = np.asarray(flat_frame, dtype=np.float32).ravel()
|
|
543
|
+
times = []
|
|
544
|
+
for _ in range(runs):
|
|
545
|
+
t0 = time.time()
|
|
546
|
+
med = float(np.median(flt))
|
|
547
|
+
_ = _flat_field_jit(arr, flt, med)
|
|
548
|
+
times.append((time.time() - t0) * 1000.0)
|
|
549
|
+
return times
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
# -----------------------------
|
|
553
|
+
# System info
|
|
554
|
+
# -----------------------------
|
|
555
|
+
def get_system_info() -> dict:
|
|
556
|
+
info = {
|
|
557
|
+
"OS": f"{platform.system()} {platform.release()}",
|
|
558
|
+
"CPU": cpuinfo.get_cpu_info().get("brand_raw", "Unknown"),
|
|
559
|
+
"RAM": f"{round(psutil.virtual_memory().total / (1024 ** 3), 1)} GB",
|
|
560
|
+
"Python": f"{platform.python_version()}",
|
|
561
|
+
}
|
|
562
|
+
# torch / onnx details (optional)
|
|
563
|
+
try:
|
|
564
|
+
from setiastro.saspro.runtime_torch import add_runtime_to_sys_path
|
|
565
|
+
add_runtime_to_sys_path(status_cb=lambda *_: None)
|
|
566
|
+
import torch
|
|
567
|
+
info["torch.__version__"] = getattr(torch, "__version__", "Unknown")
|
|
568
|
+
info["torch.version.cuda"] = getattr(getattr(torch, "version", None), "cuda", None)
|
|
569
|
+
info["CUDA Available"] = bool(getattr(torch, "cuda", None) and torch.cuda.is_available())
|
|
570
|
+
info["MPS Available"] = bool(hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
|
|
571
|
+
if info["CUDA Available"]:
|
|
572
|
+
try:
|
|
573
|
+
info["GPU"] = torch.cuda.get_device_name(0)
|
|
574
|
+
except Exception:
|
|
575
|
+
pass
|
|
576
|
+
except Exception:
|
|
577
|
+
pass
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
import onnxruntime as ort
|
|
581
|
+
info["ONNX Providers"] = ort.get_available_providers()
|
|
582
|
+
except Exception:
|
|
583
|
+
pass
|
|
584
|
+
|
|
585
|
+
return info
|
|
586
|
+
|
|
587
|
+
def _fmt_ms(first_ms: float, avg_ms: float) -> str:
|
|
588
|
+
return f"First: {first_ms:.2f} ms | Avg: {avg_ms:.2f} ms"
|
|
589
|
+
|
|
590
|
+
def _fmt_gpu(avg_ms: float, total_ms: float) -> str:
|
|
591
|
+
return f"Avg: {avg_ms:.2f} ms | Total: {total_ms:.2f} ms"
|
|
592
|
+
|
|
593
|
+
def _legacy_results_schema(results: Dict[str, Any]) -> Dict[str, Any]:
|
|
594
|
+
"""
|
|
595
|
+
Convert internal structured results -> legacy/human schema expected by submitter.
|
|
596
|
+
"""
|
|
597
|
+
out: Dict[str, Any] = {}
|
|
598
|
+
|
|
599
|
+
# --- CPU ---
|
|
600
|
+
cpu_mad = results.get("CPU MAD (Single Core)")
|
|
601
|
+
if isinstance(cpu_mad, dict):
|
|
602
|
+
out["CPU MAD (Single Core)"] = _fmt_ms(cpu_mad.get("first_ms", 0.0), cpu_mad.get("avg_ms", 0.0))
|
|
603
|
+
elif cpu_mad is not None:
|
|
604
|
+
out["CPU MAD (Single Core)"] = str(cpu_mad)
|
|
605
|
+
|
|
606
|
+
cpu_flat = results.get("CPU Flat-Field (Multi-Core)")
|
|
607
|
+
if isinstance(cpu_flat, dict):
|
|
608
|
+
out["CPU Flat-Field (Multi-Core)"] = _fmt_ms(cpu_flat.get("first_ms", 0.0), cpu_flat.get("avg_ms", 0.0))
|
|
609
|
+
elif cpu_flat is not None:
|
|
610
|
+
out["CPU Flat-Field (Multi-Core)"] = str(cpu_flat)
|
|
611
|
+
|
|
612
|
+
# --- GPU / Torch ---
|
|
613
|
+
# Your internal key looks like "Stellar Torch (CUDA)" or "(CPU)" etc.
|
|
614
|
+
torch_key = next((k for k in results.keys() if k.startswith("Stellar Torch (")), None)
|
|
615
|
+
if torch_key and isinstance(results.get(torch_key), dict):
|
|
616
|
+
backend = torch_key[len("Stellar Torch ("):-1] # extract tag inside (...)
|
|
617
|
+
t = results[torch_key]
|
|
618
|
+
out[f"GPU Time ({backend})"] = _fmt_gpu(t.get("avg_ms", 0.0), t.get("total_ms", 0.0))
|
|
619
|
+
|
|
620
|
+
# --- ONNX ---
|
|
621
|
+
onnx_key = next((k for k in results.keys() if k.startswith("Stellar ONNX")), None)
|
|
622
|
+
if onnx_key is None:
|
|
623
|
+
# you used this on non-windows
|
|
624
|
+
if "Stellar ONNX" in results:
|
|
625
|
+
out["ONNX Time"] = str(results["Stellar ONNX"])
|
|
626
|
+
else:
|
|
627
|
+
out["ONNX Time"] = "ONNX benchmark not run."
|
|
628
|
+
else:
|
|
629
|
+
v = results.get(onnx_key)
|
|
630
|
+
if isinstance(v, dict):
|
|
631
|
+
# keep it under the legacy key name
|
|
632
|
+
out["ONNX Time"] = _fmt_gpu(v.get("avg_ms", 0.0), v.get("total_ms", 0.0))
|
|
633
|
+
else:
|
|
634
|
+
out["ONNX Time"] = str(v)
|
|
635
|
+
|
|
636
|
+
# --- System Info ---
|
|
637
|
+
# If you want to match your example more closely, drop Python/torch versions.
|
|
638
|
+
si = results.get("System Info", {})
|
|
639
|
+
if isinstance(si, dict):
|
|
640
|
+
keep = {
|
|
641
|
+
"OS", "CPU", "RAM",
|
|
642
|
+
"CUDA Available", "MPS Available", "ONNX Providers", "GPU"
|
|
643
|
+
}
|
|
644
|
+
out["System Info"] = {k: si[k] for k in keep if k in si}
|
|
645
|
+
else:
|
|
646
|
+
out["System Info"] = si
|
|
647
|
+
|
|
648
|
+
return out
|
|
649
|
+
|
|
650
|
+
def _picked_backend(use_gpu: bool, status_cb=None) -> str:
|
|
651
|
+
models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
|
|
652
|
+
return "ONNX" if models.is_onnx else "TORCH"
|
|
653
|
+
|
|
654
|
+
# -----------------------------
|
|
655
|
+
# One-stop runner
|
|
656
|
+
# -----------------------------
|
|
657
|
+
def run_benchmark(
|
|
658
|
+
*,
|
|
659
|
+
mode: Literal["CPU", "GPU", "Both"] = "Both",
|
|
660
|
+
use_gpu: bool = True,
|
|
661
|
+
benchmark_fits_path: Optional[Path] = None,
|
|
662
|
+
status_cb: Optional[ProgressCB] = None,
|
|
663
|
+
progress_cb: Optional[Callable[[int, int], bool]] = None,
|
|
664
|
+
) -> Dict[str, Any]:
|
|
665
|
+
"""
|
|
666
|
+
Returns results dict safe to json.dumps.
|
|
667
|
+
progress_cb signature: (done, total)->bool continue
|
|
668
|
+
"""
|
|
669
|
+
if benchmark_fits_path is None:
|
|
670
|
+
benchmark_fits_path = benchmark_image_path()
|
|
671
|
+
benchmark_fits_path = Path(benchmark_fits_path)
|
|
672
|
+
if not benchmark_fits_path.exists():
|
|
673
|
+
raise RuntimeError("Benchmark image not downloaded yet.")
|
|
674
|
+
|
|
675
|
+
img_chw = _load_benchmark_fits(benchmark_fits_path)
|
|
676
|
+
patches = tile_chw_image(img_chw, 256) # (N,3,256,256)
|
|
677
|
+
|
|
678
|
+
results: Dict[str, Any] = {}
|
|
679
|
+
results["System Info"] = get_system_info()
|
|
680
|
+
|
|
681
|
+
if mode in ("CPU", "Both"):
|
|
682
|
+
if status_cb: status_cb("Running CPU benchmarks…")
|
|
683
|
+
cpu_mad = mad_cpu(img_chw)
|
|
684
|
+
cpu_flat = flat_field_correction(img_chw, img_chw)
|
|
685
|
+
results["CPU MAD (Single Core)"] = {
|
|
686
|
+
"first_ms": float(cpu_mad[0]),
|
|
687
|
+
"avg_ms": float(np.mean(cpu_mad[1:])) if len(cpu_mad) > 1 else float(cpu_mad[0]),
|
|
688
|
+
}
|
|
689
|
+
results["CPU Flat-Field (Multi-Core)"] = {
|
|
690
|
+
"first_ms": float(cpu_flat[0]),
|
|
691
|
+
"avg_ms": float(np.mean(cpu_flat[1:])) if len(cpu_flat) > 1 else float(cpu_flat[0]),
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
if mode in ("GPU", "Both"):
|
|
695
|
+
if status_cb: status_cb("Running Stellar model benchmark…")
|
|
696
|
+
|
|
697
|
+
picked = _picked_backend(use_gpu=use_gpu, status_cb=status_cb)
|
|
698
|
+
|
|
699
|
+
if picked == "TORCH":
|
|
700
|
+
avg_ms, total_ms, backend = torch_benchmark_stellar(
|
|
701
|
+
patches,
|
|
702
|
+
use_gpu=use_gpu,
|
|
703
|
+
progress_cb=progress_cb,
|
|
704
|
+
status_cb=status_cb,
|
|
705
|
+
)
|
|
706
|
+
results[f"Stellar Torch ({backend})"] = {"avg_ms": float(avg_ms), "total_ms": float(total_ms)}
|
|
707
|
+
|
|
708
|
+
# Optional: also run ONNX on Windows for comparison
|
|
709
|
+
if platform.system() == "Windows":
|
|
710
|
+
avg_o, total_o, provider = onnx_benchmark_stellar(
|
|
711
|
+
patches,
|
|
712
|
+
use_gpu=use_gpu,
|
|
713
|
+
progress_cb=progress_cb,
|
|
714
|
+
status_cb=status_cb,
|
|
715
|
+
)
|
|
716
|
+
results[f"Stellar ONNX ({provider})"] = {"avg_ms": float(avg_o), "total_ms": float(total_o)}
|
|
717
|
+
|
|
718
|
+
else:
|
|
719
|
+
# Picked ONNX (DirectML). Run ONNX benchmark as the “GPU benchmark”.
|
|
720
|
+
if platform.system() == "Windows":
|
|
721
|
+
avg_o, total_o, provider = onnx_benchmark_stellar(
|
|
722
|
+
patches,
|
|
723
|
+
use_gpu=use_gpu,
|
|
724
|
+
progress_cb=progress_cb,
|
|
725
|
+
status_cb=status_cb,
|
|
726
|
+
)
|
|
727
|
+
results[f"Stellar ONNX ({provider})"] = {"avg_ms": float(avg_o), "total_ms": float(total_o)}
|
|
728
|
+
else:
|
|
729
|
+
results["Stellar ONNX"] = "ONNX benchmark only available on Windows."
|
|
730
|
+
|
|
731
|
+
results = _legacy_results_schema(results)
|
|
732
|
+
return results
|