setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
#src/setiastro/saspro/cosmicclarity_engines/superres_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
from typing import Optional, Dict, Any, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import cv2
|
|
9
|
+
cv2.setNumThreads(1)
|
|
10
|
+
try:
|
|
11
|
+
import onnxruntime as ort
|
|
12
|
+
except Exception:
|
|
13
|
+
ort = None
|
|
14
|
+
|
|
15
|
+
def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
|
|
16
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
17
|
+
return import_torch(
|
|
18
|
+
prefer_cuda=prefer_cuda,
|
|
19
|
+
prefer_xpu=False,
|
|
20
|
+
prefer_dml=prefer_dml,
|
|
21
|
+
status_cb=status_cb,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from setiastro.saspro.resources import get_resources
|
|
26
|
+
|
|
27
|
+
def _load_torch_superres_model(torch, device, pth_path: str):
|
|
28
|
+
nn = torch.nn
|
|
29
|
+
|
|
30
|
+
class ResidualBlock(nn.Module):
|
|
31
|
+
def __init__(self, channels: int):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
34
|
+
self.relu = nn.ReLU(inplace=True)
|
|
35
|
+
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
36
|
+
|
|
37
|
+
def forward(self, x):
|
|
38
|
+
residual = x
|
|
39
|
+
out = self.relu(self.conv1(x))
|
|
40
|
+
out = self.conv2(out)
|
|
41
|
+
out = out + residual
|
|
42
|
+
return self.relu(out)
|
|
43
|
+
|
|
44
|
+
class SuperResolutionCNN(nn.Module):
|
|
45
|
+
def __init__(self):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(16))
|
|
48
|
+
self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(32))
|
|
49
|
+
self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(inplace=True), ResidualBlock(64))
|
|
50
|
+
self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(128))
|
|
51
|
+
self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(inplace=True), ResidualBlock(256))
|
|
52
|
+
|
|
53
|
+
self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(128))
|
|
54
|
+
self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(64))
|
|
55
|
+
self.decoder3 = nn.Sequential(nn.Conv2d(64 + 32, 32, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(32))
|
|
56
|
+
self.decoder2 = nn.Sequential(nn.Conv2d(32 + 16, 16, 3, padding=1), nn.ReLU(inplace=True), ResidualBlock(16))
|
|
57
|
+
self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
|
|
58
|
+
|
|
59
|
+
def forward(self, x):
|
|
60
|
+
e1 = self.encoder1(x)
|
|
61
|
+
e2 = self.encoder2(e1)
|
|
62
|
+
e3 = self.encoder3(e2)
|
|
63
|
+
e4 = self.encoder4(e3)
|
|
64
|
+
e5 = self.encoder5(e4)
|
|
65
|
+
|
|
66
|
+
d5 = self.decoder5(torch.cat([e5, e4], dim=1))
|
|
67
|
+
d4 = self.decoder4(torch.cat([d5, e3], dim=1))
|
|
68
|
+
d3 = self.decoder3(torch.cat([d4, e2], dim=1))
|
|
69
|
+
d2 = self.decoder2(torch.cat([d3, e1], dim=1))
|
|
70
|
+
return self.decoder1(d2)
|
|
71
|
+
|
|
72
|
+
if not os.path.exists(pth_path):
|
|
73
|
+
raise FileNotFoundError(f"SuperRes model not found: {pth_path}")
|
|
74
|
+
|
|
75
|
+
model = SuperResolutionCNN().to(device)
|
|
76
|
+
sd = torch.load(pth_path, map_location=device)
|
|
77
|
+
if isinstance(sd, dict) and "model_state_dict" in sd:
|
|
78
|
+
sd = sd["model_state_dict"]
|
|
79
|
+
model.load_state_dict(sd)
|
|
80
|
+
model.eval()
|
|
81
|
+
return model
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ----------------------------
|
|
85
|
+
# Shared helpers (copy from denoise_engine if you want)
|
|
86
|
+
# ----------------------------
|
|
87
|
+
def stretch_image(image: np.ndarray):
|
|
88
|
+
original_min = float(np.min(image))
|
|
89
|
+
stretched = image - original_min
|
|
90
|
+
|
|
91
|
+
is_single = (image.ndim == 2) or (image.ndim == 3 and image.shape[2] == 1)
|
|
92
|
+
target_median = 0.25
|
|
93
|
+
|
|
94
|
+
if is_single:
|
|
95
|
+
med = float(np.median(stretched))
|
|
96
|
+
orig_medians = [med]
|
|
97
|
+
if med != 0:
|
|
98
|
+
stretched = ((med - 1) * target_median * stretched) / (med * (target_median + stretched - 1) - target_median * stretched)
|
|
99
|
+
else:
|
|
100
|
+
orig_medians = []
|
|
101
|
+
for c in range(3):
|
|
102
|
+
med = float(np.median(stretched[..., c]))
|
|
103
|
+
orig_medians.append(med)
|
|
104
|
+
if med != 0:
|
|
105
|
+
stretched[..., c] = ((med - 1) * target_median * stretched[..., c]) / (
|
|
106
|
+
med * (target_median + stretched[..., c] - 1) - target_median * stretched[..., c]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return np.clip(stretched, 0, 1).astype(np.float32), original_min, orig_medians
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def unstretch_image(image: np.ndarray, original_medians, original_min: float):
|
|
113
|
+
is_single = (image.ndim == 2) or (image.ndim == 3 and image.shape[2] == 1)
|
|
114
|
+
if is_single:
|
|
115
|
+
med = float(np.median(image))
|
|
116
|
+
if med != 0 and original_medians[0] != 0:
|
|
117
|
+
image = ((med - 1) * original_medians[0] * image) / (med * (original_medians[0] + image - 1) - original_medians[0] * image)
|
|
118
|
+
else:
|
|
119
|
+
for c in range(3):
|
|
120
|
+
med = float(np.median(image[..., c]))
|
|
121
|
+
if med != 0 and original_medians[c] != 0:
|
|
122
|
+
image[..., c] = ((med - 1) * original_medians[c] * image[..., c]) / (
|
|
123
|
+
med * (original_medians[c] + image[..., c] - 1) - original_medians[c] * image[..., c]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
image = image + original_min
|
|
127
|
+
return np.clip(image, 0, 1).astype(np.float32)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def add_border(image: np.ndarray, border_size: int = 16):
|
|
131
|
+
if image.ndim == 2:
|
|
132
|
+
med = float(np.median(image))
|
|
133
|
+
return np.pad(image, ((border_size, border_size), (border_size, border_size)), mode="constant", constant_values=med)
|
|
134
|
+
else:
|
|
135
|
+
meds = np.median(image, axis=(0, 1)).astype(np.float32)
|
|
136
|
+
chans = []
|
|
137
|
+
for c in range(image.shape[2]):
|
|
138
|
+
chans.append(np.pad(image[..., c], ((border_size, border_size), (border_size, border_size)),
|
|
139
|
+
mode="constant", constant_values=float(meds[c])))
|
|
140
|
+
return np.stack(chans, axis=-1)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def remove_border(image: np.ndarray, border_size: int):
|
|
144
|
+
if image.ndim == 2:
|
|
145
|
+
return image[border_size:-border_size, border_size:-border_size]
|
|
146
|
+
return image[border_size:-border_size, border_size:-border_size, :]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def split_image_into_chunks_with_overlap(image: np.ndarray, chunk_size: int = 256, overlap: int = 64):
|
|
150
|
+
h, w = image.shape[:2]
|
|
151
|
+
step = chunk_size - overlap
|
|
152
|
+
chunks = []
|
|
153
|
+
for i in range(0, h, step):
|
|
154
|
+
for j in range(0, w, step):
|
|
155
|
+
end_i = min(i + chunk_size, h)
|
|
156
|
+
end_j = min(j + chunk_size, w)
|
|
157
|
+
patch = image[i:end_i, j:end_j]
|
|
158
|
+
chunks.append((patch, i, j))
|
|
159
|
+
return chunks
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def stitch_chunks_ignore_border(chunks, out_hw: Tuple[int, int], border_size: int = 16):
|
|
163
|
+
H, W = out_hw
|
|
164
|
+
stitched = np.zeros((H, W), dtype=np.float32)
|
|
165
|
+
weight = np.zeros((H, W), dtype=np.float32)
|
|
166
|
+
|
|
167
|
+
for patch, i, j in chunks:
|
|
168
|
+
ph, pw = patch.shape[:2]
|
|
169
|
+
b_h = min(border_size, ph // 2)
|
|
170
|
+
b_w = min(border_size, pw // 2)
|
|
171
|
+
|
|
172
|
+
inner = patch[b_h:ph - b_h, b_w:pw - b_w]
|
|
173
|
+
ih, iw = inner.shape[:2]
|
|
174
|
+
|
|
175
|
+
stitched[i + b_h:i + b_h + ih, j + b_w:j + b_w + iw] += inner
|
|
176
|
+
weight[i + b_h:i + b_h + ih, j + b_w:j + b_w + iw] += 1.0
|
|
177
|
+
|
|
178
|
+
return stitched / np.maximum(weight, 1.0)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# ----------------------------
|
|
182
|
+
# Model loading (cached)
|
|
183
|
+
# ----------------------------
|
|
184
|
+
from typing import Dict, Any, Tuple
|
|
185
|
+
import os
|
|
186
|
+
|
|
187
|
+
_cached: dict[tuple[str, int, bool], dict[str, Any]] = {}
|
|
188
|
+
_BACKEND_TAG = "cc_superres"
|
|
189
|
+
|
|
190
|
+
R = get_resources()
|
|
191
|
+
|
|
192
|
+
def _superres_paths(scale: int) -> tuple[str, str]:
|
|
193
|
+
if scale == 2:
|
|
194
|
+
return (R.CC_SUPERRES_2X_PTH, R.CC_SUPERRES_2X_ONNX)
|
|
195
|
+
if scale == 3:
|
|
196
|
+
return (R.CC_SUPERRES_3X_PTH, R.CC_SUPERRES_3X_ONNX)
|
|
197
|
+
if scale == 4:
|
|
198
|
+
return (R.CC_SUPERRES_4X_PTH, R.CC_SUPERRES_4X_ONNX)
|
|
199
|
+
raise ValueError("scale must be 2, 3, or 4")
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _pick_backend(torch, use_gpu: bool):
|
|
203
|
+
# Prefer: CUDA (PyTorch) -> DML (ONNX) on Windows -> MPS (PyTorch) -> CPU
|
|
204
|
+
if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
205
|
+
return ("pytorch", torch.device("cuda"))
|
|
206
|
+
if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
|
|
207
|
+
return ("onnx", "DirectML")
|
|
208
|
+
if use_gpu and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
209
|
+
return ("pytorch", torch.device("mps"))
|
|
210
|
+
return ("pytorch", torch.device("cpu"))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def load_superres(scale: int, use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
|
|
214
|
+
scale = int(scale)
|
|
215
|
+
if scale not in (2, 3, 4):
|
|
216
|
+
raise ValueError("scale must be 2, 3, or 4")
|
|
217
|
+
|
|
218
|
+
is_windows = (os.name == "nt")
|
|
219
|
+
torch = _get_torch(prefer_cuda=bool(use_gpu), prefer_dml=bool(use_gpu and is_windows), status_cb=status_cb)
|
|
220
|
+
|
|
221
|
+
pth_path, onnx_path = _superres_paths(scale)
|
|
222
|
+
|
|
223
|
+
# --- DEBUG (remove later) ---
|
|
224
|
+
cuda_ok = bool(use_gpu) and hasattr(torch, "cuda") and torch.cuda.is_available()
|
|
225
|
+
dml_ok = bool(use_gpu) and (ort is not None) and ("DmlExecutionProvider" in ort.get_available_providers())
|
|
226
|
+
|
|
227
|
+
# ---------------------------
|
|
228
|
+
|
|
229
|
+
# IMPORTANT: key should include the ACTUAL selected backend/device, not just use_gpu
|
|
230
|
+
# so you can't get stuck reusing CPU from a previous call.
|
|
231
|
+
# We'll decide backend first, then cache.
|
|
232
|
+
|
|
233
|
+
# Prefer torch CUDA if available & allowed (same as sharpen)
|
|
234
|
+
if cuda_ok:
|
|
235
|
+
device = torch.device("cuda")
|
|
236
|
+
status_cb(f"CosmicClarity SuperRes: using CUDA ({torch.cuda.get_device_name(0)})")
|
|
237
|
+
key = (_BACKEND_TAG, scale, "cuda")
|
|
238
|
+
if key in _cached:
|
|
239
|
+
return _cached[key]
|
|
240
|
+
model = _load_torch_superres_model(torch, device, pth_path)
|
|
241
|
+
out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
|
|
242
|
+
_cached[key] = out
|
|
243
|
+
return out
|
|
244
|
+
|
|
245
|
+
# Torch-DirectML (Windows)
|
|
246
|
+
if use_gpu and is_windows:
|
|
247
|
+
try:
|
|
248
|
+
import torch_directml
|
|
249
|
+
dml = torch_directml.device()
|
|
250
|
+
status_cb("CosmicClarity SuperRes: using DirectML (torch-directml)")
|
|
251
|
+
key = (_BACKEND_TAG, scale, "torch_dml")
|
|
252
|
+
if key in _cached:
|
|
253
|
+
return _cached[key]
|
|
254
|
+
model = _load_torch_superres_model(torch, dml, pth_path)
|
|
255
|
+
out = {"backend": "pytorch", "device": dml, "model": model, "scale": scale, "torch": torch}
|
|
256
|
+
_cached[key] = out
|
|
257
|
+
return out
|
|
258
|
+
except Exception:
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# DirectML ONNX fallback (Windows)
|
|
263
|
+
if dml_ok:
|
|
264
|
+
status_cb("CosmicClarity SuperRes: using DirectML (ONNX Runtime)")
|
|
265
|
+
if not os.path.exists(onnx_path):
|
|
266
|
+
raise FileNotFoundError(f"SuperRes ONNX model not found: {onnx_path}")
|
|
267
|
+
key = (_BACKEND_TAG, scale, "dml")
|
|
268
|
+
if key in _cached:
|
|
269
|
+
return _cached[key]
|
|
270
|
+
sess = ort.InferenceSession(onnx_path, providers=["DmlExecutionProvider"])
|
|
271
|
+
out = {"backend": "onnx", "device": "DirectML", "model": sess, "scale": scale, "torch": None}
|
|
272
|
+
_cached[key] = out
|
|
273
|
+
return out
|
|
274
|
+
|
|
275
|
+
# MPS (mac)
|
|
276
|
+
if bool(use_gpu) and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
277
|
+
device = torch.device("mps")
|
|
278
|
+
status_cb("CosmicClarity SuperRes: using MPS")
|
|
279
|
+
key = (_BACKEND_TAG, scale, "mps")
|
|
280
|
+
if key in _cached:
|
|
281
|
+
return _cached[key]
|
|
282
|
+
model = _load_torch_superres_model(torch, device, pth_path)
|
|
283
|
+
out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
|
|
284
|
+
_cached[key] = out
|
|
285
|
+
return out
|
|
286
|
+
|
|
287
|
+
# CPU
|
|
288
|
+
device = torch.device("cpu")
|
|
289
|
+
status_cb("CosmicClarity SuperRes: using CPU")
|
|
290
|
+
key = (_BACKEND_TAG, scale, "cpu")
|
|
291
|
+
if key in _cached:
|
|
292
|
+
return _cached[key]
|
|
293
|
+
model = _load_torch_superres_model(torch, device, pth_path)
|
|
294
|
+
out = {"backend": "pytorch", "device": device, "model": model, "scale": scale, "torch": torch}
|
|
295
|
+
_cached[key] = out
|
|
296
|
+
return out
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _amp_ok(torch, device) -> bool:
|
|
300
|
+
if not isinstance(device, torch.device) or device.type != "cuda":
|
|
301
|
+
return False
|
|
302
|
+
try:
|
|
303
|
+
props = torch.cuda.get_device_properties(device)
|
|
304
|
+
return props.major >= 8
|
|
305
|
+
except Exception:
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def superres_rgb01(
|
|
310
|
+
img_rgb01: np.ndarray,
|
|
311
|
+
*,
|
|
312
|
+
scale: int,
|
|
313
|
+
use_gpu: bool = True,
|
|
314
|
+
progress_cb=None, # progress_cb(done:int,total:int)
|
|
315
|
+
) -> np.ndarray:
|
|
316
|
+
"""
|
|
317
|
+
Input: float32 RGB in [0..1], shape (H,W,3)
|
|
318
|
+
Output: float32 RGB in [0..1], shape (H*scale,W*scale,3)
|
|
319
|
+
"""
|
|
320
|
+
scale = int(scale)
|
|
321
|
+
if scale not in (2, 3, 4):
|
|
322
|
+
raise ValueError("scale must be 2, 3, or 4")
|
|
323
|
+
|
|
324
|
+
engine = load_superres(scale, use_gpu=use_gpu, status_cb=print) # or your logger
|
|
325
|
+
|
|
326
|
+
# We process each channel independently (matches your current behavior)
|
|
327
|
+
H, W = img_rgb01.shape[:2]
|
|
328
|
+
out_chans = []
|
|
329
|
+
|
|
330
|
+
# progress accounting: per-channel chunks
|
|
331
|
+
for c in range(3):
|
|
332
|
+
chan = img_rgb01[..., c].astype(np.float32, copy=False)
|
|
333
|
+
|
|
334
|
+
# border + optional stretch
|
|
335
|
+
bordered = add_border(chan, border_size=16)
|
|
336
|
+
if float(np.median(bordered)) < 0.08:
|
|
337
|
+
stretched, orig_min, orig_meds = stretch_image(bordered)
|
|
338
|
+
stretched_applied = True
|
|
339
|
+
else:
|
|
340
|
+
stretched = bordered.astype(np.float32, copy=False)
|
|
341
|
+
stretched_applied = False
|
|
342
|
+
orig_min = float(np.min(bordered))
|
|
343
|
+
orig_meds = [float(np.median(bordered))]
|
|
344
|
+
|
|
345
|
+
# bicubic upscale
|
|
346
|
+
h, w = stretched.shape[:2]
|
|
347
|
+
up = cv2.resize(stretched, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC)
|
|
348
|
+
|
|
349
|
+
# chunk & infer
|
|
350
|
+
chunks = split_image_into_chunks_with_overlap(up, chunk_size=256, overlap=64)
|
|
351
|
+
total = len(chunks)
|
|
352
|
+
done0 = 0
|
|
353
|
+
|
|
354
|
+
processed = []
|
|
355
|
+
use_amp = (engine["backend"] == "pytorch") and _amp_ok(engine["torch"], engine["device"])
|
|
356
|
+
dev = engine["device"]
|
|
357
|
+
dev_type = getattr(dev, "type", None)
|
|
358
|
+
for idx, (patch, i, j) in enumerate(chunks):
|
|
359
|
+
ph, pw = patch.shape[:2]
|
|
360
|
+
|
|
361
|
+
# build 256x256x3 patch
|
|
362
|
+
patch_in = np.zeros((256, 256, 3), dtype=np.float32)
|
|
363
|
+
patch_in[:ph, :pw, 0] = patch[:ph, :pw]
|
|
364
|
+
patch_in[:ph, :pw, 1] = patch[:ph, :pw]
|
|
365
|
+
patch_in[:ph, :pw, 2] = patch[:ph, :pw]
|
|
366
|
+
|
|
367
|
+
if engine["backend"] == "pytorch":
|
|
368
|
+
t = engine["torch"] # torch module from runtime_torch
|
|
369
|
+
|
|
370
|
+
pt = t.from_numpy(patch_in.transpose(2, 0, 1)).unsqueeze(0).to(engine["device"])
|
|
371
|
+
|
|
372
|
+
with t.no_grad():
|
|
373
|
+
if use_amp and dev_type == "cuda":
|
|
374
|
+
with t.cuda.amp.autocast():
|
|
375
|
+
out = engine["model"](pt)
|
|
376
|
+
else:
|
|
377
|
+
out = engine["model"](pt)
|
|
378
|
+
out_np = out[0].detach().cpu().numpy() # (C,H,W)
|
|
379
|
+
else:
|
|
380
|
+
# ONNX (DirectML)
|
|
381
|
+
inp = np.expand_dims(patch_in.transpose(2, 0, 1), axis=0).astype(np.float32)
|
|
382
|
+
out_np = engine["model"].run(None, {engine["model"].get_inputs()[0].name: inp})[0].squeeze()
|
|
383
|
+
|
|
384
|
+
# output is 3ch grayscale; take first channel
|
|
385
|
+
if out_np.ndim == 3 and out_np.shape[0] == 3:
|
|
386
|
+
out_np = out_np[0]
|
|
387
|
+
elif out_np.ndim == 3 and out_np.shape[-1] == 3:
|
|
388
|
+
out_np = out_np[..., 0]
|
|
389
|
+
|
|
390
|
+
out_np = out_np[:ph, :pw].astype(np.float32, copy=False)
|
|
391
|
+
processed.append((out_np, i, j))
|
|
392
|
+
|
|
393
|
+
done0 += 1
|
|
394
|
+
if progress_cb is not None:
|
|
395
|
+
# You can interpret as global progress across all channels:
|
|
396
|
+
progress_cb((c * total) + done0, 3 * total)
|
|
397
|
+
|
|
398
|
+
# stitch
|
|
399
|
+
stitched = stitch_chunks_ignore_border(processed, up.shape[:2], border_size=16)
|
|
400
|
+
|
|
401
|
+
# unstretch if needed
|
|
402
|
+
if stretched_applied:
|
|
403
|
+
stitched = unstretch_image(stitched, orig_meds, orig_min)
|
|
404
|
+
|
|
405
|
+
# remove scaled border: 16px border became 16*scale after upscaling
|
|
406
|
+
final_border = int(16 * scale)
|
|
407
|
+
out_chan = remove_border(stitched, border_size=final_border)
|
|
408
|
+
|
|
409
|
+
out_chans.append(out_chan)
|
|
410
|
+
|
|
411
|
+
out_rgb = np.stack(out_chans, axis=-1)
|
|
412
|
+
return np.clip(out_rgb, 0.0, 1.0).astype(np.float32, copy=False)
|
|
@@ -2347,6 +2347,20 @@ class AstroSuiteProMainWindow(
|
|
|
2347
2347
|
# Log error but don't crash if preload fails
|
|
2348
2348
|
print(f"Error preloading settings: {e}")
|
|
2349
2349
|
|
|
2350
|
+
def _open_benchmark(self):
|
|
2351
|
+
from setiastro.saspro.ops.benchmark import BenchmarkDialog # new file below
|
|
2352
|
+
|
|
2353
|
+
if getattr(self, "_bench_dlg_cache", None) is None:
|
|
2354
|
+
self._bench_dlg_cache = BenchmarkDialog(self)
|
|
2355
|
+
|
|
2356
|
+
dlg = self._bench_dlg_cache
|
|
2357
|
+
if hasattr(dlg, "refresh_ui"):
|
|
2358
|
+
dlg.refresh_ui()
|
|
2359
|
+
dlg.show()
|
|
2360
|
+
dlg.raise_()
|
|
2361
|
+
dlg.activateWindow()
|
|
2362
|
+
|
|
2363
|
+
|
|
2350
2364
|
def _open_settings(self):
|
|
2351
2365
|
from setiastro.saspro.ops.settings import SettingsDialog
|
|
2352
2366
|
|
|
@@ -333,6 +333,8 @@ class MenuMixin:
|
|
|
333
333
|
|
|
334
334
|
m_settings = mb.addMenu(self.tr("&Settings"))
|
|
335
335
|
m_settings.addAction(self.tr("Preferences..."), self._open_settings)
|
|
336
|
+
m_settings.addSeparator()
|
|
337
|
+
m_settings.addAction(self.tr("Benchmark..."), self._open_benchmark)
|
|
336
338
|
|
|
337
339
|
m_about = mb.addMenu(self.tr("&About"))
|
|
338
340
|
m_about.addAction(self.act_docs)
|