setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,567 @@
1
+ # src/setiastro/saspro/cosmicclarity_engines/denoise_engine.py
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any, Tuple
8
+
9
+ import numpy as np
10
+
11
+ import cv2
12
+
13
+
14
+ from setiastro.saspro.resources import get_resources
15
+
16
+ warnings.filterwarnings("ignore")
17
+
18
+ from typing import Callable
19
+
20
+ ProgressCB = Callable[[int, int], None] # (done, total)
21
+
22
+ try:
23
+ import onnxruntime as ort
24
+ except Exception:
25
+ ort = None
26
+
27
+
28
+ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
29
+ from setiastro.saspro.runtime_torch import import_torch
30
+ return import_torch(
31
+ prefer_cuda=prefer_cuda,
32
+ prefer_xpu=False,
33
+ prefer_dml=prefer_dml,
34
+ status_cb=status_cb,
35
+ )
36
+
37
+
38
+ def _nullcontext():
39
+ from contextlib import nullcontext
40
+ return nullcontext()
41
+
42
+
43
+ def _autocast_context(torch, device) -> Any:
44
+ """
45
+ Use new torch.amp.autocast('cuda') when available.
46
+ Keep your cap >= 8.0 rule.
47
+ """
48
+ try:
49
+ if hasattr(device, "type") and device.type == "cuda":
50
+ major, minor = torch.cuda.get_device_capability()
51
+ cap = float(f"{major}.{minor}")
52
+ if cap >= 8.0:
53
+ # Preferred API (torch >= 1.10-ish; definitely in 2.x)
54
+ if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
55
+ return torch.amp.autocast(device_type="cuda")
56
+ # Fallback for older torch
57
+ return torch.cuda.amp.autocast()
58
+ except Exception:
59
+ pass
60
+ return _nullcontext()
61
+
62
+
63
+
64
+ # ----------------------------
65
+ # Model definitions (unchanged)
66
+ # ----------------------------
67
+ def _load_torch_model(torch, device, ckpt_path: str):
68
+ nn = torch.nn
69
+
70
+ class ResidualBlock(nn.Module):
71
+ def __init__(self, channels: int):
72
+ super().__init__()
73
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
74
+ self.relu = nn.ReLU()
75
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
76
+
77
+ def forward(self, x):
78
+ residual = x
79
+ out = self.relu(self.conv1(x))
80
+ out = self.conv2(out)
81
+ out = self.relu(out + residual)
82
+ return out
83
+
84
+ class DenoiseCNN(nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
88
+ self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
89
+ self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(64))
90
+ self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
91
+ self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(256))
92
+
93
+ self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
94
+ self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(), ResidualBlock(64))
95
+ self.decoder3 = nn.Sequential(nn.Conv2d( 64 + 32, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
96
+ self.decoder2 = nn.Sequential(nn.Conv2d( 32 + 16, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
97
+ self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
98
+
99
+ def forward(self, x):
100
+ e1 = self.encoder1(x)
101
+ e2 = self.encoder2(e1)
102
+ e3 = self.encoder3(e2)
103
+ e4 = self.encoder4(e3)
104
+ e5 = self.encoder5(e4)
105
+
106
+ d5 = self.decoder5(torch.cat([e5, e4], dim=1))
107
+ d4 = self.decoder4(torch.cat([d5, e3], dim=1))
108
+ d3 = self.decoder3(torch.cat([d4, e2], dim=1))
109
+ d2 = self.decoder2(torch.cat([d3, e1], dim=1))
110
+ return self.decoder1(d2)
111
+
112
+ net = DenoiseCNN().to(device)
113
+ ckpt = torch.load(ckpt_path, map_location=device)
114
+ net.load_state_dict(ckpt.get("model_state_dict", ckpt))
115
+ net.eval()
116
+ return net
117
+
118
+
119
+ # ----------------------------
120
+ # Model cache
121
+ # ----------------------------
122
+ _cached_models: dict[tuple[str, bool], Dict[str, Any]] = {} # (backend_tag, use_gpu)
123
+ _BACKEND_TAG = "cc_denoise_ai3_6"
124
+
125
+ R = get_resources()
126
+
127
+
128
+ def load_models(use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
129
+ key = (_BACKEND_TAG, bool(use_gpu))
130
+ if key in _cached_models:
131
+ return _cached_models[key]
132
+
133
+ is_windows = (os.name == "nt")
134
+
135
+ torch = _get_torch(
136
+ prefer_cuda=bool(use_gpu),
137
+ prefer_dml=bool(use_gpu and is_windows),
138
+ status_cb=status_cb,
139
+ )
140
+
141
+ # 1) CUDA
142
+ if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
143
+ device = torch.device("cuda")
144
+ status_cb(f"CosmicClarity Denoise: using CUDA ({torch.cuda.get_device_name(0)})")
145
+ mono_model = _load_torch_model(torch, device, R.CC_DENOISE_PTH)
146
+ models = {"device": device, "is_onnx": False, "mono_model": mono_model, "torch": torch}
147
+ status_cb(f"Denoise backend resolved: "
148
+ f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
149
+ _cached_models[key] = models
150
+ return models
151
+
152
+ # 2) Torch-DirectML (Windows)
153
+ if use_gpu and is_windows:
154
+ try:
155
+ import torch_directml
156
+ dml = torch_directml.device()
157
+ status_cb("CosmicClarity Denoise: using DirectML (torch-directml)")
158
+ mono_model = _load_torch_model(torch, dml, R.CC_DENOISE_PTH)
159
+ models = {"device": dml, "is_onnx": False, "mono_model": mono_model, "torch": torch}
160
+ status_cb(f"Denoise backend resolved: "
161
+ f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
162
+ _cached_models[key] = models
163
+ return models
164
+ except Exception:
165
+ pass
166
+
167
+ # 3) ORT DirectML fallback
168
+ if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
169
+ status_cb("CosmicClarity Denoise: using DirectML (ONNX Runtime)")
170
+ mono_model = ort.InferenceSession(R.CC_DENOISE_ONNX, providers=["DmlExecutionProvider"])
171
+ models = {"device": "DirectML", "is_onnx": True, "mono_model": mono_model, "torch": None}
172
+ status_cb(f"Denoise backend resolved: "
173
+ f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
174
+ _cached_models[key] = models
175
+ return models
176
+
177
+ # 4) CPU
178
+ device = torch.device("cpu")
179
+ status_cb("CosmicClarity Denoise: using CPU")
180
+ mono_model = _load_torch_model(torch, device, R.CC_DENOISE_PTH)
181
+ models = {"device": device, "is_onnx": False, "mono_model": mono_model, "torch": torch}
182
+ status_cb(f"Denoise backend resolved: "
183
+ f"{'onnx' if models['is_onnx'] else 'torch'} / device={models['device']!r}")
184
+ _cached_models[key] = models
185
+ return models
186
+
187
+
188
+ # ----------------------------
189
+ # Your helpers: luminance/chroma, chunks, borders, stretch
190
+ # (paste your existing implementations here)
191
+ # ----------------------------
192
+ def extract_luminance(image: np.ndarray):
193
+ """
194
+ Input: mono HxW, mono HxWx1, or RGB HxWx3 float32 in [0,1].
195
+ Output: (Y, Cb, Cr) where:
196
+ - Y is HxW
197
+ - Cb/Cr are HxW in [0,1] (with +0.5 offset already applied)
198
+ """
199
+ x = np.asarray(image, dtype=np.float32)
200
+
201
+ # Ensure 3-channel
202
+ if x.ndim == 2:
203
+ x = np.stack([x, x, x], axis=-1)
204
+ elif x.ndim == 3 and x.shape[-1] == 1:
205
+ x = np.repeat(x, 3, axis=-1)
206
+
207
+ if x.ndim != 3 or x.shape[-1] != 3:
208
+ raise ValueError("extract_luminance expects HxW, HxWx1, or HxWx3")
209
+
210
+ # RGB -> YCbCr (BT.601) (same numbers as your sharpen_engine)
211
+ M = np.array([[0.299, 0.587, 0.114],
212
+ [-0.168736, -0.331264, 0.5],
213
+ [0.5, -0.418688, -0.081312]], dtype=np.float32)
214
+
215
+ ycbcr = x @ M.T
216
+ y = ycbcr[..., 0]
217
+ cb = ycbcr[..., 1] + 0.5
218
+ cr = ycbcr[..., 2] + 0.5
219
+ return y, cb, cr
220
+
221
+ def ycbcr_to_rgb(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
222
+ y = np.asarray(y, np.float32)
223
+ cb = np.asarray(cb, np.float32) - 0.5
224
+ cr = np.asarray(cr, np.float32) - 0.5
225
+ ycbcr = np.stack([y, cb, cr], axis=-1)
226
+
227
+ M = np.array([[1.0, 0.0, 1.402],
228
+ [1.0, -0.344136, -0.714136],
229
+ [1.0, 1.772, 0.0]], dtype=np.float32)
230
+
231
+ rgb = ycbcr @ M.T
232
+ return np.clip(rgb, 0.0, 1.0)
233
+
234
+
235
+ def merge_luminance(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
236
+ return ycbcr_to_rgb(np.clip(y, 0, 1), np.clip(cb, 0, 1), np.clip(cr, 0, 1))
237
+
238
+
239
+ def _guided_filter(guide: np.ndarray, src: np.ndarray, radius: int, eps: float) -> np.ndarray:
240
+ """
241
+ Fast guided filter using boxFilter (edge-preserving, very fast).
242
+ guide and src are HxW float32 in [0,1].
243
+ radius is the neighborhood radius; ksize=(2*radius+1).
244
+ eps is the regularization term.
245
+ """
246
+ r = max(1, int(radius))
247
+ ksize = (2*r + 1, 2*r + 1)
248
+
249
+ mean_I = cv2.boxFilter(guide, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
250
+ mean_p = cv2.boxFilter(src, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
251
+ mean_Ip = cv2.boxFilter(guide * src, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
252
+ cov_Ip = mean_Ip - mean_I * mean_p
253
+
254
+ mean_II = cv2.boxFilter(guide * guide, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
255
+ var_I = mean_II - mean_I * mean_I
256
+
257
+ a = cov_Ip / (var_I + eps)
258
+ b = mean_p - a * mean_I
259
+
260
+ mean_a = cv2.boxFilter(a, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
261
+ mean_b = cv2.boxFilter(b, ddepth=-1, ksize=ksize, borderType=cv2.BORDER_REFLECT)
262
+
263
+ q = mean_a * guide + mean_b
264
+ return q
265
+
266
+
267
+
268
+ def denoise_chroma(cb: np.ndarray,
269
+ cr: np.ndarray,
270
+ strength: float,
271
+ method: str = "guided",
272
+ strength_scale: float = 2.0,
273
+ guide_y: np.ndarray | None = None):
274
+ """
275
+ Fast chroma-only denoise for Cb/Cr in [0,1] float32.
276
+ method: 'guided' (default), 'gaussian', 'bilateral'
277
+ strength_scale: lets chroma smoothing go up to ~2× your slider.
278
+ guide_y: optional luminance guide (Y in [0,1]); required for 'guided' to be best.
279
+ """
280
+ eff = float(np.clip(strength * strength_scale, 0.0, 1.0))
281
+ if eff <= 0.0:
282
+ return cb, cr
283
+
284
+ cb = cb.astype(np.float32, copy=False)
285
+ cr = cr.astype(np.float32, copy=False)
286
+
287
+ if method == "guided":
288
+ # Need a guide; if not provided, fall back to Gaussian
289
+ if guide_y is not None:
290
+ # radius & eps scale with strength; tuned for strong chroma smoothing but edge-safe
291
+ radius = 2 + int(round(10 * eff)) # ~2..12 (ksize ~5..25)
292
+ eps = (0.001 + 0.05 * eff) ** 2 # small regularization
293
+ cb_f = _guided_filter(guide_y, cb, radius, eps)
294
+ cr_f = _guided_filter(guide_y, cr, radius, eps)
295
+ else:
296
+ method = "gaussian" # no guide provided → fast fallback
297
+
298
+ if method == "gaussian":
299
+ k = 1 + 2 * int(round(8 * eff)) # 1,3,5,..,17
300
+ sigma = max(0.15, 2.4 * eff)
301
+ cb_f = cv2.GaussianBlur(cb, (k, k), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT)
302
+ cr_f = cv2.GaussianBlur(cr, (k, k), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT)
303
+
304
+ if method == "bilateral":
305
+ # Bilateral is decent but slower than Gaussian; guided is preferred for speed/quality.
306
+ d = 5 + 2 * int(round(6 * eff)) # 5..17
307
+ sigmaC = 25.0 * (0.5 + 3.0 * eff) # ~12.5..100
308
+ sigmaS = 3.0 * (0.5 + 6.0 * eff) # ~1.5..21
309
+ cb_f = cv2.bilateralFilter(cb, d=d, sigmaColor=sigmaC, sigmaSpace=sigmaS)
310
+ cr_f = cv2.bilateralFilter(cr, d=d, sigmaColor=sigmaC, sigmaSpace=sigmaS)
311
+
312
+ # Blend (maskless)
313
+ w = eff
314
+ cb_out = (1.0 - w) * cb + w * cb_f
315
+ cr_out = (1.0 - w) * cr + w * cr_f
316
+ return cb_out, cr_out
317
+
318
+
319
+ # Function to split an image into chunks with overlap
320
+ def split_image_into_chunks_with_overlap(image, chunk_size, overlap):
321
+ height, width = image.shape[:2]
322
+ chunks = []
323
+ step_size = chunk_size - overlap # Define how much to step over (overlapping area)
324
+
325
+ for i in range(0, height, step_size):
326
+ for j in range(0, width, step_size):
327
+ end_i = min(i + chunk_size, height)
328
+ end_j = min(j + chunk_size, width)
329
+ chunk = image[i:end_i, j:end_j]
330
+ chunks.append((chunk, i, j)) # Return chunk and its position
331
+ return chunks
332
+
333
+ def blend_images(before, after, amount):
334
+ return (1 - amount) * before + amount * after
335
+
336
+ def stitch_chunks_ignore_border(chunks, image_shape, border_size: int = 16):
337
+ """
338
+ chunks: list of (chunk, i, j) or (chunk, i, j, is_edge)
339
+ image_shape: (H,W)
340
+ """
341
+ H, W = image_shape
342
+ stitched = np.zeros((H, W), dtype=np.float32)
343
+ weights = np.zeros((H, W), dtype=np.float32)
344
+
345
+ for entry in chunks:
346
+ # accept both 3-tuple and 4-tuple
347
+ if len(entry) == 3:
348
+ chunk, i, j = entry
349
+ else:
350
+ chunk, i, j, _ = entry
351
+
352
+ h, w = chunk.shape[:2]
353
+ bh = min(border_size, h // 2)
354
+ bw = min(border_size, w // 2)
355
+
356
+ inner = chunk[bh:h-bh, bw:w-bw]
357
+ stitched[i+bh:i+h-bh, j+bw:j+w-bw] += inner
358
+ weights[i+bh:i+h-bh, j+bw:j+w-bw] += 1.0
359
+
360
+ stitched /= np.maximum(weights, 1.0)
361
+ return stitched
362
+
363
+ def replace_border(original_image, processed_image, border_size=16):
364
+ # Ensure the dimensions of both images match
365
+ if original_image.shape != processed_image.shape:
366
+ raise ValueError("Original image and processed image must have the same dimensions.")
367
+
368
+ # Replace the top border
369
+ processed_image[:border_size, :] = original_image[:border_size, :]
370
+
371
+ # Replace the bottom border
372
+ processed_image[-border_size:, :] = original_image[-border_size:, :]
373
+
374
+ # Replace the left border
375
+ processed_image[:, :border_size] = original_image[:, :border_size]
376
+
377
+ # Replace the right border
378
+ processed_image[:, -border_size:] = original_image[:, -border_size:]
379
+
380
+ return processed_image
381
+
382
+ def stretch_image_unlinked(image: np.ndarray, target_median: float = 0.25):
383
+ x = np.asarray(image, np.float32).copy()
384
+ orig_min = float(np.min(x))
385
+ x -= orig_min
386
+
387
+ if x.ndim == 2:
388
+ med = float(np.median(x))
389
+ orig_meds = [med]
390
+ if med != 0:
391
+ x = ((med - 1) * target_median * x) / (med * (target_median + x - 1) - target_median * x)
392
+ return np.clip(x, 0, 1), orig_min, orig_meds
393
+
394
+ # 3ch
395
+ orig_meds = [float(np.median(x[..., c])) for c in range(3)]
396
+ for c in range(3):
397
+ m = orig_meds[c]
398
+ if m != 0:
399
+ x[..., c] = ((m - 1) * target_median * x[..., c]) / (
400
+ m * (target_median + x[..., c] - 1) - target_median * x[..., c]
401
+ )
402
+ return np.clip(x, 0, 1), orig_min, orig_meds
403
+
404
+
405
+ def unstretch_image_unlinked(image: np.ndarray, orig_meds, orig_min: float):
406
+ x = np.asarray(image, np.float32).copy()
407
+
408
+ if x.ndim == 2:
409
+ m_now = float(np.median(x))
410
+ m0 = float(orig_meds[0])
411
+ if m_now != 0 and m0 != 0:
412
+ x = ((m_now - 1) * m0 * x) / (m_now * (m0 + x - 1) - m0 * x)
413
+ x += float(orig_min)
414
+ return np.clip(x, 0, 1)
415
+
416
+ for c in range(3):
417
+ m_now = float(np.median(x[..., c]))
418
+ m0 = float(orig_meds[c])
419
+ if m_now != 0 and m0 != 0:
420
+ x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
421
+ m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
422
+ )
423
+
424
+ x += float(orig_min)
425
+ return np.clip(x, 0, 1)
426
+
427
+ # Backwards-compatible names used by denoise_rgb01()
428
+ def stretch_image(image: np.ndarray):
429
+ return stretch_image_unlinked(image)
430
+
431
+ def unstretch_image(image: np.ndarray, original_medians, original_min: float):
432
+ return unstretch_image_unlinked(image, original_medians, original_min)
433
+
434
+ def add_border(image, border_size=16):
435
+ if image.ndim == 2: # mono
436
+ med = np.median(image)
437
+ return np.pad(image,
438
+ ((border_size, border_size), (border_size, border_size)),
439
+ mode="constant",
440
+ constant_values=med)
441
+
442
+ elif image.ndim == 3 and image.shape[2] == 3: # RGB
443
+ meds = np.median(image, axis=(0, 1)).astype(image.dtype) # (3,)
444
+ padded = [np.pad(image[..., c],
445
+ ((border_size, border_size), (border_size, border_size)),
446
+ mode="constant",
447
+ constant_values=float(meds[c]))
448
+ for c in range(3)]
449
+ return np.stack(padded, axis=-1)
450
+ else:
451
+ raise ValueError("add_border expects mono or RGB image.")
452
+
453
+ def remove_border(image, border_size: int = 16):
454
+ if image.ndim == 2:
455
+ return image[border_size:-border_size, border_size:-border_size]
456
+ return image[border_size:-border_size, border_size:-border_size, :]
457
+
458
+
459
+ # ----------------------------
460
+ # Channel denoise (paste + keep)
461
+ # IMPORTANT: remove print() spam; instead accept an optional progress callback
462
+ # ----------------------------
463
+ def denoise_channel(channel: np.ndarray, models: Dict[str, Any], *, progress_cb: ProgressCB | None = None) -> np.ndarray:
464
+ device = models["device"]
465
+ is_onnx = models["is_onnx"]
466
+ model = models["mono_model"]
467
+
468
+ chunk_size = 256
469
+ overlap = 64
470
+ chunks = split_image_into_chunks_with_overlap(channel, chunk_size=chunk_size, overlap=overlap)
471
+
472
+ denoised_chunks = []
473
+ total = len(chunks)
474
+
475
+ for idx, (chunk, i, j) in enumerate(chunks):
476
+ original_chunk_shape = chunk.shape
477
+
478
+ if is_onnx:
479
+ chunk_input = chunk[np.newaxis, np.newaxis, :, :].astype(np.float32)
480
+ chunk_input = np.tile(chunk_input, (1, 3, 1, 1))
481
+ if chunk_input.shape[2] != chunk_size or chunk_input.shape[3] != chunk_size:
482
+ padded = np.zeros((1, 3, chunk_size, chunk_size), dtype=np.float32)
483
+ padded[:, :, :chunk_input.shape[2], :chunk_input.shape[3]] = chunk_input
484
+ chunk_input = padded
485
+
486
+ input_name = model.get_inputs()[0].name
487
+ out = model.run(None, {input_name: chunk_input})[0]
488
+ denoised_chunk = out[0, 0, :original_chunk_shape[0], :original_chunk_shape[1]]
489
+
490
+ else:
491
+ torch = models["torch"]
492
+ chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
493
+ chunk_tensor = chunk_tensor.expand(1, 3, chunk_tensor.shape[2], chunk_tensor.shape[3])
494
+
495
+ with torch.no_grad(), _autocast_context(torch, device):
496
+ out = model(chunk_tensor).detach().cpu().numpy() # (1,3,H,W)
497
+
498
+ denoised_chunk = out[0, 0, :original_chunk_shape[0], :original_chunk_shape[1]]
499
+
500
+ denoised_chunks.append((denoised_chunk, i, j))
501
+
502
+ if progress_cb is not None:
503
+ progress_cb(idx + 1, total)
504
+
505
+ return stitch_chunks_ignore_border(denoised_chunks, channel.shape, border_size=16)
506
+
507
+ # ----------------------------
508
+ # High-level denoise for a loaded RGB float image (0..1)
509
+ # (this is the “engine API” SASpro will call)
510
+ # ----------------------------
511
+ def denoise_rgb01(
512
+ img_rgb01: np.ndarray,
513
+ *,
514
+ denoise_strength: float,
515
+ denoise_mode: str = "luminance", # luminance | full | separate
516
+ separate_channels: bool = False,
517
+ color_denoise_strength: Optional[float] = None,
518
+ use_gpu: bool = True,
519
+ progress_cb=None,
520
+ ) -> np.ndarray:
521
+ """
522
+ Input: float32 RGB [0..1]
523
+ Output: float32 RGB [0..1]
524
+ """
525
+ models = load_models(use_gpu=use_gpu)
526
+
527
+ # Determine stretch necessity (keep your logic)
528
+ stretch_needed = (np.median(img_rgb01 - np.min(img_rgb01)) < 0.05)
529
+ if stretch_needed:
530
+ stretched_core, original_min, original_medians = stretch_image(img_rgb01)
531
+ else:
532
+ stretched_core = img_rgb01.astype(np.float32, copy=False)
533
+ original_min = float(np.min(img_rgb01))
534
+ original_medians = [float(np.median(img_rgb01[..., c])) for c in range(3)]
535
+
536
+ stretched = add_border(stretched_core, border_size=16)
537
+
538
+ # Process
539
+ if separate_channels or denoise_mode == "separate":
540
+ out_ch = []
541
+ for c in range(3):
542
+ dch = denoise_channel(stretched[..., c], models, progress_cb=progress_cb)
543
+ out_ch.append(blend_images(stretched[..., c], dch, denoise_strength))
544
+ den = np.stack(out_ch, axis=-1)
545
+
546
+ elif denoise_mode == "luminance":
547
+ y, cb, cr = extract_luminance(stretched)
548
+ den_y = denoise_channel(y, models, progress_cb=progress_cb)
549
+ y2 = blend_images(y, den_y, denoise_strength)
550
+ den = merge_luminance(y2, cb, cr)
551
+
552
+ else:
553
+ # full: L via NN, chroma via guided
554
+ y, cb, cr = extract_luminance(stretched)
555
+ den_y = denoise_channel(y, models, progress_cb=progress_cb)
556
+ y2 = blend_images(y, den_y, denoise_strength)
557
+
558
+ cs = denoise_strength if color_denoise_strength is None else color_denoise_strength
559
+ cb2, cr2 = denoise_chroma(cb, cr, strength=cs, method="guided", guide_y=y)
560
+ den = merge_luminance(y2, cb2, cr2)
561
+
562
+ # unstretch if needed
563
+ if stretch_needed:
564
+ den = unstretch_image(den, original_medians, original_min)
565
+
566
+ den = remove_border(den, border_size=16)
567
+ return np.clip(den, 0.0, 1.0).astype(np.float32, copy=False)