setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,620 @@
|
|
|
1
|
+
# src/setiastro/saspro/cosmicclarity_engines/satellite_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from setiastro.saspro.resources import get_resources
|
|
9
|
+
|
|
10
|
+
# Optional deps
|
|
11
|
+
try:
|
|
12
|
+
import onnxruntime as ort
|
|
13
|
+
except Exception:
|
|
14
|
+
ort = None
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from skimage.transform import resize as _sk_resize
|
|
18
|
+
except Exception:
|
|
19
|
+
_sk_resize = None
|
|
20
|
+
|
|
21
|
+
ProgressCB = Callable[[int, int], None] # (done, total)
|
|
22
|
+
|
|
23
|
+
# ---------- Torch import (updated: CUDA + torch-directml awareness) ----------
|
|
24
|
+
|
|
25
|
+
def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
|
|
26
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
27
|
+
return import_torch(
|
|
28
|
+
prefer_cuda=prefer_cuda,
|
|
29
|
+
prefer_xpu=False,
|
|
30
|
+
prefer_dml=prefer_dml,
|
|
31
|
+
status_cb=status_cb,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _nullcontext():
|
|
36
|
+
from contextlib import nullcontext
|
|
37
|
+
return nullcontext()
|
|
38
|
+
|
|
39
|
+
def _autocast_context(torch, device) -> Any:
|
|
40
|
+
"""
|
|
41
|
+
Use new torch.amp.autocast('cuda') when available.
|
|
42
|
+
Keep your cap >= 8.0 rule.
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
if hasattr(device, "type") and device.type == "cuda":
|
|
46
|
+
major, minor = torch.cuda.get_device_capability()
|
|
47
|
+
cap = float(f"{major}.{minor}")
|
|
48
|
+
if cap >= 8.0:
|
|
49
|
+
# Preferred API (torch >= 1.10-ish; definitely in 2.x)
|
|
50
|
+
if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
|
|
51
|
+
return torch.amp.autocast(device_type="cuda")
|
|
52
|
+
# Fallback for older torch
|
|
53
|
+
return torch.cuda.amp.autocast()
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
return _nullcontext()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# ----------------------------
|
|
61
|
+
# Models (from standalone)
|
|
62
|
+
# ----------------------------
|
|
63
|
+
|
|
64
|
+
def _build_torch_models(torch):
|
|
65
|
+
# Import torch.nn + torchvision lazily, only after torch loads
|
|
66
|
+
import torch.nn as nn
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from torchvision import models
|
|
70
|
+
from torchvision.models import ResNet18_Weights, MobileNet_V2_Weights
|
|
71
|
+
from torchvision import transforms
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise RuntimeError(f"torchvision is required for Satellite engine torch backend: {e}")
|
|
74
|
+
|
|
75
|
+
class ResidualBlock(nn.Module):
|
|
76
|
+
def __init__(self, channels: int):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
79
|
+
self.relu = nn.ReLU()
|
|
80
|
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
|
81
|
+
def forward(self, x):
|
|
82
|
+
r = x
|
|
83
|
+
x = self.relu(self.conv1(x))
|
|
84
|
+
x = self.conv2(x)
|
|
85
|
+
x = self.relu(x + r)
|
|
86
|
+
return x
|
|
87
|
+
|
|
88
|
+
class SatelliteRemoverCNN(nn.Module):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.encoder1 = nn.Sequential(
|
|
92
|
+
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
|
|
93
|
+
ResidualBlock(16), ResidualBlock(16),
|
|
94
|
+
)
|
|
95
|
+
self.encoder2 = nn.Sequential(
|
|
96
|
+
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
|
|
97
|
+
ResidualBlock(32), ResidualBlock(32),
|
|
98
|
+
)
|
|
99
|
+
self.encoder3 = nn.Sequential(
|
|
100
|
+
nn.Conv2d(32, 64, 3, padding=2, dilation=2), nn.ReLU(),
|
|
101
|
+
ResidualBlock(64), ResidualBlock(64),
|
|
102
|
+
)
|
|
103
|
+
self.encoder4 = nn.Sequential(
|
|
104
|
+
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
|
|
105
|
+
ResidualBlock(128), ResidualBlock(128),
|
|
106
|
+
)
|
|
107
|
+
self.encoder5 = nn.Sequential(
|
|
108
|
+
nn.Conv2d(128, 256, 3, padding=2, dilation=2), nn.ReLU(),
|
|
109
|
+
ResidualBlock(256), ResidualBlock(256),
|
|
110
|
+
)
|
|
111
|
+
self.decoder5 = nn.Sequential(
|
|
112
|
+
nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(),
|
|
113
|
+
ResidualBlock(128), ResidualBlock(128),
|
|
114
|
+
)
|
|
115
|
+
self.decoder4 = nn.Sequential(
|
|
116
|
+
nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(),
|
|
117
|
+
ResidualBlock(64), ResidualBlock(64),
|
|
118
|
+
)
|
|
119
|
+
self.decoder3 = nn.Sequential(
|
|
120
|
+
nn.Conv2d(64 + 32, 32, 3, padding=1), nn.ReLU(),
|
|
121
|
+
ResidualBlock(32), ResidualBlock(32),
|
|
122
|
+
)
|
|
123
|
+
self.decoder2 = nn.Sequential(
|
|
124
|
+
nn.Conv2d(32 + 16, 16, 3, padding=1), nn.ReLU(),
|
|
125
|
+
ResidualBlock(16), ResidualBlock(16),
|
|
126
|
+
)
|
|
127
|
+
self.decoder1 = nn.Sequential(nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid())
|
|
128
|
+
|
|
129
|
+
def forward(self, x):
|
|
130
|
+
e1 = self.encoder1(x)
|
|
131
|
+
e2 = self.encoder2(e1)
|
|
132
|
+
e3 = self.encoder3(e2)
|
|
133
|
+
e4 = self.encoder4(e3)
|
|
134
|
+
e5 = self.encoder5(e4)
|
|
135
|
+
d5 = self.decoder5(torch.cat([e5, e4], dim=1))
|
|
136
|
+
d4 = self.decoder4(torch.cat([d5, e3], dim=1))
|
|
137
|
+
d3 = self.decoder3(torch.cat([d4, e2], dim=1))
|
|
138
|
+
d2 = self.decoder2(torch.cat([d3, e1], dim=1))
|
|
139
|
+
return self.decoder1(d2)
|
|
140
|
+
|
|
141
|
+
class BinaryClassificationCNN(nn.Module):
|
|
142
|
+
def __init__(self, input_channels: int = 3):
|
|
143
|
+
super().__init__()
|
|
144
|
+
self.pre_conv1 = nn.Sequential(
|
|
145
|
+
nn.Conv2d(input_channels, 32, 3, stride=1, padding=1, bias=False),
|
|
146
|
+
nn.BatchNorm2d(32),
|
|
147
|
+
nn.ReLU()
|
|
148
|
+
)
|
|
149
|
+
self.pre_conv2 = nn.Sequential(
|
|
150
|
+
nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
|
|
151
|
+
nn.BatchNorm2d(64),
|
|
152
|
+
nn.ReLU()
|
|
153
|
+
)
|
|
154
|
+
self.features = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
|
155
|
+
self.features.conv1 = nn.Conv2d(64, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
156
|
+
self.features.fc = nn.Linear(self.features.fc.in_features, 1)
|
|
157
|
+
|
|
158
|
+
def forward(self, x):
|
|
159
|
+
x = self.pre_conv1(x)
|
|
160
|
+
x = self.pre_conv2(x)
|
|
161
|
+
return self.features(x)
|
|
162
|
+
|
|
163
|
+
class BinaryClassificationCNN2(nn.Module):
|
|
164
|
+
def __init__(self, input_channels: int = 3):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.pre_conv1 = nn.Sequential(
|
|
167
|
+
nn.Conv2d(input_channels, 32, 3, stride=1, padding=1, bias=False),
|
|
168
|
+
nn.BatchNorm2d(32),
|
|
169
|
+
nn.ReLU()
|
|
170
|
+
)
|
|
171
|
+
self.pre_conv2 = nn.Sequential(
|
|
172
|
+
nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
|
|
173
|
+
nn.BatchNorm2d(64),
|
|
174
|
+
nn.ReLU()
|
|
175
|
+
)
|
|
176
|
+
self.mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
|
|
177
|
+
self.mobilenet.features[0][0] = nn.Conv2d(
|
|
178
|
+
64, 32, kernel_size=3, stride=2, padding=1, bias=False
|
|
179
|
+
)
|
|
180
|
+
in_features = self.mobilenet.classifier[-1].in_features
|
|
181
|
+
self.mobilenet.classifier[-1] = nn.Linear(in_features, 1)
|
|
182
|
+
|
|
183
|
+
def forward(self, x):
|
|
184
|
+
x = self.pre_conv1(x)
|
|
185
|
+
x = self.pre_conv2(x)
|
|
186
|
+
return self.mobilenet(x)
|
|
187
|
+
|
|
188
|
+
# Also return the torchvision transforms helper you used
|
|
189
|
+
tfm = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256))])
|
|
190
|
+
|
|
191
|
+
return nn, SatelliteRemoverCNN, BinaryClassificationCNN, BinaryClassificationCNN2, tfm
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ----------------------------
|
|
195
|
+
# Loading helpers
|
|
196
|
+
# ----------------------------
|
|
197
|
+
def _load_model_weights_lenient(torch, nn, model, checkpoint_path: str, device):
|
|
198
|
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
|
199
|
+
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
|
|
200
|
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
|
201
|
+
|
|
202
|
+
msd = model.state_dict()
|
|
203
|
+
filtered = {k: v for k, v in state_dict.items() if k in msd and msd[k].shape == v.shape}
|
|
204
|
+
model.load_state_dict(filtered, strict=False)
|
|
205
|
+
return model
|
|
206
|
+
|
|
207
|
+
# ---------- Satellite model cache + loader (updated for torch-directml + ORT DML) ----------
|
|
208
|
+
|
|
209
|
+
_SAT_CACHE: Dict[Tuple[str, str, str, str], Dict[str, Any]] = {}
|
|
210
|
+
|
|
211
|
+
def get_satellite_models(resources: Any = None, use_gpu: bool = True, status_cb=print) -> Dict[str, Any]:
|
|
212
|
+
"""
|
|
213
|
+
Backend order:
|
|
214
|
+
1) CUDA (PyTorch)
|
|
215
|
+
2) DirectML (torch-directml) [Windows]
|
|
216
|
+
3) DirectML (ONNX Runtime DML EP) [Windows]
|
|
217
|
+
4) MPS (PyTorch) [macOS]
|
|
218
|
+
5) CPU (PyTorch)
|
|
219
|
+
|
|
220
|
+
Cache key includes backend tag, so switching GPU on/off never reuses the wrong backend.
|
|
221
|
+
"""
|
|
222
|
+
import os
|
|
223
|
+
|
|
224
|
+
if resources is None:
|
|
225
|
+
resources = get_resources()
|
|
226
|
+
|
|
227
|
+
p_det1 = resources.CC_SAT_DETECT1_PTH
|
|
228
|
+
p_det2 = resources.CC_SAT_DETECT2_PTH
|
|
229
|
+
p_rem = resources.CC_SAT_REMOVE_PTH
|
|
230
|
+
|
|
231
|
+
o_det1 = resources.CC_SAT_DETECT1_ONNX
|
|
232
|
+
o_det2 = resources.CC_SAT_DETECT2_ONNX
|
|
233
|
+
o_rem = resources.CC_SAT_REMOVE_ONNX
|
|
234
|
+
|
|
235
|
+
is_windows = (os.name == "nt")
|
|
236
|
+
|
|
237
|
+
# ORT DirectML availability
|
|
238
|
+
ort_dml_ok = bool(use_gpu) and (ort is not None) and ("DmlExecutionProvider" in ort.get_available_providers())
|
|
239
|
+
|
|
240
|
+
# Torch: ask runtime_torch to prefer what we want (CUDA first, DML on Windows)
|
|
241
|
+
torch = None
|
|
242
|
+
if use_gpu or True:
|
|
243
|
+
torch = _get_torch(
|
|
244
|
+
prefer_cuda=bool(use_gpu),
|
|
245
|
+
prefer_dml=bool(use_gpu and is_windows),
|
|
246
|
+
status_cb=status_cb,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Decide backend
|
|
250
|
+
backend = "cpu"
|
|
251
|
+
|
|
252
|
+
# 1) CUDA
|
|
253
|
+
if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
254
|
+
backend = "cuda"
|
|
255
|
+
else:
|
|
256
|
+
# 2) torch-directml (Windows)
|
|
257
|
+
if use_gpu and is_windows:
|
|
258
|
+
try:
|
|
259
|
+
import torch_directml # optional
|
|
260
|
+
_ = torch_directml.device()
|
|
261
|
+
backend = "torch_dml"
|
|
262
|
+
except Exception:
|
|
263
|
+
backend = "ort_dml" if ort_dml_ok else "cpu"
|
|
264
|
+
else:
|
|
265
|
+
# 4) MPS (macOS)
|
|
266
|
+
if use_gpu and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
267
|
+
backend = "mps"
|
|
268
|
+
else:
|
|
269
|
+
backend = "cpu"
|
|
270
|
+
|
|
271
|
+
key = (p_det1, p_det2, p_rem, backend)
|
|
272
|
+
if key in _SAT_CACHE:
|
|
273
|
+
return _SAT_CACHE[key]
|
|
274
|
+
|
|
275
|
+
# ---------------- DirectML via ONNX Runtime ----------------
|
|
276
|
+
if backend == "ort_dml":
|
|
277
|
+
if ort is None:
|
|
278
|
+
raise RuntimeError("onnxruntime not available, cannot use ORT DirectML backend.")
|
|
279
|
+
# sanity: need ONNX paths
|
|
280
|
+
if not (o_det1 and o_det2 and o_rem):
|
|
281
|
+
raise FileNotFoundError("Satellite ONNX model paths are missing in resources.")
|
|
282
|
+
|
|
283
|
+
det1 = ort.InferenceSession(o_det1, providers=["DmlExecutionProvider"])
|
|
284
|
+
det2 = ort.InferenceSession(o_det2, providers=["DmlExecutionProvider"])
|
|
285
|
+
rem = ort.InferenceSession(o_rem, providers=["DmlExecutionProvider"])
|
|
286
|
+
|
|
287
|
+
out = {
|
|
288
|
+
"backend": "ort_dml",
|
|
289
|
+
"detection_model1": det1,
|
|
290
|
+
"detection_model2": det2,
|
|
291
|
+
"removal_model": rem,
|
|
292
|
+
"device": "DirectML",
|
|
293
|
+
"is_onnx": True,
|
|
294
|
+
}
|
|
295
|
+
_SAT_CACHE[key] = out
|
|
296
|
+
status_cb("CosmicClarity Satellite: using DirectML (ONNX Runtime)")
|
|
297
|
+
return out
|
|
298
|
+
|
|
299
|
+
# ---------------- Torch backends (CUDA / torch-directml / MPS / CPU) ----------------
|
|
300
|
+
# pick device
|
|
301
|
+
if backend == "cuda":
|
|
302
|
+
device = torch.device("cuda")
|
|
303
|
+
status_cb(f"CosmicClarity Satellite: using CUDA ({torch.cuda.get_device_name(0)})")
|
|
304
|
+
elif backend == "mps":
|
|
305
|
+
device = torch.device("mps")
|
|
306
|
+
status_cb("CosmicClarity Satellite: using MPS")
|
|
307
|
+
elif backend == "torch_dml":
|
|
308
|
+
import torch_directml
|
|
309
|
+
device = torch_directml.device()
|
|
310
|
+
status_cb("CosmicClarity Satellite: using DirectML (torch-directml)")
|
|
311
|
+
else:
|
|
312
|
+
device = torch.device("cpu")
|
|
313
|
+
status_cb("CosmicClarity Satellite: using CPU")
|
|
314
|
+
|
|
315
|
+
nn, SatelliteRemoverCNN, BinaryClassificationCNN, BinaryClassificationCNN2, tfm = _build_torch_models(torch)
|
|
316
|
+
|
|
317
|
+
det1 = BinaryClassificationCNN(3).to(device)
|
|
318
|
+
det1 = _load_model_weights_lenient(torch, nn, det1, p_det1, device).eval()
|
|
319
|
+
|
|
320
|
+
det2 = BinaryClassificationCNN2(3).to(device)
|
|
321
|
+
det2 = _load_model_weights_lenient(torch, nn, det2, p_det2, device).eval()
|
|
322
|
+
|
|
323
|
+
rem = SatelliteRemoverCNN().to(device)
|
|
324
|
+
rem = _load_model_weights_lenient(torch, nn, rem, p_rem, device).eval()
|
|
325
|
+
|
|
326
|
+
out = {
|
|
327
|
+
"backend": backend,
|
|
328
|
+
"detection_model1": det1,
|
|
329
|
+
"detection_model2": det2,
|
|
330
|
+
"removal_model": rem,
|
|
331
|
+
"device": device,
|
|
332
|
+
"is_onnx": False,
|
|
333
|
+
"torch": torch,
|
|
334
|
+
"tfm": tfm,
|
|
335
|
+
}
|
|
336
|
+
_SAT_CACHE[key] = out
|
|
337
|
+
return out
|
|
338
|
+
|
|
339
|
+
# ----------------------------
|
|
340
|
+
# Core processing
|
|
341
|
+
# ----------------------------
|
|
342
|
+
|
|
343
|
+
def _ensure_rgb01(img: np.ndarray) -> Tuple[np.ndarray, bool]:
|
|
344
|
+
"""
|
|
345
|
+
Input: HxW, HxWx1, or HxWx3; float/uint.
|
|
346
|
+
Output: HxWx3 float32 [0..1], plus is_mono flag (originally mono-like).
|
|
347
|
+
"""
|
|
348
|
+
a = np.asarray(img)
|
|
349
|
+
is_mono = (a.ndim == 2) or (a.ndim == 3 and a.shape[2] == 1)
|
|
350
|
+
|
|
351
|
+
a = np.nan_to_num(a.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
|
|
352
|
+
|
|
353
|
+
# normalize if >1
|
|
354
|
+
mx = float(np.max(a)) if a.size else 1.0
|
|
355
|
+
if mx > 1.0:
|
|
356
|
+
a = a / mx
|
|
357
|
+
|
|
358
|
+
if a.ndim == 2:
|
|
359
|
+
a = np.stack([a, a, a], axis=-1)
|
|
360
|
+
elif a.ndim == 3 and a.shape[2] == 1:
|
|
361
|
+
a = np.repeat(a, 3, axis=2)
|
|
362
|
+
elif a.ndim == 3 and a.shape[2] >= 3:
|
|
363
|
+
a = a[..., :3]
|
|
364
|
+
else:
|
|
365
|
+
raise ValueError(f"Unsupported image shape: {a.shape}")
|
|
366
|
+
|
|
367
|
+
a = np.clip(a, 0.0, 1.0)
|
|
368
|
+
return a, is_mono
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _extract_luminance_bt601(rgb01: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
372
|
+
# BT.601 matrix (matches your standalone)
|
|
373
|
+
M = np.array([[0.299, 0.587, 0.114],
|
|
374
|
+
[-0.168736, -0.331264, 0.5],
|
|
375
|
+
[0.5, -0.418688, -0.081312]], dtype=np.float32)
|
|
376
|
+
ycbcr = np.dot(rgb01, M.T)
|
|
377
|
+
y = ycbcr[..., 0]
|
|
378
|
+
cb = ycbcr[..., 1] + 0.5
|
|
379
|
+
cr = ycbcr[..., 2] + 0.5
|
|
380
|
+
return y, cb, cr
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _merge_luminance_bt601(y: np.ndarray, cb: np.ndarray, cr: np.ndarray) -> np.ndarray:
|
|
384
|
+
y = np.clip(y, 0, 1).astype(np.float32)
|
|
385
|
+
cb = (np.clip(cb, 0, 1).astype(np.float32) - 0.5)
|
|
386
|
+
cr = (np.clip(cr, 0, 1).astype(np.float32) - 0.5)
|
|
387
|
+
|
|
388
|
+
ycbcr = np.stack([y, cb, cr], axis=-1)
|
|
389
|
+
|
|
390
|
+
M = np.array([[1.0, 0.0, 1.402],
|
|
391
|
+
[1.0, -0.344136, -0.714136],
|
|
392
|
+
[1.0, 1.772, 0.0]], dtype=np.float32)
|
|
393
|
+
rgb = np.dot(ycbcr, M.T)
|
|
394
|
+
return np.clip(rgb, 0.0, 1.0).astype(np.float32)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _split_chunks(img: np.ndarray, chunk: int, overlap: int):
|
|
398
|
+
H, W = img.shape[:2]
|
|
399
|
+
step = chunk - overlap
|
|
400
|
+
for y0 in range(0, H, step):
|
|
401
|
+
for x0 in range(0, W, step):
|
|
402
|
+
y1 = min(y0 + chunk, H)
|
|
403
|
+
x1 = min(x0 + chunk, W)
|
|
404
|
+
yield img[y0:y1, x0:x1], y0, x0
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _stitch_ignore_border(chunks, shape_hw3, border: int = 16) -> np.ndarray:
|
|
408
|
+
H, W, C = shape_hw3
|
|
409
|
+
acc = np.zeros((H, W, C), np.float32)
|
|
410
|
+
wgt = np.zeros((H, W, C), np.float32)
|
|
411
|
+
|
|
412
|
+
for tile, y0, x0 in chunks:
|
|
413
|
+
th, tw = tile.shape[:2]
|
|
414
|
+
bh = min(border, th // 2)
|
|
415
|
+
bw = min(border, tw // 2)
|
|
416
|
+
|
|
417
|
+
inner = tile[bh:th-bh, bw:tw-bw, :]
|
|
418
|
+
acc[y0+bh:y0+th-bh, x0+bw:x0+tw-bw, :] += inner
|
|
419
|
+
wgt[y0+bh:y0+th-bh, x0+bw:x0+tw-bw, :] += 1.0
|
|
420
|
+
|
|
421
|
+
return acc / np.maximum(wgt, 1.0)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _apply_clip_trail_logic(processed: np.ndarray, original: np.ndarray, sensitivity: float) -> np.ndarray:
|
|
425
|
+
# exactly your standalone math
|
|
426
|
+
sattrail_only = original - processed
|
|
427
|
+
mean_val = float(np.mean(sattrail_only))
|
|
428
|
+
clipped = np.clip((sattrail_only - mean_val) * 10.0, 0.0, 1.0)
|
|
429
|
+
mask = np.where(clipped < sensitivity, 0.0, 1.0).astype(np.float32)
|
|
430
|
+
return np.clip(original - mask, 0.0, 1.0)
|
|
431
|
+
|
|
432
|
+
# ---------- Torch detection (FIX: tfm expects PIL/ndarray, not Tensor; avoid double ToTensor) ----------
|
|
433
|
+
|
|
434
|
+
def _torch_detect(tile_rgb01: np.ndarray, models: Dict[str, Any]) -> bool:
|
|
435
|
+
"""
|
|
436
|
+
Your tfm = ToTensor()+Resize(256,256). It expects HxWxC numpy in [0..1] (or uint8).
|
|
437
|
+
Do NOT feed it a tensor.
|
|
438
|
+
"""
|
|
439
|
+
torch = models["torch"]
|
|
440
|
+
device = models["device"]
|
|
441
|
+
det1 = models["detection_model1"]
|
|
442
|
+
det2 = models["detection_model2"]
|
|
443
|
+
tfm = models["tfm"]
|
|
444
|
+
|
|
445
|
+
a = np.asarray(tile_rgb01, np.float32)
|
|
446
|
+
a = np.clip(a, 0.0, 1.0)
|
|
447
|
+
|
|
448
|
+
# torchvision transform pipeline
|
|
449
|
+
inp = tfm(a) # -> Tensor [C,H,W], float32
|
|
450
|
+
inp = inp.unsqueeze(0).to(device)
|
|
451
|
+
|
|
452
|
+
with torch.no_grad():
|
|
453
|
+
o1 = float(det1(inp).item())
|
|
454
|
+
if o1 <= 0.5:
|
|
455
|
+
return False
|
|
456
|
+
|
|
457
|
+
with torch.no_grad():
|
|
458
|
+
o2 = float(det2(inp).item())
|
|
459
|
+
return (o2 > 0.25)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def _torch_remove(tile_rgb01: np.ndarray, models: Dict[str, Any]) -> np.ndarray:
|
|
464
|
+
torch = models["torch"]
|
|
465
|
+
device = models["device"]
|
|
466
|
+
rem = models["removal_model"]
|
|
467
|
+
|
|
468
|
+
x = torch.from_numpy(tile_rgb01).permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=torch.float32)
|
|
469
|
+
|
|
470
|
+
with torch.no_grad(), _autocast_context(torch, device):
|
|
471
|
+
out = rem(x).squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)
|
|
472
|
+
|
|
473
|
+
return np.clip(out, 0.0, 1.0).astype(np.float32)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def _onnx_detect(tile_rgb01: np.ndarray, sess) -> bool:
|
|
477
|
+
# Resize to 256x256 like your standalone ONNX path
|
|
478
|
+
if _sk_resize is None:
|
|
479
|
+
raise RuntimeError("skimage.transform.resize is required for ONNX satellite detection path.")
|
|
480
|
+
r = _sk_resize(tile_rgb01, (256, 256, 3), mode="reflect", anti_aliasing=True).astype(np.float32)
|
|
481
|
+
inp = np.transpose(r, (2, 0, 1))[None, ...]
|
|
482
|
+
out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
|
|
483
|
+
return bool(out[0] > 0.5)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _onnx_remove(tile_rgb01: np.ndarray, sess) -> np.ndarray:
|
|
487
|
+
if _sk_resize is None:
|
|
488
|
+
raise RuntimeError("skimage.transform.resize is required for ONNX satellite removal path.")
|
|
489
|
+
r = _sk_resize(tile_rgb01, (256, 256, 3), mode="reflect", anti_aliasing=True).astype(np.float32)
|
|
490
|
+
inp = np.transpose(r, (2, 0, 1))[None, ...]
|
|
491
|
+
out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
|
|
492
|
+
pred = np.transpose(out.squeeze(0), (1, 2, 0)).astype(np.float32)
|
|
493
|
+
# resize back to original tile size
|
|
494
|
+
pred2 = _sk_resize(pred, tile_rgb01.shape, mode="reflect", anti_aliasing=True).astype(np.float32)
|
|
495
|
+
return np.clip(pred2, 0.0, 1.0)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def satellite_remove_image(
|
|
499
|
+
image: np.ndarray,
|
|
500
|
+
models: Dict[str, Any],
|
|
501
|
+
*,
|
|
502
|
+
mode: str = "full", # "full" or "luminance"
|
|
503
|
+
clip_trail: bool = True,
|
|
504
|
+
sensitivity: float = 0.1,
|
|
505
|
+
chunk_size: int = 256,
|
|
506
|
+
overlap: int = 64,
|
|
507
|
+
border_size: int = 16,
|
|
508
|
+
progress_cb: Optional[Callable[[int, int], None]] = None, # (done, total)
|
|
509
|
+
) -> Tuple[np.ndarray, bool]:
|
|
510
|
+
"""
|
|
511
|
+
image: input image (any dtype/shape). Expected to be linear-ish in [0..1] for best behavior.
|
|
512
|
+
Returns: (out_image_same_shape_style, trail_detected_any)
|
|
513
|
+
"""
|
|
514
|
+
rgb01, was_mono = _ensure_rgb01(image)
|
|
515
|
+
|
|
516
|
+
# luminance mode -> process Y only, then merge back
|
|
517
|
+
if mode.lower() == "luminance":
|
|
518
|
+
y, cb, cr = _extract_luminance_bt601(rgb01)
|
|
519
|
+
# treat Y as "mono" but we still run the network as RGB by repeating
|
|
520
|
+
y3 = np.stack([y, y, y], axis=-1)
|
|
521
|
+
out3, detected = _satellite_remove_rgb(
|
|
522
|
+
y3, models,
|
|
523
|
+
clip_trail=clip_trail, sensitivity=sensitivity,
|
|
524
|
+
chunk_size=chunk_size, overlap=overlap, border_size=border_size,
|
|
525
|
+
progress_cb=progress_cb,
|
|
526
|
+
)
|
|
527
|
+
out_y = out3[..., 0]
|
|
528
|
+
out_rgb = _merge_luminance_bt601(out_y, cb, cr)
|
|
529
|
+
else:
|
|
530
|
+
out_rgb, detected = _satellite_remove_rgb(
|
|
531
|
+
rgb01, models,
|
|
532
|
+
clip_trail=clip_trail, sensitivity=sensitivity,
|
|
533
|
+
chunk_size=chunk_size, overlap=overlap, border_size=border_size,
|
|
534
|
+
progress_cb=progress_cb,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# If original was mono-like, return HxWx1 (matches your SASpro convention for mono docs)
|
|
538
|
+
if (np.asarray(image).ndim == 2) or (np.asarray(image).ndim == 3 and np.asarray(image).shape[2] == 1):
|
|
539
|
+
out_m = out_rgb[..., 0:1].astype(np.float32)
|
|
540
|
+
return out_m, detected
|
|
541
|
+
|
|
542
|
+
return out_rgb.astype(np.float32), detected
|
|
543
|
+
|
|
544
|
+
# ---------- Satellite remove loop (FIX: use correct ONNX sessions, not det1/rem confusion) ----------
|
|
545
|
+
|
|
546
|
+
def _satellite_remove_rgb(
|
|
547
|
+
rgb01: np.ndarray,
|
|
548
|
+
models: Dict[str, Any],
|
|
549
|
+
*,
|
|
550
|
+
clip_trail: bool,
|
|
551
|
+
sensitivity: float,
|
|
552
|
+
chunk_size: int,
|
|
553
|
+
overlap: int,
|
|
554
|
+
border_size: int,
|
|
555
|
+
progress_cb: Optional[Callable[[int, int], None]],
|
|
556
|
+
) -> Tuple[np.ndarray, bool]:
|
|
557
|
+
"""
|
|
558
|
+
Uses BOTH detectors like your torch path.
|
|
559
|
+
ONNX path now calls det1+det2+rem sessions correctly.
|
|
560
|
+
"""
|
|
561
|
+
is_onnx = bool(models.get("is_onnx", False))
|
|
562
|
+
|
|
563
|
+
H, W = rgb01.shape[:2]
|
|
564
|
+
trail_any = False
|
|
565
|
+
|
|
566
|
+
all_tiles = list(_split_chunks(rgb01, chunk_size, overlap))
|
|
567
|
+
total = len(all_tiles)
|
|
568
|
+
out_tiles = []
|
|
569
|
+
|
|
570
|
+
for idx, (tile, y0, x0) in enumerate(all_tiles, start=1):
|
|
571
|
+
orig = tile.astype(np.float32, copy=False)
|
|
572
|
+
|
|
573
|
+
if is_onnx:
|
|
574
|
+
det1_sess = models["detection_model1"]
|
|
575
|
+
det2_sess = models["detection_model2"]
|
|
576
|
+
rem_sess = models["removal_model"]
|
|
577
|
+
|
|
578
|
+
d1 = _onnx_detect(orig, det1_sess)
|
|
579
|
+
if d1:
|
|
580
|
+
d2 = _onnx_detect(orig, det2_sess)
|
|
581
|
+
else:
|
|
582
|
+
d2 = False
|
|
583
|
+
|
|
584
|
+
detected = bool(d1 and d2)
|
|
585
|
+
if detected:
|
|
586
|
+
trail_any = True
|
|
587
|
+
pred = _onnx_remove(orig, rem_sess)
|
|
588
|
+
final = _apply_clip_trail_logic(pred, orig, sensitivity) if clip_trail else pred
|
|
589
|
+
else:
|
|
590
|
+
final = orig
|
|
591
|
+
|
|
592
|
+
else:
|
|
593
|
+
detected = _torch_detect(orig, models)
|
|
594
|
+
if detected:
|
|
595
|
+
trail_any = True
|
|
596
|
+
pred = _torch_remove(orig, models)
|
|
597
|
+
final = _apply_clip_trail_logic(pred, orig, sensitivity) if clip_trail else pred
|
|
598
|
+
else:
|
|
599
|
+
final = orig
|
|
600
|
+
|
|
601
|
+
out_tiles.append((final, y0, x0))
|
|
602
|
+
|
|
603
|
+
if progress_cb is not None:
|
|
604
|
+
progress_cb(idx, total)
|
|
605
|
+
|
|
606
|
+
out = _stitch_ignore_border(out_tiles, (H, W, 3), border=border_size)
|
|
607
|
+
|
|
608
|
+
# keep edges unchanged
|
|
609
|
+
if border_size > 0:
|
|
610
|
+
out[:border_size, :, :] = rgb01[:border_size, :, :]
|
|
611
|
+
out[-border_size:, :, :] = rgb01[-border_size:, :, :]
|
|
612
|
+
out[:, :border_size, :] = rgb01[:, :border_size, :]
|
|
613
|
+
out[:, -border_size:, :] = rgb01[:, -border_size:, :]
|
|
614
|
+
|
|
615
|
+
out = np.clip(out, 0.0, 1.0).astype(np.float32)
|
|
616
|
+
|
|
617
|
+
if not trail_any:
|
|
618
|
+
return rgb01.astype(np.float32, copy=False), False
|
|
619
|
+
|
|
620
|
+
return out, True
|