setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,576 @@
1
+ # src/setiastro/saspro/cosmicclarity_engines/darkstar_engine.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Optional
6
+
7
+ import numpy as np
8
+
9
+ from setiastro.saspro.resources import get_resources
10
+
11
+ # Optional deps
12
+ try:
13
+ import onnxruntime as ort
14
+ except Exception:
15
+ ort = None
16
+
17
+ ProgressCB = Callable[[int, int, str], None] # (done, total, stage)
18
+
19
+
20
+ # ---------------- Torch import (your existing runtime_torch helper) ----------------
21
+
22
+ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
23
+ from setiastro.saspro.runtime_torch import import_torch
24
+ return import_torch(
25
+ prefer_cuda=prefer_cuda,
26
+ prefer_xpu=False,
27
+ prefer_dml=prefer_dml,
28
+ status_cb=status_cb,
29
+ )
30
+
31
+
32
+ def _nullcontext():
33
+ from contextlib import nullcontext
34
+ return nullcontext()
35
+
36
+
37
+ def _autocast_context(torch, device) -> Any:
38
+ """
39
+ Use new torch.amp.autocast('cuda') when available.
40
+ Keep your cap >= 8.0 rule.
41
+ """
42
+ try:
43
+ if hasattr(device, "type") and device.type == "cuda":
44
+ major, minor = torch.cuda.get_device_capability()
45
+ cap = float(f"{major}.{minor}")
46
+ if cap >= 8.0:
47
+ # Preferred API (torch >= 1.10-ish; definitely in 2.x)
48
+ if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
49
+ return torch.amp.autocast(device_type="cuda")
50
+ # Fallback for older torch
51
+ return torch.cuda.amp.autocast()
52
+ except Exception:
53
+ pass
54
+ return _nullcontext()
55
+
56
+
57
+ # ---------------- Models (same topology as your script) ----------------
58
+
59
+ def _build_darkstar_torch_models(torch):
60
+ import torch.nn as nn
61
+
62
+ class RefinementCNN(nn.Module):
63
+ def __init__(self, channels: int = 96):
64
+ super().__init__()
65
+ self.net = nn.Sequential(
66
+ nn.Conv2d(3, channels, 3, padding=1, dilation=1), nn.ReLU(),
67
+ nn.Conv2d(channels, channels, 3, padding=2, dilation=2), nn.ReLU(),
68
+ nn.Conv2d(channels, channels, 3, padding=4, dilation=4), nn.ReLU(),
69
+ nn.Conv2d(channels, channels, 3, padding=8, dilation=8), nn.ReLU(),
70
+ nn.Conv2d(channels, channels, 3, padding=8, dilation=8), nn.ReLU(),
71
+ nn.Conv2d(channels, channels, 3, padding=4, dilation=4), nn.ReLU(),
72
+ nn.Conv2d(channels, channels, 3, padding=2, dilation=2), nn.ReLU(),
73
+ nn.Conv2d(channels, 3, 3, padding=1, dilation=1), nn.Sigmoid()
74
+ )
75
+ def forward(self, x): return self.net(x)
76
+
77
+ class ResidualBlock(nn.Module):
78
+ def __init__(self, channels: int):
79
+ super().__init__()
80
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
83
+ def forward(self, x):
84
+ out = self.relu(self.conv1(x))
85
+ out = self.conv2(out)
86
+ return self.relu(out + x)
87
+
88
+ class DarkStarCNN(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.encoder1 = nn.Sequential(
92
+ nn.Conv2d(3, 16, 3, padding=1),
93
+ nn.ReLU(inplace=True),
94
+ ResidualBlock(16), ResidualBlock(16), ResidualBlock(16),
95
+ )
96
+ self.encoder2 = nn.Sequential(
97
+ nn.Conv2d(16, 32, 3, padding=1),
98
+ nn.ReLU(inplace=True),
99
+ ResidualBlock(32), ResidualBlock(32), ResidualBlock(32),
100
+ )
101
+ self.encoder3 = nn.Sequential(
102
+ nn.Conv2d(32, 64, 3, padding=2, dilation=2),
103
+ nn.ReLU(inplace=True),
104
+ ResidualBlock(64), ResidualBlock(64),
105
+ )
106
+ self.encoder4 = nn.Sequential(
107
+ nn.Conv2d(64, 128, 3, padding=1),
108
+ nn.ReLU(inplace=True),
109
+ ResidualBlock(128), ResidualBlock(128),
110
+ )
111
+ self.encoder5 = nn.Sequential(
112
+ nn.Conv2d(128, 256, 3, padding=2, dilation=2),
113
+ nn.ReLU(inplace=True),
114
+ ResidualBlock(256),
115
+ )
116
+
117
+ self.decoder5 = nn.Sequential(
118
+ nn.Conv2d(256 + 128, 128, 3, padding=1),
119
+ nn.ReLU(inplace=True),
120
+ ResidualBlock(128), ResidualBlock(128),
121
+ )
122
+ self.decoder4 = nn.Sequential(
123
+ nn.Conv2d(128 + 64, 64, 3, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ ResidualBlock(64), ResidualBlock(64),
126
+ )
127
+ self.decoder3 = nn.Sequential(
128
+ nn.Conv2d(64 + 32, 32, 3, padding=1),
129
+ nn.ReLU(inplace=True),
130
+ ResidualBlock(32), ResidualBlock(32), ResidualBlock(32),
131
+ )
132
+ self.decoder2 = nn.Sequential(
133
+ nn.Conv2d(32 + 16, 16, 3, padding=1),
134
+ nn.ReLU(inplace=True),
135
+ ResidualBlock(16), ResidualBlock(16), ResidualBlock(16),
136
+ )
137
+ self.decoder1 = nn.Sequential(
138
+ nn.Conv2d(16, 16, 3, padding=1),
139
+ nn.ReLU(inplace=True),
140
+ ResidualBlock(16), ResidualBlock(16),
141
+ nn.Conv2d(16, 3, 3, padding=1),
142
+ nn.Sigmoid(),
143
+ )
144
+
145
+ def forward(self, x):
146
+ e1 = self.encoder1(x)
147
+ e2 = self.encoder2(e1)
148
+ e3 = self.encoder3(e2)
149
+ e4 = self.encoder4(e3)
150
+ e5 = self.encoder5(e4)
151
+
152
+ d5 = self.decoder5(torch.cat([e5, e4], dim=1))
153
+ d4 = self.decoder4(torch.cat([d5, e3], dim=1))
154
+ d3 = self.decoder3(torch.cat([d4, e2], dim=1))
155
+ d2 = self.decoder2(torch.cat([d3, e1], dim=1))
156
+ return self.decoder1(d2)
157
+
158
+ class CascadedStarRemovalNetCombined(nn.Module):
159
+ def __init__(self, stage1_path: str, stage2_path: str | None = None):
160
+ super().__init__()
161
+ self.stage1 = DarkStarCNN()
162
+ ckpt1 = torch.load(stage1_path, map_location="cpu")
163
+
164
+ # strip "stage1." prefix if present
165
+ if isinstance(ckpt1, dict):
166
+ sd1 = {k[len("stage1."):] : v for k, v in ckpt1.items() if k.startswith("stage1.")}
167
+ if sd1:
168
+ ckpt1 = sd1
169
+ self.stage1.load_state_dict(ckpt1)
170
+
171
+ # refinement exists in your code but currently not used (forward returns coarse)
172
+ self.stage2 = RefinementCNN()
173
+ if stage2_path:
174
+ try:
175
+ ckpt2 = torch.load(stage2_path, map_location="cpu")
176
+ if isinstance(ckpt2, dict) and "model_state" in ckpt2:
177
+ ckpt2 = ckpt2["model_state"]
178
+ self.stage2.load_state_dict(ckpt2)
179
+ except Exception:
180
+ pass
181
+
182
+ for p in self.stage1.parameters():
183
+ p.requires_grad = False
184
+
185
+ def forward(self, x):
186
+ with torch.no_grad():
187
+ coarse = self.stage1(x)
188
+ return coarse
189
+
190
+ return CascadedStarRemovalNetCombined
191
+
192
+
193
+ # ---------------- Stretch/unstretch + borders (match your other engines) ----------------
194
+
195
+ def add_border(image: np.ndarray, border_size: int = 5) -> np.ndarray:
196
+ if image.ndim == 2:
197
+ med = float(np.median(image))
198
+ return np.pad(image, ((border_size, border_size), (border_size, border_size)),
199
+ mode="constant", constant_values=med)
200
+ if image.ndim == 3 and image.shape[2] == 3:
201
+ meds = np.median(image, axis=(0, 1)).astype(np.float32)
202
+ chans = []
203
+ for c in range(3):
204
+ chans.append(np.pad(image[..., c], ((border_size, border_size), (border_size, border_size)),
205
+ mode="constant", constant_values=float(meds[c])))
206
+ return np.stack(chans, axis=-1)
207
+ raise ValueError("add_border expects 2D or HxWx3")
208
+
209
+
210
+ def remove_border(image: np.ndarray, border_size: int = 5) -> np.ndarray:
211
+ if image.ndim == 2:
212
+ return image[border_size:-border_size, border_size:-border_size]
213
+ return image[border_size:-border_size, border_size:-border_size, :]
214
+
215
+
216
+ def stretch_image_unlinked_rgb(img_rgb: np.ndarray, target_median: float = 0.25):
217
+ x = img_rgb.astype(np.float32, copy=True)
218
+ orig_min = x.reshape(-1, 3).min(axis=0) # (3,)
219
+ x = (x - orig_min.reshape(1, 1, 3))
220
+ orig_meds = np.median(x, axis=(0, 1)).astype(np.float32)
221
+
222
+ for c in range(3):
223
+ m = float(orig_meds[c])
224
+ if m != 0:
225
+ x[..., c] = ((m - 1) * target_median * x[..., c]) / (
226
+ m * (target_median + x[..., c] - 1) - target_median * x[..., c]
227
+ )
228
+ x = np.clip(x, 0, 1)
229
+ return x, orig_min.astype(np.float32), orig_meds.astype(np.float32)
230
+
231
+
232
+ def unstretch_image_unlinked_rgb(img_rgb: np.ndarray, orig_meds, orig_min):
233
+ x = img_rgb.astype(np.float32, copy=True)
234
+ for c in range(3):
235
+ m_now = float(np.median(x[..., c]))
236
+ m0 = float(orig_meds[c])
237
+ if m_now != 0 and m0 != 0:
238
+ x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
239
+ m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
240
+ )
241
+ x = x + orig_min.reshape(1, 1, 3)
242
+ return np.clip(x, 0, 1).astype(np.float32, copy=False)
243
+
244
+
245
+ # ---------------- Chunking & stitch (soft blend like your script) ----------------
246
+
247
+ def split_image_into_chunks_with_overlap(image: np.ndarray, chunk_size: int, overlap: int):
248
+ H, W = image.shape[:2]
249
+ step = chunk_size - overlap
250
+ out = []
251
+ for i in range(0, H, step):
252
+ for j in range(0, W, step):
253
+ ei = min(i + chunk_size, H)
254
+ ej = min(j + chunk_size, W)
255
+ out.append((image[i:ei, j:ej], i, j))
256
+ return out
257
+
258
+
259
+ def _blend_weights(chunk_size: int, overlap: int):
260
+ if overlap <= 0:
261
+ return np.ones((chunk_size, chunk_size), dtype=np.float32)
262
+ ramp = np.linspace(0, 1, overlap, dtype=np.float32)
263
+ flat = np.ones(max(chunk_size - 2 * overlap, 1), dtype=np.float32)
264
+ v = np.concatenate([ramp, flat, ramp[::-1]])
265
+ w = np.outer(v, v).astype(np.float32)
266
+ return w
267
+
268
+
269
+ def stitch_chunks_soft_blend(
270
+ chunks: list[tuple[np.ndarray, int, int]],
271
+ out_shape: tuple[int, int, int],
272
+ *,
273
+ chunk_size: int,
274
+ overlap: int,
275
+ border_size: int = 5,
276
+ ) -> np.ndarray:
277
+ H, W, C = out_shape
278
+ out = np.zeros((H, W, C), np.float32)
279
+ wsum = np.zeros((H, W, 1), np.float32)
280
+ bw_full = _blend_weights(chunk_size, overlap)
281
+
282
+ for tile, i, j in chunks:
283
+ th, tw = tile.shape[:2]
284
+
285
+ # adaptive inner crop like your script
286
+ top = 0 if i == 0 else min(border_size, th // 2)
287
+ left = 0 if j == 0 else min(border_size, tw // 2)
288
+ bottom = 0 if (i + th) >= H else min(border_size, th // 2)
289
+ right = 0 if (j + tw) >= W else min(border_size, tw // 2)
290
+
291
+ inner = tile[top:th-bottom, left:tw-right, :]
292
+ ih, iw = inner.shape[:2]
293
+
294
+ rr0 = i + top
295
+ cc0 = j + left
296
+ rr1 = rr0 + ih
297
+ cc1 = cc0 + iw
298
+
299
+ bw = bw_full[:ih, :iw].reshape(ih, iw, 1)
300
+ out[rr0:rr1, cc0:cc1, :] += inner * bw
301
+ wsum[rr0:rr1, cc0:cc1, :] += bw
302
+
303
+ out = out / np.maximum(wsum, 1e-8)
304
+ return out
305
+
306
+
307
+ # ---------------- Model loading (cached) ----------------
308
+
309
+ @dataclass
310
+ class DarkStarModels:
311
+ device: Any
312
+ is_onnx: bool
313
+ model: Any
314
+ torch: Any | None = None
315
+ chunk_size: int = 512 # used for ONNX fixed shapes
316
+
317
+
318
+ # ---------------- Model loading (cached) ----------------
319
+
320
+ _MODELS_CACHE: dict[tuple[str, str], DarkStarModels] = {} # (tag, backend_id)
321
+
322
+ def load_darkstar_models(*, use_gpu: bool, color: bool, status_cb=print) -> DarkStarModels:
323
+ """
324
+ Backend order:
325
+ 1) CUDA (PyTorch)
326
+ 2) DirectML (torch-directml) [Windows]
327
+ 3) DirectML (ONNX Runtime) [Windows]
328
+ 4) MPS (PyTorch) [macOS]
329
+ 5) CPU (PyTorch)
330
+ Cache key includes backend_id so we never "stick" on CPU when GPU is later enabled.
331
+ """
332
+ R = get_resources()
333
+
334
+ if color:
335
+ pth = R.CC_DARKSTAR_COLOR_PTH
336
+ onnx = R.CC_DARKSTAR_COLOR_ONNX
337
+ tag = "cc_darkstar_color"
338
+ else:
339
+ pth = R.CC_DARKSTAR_MONO_PTH
340
+ onnx = R.CC_DARKSTAR_MONO_ONNX
341
+ tag = "cc_darkstar_mono"
342
+
343
+ import os
344
+ is_windows = os.name == "nt"
345
+
346
+ # Request torch with the right preferences (runtime_torch decides what it can do)
347
+ torch = _get_torch(
348
+ prefer_cuda=bool(use_gpu),
349
+ prefer_dml=bool(use_gpu and is_windows),
350
+ status_cb=status_cb,
351
+ )
352
+
353
+ # ---------------- CUDA (torch) ----------------
354
+ if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
355
+ backend_id = "cuda"
356
+ key = (tag, backend_id)
357
+ if key in _MODELS_CACHE:
358
+ return _MODELS_CACHE[key]
359
+
360
+ dev = torch.device("cuda")
361
+ status_cb(f"Dark Star: using CUDA ({torch.cuda.get_device_name(0)})")
362
+ Net = _build_darkstar_torch_models(torch)
363
+ net = Net(pth, None).eval().to(dev)
364
+
365
+ m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
366
+ _MODELS_CACHE[key] = m
367
+ return m
368
+
369
+ # ---------------- DirectML (torch-directml) ----------------
370
+ if use_gpu and is_windows:
371
+ try:
372
+ import torch_directml # optional
373
+ backend_id = "torch_dml"
374
+ key = (tag, backend_id)
375
+ if key in _MODELS_CACHE:
376
+ return _MODELS_CACHE[key]
377
+
378
+ dev = torch_directml.device()
379
+ status_cb("Dark Star: using DirectML (torch-directml)")
380
+ Net = _build_darkstar_torch_models(torch)
381
+ net = Net(pth, None).eval().to(dev)
382
+
383
+ m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
384
+ _MODELS_CACHE[key] = m
385
+ return m
386
+ except Exception:
387
+ pass
388
+
389
+ # ---------------- DirectML (ONNX Runtime) ----------------
390
+ if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
391
+ if onnx and onnx.strip():
392
+ backend_id = "ort_dml"
393
+ key = (tag, backend_id)
394
+ if key in _MODELS_CACHE:
395
+ return _MODELS_CACHE[key]
396
+
397
+ status_cb("Dark Star: using DirectML (ONNX Runtime)")
398
+ sess = ort.InferenceSession(onnx, providers=["DmlExecutionProvider"])
399
+
400
+ # fixed-ish input: [1,3,H,W]; some exports have static shapes
401
+ inp = sess.get_inputs()[0]
402
+ cs = 512
403
+ try:
404
+ if getattr(inp, "shape", None) and len(inp.shape) >= 4:
405
+ if inp.shape[2] not in (None, "None"):
406
+ cs = int(inp.shape[2])
407
+ except Exception:
408
+ pass
409
+
410
+ m = DarkStarModels(device="DirectML", is_onnx=True, model=sess, torch=None, chunk_size=cs)
411
+ _MODELS_CACHE[key] = m
412
+ return m
413
+
414
+ # ---------------- MPS (torch) ----------------
415
+ if use_gpu and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
416
+ backend_id = "mps"
417
+ key = (tag, backend_id)
418
+ if key in _MODELS_CACHE:
419
+ return _MODELS_CACHE[key]
420
+
421
+ dev = torch.device("mps")
422
+ status_cb("Dark Star: using MPS")
423
+ Net = _build_darkstar_torch_models(torch)
424
+ net = Net(pth, None).eval().to(dev)
425
+
426
+ m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
427
+ _MODELS_CACHE[key] = m
428
+ return m
429
+
430
+ # ---------------- CPU (torch) ----------------
431
+ backend_id = "cpu"
432
+ key = (tag, backend_id)
433
+ if key in _MODELS_CACHE:
434
+ return _MODELS_CACHE[key]
435
+
436
+ dev = torch.device("cpu")
437
+ status_cb("Dark Star: using CPU")
438
+ Net = _build_darkstar_torch_models(torch)
439
+ net = Net(pth, None).eval().to(dev)
440
+
441
+ m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
442
+ _MODELS_CACHE[key] = m
443
+ return m
444
+
445
+ # ---------------- Core inference on one HxWx3 image ----------------
446
+
447
+ def _infer_tile(models: DarkStarModels, tile_rgb: np.ndarray) -> np.ndarray:
448
+ tile_rgb = np.asarray(tile_rgb, np.float32)
449
+ h0, w0 = tile_rgb.shape[:2]
450
+
451
+ if models.is_onnx:
452
+ cs = int(models.chunk_size)
453
+
454
+ # pad/crop robustly to (cs,cs,3)
455
+ if (h0 != cs) or (w0 != cs):
456
+ pad = np.zeros((cs, cs, 3), np.float32)
457
+ hh = min(h0, cs)
458
+ ww = min(w0, cs)
459
+ pad[:hh, :ww, :] = tile_rgb[:hh, :ww, :]
460
+ tile_rgb = pad
461
+
462
+ inp = tile_rgb.transpose(2, 0, 1)[None, ...] # 1,3,H,W
463
+ sess = models.model
464
+ out = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0] # 3,H,W
465
+ out = out.transpose(1, 2, 0)
466
+
467
+ hh = min(h0, cs)
468
+ ww = min(w0, cs)
469
+ return out[:hh, :ww, :].astype(np.float32, copy=False)
470
+
471
+ # torch (CUDA / MPS / CPU / torch-directml)
472
+ torch = models.torch
473
+ dev = models.device
474
+ t = torch.from_numpy(tile_rgb.transpose(2, 0, 1)).unsqueeze(0).to(dev)
475
+
476
+ with torch.no_grad(), _autocast_context(torch, dev):
477
+ y = models.model(t)[0].detach().cpu().numpy().transpose(1, 2, 0)
478
+
479
+ return y[:h0, :w0, :].astype(np.float32, copy=False)
480
+
481
+ # ---------------- Public API ----------------
482
+
483
+ @dataclass
484
+ class DarkStarParams:
485
+ use_gpu: bool = True
486
+ chunk_size: int = 512
487
+ overlap_frac: float = 0.125
488
+ mode: str = "unscreen" # "unscreen" or "additive"
489
+ output_stars_only: bool = False
490
+
491
+
492
+ def darkstar_starremoval_rgb01(
493
+ img_rgb01: np.ndarray,
494
+ *,
495
+ params: DarkStarParams,
496
+ progress_cb: Optional[ProgressCB] = None,
497
+ status_cb=print,
498
+ ) -> tuple[np.ndarray, Optional[np.ndarray], bool]:
499
+ """
500
+ Input : float32 image in [0..1], shape HxWx3 or HxWx1 or HxW
501
+ Output: (starless_rgb01, stars_only_rgb01 or None, was_mono)
502
+ """
503
+ if progress_cb is None:
504
+ progress_cb = lambda done, total, stage: None
505
+
506
+ img = np.asarray(img_rgb01, np.float32)
507
+ was_mono = (img.ndim == 2) or (img.ndim == 3 and img.shape[2] == 1)
508
+
509
+ # normalize shape to HxWx3
510
+ if img.ndim == 2:
511
+ img3 = np.stack([img, img, img], axis=-1)
512
+ elif img.ndim == 3 and img.shape[2] == 1:
513
+ ch = img[..., 0]
514
+ img3 = np.stack([ch, ch, ch], axis=-1)
515
+ else:
516
+ img3 = img
517
+
518
+ img3 = np.clip(img3, 0.0, 1.0)
519
+
520
+ # decide "true RGB" vs "3-channel mono"
521
+ same_rg = np.allclose(img3[..., 0], img3[..., 1], rtol=0, atol=1e-6)
522
+ same_rb = np.allclose(img3[..., 0], img3[..., 2], rtol=0, atol=1e-6)
523
+ is_true_rgb = not (same_rg and same_rb)
524
+
525
+ models = load_darkstar_models(use_gpu=params.use_gpu, color=is_true_rgb, status_cb=status_cb)
526
+
527
+ # stretch decision: pedestal-aware (matches other engines more closely)
528
+ stretch_needed = float(np.median(img3 - float(np.min(img3)))) < 0.125
529
+ if stretch_needed:
530
+ stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(img3)
531
+ else:
532
+ stretched, orig_min, orig_meds = img3, None, None
533
+
534
+ bordered = add_border(stretched, border_size=5)
535
+
536
+ # ONNX may force chunk_size
537
+ chunk_size = int(models.chunk_size) if models.is_onnx else int(params.chunk_size)
538
+ overlap = int(round(float(params.overlap_frac) * chunk_size))
539
+
540
+ chunks = split_image_into_chunks_with_overlap(bordered, chunk_size=chunk_size, overlap=overlap)
541
+ total = len(chunks)
542
+
543
+ out_tiles: list[tuple[np.ndarray, int, int]] = []
544
+ for k, (tile, i, j) in enumerate(chunks, start=1):
545
+ out = _infer_tile(models, tile)
546
+ out_tiles.append((out, i, j))
547
+ progress_cb(k, total, "Dark Star removal")
548
+
549
+ starless_b = stitch_chunks_soft_blend(
550
+ out_tiles,
551
+ bordered.shape,
552
+ chunk_size=chunk_size,
553
+ overlap=overlap,
554
+ border_size=5,
555
+ )
556
+
557
+ if stretch_needed:
558
+ starless_b = unstretch_image_unlinked_rgb(starless_b, orig_meds, orig_min)
559
+
560
+ starless = remove_border(starless_b, border_size=5)
561
+ starless = np.clip(starless, 0.0, 1.0).astype(np.float32, copy=False)
562
+
563
+ stars_only = None
564
+ if params.output_stars_only:
565
+ if params.mode == "additive":
566
+ stars_only = np.clip(img3 - starless, 0.0, 1.0).astype(np.float32, copy=False)
567
+ else: # unscreen
568
+ denom = np.maximum(1.0 - starless, 1e-6)
569
+ stars_only = np.clip((img3 - starless) / denom, 0.0, 1.0).astype(np.float32, copy=False)
570
+
571
+ if was_mono:
572
+ starless = np.mean(starless, axis=2, keepdims=True).astype(np.float32, copy=False)
573
+ if stars_only is not None:
574
+ stars_only = np.mean(stars_only, axis=2, keepdims=True).astype(np.float32, copy=False)
575
+
576
+ return starless, stars_only, was_mono