setiastrosuitepro 1.8.0__py3-none-any.whl → 1.8.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

@@ -7,6 +7,7 @@ from typing import Callable, Optional, Any
7
7
  import os
8
8
  import numpy as np
9
9
  from setiastro.saspro.resources import get_resources
10
+ from setiastro.saspro.runtime_torch import _user_runtime_dir, _venv_paths, _check_cuda_in_venv
10
11
 
11
12
 
12
13
  # Optional deps used by auto-PSF
@@ -20,6 +21,24 @@ try:
20
21
  except Exception:
21
22
  ort = None
22
23
 
24
+ import sys, time, tempfile, traceback
25
+
26
+ _DEBUG_SHARPEN = True # flip False later
27
+
28
+ def _dbg(msg: str, status_cb=print):
29
+ pass
30
+ #"""Debug print to terminal; also tries status_cb."""
31
+ #if not _DEBUG_SHARPEN:
32
+ # return
33
+
34
+ #ts = time.strftime("%H:%M:%S")
35
+ #line = f"[SharpenDBG {ts}] {msg}"
36
+ #print(line, flush=True)
37
+ #try:
38
+ # status_cb(line)
39
+ #except Exception:
40
+ # pass
41
+
23
42
 
24
43
  ProgressCB = Callable[[int, int, str], bool] # True=continue, False=cancel
25
44
 
@@ -51,17 +70,23 @@ def _autocast_context(torch, device) -> Any:
51
70
  major, minor = torch.cuda.get_device_capability()
52
71
  cap = float(f"{major}.{minor}")
53
72
  if cap >= 8.0:
54
- # Preferred API (torch >= 1.10-ish; definitely in 2.x)
55
73
  if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
56
74
  return torch.amp.autocast(device_type="cuda")
57
- # Fallback for older torch
58
75
  return torch.cuda.amp.autocast()
76
+
77
+ elif hasattr(device, "type") and device.type == "mps":
78
+ # MPS often benefits from autocast in newer torch versions
79
+ if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
80
+ return torch.amp.autocast(device_type="mps")
81
+
59
82
  except Exception:
60
83
  pass
84
+
61
85
  return _nullcontext()
62
86
 
63
87
 
64
88
 
89
+
65
90
  def _to_3ch(image: np.ndarray) -> tuple[np.ndarray, bool]:
66
91
  """Return (img3, was_mono). img3 is HxWx3 float32."""
67
92
  if image.ndim == 2:
@@ -115,6 +140,8 @@ def split_image_into_chunks_with_overlap(image2d: np.ndarray, chunk_size: int, o
115
140
  for j in range(0, W, step):
116
141
  ei = min(i + chunk_size, H)
117
142
  ej = min(j + chunk_size, W)
143
+ if ei <= i or ej <= j:
144
+ continue
118
145
  chunk = image2d[i:ei, j:ej]
119
146
  is_edge = (i == 0) or (j == 0) or (i + chunk_size >= H) or (j + chunk_size >= W)
120
147
  out.append((chunk, i, j, is_edge))
@@ -122,21 +149,55 @@ def split_image_into_chunks_with_overlap(image2d: np.ndarray, chunk_size: int, o
122
149
 
123
150
 
124
151
  def stitch_chunks_ignore_border(chunks, image_shape, border_size: int = 16):
125
- stitched = np.zeros(image_shape, dtype=np.float32)
126
- weights = np.zeros(image_shape, dtype=np.float32)
152
+ H, W = image_shape
153
+ stitched = np.zeros((H, W), dtype=np.float32)
154
+ weights = np.zeros((H, W), dtype=np.float32)
127
155
 
128
156
  for chunk, i, j, _is_edge in chunks:
129
157
  h, w = chunk.shape
158
+ if h <= 0 or w <= 0:
159
+ continue
160
+
130
161
  bh = min(border_size, h // 2)
131
162
  bw = min(border_size, w // 2)
163
+
164
+ y0 = i + bh
165
+ y1 = i + h - bh
166
+ x0 = j + bw
167
+ x1 = j + w - bw
168
+
169
+ # empty inner region? skip
170
+ if y1 <= y0 or x1 <= x0:
171
+ continue
172
+
132
173
  inner = chunk[bh:h-bh, bw:w-bw]
133
- stitched[i+bh:i+h-bh, j+bw:j+w-bw] += inner
134
- weights[i+bh:i+h-bh, j+bw:j+w-bw] += 1.0
174
+
175
+ # clip destination to image bounds
176
+ yy0 = max(0, y0)
177
+ yy1 = min(H, y1)
178
+ xx0 = max(0, x0)
179
+ xx1 = min(W, x1)
180
+
181
+ # did clipping erase it?
182
+ if yy1 <= yy0 or xx1 <= xx0:
183
+ continue
184
+
185
+ # clip source to match clipped destination
186
+ sy0 = yy0 - y0
187
+ sy1 = sy0 + (yy1 - yy0)
188
+ sx0 = xx0 - x0
189
+ sx1 = sx0 + (xx1 - xx0)
190
+
191
+ src = inner[sy0:sy1, sx0:sx1]
192
+
193
+ stitched[yy0:yy1, xx0:xx1] += src
194
+ weights[yy0:yy1, xx0:xx1] += 1.0
135
195
 
136
196
  stitched /= np.maximum(weights, 1.0)
137
197
  return stitched
138
198
 
139
199
 
200
+
140
201
  def add_border(image: np.ndarray, border_size: int = 16) -> np.ndarray:
141
202
  med = float(np.median(image))
142
203
  if image.ndim == 2:
@@ -231,62 +292,180 @@ class SharpenModels:
231
292
  torch: Any | None = None # set for torch path
232
293
 
233
294
 
234
- _MODELS_CACHE: dict[tuple[str, bool], SharpenModels] = {} # (backend_tag, use_gpu)
295
+ # Cache by (backend_tag, resolved_backend)
296
+ _MODELS_CACHE: dict[tuple[str, str], SharpenModels] = {} # (backend_tag, resolved)
235
297
 
236
298
  def load_sharpen_models(use_gpu: bool, status_cb=print) -> SharpenModels:
237
299
  backend_tag = "cc_sharpen_ai3_5s"
238
- key = (backend_tag, bool(use_gpu))
239
- if key in _MODELS_CACHE:
240
- return _MODELS_CACHE[key]
241
-
242
300
  is_windows = (os.name == "nt")
301
+ _dbg(f"ENTER load_sharpen_models(use_gpu={use_gpu})", status_cb=status_cb)
243
302
 
244
- # ask runtime to prefer DML only on Windows + when user wants GPU
245
- torch = _get_torch(
246
- prefer_cuda=bool(use_gpu), # still try CUDA first
247
- prefer_dml=bool(use_gpu and is_windows), # enable DML install/usage path
248
- status_cb=status_cb,
249
- )
303
+ # ---- torch runtime import ----
304
+ t0 = time.time()
305
+ try:
306
+ _dbg("Calling _get_torch(...)", status_cb=status_cb)
307
+ torch = _get_torch(
308
+ prefer_cuda=bool(use_gpu),
309
+ prefer_dml=False,
310
+ status_cb=status_cb,
311
+ )
312
+ _dbg(f"_get_torch OK in {time.time()-t0:.3f}s; torch={getattr(torch,'__version__',None)} file={getattr(torch,'__file__',None)}",
313
+ status_cb=status_cb)
314
+ except Exception as e:
315
+ _dbg("ERROR in _get_torch:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
316
+ status_cb=status_cb)
317
+ raise
318
+
319
+ # ---- runtime venv CUDA probe (subprocess) ----
320
+ try:
321
+ rt = _user_runtime_dir()
322
+ vpy = _venv_paths(rt)["python"]
323
+ _dbg(f"Runtime dir={rt} venv_python={vpy}", status_cb=status_cb)
324
+ t1 = time.time()
325
+ ok, cuda_tag, err = _check_cuda_in_venv(vpy, status_cb=status_cb)
326
+ _dbg(f"CUDA probe finished in {time.time()-t1:.3f}s: ok={ok}, torch.version.cuda={cuda_tag}, err={err!r}",
327
+ status_cb=status_cb)
328
+ except Exception as e:
329
+ _dbg("Runtime CUDA probe FAILED:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
330
+ status_cb=status_cb)
331
+
332
+ # ---- CUDA branch ----
333
+ if use_gpu:
334
+ _dbg("Checking torch.cuda.is_available() ...", status_cb=status_cb)
335
+ try:
336
+ t2 = time.time()
337
+ cuda_ok = bool(getattr(torch, "cuda", None) and torch.cuda.is_available())
338
+ _dbg(f"torch.cuda.is_available()={cuda_ok} (took {time.time()-t2:.3f}s)", status_cb=status_cb)
339
+ except Exception as e:
340
+ _dbg("torch.cuda.is_available() raised:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
341
+ status_cb=status_cb)
342
+ cuda_ok = False
343
+
344
+ if cuda_ok:
345
+ cache_key = (backend_tag, "cuda")
346
+ if cache_key in _MODELS_CACHE:
347
+ _dbg("Returning cached CUDA models.", status_cb=status_cb)
348
+ return _MODELS_CACHE[cache_key]
250
349
 
251
- # 1) CUDA
252
- if use_gpu and hasattr(torch, "cuda") and torch.cuda.is_available():
253
- device = torch.device("cuda")
254
- status_cb(f"CosmicClarity Sharpen: using CUDA ({torch.cuda.get_device_name(0)})")
255
- models = _load_torch_models(torch, device)
256
- _MODELS_CACHE[key] = models
257
- return models
350
+ try:
351
+ _dbg("Creating torch.device('cuda') ...", status_cb=status_cb)
352
+ device = torch.device("cuda")
353
+ _dbg("Querying torch.cuda.get_device_name(0) ...", status_cb=status_cb)
354
+ name = torch.cuda.get_device_name(0)
355
+ _dbg(f"Using CUDA device: {name}", status_cb=status_cb)
356
+
357
+ _dbg("Loading torch .pth models (CUDA) ...", status_cb=status_cb)
358
+ t3 = time.time()
359
+ models = _load_torch_models(torch, device)
360
+ _dbg(f"Loaded CUDA models in {time.time()-t3:.3f}s", status_cb=status_cb)
361
+
362
+ _MODELS_CACHE[cache_key] = models
363
+ return models
364
+ except Exception as e:
365
+ _dbg("CUDA path failed:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
366
+ status_cb=status_cb)
367
+ # fall through to DML/ORT/CPU
368
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
369
+ # ADD THE MPS BLOCK RIGHT HERE (after CUDA, before DirectML)
370
+ # ---- MPS branch (macOS Apple Silicon) ----
371
+ if use_gpu:
372
+ try:
373
+ mps_ok = bool(
374
+ hasattr(torch, "backends")
375
+ and hasattr(torch.backends, "mps")
376
+ and torch.backends.mps.is_available()
377
+ )
378
+ except Exception:
379
+ mps_ok = False
380
+
381
+ if mps_ok:
382
+ cache_key = (backend_tag, "mps")
383
+ if cache_key in _MODELS_CACHE:
384
+ _dbg("Returning cached MPS models.", status_cb=status_cb)
385
+ return _MODELS_CACHE[cache_key]
258
386
 
259
- # 2) Torch-DirectML (Windows)
387
+ try:
388
+ device = torch.device("mps")
389
+ _dbg("CosmicClarity Sharpen: using MPS", status_cb=status_cb)
390
+
391
+ t_m = time.time()
392
+ models = _load_torch_models(torch, device)
393
+ _dbg(f"Loaded MPS models in {time.time()-t_m:.3f}s", status_cb=status_cb)
394
+
395
+ _MODELS_CACHE[cache_key] = models
396
+ return models
397
+ except Exception as e:
398
+ _dbg("MPS path failed:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
399
+ status_cb=status_cb)
400
+ # fall through
401
+ # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
402
+ # ---- Torch DirectML branch ----
260
403
  if use_gpu and is_windows:
404
+ _dbg("Trying torch-directml path ...", status_cb=status_cb)
261
405
  try:
262
- import torch_directml # provided by torch-directml
406
+ import torch_directml
407
+ cache_key = (backend_tag, "dml_torch")
408
+ if cache_key in _MODELS_CACHE:
409
+ _dbg("Returning cached DML-torch models.", status_cb=status_cb)
410
+ return _MODELS_CACHE[cache_key]
411
+
412
+ t4 = time.time()
263
413
  dml = torch_directml.device()
264
- status_cb("CosmicClarity Sharpen: using DirectML (torch-directml)")
414
+ _dbg(f"torch_directml.device() OK in {time.time()-t4:.3f}s", status_cb=status_cb)
415
+
416
+ _dbg("Loading torch .pth models (DirectML) ...", status_cb=status_cb)
417
+ t5 = time.time()
265
418
  models = _load_torch_models(torch, dml)
266
- _MODELS_CACHE[key] = models
267
- status_cb(f"torch.__version__={getattr(torch,'__version__',None)}; "
268
- f"cuda_available={bool(getattr(torch,'cuda',None) and torch.cuda.is_available())}")
419
+ _dbg(f"Loaded DML-torch models in {time.time()-t5:.3f}s", status_cb=status_cb)
420
+
421
+ _MODELS_CACHE[cache_key] = models
269
422
  return models
270
-
271
- except Exception:
272
- pass
423
+ except Exception as e:
424
+ _dbg("DirectML (torch-directml) failed:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
425
+ status_cb=status_cb)
273
426
 
274
- # 3) ONNX Runtime DirectML fallback
275
- if use_gpu and ort is not None and "DmlExecutionProvider" in ort.get_available_providers():
276
- status_cb("CosmicClarity Sharpen: using DirectML (ONNX Runtime)")
277
- models = _load_onnx_models()
278
- _MODELS_CACHE[key] = models
279
- return models
427
+ # ---- ONNX Runtime DirectML fallback ----
428
+ if use_gpu and ort is not None:
429
+ _dbg("Checking ORT providers ...", status_cb=status_cb)
430
+ try:
431
+ prov = ort.get_available_providers()
432
+ _dbg(f"ORT providers: {prov}", status_cb=status_cb)
433
+ if "DmlExecutionProvider" in prov:
434
+ cache_key = (backend_tag, "dml_ort")
435
+ if cache_key in _MODELS_CACHE:
436
+ _dbg("Returning cached DML-ORT models.", status_cb=status_cb)
437
+ return _MODELS_CACHE[cache_key]
438
+
439
+ _dbg("Loading ONNX models (DML EP) ...", status_cb=status_cb)
440
+ t6 = time.time()
441
+ models = _load_onnx_models()
442
+ _dbg(f"Loaded ONNX models in {time.time()-t6:.3f}s", status_cb=status_cb)
443
+
444
+ _MODELS_CACHE[cache_key] = models
445
+ return models
446
+ except Exception as e:
447
+ _dbg("ORT provider check/load failed:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
448
+ status_cb=status_cb)
449
+
450
+ # ---- CPU fallback ----
451
+ cache_key = (backend_tag, "cpu")
452
+ if cache_key in _MODELS_CACHE:
453
+ _dbg("Returning cached CPU models.", status_cb=status_cb)
454
+ return _MODELS_CACHE[cache_key]
280
455
 
281
- # 4) CPU
282
- device = torch.device("cpu")
283
- status_cb("CosmicClarity Sharpen: using CPU")
284
- models = _load_torch_models(torch, device)
285
- _MODELS_CACHE[key] = models
286
- status_cb(f"Sharpen backend resolved: "
287
- f"{'onnx' if models.is_onnx else 'torch'} / device={models.device!r}")
456
+ try:
457
+ _dbg("Falling back to CPU torch models ...", status_cb=status_cb)
458
+ device = torch.device("cpu")
459
+ t7 = time.time()
460
+ models = _load_torch_models(torch, device)
461
+ _dbg(f"Loaded CPU models in {time.time()-t7:.3f}s", status_cb=status_cb)
462
+ _MODELS_CACHE[cache_key] = models
463
+ return models
464
+ except Exception as e:
465
+ _dbg("CPU model load failed:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
466
+ status_cb=status_cb)
467
+ raise
288
468
 
289
- return models
290
469
 
291
470
 
292
471
  def _load_onnx_models() -> SharpenModels:
@@ -381,6 +560,7 @@ def _infer_chunk(models: SharpenModels, model: Any, chunk2d: np.ndarray) -> np.n
381
560
  h0, w0 = chunk2d.shape
382
561
 
383
562
  if models.is_onnx:
563
+ t0 = time.time()
384
564
  inp = chunk2d[np.newaxis, np.newaxis, :, :].astype(np.float32) # (1,1,H,W)
385
565
  inp = np.tile(inp, (1, 3, 1, 1)) # (1,3,H,W)
386
566
  h, w = inp.shape[2:]
@@ -391,15 +571,54 @@ def _infer_chunk(models: SharpenModels, model: Any, chunk2d: np.ndarray) -> np.n
391
571
  name_in = model.get_inputs()[0].name
392
572
  name_out = model.get_outputs()[0].name
393
573
  out = model.run([name_out], {name_in: inp})[0][0, 0]
394
- return out[:h0, :w0].astype(np.float32, copy=False)
574
+ y = out[:h0, :w0].astype(np.float32, copy=False)
575
+ if _DEBUG_SHARPEN:
576
+ _dbg(f"ORT infer OK {h0}x{w0} in {time.time()-t0:.3f}s", status_cb=lambda *_: None)
577
+ return y
395
578
 
396
579
  # torch path
397
580
  torch = models.torch
398
581
  dev = models.device
399
- t = torch.tensor(chunk2d, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(dev)
400
- with torch.no_grad(), _autocast_context(torch, dev):
401
- y = model(t.repeat(1, 3, 1, 1)).squeeze().detach().cpu().numpy()[0]
402
- return y[:h0, :w0].astype(np.float32, copy=False)
582
+
583
+ t0 = time.time()
584
+ if _DEBUG_SHARPEN:
585
+ _dbg(f"Torch infer start chunk={h0}x{w0} dev={getattr(dev,'type',dev)}", status_cb=lambda *_: None)
586
+
587
+ try:
588
+ # tensor creation (CPU)
589
+ t_cpu = torch.tensor(chunk2d, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
590
+
591
+ # move to device (this can hang if CUDA context is broken)
592
+ if _DEBUG_SHARPEN:
593
+ _dbg(" moving tensor to device ...", status_cb=lambda *_: None)
594
+ t = t_cpu.to(dev)
595
+
596
+ # optional: force CUDA sync right after first transfer
597
+ if _DEBUG_SHARPEN and hasattr(dev, "type") and dev.type == "cuda":
598
+ _dbg(" cuda.synchronize after .to(dev) ...", status_cb=lambda *_: None)
599
+ torch.cuda.synchronize()
600
+
601
+ with torch.no_grad(), _autocast_context(torch, dev):
602
+ if _DEBUG_SHARPEN:
603
+ _dbg(" running model forward ...", status_cb=lambda *_: None)
604
+ y = model(t.repeat(1, 3, 1, 1))
605
+
606
+ if _DEBUG_SHARPEN and hasattr(dev, "type") and dev.type == "cuda":
607
+ _dbg(" cuda.synchronize after forward ...", status_cb=lambda *_: None)
608
+ torch.cuda.synchronize()
609
+
610
+ y = y.squeeze().detach().cpu().numpy()[0]
611
+
612
+ out = y[:h0, :w0].astype(np.float32, copy=False)
613
+ if _DEBUG_SHARPEN:
614
+ _dbg(f"Torch infer OK in {time.time()-t0:.3f}s", status_cb=lambda *_: None)
615
+ return out
616
+
617
+ except Exception as e:
618
+ if _DEBUG_SHARPEN:
619
+ _dbg("Torch infer ERROR:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
620
+ status_cb=lambda *_: None)
621
+ raise
403
622
 
404
623
 
405
624
  # ---------------- Main API ----------------
@@ -419,79 +638,123 @@ def sharpen_image_array(image: np.ndarray,
419
638
  params: SharpenParams,
420
639
  progress_cb: Optional[ProgressCB] = None,
421
640
  status_cb=print) -> tuple[np.ndarray, bool]:
422
- """
423
- Pure in-memory sharpen. Returns (out_image, was_mono).
424
- """
425
641
  if progress_cb is None:
426
642
  progress_cb = lambda done, total, stage: True
427
643
 
428
- img = np.asarray(image)
429
- if img.dtype != np.float32:
430
- img = img.astype(np.float32, copy=False)
644
+ _dbg("ENTER sharpen_image_array()", status_cb=status_cb)
645
+ t_all = time.time()
431
646
 
432
- img3, was_mono = _to_3ch(img)
433
- img3 = np.clip(img3, 0.0, 1.0)
434
-
435
- models = load_sharpen_models(use_gpu=params.use_gpu, status_cb=status_cb)
436
-
437
- # border & stretch
438
- bordered = add_border(img3, border_size=16)
439
- stretch_needed = (np.median(bordered - np.min(bordered)) < 0.08)
440
-
441
- if stretch_needed:
442
- stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(bordered)
443
- else:
444
- stretched, orig_min, orig_meds = bordered, None, None
445
-
446
- # per-channel sharpening option (color only)
447
- if params.sharpen_channels_separately and (not was_mono):
448
- out = np.empty_like(stretched)
449
- for c, label in enumerate(("R", "G", "B")):
450
- progress_cb(0, 1, f"Sharpening {label} channel")
451
- out[..., c] = _sharpen_plane(models, stretched[..., c], params, progress_cb)
452
- sharpened = out
453
- else:
454
- # luminance pipeline (works for mono too, since mono is in all 3 chans)
455
- y, cb, cr = extract_luminance_rgb(stretched)
456
- y2 = _sharpen_plane(models, y, params, progress_cb)
457
- sharpened = merge_luminance(y2, cb, cr)
647
+ try:
648
+ img = np.asarray(image)
649
+ _dbg(f"Input shape={img.shape} dtype={img.dtype} min={float(np.nanmin(img)):.6f} max={float(np.nanmax(img)):.6f}",
650
+ status_cb=status_cb)
458
651
 
459
- # unstretch / deborder
460
- if stretch_needed:
461
- sharpened = unstretch_image_unlinked_rgb(sharpened, orig_meds, orig_min, was_mono)
652
+ if img.dtype != np.float32:
653
+ img = img.astype(np.float32, copy=False)
654
+ _dbg("Converted input to float32", status_cb=status_cb)
462
655
 
463
- sharpened = remove_border(sharpened, border_size=16)
656
+ img3, was_mono = _to_3ch(img)
657
+ img3 = np.clip(img3, 0.0, 1.0)
658
+ _dbg(f"After _to_3ch: shape={img3.shape} was_mono={was_mono}", status_cb=status_cb)
464
659
 
465
- # return mono as HxWx1 if it came in mono (matches your CC behavior)
466
- if was_mono:
467
- if sharpened.ndim == 3 and sharpened.shape[2] == 3:
468
- sharpened = np.mean(sharpened, axis=2, keepdims=True).astype(np.float32, copy=False)
660
+ # prove progress_cb works
661
+ try:
662
+ progress_cb(0, 1, "Loading models")
663
+ except Exception:
664
+ pass
469
665
 
470
- return np.clip(sharpened, 0.0, 1.0), was_mono
666
+ models = load_sharpen_models(use_gpu=params.use_gpu, status_cb=status_cb)
667
+ _dbg(f"Models loaded: is_onnx={models.is_onnx} device={models.device!r}", status_cb=status_cb)
668
+
669
+ # border & stretch
670
+ bordered = add_border(img3, border_size=16)
671
+ med_metric = float(np.median(bordered - np.min(bordered)))
672
+ stretch_needed = (med_metric < 0.08)
673
+ _dbg(f"Bordered shape={bordered.shape}; stretch_metric={med_metric:.6f}; stretch_needed={stretch_needed}",
674
+ status_cb=status_cb)
675
+
676
+ if stretch_needed:
677
+ _dbg("Stretching unlinked RGB ...", status_cb=status_cb)
678
+ stretched, orig_min, orig_meds = stretch_image_unlinked_rgb(bordered)
679
+ else:
680
+ stretched, orig_min, orig_meds = bordered, None, None
681
+
682
+ # per-channel sharpening option (color only)
683
+ if params.sharpen_channels_separately and (not was_mono):
684
+ _dbg("Sharpen per-channel path", status_cb=status_cb)
685
+ out = np.empty_like(stretched)
686
+ for c, label in enumerate(("R", "G", "B")):
687
+ progress_cb(0, 1, f"Sharpening {label} channel")
688
+ out[..., c] = _sharpen_plane(models, stretched[..., c], params, progress_cb)
689
+ sharpened = out
690
+ else:
691
+ _dbg("Sharpen luminance path", status_cb=status_cb)
692
+ y, cb, cr = extract_luminance_rgb(stretched)
693
+ y2 = _sharpen_plane(models, y, params, progress_cb)
694
+ sharpened = merge_luminance(y2, cb, cr)
695
+
696
+ # unstretch / deborder
697
+ if stretch_needed:
698
+ _dbg("Unstretching ...", status_cb=status_cb)
699
+ sharpened = unstretch_image_unlinked_rgb(sharpened, orig_meds, orig_min, was_mono)
700
+
701
+ sharpened = remove_border(sharpened, border_size=16)
702
+
703
+ if was_mono:
704
+ if sharpened.ndim == 3 and sharpened.shape[2] == 3:
705
+ sharpened = np.mean(sharpened, axis=2, keepdims=True).astype(np.float32, copy=False)
706
+
707
+ out = np.clip(sharpened, 0.0, 1.0)
708
+ _dbg(f"EXIT sharpen_image_array total_time={time.time()-t_all:.2f}s out_shape={out.shape}", status_cb=status_cb)
709
+ return out, was_mono
710
+
711
+ except Exception as e:
712
+ _dbg("sharpen_image_array ERROR:\n" + "".join(traceback.format_exception(type(e), e, e.__traceback__)),
713
+ status_cb=status_cb)
714
+ raise
471
715
 
472
716
 
473
717
  def _sharpen_plane(models: SharpenModels,
474
718
  plane: np.ndarray,
475
719
  params: SharpenParams,
476
720
  progress_cb: ProgressCB) -> np.ndarray:
477
- """
478
- Sharpen a single 2D plane using your two-stage pipeline.
479
- """
480
721
  plane = np.asarray(plane, np.float32)
481
722
  chunks = split_image_into_chunks_with_overlap(plane, chunk_size=256, overlap=64)
482
723
  total = len(chunks)
483
724
 
725
+ # prove we got here
726
+ try:
727
+ progress_cb(0, max(total, 1), f"Sharpen start ({total} chunks)")
728
+ except Exception:
729
+ pass
730
+
731
+ _dbg(f"_sharpen_plane: mode={params.mode} total_chunks={total} auto_psf={params.auto_detect_psf} dev={models.device!r}",
732
+ status_cb=lambda *_: None)
733
+
734
+ def _every(n: int) -> bool:
735
+ return _DEBUG_SHARPEN and (n == 1 or n % 10 == 0 or n == total)
736
+
484
737
  # Stage 1: stellar
485
738
  if params.mode in ("Stellar Only", "Both"):
739
+ _dbg("Stage 1: stellar BEGIN", status_cb=lambda *_: None)
486
740
  out_chunks = []
741
+ t_stage = time.time()
742
+
487
743
  for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
744
+ t0 = time.time()
488
745
  y = _infer_chunk(models, models.stellar, chunk)
489
746
  blended = blend_images(chunk, y, params.stellar_amount)
490
747
  out_chunks.append((blended, i, j, is_edge))
748
+
749
+ if _every(k):
750
+ _dbg(f" stellar chunk {k}/{total} ({time.time()-t0:.3f}s)", status_cb=lambda *_: None)
751
+
491
752
  if progress_cb(k, total, "Stellar sharpening") is False:
492
-
753
+ _dbg("Stage 1: stellar CANCELLED", status_cb=lambda *_: None)
493
754
  return plane
755
+
494
756
  plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
757
+ _dbg(f"Stage 1: stellar END ({time.time()-t_stage:.2f}s)", status_cb=lambda *_: None)
495
758
 
496
759
  if params.mode == "Stellar Only":
497
760
  return plane
@@ -502,11 +765,14 @@ def _sharpen_plane(models: SharpenModels,
502
765
 
503
766
  # Stage 2: non-stellar
504
767
  if params.mode in ("Non-Stellar Only", "Both"):
768
+ _dbg("Stage 2: non-stellar BEGIN", status_cb=lambda *_: None)
505
769
  out_chunks = []
506
770
  radii = np.array([1.0, 2.0, 4.0, 8.0], dtype=float)
507
771
  model_map = {1.0: models.ns1, 2.0: models.ns2, 4.0: models.ns4, 8.0: models.ns8}
772
+ t_stage = time.time()
508
773
 
509
774
  for k, (chunk, i, j, is_edge) in enumerate(chunks, start=1):
775
+ t0 = time.time()
510
776
  if params.auto_detect_psf:
511
777
  fwhm = measure_psf_fwhm(chunk, default_fwhm=3.0)
512
778
  r = float(np.clip(fwhm, radii[0], radii[-1]))
@@ -531,14 +797,21 @@ def _sharpen_plane(models: SharpenModels,
531
797
 
532
798
  blended = blend_images(chunk, y, params.nonstellar_amount)
533
799
  out_chunks.append((blended, i, j, is_edge))
800
+
801
+ if _every(k):
802
+ _dbg(f" nonstellar chunk {k}/{total} r={r:.2f} lo={lo} hi={hi} ({time.time()-t0:.3f}s)",
803
+ status_cb=lambda *_: None)
804
+
534
805
  if progress_cb(k, total, "Non-stellar sharpening") is False:
806
+ _dbg("Stage 2: non-stellar CANCELLED", status_cb=lambda *_: None)
535
807
  return plane
536
- progress_cb(k, total, "Non-stellar sharpening")
537
808
 
538
809
  plane = stitch_chunks_ignore_border(out_chunks, plane.shape, border_size=16)
810
+ _dbg(f"Stage 2: non-stellar END ({time.time()-t_stage:.2f}s)", status_cb=lambda *_: None)
539
811
 
540
812
  return plane
541
813
 
814
+
542
815
  def sharpen_rgb01(
543
816
  image_rgb01: np.ndarray,
544
817
  *,
@@ -20,6 +20,7 @@ def _get_torch(*, prefer_cuda: bool, prefer_dml: bool, status_cb=print):
20
20
  prefer_dml=prefer_dml,
21
21
  status_cb=status_cb,
22
22
  )
23
+ from setiastro.saspro.runtime_torch import _user_runtime_dir, _venv_paths, _check_cuda_in_venv
23
24
 
24
25
 
25
26
  from setiastro.saspro.resources import get_resources