setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
# src/setiastro/saspro/cosmicclarity_engines/sharpen_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Optional, Any
|
|
7
|
+
import os
|
|
8
|
+
import numpy as np
|
|
9
|
+
from setiastro.saspro.resources import get_resources
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Optional deps used by auto-PSF
|
|
13
|
+
try:
|
|
14
|
+
import sep
|
|
15
|
+
except Exception:
|
|
16
|
+
sep = None
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import onnxruntime as ort
|
|
20
|
+
except Exception:
|
|
21
|
+
ort = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
ProgressCB = Callable[[int, int, str], bool] # True=continue, False=cancel
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ---------------- Torch model defs (needed for .pth) ----------------
|
|
28
|
+
def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
|
|
29
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
30
|
+
return import_torch(
|
|
31
|
+
prefer_cuda=prefer_cuda,
|
|
32
|
+
prefer_xpu=False,
|
|
33
|
+
prefer_dml=prefer_dml,
|
|
34
|
+
status_cb=status_cb,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _nullcontext():
|
|
40
|
+
from contextlib import nullcontext
|
|
41
|
+
return nullcontext()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _autocast_context(torch, device) -> Any:
|
|
45
|
+
"""
|
|
46
|
+
Use new torch.amp.autocast('cuda') when available.
|
|
47
|
+
Keep your cap >= 8.0 rule.
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
if hasattr(device, "type") and device.type == "cuda":
|
|
51
|
+
major, minor = torch.cuda.get_device_capability()
|
|
52
|
+
cap = float(f"{major}.{minor}")
|
|
53
|
+
if cap >= 8.0:
|
|
54
|
+
# Preferred API (torch >= 1.10-ish; definitely in 2.x)
|
|
55
|
+
if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
|
|
56
|
+
return torch.amp.autocast(device_type="cuda")
|
|
57
|
+
# Fallback for older torch
|
|
58
|
+
return torch.cuda.amp.autocast()
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
return _nullcontext()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _to_3ch(image: np.ndarray) -> tuple[np.ndarray, bool]:
|
|
66
|
+
"""Return (img3, was_mono). img3 is HxWx3 float32."""
|
|
67
|
+
if image.ndim == 2:
|
|
68
|
+
img3 = np.stack([image, image, image], axis=-1)
|
|
69
|
+
return img3, True
|
|
70
|
+
if image.ndim == 3 and image.shape[2] == 1:
|
|
71
|
+
img = image[..., 0]
|
|
72
|
+
img3 = np.stack([img, img, img], axis=-1)
|
|
73
|
+
return img3, True
|
|
74
|
+
return image, False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Your BT.601 luminance extraction / merge
|
|
78
|
+
def extract_luminance_rgb(image_rgb: np.ndarray):
|
|
79
|
+
image_rgb = np.asarray(image_rgb, dtype=np.float32)
|
|
80
|
+
if image_rgb.shape[-1] != 3:
|
|
81
|
+
raise ValueError("extract_luminance_rgb expects HxWx3")
|
|
82
|
+
M = np.array([[0.299, 0.587, 0.114],
|
|
83
|
+
[-0.168736, -0.331264, 0.5],
|
|
84
|
+
[0.5, -0.418688, -0.081312]], dtype=np.float32)
|
|
85
|
+
ycbcr = image_rgb @ M.T
|
|
86
|
+
y = ycbcr[..., 0]
|
|
87
|
+
cb = ycbcr[..., 1] + 0.5
|
|
88
|
+
cr = ycbcr[..., 2] + 0.5
|
|
89
|
+
return y, cb, cr
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def ycbcr_to_rgb(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
|
|
93
|
+
y = np.asarray(y, np.float32)
|
|
94
|
+
cb = np.asarray(cb, np.float32) - 0.5
|
|
95
|
+
cr = np.asarray(cr, np.float32) - 0.5
|
|
96
|
+
ycbcr = np.stack([y, cb, cr], axis=-1)
|
|
97
|
+
M = np.array([[1.0, 0.0, 1.402],
|
|
98
|
+
[1.0, -0.344136, -0.714136],
|
|
99
|
+
[1.0, 1.772, 0.0]], dtype=np.float32)
|
|
100
|
+
rgb = ycbcr @ M.T
|
|
101
|
+
return np.clip(rgb, 0.0, 1.0)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def merge_luminance(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
|
|
105
|
+
return ycbcr_to_rgb(np.clip(y, 0, 1), np.clip(cb, 0, 1), np.clip(cr, 0, 1))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ---------------- Chunking & stitching (your exact behavior) ----------------
|
|
109
|
+
|
|
110
|
+
def split_image_into_chunks_with_overlap(image2d: np.ndarray, chunk_size: int, overlap: int):
|
|
111
|
+
H, W = image2d.shape
|
|
112
|
+
step = chunk_size - overlap
|
|
113
|
+
out = []
|
|
114
|
+
for i in range(0, H, step):
|
|
115
|
+
for j in range(0, W, step):
|
|
116
|
+
ei = min(i + chunk_size, H)
|
|
117
|
+
ej = min(j + chunk_size, W)
|
|
118
|
+
chunk = image2d[i:ei, j:ej]
|
|
119
|
+
is_edge = (i == 0) or (j == 0) or (i + chunk_size >= H) or (j + chunk_size >= W)
|
|
120
|
+
out.append((chunk, i, j, is_edge))
|
|
121
|
+
return out
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def stitch_chunks_ignore_border(chunks, image_shape, border_size: int = 16):
|
|
125
|
+
stitched = np.zeros(image_shape, dtype=np.float32)
|
|
126
|
+
weights = np.zeros(image_shape, dtype=np.float32)
|
|
127
|
+
|
|
128
|
+
for chunk, i, j, _is_edge in chunks:
|
|
129
|
+
h, w = chunk.shape
|
|
130
|
+
bh = min(border_size, h // 2)
|
|
131
|
+
bw = min(border_size, w // 2)
|
|
132
|
+
inner = chunk[bh:h-bh, bw:w-bw]
|
|
133
|
+
stitched[i+bh:i+h-bh, j+bw:j+w-bw] += inner
|
|
134
|
+
weights[i+bh:i+h-bh, j+bw:j+w-bw] += 1.0
|
|
135
|
+
|
|
136
|
+
stitched /= np.maximum(weights, 1.0)
|
|
137
|
+
return stitched
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def add_border(image: np.ndarray, border_size: int = 16) -> np.ndarray:
|
|
141
|
+
med = float(np.median(image))
|
|
142
|
+
if image.ndim == 2:
|
|
143
|
+
return np.pad(image, ((border_size, border_size), (border_size, border_size)),
|
|
144
|
+
mode="constant", constant_values=med)
|
|
145
|
+
return np.pad(image, ((border_size, border_size), (border_size, border_size), (0, 0)),
|
|
146
|
+
mode="constant", constant_values=med)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def remove_border(image: np.ndarray, border_size: int = 16) -> np.ndarray:
|
|
150
|
+
if image.ndim == 2:
|
|
151
|
+
return image[border_size:-border_size, border_size:-border_size]
|
|
152
|
+
return image[border_size:-border_size, border_size:-border_size, :]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def blend_images(before: np.ndarray, after: np.ndarray, amount: float) -> np.ndarray:
|
|
156
|
+
a = float(np.clip(amount, 0.0, 1.0))
|
|
157
|
+
return (1.0 - a) * before + a * after
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# ---------------- Stretch / unstretch (your current logic) ----------------
|
|
161
|
+
|
|
162
|
+
def stretch_image_unlinked_rgb(image_rgb: np.ndarray, target_median: float = 0.25):
|
|
163
|
+
x = image_rgb.astype(np.float32, copy=True)
|
|
164
|
+
orig_min = float(np.min(x))
|
|
165
|
+
x -= orig_min
|
|
166
|
+
orig_meds = [float(np.median(x[..., c])) for c in range(3)]
|
|
167
|
+
|
|
168
|
+
for c in range(3):
|
|
169
|
+
m = orig_meds[c]
|
|
170
|
+
if m != 0:
|
|
171
|
+
x[..., c] = ((m - 1) * target_median * x[..., c]) / (
|
|
172
|
+
m * (target_median + x[..., c] - 1) - target_median * x[..., c]
|
|
173
|
+
)
|
|
174
|
+
x = np.clip(x, 0, 1)
|
|
175
|
+
return x, orig_min, orig_meds
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def unstretch_image_unlinked_rgb(image_rgb: np.ndarray, orig_meds, orig_min: float, was_mono: bool):
|
|
179
|
+
x = image_rgb.astype(np.float32, copy=True)
|
|
180
|
+
for c in range(3):
|
|
181
|
+
m_now = float(np.median(x[..., c]))
|
|
182
|
+
m0 = float(orig_meds[c])
|
|
183
|
+
if m_now != 0 and m0 != 0:
|
|
184
|
+
x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
|
|
185
|
+
m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
|
|
186
|
+
)
|
|
187
|
+
x += float(orig_min)
|
|
188
|
+
x = np.clip(x, 0, 1)
|
|
189
|
+
if was_mono:
|
|
190
|
+
# match your behavior: return mono with keepdims
|
|
191
|
+
x = np.mean(x, axis=2, keepdims=True)
|
|
192
|
+
return x
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ---------------- Auto PSF per chunk (SEP) ----------------
|
|
196
|
+
|
|
197
|
+
def measure_psf_fwhm(chunk2d: np.ndarray, default_fwhm: float = 3.0) -> float:
|
|
198
|
+
if sep is None:
|
|
199
|
+
return default_fwhm
|
|
200
|
+
try:
|
|
201
|
+
data = chunk2d.astype(np.float32, copy=False)
|
|
202
|
+
bkg = sep.Background(data)
|
|
203
|
+
sub = data - bkg.back()
|
|
204
|
+
rms = bkg.rms()
|
|
205
|
+
if rms.size == 0:
|
|
206
|
+
return default_fwhm
|
|
207
|
+
objs = sep.extract(sub, 1.5, err=rms)
|
|
208
|
+
fwhms = []
|
|
209
|
+
for o in objs:
|
|
210
|
+
if o["npix"] < 5:
|
|
211
|
+
continue
|
|
212
|
+
sigma = float(np.sqrt(o["a"] * o["b"]))
|
|
213
|
+
fwhm = sigma * 2.0 * np.sqrt(2.0 * np.log(2.0))
|
|
214
|
+
fwhms.append(fwhm)
|
|
215
|
+
return float(np.median(fwhms) * 0.5) if fwhms else default_fwhm
|
|
216
|
+
except Exception:
|
|
217
|
+
return default_fwhm
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# ---------------- Model bundle + loading ----------------
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class SharpenModels:
|
|
224
|
+
device: Any # torch.device or "DirectML"
|
|
225
|
+
is_onnx: bool
|
|
226
|
+
stellar: Any
|
|
227
|
+
ns1: Any
|
|
228
|
+
ns2: Any
|
|
229
|
+
ns4: Any
|
|
230
|
+
ns8: Any
|
|
231
|
+
torch: Any | None = None # set for torch path
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
_MODELS_CACHE: dict[tuple[str, bool], SharpenModels] = {} # (backend_tag, use_gpu)
|
|
235
|
+
|
|
236
|
+
def load_sharpen_models(use_gpu: bool, status_cb=print) -> SharpenModels:
|
|
237
|
+
backend_tag = "cc_sharpen_ai3_5s"
|
|
238
|
+
key = (backend_tag, bool(use_gpu))
|
|
239
|
+
if key in _MODELS_CACHE:
|
|
240
|
+
return _MODELS_CACHE[key]
|
|
241
|
+
|
|
242
|
+
is_windows = (os.name == "nt")
|
|
243
|
+
|
|
244
|
+
# ask runtime to prefer DML only on Windows + when user wants GPU
|
|
245
|
+
torch = _get_torch(
|
|
246
|
+
prefer_cuda=bool(use_gpu), # still try CUDA first
|
|
247
|
+
prefer_dml=bool(use_gpu and is_windows), # enable DML install/usage path
|
|
248
|
+
status_cb=status_cb,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# 1) CUDA
|
|
252
|
+
if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
253
|
+
device = torch.device("cuda")
|
|
254
|
+
status_cb(f"CosmicClarity Sharpen: using CUDA ({torch.cuda.get_device_name(0)})")
|
|
255
|
+
models = _load_torch_models(torch, device)
|
|
256
|
+
_MODELS_CACHE[key] = models
|
|
257
|
+
return models
|
|
258
|
+
|
|
259
|
+
# 2) Torch-DirectML (Windows)
|
|
260
|
+
if use_gpu and is_windows:
|
|
261
|
+
try:
|
|
262
|
+
import torch_directml # provided by torch-directml
|
|
263
|
+
dml = torch_directml.device()
|
|
264
|
+
status_cb("CosmicClarity Sharpen: using DirectML (torch-directml)")
|
|
265
|
+
models = _load_torch_models(torch, dml)
|
|
266
|
+
_MODELS_CACHE[key] = models
|
|
267
|
+
status_cb(f"torch.__version__={getattr(torch,'__version__',None)}; "
|
|
268
|
+
f"cuda_available={bool(getattr(torch,'cuda',None) and torch.cuda.is_available())}")
|
|
269
|
+
return models
|
|
270
|
+
|
|
271
|
+
except Exception:
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
# 3) ONNX Runtime DirectML fallback
|
|
275
|
+
if use_gpu and ort is not None and "DmlExecutionProvider" in ort.get_available_providers():
|
|
276
|
+
status_cb("CosmicClarity Sharpen: using DirectML (ONNX Runtime)")
|
|
277
|
+
models = _load_onnx_models()
|
|
278
|
+
_MODELS_CACHE[key] = models
|
|
279
|
+
return models
|
|
280
|
+
|
|
281
|
+
# 4) CPU
|
|
282
|
+
device = torch.device("cpu")
|
|
283
|
+
status_cb("CosmicClarity Sharpen: using CPU")
|
|
284
|
+
models = _load_torch_models(torch, device)
|
|
285
|
+
_MODELS_CACHE[key] = models
|
|
286
|
+
status_cb(f"Sharpen backend resolved: "
|
|
287
|
+
f"{'onnx' if models.is_onnx else 'torch'} / device={models.device!r}")
|
|
288
|
+
|
|
289
|
+
return models
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _load_onnx_models() -> SharpenModels:
|
|
293
|
+
assert ort is not None
|
|
294
|
+
prov = ["DmlExecutionProvider"]
|
|
295
|
+
R = get_resources()
|
|
296
|
+
|
|
297
|
+
def s(path: str):
|
|
298
|
+
return ort.InferenceSession(path, providers=prov)
|
|
299
|
+
|
|
300
|
+
return SharpenModels(
|
|
301
|
+
device="DirectML",
|
|
302
|
+
is_onnx=True,
|
|
303
|
+
stellar=s(R.CC_STELLAR_SHARP_ONNX),
|
|
304
|
+
ns1=s(R.CC_NS1_ONNX),
|
|
305
|
+
ns2=s(R.CC_NS2_ONNX),
|
|
306
|
+
ns4=s(R.CC_NS4_ONNX),
|
|
307
|
+
ns8=s(R.CC_NS8_ONNX),
|
|
308
|
+
torch=None,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _load_torch_models(torch, device) -> SharpenModels:
|
|
314
|
+
import torch.nn as nn # comes from runtime torch env
|
|
315
|
+
|
|
316
|
+
class ResidualBlock(nn.Module):
|
|
317
|
+
def __init__(self, channels):
|
|
318
|
+
super().__init__()
|
|
319
|
+
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
320
|
+
self.relu = nn.ReLU()
|
|
321
|
+
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
322
|
+
def forward(self, x):
|
|
323
|
+
r = x
|
|
324
|
+
x = self.relu(self.conv1(x))
|
|
325
|
+
x = self.conv2(x)
|
|
326
|
+
x = self.relu(x + r)
|
|
327
|
+
return x
|
|
328
|
+
|
|
329
|
+
class SharpeningCNN(nn.Module):
|
|
330
|
+
def __init__(self):
|
|
331
|
+
super().__init__()
|
|
332
|
+
self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
|
|
333
|
+
self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
|
|
334
|
+
self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(64))
|
|
335
|
+
self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
|
|
336
|
+
self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(256))
|
|
337
|
+
|
|
338
|
+
self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
|
|
339
|
+
self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(), ResidualBlock(64))
|
|
340
|
+
self.decoder3 = nn.Sequential(nn.Conv2d( 64 + 32, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
|
|
341
|
+
self.decoder2 = nn.Sequential(nn.Conv2d( 32 + 16, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
|
|
342
|
+
self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
|
|
343
|
+
|
|
344
|
+
def forward(self, x):
|
|
345
|
+
e1 = self.encoder1(x)
|
|
346
|
+
e2 = self.encoder2(e1)
|
|
347
|
+
e3 = self.encoder3(e2)
|
|
348
|
+
e4 = self.encoder4(e3)
|
|
349
|
+
e5 = self.encoder5(e4)
|
|
350
|
+
d5 = self.decoder5(torch.cat([e5, e4], dim=1))
|
|
351
|
+
d4 = self.decoder4(torch.cat([d5, e3], dim=1))
|
|
352
|
+
d3 = self.decoder3(torch.cat([d4, e2], dim=1))
|
|
353
|
+
d2 = self.decoder2(torch.cat([d3, e1], dim=1))
|
|
354
|
+
return self.decoder1(d2)
|
|
355
|
+
|
|
356
|
+
R = get_resources()
|
|
357
|
+
|
|
358
|
+
def m(path: str):
|
|
359
|
+
net = SharpeningCNN()
|
|
360
|
+
net.load_state_dict(torch.load(path, map_location=device))
|
|
361
|
+
net.eval().to(device)
|
|
362
|
+
return net
|
|
363
|
+
|
|
364
|
+
return SharpenModels(
|
|
365
|
+
device=device,
|
|
366
|
+
is_onnx=False,
|
|
367
|
+
stellar=m(R.CC_STELLAR_SHARP_PTH),
|
|
368
|
+
ns1=m(R.CC_NS1_PTH),
|
|
369
|
+
ns2=m(R.CC_NS2_PTH),
|
|
370
|
+
ns4=m(R.CC_NS4_PTH),
|
|
371
|
+
ns8=m(R.CC_NS8_PTH),
|
|
372
|
+
torch=torch,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# ---------------- Inference helpers ----------------
|
|
378
|
+
|
|
379
|
+
def _infer_chunk(models: SharpenModels, model: Any, chunk2d: np.ndarray) -> np.ndarray:
|
|
380
|
+
"""Returns 2D float32 (cropped to original chunk shape)."""
|
|
381
|
+
h0, w0 = chunk2d.shape
|
|
382
|
+
|
|
383
|
+
if models.is_onnx:
|
|
384
|
+
inp = chunk2d[np.newaxis, np.newaxis, :, :].astype(np.float32) # (1,1,H,W)
|
|
385
|
+
inp = np.tile(inp, (1, 3, 1, 1)) # (1,3,H,W)
|
|
386
|
+
h, w = inp.shape[2:]
|
|
387
|
+
if (h != 256) or (w != 256):
|
|
388
|
+
pad = np.zeros((1, 3, 256, 256), dtype=np.float32)
|
|
389
|
+
pad[:, :, :h, :w] = inp
|
|
390
|
+
inp = pad
|
|
391
|
+
name_in = model.get_inputs()[0].name
|
|
392
|
+
name_out = model.get_outputs()[0].name
|
|
393
|
+
out = model.run([name_out], {name_in: inp})[0][0, 0]
|
|
394
|
+
return out[:h0, :w0].astype(np.float32, copy=False)
|
|
395
|
+
|
|
396
|
+
# torch path
|
|
397
|
+
torch = models.torch
|
|
398
|
+
dev = models.device
|
|
399
|
+
t = torch.tensor(chunk2d, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(dev)
|
|
400
|
+
with torch.no_grad(), _autocast_context(torch, dev):
|
|
401
|
+
y = model(t.repeat(1, 3, 1, 1)).squeeze().detach().cpu().numpy()[0]
|
|
402
|
+
return y[:h0, :w0].astype(np.float32, copy=False)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# ---------------- Main API ----------------
|
|
406
|
+
|
|
407
|
+
@dataclass
|
|
408
|
+
class SharpenParams:
|
|
409
|
+
mode: str # "Both" | "Stellar Only" | "Non-Stellar Only"
|
|
410
|
+
stellar_amount: float # 0..1
|
|
411
|
+
nonstellar_amount: float # 0..1
|
|
412
|
+
nonstellar_strength: float # 1..8 (ignored if auto_detect_psf True)
|
|
413
|
+
sharpen_channels_separately: bool
|
|
414
|
+
auto_detect_psf: bool
|
|
415
|
+
use_gpu: bool
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def sharpen_image_array(image: np.ndarray,
|
|
419
|
+
params: SharpenParams,
|
|
420
|
+
progress_cb: Optional[ProgressCB] = None,
|
|
421
|
+
status_cb=print) -> tuple[np.ndarray, bool]:
|
|
422
|
+
"""
|
|
423
|
+
Pure in-memory sharpen. Returns (out_image, was_mono).
|
|
424
|
+
"""
|
|
425
|
+
if progress_cb is None:
|
|
426
|
+
progress_cb = lambda done, total, stage: True
|
|
427
|
+
|
|
428
|
+
img = np.asarray(image)
|
|
429
|
+
if img.dtype != np.float32:
|
|
430
|
+
img = img.astype(np.float32, copy=False)
|
|
431
|
+
|
|
432
|
+
img3, was_mono = _to_3ch(img)
|
|
433
|
+
img3 = np.clip(img3, 0.0, 1.0)
|
|
434
|
+
|
|
435
|
+
models = load_sharpen_models(use_gpu=params.use_gpu, status_cb=status_cb)
|
|
436
|
+
|
|
437
|
+
# border & stretch
|
|
438
|
+
bordered = add_border(img3, border_size=16)
|
|
439
|
+
stretch_needed = (np.median(bordered - np.min(bordered)) < 0.08)
|
|
440
|
+
|
|
441
|
+
if stretch_needed:
|
|
442
|
+
stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(bordered)
|
|
443
|
+
else:
|
|
444
|
+
stretched, orig_min, orig_meds = bordered, None, None
|
|
445
|
+
|
|
446
|
+
# per-channel sharpening option (color only)
|
|
447
|
+
if params.sharpen_channels_separately and (not was_mono):
|
|
448
|
+
out = np.empty_like(stretched)
|
|
449
|
+
for c, label in enumerate(("R", "G", "B")):
|
|
450
|
+
progress_cb(0, 1, f"Sharpening {label} channel")
|
|
451
|
+
out[..., c] = _sharpen_plane(models, stretched[..., c], params, progress_cb)
|
|
452
|
+
sharpened = out
|
|
453
|
+
else:
|
|
454
|
+
# luminance pipeline (works for mono too, since mono is in all 3 chans)
|
|
455
|
+
y, cb, cr = extract_luminance_rgb(stretched)
|
|
456
|
+
y2 = _sharpen_plane(models, y, params, progress_cb)
|
|
457
|
+
sharpened = merge_luminance(y2, cb, cr)
|
|
458
|
+
|
|
459
|
+
# unstretch / deborder
|
|
460
|
+
if stretch_needed:
|
|
461
|
+
sharpened = unstretch_image_unlinked_rgb(sharpened, orig_meds, orig_min, was_mono)
|
|
462
|
+
|
|
463
|
+
sharpened = remove_border(sharpened, border_size=16)
|
|
464
|
+
|
|
465
|
+
# return mono as HxWx1 if it came in mono (matches your CC behavior)
|
|
466
|
+
if was_mono:
|
|
467
|
+
if sharpened.ndim == 3 and sharpened.shape[2] == 3:
|
|
468
|
+
sharpened = np.mean(sharpened, axis=2, keepdims=True).astype(np.float32, copy=False)
|
|
469
|
+
|
|
470
|
+
return np.clip(sharpened, 0.0, 1.0), was_mono
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _sharpen_plane(models: SharpenModels,
|
|
474
|
+
plane: np.ndarray,
|
|
475
|
+
params: SharpenParams,
|
|
476
|
+
progress_cb: ProgressCB) -> np.ndarray:
|
|
477
|
+
"""
|
|
478
|
+
Sharpen a single 2D plane using your two-stage pipeline.
|
|
479
|
+
"""
|
|
480
|
+
plane = np.asarray(plane, np.float32)
|
|
481
|
+
chunks = split_image_into_chunks_with_overlap(plane, chunk_size=256, overlap=64)
|
|
482
|
+
total = len(chunks)
|
|
483
|
+
|
|
484
|
+
# Stage 1: stellar
|
|
485
|
+
if params.mode in ("Stellar Only", "Both"):
|
|
486
|
+
out_chunks = []
|
|
487
|
+
for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
|
|
488
|
+
y = _infer_chunk(models, models.stellar, chunk)
|
|
489
|
+
blended = blend_images(chunk, y, params.stellar_amount)
|
|
490
|
+
out_chunks.append((blended, i, j, is_edge))
|
|
491
|
+
if progress_cb(k, total, "Stellar sharpening") is False:
|
|
492
|
+
|
|
493
|
+
return plane
|
|
494
|
+
plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
|
|
495
|
+
|
|
496
|
+
if params.mode == "Stellar Only":
|
|
497
|
+
return plane
|
|
498
|
+
|
|
499
|
+
# update chunks for stage 2
|
|
500
|
+
chunks = split_image_into_chunks_with_overlap(plane, chunk_size=256, overlap=64)
|
|
501
|
+
total = len(chunks)
|
|
502
|
+
|
|
503
|
+
# Stage 2: non-stellar
|
|
504
|
+
if params.mode in ("Non-Stellar Only", "Both"):
|
|
505
|
+
out_chunks = []
|
|
506
|
+
radii = np.array([1.0, 2.0, 4.0, 8.0], dtype=float)
|
|
507
|
+
model_map = {1.0: models.ns1, 2.0: models.ns2, 4.0: models.ns4, 8.0: models.ns8}
|
|
508
|
+
|
|
509
|
+
for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
|
|
510
|
+
if params.auto_detect_psf:
|
|
511
|
+
fwhm = measure_psf_fwhm(chunk, default_fwhm=3.0)
|
|
512
|
+
r = float(np.clip(fwhm, radii[0], radii[-1]))
|
|
513
|
+
else:
|
|
514
|
+
r = float(np.clip(params.nonstellar_strength, radii[0], radii[-1]))
|
|
515
|
+
|
|
516
|
+
idx = int(np.searchsorted(radii, r, side="left"))
|
|
517
|
+
if idx <= 0:
|
|
518
|
+
lo = hi = radii[0]
|
|
519
|
+
elif idx >= len(radii):
|
|
520
|
+
lo = hi = radii[-1]
|
|
521
|
+
else:
|
|
522
|
+
lo, hi = radii[idx-1], radii[idx]
|
|
523
|
+
|
|
524
|
+
if lo == hi:
|
|
525
|
+
y = _infer_chunk(models, model_map[lo], chunk)
|
|
526
|
+
else:
|
|
527
|
+
w = (r - lo) / (hi - lo)
|
|
528
|
+
y0 = _infer_chunk(models, model_map[lo], chunk)
|
|
529
|
+
y1 = _infer_chunk(models, model_map[hi], chunk)
|
|
530
|
+
y = (1.0 - w) * y0 + w * y1
|
|
531
|
+
|
|
532
|
+
blended = blend_images(chunk, y, params.nonstellar_amount)
|
|
533
|
+
out_chunks.append((blended, i, j, is_edge))
|
|
534
|
+
if progress_cb(k, total, "Non-stellar sharpening") is False:
|
|
535
|
+
return plane
|
|
536
|
+
progress_cb(k, total, "Non-stellar sharpening")
|
|
537
|
+
|
|
538
|
+
plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
|
|
539
|
+
|
|
540
|
+
return plane
|
|
541
|
+
|
|
542
|
+
def sharpen_rgb01(
|
|
543
|
+
image_rgb01: np.ndarray,
|
|
544
|
+
*,
|
|
545
|
+
sharpening_mode: str = "Both",
|
|
546
|
+
stellar_amount: float = 0.5,
|
|
547
|
+
nonstellar_amount: float = 0.5,
|
|
548
|
+
nonstellar_strength: float = 3.0,
|
|
549
|
+
auto_detect_psf: bool = True,
|
|
550
|
+
separate_channels: bool = False,
|
|
551
|
+
use_gpu: bool = True,
|
|
552
|
+
progress_cb: Optional[Callable[[int, int], bool]] = None,
|
|
553
|
+
status_cb=print,
|
|
554
|
+
) -> np.ndarray:
|
|
555
|
+
"""
|
|
556
|
+
Backward-compatible API for SASpro CosmicClarityDialogPro.
|
|
557
|
+
Expects/returns float32 RGB in [0,1]. If input is mono, returns HxWx1.
|
|
558
|
+
progress_cb signature: (done, total) -> UI-friendly.
|
|
559
|
+
"""
|
|
560
|
+
# Adapt UI progress_cb(done,total) -> engine progress_cb(done,total,stage)
|
|
561
|
+
if progress_cb is None:
|
|
562
|
+
def _prog(done, total, stage): # noqa
|
|
563
|
+
return True
|
|
564
|
+
else:
|
|
565
|
+
def _prog(done, total, stage):
|
|
566
|
+
try:
|
|
567
|
+
return bool(progress_cb(int(done), int(total)))
|
|
568
|
+
except Exception:
|
|
569
|
+
return True
|
|
570
|
+
|
|
571
|
+
params = SharpenParams(
|
|
572
|
+
mode=str(sharpening_mode),
|
|
573
|
+
stellar_amount=float(stellar_amount),
|
|
574
|
+
nonstellar_amount=float(nonstellar_amount),
|
|
575
|
+
nonstellar_strength=float(nonstellar_strength),
|
|
576
|
+
sharpen_channels_separately=bool(separate_channels),
|
|
577
|
+
auto_detect_psf=bool(auto_detect_psf),
|
|
578
|
+
use_gpu=bool(use_gpu),
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
out, _was_mono = sharpen_image_array(
|
|
582
|
+
image_rgb01,
|
|
583
|
+
params=params,
|
|
584
|
+
progress_cb=_prog,
|
|
585
|
+
status_cb=status_cb,
|
|
586
|
+
)
|
|
587
|
+
return np.asarray(out, dtype=np.float32)
|