setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,732 @@
1
+ # src/setiastro/saspro/cosmicclarity_engines/benchmark_engine.py
2
+ from __future__ import annotations
3
+
4
+ import os, time, platform
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Literal, Dict, Any
7
+
8
+ import numpy as np
9
+ from astropy.io import fits
10
+
11
+ from numba import njit, prange
12
+ import psutil
13
+
14
+ import cpuinfo
15
+
16
+ ProgressCB = Callable[[str], None]
17
+ CancelCB = Callable[[], bool]
18
+ DLProgressCB = Callable[[int, int], None] # (done_bytes, total_bytes)
19
+
20
+ from setiastro.saspro.cosmicclarity_engines.sharpen_engine import load_sharpen_models
21
+ from setiastro.saspro.resources import get_resources
22
+
23
+ # -----------------------------
24
+ # Paths / cache
25
+ # -----------------------------
26
+ def benchmark_cache_dir() -> Path:
27
+ # Reuse your runtime dir (same one accel installer uses)
28
+ from setiastro.saspro.runtime_torch import _user_runtime_dir
29
+ rt = Path(_user_runtime_dir())
30
+ d = rt / "benchmarks"
31
+ d.mkdir(parents=True, exist_ok=True)
32
+ return d
33
+
34
+ def benchmark_image_path() -> Path:
35
+ return benchmark_cache_dir() / "benchmarkimage.fit"
36
+
37
+ from typing import Sequence, Union
38
+ from urllib.parse import urlparse, parse_qs
39
+ import re
40
+
41
+ BENCHMARK_URLS = [
42
+ "https://drive.google.com/file/d/1wgp6Ydn8JgF1j9FVnF6PgjyN-6ptJTnK/view?usp=drive_link",
43
+ "https://drive.google.com/file/d/1QhsmuKjvksAMq45M3aHKHylZgZEd8Nh0/view?usp=drive_link",
44
+ "https://github.com/setiastro/setiastrosuitepro/releases/download/benchmarkFIT/benchmarkimage.fit",
45
+ ]
46
+
47
+ # Keep for backwards compat (some code imports BENCHMARK_FITS_URL)
48
+ BENCHMARK_FITS_URL = BENCHMARK_URLS[-1]
49
+
50
+
51
+ from urllib.parse import urljoin
52
+
53
+ def _looks_like_html_prefix(b: bytes) -> bool:
54
+ head = (b or b"").lstrip()[:256].lower()
55
+ return head.startswith(b"<!doctype html") or head.startswith(b"<html") or b"<html" in head
56
+
57
+ def _parse_gdrive_download_form(html: str) -> tuple[str, dict] | tuple[None, None]:
58
+ """
59
+ Parse the Google Drive virus-scan warning page.
60
+ Extracts:
61
+ - form action URL (often https://drive.usercontent.google.com/download)
62
+ - all hidden input fields needed for the download
63
+ """
64
+ # action="..."
65
+ m = re.search(r'<form[^>]+id="download-form"[^>]+action="([^"]+)"', html)
66
+ if not m:
67
+ return None, None
68
+ action = m.group(1)
69
+
70
+ # hidden inputs: <input type="hidden" name="X" value="Y">
71
+ inputs = {}
72
+ for name, val in re.findall(r'<input[^>]+type="hidden"[^>]+name="([^"]+)"[^>]*value="([^"]*)"', html):
73
+ inputs[name] = val
74
+
75
+ # Some pages omit value="" explicitly; handle name-only hidden inputs too (rare)
76
+ for name in re.findall(r'<input[^>]+type="hidden"[^>]+name="([^"]+)"(?![^>]*value=)', html):
77
+ inputs.setdefault(name, "")
78
+
79
+ return action, inputs
80
+
81
+
82
+ def _gdrive_file_id(url: str) -> Optional[str]:
83
+ """
84
+ Extract Google Drive file id from:
85
+ - https://drive.google.com/file/d/<ID>/view
86
+ - https://drive.google.com/open?id=<ID>
87
+ - https://drive.google.com/uc?id=<ID>&export=download
88
+ """
89
+ try:
90
+ u = urlparse(url)
91
+ if "drive.google.com" not in (u.netloc or ""):
92
+ return None
93
+
94
+ # /file/d/<id>/...
95
+ m = re.search(r"/file/d/([^/]+)", u.path or "")
96
+ if m:
97
+ return m.group(1)
98
+
99
+ # ?id=<id>
100
+ qs = parse_qs(u.query or "")
101
+ if "id" in qs and qs["id"]:
102
+ return qs["id"][0]
103
+ except Exception:
104
+ pass
105
+ return None
106
+
107
+
108
+ def _gdrive_direct_url(file_id: str) -> str:
109
+ # export=download is essential; confirm token may be appended later if needed
110
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
111
+
112
+
113
+ def _gdrive_confirm_token(html: str) -> Optional[str]:
114
+ """
115
+ When Drive shows the "can't scan for viruses" interstitial,
116
+ it includes a confirm token in a download link.
117
+ """
118
+ # Typical patterns include confirm=<TOKEN>
119
+ m = re.search(r"confirm=([0-9A-Za-z_]+)", html)
120
+ if m:
121
+ return m.group(1)
122
+ return None
123
+
124
+
125
+ def _normalize_download_url(url: str) -> tuple[str, Optional[str]]:
126
+ """
127
+ Returns (normalized_url, label_for_logging).
128
+ If Google Drive view/open link, returns a direct uc?export=download&id= URL.
129
+ """
130
+ fid = _gdrive_file_id(url)
131
+ if fid:
132
+ return _gdrive_direct_url(fid), f"Google Drive ({fid})"
133
+ return url, None
134
+
135
+ def _looks_like_html_prefix(b: bytes) -> bool:
136
+ if not b:
137
+ return False
138
+ head = b.lstrip()[:64].lower()
139
+ return head.startswith(b"<!doctype html") or head.startswith(b"<html") or b"<html" in head
140
+
141
+ def _is_probably_valid_fits(path: Path, *, min_bytes: int = 1_000_000) -> bool:
142
+ try:
143
+ if path.stat().st_size < min_bytes:
144
+ return False
145
+ with open(path, "rb") as f:
146
+ first = f.read(80)
147
+ # FITS primary header should start with SIMPLE =
148
+ if b"SIMPLE" not in first[:20]:
149
+ return False
150
+ return True
151
+ except Exception:
152
+ return False
153
+
154
+
155
+ # -----------------------------
156
+ # Download benchmark FITS
157
+ # -----------------------------
158
+ def download_benchmark_image(
159
+ url: Union[str, Sequence[str], None] = None,
160
+ dst: Optional[Path] = None,
161
+ *,
162
+ status_cb: Optional[ProgressCB] = None,
163
+ progress_cb: Optional[DLProgressCB] = None,
164
+ cancel_cb: Optional[CancelCB] = None,
165
+ timeout: int = 30,
166
+ ) -> Path:
167
+ """
168
+ Download benchmarkimage.fit into runtime cache.
169
+
170
+ url:
171
+ - None -> try BENCHMARK_URLS in order
172
+ - str -> try that one (but will still handle Drive confirms)
173
+ - list/tuple -> try each in order
174
+
175
+ Uses streaming download + atomic replace. Supports cancel.
176
+ """
177
+ if dst is None:
178
+ dst = benchmark_image_path()
179
+ dst = Path(dst)
180
+ tmp = dst.with_suffix(dst.suffix + ".part")
181
+
182
+ # Build candidate list
183
+ if url is None:
184
+ candidates = list(BENCHMARK_URLS)
185
+ elif isinstance(url, (list, tuple)):
186
+ candidates = list(url)
187
+ else:
188
+ candidates = [url]
189
+
190
+ import requests # local import keeps startup lighter
191
+
192
+ last_err = None
193
+
194
+ for idx, raw in enumerate(candidates, start=1):
195
+ try:
196
+ dl_url, label = _normalize_download_url(raw)
197
+ src_label = label or raw
198
+
199
+ if status_cb:
200
+ status_cb(f"Downloading benchmark image… (source {idx}/{len(candidates)}: {src_label})")
201
+
202
+ # Use a session so Drive confirm/cookies work reliably
203
+ with requests.Session() as s:
204
+ r = s.get(dl_url, stream=True, timeout=timeout, allow_redirects=True)
205
+
206
+ ctype = (r.headers.get("Content-Type") or "").lower()
207
+
208
+ # If we got HTML, it’s probably the virus-scan warning page
209
+ if "text/html" in ctype:
210
+ html = r.text # reads page into memory (small)
211
+ r.close()
212
+
213
+ action, params = _parse_gdrive_download_form(html)
214
+ if action and params:
215
+ # Submit the "Download anyway" form with the SAME session/cookies
216
+ r = s.get(action, params=params, stream=True, timeout=timeout, allow_redirects=True)
217
+ ctype = (r.headers.get("Content-Type") or "").lower()
218
+ else:
219
+ raise RuntimeError("Google Drive returned an interstitial HTML page, but download form could not be parsed.")
220
+
221
+ r.raise_for_status()
222
+ total = int(r.headers.get("Content-Length") or 0)
223
+ done = 0
224
+ t_start = time.time()
225
+ t_last = t_start
226
+ done_last = 0
227
+ ema_bps = None
228
+
229
+ tmp.parent.mkdir(parents=True, exist_ok=True)
230
+
231
+ first_chunk = True
232
+ with open(tmp, "wb") as f:
233
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
234
+ if cancel_cb and cancel_cb():
235
+ try:
236
+ f.close()
237
+ tmp.unlink(missing_ok=True)
238
+ except Exception:
239
+ pass
240
+ raise RuntimeError("Download canceled.")
241
+
242
+ if not chunk:
243
+ continue
244
+
245
+ # SNIFF FIRST BYTES: if it's HTML, abort this source immediately
246
+ if first_chunk:
247
+ first_chunk = False
248
+ if _looks_like_html_prefix(chunk[:256]):
249
+ raise RuntimeError("Google Drive returned HTML (not the FITS). Link likely requires confirm/permission.")
250
+
251
+ f.write(chunk)
252
+ done += len(chunk)
253
+
254
+ if progress_cb:
255
+ progress_cb(done, total)
256
+
257
+ now = time.time()
258
+ dt = now - t_last
259
+ if dt >= 0.5:
260
+ inst_bps = (done - done_last) / max(dt, 1e-9)
261
+ ema_bps = inst_bps if ema_bps is None else (0.75 * ema_bps + 0.25 * inst_bps)
262
+
263
+ eta = None
264
+ if total > 0 and ema_bps and ema_bps > 1:
265
+ eta = (total - done) / ema_bps
266
+
267
+ if status_cb:
268
+ pct = (done * 100.0 / total) if total > 0 else None
269
+ if pct is None:
270
+ status_cb(f"Downloading… {_fmt_bytes(done)} at {_fmt_bytes(ema_bps)}/s • ETA {_fmt_eta(None)}")
271
+ else:
272
+ status_cb(
273
+ f"Downloading… {pct:5.1f}% • {_fmt_bytes(done)}/{_fmt_bytes(total)} "
274
+ f"at {_fmt_bytes(ema_bps)}/s • ETA {_fmt_eta(eta)}"
275
+ )
276
+
277
+ t_last = now
278
+ done_last = done
279
+
280
+ # atomic replace only after a full success
281
+ os.replace(str(tmp), str(dst))
282
+
283
+ # VALIDATE: size + FITS header sanity. If invalid, treat as failure and try next URL.
284
+ if not _is_probably_valid_fits(dst, min_bytes=10_000_000): # 10MB floor; tune as you like
285
+ raise RuntimeError("Downloaded file is not a valid FITS (too small or missing SIMPLE).")
286
+
287
+ if status_cb:
288
+ status_cb(f"Benchmark image ready: {dst}")
289
+ return dst
290
+
291
+
292
+ except Exception as e:
293
+ last_err = e
294
+ # clean partial
295
+ try:
296
+ tmp.unlink(missing_ok=True)
297
+ except Exception:
298
+ pass
299
+ if status_cb:
300
+ status_cb(f"Source {idx} failed: {e}")
301
+
302
+ raise RuntimeError(f"All benchmark download sources failed. Last error: {last_err}")
303
+
304
+ def _fmt_bytes(n: float) -> str:
305
+ units = ["B", "KB", "MB", "GB", "TB"]
306
+ n = float(max(0.0, n))
307
+ for u in units:
308
+ if n < 1024.0 or u == units[-1]:
309
+ return f"{n:.1f} {u}" if u != "B" else f"{n:.0f} {u}"
310
+ n /= 1024.0
311
+ return f"{n:.1f} TB"
312
+
313
+ def _fmt_eta(seconds: Optional[float]) -> str:
314
+ if seconds is None or seconds <= 0 or not np.isfinite(seconds):
315
+ return "—"
316
+ s = int(seconds + 0.5)
317
+ m, s = divmod(s, 60)
318
+ h, m = divmod(m, 60)
319
+ if h:
320
+ return f"{h:d}h {m:02d}m"
321
+ if m:
322
+ return f"{m:d}m {s:02d}s"
323
+ return f"{s:d}s"
324
+
325
+
326
+ def _get_stellar_model_for_benchmark(use_gpu: bool, status_cb=None):
327
+ """
328
+ Returns (models, backend_tag)
329
+ - models: SharpenModels (torch or onnx)
330
+ - backend_tag: 'CUDA', 'CPU', 'DirectML', 'MPS', etc.
331
+ """
332
+ models = load_sharpen_models(use_gpu=use_gpu, status_cb=status_cb or (lambda *_: None))
333
+
334
+ if models.is_onnx:
335
+ return models, "DirectML"
336
+
337
+ dev = models.device
338
+ dev_type = getattr(dev, "type", "")
339
+ if dev_type == "cuda":
340
+ return models, "CUDA"
341
+ if dev_type == "mps":
342
+ return models, "MPS"
343
+
344
+ # torch-directml devices don’t have .type == 'dml' typically; handle by string
345
+ if "dml" in str(dev).lower() or "directml" in str(dev).lower():
346
+ return models, "DirectML"
347
+
348
+ return models, "CPU"
349
+
350
+
351
+
352
+ def torch_benchmark_stellar(
353
+ patches_nchw: np.ndarray,
354
+ *,
355
+ use_gpu: bool,
356
+ progress_cb=None, # (done,total)->bool
357
+ status_cb=None,
358
+ ) -> tuple[float, float, str]:
359
+ """
360
+ Torch benchmark using the SAME Stellar model + autocast policy as sharpen_engine.
361
+ """
362
+ status_cb = status_cb or (lambda *_: None)
363
+ models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
364
+
365
+ if models.is_onnx:
366
+ raise RuntimeError("torch_benchmark_stellar called but models.is_onnx=True")
367
+
368
+ torch = models.torch
369
+ device = models.device
370
+ model = models.stellar
371
+
372
+ x = torch.from_numpy(patches_nchw).to(device=device, dtype=torch.float32, non_blocking=True)
373
+
374
+ # warmup a tiny bit to avoid first-kernel skew
375
+ with torch.no_grad():
376
+ _ = model(x[0:1])
377
+ if device.type == "cuda":
378
+ torch.cuda.synchronize()
379
+
380
+ total_ms = 0.0
381
+ n = int(x.shape[0])
382
+ status_cb(f"Benchmarking Stellar model via Torch ({tag})…")
383
+
384
+ # IMPORTANT: reuse sharpen_engine autocast policy, not unconditional AMP
385
+ from setiastro.saspro.cosmicclarity_engines.sharpen_engine import _autocast_context
386
+
387
+ with torch.no_grad(), _autocast_context(torch, device):
388
+ for i in range(n):
389
+ t0 = time.time()
390
+ _ = model(x[i:i+1])
391
+ if device.type == "cuda":
392
+ torch.cuda.synchronize()
393
+ total_ms += (time.time() - t0) * 1000.0
394
+
395
+ if progress_cb and (not progress_cb(i + 1, n)):
396
+ raise RuntimeError("Canceled.")
397
+
398
+ return (total_ms / n), total_ms, tag
399
+
400
+ def onnx_benchmark_stellar(
401
+ patches_nchw: np.ndarray,
402
+ *,
403
+ use_gpu: bool,
404
+ progress_cb=None,
405
+ status_cb=None,
406
+ ) -> tuple[float, float, str]:
407
+ """
408
+ ONNX benchmark:
409
+ - If sharpen_engine selected ONNX (DirectML), reuse that exact session.
410
+ - Otherwise, on Windows, prefer DirectML provider if available (when use_gpu=True),
411
+ then CUDA EP if present, else CPU.
412
+ """
413
+ status_cb = status_cb or (lambda *_: None)
414
+
415
+ # Reuse sharpen_engine session if it already chose ONNX (typically DirectML on Windows)
416
+ models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
417
+ if models.is_onnx:
418
+ sess = models.stellar
419
+ # Use the provider that the session actually has (best-effort label)
420
+ try:
421
+ provs = sess.get_providers()
422
+ provider = provs[0] if provs else "ONNX"
423
+ except Exception:
424
+ provider = "DmlExecutionProvider"
425
+ else:
426
+ import onnxruntime as ort
427
+ from setiastro.saspro.model_manager import require_model
428
+ onnx_path = require_model("deep_sharp_stellar_cnn_AI3_5s.onnx")
429
+
430
+ providers_avail = ort.get_available_providers()
431
+
432
+ # Prefer DirectML if possible (Windows) when GPU requested
433
+ providers = []
434
+ if use_gpu and ("DmlExecutionProvider" in providers_avail):
435
+ providers.append("DmlExecutionProvider")
436
+ # If no DML (or user disabled GPU), try CUDA EP if available
437
+ if use_gpu and ("CUDAExecutionProvider" in providers_avail):
438
+ providers.append("CUDAExecutionProvider")
439
+ # Always end with CPU
440
+ providers.append("CPUExecutionProvider")
441
+
442
+ # Build session
443
+ sess = ort.InferenceSession(str(onnx_path), providers=providers)
444
+
445
+ # Label by what actually got picked
446
+ try:
447
+ provs = sess.get_providers()
448
+ provider = provs[0] if provs else providers[0]
449
+ except Exception:
450
+ provider = providers[0]
451
+
452
+ input_name = sess.get_inputs()[0].name
453
+ total_ms = 0.0
454
+ n = int(patches_nchw.shape[0])
455
+
456
+ status_cb(f"Benchmarking Stellar model via ONNX ({provider})…")
457
+
458
+ for i in range(n):
459
+ patch = patches_nchw[i:i+1].astype(np.float32, copy=False)
460
+ t0 = time.time()
461
+ sess.run(None, {input_name: patch})
462
+ total_ms += (time.time() - t0) * 1000.0
463
+
464
+ if progress_cb and (not progress_cb(i + 1, n)):
465
+ raise RuntimeError("Canceled.")
466
+
467
+ return (total_ms / n), total_ms, provider
468
+
469
+ # -----------------------------
470
+ # Load + tile image
471
+ # -----------------------------
472
+ def _load_benchmark_fits(path: Path) -> np.ndarray:
473
+ with fits.open(str(path), memmap=False) as hdul:
474
+ img = hdul[0].data
475
+ if img is None:
476
+ raise RuntimeError("FITS contains no data.")
477
+ img = np.asarray(img, dtype=np.float32)
478
+
479
+ # Expect mono 2D; convert to CHW(3,H,W)
480
+ if img.ndim == 2:
481
+ img = np.stack([img, img, img], axis=0)
482
+ elif img.ndim == 3:
483
+ # If HWC convert -> CHW
484
+ if img.shape[-1] in (3, 4) and img.shape[0] != 3:
485
+ img = np.transpose(img[..., :3], (2, 0, 1))
486
+ # If already CHW with 3 ok; if 1 channel expand
487
+ if img.shape[0] == 1:
488
+ img = np.repeat(img, 3, axis=0)
489
+ else:
490
+ raise RuntimeError(f"Unexpected FITS shape: {img.shape}")
491
+ return img
492
+
493
+
494
+ def tile_chw_image(image_chw: np.ndarray, patch_size: int = 256) -> np.ndarray:
495
+ """
496
+ image_chw: (3,H,W) -> patches: (N,3,patch,patch)
497
+ Only full patches (no padding) to match your old behavior.
498
+ """
499
+ c, h, w = image_chw.shape
500
+ patches = []
501
+ for y in range(0, h, patch_size):
502
+ for x in range(0, w, patch_size):
503
+ p = image_chw[:, y:y+patch_size, x:x+patch_size]
504
+ if p.shape[1] == patch_size and p.shape[2] == patch_size:
505
+ patches.append(p)
506
+ if not patches:
507
+ raise RuntimeError("No full 256x256 patches found in benchmark image.")
508
+ return np.stack(patches, axis=0).astype(np.float32, copy=False)
509
+
510
+
511
+ # -----------------------------
512
+ # CPU microbenchmarks (Numba)
513
+ # -----------------------------
514
+ @njit
515
+ def _mad_cpu_jit(image_array: np.ndarray, median_val: float) -> float:
516
+ return np.median(np.abs(image_array - median_val))
517
+
518
+ def mad_cpu(image_array: np.ndarray, runs: int = 3) -> list[float]:
519
+ """
520
+ Return ms timings. First run includes JIT compile.
521
+ image_array can be CHW or HW; we flatten it for fairness.
522
+ """
523
+ arr = np.asarray(image_array, dtype=np.float32).ravel()
524
+ times = []
525
+ for _ in range(runs):
526
+ t0 = time.time()
527
+ med = float(np.median(arr))
528
+ _ = _mad_cpu_jit(arr, med)
529
+ times.append((time.time() - t0) * 1000.0)
530
+ return times
531
+
532
+ @njit(parallel=True)
533
+ def _flat_field_jit(image_array: np.ndarray, flat_frame: np.ndarray, median_flat: float) -> np.ndarray:
534
+ out = np.empty_like(image_array)
535
+ n = image_array.size
536
+ for i in prange(n):
537
+ out[i] = image_array[i] / (flat_frame[i] / median_flat)
538
+ return out
539
+
540
+ def flat_field_correction(image_array: np.ndarray, flat_frame: np.ndarray, runs: int = 3) -> list[float]:
541
+ arr = np.asarray(image_array, dtype=np.float32).ravel()
542
+ flt = np.asarray(flat_frame, dtype=np.float32).ravel()
543
+ times = []
544
+ for _ in range(runs):
545
+ t0 = time.time()
546
+ med = float(np.median(flt))
547
+ _ = _flat_field_jit(arr, flt, med)
548
+ times.append((time.time() - t0) * 1000.0)
549
+ return times
550
+
551
+
552
+ # -----------------------------
553
+ # System info
554
+ # -----------------------------
555
+ def get_system_info() -> dict:
556
+ info = {
557
+ "OS": f"{platform.system()} {platform.release()}",
558
+ "CPU": cpuinfo.get_cpu_info().get("brand_raw", "Unknown"),
559
+ "RAM": f"{round(psutil.virtual_memory().total / (1024 ** 3), 1)} GB",
560
+ "Python": f"{platform.python_version()}",
561
+ }
562
+ # torch / onnx details (optional)
563
+ try:
564
+ from setiastro.saspro.runtime_torch import add_runtime_to_sys_path
565
+ add_runtime_to_sys_path(status_cb=lambda *_: None)
566
+ import torch
567
+ info["torch.__version__"] = getattr(torch, "__version__", "Unknown")
568
+ info["torch.version.cuda"] = getattr(getattr(torch, "version", None), "cuda", None)
569
+ info["CUDA Available"] = bool(getattr(torch, "cuda", None) and torch.cuda.is_available())
570
+ info["MPS Available"] = bool(hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
571
+ if info["CUDA Available"]:
572
+ try:
573
+ info["GPU"] = torch.cuda.get_device_name(0)
574
+ except Exception:
575
+ pass
576
+ except Exception:
577
+ pass
578
+
579
+ try:
580
+ import onnxruntime as ort
581
+ info["ONNX Providers"] = ort.get_available_providers()
582
+ except Exception:
583
+ pass
584
+
585
+ return info
586
+
587
+ def _fmt_ms(first_ms: float, avg_ms: float) -> str:
588
+ return f"First: {first_ms:.2f} ms | Avg: {avg_ms:.2f} ms"
589
+
590
+ def _fmt_gpu(avg_ms: float, total_ms: float) -> str:
591
+ return f"Avg: {avg_ms:.2f} ms | Total: {total_ms:.2f} ms"
592
+
593
+ def _legacy_results_schema(results: Dict[str, Any]) -> Dict[str, Any]:
594
+ """
595
+ Convert internal structured results -> legacy/human schema expected by submitter.
596
+ """
597
+ out: Dict[str, Any] = {}
598
+
599
+ # --- CPU ---
600
+ cpu_mad = results.get("CPU MAD (Single Core)")
601
+ if isinstance(cpu_mad, dict):
602
+ out["CPU MAD (Single Core)"] = _fmt_ms(cpu_mad.get("first_ms", 0.0), cpu_mad.get("avg_ms", 0.0))
603
+ elif cpu_mad is not None:
604
+ out["CPU MAD (Single Core)"] = str(cpu_mad)
605
+
606
+ cpu_flat = results.get("CPU Flat-Field (Multi-Core)")
607
+ if isinstance(cpu_flat, dict):
608
+ out["CPU Flat-Field (Multi-Core)"] = _fmt_ms(cpu_flat.get("first_ms", 0.0), cpu_flat.get("avg_ms", 0.0))
609
+ elif cpu_flat is not None:
610
+ out["CPU Flat-Field (Multi-Core)"] = str(cpu_flat)
611
+
612
+ # --- GPU / Torch ---
613
+ # Your internal key looks like "Stellar Torch (CUDA)" or "(CPU)" etc.
614
+ torch_key = next((k for k in results.keys() if k.startswith("Stellar Torch (")), None)
615
+ if torch_key and isinstance(results.get(torch_key), dict):
616
+ backend = torch_key[len("Stellar Torch ("):-1] # extract tag inside (...)
617
+ t = results[torch_key]
618
+ out[f"GPU Time ({backend})"] = _fmt_gpu(t.get("avg_ms", 0.0), t.get("total_ms", 0.0))
619
+
620
+ # --- ONNX ---
621
+ onnx_key = next((k for k in results.keys() if k.startswith("Stellar ONNX")), None)
622
+ if onnx_key is None:
623
+ # you used this on non-windows
624
+ if "Stellar ONNX" in results:
625
+ out["ONNX Time"] = str(results["Stellar ONNX"])
626
+ else:
627
+ out["ONNX Time"] = "ONNX benchmark not run."
628
+ else:
629
+ v = results.get(onnx_key)
630
+ if isinstance(v, dict):
631
+ # keep it under the legacy key name
632
+ out["ONNX Time"] = _fmt_gpu(v.get("avg_ms", 0.0), v.get("total_ms", 0.0))
633
+ else:
634
+ out["ONNX Time"] = str(v)
635
+
636
+ # --- System Info ---
637
+ # If you want to match your example more closely, drop Python/torch versions.
638
+ si = results.get("System Info", {})
639
+ if isinstance(si, dict):
640
+ keep = {
641
+ "OS", "CPU", "RAM",
642
+ "CUDA Available", "MPS Available", "ONNX Providers", "GPU"
643
+ }
644
+ out["System Info"] = {k: si[k] for k in keep if k in si}
645
+ else:
646
+ out["System Info"] = si
647
+
648
+ return out
649
+
650
+ def _picked_backend(use_gpu: bool, status_cb=None) -> str:
651
+ models, tag = _get_stellar_model_for_benchmark(use_gpu=use_gpu, status_cb=status_cb)
652
+ return "ONNX" if models.is_onnx else "TORCH"
653
+
654
+ # -----------------------------
655
+ # One-stop runner
656
+ # -----------------------------
657
+ def run_benchmark(
658
+ *,
659
+ mode: Literal["CPU", "GPU", "Both"] = "Both",
660
+ use_gpu: bool = True,
661
+ benchmark_fits_path: Optional[Path] = None,
662
+ status_cb: Optional[ProgressCB] = None,
663
+ progress_cb: Optional[Callable[[int, int], bool]] = None,
664
+ ) -> Dict[str, Any]:
665
+ """
666
+ Returns results dict safe to json.dumps.
667
+ progress_cb signature: (done, total)->bool continue
668
+ """
669
+ if benchmark_fits_path is None:
670
+ benchmark_fits_path = benchmark_image_path()
671
+ benchmark_fits_path = Path(benchmark_fits_path)
672
+ if not benchmark_fits_path.exists():
673
+ raise RuntimeError("Benchmark image not downloaded yet.")
674
+
675
+ img_chw = _load_benchmark_fits(benchmark_fits_path)
676
+ patches = tile_chw_image(img_chw, 256) # (N,3,256,256)
677
+
678
+ results: Dict[str, Any] = {}
679
+ results["System Info"] = get_system_info()
680
+
681
+ if mode in ("CPU", "Both"):
682
+ if status_cb: status_cb("Running CPU benchmarks…")
683
+ cpu_mad = mad_cpu(img_chw)
684
+ cpu_flat = flat_field_correction(img_chw, img_chw)
685
+ results["CPU MAD (Single Core)"] = {
686
+ "first_ms": float(cpu_mad[0]),
687
+ "avg_ms": float(np.mean(cpu_mad[1:])) if len(cpu_mad) > 1 else float(cpu_mad[0]),
688
+ }
689
+ results["CPU Flat-Field (Multi-Core)"] = {
690
+ "first_ms": float(cpu_flat[0]),
691
+ "avg_ms": float(np.mean(cpu_flat[1:])) if len(cpu_flat) > 1 else float(cpu_flat[0]),
692
+ }
693
+
694
+ if mode in ("GPU", "Both"):
695
+ if status_cb: status_cb("Running Stellar model benchmark…")
696
+
697
+ picked = _picked_backend(use_gpu=use_gpu, status_cb=status_cb)
698
+
699
+ if picked == "TORCH":
700
+ avg_ms, total_ms, backend = torch_benchmark_stellar(
701
+ patches,
702
+ use_gpu=use_gpu,
703
+ progress_cb=progress_cb,
704
+ status_cb=status_cb,
705
+ )
706
+ results[f"Stellar Torch ({backend})"] = {"avg_ms": float(avg_ms), "total_ms": float(total_ms)}
707
+
708
+ # Optional: also run ONNX on Windows for comparison
709
+ if platform.system() == "Windows":
710
+ avg_o, total_o, provider = onnx_benchmark_stellar(
711
+ patches,
712
+ use_gpu=use_gpu,
713
+ progress_cb=progress_cb,
714
+ status_cb=status_cb,
715
+ )
716
+ results[f"Stellar ONNX ({provider})"] = {"avg_ms": float(avg_o), "total_ms": float(total_o)}
717
+
718
+ else:
719
+ # Picked ONNX (DirectML). Run ONNX benchmark as the “GPU benchmark”.
720
+ if platform.system() == "Windows":
721
+ avg_o, total_o, provider = onnx_benchmark_stellar(
722
+ patches,
723
+ use_gpu=use_gpu,
724
+ progress_cb=progress_cb,
725
+ status_cb=status_cb,
726
+ )
727
+ results[f"Stellar ONNX ({provider})"] = {"avg_ms": float(avg_o), "total_ms": float(total_o)}
728
+ else:
729
+ results["Stellar ONNX"] = "ONNX benchmark only available on Windows."
730
+
731
+ results = _legacy_results_schema(results)
732
+ return results