setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,567 @@
|
|
|
1
|
+
# src/setiastro/saspro/cosmicclarity_engines/denoise_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import warnings
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Optional, Dict, Any, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import cv2
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from setiastro.saspro.resources import get_resources
|
|
15
|
+
|
|
16
|
+
warnings.filterwarnings("ignore")
|
|
17
|
+
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
ProgressCB = Callable[[int, int], None] # (done, total)
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import onnxruntime as ort
|
|
24
|
+
except Exception:
|
|
25
|
+
ort = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
|
|
29
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
30
|
+
return import_torch(
|
|
31
|
+
prefer_cuda=prefer_cuda,
|
|
32
|
+
prefer_xpu=False,
|
|
33
|
+
prefer_dml=prefer_dml,
|
|
34
|
+
status_cb=status_cb,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _nullcontext():
|
|
39
|
+
from contextlib import nullcontext
|
|
40
|
+
return nullcontext()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _autocast_context(torch, device) -> Any:
|
|
44
|
+
"""
|
|
45
|
+
Use new torch.amp.autocast('cuda') when available.
|
|
46
|
+
Keep your cap >= 8.0 rule.
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
if hasattr(device, "type") and device.type == "cuda":
|
|
50
|
+
major, minor = torch.cuda.get_device_capability()
|
|
51
|
+
cap = float(f"{major}.{minor}")
|
|
52
|
+
if cap >= 8.0:
|
|
53
|
+
# Preferred API (torch >= 1.10-ish; definitely in 2.x)
|
|
54
|
+
if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
|
|
55
|
+
return torch.amp.autocast(device_type="cuda")
|
|
56
|
+
# Fallback for older torch
|
|
57
|
+
return torch.cuda.amp.autocast()
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
return _nullcontext()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# ----------------------------
|
|
65
|
+
# Model definitions (unchanged)
|
|
66
|
+
# ----------------------------
|
|
67
|
+
def _load_torch_model(torch, device, ckpt_path: str):
|
|
68
|
+
nn = torch.nn
|
|
69
|
+
|
|
70
|
+
class ResidualBlock(nn.Module):
|
|
71
|
+
def __init__(self, channels: int):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
74
|
+
self.relu = nn.ReLU()
|
|
75
|
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
76
|
+
|
|
77
|
+
def forward(self, x):
|
|
78
|
+
residual = x
|
|
79
|
+
out = self.relu(self.conv1(x))
|
|
80
|
+
out = self.conv2(out)
|
|
81
|
+
out = self.relu(out + residual)
|
|
82
|
+
return out
|
|
83
|
+
|
|
84
|
+
class DenoiseCNN(nn.Module):
|
|
85
|
+
def __init__(self):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
|
|
88
|
+
self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
|
|
89
|
+
self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(64))
|
|
90
|
+
self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
|
|
91
|
+
self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(256))
|
|
92
|
+
|
|
93
|
+
self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
|
|
94
|
+
self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(), ResidualBlock(64))
|
|
95
|
+
self.decoder3 = nn.Sequential(nn.Conv2d( 64 + 32, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
|
|
96
|
+
self.decoder2 = nn.Sequential(nn.Conv2d( 32 + 16, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
|
|
97
|
+
self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
|
|
98
|
+
|
|
99
|
+
def forward(self, x):
|
|
100
|
+
e1 = self.encoder1(x)
|
|
101
|
+
e2 = self.encoder2(e1)
|
|
102
|
+
e3 = self.encoder3(e2)
|
|
103
|
+
e4 = self.encoder4(e3)
|
|
104
|
+
e5 = self.encoder5(e4)
|
|
105
|
+
|
|
106
|
+
d5 = self.decoder5(torch.cat([e5, e4], dim=1))
|
|
107
|
+
d4 = self.decoder4(torch.cat([d5, e3], dim=1))
|
|
108
|
+
d3 = self.decoder3(torch.cat([d4, e2], dim=1))
|
|
109
|
+
d2 = self.decoder2(torch.cat([d3, e1], dim=1))
|
|
110
|
+
return self.decoder1(d2)
|
|
111
|
+
|
|
112
|
+
net = DenoiseCNN().to(device)
|
|
113
|
+
ckpt = torch.load(ckpt_path, map_location=device)
|
|
114
|
+
net.load_state_dict(ckpt.get("model_state_dict", ckpt))
|
|
115
|
+
net.eval()
|
|
116
|
+
return net
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ----------------------------
|
|
120
|
+
# Model cache
|
|
121
|
+
# ----------------------------
|
|
122
|
+
_cached_models: dict[tuple[str, bool], Dict[str, Any]] = {} # (backend_tag, use_gpu)
|
|
123
|
+
_BACKEND_TAG = "cc_denoise_ai3_6"
|
|
124
|
+
|
|
125
|
+
R = get_resources()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def load_models(use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
|
|
129
|
+
key = (_BACKEND_TAG, bool(use_gpu))
|
|
130
|
+
if key in _cached_models:
|
|
131
|
+
return _cached_models[key]
|
|
132
|
+
|
|
133
|
+
is_windows = (os.name == "nt")
|
|
134
|
+
|
|
135
|
+
torch = _get_torch(
|
|
136
|
+
prefer_cuda=bool(use_gpu),
|
|
137
|
+
prefer_dml=bool(use_gpu and is_windows),
|
|
138
|
+
status_cb=status_cb,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# 1) CUDA
|
|
142
|
+
if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
143
|
+
device = torch.device("cuda")
|
|
144
|
+
status_cb(f"CosmicClarity Denoise: using CUDA ({torch.cuda.get_device_name(0)})")
|
|
145
|
+
mono_model = _load_torch_model(torch, device, R.CC_DENOISE_PTH)
|
|
146
|
+
models = {"device": device, "is_onnx": False, "mono_model": mono_model, "torch": torch}
|
|
147
|
+
status_cb(f"Denoise backend resolved: "
|
|
148
|
+
f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
|
|
149
|
+
_cached_models[key] = models
|
|
150
|
+
return models
|
|
151
|
+
|
|
152
|
+
# 2) Torch-DirectML (Windows)
|
|
153
|
+
if use_gpu and is_windows:
|
|
154
|
+
try:
|
|
155
|
+
import torch_directml
|
|
156
|
+
dml = torch_directml.device()
|
|
157
|
+
status_cb("CosmicClarity Denoise: using DirectML (torch-directml)")
|
|
158
|
+
mono_model = _load_torch_model(torch, dml, R.CC_DENOISE_PTH)
|
|
159
|
+
models = {"device": dml, "is_onnx": False, "mono_model": mono_model, "torch": torch}
|
|
160
|
+
status_cb(f"Denoise backend resolved: "
|
|
161
|
+
f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
|
|
162
|
+
_cached_models[key] = models
|
|
163
|
+
return models
|
|
164
|
+
except Exception:
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
# 3) ORT DirectML fallback
|
|
168
|
+
if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
|
|
169
|
+
status_cb("CosmicClarity Denoise: using DirectML (ONNX Runtime)")
|
|
170
|
+
mono_model = ort.InferenceSession(R.CC_DENOISE_ONNX, providers=["DmlExecutionProvider"])
|
|
171
|
+
models = {"device": "DirectML", "is_onnx": True, "mono_model": mono_model, "torch": None}
|
|
172
|
+
status_cb(f"Denoise backend resolved: "
|
|
173
|
+
f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
|
|
174
|
+
_cached_models[key] = models
|
|
175
|
+
return models
|
|
176
|
+
|
|
177
|
+
# 4) CPU
|
|
178
|
+
device = torch.device("cpu")
|
|
179
|
+
status_cb("CosmicClarity Denoise: using CPU")
|
|
180
|
+
mono_model = _load_torch_model(torch, device, R.CC_DENOISE_PTH)
|
|
181
|
+
models = {"device": device, "is_onnx": False, "mono_model": mono_model, "torch": torch}
|
|
182
|
+
status_cb(f"Denoise backend resolved: "
|
|
183
|
+
f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
|
|
184
|
+
_cached_models[key] = models
|
|
185
|
+
return models
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ----------------------------
|
|
189
|
+
# Your helpers: luminance/chroma, chunks, borders, stretch
|
|
190
|
+
# (paste your existing implementations here)
|
|
191
|
+
# ----------------------------
|
|
192
|
+
def extract_luminance(image: np.ndarray):
|
|
193
|
+
"""
|
|
194
|
+
Input: mono HxW, mono HxWx1, or RGB HxWx3 float32 in [0,1].
|
|
195
|
+
Output: (Y, Cb, Cr) where:
|
|
196
|
+
- Y is HxW
|
|
197
|
+
- Cb/Cr are HxW in [0,1] (with +0.5 offset already applied)
|
|
198
|
+
"""
|
|
199
|
+
x = np.asarray(image, dtype=np.float32)
|
|
200
|
+
|
|
201
|
+
# Ensure 3-channel
|
|
202
|
+
if x.ndim == 2:
|
|
203
|
+
x = np.stack([x, x, x], axis=-1)
|
|
204
|
+
elif x.ndim == 3 and x.shape[-1] == 1:
|
|
205
|
+
x = np.repeat(x, 3, axis=-1)
|
|
206
|
+
|
|
207
|
+
if x.ndim != 3 or x.shape[-1] != 3:
|
|
208
|
+
raise ValueError("extract_luminance expects HxW, HxWx1, or HxWx3")
|
|
209
|
+
|
|
210
|
+
# RGB -> YCbCr (BT.601) (same numbers as your sharpen_engine)
|
|
211
|
+
M = np.array([[0.299, 0.587, 0.114],
|
|
212
|
+
[-0.168736, -0.331264, 0.5],
|
|
213
|
+
[0.5, -0.418688, -0.081312]], dtype=np.float32)
|
|
214
|
+
|
|
215
|
+
ycbcr = x @ M.T
|
|
216
|
+
y = ycbcr[..., 0]
|
|
217
|
+
cb = ycbcr[..., 1] + 0.5
|
|
218
|
+
cr = ycbcr[..., 2] + 0.5
|
|
219
|
+
return y, cb, cr
|
|
220
|
+
|
|
221
|
+
def ycbcr_to_rgb(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
|
|
222
|
+
y = np.asarray(y, np.float32)
|
|
223
|
+
cb = np.asarray(cb, np.float32) - 0.5
|
|
224
|
+
cr = np.asarray(cr, np.float32) - 0.5
|
|
225
|
+
ycbcr = np.stack([y, cb, cr], axis=-1)
|
|
226
|
+
|
|
227
|
+
M = np.array([[1.0, 0.0, 1.402],
|
|
228
|
+
[1.0, -0.344136, -0.714136],
|
|
229
|
+
[1.0, 1.772, 0.0]], dtype=np.float32)
|
|
230
|
+
|
|
231
|
+
rgb = ycbcr @ M.T
|
|
232
|
+
return np.clip(rgb, 0.0, 1.0)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def merge_luminance(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
|
|
236
|
+
return ycbcr_to_rgb(np.clip(y, 0, 1), np.clip(cb, 0, 1), np.clip(cr, 0, 1))
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _guided_filter(guide: np.ndarray, src: np.ndarray, radius: int, eps: float) -> np.ndarray:
|
|
240
|
+
"""
|
|
241
|
+
Fast guided filter using boxFilter (edge-preserving, very fast).
|
|
242
|
+
guide and src are HxW float32 in [0,1].
|
|
243
|
+
radius is the neighborhood radius; ksize=(2*radius+1).
|
|
244
|
+
eps is the regularization term.
|
|
245
|
+
"""
|
|
246
|
+
r = max(1, int(radius))
|
|
247
|
+
ksize = (2*r + 1, 2*r + 1)
|
|
248
|
+
|
|
249
|
+
mean_I = cv2.boxFilter(guide, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
250
|
+
mean_p = cv2.boxFilter(src, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
251
|
+
mean_Ip = cv2.boxFilter(guide * src, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
252
|
+
cov_Ip = mean_Ip - mean_I * mean_p
|
|
253
|
+
|
|
254
|
+
mean_II = cv2.boxFilter(guide * guide, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
255
|
+
var_I = mean_II - mean_I * mean_I
|
|
256
|
+
|
|
257
|
+
a = cov_Ip / (var_I + eps)
|
|
258
|
+
b = mean_p - a * mean_I
|
|
259
|
+
|
|
260
|
+
mean_a = cv2.boxFilter(a, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
261
|
+
mean_b = cv2.boxFilter(b, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
|
|
262
|
+
|
|
263
|
+
q = mean_a * guide + mean_b
|
|
264
|
+
return q
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def denoise_chroma(cb: np.ndarray,
|
|
269
|
+
cr: np.ndarray,
|
|
270
|
+
strength: float,
|
|
271
|
+
method: str = "guided",
|
|
272
|
+
strength_scale: float = 2.0,
|
|
273
|
+
guide_y: np.ndarray | None = None):
|
|
274
|
+
"""
|
|
275
|
+
Fast chroma-only denoise for Cb/Cr in [0,1] float32.
|
|
276
|
+
method: 'guided' (default), 'gaussian', 'bilateral'
|
|
277
|
+
strength_scale: lets chroma smoothing go up to ~2× your slider.
|
|
278
|
+
guide_y: optional luminance guide (Y in [0,1]); required for 'guided' to be best.
|
|
279
|
+
"""
|
|
280
|
+
eff = float(np.clip(strength * strength_scale, 0.0, 1.0))
|
|
281
|
+
if eff <= 0.0:
|
|
282
|
+
return cb, cr
|
|
283
|
+
|
|
284
|
+
cb = cb.astype(np.float32, copy=False)
|
|
285
|
+
cr = cr.astype(np.float32, copy=False)
|
|
286
|
+
|
|
287
|
+
if method == "guided":
|
|
288
|
+
# Need a guide; if not provided, fall back to Gaussian
|
|
289
|
+
if guide_y is not None:
|
|
290
|
+
# radius & eps scale with strength; tuned for strong chroma smoothing but edge-safe
|
|
291
|
+
radius = 2 + int(round(10 * eff)) # ~2..12 (ksize ~5..25)
|
|
292
|
+
eps = (0.001 + 0.05 * eff) ** 2 # small regularization
|
|
293
|
+
cb_f = _guided_filter(guide_y, cb, radius, eps)
|
|
294
|
+
cr_f = _guided_filter(guide_y, cr, radius, eps)
|
|
295
|
+
else:
|
|
296
|
+
method = "gaussian" # no guide provided → fast fallback
|
|
297
|
+
|
|
298
|
+
if method == "gaussian":
|
|
299
|
+
k = 1 + 2 * int(round(8 * eff)) # 1,3,5,..,17
|
|
300
|
+
sigma = max(0.15, 2.4 * eff)
|
|
301
|
+
cb_f = cv2.GaussianBlur(cb, (k, k), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT)
|
|
302
|
+
cr_f = cv2.GaussianBlur(cr, (k, k), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT)
|
|
303
|
+
|
|
304
|
+
if method == "bilateral":
|
|
305
|
+
# Bilateral is decent but slower than Gaussian; guided is preferred for speed/quality.
|
|
306
|
+
d = 5 + 2 * int(round(6 * eff)) # 5..17
|
|
307
|
+
sigmaC = 25.0 * (0.5 + 3.0 * eff) # ~12.5..100
|
|
308
|
+
sigmaS = 3.0 * (0.5 + 6.0 * eff) # ~1.5..21
|
|
309
|
+
cb_f = cv2.bilateralFilter(cb, d=d, sigmaColor=sigmaC, sigmaSpace=sigmaS)
|
|
310
|
+
cr_f = cv2.bilateralFilter(cr, d=d, sigmaColor=sigmaC, sigmaSpace=sigmaS)
|
|
311
|
+
|
|
312
|
+
# Blend (maskless)
|
|
313
|
+
w = eff
|
|
314
|
+
cb_out = (1.0 - w) * cb + w * cb_f
|
|
315
|
+
cr_out = (1.0 - w) * cr + w * cr_f
|
|
316
|
+
return cb_out, cr_out
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# Function to split an image into chunks with overlap
|
|
320
|
+
def split_image_into_chunks_with_overlap(image, chunk_size, overlap):
|
|
321
|
+
height, width = image.shape[:2]
|
|
322
|
+
chunks = []
|
|
323
|
+
step_size = chunk_size - overlap # Define how much to step over (overlapping area)
|
|
324
|
+
|
|
325
|
+
for i in range(0, height, step_size):
|
|
326
|
+
for j in range(0, width, step_size):
|
|
327
|
+
end_i = min(i + chunk_size, height)
|
|
328
|
+
end_j = min(j + chunk_size, width)
|
|
329
|
+
chunk = image[i:end_i, j:end_j]
|
|
330
|
+
chunks.append((chunk, i, j)) # Return chunk and its position
|
|
331
|
+
return chunks
|
|
332
|
+
|
|
333
|
+
def blend_images(before, after, amount):
|
|
334
|
+
return (1 - amount) * before + amount * after
|
|
335
|
+
|
|
336
|
+
def stitch_chunks_ignore_border(chunks, image_shape, border_size: int = 16):
|
|
337
|
+
"""
|
|
338
|
+
chunks: list of (chunk, i, j) or (chunk, i, j, is_edge)
|
|
339
|
+
image_shape: (H,W)
|
|
340
|
+
"""
|
|
341
|
+
H, W = image_shape
|
|
342
|
+
stitched = np.zeros((H, W), dtype=np.float32)
|
|
343
|
+
weights = np.zeros((H, W), dtype=np.float32)
|
|
344
|
+
|
|
345
|
+
for entry in chunks:
|
|
346
|
+
# accept both 3-tuple and 4-tuple
|
|
347
|
+
if len(entry) == 3:
|
|
348
|
+
chunk, i, j = entry
|
|
349
|
+
else:
|
|
350
|
+
chunk, i, j, _ = entry
|
|
351
|
+
|
|
352
|
+
h, w = chunk.shape[:2]
|
|
353
|
+
bh = min(border_size, h // 2)
|
|
354
|
+
bw = min(border_size, w // 2)
|
|
355
|
+
|
|
356
|
+
inner = chunk[bh:h-bh, bw:w-bw]
|
|
357
|
+
stitched[i+bh:i+h-bh, j+bw:j+w-bw] += inner
|
|
358
|
+
weights[i+bh:i+h-bh, j+bw:j+w-bw] += 1.0
|
|
359
|
+
|
|
360
|
+
stitched /= np.maximum(weights, 1.0)
|
|
361
|
+
return stitched
|
|
362
|
+
|
|
363
|
+
def replace_border(original_image, processed_image, border_size=16):
|
|
364
|
+
# Ensure the dimensions of both images match
|
|
365
|
+
if original_image.shape != processed_image.shape:
|
|
366
|
+
raise ValueError("Original image and processed image must have the same dimensions.")
|
|
367
|
+
|
|
368
|
+
# Replace the top border
|
|
369
|
+
processed_image[:border_size, :] = original_image[:border_size, :]
|
|
370
|
+
|
|
371
|
+
# Replace the bottom border
|
|
372
|
+
processed_image[-border_size:, :] = original_image[-border_size:, :]
|
|
373
|
+
|
|
374
|
+
# Replace the left border
|
|
375
|
+
processed_image[:, :border_size] = original_image[:, :border_size]
|
|
376
|
+
|
|
377
|
+
# Replace the right border
|
|
378
|
+
processed_image[:, -border_size:] = original_image[:, -border_size:]
|
|
379
|
+
|
|
380
|
+
return processed_image
|
|
381
|
+
|
|
382
|
+
def stretch_image_unlinked(image: np.ndarray, target_median: float = 0.25):
|
|
383
|
+
x = np.asarray(image, np.float32).copy()
|
|
384
|
+
orig_min = float(np.min(x))
|
|
385
|
+
x -= orig_min
|
|
386
|
+
|
|
387
|
+
if x.ndim == 2:
|
|
388
|
+
med = float(np.median(x))
|
|
389
|
+
orig_meds = [med]
|
|
390
|
+
if med != 0:
|
|
391
|
+
x = ((med - 1) * target_median * x) / (med * (target_median + x - 1) - target_median * x)
|
|
392
|
+
return np.clip(x, 0, 1), orig_min, orig_meds
|
|
393
|
+
|
|
394
|
+
# 3ch
|
|
395
|
+
orig_meds = [float(np.median(x[..., c])) for c in range(3)]
|
|
396
|
+
for c in range(3):
|
|
397
|
+
m = orig_meds[c]
|
|
398
|
+
if m != 0:
|
|
399
|
+
x[..., c] = ((m - 1) * target_median * x[..., c]) / (
|
|
400
|
+
m * (target_median + x[..., c] - 1) - target_median * x[..., c]
|
|
401
|
+
)
|
|
402
|
+
return np.clip(x, 0, 1), orig_min, orig_meds
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def unstretch_image_unlinked(image: np.ndarray, orig_meds, orig_min: float):
|
|
406
|
+
x = np.asarray(image, np.float32).copy()
|
|
407
|
+
|
|
408
|
+
if x.ndim == 2:
|
|
409
|
+
m_now = float(np.median(x))
|
|
410
|
+
m0 = float(orig_meds[0])
|
|
411
|
+
if m_now != 0 and m0 != 0:
|
|
412
|
+
x = ((m_now - 1) * m0 * x) / (m_now * (m0 + x - 1) - m0 * x)
|
|
413
|
+
x += float(orig_min)
|
|
414
|
+
return np.clip(x, 0, 1)
|
|
415
|
+
|
|
416
|
+
for c in range(3):
|
|
417
|
+
m_now = float(np.median(x[..., c]))
|
|
418
|
+
m0 = float(orig_meds[c])
|
|
419
|
+
if m_now != 0 and m0 != 0:
|
|
420
|
+
x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
|
|
421
|
+
m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
x += float(orig_min)
|
|
425
|
+
return np.clip(x, 0, 1)
|
|
426
|
+
|
|
427
|
+
# Backwards-compatible names used by denoise_rgb01()
|
|
428
|
+
def stretch_image(image: np.ndarray):
|
|
429
|
+
return stretch_image_unlinked(image)
|
|
430
|
+
|
|
431
|
+
def unstretch_image(image: np.ndarray, original_medians, original_min: float):
|
|
432
|
+
return unstretch_image_unlinked(image, original_medians, original_min)
|
|
433
|
+
|
|
434
|
+
def add_border(image, border_size=16):
|
|
435
|
+
if image.ndim == 2: # mono
|
|
436
|
+
med = np.median(image)
|
|
437
|
+
return np.pad(image,
|
|
438
|
+
((border_size, border_size), (border_size, border_size)),
|
|
439
|
+
mode="constant",
|
|
440
|
+
constant_values=med)
|
|
441
|
+
|
|
442
|
+
elif image.ndim == 3 and image.shape[2] == 3: # RGB
|
|
443
|
+
meds = np.median(image, axis=(0, 1)).astype(image.dtype) # (3,)
|
|
444
|
+
padded = [np.pad(image[..., c],
|
|
445
|
+
((border_size, border_size), (border_size, border_size)),
|
|
446
|
+
mode="constant",
|
|
447
|
+
constant_values=float(meds[c]))
|
|
448
|
+
for c in range(3)]
|
|
449
|
+
return np.stack(padded, axis=-1)
|
|
450
|
+
else:
|
|
451
|
+
raise ValueError("add_border expects mono or RGB image.")
|
|
452
|
+
|
|
453
|
+
def remove_border(image, border_size: int = 16):
|
|
454
|
+
if image.ndim == 2:
|
|
455
|
+
return image[border_size:-border_size, border_size:-border_size]
|
|
456
|
+
return image[border_size:-border_size, border_size:-border_size, :]
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# ----------------------------
|
|
460
|
+
# Channel denoise (paste + keep)
|
|
461
|
+
# IMPORTANT: remove print() spam; instead accept an optional progress callback
|
|
462
|
+
# ----------------------------
|
|
463
|
+
def denoise_channel(channel: np.ndarray, models: Dict[str, Any], *, progress_cb: ProgressCB | None = None) -> np.ndarray:
|
|
464
|
+
device = models["device"]
|
|
465
|
+
is_onnx = models["is_onnx"]
|
|
466
|
+
model = models["mono_model"]
|
|
467
|
+
|
|
468
|
+
chunk_size = 256
|
|
469
|
+
overlap = 64
|
|
470
|
+
chunks = split_image_into_chunks_with_overlap(channel, chunk_size=chunk_size, overlap=overlap)
|
|
471
|
+
|
|
472
|
+
denoised_chunks = []
|
|
473
|
+
total = len(chunks)
|
|
474
|
+
|
|
475
|
+
for idx, (chunk, i, j) in enumerate(chunks):
|
|
476
|
+
original_chunk_shape = chunk.shape
|
|
477
|
+
|
|
478
|
+
if is_onnx:
|
|
479
|
+
chunk_input = chunk[np.newaxis, np.newaxis, :, :].astype(np.float32)
|
|
480
|
+
chunk_input = np.tile(chunk_input, (1, 3, 1, 1))
|
|
481
|
+
if chunk_input.shape[2] != chunk_size or chunk_input.shape[3] != chunk_size:
|
|
482
|
+
padded = np.zeros((1, 3, chunk_size, chunk_size), dtype=np.float32)
|
|
483
|
+
padded[:, :, :chunk_input.shape[2], :chunk_input.shape[3]] = chunk_input
|
|
484
|
+
chunk_input = padded
|
|
485
|
+
|
|
486
|
+
input_name = model.get_inputs()[0].name
|
|
487
|
+
out = model.run(None, {input_name: chunk_input})[0]
|
|
488
|
+
denoised_chunk = out[0, 0, :original_chunk_shape[0], :original_chunk_shape[1]]
|
|
489
|
+
|
|
490
|
+
else:
|
|
491
|
+
torch = models["torch"]
|
|
492
|
+
chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
|
|
493
|
+
chunk_tensor = chunk_tensor.expand(1, 3, chunk_tensor.shape[2], chunk_tensor.shape[3])
|
|
494
|
+
|
|
495
|
+
with torch.no_grad(), _autocast_context(torch, device):
|
|
496
|
+
out = model(chunk_tensor).detach().cpu().numpy() # (1,3,H,W)
|
|
497
|
+
|
|
498
|
+
denoised_chunk = out[0, 0, :original_chunk_shape[0], :original_chunk_shape[1]]
|
|
499
|
+
|
|
500
|
+
denoised_chunks.append((denoised_chunk, i, j))
|
|
501
|
+
|
|
502
|
+
if progress_cb is not None:
|
|
503
|
+
progress_cb(idx + 1, total)
|
|
504
|
+
|
|
505
|
+
return stitch_chunks_ignore_border(denoised_chunks, channel.shape, border_size=16)
|
|
506
|
+
|
|
507
|
+
# ----------------------------
|
|
508
|
+
# High-level denoise for a loaded RGB float image (0..1)
|
|
509
|
+
# (this is the “engine API” SASpro will call)
|
|
510
|
+
# ----------------------------
|
|
511
|
+
def denoise_rgb01(
|
|
512
|
+
img_rgb01: np.ndarray,
|
|
513
|
+
*,
|
|
514
|
+
denoise_strength: float,
|
|
515
|
+
denoise_mode: str = "luminance", # luminance | full | separate
|
|
516
|
+
separate_channels: bool = False,
|
|
517
|
+
color_denoise_strength: Optional[float] = None,
|
|
518
|
+
use_gpu: bool = True,
|
|
519
|
+
progress_cb=None,
|
|
520
|
+
) -> np.ndarray:
|
|
521
|
+
"""
|
|
522
|
+
Input: float32 RGB [0..1]
|
|
523
|
+
Output: float32 RGB [0..1]
|
|
524
|
+
"""
|
|
525
|
+
models = load_models(use_gpu=use_gpu)
|
|
526
|
+
|
|
527
|
+
# Determine stretch necessity (keep your logic)
|
|
528
|
+
stretch_needed = (np.median(img_rgb01 - np.min(img_rgb01)) < 0.05)
|
|
529
|
+
if stretch_needed:
|
|
530
|
+
stretched_core, original_min, original_medians = stretch_image(img_rgb01)
|
|
531
|
+
else:
|
|
532
|
+
stretched_core = img_rgb01.astype(np.float32, copy=False)
|
|
533
|
+
original_min = float(np.min(img_rgb01))
|
|
534
|
+
original_medians = [float(np.median(img_rgb01[..., c])) for c in range(3)]
|
|
535
|
+
|
|
536
|
+
stretched = add_border(stretched_core, border_size=16)
|
|
537
|
+
|
|
538
|
+
# Process
|
|
539
|
+
if separate_channels or denoise_mode == "separate":
|
|
540
|
+
out_ch = []
|
|
541
|
+
for c in range(3):
|
|
542
|
+
dch = denoise_channel(stretched[..., c], models, progress_cb=progress_cb)
|
|
543
|
+
out_ch.append(blend_images(stretched[..., c], dch, denoise_strength))
|
|
544
|
+
den = np.stack(out_ch, axis=-1)
|
|
545
|
+
|
|
546
|
+
elif denoise_mode == "luminance":
|
|
547
|
+
y, cb, cr = extract_luminance(stretched)
|
|
548
|
+
den_y = denoise_channel(y, models, progress_cb=progress_cb)
|
|
549
|
+
y2 = blend_images(y, den_y, denoise_strength)
|
|
550
|
+
den = merge_luminance(y2, cb, cr)
|
|
551
|
+
|
|
552
|
+
else:
|
|
553
|
+
# full: L via NN, chroma via guided
|
|
554
|
+
y, cb, cr = extract_luminance(stretched)
|
|
555
|
+
den_y = denoise_channel(y, models, progress_cb=progress_cb)
|
|
556
|
+
y2 = blend_images(y, den_y, denoise_strength)
|
|
557
|
+
|
|
558
|
+
cs = denoise_strength if color_denoise_strength is None else color_denoise_strength
|
|
559
|
+
cb2, cr2 = denoise_chroma(cb, cr, strength=cs, method="guided", guide_y=y)
|
|
560
|
+
den = merge_luminance(y2, cb2, cr2)
|
|
561
|
+
|
|
562
|
+
# unstretch if needed
|
|
563
|
+
if stretch_needed:
|
|
564
|
+
den = unstretch_image(den, original_medians, original_min)
|
|
565
|
+
|
|
566
|
+
den = remove_border(den, border_size=16)
|
|
567
|
+
return np.clip(den, 0.0, 1.0).astype(np.float32, copy=False)
|