setiastrosuitepro 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/blink_comparator_pro.py +146 -2
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +715 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +306 -0
- setiastro/saspro/model_workers.py +65 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +308 -9
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- setiastro/saspro/stacking_suite.py +539 -115
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/RECORD +29 -20
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.dist-info → setiastrosuitepro-1.8.0.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# src/setiastro/saspro/model_manager.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import json
|
|
7
|
+
import time
|
|
8
|
+
import shutil
|
|
9
|
+
import hashlib
|
|
10
|
+
import zipfile
|
|
11
|
+
import tempfile
|
|
12
|
+
from typing import Optional, Callable
|
|
13
|
+
from urllib.parse import urlparse, parse_qs
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
APP_FOLDER_NAME = "SetiAstroSuitePro" # keep stable
|
|
17
|
+
ProgressCB = Optional[Callable[[str], None]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def app_data_root() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Frozen-safe persistent data root.
|
|
23
|
+
MUST match the benchmark cache dir base (runtime_torch._user_runtime_dir()).
|
|
24
|
+
Example on Windows:
|
|
25
|
+
C:\\Users\\YOU\\AppData\\Local\\SASpro
|
|
26
|
+
"""
|
|
27
|
+
from setiastro.saspro.runtime_torch import _user_runtime_dir
|
|
28
|
+
root = Path(_user_runtime_dir()) # this is what benchmark_cache_dir() uses
|
|
29
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
return str(root)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def models_root() -> str:
|
|
34
|
+
p = Path(app_data_root()) / "models"
|
|
35
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
36
|
+
return str(p)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def installed_manifest_path() -> str:
|
|
40
|
+
return str(Path(models_root()) / "manifest.json")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def read_installed_manifest() -> dict:
|
|
44
|
+
try:
|
|
45
|
+
with open(installed_manifest_path(), "r", encoding="utf-8") as f:
|
|
46
|
+
return json.load(f)
|
|
47
|
+
except Exception:
|
|
48
|
+
return {}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def write_installed_manifest(d: dict) -> None:
|
|
52
|
+
try:
|
|
53
|
+
with open(installed_manifest_path(), "w", encoding="utf-8") as f:
|
|
54
|
+
json.dump(d, f, indent=2)
|
|
55
|
+
except Exception:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ---------------- Google Drive helpers ----------------
|
|
60
|
+
|
|
61
|
+
_DRIVE_FILE_RE = re.compile(r"/file/d/([a-zA-Z0-9_-]+)")
|
|
62
|
+
_DRIVE_ID_RE = re.compile(r"[?&]id=([a-zA-Z0-9_-]+)")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def extract_drive_file_id(url_or_id: str) -> Optional[str]:
|
|
66
|
+
s = (url_or_id or "").strip()
|
|
67
|
+
if not s:
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
# raw id
|
|
71
|
+
if re.fullmatch(r"[0-9A-Za-z_-]{10,}", s):
|
|
72
|
+
return s
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
u = urlparse(s)
|
|
76
|
+
if "drive.google.com" not in (u.netloc or "") and "docs.google.com" not in (u.netloc or ""):
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
m = re.search(r"/file/d/([^/]+)", u.path or "")
|
|
80
|
+
if m:
|
|
81
|
+
return m.group(1)
|
|
82
|
+
|
|
83
|
+
qs = parse_qs(u.query or "")
|
|
84
|
+
if "id" in qs and qs["id"]:
|
|
85
|
+
return qs["id"][0]
|
|
86
|
+
except Exception:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _looks_like_html_prefix(b: bytes) -> bool:
|
|
93
|
+
head = (b or b"").lstrip()[:256].lower()
|
|
94
|
+
return head.startswith(b"<!doctype html") or head.startswith(b"<html") or (b"<html" in head)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _parse_gdrive_download_form(html: str) -> tuple[Optional[str], Optional[dict]]:
|
|
98
|
+
m = re.search(r'<form[^>]+id="download-form"[^>]+action="([^"]+)"', html)
|
|
99
|
+
if not m:
|
|
100
|
+
return None, None
|
|
101
|
+
action = m.group(1)
|
|
102
|
+
params: dict[str, str] = {}
|
|
103
|
+
|
|
104
|
+
for name, val in re.findall(
|
|
105
|
+
r'<input[^>]+type="hidden"[^>]+name="([^"]+)"[^>]*value="([^"]*)"', html
|
|
106
|
+
):
|
|
107
|
+
params[name] = val
|
|
108
|
+
|
|
109
|
+
for name in re.findall(
|
|
110
|
+
r'<input[^>]+type="hidden"[^>]+name="([^"]+)"(?![^>]*value=)', html
|
|
111
|
+
):
|
|
112
|
+
params.setdefault(name, "")
|
|
113
|
+
|
|
114
|
+
return action, params
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def download_google_drive_file(
|
|
118
|
+
file_id: str,
|
|
119
|
+
dst_path: str | os.PathLike,
|
|
120
|
+
*,
|
|
121
|
+
progress_cb: ProgressCB = None,
|
|
122
|
+
should_cancel=None, # callable -> bool
|
|
123
|
+
timeout: int = 60,
|
|
124
|
+
chunk_size: int = 1024 * 1024,
|
|
125
|
+
) -> Path:
|
|
126
|
+
"""
|
|
127
|
+
Downloads a Google Drive file by ID, handling virus-scan interstitial HTML.
|
|
128
|
+
Writes atomically (dst.part -> dst).
|
|
129
|
+
"""
|
|
130
|
+
import requests # local import to keep import cost down
|
|
131
|
+
|
|
132
|
+
fid = extract_drive_file_id(file_id) or file_id
|
|
133
|
+
if not fid:
|
|
134
|
+
raise RuntimeError("No Google Drive file id provided.")
|
|
135
|
+
|
|
136
|
+
dst = Path(dst_path)
|
|
137
|
+
tmp = dst.with_suffix(dst.suffix + ".part")
|
|
138
|
+
tmp.parent.mkdir(parents=True, exist_ok=True)
|
|
139
|
+
|
|
140
|
+
# The “uc” endpoint is best for download
|
|
141
|
+
url = f"https://drive.google.com/uc?export=download&id={fid}"
|
|
142
|
+
|
|
143
|
+
def log(msg: str):
|
|
144
|
+
if progress_cb:
|
|
145
|
+
progress_cb(msg)
|
|
146
|
+
|
|
147
|
+
# Clean any old partial
|
|
148
|
+
try:
|
|
149
|
+
tmp.unlink(missing_ok=True)
|
|
150
|
+
except Exception:
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
with requests.Session() as s:
|
|
154
|
+
log("Connecting to Google Drive…")
|
|
155
|
+
r = s.get(url, stream=True, timeout=timeout, allow_redirects=True)
|
|
156
|
+
|
|
157
|
+
ctype = (r.headers.get("Content-Type") or "").lower()
|
|
158
|
+
|
|
159
|
+
# If HTML, parse the interstitial "download anyway" form and re-request.
|
|
160
|
+
if "text/html" in ctype:
|
|
161
|
+
html = r.text
|
|
162
|
+
r.close()
|
|
163
|
+
action, params = _parse_gdrive_download_form(html)
|
|
164
|
+
if not action or not params:
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
"Google Drive returned an interstitial HTML page, but the download form could not be parsed."
|
|
167
|
+
)
|
|
168
|
+
log("Google Drive interstitial detected — confirming download…")
|
|
169
|
+
r = s.get(action, params=params, stream=True, timeout=timeout, allow_redirects=True)
|
|
170
|
+
|
|
171
|
+
r.raise_for_status()
|
|
172
|
+
|
|
173
|
+
total = int(r.headers.get("Content-Length") or 0)
|
|
174
|
+
done = 0
|
|
175
|
+
t_last = time.time()
|
|
176
|
+
done_last = 0
|
|
177
|
+
|
|
178
|
+
first = True
|
|
179
|
+
with open(tmp, "wb") as f:
|
|
180
|
+
for chunk in r.iter_content(chunk_size=chunk_size):
|
|
181
|
+
if should_cancel and should_cancel():
|
|
182
|
+
try:
|
|
183
|
+
f.close()
|
|
184
|
+
tmp.unlink(missing_ok=True)
|
|
185
|
+
except Exception:
|
|
186
|
+
pass
|
|
187
|
+
raise RuntimeError("Download canceled.")
|
|
188
|
+
|
|
189
|
+
if not chunk:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
if first:
|
|
193
|
+
first = False
|
|
194
|
+
# extra safety: even if content-type lies
|
|
195
|
+
if _looks_like_html_prefix(chunk[:256]):
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
"Google Drive returned HTML instead of the file (permission/confirm issue)."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
f.write(chunk)
|
|
201
|
+
done += len(chunk)
|
|
202
|
+
|
|
203
|
+
now = time.time()
|
|
204
|
+
if now - t_last >= 0.5:
|
|
205
|
+
if total > 0:
|
|
206
|
+
pct = (done * 100.0) / total
|
|
207
|
+
log(f"Downloading… {pct:5.1f}% ({done}/{total} bytes)")
|
|
208
|
+
else:
|
|
209
|
+
bps = (done - done_last) / max(now - t_last, 1e-9)
|
|
210
|
+
log(f"Downloading… {done} bytes ({bps/1024/1024:.1f} MB/s)")
|
|
211
|
+
t_last = now
|
|
212
|
+
done_last = done
|
|
213
|
+
|
|
214
|
+
os.replace(str(tmp), str(dst))
|
|
215
|
+
log(f"Download complete: {dst}")
|
|
216
|
+
return dst
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def install_models_zip(
|
|
220
|
+
zip_path: str | os.PathLike,
|
|
221
|
+
*,
|
|
222
|
+
progress_cb: ProgressCB = None,
|
|
223
|
+
manifest: dict | None = None,
|
|
224
|
+
) -> None:
|
|
225
|
+
"""
|
|
226
|
+
Extracts a models zip and installs it into models_root(), replacing previous contents.
|
|
227
|
+
Writes manifest.json if provided.
|
|
228
|
+
"""
|
|
229
|
+
dst = Path(models_root())
|
|
230
|
+
|
|
231
|
+
# Use unique temp dirs per install to avoid collisions
|
|
232
|
+
tmp_extract = Path(tempfile.gettempdir()) / f"saspro_models_extract_{os.getpid()}_{int(time.time())}"
|
|
233
|
+
tmp_stage = Path(tempfile.gettempdir()) / f"saspro_models_stage_{os.getpid()}_{int(time.time())}"
|
|
234
|
+
|
|
235
|
+
def log(msg: str):
|
|
236
|
+
if progress_cb:
|
|
237
|
+
progress_cb(msg)
|
|
238
|
+
|
|
239
|
+
# clean temp (best-effort)
|
|
240
|
+
try:
|
|
241
|
+
shutil.rmtree(tmp_extract, ignore_errors=True)
|
|
242
|
+
shutil.rmtree(tmp_stage, ignore_errors=True)
|
|
243
|
+
except Exception:
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
log("Extracting models zip…")
|
|
248
|
+
tmp_extract.mkdir(parents=True, exist_ok=True)
|
|
249
|
+
with zipfile.ZipFile(str(zip_path), "r") as z:
|
|
250
|
+
z.extractall(tmp_extract)
|
|
251
|
+
|
|
252
|
+
# Some zips contain a top-level folder; normalize:
|
|
253
|
+
root = tmp_extract
|
|
254
|
+
kids = list(root.iterdir())
|
|
255
|
+
if len(kids) == 1 and kids[0].is_dir():
|
|
256
|
+
root = kids[0]
|
|
257
|
+
|
|
258
|
+
# sanity: must contain at least one model file
|
|
259
|
+
any_model = any(p.suffix.lower() in (".pth", ".onnx") for p in root.rglob("*"))
|
|
260
|
+
if not any_model:
|
|
261
|
+
raise RuntimeError("Models zip did not contain any .pth/.onnx files.")
|
|
262
|
+
|
|
263
|
+
log(f"Installing to: {dst}")
|
|
264
|
+
|
|
265
|
+
# Stage copy
|
|
266
|
+
shutil.copytree(root, tmp_stage)
|
|
267
|
+
|
|
268
|
+
# Clear destination contents (keep dst folder stable)
|
|
269
|
+
dst.mkdir(parents=True, exist_ok=True)
|
|
270
|
+
for item in dst.iterdir():
|
|
271
|
+
try:
|
|
272
|
+
if item.is_dir():
|
|
273
|
+
shutil.rmtree(item, ignore_errors=True)
|
|
274
|
+
else:
|
|
275
|
+
item.unlink(missing_ok=True)
|
|
276
|
+
except Exception:
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
# Copy staged contents into dst
|
|
280
|
+
for item in tmp_stage.iterdir():
|
|
281
|
+
target = dst / item.name
|
|
282
|
+
if item.is_dir():
|
|
283
|
+
# dirs_exist_ok requires Python 3.8+, you're on 3.12 so OK
|
|
284
|
+
shutil.copytree(item, target, dirs_exist_ok=True)
|
|
285
|
+
else:
|
|
286
|
+
shutil.copy2(item, target)
|
|
287
|
+
|
|
288
|
+
if manifest:
|
|
289
|
+
log("Writing manifest…")
|
|
290
|
+
write_installed_manifest(manifest)
|
|
291
|
+
|
|
292
|
+
log("Models installed.")
|
|
293
|
+
finally:
|
|
294
|
+
shutil.rmtree(tmp_extract, ignore_errors=True)
|
|
295
|
+
shutil.rmtree(tmp_stage, ignore_errors=True)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def sha256_file(path: str | os.PathLike, *, chunk_size: int = 1024 * 1024) -> str:
|
|
299
|
+
h = hashlib.sha256()
|
|
300
|
+
with open(path, "rb") as f:
|
|
301
|
+
while True:
|
|
302
|
+
b = f.read(chunk_size)
|
|
303
|
+
if not b:
|
|
304
|
+
break
|
|
305
|
+
h.update(b)
|
|
306
|
+
return h.hexdigest()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# src/setiastro/saspro/model_workers.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from PyQt6.QtCore import QObject, pyqtSignal
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
import zipfile
|
|
8
|
+
|
|
9
|
+
from setiastro.saspro.model_manager import (
|
|
10
|
+
extract_drive_file_id,
|
|
11
|
+
download_google_drive_file,
|
|
12
|
+
install_models_zip,
|
|
13
|
+
sha256_file,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
class ModelsDownloadWorker(QObject):
|
|
17
|
+
progress = pyqtSignal(str)
|
|
18
|
+
finished = pyqtSignal(bool, str)
|
|
19
|
+
|
|
20
|
+
def __init__(self, primary: str, backup: str, expected_sha256: str | None = None, should_cancel=None):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.primary = primary
|
|
23
|
+
self.backup = backup
|
|
24
|
+
self.expected_sha256 = (expected_sha256 or "").strip() or None
|
|
25
|
+
self.should_cancel = should_cancel # callable -> bool
|
|
26
|
+
|
|
27
|
+
def run(self):
|
|
28
|
+
try:
|
|
29
|
+
# The inputs should be FILE links (or IDs), not folder links.
|
|
30
|
+
fid = extract_drive_file_id(self.primary) or extract_drive_file_id(self.backup)
|
|
31
|
+
if not fid:
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
"Models URL is not a Google Drive *file* link or id.\n"
|
|
34
|
+
"Please provide a shared file link (…/file/d/<ID>/view) to the models zip."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
tmp = os.path.join(tempfile.gettempdir(), "saspro_models_latest.zip")
|
|
38
|
+
try:
|
|
39
|
+
self.progress.emit("Downloading from primary…")
|
|
40
|
+
download_google_drive_file(fid, tmp, progress_cb=lambda s: self.progress.emit(s), should_cancel=self.should_cancel)
|
|
41
|
+
except Exception as e:
|
|
42
|
+
# Try backup if primary fails AND backup has a different file id
|
|
43
|
+
fid2 = extract_drive_file_id(self.backup)
|
|
44
|
+
if fid2 and fid2 != fid:
|
|
45
|
+
self.progress.emit("Primary failed. Trying backup…")
|
|
46
|
+
download_google_drive_file(fid2, tmp, progress_cb=lambda s: self.progress.emit(s), should_cancel=self.should_cancel)
|
|
47
|
+
else:
|
|
48
|
+
raise
|
|
49
|
+
|
|
50
|
+
if self.expected_sha256:
|
|
51
|
+
self.progress.emit("Verifying checksum…")
|
|
52
|
+
got = sha256_file(tmp)
|
|
53
|
+
if got.lower() != self.expected_sha256.lower():
|
|
54
|
+
raise RuntimeError(f"SHA256 mismatch.\nExpected: {self.expected_sha256}\nGot: {got}")
|
|
55
|
+
|
|
56
|
+
manifest = {
|
|
57
|
+
"source": "google_drive",
|
|
58
|
+
"file_id": fid,
|
|
59
|
+
"sha256": self.expected_sha256,
|
|
60
|
+
}
|
|
61
|
+
install_models_zip(tmp, progress_cb=lambda s: self.progress.emit(s), manifest=manifest)
|
|
62
|
+
|
|
63
|
+
self.finished.emit(True, "Models updated successfully.")
|
|
64
|
+
except Exception as e:
|
|
65
|
+
self.finished.emit(False, str(e))
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# src/setiastro/saspro/ops/benchmark.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from PyQt6.QtCore import QThread, pyqtSignal, QObject
|
|
6
|
+
from PyQt6.QtWidgets import (
|
|
7
|
+
QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QComboBox,
|
|
8
|
+
QProgressBar, QTextEdit, QMessageBox, QApplication
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from setiastro.saspro.cosmicclarity_engines.benchmark_engine import (
|
|
12
|
+
benchmark_image_path, download_benchmark_image, run_benchmark, # keep your run_benchmark
|
|
13
|
+
)
|
|
14
|
+
from setiastro.saspro.cosmicclarity_engines.benchmark_engine import BENCHMARK_FITS_URL # or define here
|
|
15
|
+
|
|
16
|
+
class _BenchWorker(QObject):
|
|
17
|
+
log = pyqtSignal(str)
|
|
18
|
+
prog = pyqtSignal(int, int) # done,total
|
|
19
|
+
done = pyqtSignal(bool, dict, str) # ok, results, err
|
|
20
|
+
|
|
21
|
+
def __init__(self, mode: str, use_gpu: bool):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.mode = mode
|
|
24
|
+
self.use_gpu = use_gpu
|
|
25
|
+
|
|
26
|
+
def run(self):
|
|
27
|
+
try:
|
|
28
|
+
def status_cb(s: str):
|
|
29
|
+
self.log.emit(str(s))
|
|
30
|
+
|
|
31
|
+
def progress_cb(done: int, total: int) -> bool:
|
|
32
|
+
self.prog.emit(int(done), int(total))
|
|
33
|
+
return not QThread.currentThread().isInterruptionRequested()
|
|
34
|
+
|
|
35
|
+
results = run_benchmark(
|
|
36
|
+
mode=self.mode,
|
|
37
|
+
use_gpu=self.use_gpu,
|
|
38
|
+
status_cb=status_cb,
|
|
39
|
+
progress_cb=progress_cb,
|
|
40
|
+
)
|
|
41
|
+
self.done.emit(True, results, "")
|
|
42
|
+
except Exception as e:
|
|
43
|
+
msg = str(e)
|
|
44
|
+
if "Canceled" in msg or "cancel" in msg.lower():
|
|
45
|
+
self.done.emit(False, {}, "Canceled.")
|
|
46
|
+
else:
|
|
47
|
+
self.done.emit(False, {}, msg)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class _DownloadWorker(QObject):
|
|
51
|
+
log = pyqtSignal(str)
|
|
52
|
+
prog = pyqtSignal(int, int) # bytes_done, bytes_total
|
|
53
|
+
done = pyqtSignal(bool, str) # ok, message/path
|
|
54
|
+
|
|
55
|
+
def __init__(self, url: str):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.url = url
|
|
58
|
+
|
|
59
|
+
def run(self):
|
|
60
|
+
try:
|
|
61
|
+
def status_cb(s: str):
|
|
62
|
+
self.log.emit(str(s))
|
|
63
|
+
|
|
64
|
+
def progress_cb(done: int, total: int):
|
|
65
|
+
self.prog.emit(int(done), int(total))
|
|
66
|
+
|
|
67
|
+
def cancel_cb() -> bool:
|
|
68
|
+
return QThread.currentThread().isInterruptionRequested()
|
|
69
|
+
|
|
70
|
+
p = download_benchmark_image(
|
|
71
|
+
None,
|
|
72
|
+
status_cb=status_cb,
|
|
73
|
+
progress_cb=progress_cb,
|
|
74
|
+
cancel_cb=cancel_cb,
|
|
75
|
+
)
|
|
76
|
+
self.done.emit(True, str(p))
|
|
77
|
+
except Exception as e:
|
|
78
|
+
self.done.emit(False, str(e))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class BenchmarkDialog(QDialog):
|
|
82
|
+
def __init__(self, parent=None):
|
|
83
|
+
super().__init__(parent)
|
|
84
|
+
self.setWindowTitle("Seti Astro Benchmark")
|
|
85
|
+
self.setModal(False)
|
|
86
|
+
self.setMinimumSize(560, 520)
|
|
87
|
+
|
|
88
|
+
self._results = None
|
|
89
|
+
self._thread = None
|
|
90
|
+
|
|
91
|
+
outer = QVBoxLayout(self)
|
|
92
|
+
outer.setContentsMargins(10, 10, 10, 10)
|
|
93
|
+
outer.setSpacing(8)
|
|
94
|
+
|
|
95
|
+
# Top row: image status + download
|
|
96
|
+
top = QHBoxLayout()
|
|
97
|
+
self.lbl_img = QLabel(self)
|
|
98
|
+
self.btn_dl = QPushButton("Download Benchmark Image…", self)
|
|
99
|
+
self.btn_dl.clicked.connect(self._download_image)
|
|
100
|
+
top.addWidget(self.lbl_img, 1)
|
|
101
|
+
top.addWidget(self.btn_dl)
|
|
102
|
+
outer.addLayout(top)
|
|
103
|
+
|
|
104
|
+
# Mode row
|
|
105
|
+
row = QHBoxLayout()
|
|
106
|
+
row.addWidget(QLabel("Run:", self))
|
|
107
|
+
self.cmb = QComboBox(self)
|
|
108
|
+
self.cmb.addItems(["CPU", "GPU", "Both"])
|
|
109
|
+
self.cmb.setCurrentText("Both")
|
|
110
|
+
row.addWidget(self.cmb)
|
|
111
|
+
|
|
112
|
+
self.btn_run = QPushButton("Run Benchmark", self)
|
|
113
|
+
self.btn_run.clicked.connect(self._run_benchmark)
|
|
114
|
+
row.addWidget(self.btn_run)
|
|
115
|
+
row.addStretch(1)
|
|
116
|
+
outer.addLayout(row)
|
|
117
|
+
|
|
118
|
+
# Progress
|
|
119
|
+
self.pbar = QProgressBar(self)
|
|
120
|
+
self.pbar.setRange(0, 100)
|
|
121
|
+
outer.addWidget(self.pbar)
|
|
122
|
+
|
|
123
|
+
# Log / results
|
|
124
|
+
self.txt = QTextEdit(self)
|
|
125
|
+
self.txt.setReadOnly(True)
|
|
126
|
+
outer.addWidget(self.txt, 1)
|
|
127
|
+
|
|
128
|
+
# Bottom buttons
|
|
129
|
+
bot = QHBoxLayout()
|
|
130
|
+
self.btn_copy = QPushButton("Copy JSON", self)
|
|
131
|
+
self.btn_copy.setEnabled(False)
|
|
132
|
+
self.btn_copy.clicked.connect(self._copy_json)
|
|
133
|
+
|
|
134
|
+
self.btn_save = QPushButton("Save Locally", self)
|
|
135
|
+
self.btn_save.setEnabled(False)
|
|
136
|
+
self.btn_save.clicked.connect(self._save_local)
|
|
137
|
+
|
|
138
|
+
self.btn_submit = QPushButton("Submit…", self)
|
|
139
|
+
self.btn_submit.clicked.connect(self._submit)
|
|
140
|
+
|
|
141
|
+
self.btn_close = QPushButton("Close", self)
|
|
142
|
+
self.btn_close.clicked.connect(self.close)
|
|
143
|
+
|
|
144
|
+
bot.addWidget(self.btn_copy)
|
|
145
|
+
bot.addWidget(self.btn_save)
|
|
146
|
+
bot.addStretch(1)
|
|
147
|
+
bot.addWidget(self.btn_submit)
|
|
148
|
+
bot.addWidget(self.btn_close)
|
|
149
|
+
outer.addLayout(bot)
|
|
150
|
+
|
|
151
|
+
self.refresh_ui()
|
|
152
|
+
|
|
153
|
+
def refresh_ui(self):
|
|
154
|
+
p = benchmark_image_path()
|
|
155
|
+
if p.exists():
|
|
156
|
+
self.lbl_img.setText(f"Benchmark image: Ready ({p.name})")
|
|
157
|
+
self.btn_run.setEnabled(True)
|
|
158
|
+
else:
|
|
159
|
+
self.lbl_img.setText("Benchmark image: Not downloaded")
|
|
160
|
+
self.btn_run.setEnabled(False)
|
|
161
|
+
|
|
162
|
+
self.pbar.setValue(0)
|
|
163
|
+
|
|
164
|
+
# ---------- download ----------
|
|
165
|
+
def _download_image(self):
|
|
166
|
+
self._stop_thread_if_any()
|
|
167
|
+
|
|
168
|
+
self.txt.append("Starting download…")
|
|
169
|
+
self.btn_dl.setEnabled(False)
|
|
170
|
+
self.btn_run.setEnabled(False)
|
|
171
|
+
self.pbar.setValue(0)
|
|
172
|
+
|
|
173
|
+
t = QThread(self)
|
|
174
|
+
w = _DownloadWorker(BENCHMARK_FITS_URL)
|
|
175
|
+
w.moveToThread(t)
|
|
176
|
+
|
|
177
|
+
w.log.connect(self._log)
|
|
178
|
+
w.prog.connect(self._dl_progress)
|
|
179
|
+
w.done.connect(lambda ok, msg: self._dl_done(ok, msg, t, w))
|
|
180
|
+
t.started.connect(w.run)
|
|
181
|
+
t.start()
|
|
182
|
+
self._thread = t
|
|
183
|
+
|
|
184
|
+
def _dl_progress(self, done: int, total: int):
|
|
185
|
+
if total > 0:
|
|
186
|
+
pct = int(done * 100 / total)
|
|
187
|
+
self.pbar.setRange(0, 100)
|
|
188
|
+
self.pbar.setValue(max(0, min(100, pct)))
|
|
189
|
+
|
|
190
|
+
# 👇 add this
|
|
191
|
+
self.pbar.setFormat(f"{pct}% ({done/1e6:.0f}/{total/1e6:.0f} MB)")
|
|
192
|
+
else:
|
|
193
|
+
# unknown length
|
|
194
|
+
self.pbar.setRange(0, 0)
|
|
195
|
+
self.pbar.setFormat(f"{done/1e6:.0f} MB")
|
|
196
|
+
|
|
197
|
+
def _dl_done(self, ok: bool, msg: str, t: QThread, w: QObject):
|
|
198
|
+
t.quit(); t.wait()
|
|
199
|
+
self._thread = None
|
|
200
|
+
|
|
201
|
+
self.pbar.setRange(0, 100)
|
|
202
|
+
self.pbar.setFormat("%p%")
|
|
203
|
+
self.btn_dl.setEnabled(True)
|
|
204
|
+
|
|
205
|
+
if ok:
|
|
206
|
+
self._log(f"✅ Downloaded: {msg}")
|
|
207
|
+
self.refresh_ui()
|
|
208
|
+
else:
|
|
209
|
+
self._log(f"❌ Download failed: {msg}")
|
|
210
|
+
QMessageBox.warning(self, "Download failed", msg)
|
|
211
|
+
self.refresh_ui()
|
|
212
|
+
|
|
213
|
+
def closeEvent(self, e):
|
|
214
|
+
self._stop_thread_if_any()
|
|
215
|
+
self.pbar.setRange(0, 100)
|
|
216
|
+
super().closeEvent(e)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ---------- run benchmark ----------
|
|
220
|
+
def _run_benchmark(self):
|
|
221
|
+
p = benchmark_image_path()
|
|
222
|
+
if not p.exists():
|
|
223
|
+
QMessageBox.information(self, "Benchmark image missing", "Please download the benchmark image first.")
|
|
224
|
+
return
|
|
225
|
+
|
|
226
|
+
self._stop_thread_if_any()
|
|
227
|
+
self.btn_run.setEnabled(False)
|
|
228
|
+
self.btn_dl.setEnabled(False)
|
|
229
|
+
self._results = None
|
|
230
|
+
self.btn_copy.setEnabled(False)
|
|
231
|
+
self.btn_save.setEnabled(False)
|
|
232
|
+
|
|
233
|
+
self.txt.clear()
|
|
234
|
+
self._log("Running benchmark…")
|
|
235
|
+
self.pbar.setValue(0)
|
|
236
|
+
|
|
237
|
+
mode = self.cmb.currentText()
|
|
238
|
+
use_gpu = True # benchmark engine will pick CPU if no CUDA/DML
|
|
239
|
+
|
|
240
|
+
t = QThread(self)
|
|
241
|
+
w = _BenchWorker(mode=mode, use_gpu=use_gpu)
|
|
242
|
+
w.moveToThread(t)
|
|
243
|
+
|
|
244
|
+
w.log.connect(self._log)
|
|
245
|
+
w.prog.connect(self._bench_progress)
|
|
246
|
+
w.done.connect(lambda ok, results, err: self._bench_done(ok, results, err, t, w))
|
|
247
|
+
t.started.connect(w.run)
|
|
248
|
+
t.start()
|
|
249
|
+
self._thread = t
|
|
250
|
+
|
|
251
|
+
def _bench_progress(self, done: int, total: int):
|
|
252
|
+
if total > 0:
|
|
253
|
+
self.pbar.setValue(int(done * 100 / total))
|
|
254
|
+
QApplication.processEvents()
|
|
255
|
+
|
|
256
|
+
def _bench_done(self, ok: bool, results: dict, err: str, t: QThread, w: QObject):
|
|
257
|
+
t.quit(); t.wait()
|
|
258
|
+
self._thread = None
|
|
259
|
+
self.pbar.setValue(100 if ok else 0)
|
|
260
|
+
self.btn_run.setEnabled(True)
|
|
261
|
+
self.btn_dl.setEnabled(True)
|
|
262
|
+
if not ok:
|
|
263
|
+
self._log(f"❌ Benchmark failed: {err}")
|
|
264
|
+
if str(err).strip().lower().startswith("canceled"):
|
|
265
|
+
# no scary dialog for user-cancel
|
|
266
|
+
self.pbar.setValue(0)
|
|
267
|
+
return
|
|
268
|
+
QMessageBox.warning(self, "Benchmark failed", err)
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
self._results = results
|
|
272
|
+
self._log("✅ Benchmark complete.\n")
|
|
273
|
+
self._log(json.dumps([results], indent=2))
|
|
274
|
+
|
|
275
|
+
self.btn_copy.setEnabled(True)
|
|
276
|
+
self.btn_save.setEnabled(True)
|
|
277
|
+
|
|
278
|
+
# ---------- actions ----------
|
|
279
|
+
def _copy_json(self):
|
|
280
|
+
if not self._results:
|
|
281
|
+
return
|
|
282
|
+
s = json.dumps([self._results], indent=4)
|
|
283
|
+
QApplication.clipboard().setText(s)
|
|
284
|
+
QMessageBox.information(self, "Copied", "Benchmark JSON copied to clipboard.")
|
|
285
|
+
|
|
286
|
+
def _save_local(self):
|
|
287
|
+
if not self._results:
|
|
288
|
+
return
|
|
289
|
+
# reuse your existing helper if you want; simplest local save:
|
|
290
|
+
import os, time
|
|
291
|
+
fn = "benchmark_results.json"
|
|
292
|
+
try:
|
|
293
|
+
if os.path.exists(fn):
|
|
294
|
+
with open(fn, "r", encoding="utf-8") as f:
|
|
295
|
+
try:
|
|
296
|
+
allr = json.load(f)
|
|
297
|
+
except Exception:
|
|
298
|
+
allr = []
|
|
299
|
+
else:
|
|
300
|
+
allr = []
|
|
301
|
+
allr.append(self._results)
|
|
302
|
+
with open(fn, "w", encoding="utf-8") as f:
|
|
303
|
+
json.dump(allr, f, indent=4)
|
|
304
|
+
self._log(f"\n✅ Saved to {fn}")
|
|
305
|
+
except Exception as e:
|
|
306
|
+
QMessageBox.warning(self, "Save failed", str(e))
|
|
307
|
+
|
|
308
|
+
def _submit(self):
|
|
309
|
+
import webbrowser
|
|
310
|
+
webbrowser.open("https://setiastro.com/benchmark-submit")
|
|
311
|
+
|
|
312
|
+
def _log(self, s: str):
|
|
313
|
+
self.txt.append(str(s))
|
|
314
|
+
|
|
315
|
+
def _stop_thread_if_any(self):
|
|
316
|
+
if self._thread is not None and self._thread.isRunning():
|
|
317
|
+
self._thread.requestInterruption()
|
|
318
|
+
self._thread.quit()
|
|
319
|
+
self._thread.wait()
|
|
320
|
+
self._thread = None
|