setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (27) hide show
  1. setiastro/saspro/_generated/build_info.py +2 -2
  2. setiastro/saspro/accel_installer.py +21 -8
  3. setiastro/saspro/accel_workers.py +11 -12
  4. setiastro/saspro/comet_stacking.py +113 -85
  5. setiastro/saspro/cosmicclarity.py +604 -826
  6. setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
  7. setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
  8. setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
  9. setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
  10. setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
  11. setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
  12. setiastro/saspro/gui/main_window.py +14 -0
  13. setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
  14. setiastro/saspro/model_manager.py +324 -0
  15. setiastro/saspro/model_workers.py +102 -0
  16. setiastro/saspro/ops/benchmark.py +320 -0
  17. setiastro/saspro/ops/settings.py +407 -10
  18. setiastro/saspro/remove_stars.py +424 -442
  19. setiastro/saspro/resources.py +73 -10
  20. setiastro/saspro/runtime_torch.py +107 -22
  21. setiastro/saspro/signature_insert.py +14 -3
  22. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
  23. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
  24. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
  25. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
  26. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
  27. {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
@@ -0,0 +1,620 @@
1
+ # src/setiastro/saspro/cosmicclarity_engines/satellite_engine.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import numpy as np
8
+ from setiastro.saspro.resources import get_resources
9
+
10
+ # Optional deps
11
+ try:
12
+ import onnxruntime as ort
13
+ except Exception:
14
+ ort = None
15
+
16
+ try:
17
+ from skimage.transform import resize as _sk_resize
18
+ except Exception:
19
+ _sk_resize = None
20
+
21
+ ProgressCB = Callable[[int, int], None] # (done, total)
22
+
23
+ # ---------- Torch import (updated: CUDA + torch-directml awareness) ----------
24
+
25
+ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
26
+ from setiastro.saspro.runtime_torch import import_torch
27
+ return import_torch(
28
+ prefer_cuda=prefer_cuda,
29
+ prefer_xpu=False,
30
+ prefer_dml=prefer_dml,
31
+ status_cb=status_cb,
32
+ )
33
+
34
+
35
+ def _nullcontext():
36
+ from contextlib import nullcontext
37
+ return nullcontext()
38
+
39
+ def _autocast_context(torch, device) -> Any:
40
+ """
41
+ Use new torch.amp.autocast('cuda') when available.
42
+ Keep your cap >= 8.0 rule.
43
+ """
44
+ try:
45
+ if hasattr(device, "type") and device.type == "cuda":
46
+ major, minor = torch.cuda.get_device_capability()
47
+ cap = float(f"{major}.{minor}")
48
+ if cap >= 8.0:
49
+ # Preferred API (torch >= 1.10-ish; definitely in 2.x)
50
+ if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
51
+ return torch.amp.autocast(device_type="cuda")
52
+ # Fallback for older torch
53
+ return torch.cuda.amp.autocast()
54
+ except Exception:
55
+ pass
56
+ return _nullcontext()
57
+
58
+
59
+
60
+ # ----------------------------
61
+ # Models (from standalone)
62
+ # ----------------------------
63
+
64
+ def _build_torch_models(torch):
65
+ # Import torch.nn + torchvision lazily, only after torch loads
66
+ import torch.nn as nn
67
+
68
+ try:
69
+ from torchvision import models
70
+ from torchvision.models import ResNet18_Weights, MobileNet_V2_Weights
71
+ from torchvision import transforms
72
+ except Exception as e:
73
+ raise RuntimeError(f"torchvision is required for Satellite engine torch backend: {e}")
74
+
75
+ class ResidualBlock(nn.Module):
76
+ def __init__(self, channels: int):
77
+ super().__init__()
78
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
79
+ self.relu = nn.ReLU()
80
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
81
+ def forward(self, x):
82
+ r = x
83
+ x = self.relu(self.conv1(x))
84
+ x = self.conv2(x)
85
+ x = self.relu(x + r)
86
+ return x
87
+
88
+ class SatelliteRemoverCNN(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.encoder1 = nn.Sequential(
92
+ nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
93
+ ResidualBlock(16), ResidualBlock(16),
94
+ )
95
+ self.encoder2 = nn.Sequential(
96
+ nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
97
+ ResidualBlock(32), ResidualBlock(32),
98
+ )
99
+ self.encoder3 = nn.Sequential(
100
+ nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(),
101
+ ResidualBlock(64), ResidualBlock(64),
102
+ )
103
+ self.encoder4 = nn.Sequential(
104
+ nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
105
+ ResidualBlock(128), ResidualBlock(128),
106
+ )
107
+ self.encoder5 = nn.Sequential(
108
+ nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(),
109
+ ResidualBlock(256), ResidualBlock(256),
110
+ )
111
+ self.decoder5 = nn.Sequential(
112
+ nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(),
113
+ ResidualBlock(128), ResidualBlock(128),
114
+ )
115
+ self.decoder4 = nn.Sequential(
116
+ nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(),
117
+ ResidualBlock(64), ResidualBlock(64),
118
+ )
119
+ self.decoder3 = nn.Sequential(
120
+ nn.Conv2d(64 + 32, 32, 3, padding=1), nn.ReLU(),
121
+ ResidualBlock(32), ResidualBlock(32),
122
+ )
123
+ self.decoder2 = nn.Sequential(
124
+ nn.Conv2d(32 + 16, 16, 3, padding=1), nn.ReLU(),
125
+ ResidualBlock(16), ResidualBlock(16),
126
+ )
127
+ self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
128
+
129
+ def forward(self, x):
130
+ e1 = self.encoder1(x)
131
+ e2 = self.encoder2(e1)
132
+ e3 = self.encoder3(e2)
133
+ e4 = self.encoder4(e3)
134
+ e5 = self.encoder5(e4)
135
+ d5 = self.decoder5(torch.cat([e5, e4], dim=1))
136
+ d4 = self.decoder4(torch.cat([d5, e3], dim=1))
137
+ d3 = self.decoder3(torch.cat([d4, e2], dim=1))
138
+ d2 = self.decoder2(torch.cat([d3, e1], dim=1))
139
+ return self.decoder1(d2)
140
+
141
+ class BinaryClassificationCNN(nn.Module):
142
+ def __init__(self, input_channels: int = 3):
143
+ super().__init__()
144
+ self.pre_conv1 = nn.Sequential(
145
+ nn.Conv2d(input_channels, 32, 3, stride=1, padding=1, bias=False),
146
+ nn.BatchNorm2d(32),
147
+ nn.ReLU()
148
+ )
149
+ self.pre_conv2 = nn.Sequential(
150
+ nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
151
+ nn.BatchNorm2d(64),
152
+ nn.ReLU()
153
+ )
154
+ self.features = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
155
+ self.features.conv1 = nn.Conv2d(64, 64, kernel_size=7, stride=2, padding=3, bias=False)
156
+ self.features.fc = nn.Linear(self.features.fc.in_features, 1)
157
+
158
+ def forward(self, x):
159
+ x = self.pre_conv1(x)
160
+ x = self.pre_conv2(x)
161
+ return self.features(x)
162
+
163
+ class BinaryClassificationCNN2(nn.Module):
164
+ def __init__(self, input_channels: int = 3):
165
+ super().__init__()
166
+ self.pre_conv1 = nn.Sequential(
167
+ nn.Conv2d(input_channels, 32, 3, stride=1, padding=1, bias=False),
168
+ nn.BatchNorm2d(32),
169
+ nn.ReLU()
170
+ )
171
+ self.pre_conv2 = nn.Sequential(
172
+ nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
173
+ nn.BatchNorm2d(64),
174
+ nn.ReLU()
175
+ )
176
+ self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
177
+ self.mobilenet.features[0][0] = nn.Conv2d(
178
+ 64, 32, kernel_size=3, stride=2, padding=1, bias=False
179
+ )
180
+ in_features = self.mobilenet.classifier[-1].in_features
181
+ self.mobilenet.classifier[-1] = nn.Linear(in_features, 1)
182
+
183
+ def forward(self, x):
184
+ x = self.pre_conv1(x)
185
+ x = self.pre_conv2(x)
186
+ return self.mobilenet(x)
187
+
188
+ # Also return the torchvision transforms helper you used
189
+ tfm = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256))])
190
+
191
+ return nn, SatelliteRemoverCNN, BinaryClassificationCNN, BinaryClassificationCNN2, tfm
192
+
193
+
194
+ # ----------------------------
195
+ # Loading helpers
196
+ # ----------------------------
197
+ def _load_model_weights_lenient(torch, nn, model, checkpoint_path: str, device):
198
+ ckpt = torch.load(checkpoint_path, map_location=device)
199
+ state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
200
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
201
+
202
+ msd = model.state_dict()
203
+ filtered = {k: v for k, v in state_dict.items() if k in msd and msd[k].shape == v.shape}
204
+ model.load_state_dict(filtered, strict=False)
205
+ return model
206
+
207
+ # ---------- Satellite model cache + loader (updated for torch-directml + ORT DML) ----------
208
+
209
+ _SAT_CACHE: Dict[Tuple[str, str, str, str], Dict[str, Any]] = {}
210
+
211
+ def get_satellite_models(resources: Any = None, use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
212
+ """
213
+ Backend order:
214
+ 1) CUDA (PyTorch)
215
+ 2) DirectML (torch-directml) [Windows]
216
+ 3) DirectML (ONNX Runtime DML EP) [Windows]
217
+ 4) MPS (PyTorch) [macOS]
218
+ 5) CPU (PyTorch)
219
+
220
+ Cache key includes backend tag, so switching GPU on/off never reuses the wrong backend.
221
+ """
222
+ import os
223
+
224
+ if resources is None:
225
+ resources = get_resources()
226
+
227
+ p_det1 = resources.CC_SAT_DETECT1_PTH
228
+ p_det2 = resources.CC_SAT_DETECT2_PTH
229
+ p_rem = resources.CC_SAT_REMOVE_PTH
230
+
231
+ o_det1 = resources.CC_SAT_DETECT1_ONNX
232
+ o_det2 = resources.CC_SAT_DETECT2_ONNX
233
+ o_rem = resources.CC_SAT_REMOVE_ONNX
234
+
235
+ is_windows = (os.name == "nt")
236
+
237
+ # ORT DirectML availability
238
+ ort_dml_ok = bool(use_gpu) and (ort is not None) and ("DmlExecutionProvider" in ort.get_available_providers())
239
+
240
+ # Torch: ask runtime_torch to prefer what we want (CUDA first, DML on Windows)
241
+ torch = None
242
+ if use_gpu or True:
243
+ torch = _get_torch(
244
+ prefer_cuda=bool(use_gpu),
245
+ prefer_dml=bool(use_gpu and is_windows),
246
+ status_cb=status_cb,
247
+ )
248
+
249
+ # Decide backend
250
+ backend = "cpu"
251
+
252
+ # 1) CUDA
253
+ if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
254
+ backend = "cuda"
255
+ else:
256
+ # 2) torch-directml (Windows)
257
+ if use_gpu and is_windows:
258
+ try:
259
+ import torch_directml # optional
260
+ _ = torch_directml.device()
261
+ backend = "torch_dml"
262
+ except Exception:
263
+ backend = "ort_dml" if ort_dml_ok else "cpu"
264
+ else:
265
+ # 4) MPS (macOS)
266
+ if use_gpu and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
267
+ backend = "mps"
268
+ else:
269
+ backend = "cpu"
270
+
271
+ key = (p_det1, p_det2, p_rem, backend)
272
+ if key in _SAT_CACHE:
273
+ return _SAT_CACHE[key]
274
+
275
+ # ---------------- DirectML via ONNX Runtime ----------------
276
+ if backend == "ort_dml":
277
+ if ort is None:
278
+ raise RuntimeError("onnxruntime not available, cannot use ORT DirectML backend.")
279
+ # sanity: need ONNX paths
280
+ if not (o_det1 and o_det2 and o_rem):
281
+ raise FileNotFoundError("Satellite ONNX model paths are missing in resources.")
282
+
283
+ det1 = ort.InferenceSession(o_det1, providers=["DmlExecutionProvider"])
284
+ det2 = ort.InferenceSession(o_det2, providers=["DmlExecutionProvider"])
285
+ rem = ort.InferenceSession(o_rem, providers=["DmlExecutionProvider"])
286
+
287
+ out = {
288
+ "backend": "ort_dml",
289
+ "detection_model1": det1,
290
+ "detection_model2": det2,
291
+ "removal_model": rem,
292
+ "device": "DirectML",
293
+ "is_onnx": True,
294
+ }
295
+ _SAT_CACHE[key] = out
296
+ status_cb("CosmicClarity Satellite: using DirectML (ONNX Runtime)")
297
+ return out
298
+
299
+ # ---------------- Torch backends (CUDA / torch-directml / MPS / CPU) ----------------
300
+ # pick device
301
+ if backend == "cuda":
302
+ device = torch.device("cuda")
303
+ status_cb(f"CosmicClarity Satellite: using CUDA ({torch.cuda.get_device_name(0)})")
304
+ elif backend == "mps":
305
+ device = torch.device("mps")
306
+ status_cb("CosmicClarity Satellite: using MPS")
307
+ elif backend == "torch_dml":
308
+ import torch_directml
309
+ device = torch_directml.device()
310
+ status_cb("CosmicClarity Satellite: using DirectML (torch-directml)")
311
+ else:
312
+ device = torch.device("cpu")
313
+ status_cb("CosmicClarity Satellite: using CPU")
314
+
315
+ nn, SatelliteRemoverCNN, BinaryClassificationCNN, BinaryClassificationCNN2, tfm = _build_torch_models(torch)
316
+
317
+ det1 = BinaryClassificationCNN(3).to(device)
318
+ det1 = _load_model_weights_lenient(torch, nn, det1, p_det1, device).eval()
319
+
320
+ det2 = BinaryClassificationCNN2(3).to(device)
321
+ det2 = _load_model_weights_lenient(torch, nn, det2, p_det2, device).eval()
322
+
323
+ rem = SatelliteRemoverCNN().to(device)
324
+ rem = _load_model_weights_lenient(torch, nn, rem, p_rem, device).eval()
325
+
326
+ out = {
327
+ "backend": backend,
328
+ "detection_model1": det1,
329
+ "detection_model2": det2,
330
+ "removal_model": rem,
331
+ "device": device,
332
+ "is_onnx": False,
333
+ "torch": torch,
334
+ "tfm": tfm,
335
+ }
336
+ _SAT_CACHE[key] = out
337
+ return out
338
+
339
+ # ----------------------------
340
+ # Core processing
341
+ # ----------------------------
342
+
343
+ def _ensure_rgb01(img: np.ndarray) -> Tuple[np.ndarray, bool]:
344
+ """
345
+ Input: HxW, HxWx1, or HxWx3; float/uint.
346
+ Output: HxWx3 float32 [0..1], plus is_mono flag (originally mono-like).
347
+ """
348
+ a = np.asarray(img)
349
+ is_mono = (a.ndim == 2) or (a.ndim == 3 and a.shape[2] == 1)
350
+
351
+ a = np.nan_to_num(a.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
352
+
353
+ # normalize if >1
354
+ mx = float(np.max(a)) if a.size else 1.0
355
+ if mx > 1.0:
356
+ a = a / mx
357
+
358
+ if a.ndim == 2:
359
+ a = np.stack([a, a, a], axis=-1)
360
+ elif a.ndim == 3 and a.shape[2] == 1:
361
+ a = np.repeat(a, 3, axis=2)
362
+ elif a.ndim == 3 and a.shape[2] >= 3:
363
+ a = a[..., :3]
364
+ else:
365
+ raise ValueError(f"Unsupported image shape: {a.shape}")
366
+
367
+ a = np.clip(a, 0.0, 1.0)
368
+ return a, is_mono
369
+
370
+
371
+ def _extract_luminance_bt601(rgb01: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
372
+ # BT.601 matrix (matches your standalone)
373
+ M = np.array([[0.299, 0.587, 0.114],
374
+ [-0.168736, -0.331264, 0.5],
375
+ [0.5, -0.418688, -0.081312]], dtype=np.float32)
376
+ ycbcr = np.dot(rgb01, M.T)
377
+ y = ycbcr[..., 0]
378
+ cb = ycbcr[..., 1] + 0.5
379
+ cr = ycbcr[..., 2] + 0.5
380
+ return y, cb, cr
381
+
382
+
383
+ def _merge_luminance_bt601(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
384
+ y = np.clip(y, 0, 1).astype(np.float32)
385
+ cb = (np.clip(cb, 0, 1).astype(np.float32) - 0.5)
386
+ cr = (np.clip(cr, 0, 1).astype(np.float32) - 0.5)
387
+
388
+ ycbcr = np.stack([y, cb, cr], axis=-1)
389
+
390
+ M = np.array([[1.0, 0.0, 1.402],
391
+ [1.0, -0.344136, -0.714136],
392
+ [1.0, 1.772, 0.0]], dtype=np.float32)
393
+ rgb = np.dot(ycbcr, M.T)
394
+ return np.clip(rgb, 0.0, 1.0).astype(np.float32)
395
+
396
+
397
+ def _split_chunks(img: np.ndarray, chunk: int, overlap: int):
398
+ H, W = img.shape[:2]
399
+ step = chunk - overlap
400
+ for y0 in range(0, H, step):
401
+ for x0 in range(0, W, step):
402
+ y1 = min(y0 + chunk, H)
403
+ x1 = min(x0 + chunk, W)
404
+ yield img[y0:y1, x0:x1], y0, x0
405
+
406
+
407
+ def _stitch_ignore_border(chunks, shape_hw3, border: int = 16) -> np.ndarray:
408
+ H, W, C = shape_hw3
409
+ acc = np.zeros((H, W, C), np.float32)
410
+ wgt = np.zeros((H, W, C), np.float32)
411
+
412
+ for tile, y0, x0 in chunks:
413
+ th, tw = tile.shape[:2]
414
+ bh = min(border, th // 2)
415
+ bw = min(border, tw // 2)
416
+
417
+ inner = tile[bh:th-bh, bw:tw-bw, :]
418
+ acc[y0+bh:y0+th-bh, x0+bw:x0+tw-bw, :] += inner
419
+ wgt[y0+bh:y0+th-bh, x0+bw:x0+tw-bw, :] += 1.0
420
+
421
+ return acc / np.maximum(wgt, 1.0)
422
+
423
+
424
+ def _apply_clip_trail_logic(processed: np.ndarray, original: np.ndarray, sensitivity: float) -> np.ndarray:
425
+ # exactly your standalone math
426
+ sattrail_only = original - processed
427
+ mean_val = float(np.mean(sattrail_only))
428
+ clipped = np.clip((sattrail_only - mean_val) * 10.0, 0.0, 1.0)
429
+ mask = np.where(clipped < sensitivity, 0.0, 1.0).astype(np.float32)
430
+ return np.clip(original - mask, 0.0, 1.0)
431
+
432
+ # ---------- Torch detection (FIX: tfm expects PIL/ndarray, not Tensor; avoid double ToTensor) ----------
433
+
434
+ def _torch_detect(tile_rgb01: np.ndarray, models: Dict[str, Any]) -> bool:
435
+ """
436
+ Your tfm = ToTensor()+Resize(256,256). It expects HxWxC numpy in [0..1] (or uint8).
437
+ Do NOT feed it a tensor.
438
+ """
439
+ torch = models["torch"]
440
+ device = models["device"]
441
+ det1 = models["detection_model1"]
442
+ det2 = models["detection_model2"]
443
+ tfm = models["tfm"]
444
+
445
+ a = np.asarray(tile_rgb01, np.float32)
446
+ a = np.clip(a, 0.0, 1.0)
447
+
448
+ # torchvision transform pipeline
449
+ inp = tfm(a) # -> Tensor [C,H,W], float32
450
+ inp = inp.unsqueeze(0).to(device)
451
+
452
+ with torch.no_grad():
453
+ o1 = float(det1(inp).item())
454
+ if o1 <= 0.5:
455
+ return False
456
+
457
+ with torch.no_grad():
458
+ o2 = float(det2(inp).item())
459
+ return (o2 > 0.25)
460
+
461
+
462
+
463
+ def _torch_remove(tile_rgb01: np.ndarray, models: Dict[str, Any]) -> np.ndarray:
464
+ torch = models["torch"]
465
+ device = models["device"]
466
+ rem = models["removal_model"]
467
+
468
+ x = torch.from_numpy(tile_rgb01).permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=torch.float32)
469
+
470
+ with torch.no_grad(), _autocast_context(torch, device):
471
+ out = rem(x).squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)
472
+
473
+ return np.clip(out, 0.0, 1.0).astype(np.float32)
474
+
475
+
476
+ def _onnx_detect(tile_rgb01: np.ndarray, sess) -> bool:
477
+ # Resize to 256x256 like your standalone ONNX path
478
+ if _sk_resize is None:
479
+ raise RuntimeError("skimage.transform.resize is required for ONNX satellite detection path.")
480
+ r = _sk_resize(tile_rgb01, (256, 256, 3), mode="reflect", anti_aliasing=True).astype(np.float32)
481
+ inp = np.transpose(r, (2, 0, 1))[None, ...]
482
+ out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
483
+ return bool(out[0] > 0.5)
484
+
485
+
486
+ def _onnx_remove(tile_rgb01: np.ndarray, sess) -> np.ndarray:
487
+ if _sk_resize is None:
488
+ raise RuntimeError("skimage.transform.resize is required for ONNX satellite removal path.")
489
+ r = _sk_resize(tile_rgb01, (256, 256, 3), mode="reflect", anti_aliasing=True).astype(np.float32)
490
+ inp = np.transpose(r, (2, 0, 1))[None, ...]
491
+ out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
492
+ pred = np.transpose(out.squeeze(0), (1, 2, 0)).astype(np.float32)
493
+ # resize back to original tile size
494
+ pred2 = _sk_resize(pred, tile_rgb01.shape, mode="reflect", anti_aliasing=True).astype(np.float32)
495
+ return np.clip(pred2, 0.0, 1.0)
496
+
497
+
498
+ def satellite_remove_image(
499
+ image: np.ndarray,
500
+ models: Dict[str, Any],
501
+ *,
502
+ mode: str = "full", # "full" or "luminance"
503
+ clip_trail: bool = True,
504
+ sensitivity: float = 0.1,
505
+ chunk_size: int = 256,
506
+ overlap: int = 64,
507
+ border_size: int = 16,
508
+ progress_cb: Optional[Callable[[int, int], None]] = None, # (done, total)
509
+ ) -> Tuple[np.ndarray, bool]:
510
+ """
511
+ image: input image (any dtype/shape). Expected to be linear-ish in [0..1] for best behavior.
512
+ Returns: (out_image_same_shape_style, trail_detected_any)
513
+ """
514
+ rgb01, was_mono = _ensure_rgb01(image)
515
+
516
+ # luminance mode -> process Y only, then merge back
517
+ if mode.lower() == "luminance":
518
+ y, cb, cr = _extract_luminance_bt601(rgb01)
519
+ # treat Y as "mono" but we still run the network as RGB by repeating
520
+ y3 = np.stack([y, y, y], axis=-1)
521
+ out3, detected = _satellite_remove_rgb(
522
+ y3, models,
523
+ clip_trail=clip_trail, sensitivity=sensitivity,
524
+ chunk_size=chunk_size, overlap=overlap, border_size=border_size,
525
+ progress_cb=progress_cb,
526
+ )
527
+ out_y = out3[..., 0]
528
+ out_rgb = _merge_luminance_bt601(out_y, cb, cr)
529
+ else:
530
+ out_rgb, detected = _satellite_remove_rgb(
531
+ rgb01, models,
532
+ clip_trail=clip_trail, sensitivity=sensitivity,
533
+ chunk_size=chunk_size, overlap=overlap, border_size=border_size,
534
+ progress_cb=progress_cb,
535
+ )
536
+
537
+ # If original was mono-like, return HxWx1 (matches your SASpro convention for mono docs)
538
+ if (np.asarray(image).ndim == 2) or (np.asarray(image).ndim == 3 and np.asarray(image).shape[2] == 1):
539
+ out_m = out_rgb[..., 0:1].astype(np.float32)
540
+ return out_m, detected
541
+
542
+ return out_rgb.astype(np.float32), detected
543
+
544
+ # ---------- Satellite remove loop (FIX: use correct ONNX sessions, not det1/rem confusion) ----------
545
+
546
+ def _satellite_remove_rgb(
547
+ rgb01: np.ndarray,
548
+ models: Dict[str, Any],
549
+ *,
550
+ clip_trail: bool,
551
+ sensitivity: float,
552
+ chunk_size: int,
553
+ overlap: int,
554
+ border_size: int,
555
+ progress_cb: Optional[Callable[[int, int], None]],
556
+ ) -> Tuple[np.ndarray, bool]:
557
+ """
558
+ Uses BOTH detectors like your torch path.
559
+ ONNX path now calls det1+det2+rem sessions correctly.
560
+ """
561
+ is_onnx = bool(models.get("is_onnx", False))
562
+
563
+ H, W = rgb01.shape[:2]
564
+ trail_any = False
565
+
566
+ all_tiles = list(_split_chunks(rgb01, chunk_size, overlap))
567
+ total = len(all_tiles)
568
+ out_tiles = []
569
+
570
+ for idx, (tile, y0, x0) in enumerate(all_tiles, start=1):
571
+ orig = tile.astype(np.float32, copy=False)
572
+
573
+ if is_onnx:
574
+ det1_sess = models["detection_model1"]
575
+ det2_sess = models["detection_model2"]
576
+ rem_sess = models["removal_model"]
577
+
578
+ d1 = _onnx_detect(orig, det1_sess)
579
+ if d1:
580
+ d2 = _onnx_detect(orig, det2_sess)
581
+ else:
582
+ d2 = False
583
+
584
+ detected = bool(d1 and d2)
585
+ if detected:
586
+ trail_any = True
587
+ pred = _onnx_remove(orig, rem_sess)
588
+ final = _apply_clip_trail_logic(pred, orig, sensitivity) if clip_trail else pred
589
+ else:
590
+ final = orig
591
+
592
+ else:
593
+ detected = _torch_detect(orig, models)
594
+ if detected:
595
+ trail_any = True
596
+ pred = _torch_remove(orig, models)
597
+ final = _apply_clip_trail_logic(pred, orig, sensitivity) if clip_trail else pred
598
+ else:
599
+ final = orig
600
+
601
+ out_tiles.append((final, y0, x0))
602
+
603
+ if progress_cb is not None:
604
+ progress_cb(idx, total)
605
+
606
+ out = _stitch_ignore_border(out_tiles, (H, W, 3), border=border_size)
607
+
608
+ # keep edges unchanged
609
+ if border_size > 0:
610
+ out[:border_size, :, :] = rgb01[:border_size, :, :]
611
+ out[-border_size:, :, :] = rgb01[-border_size:, :, :]
612
+ out[:, :border_size, :] = rgb01[:, :border_size, :]
613
+ out[:, -border_size:, :] = rgb01[:, -border_size:, :]
614
+
615
+ out = np.clip(out, 0.0, 1.0).astype(np.float32)
616
+
617
+ if not trail_any:
618
+ return rgb01.astype(np.float32, copy=False), False
619
+
620
+ return out, True