setiastrosuitepro 1.7.5.post1__py3-none-any.whl → 1.8.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of setiastrosuitepro might be problematic. Click here for more details.
- setiastro/saspro/_generated/build_info.py +2 -2
- setiastro/saspro/accel_installer.py +21 -8
- setiastro/saspro/accel_workers.py +11 -12
- setiastro/saspro/comet_stacking.py +113 -85
- setiastro/saspro/cosmicclarity.py +604 -826
- setiastro/saspro/cosmicclarity_engines/benchmark_engine.py +732 -0
- setiastro/saspro/cosmicclarity_engines/darkstar_engine.py +576 -0
- setiastro/saspro/cosmicclarity_engines/denoise_engine.py +567 -0
- setiastro/saspro/cosmicclarity_engines/satellite_engine.py +620 -0
- setiastro/saspro/cosmicclarity_engines/sharpen_engine.py +587 -0
- setiastro/saspro/cosmicclarity_engines/superres_engine.py +412 -0
- setiastro/saspro/gui/main_window.py +14 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +2 -0
- setiastro/saspro/model_manager.py +324 -0
- setiastro/saspro/model_workers.py +102 -0
- setiastro/saspro/ops/benchmark.py +320 -0
- setiastro/saspro/ops/settings.py +407 -10
- setiastro/saspro/remove_stars.py +424 -442
- setiastro/saspro/resources.py +73 -10
- setiastro/saspro/runtime_torch.py +107 -22
- setiastro/saspro/signature_insert.py +14 -3
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/METADATA +2 -1
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/RECORD +27 -18
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/WHEEL +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/entry_points.txt +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/LICENSE +0 -0
- {setiastrosuitepro-1.7.5.post1.dist-info → setiastrosuitepro-1.8.0.post3.dist-info}/licenses/license.txt +0 -0
|
@@ -0,0 +1,576 @@
|
|
|
1
|
+
# src/setiastro/saspro/cosmicclarity_engines/darkstar_engine.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Callable, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from setiastro.saspro.resources import get_resources
|
|
10
|
+
|
|
11
|
+
# Optional deps
|
|
12
|
+
try:
|
|
13
|
+
import onnxruntime as ort
|
|
14
|
+
except Exception:
|
|
15
|
+
ort = None
|
|
16
|
+
|
|
17
|
+
ProgressCB = Callable[[int, int, str], None] # (done, total, stage)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ---------------- Torch import (your existing runtime_torch helper) ----------------
|
|
21
|
+
|
|
22
|
+
def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
|
|
23
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
24
|
+
return import_torch(
|
|
25
|
+
prefer_cuda=prefer_cuda,
|
|
26
|
+
prefer_xpu=False,
|
|
27
|
+
prefer_dml=prefer_dml,
|
|
28
|
+
status_cb=status_cb,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _nullcontext():
|
|
33
|
+
from contextlib import nullcontext
|
|
34
|
+
return nullcontext()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _autocast_context(torch, device) -> Any:
|
|
38
|
+
"""
|
|
39
|
+
Use new torch.amp.autocast('cuda') when available.
|
|
40
|
+
Keep your cap >= 8.0 rule.
|
|
41
|
+
"""
|
|
42
|
+
try:
|
|
43
|
+
if hasattr(device, "type") and device.type == "cuda":
|
|
44
|
+
major, minor = torch.cuda.get_device_capability()
|
|
45
|
+
cap = float(f"{major}.{minor}")
|
|
46
|
+
if cap >= 8.0:
|
|
47
|
+
# Preferred API (torch >= 1.10-ish; definitely in 2.x)
|
|
48
|
+
if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
|
|
49
|
+
return torch.amp.autocast(device_type="cuda")
|
|
50
|
+
# Fallback for older torch
|
|
51
|
+
return torch.cuda.amp.autocast()
|
|
52
|
+
except Exception:
|
|
53
|
+
pass
|
|
54
|
+
return _nullcontext()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ---------------- Models (same topology as your script) ----------------
|
|
58
|
+
|
|
59
|
+
def _build_darkstar_torch_models(torch):
|
|
60
|
+
import torch.nn as nn
|
|
61
|
+
|
|
62
|
+
class RefinementCNN(nn.Module):
|
|
63
|
+
def __init__(self, channels: int = 96):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.net = nn.Sequential(
|
|
66
|
+
nn.Conv2d(3, channels, 3, padding=1, dilation=1), nn.ReLU(),
|
|
67
|
+
nn.Conv2d(channels, channels, 3, padding=2, dilation=2), nn.ReLU(),
|
|
68
|
+
nn.Conv2d(channels, channels, 3, padding=4, dilation=4), nn.ReLU(),
|
|
69
|
+
nn.Conv2d(channels, channels, 3, padding=8, dilation=8), nn.ReLU(),
|
|
70
|
+
nn.Conv2d(channels, channels, 3, padding=8, dilation=8), nn.ReLU(),
|
|
71
|
+
nn.Conv2d(channels, channels, 3, padding=4, dilation=4), nn.ReLU(),
|
|
72
|
+
nn.Conv2d(channels, channels, 3, padding=2, dilation=2), nn.ReLU(),
|
|
73
|
+
nn.Conv2d(channels, 3, 3, padding=1, dilation=1), nn.Sigmoid()
|
|
74
|
+
)
|
|
75
|
+
def forward(self, x): return self.net(x)
|
|
76
|
+
|
|
77
|
+
class ResidualBlock(nn.Module):
|
|
78
|
+
def __init__(self, channels: int):
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
81
|
+
self.relu = nn.ReLU(inplace=True)
|
|
82
|
+
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
83
|
+
def forward(self, x):
|
|
84
|
+
out = self.relu(self.conv1(x))
|
|
85
|
+
out = self.conv2(out)
|
|
86
|
+
return self.relu(out + x)
|
|
87
|
+
|
|
88
|
+
class DarkStarCNN(nn.Module):
|
|
89
|
+
def __init__(self):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.encoder1 = nn.Sequential(
|
|
92
|
+
nn.Conv2d(3, 16, 3, padding=1),
|
|
93
|
+
nn.ReLU(inplace=True),
|
|
94
|
+
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16),
|
|
95
|
+
)
|
|
96
|
+
self.encoder2 = nn.Sequential(
|
|
97
|
+
nn.Conv2d(16, 32, 3, padding=1),
|
|
98
|
+
nn.ReLU(inplace=True),
|
|
99
|
+
ResidualBlock(32), ResidualBlock(32), ResidualBlock(32),
|
|
100
|
+
)
|
|
101
|
+
self.encoder3 = nn.Sequential(
|
|
102
|
+
nn.Conv2d(32, 64, 3, padding=2, dilation=2),
|
|
103
|
+
nn.ReLU(inplace=True),
|
|
104
|
+
ResidualBlock(64), ResidualBlock(64),
|
|
105
|
+
)
|
|
106
|
+
self.encoder4 = nn.Sequential(
|
|
107
|
+
nn.Conv2d(64, 128, 3, padding=1),
|
|
108
|
+
nn.ReLU(inplace=True),
|
|
109
|
+
ResidualBlock(128), ResidualBlock(128),
|
|
110
|
+
)
|
|
111
|
+
self.encoder5 = nn.Sequential(
|
|
112
|
+
nn.Conv2d(128, 256, 3, padding=2, dilation=2),
|
|
113
|
+
nn.ReLU(inplace=True),
|
|
114
|
+
ResidualBlock(256),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.decoder5 = nn.Sequential(
|
|
118
|
+
nn.Conv2d(256 + 128, 128, 3, padding=1),
|
|
119
|
+
nn.ReLU(inplace=True),
|
|
120
|
+
ResidualBlock(128), ResidualBlock(128),
|
|
121
|
+
)
|
|
122
|
+
self.decoder4 = nn.Sequential(
|
|
123
|
+
nn.Conv2d(128 + 64, 64, 3, padding=1),
|
|
124
|
+
nn.ReLU(inplace=True),
|
|
125
|
+
ResidualBlock(64), ResidualBlock(64),
|
|
126
|
+
)
|
|
127
|
+
self.decoder3 = nn.Sequential(
|
|
128
|
+
nn.Conv2d(64 + 32, 32, 3, padding=1),
|
|
129
|
+
nn.ReLU(inplace=True),
|
|
130
|
+
ResidualBlock(32), ResidualBlock(32), ResidualBlock(32),
|
|
131
|
+
)
|
|
132
|
+
self.decoder2 = nn.Sequential(
|
|
133
|
+
nn.Conv2d(32 + 16, 16, 3, padding=1),
|
|
134
|
+
nn.ReLU(inplace=True),
|
|
135
|
+
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16),
|
|
136
|
+
)
|
|
137
|
+
self.decoder1 = nn.Sequential(
|
|
138
|
+
nn.Conv2d(16, 16, 3, padding=1),
|
|
139
|
+
nn.ReLU(inplace=True),
|
|
140
|
+
ResidualBlock(16), ResidualBlock(16),
|
|
141
|
+
nn.Conv2d(16, 3, 3, padding=1),
|
|
142
|
+
nn.Sigmoid(),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def forward(self, x):
|
|
146
|
+
e1 = self.encoder1(x)
|
|
147
|
+
e2 = self.encoder2(e1)
|
|
148
|
+
e3 = self.encoder3(e2)
|
|
149
|
+
e4 = self.encoder4(e3)
|
|
150
|
+
e5 = self.encoder5(e4)
|
|
151
|
+
|
|
152
|
+
d5 = self.decoder5(torch.cat([e5, e4], dim=1))
|
|
153
|
+
d4 = self.decoder4(torch.cat([d5, e3], dim=1))
|
|
154
|
+
d3 = self.decoder3(torch.cat([d4, e2], dim=1))
|
|
155
|
+
d2 = self.decoder2(torch.cat([d3, e1], dim=1))
|
|
156
|
+
return self.decoder1(d2)
|
|
157
|
+
|
|
158
|
+
class CascadedStarRemovalNetCombined(nn.Module):
|
|
159
|
+
def __init__(self, stage1_path: str, stage2_path: str | None = None):
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.stage1 = DarkStarCNN()
|
|
162
|
+
ckpt1 = torch.load(stage1_path, map_location="cpu")
|
|
163
|
+
|
|
164
|
+
# strip "stage1." prefix if present
|
|
165
|
+
if isinstance(ckpt1, dict):
|
|
166
|
+
sd1 = {k[len("stage1."):] : v for k, v in ckpt1.items() if k.startswith("stage1.")}
|
|
167
|
+
if sd1:
|
|
168
|
+
ckpt1 = sd1
|
|
169
|
+
self.stage1.load_state_dict(ckpt1)
|
|
170
|
+
|
|
171
|
+
# refinement exists in your code but currently not used (forward returns coarse)
|
|
172
|
+
self.stage2 = RefinementCNN()
|
|
173
|
+
if stage2_path:
|
|
174
|
+
try:
|
|
175
|
+
ckpt2 = torch.load(stage2_path, map_location="cpu")
|
|
176
|
+
if isinstance(ckpt2, dict) and "model_state" in ckpt2:
|
|
177
|
+
ckpt2 = ckpt2["model_state"]
|
|
178
|
+
self.stage2.load_state_dict(ckpt2)
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
for p in self.stage1.parameters():
|
|
183
|
+
p.requires_grad = False
|
|
184
|
+
|
|
185
|
+
def forward(self, x):
|
|
186
|
+
with torch.no_grad():
|
|
187
|
+
coarse = self.stage1(x)
|
|
188
|
+
return coarse
|
|
189
|
+
|
|
190
|
+
return CascadedStarRemovalNetCombined
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# ---------------- Stretch/unstretch + borders (match your other engines) ----------------
|
|
194
|
+
|
|
195
|
+
def add_border(image: np.ndarray, border_size: int = 5) -> np.ndarray:
|
|
196
|
+
if image.ndim == 2:
|
|
197
|
+
med = float(np.median(image))
|
|
198
|
+
return np.pad(image, ((border_size, border_size), (border_size, border_size)),
|
|
199
|
+
mode="constant", constant_values=med)
|
|
200
|
+
if image.ndim == 3 and image.shape[2] == 3:
|
|
201
|
+
meds = np.median(image, axis=(0, 1)).astype(np.float32)
|
|
202
|
+
chans = []
|
|
203
|
+
for c in range(3):
|
|
204
|
+
chans.append(np.pad(image[..., c], ((border_size, border_size), (border_size, border_size)),
|
|
205
|
+
mode="constant", constant_values=float(meds[c])))
|
|
206
|
+
return np.stack(chans, axis=-1)
|
|
207
|
+
raise ValueError("add_border expects 2D or HxWx3")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def remove_border(image: np.ndarray, border_size: int = 5) -> np.ndarray:
|
|
211
|
+
if image.ndim == 2:
|
|
212
|
+
return image[border_size:-border_size, border_size:-border_size]
|
|
213
|
+
return image[border_size:-border_size, border_size:-border_size, :]
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def stretch_image_unlinked_rgb(img_rgb: np.ndarray, target_median: float = 0.25):
|
|
217
|
+
x = img_rgb.astype(np.float32, copy=True)
|
|
218
|
+
orig_min = x.reshape(-1, 3).min(axis=0) # (3,)
|
|
219
|
+
x = (x - orig_min.reshape(1, 1, 3))
|
|
220
|
+
orig_meds = np.median(x, axis=(0, 1)).astype(np.float32)
|
|
221
|
+
|
|
222
|
+
for c in range(3):
|
|
223
|
+
m = float(orig_meds[c])
|
|
224
|
+
if m != 0:
|
|
225
|
+
x[..., c] = ((m - 1) * target_median * x[..., c]) / (
|
|
226
|
+
m * (target_median + x[..., c] - 1) - target_median * x[..., c]
|
|
227
|
+
)
|
|
228
|
+
x = np.clip(x, 0, 1)
|
|
229
|
+
return x, orig_min.astype(np.float32), orig_meds.astype(np.float32)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def unstretch_image_unlinked_rgb(img_rgb: np.ndarray, orig_meds, orig_min):
|
|
233
|
+
x = img_rgb.astype(np.float32, copy=True)
|
|
234
|
+
for c in range(3):
|
|
235
|
+
m_now = float(np.median(x[..., c]))
|
|
236
|
+
m0 = float(orig_meds[c])
|
|
237
|
+
if m_now != 0 and m0 != 0:
|
|
238
|
+
x[..., c] = ((m_now - 1) * m0 * x[..., c]) / (
|
|
239
|
+
m_now * (m0 + x[..., c] - 1) - m0 * x[..., c]
|
|
240
|
+
)
|
|
241
|
+
x = x + orig_min.reshape(1, 1, 3)
|
|
242
|
+
return np.clip(x, 0, 1).astype(np.float32, copy=False)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# ---------------- Chunking & stitch (soft blend like your script) ----------------
|
|
246
|
+
|
|
247
|
+
def split_image_into_chunks_with_overlap(image: np.ndarray, chunk_size: int, overlap: int):
|
|
248
|
+
H, W = image.shape[:2]
|
|
249
|
+
step = chunk_size - overlap
|
|
250
|
+
out = []
|
|
251
|
+
for i in range(0, H, step):
|
|
252
|
+
for j in range(0, W, step):
|
|
253
|
+
ei = min(i + chunk_size, H)
|
|
254
|
+
ej = min(j + chunk_size, W)
|
|
255
|
+
out.append((image[i:ei, j:ej], i, j))
|
|
256
|
+
return out
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _blend_weights(chunk_size: int, overlap: int):
|
|
260
|
+
if overlap <= 0:
|
|
261
|
+
return np.ones((chunk_size, chunk_size), dtype=np.float32)
|
|
262
|
+
ramp = np.linspace(0, 1, overlap, dtype=np.float32)
|
|
263
|
+
flat = np.ones(max(chunk_size - 2 * overlap, 1), dtype=np.float32)
|
|
264
|
+
v = np.concatenate([ramp, flat, ramp[::-1]])
|
|
265
|
+
w = np.outer(v, v).astype(np.float32)
|
|
266
|
+
return w
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def stitch_chunks_soft_blend(
|
|
270
|
+
chunks: list[tuple[np.ndarray, int, int]],
|
|
271
|
+
out_shape: tuple[int, int, int],
|
|
272
|
+
*,
|
|
273
|
+
chunk_size: int,
|
|
274
|
+
overlap: int,
|
|
275
|
+
border_size: int = 5,
|
|
276
|
+
) -> np.ndarray:
|
|
277
|
+
H, W, C = out_shape
|
|
278
|
+
out = np.zeros((H, W, C), np.float32)
|
|
279
|
+
wsum = np.zeros((H, W, 1), np.float32)
|
|
280
|
+
bw_full = _blend_weights(chunk_size, overlap)
|
|
281
|
+
|
|
282
|
+
for tile, i, j in chunks:
|
|
283
|
+
th, tw = tile.shape[:2]
|
|
284
|
+
|
|
285
|
+
# adaptive inner crop like your script
|
|
286
|
+
top = 0 if i == 0 else min(border_size, th // 2)
|
|
287
|
+
left = 0 if j == 0 else min(border_size, tw // 2)
|
|
288
|
+
bottom = 0 if (i + th) >= H else min(border_size, th // 2)
|
|
289
|
+
right = 0 if (j + tw) >= W else min(border_size, tw // 2)
|
|
290
|
+
|
|
291
|
+
inner = tile[top:th-bottom, left:tw-right, :]
|
|
292
|
+
ih, iw = inner.shape[:2]
|
|
293
|
+
|
|
294
|
+
rr0 = i + top
|
|
295
|
+
cc0 = j + left
|
|
296
|
+
rr1 = rr0 + ih
|
|
297
|
+
cc1 = cc0 + iw
|
|
298
|
+
|
|
299
|
+
bw = bw_full[:ih, :iw].reshape(ih, iw, 1)
|
|
300
|
+
out[rr0:rr1, cc0:cc1, :] += inner * bw
|
|
301
|
+
wsum[rr0:rr1, cc0:cc1, :] += bw
|
|
302
|
+
|
|
303
|
+
out = out / np.maximum(wsum, 1e-8)
|
|
304
|
+
return out
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ---------------- Model loading (cached) ----------------
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class DarkStarModels:
|
|
311
|
+
device: Any
|
|
312
|
+
is_onnx: bool
|
|
313
|
+
model: Any
|
|
314
|
+
torch: Any | None = None
|
|
315
|
+
chunk_size: int = 512 # used for ONNX fixed shapes
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# ---------------- Model loading (cached) ----------------
|
|
319
|
+
|
|
320
|
+
_MODELS_CACHE: dict[tuple[str, str], DarkStarModels] = {} # (tag, backend_id)
|
|
321
|
+
|
|
322
|
+
def load_darkstar_models(*, use_gpu: bool, color: bool, status_cb=print) -> DarkStarModels:
|
|
323
|
+
"""
|
|
324
|
+
Backend order:
|
|
325
|
+
1) CUDA (PyTorch)
|
|
326
|
+
2) DirectML (torch-directml) [Windows]
|
|
327
|
+
3) DirectML (ONNX Runtime) [Windows]
|
|
328
|
+
4) MPS (PyTorch) [macOS]
|
|
329
|
+
5) CPU (PyTorch)
|
|
330
|
+
Cache key includes backend_id so we never "stick" on CPU when GPU is later enabled.
|
|
331
|
+
"""
|
|
332
|
+
R = get_resources()
|
|
333
|
+
|
|
334
|
+
if color:
|
|
335
|
+
pth = R.CC_DARKSTAR_COLOR_PTH
|
|
336
|
+
onnx = R.CC_DARKSTAR_COLOR_ONNX
|
|
337
|
+
tag = "cc_darkstar_color"
|
|
338
|
+
else:
|
|
339
|
+
pth = R.CC_DARKSTAR_MONO_PTH
|
|
340
|
+
onnx = R.CC_DARKSTAR_MONO_ONNX
|
|
341
|
+
tag = "cc_darkstar_mono"
|
|
342
|
+
|
|
343
|
+
import os
|
|
344
|
+
is_windows = os.name == "nt"
|
|
345
|
+
|
|
346
|
+
# Request torch with the right preferences (runtime_torch decides what it can do)
|
|
347
|
+
torch = _get_torch(
|
|
348
|
+
prefer_cuda=bool(use_gpu),
|
|
349
|
+
prefer_dml=bool(use_gpu and is_windows),
|
|
350
|
+
status_cb=status_cb,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# ---------------- CUDA (torch) ----------------
|
|
354
|
+
if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
355
|
+
backend_id = "cuda"
|
|
356
|
+
key = (tag, backend_id)
|
|
357
|
+
if key in _MODELS_CACHE:
|
|
358
|
+
return _MODELS_CACHE[key]
|
|
359
|
+
|
|
360
|
+
dev = torch.device("cuda")
|
|
361
|
+
status_cb(f"Dark Star: using CUDA ({torch.cuda.get_device_name(0)})")
|
|
362
|
+
Net = _build_darkstar_torch_models(torch)
|
|
363
|
+
net = Net(pth, None).eval().to(dev)
|
|
364
|
+
|
|
365
|
+
m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
|
|
366
|
+
_MODELS_CACHE[key] = m
|
|
367
|
+
return m
|
|
368
|
+
|
|
369
|
+
# ---------------- DirectML (torch-directml) ----------------
|
|
370
|
+
if use_gpu and is_windows:
|
|
371
|
+
try:
|
|
372
|
+
import torch_directml # optional
|
|
373
|
+
backend_id = "torch_dml"
|
|
374
|
+
key = (tag, backend_id)
|
|
375
|
+
if key in _MODELS_CACHE:
|
|
376
|
+
return _MODELS_CACHE[key]
|
|
377
|
+
|
|
378
|
+
dev = torch_directml.device()
|
|
379
|
+
status_cb("Dark Star: using DirectML (torch-directml)")
|
|
380
|
+
Net = _build_darkstar_torch_models(torch)
|
|
381
|
+
net = Net(pth, None).eval().to(dev)
|
|
382
|
+
|
|
383
|
+
m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
|
|
384
|
+
_MODELS_CACHE[key] = m
|
|
385
|
+
return m
|
|
386
|
+
except Exception:
|
|
387
|
+
pass
|
|
388
|
+
|
|
389
|
+
# ---------------- DirectML (ONNX Runtime) ----------------
|
|
390
|
+
if use_gpu and ort is not None and ("DmlExecutionProvider" in ort.get_available_providers()):
|
|
391
|
+
if onnx and onnx.strip():
|
|
392
|
+
backend_id = "ort_dml"
|
|
393
|
+
key = (tag, backend_id)
|
|
394
|
+
if key in _MODELS_CACHE:
|
|
395
|
+
return _MODELS_CACHE[key]
|
|
396
|
+
|
|
397
|
+
status_cb("Dark Star: using DirectML (ONNX Runtime)")
|
|
398
|
+
sess = ort.InferenceSession(onnx, providers=["DmlExecutionProvider"])
|
|
399
|
+
|
|
400
|
+
# fixed-ish input: [1,3,H,W]; some exports have static shapes
|
|
401
|
+
inp = sess.get_inputs()[0]
|
|
402
|
+
cs = 512
|
|
403
|
+
try:
|
|
404
|
+
if getattr(inp, "shape", None) and len(inp.shape) >= 4:
|
|
405
|
+
if inp.shape[2] not in (None, "None"):
|
|
406
|
+
cs = int(inp.shape[2])
|
|
407
|
+
except Exception:
|
|
408
|
+
pass
|
|
409
|
+
|
|
410
|
+
m = DarkStarModels(device="DirectML", is_onnx=True, model=sess, torch=None, chunk_size=cs)
|
|
411
|
+
_MODELS_CACHE[key] = m
|
|
412
|
+
return m
|
|
413
|
+
|
|
414
|
+
# ---------------- MPS (torch) ----------------
|
|
415
|
+
if use_gpu and hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
416
|
+
backend_id = "mps"
|
|
417
|
+
key = (tag, backend_id)
|
|
418
|
+
if key in _MODELS_CACHE:
|
|
419
|
+
return _MODELS_CACHE[key]
|
|
420
|
+
|
|
421
|
+
dev = torch.device("mps")
|
|
422
|
+
status_cb("Dark Star: using MPS")
|
|
423
|
+
Net = _build_darkstar_torch_models(torch)
|
|
424
|
+
net = Net(pth, None).eval().to(dev)
|
|
425
|
+
|
|
426
|
+
m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
|
|
427
|
+
_MODELS_CACHE[key] = m
|
|
428
|
+
return m
|
|
429
|
+
|
|
430
|
+
# ---------------- CPU (torch) ----------------
|
|
431
|
+
backend_id = "cpu"
|
|
432
|
+
key = (tag, backend_id)
|
|
433
|
+
if key in _MODELS_CACHE:
|
|
434
|
+
return _MODELS_CACHE[key]
|
|
435
|
+
|
|
436
|
+
dev = torch.device("cpu")
|
|
437
|
+
status_cb("Dark Star: using CPU")
|
|
438
|
+
Net = _build_darkstar_torch_models(torch)
|
|
439
|
+
net = Net(pth, None).eval().to(dev)
|
|
440
|
+
|
|
441
|
+
m = DarkStarModels(device=dev, is_onnx=False, model=net, torch=torch, chunk_size=512)
|
|
442
|
+
_MODELS_CACHE[key] = m
|
|
443
|
+
return m
|
|
444
|
+
|
|
445
|
+
# ---------------- Core inference on one HxWx3 image ----------------
|
|
446
|
+
|
|
447
|
+
def _infer_tile(models: DarkStarModels, tile_rgb: np.ndarray) -> np.ndarray:
|
|
448
|
+
tile_rgb = np.asarray(tile_rgb, np.float32)
|
|
449
|
+
h0, w0 = tile_rgb.shape[:2]
|
|
450
|
+
|
|
451
|
+
if models.is_onnx:
|
|
452
|
+
cs = int(models.chunk_size)
|
|
453
|
+
|
|
454
|
+
# pad/crop robustly to (cs,cs,3)
|
|
455
|
+
if (h0 != cs) or (w0 != cs):
|
|
456
|
+
pad = np.zeros((cs, cs, 3), np.float32)
|
|
457
|
+
hh = min(h0, cs)
|
|
458
|
+
ww = min(w0, cs)
|
|
459
|
+
pad[:hh, :ww, :] = tile_rgb[:hh, :ww, :]
|
|
460
|
+
tile_rgb = pad
|
|
461
|
+
|
|
462
|
+
inp = tile_rgb.transpose(2, 0, 1)[None, ...] # 1,3,H,W
|
|
463
|
+
sess = models.model
|
|
464
|
+
out = sess.run(None, {sess.get_inputs()[0].name: inp})[0][0] # 3,H,W
|
|
465
|
+
out = out.transpose(1, 2, 0)
|
|
466
|
+
|
|
467
|
+
hh = min(h0, cs)
|
|
468
|
+
ww = min(w0, cs)
|
|
469
|
+
return out[:hh, :ww, :].astype(np.float32, copy=False)
|
|
470
|
+
|
|
471
|
+
# torch (CUDA / MPS / CPU / torch-directml)
|
|
472
|
+
torch = models.torch
|
|
473
|
+
dev = models.device
|
|
474
|
+
t = torch.from_numpy(tile_rgb.transpose(2, 0, 1)).unsqueeze(0).to(dev)
|
|
475
|
+
|
|
476
|
+
with torch.no_grad(), _autocast_context(torch, dev):
|
|
477
|
+
y = models.model(t)[0].detach().cpu().numpy().transpose(1, 2, 0)
|
|
478
|
+
|
|
479
|
+
return y[:h0, :w0, :].astype(np.float32, copy=False)
|
|
480
|
+
|
|
481
|
+
# ---------------- Public API ----------------
|
|
482
|
+
|
|
483
|
+
@dataclass
|
|
484
|
+
class DarkStarParams:
|
|
485
|
+
use_gpu: bool = True
|
|
486
|
+
chunk_size: int = 512
|
|
487
|
+
overlap_frac: float = 0.125
|
|
488
|
+
mode: str = "unscreen" # "unscreen" or "additive"
|
|
489
|
+
output_stars_only: bool = False
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def darkstar_starremoval_rgb01(
|
|
493
|
+
img_rgb01: np.ndarray,
|
|
494
|
+
*,
|
|
495
|
+
params: DarkStarParams,
|
|
496
|
+
progress_cb: Optional[ProgressCB] = None,
|
|
497
|
+
status_cb=print,
|
|
498
|
+
) -> tuple[np.ndarray, Optional[np.ndarray], bool]:
|
|
499
|
+
"""
|
|
500
|
+
Input : float32 image in [0..1], shape HxWx3 or HxWx1 or HxW
|
|
501
|
+
Output: (starless_rgb01, stars_only_rgb01 or None, was_mono)
|
|
502
|
+
"""
|
|
503
|
+
if progress_cb is None:
|
|
504
|
+
progress_cb = lambda done, total, stage: None
|
|
505
|
+
|
|
506
|
+
img = np.asarray(img_rgb01, np.float32)
|
|
507
|
+
was_mono = (img.ndim == 2) or (img.ndim == 3 and img.shape[2] == 1)
|
|
508
|
+
|
|
509
|
+
# normalize shape to HxWx3
|
|
510
|
+
if img.ndim == 2:
|
|
511
|
+
img3 = np.stack([img, img, img], axis=-1)
|
|
512
|
+
elif img.ndim == 3 and img.shape[2] == 1:
|
|
513
|
+
ch = img[..., 0]
|
|
514
|
+
img3 = np.stack([ch, ch, ch], axis=-1)
|
|
515
|
+
else:
|
|
516
|
+
img3 = img
|
|
517
|
+
|
|
518
|
+
img3 = np.clip(img3, 0.0, 1.0)
|
|
519
|
+
|
|
520
|
+
# decide "true RGB" vs "3-channel mono"
|
|
521
|
+
same_rg = np.allclose(img3[..., 0], img3[..., 1], rtol=0, atol=1e-6)
|
|
522
|
+
same_rb = np.allclose(img3[..., 0], img3[..., 2], rtol=0, atol=1e-6)
|
|
523
|
+
is_true_rgb = not (same_rg and same_rb)
|
|
524
|
+
|
|
525
|
+
models = load_darkstar_models(use_gpu=params.use_gpu, color=is_true_rgb, status_cb=status_cb)
|
|
526
|
+
|
|
527
|
+
# stretch decision: pedestal-aware (matches other engines more closely)
|
|
528
|
+
stretch_needed = float(np.median(img3 - float(np.min(img3)))) < 0.125
|
|
529
|
+
if stretch_needed:
|
|
530
|
+
stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(img3)
|
|
531
|
+
else:
|
|
532
|
+
stretched, orig_min, orig_meds = img3, None, None
|
|
533
|
+
|
|
534
|
+
bordered = add_border(stretched, border_size=5)
|
|
535
|
+
|
|
536
|
+
# ONNX may force chunk_size
|
|
537
|
+
chunk_size = int(models.chunk_size) if models.is_onnx else int(params.chunk_size)
|
|
538
|
+
overlap = int(round(float(params.overlap_frac) * chunk_size))
|
|
539
|
+
|
|
540
|
+
chunks = split_image_into_chunks_with_overlap(bordered, chunk_size=chunk_size, overlap=overlap)
|
|
541
|
+
total = len(chunks)
|
|
542
|
+
|
|
543
|
+
out_tiles: list[tuple[np.ndarray, int, int]] = []
|
|
544
|
+
for k, (tile, i, j) in enumerate(chunks, start=1):
|
|
545
|
+
out = _infer_tile(models, tile)
|
|
546
|
+
out_tiles.append((out, i, j))
|
|
547
|
+
progress_cb(k, total, "Dark Star removal")
|
|
548
|
+
|
|
549
|
+
starless_b = stitch_chunks_soft_blend(
|
|
550
|
+
out_tiles,
|
|
551
|
+
bordered.shape,
|
|
552
|
+
chunk_size=chunk_size,
|
|
553
|
+
overlap=overlap,
|
|
554
|
+
border_size=5,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
if stretch_needed:
|
|
558
|
+
starless_b = unstretch_image_unlinked_rgb(starless_b, orig_meds, orig_min)
|
|
559
|
+
|
|
560
|
+
starless = remove_border(starless_b, border_size=5)
|
|
561
|
+
starless = np.clip(starless, 0.0, 1.0).astype(np.float32, copy=False)
|
|
562
|
+
|
|
563
|
+
stars_only = None
|
|
564
|
+
if params.output_stars_only:
|
|
565
|
+
if params.mode == "additive":
|
|
566
|
+
stars_only = np.clip(img3 - starless, 0.0, 1.0).astype(np.float32, copy=False)
|
|
567
|
+
else: # unscreen
|
|
568
|
+
denom = np.maximum(1.0 - starless, 1e-6)
|
|
569
|
+
stars_only = np.clip((img3 - starless) / denom, 0.0, 1.0).astype(np.float32, copy=False)
|
|
570
|
+
|
|
571
|
+
if was_mono:
|
|
572
|
+
starless = np.mean(starless, axis=2, keepdims=True).astype(np.float32, copy=False)
|
|
573
|
+
if stars_only is not None:
|
|
574
|
+
stars_only = np.mean(stars_only, axis=2, keepdims=True).astype(np.float32, copy=False)
|
|
575
|
+
|
|
576
|
+
return starless, stars_only, was_mono
|