setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,412 @@
1
+ #src/setiastro/saspro/cosmicclarity_engines/superres_engine.py
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from typing import Optional, Dict, Any, Tuple
6
+
7
+ import numpy as np
8
+ import cv2
9
+ cv2.setNumThreads(1)
10
+ try:
11
+ import onnxruntime as ort
12
+ except Exception:
13
+ ort = None
14
+
15
+ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
16
+ from setiastro.saspro.runtime_torch import import_torch
17
+ return import_torch(
18
+ prefer_cuda=prefer_cuda,
19
+ prefer_xpu=False,
20
+ prefer_dml=prefer_dml,
21
+ status_cb=status_cb,
22
+ )
23
+
24
+
25
+ from setiastro.saspro.resources import get_resources
26
+
27
+ def _load_torch_superres_model(torch, device, pth_path: str):
28
+ nn = torch.nn
29
+
30
+ class ResidualBlock(nn.Module):
31
+ def __init__(self, channels: int):
32
+ super().__init__()
33
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
34
+ self.relu = nn.ReLU(inplace=True)
35
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+ out = self.relu(self.conv1(x))
40
+ out = self.conv2(out)
41
+ out = out + residual
42
+ return self.relu(out)
43
+
44
+ class SuperResolutionCNN(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(16))
48
+ self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(32))
49
+ self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(inplace=True), ResidualBlock(64))
50
+ self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(128))
51
+ self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(inplace=True), ResidualBlock(256))
52
+
53
+ self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(128))
54
+ self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(64))
55
+ self.decoder3 = nn.Sequential(nn.Conv2d(64 + 32, 32, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(32))
56
+ self.decoder2 = nn.Sequential(nn.Conv2d(32 + 16, 16, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(16))
57
+ self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
58
+
59
+ def forward(self, x):
60
+ e1 = self.encoder1(x)
61
+ e2 = self.encoder2(e1)
62
+ e3 = self.encoder3(e2)
63
+ e4 = self.encoder4(e3)
64
+ e5 = self.encoder5(e4)
65
+
66
+ d5 = self.decoder5(torch.cat([e5, e4], dim=1))
67
+ d4 = self.decoder4(torch.cat([d5, e3], dim=1))
68
+ d3 = self.decoder3(torch.cat([d4, e2], dim=1))
69
+ d2 = self.decoder2(torch.cat([d3, e1], dim=1))
70
+ return self.decoder1(d2)
71
+
72
+ if not os.path.exists(pth_path):
73
+ raise FileNotFoundError(f"SuperRes model not found: {pth_path}")
74
+
75
+ model = SuperResolutionCNN().to(device)
76
+ sd = torch.load(pth_path, map_location=device)
77
+ if isinstance(sd, dict) and "model_state_dict" in sd:
78
+ sd = sd["model_state_dict"]
79
+ model.load_state_dict(sd)
80
+ model.eval()
81
+ return model
82
+
83
+
84
+ # ----------------------------
85
+ # Shared helpers (copy from denoise_engine if you want)
86
+ # ----------------------------
87
+ def stretch_image(image: np.ndarray):
88
+ original_min = float(np.min(image))
89
+ stretched = image - original_min
90
+
91
+ is_single = (image.ndim == 2) or (image.ndim == 3 and image.shape[2] == 1)
92
+ target_median = 0.25
93
+
94
+ if is_single:
95
+ med = float(np.median(stretched))
96
+ orig_medians = [med]
97
+ if med != 0:
98
+ stretched = ((med - 1) * target_median * stretched) / (med * (target_median + stretched - 1) - target_median * stretched)
99
+ else:
100
+ orig_medians = []
101
+ for c in range(3):
102
+ med = float(np.median(stretched[..., c]))
103
+ orig_medians.append(med)
104
+ if med != 0:
105
+ stretched[..., c] = ((med - 1) * target_median * stretched[..., c]) / (
106
+ med * (target_median + stretched[..., c] - 1) - target_median * stretched[..., c]
107
+ )
108
+
109
+ return np.clip(stretched, 0, 1).astype(np.float32), original_min, orig_medians
110
+
111
+
112
+ def unstretch_image(image: np.ndarray, original_medians, original_min: float):
113
+ is_single = (image.ndim == 2) or (image.ndim == 3 and image.shape[2] == 1)
114
+ if is_single:
115
+ med = float(np.median(image))
116
+ if med != 0 and original_medians[0] != 0:
117
+ image = ((med - 1) * original_medians[0] * image) / (med * (original_medians[0] + image - 1) - original_medians[0] * image)
118
+ else:
119
+ for c in range(3):
120
+ med = float(np.median(image[..., c]))
121
+ if med != 0 and original_medians[c] != 0:
122
+ image[..., c] = ((med - 1) * original_medians[c] * image[..., c]) / (
123
+ med * (original_medians[c] + image[..., c] - 1) - original_medians[c] * image[..., c]
124
+ )
125
+
126
+ image = image + original_min
127
+ return np.clip(image, 0, 1).astype(np.float32)
128
+
129
+
130
+ def add_border(image: np.ndarray, border_size: int = 16):
131
+ if image.ndim == 2:
132
+ med = float(np.median(image))
133
+ return np.pad(image, ((border_size, border_size), (border_size, border_size)), mode="constant", constant_values=med)
134
+ else:
135
+ meds = np.median(image, axis=(0, 1)).astype(np.float32)
136
+ chans = []
137
+ for c in range(image.shape[2]):
138
+ chans.append(np.pad(image[..., c], ((border_size, border_size), (border_size, border_size)),
139
+ mode="constant", constant_values=float(meds[c])))
140
+ return np.stack(chans, axis=-1)
141
+
142
+
143
+ def remove_border(image: np.ndarray, border_size: int):
144
+ if image.ndim == 2:
145
+ return image[border_size:-border_size, border_size:-border_size]
146
+ return image[border_size:-border_size, border_size:-border_size, :]
147
+
148
+
149
+ def split_image_into_chunks_with_overlap(image: np.ndarray, chunk_size: int = 256, overlap: int = 64):
150
+ h, w = image.shape[:2]
151
+ step = chunk_size - overlap
152
+ chunks = []
153
+ for i in range(0, h, step):
154
+ for j in range(0, w, step):
155
+ end_i = min(i + chunk_size, h)
156
+ end_j = min(j + chunk_size, w)
157
+ patch = image[i:end_i, j:end_j]
158
+ chunks.append((patch, i, j))
159
+ return chunks
160
+
161
+
162
+ def stitch_chunks_ignore_border(chunks, out_hw: Tuple[int, int], border_size: int = 16):
163
+ H, W = out_hw
164
+ stitched = np.zeros((H, W), dtype=np.float32)
165
+ weight = np.zeros((H, W), dtype=np.float32)
166
+
167
+ for patch, i, j in chunks:
168
+ ph, pw = patch.shape[:2]
169
+ b_h = min(border_size, ph // 2)
170
+ b_w = min(border_size, pw // 2)
171
+
172
+ inner = patch[b_h:ph - b_h, b_w:pw - b_w]
173
+ ih, iw = inner.shape[:2]
174
+
175
+ stitched[i + b_h:i + b_h + ih, j + b_w:j + b_w + iw] += inner
176
+ weight[i + b_h:i + b_h + ih, j + b_w:j + b_w + iw] += 1.0
177
+
178
+ return stitched / np.maximum(weight, 1.0)
179
+
180
+
181
+ # ----------------------------
182
+ # Model loading (cached)
183
+ # ----------------------------
184
+ from typing import Dict, Any, Tuple
185
+ import os
186
+
187
+ _cached: dict[tuple[str, int, bool], dict[str, Any]] = {}
188
+ _BACKEND_TAG = "cc_superres"
189
+
190
+ R = get_resources()
191
+
192
+ def _superres_paths(scale: int) -> tuple[str, str]:
193
+ if scale == 2:
194
+ return (R.CC_SUPERRES_2X_PTH, R.CC_SUPERRES_2X_ONNX)
195
+ if scale == 3:
196
+ return (R.CC_SUPERRES_3X_PTH, R.CC_SUPERRES_3X_ONNX)
197
+ if scale == 4:
198
+ return (R.CC_SUPERRES_4X_PTH, R.CC_SUPERRES_4X_ONNX)
199
+ raise ValueError("scale must be 2, 3, or 4")
200
+
201
+
202
+ def _pick_backend(torch, use_gpu: bool):
203
+ # Prefer: CUDA (PyTorch) -> DML (ONNX) on Windows -> MPS (PyTorch) -> CPU
204
+ if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
205
+ return ("pytorch", torch.device("cuda"))
206
+ if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
207
+ return ("onnx", "DirectML")
208
+ if use_gpu and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
209
+ return ("pytorch", torch.device("mps"))
210
+ return ("pytorch", torch.device("cpu"))
211
+
212
+
213
+ def load_superres(scale: int, use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
214
+ scale = int(scale)
215
+ if scale not in (2, 3, 4):
216
+ raise ValueError("scale must be 2, 3, or 4")
217
+
218
+ is_windows = (os.name == "nt")
219
+ torch = _get_torch(prefer_cuda=bool(use_gpu), prefer_dml=bool(use_gpu and is_windows), status_cb=status_cb)
220
+
221
+ pth_path, onnx_path = _superres_paths(scale)
222
+
223
+ # --- DEBUG (remove later) ---
224
+ cuda_ok = bool(use_gpu) and hasattr(torch, "cuda") and torch.cuda.is_available()
225
+ dml_ok = bool(use_gpu) and (ort is not None) and ("DmlExecutionProvider" in ort.get_available_providers())
226
+
227
+ # ---------------------------
228
+
229
+ # IMPORTANT: key should include the ACTUAL selected backend/device, not just use_gpu
230
+ # so you can't get stuck reusing CPU from a previous call.
231
+ # We'll decide backend first, then cache.
232
+
233
+ # Prefer torch CUDA if available & allowed (same as sharpen)
234
+ if cuda_ok:
235
+ device = torch.device("cuda")
236
+ status_cb(f"CosmicClarity SuperRes: using CUDA ({torch.cuda.get_device_name(0)})")
237
+ key = (_BACKEND_TAG, scale, "cuda")
238
+ if key in _cached:
239
+ return _cached[key]
240
+ model = _load_torch_superres_model(torch, device, pth_path)
241
+ out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
242
+ _cached[key] = out
243
+ return out
244
+
245
+ # Torch-DirectML (Windows)
246
+ if use_gpu and is_windows:
247
+ try:
248
+ import torch_directml
249
+ dml = torch_directml.device()
250
+ status_cb("CosmicClarity SuperRes: using DirectML (torch-directml)")
251
+ key = (_BACKEND_TAG, scale, "torch_dml")
252
+ if key in _cached:
253
+ return _cached[key]
254
+ model = _load_torch_superres_model(torch, dml, pth_path)
255
+ out = {"backend": "pytorch", "device": dml, "model": model, "scale": scale, "torch": torch}
256
+ _cached[key] = out
257
+ return out
258
+ except Exception:
259
+ pass
260
+
261
+
262
+ # DirectML ONNX fallback (Windows)
263
+ if dml_ok:
264
+ status_cb("CosmicClarity SuperRes: using DirectML (ONNX Runtime)")
265
+ if not os.path.exists(onnx_path):
266
+ raise FileNotFoundError(f"SuperRes ONNX model not found: {onnx_path}")
267
+ key = (_BACKEND_TAG, scale, "dml")
268
+ if key in _cached:
269
+ return _cached[key]
270
+ sess = ort.InferenceSession(onnx_path, providers=["DmlExecutionProvider"])
271
+ out = {"backend": "onnx", "device": "DirectML", "model": sess, "scale": scale, "torch": None}
272
+ _cached[key] = out
273
+ return out
274
+
275
+ # MPS (mac)
276
+ if bool(use_gpu) and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
277
+ device = torch.device("mps")
278
+ status_cb("CosmicClarity SuperRes: using MPS")
279
+ key = (_BACKEND_TAG, scale, "mps")
280
+ if key in _cached:
281
+ return _cached[key]
282
+ model = _load_torch_superres_model(torch, device, pth_path)
283
+ out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
284
+ _cached[key] = out
285
+ return out
286
+
287
+ # CPU
288
+ device = torch.device("cpu")
289
+ status_cb("CosmicClarity SuperRes: using CPU")
290
+ key = (_BACKEND_TAG, scale, "cpu")
291
+ if key in _cached:
292
+ return _cached[key]
293
+ model = _load_torch_superres_model(torch, device, pth_path)
294
+ out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
295
+ _cached[key] = out
296
+ return out
297
+
298
+
299
+ def _amp_ok(torch, device) -> bool:
300
+ if not isinstance(device, torch.device) or device.type != "cuda":
301
+ return False
302
+ try:
303
+ props = torch.cuda.get_device_properties(device)
304
+ return props.major >= 8
305
+ except Exception:
306
+ return False
307
+
308
+
309
+ def superres_rgb01(
310
+ img_rgb01: np.ndarray,
311
+ *,
312
+ scale: int,
313
+ use_gpu: bool = True,
314
+ progress_cb=None, # progress_cb(done:int,total:int)
315
+ ) -> np.ndarray:
316
+ """
317
+ Input: float32 RGB in [0..1], shape (H,W,3)
318
+ Output: float32 RGB in [0..1], shape (H*scale,W*scale,3)
319
+ """
320
+ scale = int(scale)
321
+ if scale not in (2, 3, 4):
322
+ raise ValueError("scale must be 2, 3, or 4")
323
+
324
+ engine = load_superres(scale, use_gpu=use_gpu, status_cb=print) # or your logger
325
+
326
+ # We process each channel independently (matches your current behavior)
327
+ H, W = img_rgb01.shape[:2]
328
+ out_chans = []
329
+
330
+ # progress accounting: per-channel chunks
331
+ for c in range(3):
332
+ chan = img_rgb01[..., c].astype(np.float32, copy=False)
333
+
334
+ # border + optional stretch
335
+ bordered = add_border(chan, border_size=16)
336
+ if float(np.median(bordered)) < 0.08:
337
+ stretched, orig_min, orig_meds = stretch_image(bordered)
338
+ stretched_applied = True
339
+ else:
340
+ stretched = bordered.astype(np.float32, copy=False)
341
+ stretched_applied = False
342
+ orig_min = float(np.min(bordered))
343
+ orig_meds = [float(np.median(bordered))]
344
+
345
+ # bicubic upscale
346
+ h, w = stretched.shape[:2]
347
+ up = cv2.resize(stretched, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC)
348
+
349
+ # chunk & infer
350
+ chunks = split_image_into_chunks_with_overlap(up, chunk_size=256, overlap=64)
351
+ total = len(chunks)
352
+ done0 = 0
353
+
354
+ processed = []
355
+ use_amp = (engine["backend"] == "pytorch") and _amp_ok(engine["torch"], engine["device"])
356
+ dev = engine["device"]
357
+ dev_type = getattr(dev, "type", None)
358
+ for idx, (patch, i, j) in enumerate(chunks):
359
+ ph, pw = patch.shape[:2]
360
+
361
+ # build 256x256x3 patch
362
+ patch_in = np.zeros((256, 256, 3), dtype=np.float32)
363
+ patch_in[:ph, :pw, 0] = patch[:ph, :pw]
364
+ patch_in[:ph, :pw, 1] = patch[:ph, :pw]
365
+ patch_in[:ph, :pw, 2] = patch[:ph, :pw]
366
+
367
+ if engine["backend"] == "pytorch":
368
+ t = engine["torch"] # torch module from runtime_torch
369
+
370
+ pt = t.from_numpy(patch_in.transpose(2, 0, 1)).unsqueeze(0).to(engine["device"])
371
+
372
+ with t.no_grad():
373
+ if use_amp and dev_type == "cuda":
374
+ with t.cuda.amp.autocast():
375
+ out = engine["model"](pt)
376
+ else:
377
+ out = engine["model"](pt)
378
+ out_np = out[0].detach().cpu().numpy() # (C,H,W)
379
+ else:
380
+ # ONNX (DirectML)
381
+ inp = np.expand_dims(patch_in.transpose(2, 0, 1), axis=0).astype(np.float32)
382
+ out_np = engine["model"].run(None, {engine["model"].get_inputs()[0].name: inp})[0].squeeze()
383
+
384
+ # output is 3ch grayscale; take first channel
385
+ if out_np.ndim == 3 and out_np.shape[0] == 3:
386
+ out_np = out_np[0]
387
+ elif out_np.ndim == 3 and out_np.shape[-1] == 3:
388
+ out_np = out_np[..., 0]
389
+
390
+ out_np = out_np[:ph, :pw].astype(np.float32, copy=False)
391
+ processed.append((out_np, i, j))
392
+
393
+ done0 += 1
394
+ if progress_cb is not None:
395
+ # You can interpret as global progress across all channels:
396
+ progress_cb((c * total) + done0, 3 * total)
397
+
398
+ # stitch
399
+ stitched = stitch_chunks_ignore_border(processed, up.shape[:2], border_size=16)
400
+
401
+ # unstretch if needed
402
+ if stretched_applied:
403
+ stitched = unstretch_image(stitched, orig_meds, orig_min)
404
+
405
+ # remove scaled border: 16px border became 16*scale after upscaling
406
+ final_border = int(16 * scale)
407
+ out_chan = remove_border(stitched, border_size=final_border)
408
+
409
+ out_chans.append(out_chan)
410
+
411
+ out_rgb = np.stack(out_chans, axis=-1)
412
+ return np.clip(out_rgb, 0.0, 1.0).astype(np.float32, copy=False)
@@ -2347,6 +2347,20 @@ class AstroSuiteProMainWindow(
2347
2347
  # Log error but don't crash if preload fails
2348
2348
  print(f"Error preloading settings: {e}")
2349
2349
 
2350
+ def _open_benchmark(self):
2351
+ from setiastro.saspro.ops.benchmark import BenchmarkDialog # new file below
2352
+
2353
+ if getattr(self, "_bench_dlg_cache", None) is None:
2354
+ self._bench_dlg_cache = BenchmarkDialog(self)
2355
+
2356
+ dlg = self._bench_dlg_cache
2357
+ if hasattr(dlg, "refresh_ui"):
2358
+ dlg.refresh_ui()
2359
+ dlg.show()
2360
+ dlg.raise_()
2361
+ dlg.activateWindow()
2362
+
2363
+
2350
2364
  def _open_settings(self):
2351
2365
  from setiastro.saspro.ops.settings import SettingsDialog
2352
2366
 
@@ -333,6 +333,8 @@ class MenuMixin:
333
333
 
334
334
  m_settings = mb.addMenu(self.tr("&Settings"))
335
335
  m_settings.addAction(self.tr("Preferences..."), self._open_settings)
336
+ m_settings.addSeparator()
337
+ m_settings.addAction(self.tr("Benchmark..."), self._open_benchmark)
336
338
 
337
339
  m_about = mb.addMenu(self.tr("&About"))
338
340
  m_about.addAction(self.act_docs)