setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,587 @@
1
+ # src/setiastro/saspro/cosmicclarity_engines/sharpen_engine.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Any
7
+ import os
8
+ import numpy as np
9
+ from setiastro.saspro.resources import get_resources
10
+
11
+
12
+ # Optional deps used by auto-PSF
13
+ try:
14
+ import sep
15
+ except Exception:
16
+ sep = None
17
+
18
+ try:
19
+ import onnxruntime as ort
20
+ except Exception:
21
+ ort = None
22
+
23
+
24
+ ProgressCB = Callable[[int, int, str], bool] # True=continue, False=cancel
25
+
26
+
27
+ # ---------------- Torch model defs (needed for .pth) ----------------
28
+ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
29
+ from setiastro.saspro.runtime_torch import import_torch
30
+ return import_torch(
31
+ prefer_cuda=prefer_cuda,
32
+ prefer_xpu=False,
33
+ prefer_dml=prefer_dml,
34
+ status_cb=status_cb,
35
+ )
36
+
37
+
38
+
39
+ def _nullcontext():
40
+ from contextlib import nullcontext
41
+ return nullcontext()
42
+
43
+
44
+ def _autocast_context(torch, device) -> Any:
45
+ """
46
+ Use new torch.amp.autocast('cuda') when available.
47
+ Keep your cap >= 8.0 rule.
48
+ """
49
+ try:
50
+ if hasattr(device, "type") and device.type == "cuda":
51
+ major, minor = torch.cuda.get_device_capability()
52
+ cap = float(f"{major}.{minor}")
53
+ if cap >= 8.0:
54
+ # Preferred API (torch >= 1.10-ish; definitely in 2.x)
55
+ if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
56
+ return torch.amp.autocast(device_type="cuda")
57
+ # Fallback for older torch
58
+ return torch.cuda.amp.autocast()
59
+ except Exception:
60
+ pass
61
+ return _nullcontext()
62
+
63
+
64
+
65
+ def _to_3ch(image: np.ndarray) -> tuple[np.ndarray, bool]:
66
+ """Return (img3, was_mono). img3 is HxWx3 float32."""
67
+ if image.ndim == 2:
68
+ img3 = np.stack([image, image, image], axis=-1)
69
+ return img3, True
70
+ if image.ndim == 3 and image.shape[2] == 1:
71
+ img = image[..., 0]
72
+ img3 = np.stack([img, img, img], axis=-1)
73
+ return img3, True
74
+ return image, False
75
+
76
+
77
+ # Your BT.601 luminance extraction / merge
78
+ def extract_luminance_rgb(image_rgb: np.ndarray):
79
+ image_rgb = np.asarray(image_rgb, dtype=np.float32)
80
+ if image_rgb.shape[-1] != 3:
81
+ raise ValueError("extract_luminance_rgb expects HxWx3")
82
+ M = np.array([[0.299, 0.587, 0.114],
83
+ [-0.168736, -0.331264, 0.5],
84
+ [0.5, -0.418688, -0.081312]], dtype=np.float32)
85
+ ycbcr = image_rgb @ M.T
86
+ y = ycbcr[..., 0]
87
+ cb = ycbcr[..., 1] + 0.5
88
+ cr = ycbcr[..., 2] + 0.5
89
+ return y, cb, cr
90
+
91
+
92
+ def ycbcr_to_rgb(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
93
+ y = np.asarray(y, np.float32)
94
+ cb = np.asarray(cb, np.float32) - 0.5
95
+ cr = np.asarray(cr, np.float32) - 0.5
96
+ ycbcr = np.stack([y, cb, cr], axis=-1)
97
+ M = np.array([[1.0, 0.0, 1.402],
98
+ [1.0, -0.344136, -0.714136],
99
+ [1.0, 1.772, 0.0]], dtype=np.float32)
100
+ rgb = ycbcr @ M.T
101
+ return np.clip(rgb, 0.0, 1.0)
102
+
103
+
104
+ def merge_luminance(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
105
+ return ycbcr_to_rgb(np.clip(y, 0, 1), np.clip(cb, 0, 1), np.clip(cr, 0, 1))
106
+
107
+
108
+ # ---------------- Chunking & stitching (your exact behavior) ----------------
109
+
110
+ def split_image_into_chunks_with_overlap(image2d: np.ndarray, chunk_size: int, overlap: int):
111
+ H, W = image2d.shape
112
+ step = chunk_size - overlap
113
+ out = []
114
+ for i in range(0, H, step):
115
+ for j in range(0, W, step):
116
+ ei = min(i + chunk_size, H)
117
+ ej = min(j + chunk_size, W)
118
+ chunk = image2d[i:ei, j:ej]
119
+ is_edge = (i == 0) or (j == 0) or (i + chunk_size >= H) or (j + chunk_size >= W)
120
+ out.append((chunk, i, j, is_edge))
121
+ return out
122
+
123
+
124
+ def stitch_chunks_ignore_border(chunks, image_shape, border_size: int = 16):
125
+ stitched = np.zeros(image_shape, dtype=np.float32)
126
+ weights = np.zeros(image_shape, dtype=np.float32)
127
+
128
+ for chunk, i, j, _is_edge in chunks:
129
+ h, w = chunk.shape
130
+ bh = min(border_size, h // 2)
131
+ bw = min(border_size, w // 2)
132
+ inner = chunk[bh:h-bh, bw:w-bw]
133
+ stitched[i+bh:i+h-bh, j+bw:j+w-bw] += inner
134
+ weights[i+bh:i+h-bh, j+bw:j+w-bw] += 1.0
135
+
136
+ stitched /= np.maximum(weights, 1.0)
137
+ return stitched
138
+
139
+
140
+ def add_border(image: np.ndarray, border_size: int = 16) -> np.ndarray:
141
+ med = float(np.median(image))
142
+ if image.ndim == 2:
143
+ return np.pad(image, ((border_size, border_size), (border_size, border_size)),
144
+ mode="constant", constant_values=med)
145
+ return np.pad(image, ((border_size, border_size), (border_size, border_size), (0, 0)),
146
+ mode="constant", constant_values=med)
147
+
148
+
149
+ def remove_border(image: np.ndarray, border_size: int = 16) -> np.ndarray:
150
+ if image.ndim == 2:
151
+ return image[border_size:-border_size, border_size:-border_size]
152
+ return image[border_size:-border_size, border_size:-border_size, :]
153
+
154
+
155
+ def blend_images(before: np.ndarray, after: np.ndarray, amount: float) -> np.ndarray:
156
+ a = float(np.clip(amount, 0.0, 1.0))
157
+ return (1.0 - a) * before + a * after
158
+
159
+
160
+ # ---------------- Stretch / unstretch (your current logic) ----------------
161
+
162
+ def stretch_image_unlinked_rgb(image_rgb: np.ndarray, target_median: float = 0.25):
163
+ x = image_rgb.astype(np.float32, copy=True)
164
+ orig_min = float(np.min(x))
165
+ x -= orig_min
166
+ orig_meds = [float(np.median(x[..., c])) for c in range(3)]
167
+
168
+ for c in range(3):
169
+ m = orig_meds[c]
170
+ if m != 0:
171
+ x[..., c] = ((m - 1) * target_median * x[..., c]) / (
172
+ m * (target_median + x[..., c] - 1) - target_median * x[..., c]
173
+ )
174
+ x = np.clip(x, 0, 1)
175
+ return x, orig_min, orig_meds
176
+
177
+
178
+ def unstretch_image_unlinked_rgb(image_rgb: np.ndarray, orig_meds, orig_min: float, was_mono: bool):
179
+ x = image_rgb.astype(np.float32, copy=True)
180
+ for c in range(3):
181
+ m_now = float(np.median(x[..., c]))
182
+ m0 = float(orig_meds[c])
183
+ if m_now != 0 and m0 != 0:
184
+ x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
185
+ m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
186
+ )
187
+ x += float(orig_min)
188
+ x = np.clip(x, 0, 1)
189
+ if was_mono:
190
+ # match your behavior: return mono with keepdims
191
+ x = np.mean(x, axis=2, keepdims=True)
192
+ return x
193
+
194
+
195
+ # ---------------- Auto PSF per chunk (SEP) ----------------
196
+
197
+ def measure_psf_fwhm(chunk2d: np.ndarray, default_fwhm: float = 3.0) -> float:
198
+ if sep is None:
199
+ return default_fwhm
200
+ try:
201
+ data = chunk2d.astype(np.float32, copy=False)
202
+ bkg = sep.Background(data)
203
+ sub = data - bkg.back()
204
+ rms = bkg.rms()
205
+ if rms.size == 0:
206
+ return default_fwhm
207
+ objs = sep.extract(sub, 1.5, err=rms)
208
+ fwhms = []
209
+ for o in objs:
210
+ if o["npix"] < 5:
211
+ continue
212
+ sigma = float(np.sqrt(o["a"] * o["b"]))
213
+ fwhm = sigma * 2.0 * np.sqrt(2.0 * np.log(2.0))
214
+ fwhms.append(fwhm)
215
+ return float(np.median(fwhms) * 0.5) if fwhms else default_fwhm
216
+ except Exception:
217
+ return default_fwhm
218
+
219
+
220
+ # ---------------- Model bundle + loading ----------------
221
+
222
+ @dataclass
223
+ class SharpenModels:
224
+ device: Any # torch.device or "DirectML"
225
+ is_onnx: bool
226
+ stellar: Any
227
+ ns1: Any
228
+ ns2: Any
229
+ ns4: Any
230
+ ns8: Any
231
+ torch: Any | None = None # set for torch path
232
+
233
+
234
+ _MODELS_CACHE: dict[tuple[str, bool], SharpenModels] = {} # (backend_tag, use_gpu)
235
+
236
+ def load_sharpen_models(use_gpu: bool, status_cb=print) -> SharpenModels:
237
+ backend_tag = "cc_sharpen_ai3_5s"
238
+ key = (backend_tag, bool(use_gpu))
239
+ if key in _MODELS_CACHE:
240
+ return _MODELS_CACHE[key]
241
+
242
+ is_windows = (os.name == "nt")
243
+
244
+ # ask runtime to prefer DML only on Windows + when user wants GPU
245
+ torch = _get_torch(
246
+ prefer_cuda=bool(use_gpu), # still try CUDA first
247
+ prefer_dml=bool(use_gpu and is_windows), # enable DML install/usage path
248
+ status_cb=status_cb,
249
+ )
250
+
251
+ # 1) CUDA
252
+ if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
253
+ device = torch.device("cuda")
254
+ status_cb(f"CosmicClarity Sharpen: using CUDA ({torch.cuda.get_device_name(0)})")
255
+ models = _load_torch_models(torch, device)
256
+ _MODELS_CACHE[key] = models
257
+ return models
258
+
259
+ # 2) Torch-DirectML (Windows)
260
+ if use_gpu and is_windows:
261
+ try:
262
+ import torch_directml # provided by torch-directml
263
+ dml = torch_directml.device()
264
+ status_cb("CosmicClarity Sharpen: using DirectML (torch-directml)")
265
+ models = _load_torch_models(torch, dml)
266
+ _MODELS_CACHE[key] = models
267
+ status_cb(f"torch.__version__={getattr(torch,'__version__',None)}; "
268
+ f"cuda_available={bool(getattr(torch,'cuda',None) and torch.cuda.is_available())}")
269
+ return models
270
+
271
+ except Exception:
272
+ pass
273
+
274
+ # 3) ONNX Runtime DirectML fallback
275
+ if use_gpu and ort is not None and "DmlExecutionProvider" in ort.get_available_providers():
276
+ status_cb("CosmicClarity Sharpen: using DirectML (ONNX Runtime)")
277
+ models = _load_onnx_models()
278
+ _MODELS_CACHE[key] = models
279
+ return models
280
+
281
+ # 4) CPU
282
+ device = torch.device("cpu")
283
+ status_cb("CosmicClarity Sharpen: using CPU")
284
+ models = _load_torch_models(torch, device)
285
+ _MODELS_CACHE[key] = models
286
+ status_cb(f"Sharpen backend resolved: "
287
+ f"{'onnx' if models.is_onnx else 'torch'} / device={models.device!r}")
288
+
289
+ return models
290
+
291
+
292
+ def _load_onnx_models() -> SharpenModels:
293
+ assert ort is not None
294
+ prov = ["DmlExecutionProvider"]
295
+ R = get_resources()
296
+
297
+ def s(path: str):
298
+ return ort.InferenceSession(path, providers=prov)
299
+
300
+ return SharpenModels(
301
+ device="DirectML",
302
+ is_onnx=True,
303
+ stellar=s(R.CC_STELLAR_SHARP_ONNX),
304
+ ns1=s(R.CC_NS1_ONNX),
305
+ ns2=s(R.CC_NS2_ONNX),
306
+ ns4=s(R.CC_NS4_ONNX),
307
+ ns8=s(R.CC_NS8_ONNX),
308
+ torch=None,
309
+ )
310
+
311
+
312
+
313
+ def _load_torch_models(torch, device) -> SharpenModels:
314
+ import torch.nn as nn # comes from runtime torch env
315
+
316
+ class ResidualBlock(nn.Module):
317
+ def __init__(self, channels):
318
+ super().__init__()
319
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
320
+ self.relu = nn.ReLU()
321
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
322
+ def forward(self, x):
323
+ r = x
324
+ x = self.relu(self.conv1(x))
325
+ x = self.conv2(x)
326
+ x = self.relu(x + r)
327
+ return x
328
+
329
+ class SharpeningCNN(nn.Module):
330
+ def __init__(self):
331
+ super().__init__()
332
+ self.encoder1 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
333
+ self.encoder2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
334
+ self.encoder3 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(64))
335
+ self.encoder4 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
336
+ self.encoder5 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(), ResidualBlock(256))
337
+
338
+ self.decoder5 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(), ResidualBlock(128))
339
+ self.decoder4 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(), ResidualBlock(64))
340
+ self.decoder3 = nn.Sequential(nn.Conv2d( 64 + 32, 32, 3, padding=1), nn.ReLU(), ResidualBlock(32))
341
+ self.decoder2 = nn.Sequential(nn.Conv2d( 32 + 16, 16, 3, padding=1), nn.ReLU(), ResidualBlock(16))
342
+ self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
343
+
344
+ def forward(self, x):
345
+ e1 = self.encoder1(x)
346
+ e2 = self.encoder2(e1)
347
+ e3 = self.encoder3(e2)
348
+ e4 = self.encoder4(e3)
349
+ e5 = self.encoder5(e4)
350
+ d5 = self.decoder5(torch.cat([e5, e4], dim=1))
351
+ d4 = self.decoder4(torch.cat([d5, e3], dim=1))
352
+ d3 = self.decoder3(torch.cat([d4, e2], dim=1))
353
+ d2 = self.decoder2(torch.cat([d3, e1], dim=1))
354
+ return self.decoder1(d2)
355
+
356
+ R = get_resources()
357
+
358
+ def m(path: str):
359
+ net = SharpeningCNN()
360
+ net.load_state_dict(torch.load(path, map_location=device))
361
+ net.eval().to(device)
362
+ return net
363
+
364
+ return SharpenModels(
365
+ device=device,
366
+ is_onnx=False,
367
+ stellar=m(R.CC_STELLAR_SHARP_PTH),
368
+ ns1=m(R.CC_NS1_PTH),
369
+ ns2=m(R.CC_NS2_PTH),
370
+ ns4=m(R.CC_NS4_PTH),
371
+ ns8=m(R.CC_NS8_PTH),
372
+ torch=torch,
373
+ )
374
+
375
+
376
+
377
+ # ---------------- Inference helpers ----------------
378
+
379
+ def _infer_chunk(models: SharpenModels, model: Any, chunk2d: np.ndarray) -> np.ndarray:
380
+ """Returns 2D float32 (cropped to original chunk shape)."""
381
+ h0, w0 = chunk2d.shape
382
+
383
+ if models.is_onnx:
384
+ inp = chunk2d[np.newaxis, np.newaxis, :, :].astype(np.float32) # (1,1,H,W)
385
+ inp = np.tile(inp, (1, 3, 1, 1)) # (1,3,H,W)
386
+ h, w = inp.shape[2:]
387
+ if (h != 256) or (w != 256):
388
+ pad = np.zeros((1, 3, 256, 256), dtype=np.float32)
389
+ pad[:, :, :h, :w] = inp
390
+ inp = pad
391
+ name_in = model.get_inputs()[0].name
392
+ name_out = model.get_outputs()[0].name
393
+ out = model.run([name_out], {name_in: inp})[0][0, 0]
394
+ return out[:h0, :w0].astype(np.float32, copy=False)
395
+
396
+ # torch path
397
+ torch = models.torch
398
+ dev = models.device
399
+ t = torch.tensor(chunk2d, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(dev)
400
+ with torch.no_grad(), _autocast_context(torch, dev):
401
+ y = model(t.repeat(1, 3, 1, 1)).squeeze().detach().cpu().numpy()[0]
402
+ return y[:h0, :w0].astype(np.float32, copy=False)
403
+
404
+
405
+ # ---------------- Main API ----------------
406
+
407
+ @dataclass
408
+ class SharpenParams:
409
+ mode: str # "Both" | "Stellar Only" | "Non-Stellar Only"
410
+ stellar_amount: float # 0..1
411
+ nonstellar_amount: float # 0..1
412
+ nonstellar_strength: float # 1..8 (ignored if auto_detect_psf True)
413
+ sharpen_channels_separately: bool
414
+ auto_detect_psf: bool
415
+ use_gpu: bool
416
+
417
+
418
+ def sharpen_image_array(image: np.ndarray,
419
+ params: SharpenParams,
420
+ progress_cb: Optional[ProgressCB] = None,
421
+ status_cb=print) -> tuple[np.ndarray, bool]:
422
+ """
423
+ Pure in-memory sharpen. Returns (out_image, was_mono).
424
+ """
425
+ if progress_cb is None:
426
+ progress_cb = lambda done, total, stage: True
427
+
428
+ img = np.asarray(image)
429
+ if img.dtype != np.float32:
430
+ img = img.astype(np.float32, copy=False)
431
+
432
+ img3, was_mono = _to_3ch(img)
433
+ img3 = np.clip(img3, 0.0, 1.0)
434
+
435
+ models = load_sharpen_models(use_gpu=params.use_gpu, status_cb=status_cb)
436
+
437
+ # border & stretch
438
+ bordered = add_border(img3, border_size=16)
439
+ stretch_needed = (np.median(bordered - np.min(bordered)) < 0.08)
440
+
441
+ if stretch_needed:
442
+ stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(bordered)
443
+ else:
444
+ stretched, orig_min, orig_meds = bordered, None, None
445
+
446
+ # per-channel sharpening option (color only)
447
+ if params.sharpen_channels_separately and (not was_mono):
448
+ out = np.empty_like(stretched)
449
+ for c, label in enumerate(("R", "G", "B")):
450
+ progress_cb(0, 1, f"Sharpening {label} channel")
451
+ out[..., c] = _sharpen_plane(models, stretched[..., c], params, progress_cb)
452
+ sharpened = out
453
+ else:
454
+ # luminance pipeline (works for mono too, since mono is in all 3 chans)
455
+ y, cb, cr = extract_luminance_rgb(stretched)
456
+ y2 = _sharpen_plane(models, y, params, progress_cb)
457
+ sharpened = merge_luminance(y2, cb, cr)
458
+
459
+ # unstretch / deborder
460
+ if stretch_needed:
461
+ sharpened = unstretch_image_unlinked_rgb(sharpened, orig_meds, orig_min, was_mono)
462
+
463
+ sharpened = remove_border(sharpened, border_size=16)
464
+
465
+ # return mono as HxWx1 if it came in mono (matches your CC behavior)
466
+ if was_mono:
467
+ if sharpened.ndim == 3 and sharpened.shape[2] == 3:
468
+ sharpened = np.mean(sharpened, axis=2, keepdims=True).astype(np.float32, copy=False)
469
+
470
+ return np.clip(sharpened, 0.0, 1.0), was_mono
471
+
472
+
473
+ def _sharpen_plane(models: SharpenModels,
474
+ plane: np.ndarray,
475
+ params: SharpenParams,
476
+ progress_cb: ProgressCB) -> np.ndarray:
477
+ """
478
+ Sharpen a single 2D plane using your two-stage pipeline.
479
+ """
480
+ plane = np.asarray(plane, np.float32)
481
+ chunks = split_image_into_chunks_with_overlap(plane, chunk_size=256, overlap=64)
482
+ total = len(chunks)
483
+
484
+ # Stage 1: stellar
485
+ if params.mode in ("Stellar Only", "Both"):
486
+ out_chunks = []
487
+ for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
488
+ y = _infer_chunk(models, models.stellar, chunk)
489
+ blended = blend_images(chunk, y, params.stellar_amount)
490
+ out_chunks.append((blended, i, j, is_edge))
491
+ if progress_cb(k, total, "Stellar sharpening") is False:
492
+
493
+ return plane
494
+ plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
495
+
496
+ if params.mode == "Stellar Only":
497
+ return plane
498
+
499
+ # update chunks for stage 2
500
+ chunks = split_image_into_chunks_with_overlap(plane, chunk_size=256, overlap=64)
501
+ total = len(chunks)
502
+
503
+ # Stage 2: non-stellar
504
+ if params.mode in ("Non-Stellar Only", "Both"):
505
+ out_chunks = []
506
+ radii = np.array([1.0, 2.0, 4.0, 8.0], dtype=float)
507
+ model_map = {1.0: models.ns1, 2.0: models.ns2, 4.0: models.ns4, 8.0: models.ns8}
508
+
509
+ for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
510
+ if params.auto_detect_psf:
511
+ fwhm = measure_psf_fwhm(chunk, default_fwhm=3.0)
512
+ r = float(np.clip(fwhm, radii[0], radii[-1]))
513
+ else:
514
+ r = float(np.clip(params.nonstellar_strength, radii[0], radii[-1]))
515
+
516
+ idx = int(np.searchsorted(radii, r, side="left"))
517
+ if idx <= 0:
518
+ lo = hi = radii[0]
519
+ elif idx >= len(radii):
520
+ lo = hi = radii[-1]
521
+ else:
522
+ lo, hi = radii[idx-1], radii[idx]
523
+
524
+ if lo == hi:
525
+ y = _infer_chunk(models, model_map[lo], chunk)
526
+ else:
527
+ w = (r - lo) / (hi - lo)
528
+ y0 = _infer_chunk(models, model_map[lo], chunk)
529
+ y1 = _infer_chunk(models, model_map[hi], chunk)
530
+ y = (1.0 - w) * y0 + w * y1
531
+
532
+ blended = blend_images(chunk, y, params.nonstellar_amount)
533
+ out_chunks.append((blended, i, j, is_edge))
534
+ if progress_cb(k, total, "Non-stellar sharpening") is False:
535
+ return plane
536
+ progress_cb(k, total, "Non-stellar sharpening")
537
+
538
+ plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
539
+
540
+ return plane
541
+
542
+ def sharpen_rgb01(
543
+ image_rgb01: np.ndarray,
544
+ *,
545
+ sharpening_mode: str = "Both",
546
+ stellar_amount: float = 0.5,
547
+ nonstellar_amount: float = 0.5,
548
+ nonstellar_strength: float = 3.0,
549
+ auto_detect_psf: bool = True,
550
+ separate_channels: bool = False,
551
+ use_gpu: bool = True,
552
+ progress_cb: Optional[Callable[[int, int], bool]] = None,
553
+ status_cb=print,
554
+ ) -> np.ndarray:
555
+ """
556
+ Backward-compatible API for SASpro CosmicClarityDialogPro.
557
+ Expects/returns float32 RGB in [0,1]. If input is mono, returns HxWx1.
558
+ progress_cb signature: (done, total) -> UI-friendly.
559
+ """
560
+ # Adapt UI progress_cb(done,total) -> engine progress_cb(done,total,stage)
561
+ if progress_cb is None:
562
+ def _prog(done, total, stage): # noqa
563
+ return True
564
+ else:
565
+ def _prog(done, total, stage):
566
+ try:
567
+ return bool(progress_cb(int(done), int(total)))
568
+ except Exception:
569
+ return True
570
+
571
+ params = SharpenParams(
572
+ mode=str(sharpening_mode),
573
+ stellar_amount=float(stellar_amount),
574
+ nonstellar_amount=float(nonstellar_amount),
575
+ nonstellar_strength=float(nonstellar_strength),
576
+ sharpen_channels_separately=bool(separate_channels),
577
+ auto_detect_psf=bool(auto_detect_psf),
578
+ use_gpu=bool(use_gpu),
579
+ )
580
+
581
+ out, _was_mono = sharpen_image_array(
582
+ image_rgb01,
583
+ params=params,
584
+ progress_cb=_prog,
585
+ status_cb=status_cb,
586
+ )
587
+ return np.asarray(out, dtype=np.float32)