setiastrosuitepro 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (174) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/saspro/__init__.py +20 -0
  3. setiastro/saspro/__main__.py +784 -0
  4. setiastro/saspro/_generated/__init__.py +7 -0
  5. setiastro/saspro/_generated/build_info.py +2 -0
  6. setiastro/saspro/abe.py +1295 -0
  7. setiastro/saspro/abe_preset.py +196 -0
  8. setiastro/saspro/aberration_ai.py +694 -0
  9. setiastro/saspro/aberration_ai_preset.py +224 -0
  10. setiastro/saspro/accel_installer.py +218 -0
  11. setiastro/saspro/accel_workers.py +30 -0
  12. setiastro/saspro/add_stars.py +621 -0
  13. setiastro/saspro/astrobin_exporter.py +1007 -0
  14. setiastro/saspro/astrospike.py +153 -0
  15. setiastro/saspro/astrospike_python.py +1839 -0
  16. setiastro/saspro/autostretch.py +196 -0
  17. setiastro/saspro/backgroundneutral.py +560 -0
  18. setiastro/saspro/batch_convert.py +325 -0
  19. setiastro/saspro/batch_renamer.py +519 -0
  20. setiastro/saspro/blemish_blaster.py +488 -0
  21. setiastro/saspro/blink_comparator_pro.py +2923 -0
  22. setiastro/saspro/bundles.py +61 -0
  23. setiastro/saspro/bundles_dock.py +114 -0
  24. setiastro/saspro/cheat_sheet.py +168 -0
  25. setiastro/saspro/clahe.py +342 -0
  26. setiastro/saspro/comet_stacking.py +1377 -0
  27. setiastro/saspro/config.py +38 -0
  28. setiastro/saspro/config_bootstrap.py +40 -0
  29. setiastro/saspro/config_manager.py +316 -0
  30. setiastro/saspro/continuum_subtract.py +1617 -0
  31. setiastro/saspro/convo.py +1397 -0
  32. setiastro/saspro/convo_preset.py +414 -0
  33. setiastro/saspro/copyastro.py +187 -0
  34. setiastro/saspro/cosmicclarity.py +1564 -0
  35. setiastro/saspro/cosmicclarity_preset.py +407 -0
  36. setiastro/saspro/crop_dialog_pro.py +948 -0
  37. setiastro/saspro/crop_preset.py +189 -0
  38. setiastro/saspro/curve_editor_pro.py +2544 -0
  39. setiastro/saspro/curves_preset.py +375 -0
  40. setiastro/saspro/debayer.py +670 -0
  41. setiastro/saspro/debug_utils.py +29 -0
  42. setiastro/saspro/dnd_mime.py +35 -0
  43. setiastro/saspro/doc_manager.py +2634 -0
  44. setiastro/saspro/exoplanet_detector.py +2166 -0
  45. setiastro/saspro/file_utils.py +284 -0
  46. setiastro/saspro/fitsmodifier.py +744 -0
  47. setiastro/saspro/free_torch_memory.py +48 -0
  48. setiastro/saspro/frequency_separation.py +1343 -0
  49. setiastro/saspro/function_bundle.py +1594 -0
  50. setiastro/saspro/ghs_dialog_pro.py +660 -0
  51. setiastro/saspro/ghs_preset.py +284 -0
  52. setiastro/saspro/graxpert.py +634 -0
  53. setiastro/saspro/graxpert_preset.py +287 -0
  54. setiastro/saspro/gui/__init__.py +0 -0
  55. setiastro/saspro/gui/main_window.py +8494 -0
  56. setiastro/saspro/gui/mixins/__init__.py +33 -0
  57. setiastro/saspro/gui/mixins/dock_mixin.py +263 -0
  58. setiastro/saspro/gui/mixins/file_mixin.py +445 -0
  59. setiastro/saspro/gui/mixins/geometry_mixin.py +403 -0
  60. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  61. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  62. setiastro/saspro/gui/mixins/menu_mixin.py +361 -0
  63. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  64. setiastro/saspro/gui/mixins/toolbar_mixin.py +1324 -0
  65. setiastro/saspro/gui/mixins/update_mixin.py +309 -0
  66. setiastro/saspro/gui/mixins/view_mixin.py +435 -0
  67. setiastro/saspro/halobgon.py +462 -0
  68. setiastro/saspro/header_viewer.py +445 -0
  69. setiastro/saspro/headless_utils.py +88 -0
  70. setiastro/saspro/histogram.py +753 -0
  71. setiastro/saspro/history_explorer.py +939 -0
  72. setiastro/saspro/image_combine.py +414 -0
  73. setiastro/saspro/image_peeker_pro.py +1596 -0
  74. setiastro/saspro/imageops/__init__.py +37 -0
  75. setiastro/saspro/imageops/mdi_snap.py +292 -0
  76. setiastro/saspro/imageops/scnr.py +36 -0
  77. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  78. setiastro/saspro/imageops/stretch.py +244 -0
  79. setiastro/saspro/isophote.py +1179 -0
  80. setiastro/saspro/layers.py +208 -0
  81. setiastro/saspro/layers_dock.py +714 -0
  82. setiastro/saspro/lazy_imports.py +193 -0
  83. setiastro/saspro/legacy/__init__.py +2 -0
  84. setiastro/saspro/legacy/image_manager.py +2226 -0
  85. setiastro/saspro/legacy/numba_utils.py +3659 -0
  86. setiastro/saspro/legacy/xisf.py +1071 -0
  87. setiastro/saspro/linear_fit.py +534 -0
  88. setiastro/saspro/live_stacking.py +1830 -0
  89. setiastro/saspro/log_bus.py +5 -0
  90. setiastro/saspro/logging_config.py +460 -0
  91. setiastro/saspro/luminancerecombine.py +309 -0
  92. setiastro/saspro/main_helpers.py +201 -0
  93. setiastro/saspro/mask_creation.py +928 -0
  94. setiastro/saspro/masks_core.py +56 -0
  95. setiastro/saspro/mdi_widgets.py +353 -0
  96. setiastro/saspro/memory_utils.py +666 -0
  97. setiastro/saspro/metadata_patcher.py +75 -0
  98. setiastro/saspro/mfdeconv.py +3826 -0
  99. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  100. setiastro/saspro/mfdeconvcudnn.py +3263 -0
  101. setiastro/saspro/mfdeconvsport.py +2382 -0
  102. setiastro/saspro/minorbodycatalog.py +567 -0
  103. setiastro/saspro/morphology.py +382 -0
  104. setiastro/saspro/multiscale_decomp.py +1290 -0
  105. setiastro/saspro/nbtorgb_stars.py +531 -0
  106. setiastro/saspro/numba_utils.py +3044 -0
  107. setiastro/saspro/numba_warmup.py +141 -0
  108. setiastro/saspro/ops/__init__.py +9 -0
  109. setiastro/saspro/ops/command_help_dialog.py +623 -0
  110. setiastro/saspro/ops/command_runner.py +217 -0
  111. setiastro/saspro/ops/commands.py +1594 -0
  112. setiastro/saspro/ops/script_editor.py +1102 -0
  113. setiastro/saspro/ops/scripts.py +1413 -0
  114. setiastro/saspro/ops/settings.py +560 -0
  115. setiastro/saspro/parallel_utils.py +554 -0
  116. setiastro/saspro/pedestal.py +121 -0
  117. setiastro/saspro/perfect_palette_picker.py +1053 -0
  118. setiastro/saspro/pipeline.py +110 -0
  119. setiastro/saspro/pixelmath.py +1600 -0
  120. setiastro/saspro/plate_solver.py +2435 -0
  121. setiastro/saspro/project_io.py +797 -0
  122. setiastro/saspro/psf_utils.py +136 -0
  123. setiastro/saspro/psf_viewer.py +549 -0
  124. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  125. setiastro/saspro/remove_green.py +314 -0
  126. setiastro/saspro/remove_stars.py +1625 -0
  127. setiastro/saspro/remove_stars_preset.py +404 -0
  128. setiastro/saspro/resources.py +472 -0
  129. setiastro/saspro/rgb_combination.py +207 -0
  130. setiastro/saspro/rgb_extract.py +19 -0
  131. setiastro/saspro/rgbalign.py +723 -0
  132. setiastro/saspro/runtime_imports.py +7 -0
  133. setiastro/saspro/runtime_torch.py +754 -0
  134. setiastro/saspro/save_options.py +72 -0
  135. setiastro/saspro/selective_color.py +1552 -0
  136. setiastro/saspro/sfcc.py +1425 -0
  137. setiastro/saspro/shortcuts.py +2807 -0
  138. setiastro/saspro/signature_insert.py +1099 -0
  139. setiastro/saspro/stacking_suite.py +17712 -0
  140. setiastro/saspro/star_alignment.py +7420 -0
  141. setiastro/saspro/star_alignment_preset.py +329 -0
  142. setiastro/saspro/star_metrics.py +49 -0
  143. setiastro/saspro/star_spikes.py +681 -0
  144. setiastro/saspro/star_stretch.py +470 -0
  145. setiastro/saspro/stat_stretch.py +502 -0
  146. setiastro/saspro/status_log_dock.py +78 -0
  147. setiastro/saspro/subwindow.py +3267 -0
  148. setiastro/saspro/supernovaasteroidhunter.py +1712 -0
  149. setiastro/saspro/swap_manager.py +99 -0
  150. setiastro/saspro/torch_backend.py +89 -0
  151. setiastro/saspro/torch_rejection.py +434 -0
  152. setiastro/saspro/view_bundle.py +1555 -0
  153. setiastro/saspro/wavescale_hdr.py +624 -0
  154. setiastro/saspro/wavescale_hdr_preset.py +100 -0
  155. setiastro/saspro/wavescalede.py +657 -0
  156. setiastro/saspro/wavescalede_preset.py +228 -0
  157. setiastro/saspro/wcs_update.py +374 -0
  158. setiastro/saspro/whitebalance.py +456 -0
  159. setiastro/saspro/widgets/__init__.py +48 -0
  160. setiastro/saspro/widgets/common_utilities.py +305 -0
  161. setiastro/saspro/widgets/graphics_views.py +122 -0
  162. setiastro/saspro/widgets/image_utils.py +518 -0
  163. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  164. setiastro/saspro/widgets/spinboxes.py +275 -0
  165. setiastro/saspro/widgets/themed_buttons.py +13 -0
  166. setiastro/saspro/widgets/wavelet_utils.py +299 -0
  167. setiastro/saspro/window_shelf.py +185 -0
  168. setiastro/saspro/xisf.py +1123 -0
  169. setiastrosuitepro-1.6.0.dist-info/METADATA +266 -0
  170. setiastrosuitepro-1.6.0.dist-info/RECORD +174 -0
  171. setiastrosuitepro-1.6.0.dist-info/WHEEL +4 -0
  172. setiastrosuitepro-1.6.0.dist-info/entry_points.txt +6 -0
  173. setiastrosuitepro-1.6.0.dist-info/licenses/LICENSE +674 -0
  174. setiastrosuitepro-1.6.0.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3826 @@
1
+ # pro/mfdeconv.py non-sport normal version
2
+ from __future__ import annotations
3
+ import os, sys
4
+ import math
5
+ import re
6
+ import numpy as np
7
+ import time
8
+ from astropy.io import fits
9
+ from PyQt6.QtCore import QObject, pyqtSignal
10
+ from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
11
+ from PyQt6.QtWidgets import QApplication
12
+ from PyQt6.QtCore import QThread
13
+ import contextlib
14
+ from threadpoolctl import threadpool_limits
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
16
+ _USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
17
+ import gc
18
+ try:
19
+ import sep
20
+ except Exception:
21
+ sep = None
22
+ from setiastro.saspro.free_torch_memory import _free_torch_memory
23
+ from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
24
+ torch = None # filled by runtime loader if available
25
+ TORCH_OK = False
26
+ NO_GRAD = contextlib.nullcontext # fallback
27
+ _XISF_READERS = []
28
+ try:
29
+ # e.g. your legacy module
30
+ from setiastro.saspro.legacy import xisf as _legacy_xisf
31
+ if hasattr(_legacy_xisf, "read"):
32
+ _XISF_READERS.append(lambda p: _legacy_xisf.read(p))
33
+ elif hasattr(_legacy_xisf, "open"):
34
+ _XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
35
+ except Exception:
36
+ pass
37
+ try:
38
+ # sometimes projects expose a generic load_image
39
+ from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
40
+ _XISF_READERS.append(lambda p: _generic_load_image(p)[0])
41
+ except Exception:
42
+ pass
43
+
44
+ from pathlib import Path
45
+
46
+ # at top of file with the other imports
47
+
48
+ from queue import SimpleQueue
49
+ from setiastro.saspro.memory_utils import LRUDict
50
+
51
+ # ── XISF decode cache → memmap on disk ─────────────────────────────────
52
+ import tempfile
53
+ import threading
54
+ import uuid
55
+ import atexit
56
+ _XISF_CACHE = LRUDict(50)
57
+ _XISF_LOCK = threading.Lock()
58
+ _XISF_TMPFILES = []
59
+
60
+ from collections import OrderedDict
61
+
62
+ # ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
63
+ class _FrameCHWLRU:
64
+ def __init__(self, capacity=8):
65
+ self.cap = int(max(1, capacity))
66
+ self.od = OrderedDict()
67
+
68
+ def clear(self):
69
+ self.od.clear()
70
+
71
+ def get(self, path, Ht, Wt, color_mode):
72
+ key = (path, Ht, Wt, str(color_mode).lower())
73
+ hit = self.od.get(key)
74
+ if hit is not None:
75
+ self.od.move_to_end(key)
76
+ return hit
77
+
78
+ # Load backing array cheaply (memmap for FITS, cached memmap for XISF)
79
+ ext = os.path.splitext(path)[1].lower()
80
+ if ext == ".xisf":
81
+ a = _xisf_cached_array(path) # float32, HW/HWC/CHW
82
+ else:
83
+ # FITS path: use astropy memmap (no data copy)
84
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
85
+ arr = None
86
+ for h in hdul:
87
+ if getattr(h, "data", None) is not None:
88
+ arr = h.data
89
+ break
90
+ if arr is None:
91
+ raise ValueError(f"No image data in {path}")
92
+ a = np.asarray(arr)
93
+ # dtype normalize once; keep float32
94
+ if a.dtype.kind in "ui":
95
+ a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
96
+ else:
97
+ a = a.astype(np.float32, copy=False)
98
+
99
+ # Center-crop to (Ht, Wt) and convert to CHW
100
+ a = np.asarray(a) # float32
101
+ a = _center_crop(a, Ht, Wt)
102
+
103
+ # Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
104
+ cm = str(color_mode).lower()
105
+ if cm == "luma":
106
+ a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
107
+ else:
108
+ a_chw = _as_chw(a).astype(np.float32, copy=False)
109
+ if a_chw.shape[0] == 1 and cm != "luma":
110
+ # still OK (mono data)
111
+ pass
112
+
113
+ # LRU insert
114
+ self.od[key] = a_chw
115
+ if len(self.od) > self.cap:
116
+ self.od.popitem(last=False)
117
+ return a_chw
118
+
119
+ _FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
120
+
121
+ def _clear_all_caches():
122
+ try: _clear_xisf_cache()
123
+ except Exception as e:
124
+ import logging
125
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
126
+ try: _FRAME_LRU.clear()
127
+ except Exception as e:
128
+ import logging
129
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
130
+
131
+
132
+ def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
133
+ if a.dtype.kind in "ui":
134
+ return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
135
+ if a.dtype == np.float32:
136
+ return a
137
+ return a.astype(np.float32, copy=False)
138
+
139
+ def _xisf_cached_array(path: str) -> np.memmap:
140
+ """
141
+ Decode an XISF image exactly once and back it by a read-only float32 memmap.
142
+ Returns a memmap that can be sliced cheaply for tiles.
143
+ """
144
+ with _XISF_LOCK:
145
+ hit = _XISF_CACHE.get(path)
146
+ if hit is not None:
147
+ fn, shape = hit
148
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
149
+
150
+ # Decode once
151
+ arr, _ = _load_image_array(path) # your existing loader
152
+ if arr is None:
153
+ raise ValueError(f"XISF loader returned None for {path}")
154
+ arr = np.asarray(arr)
155
+ arrf = _normalize_to_float32(arr)
156
+
157
+ # Create a temp file-backed memmap
158
+ tmpdir = tempfile.gettempdir()
159
+ fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
160
+ mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
161
+ mm[...] = arrf[...]
162
+ mm.flush()
163
+ del mm # close writer handle; re-open below as read-only
164
+
165
+ _XISF_CACHE[path] = (fn, arrf.shape)
166
+ _XISF_TMPFILES.append(fn)
167
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
168
+
169
+ def _clear_xisf_cache():
170
+ with _XISF_LOCK:
171
+ for fn in _XISF_TMPFILES:
172
+ try: os.remove(fn)
173
+ except Exception as e:
174
+ import logging
175
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
176
+ _XISF_CACHE.clear()
177
+ _XISF_TMPFILES.clear()
178
+
179
+ atexit.register(_clear_xisf_cache)
180
+
181
+
182
+ def _is_xisf(path: str) -> bool:
183
+ return os.path.splitext(path)[1].lower() == ".xisf"
184
+
185
+ def _read_xisf_numpy(path: str) -> np.ndarray:
186
+ if not _XISF_READERS:
187
+ raise RuntimeError(
188
+ "No XISF readers registered. Ensure one of "
189
+ "legacy.xisf.read/open or *.image_io.load_image is importable."
190
+ )
191
+ last_err = None
192
+ for fn in _XISF_READERS:
193
+ try:
194
+ arr = fn(path)
195
+ if isinstance(arr, tuple):
196
+ arr = arr[0]
197
+ return np.asarray(arr)
198
+ except Exception as e:
199
+ last_err = e
200
+ raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
201
+
202
+ def _fits_open_data(path: str):
203
+ # ignore_missing_simple=True lets us open headers missing SIMPLE
204
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
205
+ hdu = hdul[0]
206
+ if hdu.data is None:
207
+ # find first image HDU if primary is header-only
208
+ for h in hdul[1:]:
209
+ if getattr(h, "data", None) is not None:
210
+ hdu = h
211
+ break
212
+ data = np.asanyarray(hdu.data)
213
+ hdr = hdu.header
214
+ return data, hdr
215
+
216
+ def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
217
+ """
218
+ Return (numpy array, fits.Header or None). Color-last if 3D.
219
+ dtype left as-is; callers cast to float32. Array is C-contig & writeable.
220
+ """
221
+ if _is_xisf(path):
222
+ arr = _read_xisf_numpy(path)
223
+ hdr = None
224
+ else:
225
+ arr, hdr = _fits_open_data(path)
226
+
227
+ a = np.asarray(arr)
228
+ # Move color axis to last if 3D with a leading channel axis
229
+ if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
230
+ a = np.moveaxis(a, 0, -1)
231
+ # Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
232
+ if (not a.flags.c_contiguous) or (not a.flags.writeable):
233
+ a = np.array(a, copy=True)
234
+ return a, hdr
235
+
236
+ def _probe_hw(path: str) -> tuple[int, int, int | None]:
237
+ """
238
+ Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
239
+ """
240
+ a, _ = _load_image_array(path)
241
+ if a.ndim == 2:
242
+ return a.shape[0], a.shape[1], None
243
+ if a.ndim == 3:
244
+ h, w, c = a.shape
245
+ # treat mono-3D as (H,W,1)
246
+ if c not in (1, 3) and a.shape[0] in (1, 3):
247
+ a = np.moveaxis(a, 0, -1)
248
+ h, w, c = a.shape
249
+ return h, w, c if c in (1, 3) else None
250
+ raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
251
+
252
+ def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
253
+ """
254
+ Replacement for the old FITS-only version: min(H), min(W) across files.
255
+ """
256
+ Hs, Ws = [], []
257
+ for p in paths:
258
+ h, w, _ = _probe_hw(p)
259
+ Hs.append(int(h)); Ws.append(int(w))
260
+ return int(min(Hs)), int(min(Ws))
261
+
262
+ def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
263
+ """
264
+ Convert to CHW float32:
265
+ - mono → (1,H,W)
266
+ - RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
267
+ """
268
+ x = np.asarray(img)
269
+ if x.ndim == 2:
270
+ y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
271
+ return y
272
+ if x.ndim == 3:
273
+ # color-last (H,W,C) expected
274
+ if x.shape[-1] == 1:
275
+ return x[..., 0].astype(np.float32, copy=False)[None, ...]
276
+ if x.shape[-1] == 3:
277
+ if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
278
+ r, g, b = x[..., 0], x[..., 1], x[..., 2]
279
+ return np.stack([r.astype(np.float32, copy=False),
280
+ g.astype(np.float32, copy=False),
281
+ b.astype(np.float32, copy=False)], axis=0)
282
+ # luma
283
+ r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
284
+ L = 0.2126*r + 0.7152*g + 0.0722*b
285
+ return L[None, ...]
286
+ # rare mono-3D
287
+ if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
288
+ x = np.moveaxis(x, 0, -1)
289
+ return _to_chw_float32(x, color_mode)
290
+ raise ValueError(f"Unsupported image shape {x.shape}")
291
+
292
+ def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
293
+ h, w = img.shape[:2]
294
+ y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
295
+ return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
296
+
297
+ def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
298
+ """
299
+ Drop-in replacement of the old FITS-only helper.
300
+ Returns (ys, hdrs):
301
+ ys : list of CHW float32 arrays cropped to (Ht,Wt)
302
+ hdrs : list of fits.Header or None (XISF)
303
+ """
304
+ ys, hdrs = [], []
305
+ for p in paths:
306
+ arr, hdr = _load_image_array(p)
307
+ arr = _center_crop_hw(arr, Ht, Wt)
308
+ # normalize integer data to [0,1] like the rest of your code
309
+ if arr.dtype.kind in "ui":
310
+ mx = np.float32(np.iinfo(arr.dtype).max)
311
+ arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
312
+ elif arr.dtype.kind == "f":
313
+ arr = arr.astype(np.float32, copy=False)
314
+ else:
315
+ arr = arr.astype(np.float32, copy=False)
316
+
317
+ y = _to_chw_float32(arr, color_mode)
318
+ if (not y.flags.c_contiguous) or (not y.flags.writeable):
319
+ y = np.ascontiguousarray(y.astype(np.float32, copy=True))
320
+ ys.append(y)
321
+ hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
322
+ return ys, hdrs
323
+
324
+ def _safe_primary_header(path: str) -> fits.Header:
325
+ if _is_xisf(path):
326
+ # best-effort synthetic header
327
+ h = fits.Header()
328
+ h["SIMPLE"] = (True, "created by MFDeconv")
329
+ h["BITPIX"] = -32
330
+ h["NAXIS"] = 2
331
+ return h
332
+ try:
333
+ return fits.getheader(path, ext=0, ignore_missing_simple=True)
334
+ except Exception:
335
+ return fits.Header()
336
+
337
+ # --- CUDA busy/unavailable detector (runtime fallback helper) ---
338
+ def _is_cuda_busy_error(e: Exception) -> bool:
339
+ """
340
+ Return True for 'device busy/unavailable' style CUDA errors that can pop up
341
+ mid-run on shared Linux systems. We *exclude* OOM here (handled elsewhere).
342
+ """
343
+ em = str(e).lower()
344
+ if "out of memory" in em:
345
+ return False # OOM handled by your batch size backoff
346
+ return (
347
+ ("cuda" in em and ("busy" in em or "unavailable" in em))
348
+ or ("device-side" in em and "assert" in em)
349
+ or ("driver shutting down" in em)
350
+ )
351
+
352
+
353
+ def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
354
+ star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
355
+ """
356
+ Worker function: compute PSF and optional star mask / varmap for one frame.
357
+ Returns (index, psf, mask_or_None, var_or_None, log_lines)
358
+ """
359
+ logs = []
360
+ def log(s): logs.append(s)
361
+
362
+ # --- PSF sizing by FWHM ---
363
+ f_hdr = _estimate_fwhm_from_header(hdr)
364
+ f_img = _estimate_fwhm_from_image(arr)
365
+ f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
366
+ if not np.isfinite(f_whm) or f_whm <= 0:
367
+ f_whm = 2.5
368
+ k_auto = _auto_ksize_from_fwhm(f_whm)
369
+
370
+ # --- Star-derived PSF with retries ---
371
+ tried, psf = [], None
372
+ for k_try in [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]:
373
+ if k_try in tried: continue
374
+ tried.append(k_try)
375
+ try:
376
+ out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=6.0, max_stars=80)
377
+ psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
378
+ if psf_try is not None:
379
+ psf = psf_try
380
+ break
381
+ except Exception:
382
+ psf = None
383
+ if psf is None:
384
+ psf = _gaussian_psf(f_whm, ksize=k_auto)
385
+ psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.0)
386
+
387
+ mask = None
388
+ var = None
389
+
390
+ if make_masks or make_varmaps:
391
+ # one background per frame (reused by both)
392
+ luma = _to_luma_local(arr)
393
+ vmc = (varmap_cfg or {})
394
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
395
+ luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
396
+ )
397
+
398
+ if make_masks:
399
+ smc = star_mask_cfg or {}
400
+ mask = _star_mask_from_precomputed(
401
+ luma, sky_map, err_scalar,
402
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
403
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
404
+ grow_px = smc.get("grow_px", GROW_PX),
405
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
406
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
407
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
408
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
409
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
410
+ status_cb = log,
411
+ )
412
+
413
+ if make_varmaps:
414
+ vmc = varmap_cfg or {}
415
+ var = _variance_map_from_precomputed(
416
+ luma, sky_map, rms_map, hdr,
417
+ smooth_sigma = vmc.get("smooth_sigma", 1.0),
418
+ floor = vmc.get("floor", 1e-8),
419
+ status_cb = log,
420
+ )
421
+
422
+ # small per-frame summary
423
+ fwhm_est = _psf_fwhm_px(psf)
424
+ logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
425
+
426
+ return i, psf, mask, var, logs
427
+
428
+ def _compute_one_worker(args):
429
+ """
430
+ Top-level picklable worker for ProcessPoolExecutor.
431
+ args: (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg)
432
+ Returns (i, psf, mask, var, logs)
433
+ """
434
+ (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg) = args
435
+ # avoid BLAS/OMP storm inside each process
436
+ with threadpool_limits(limits=1):
437
+ arr, hdr = _load_image_array(path) # FITS or XISF
438
+ arr = np.asarray(arr, dtype=np.float32, order="C")
439
+ if arr.ndim == 3 and arr.shape[-1] == 1:
440
+ arr = np.squeeze(arr, axis=-1)
441
+ if not isinstance(hdr, fits.Header): # synthesize FITS-like header for XISF
442
+ hdr = _safe_primary_header(path)
443
+ return _compute_frame_assets(
444
+ i, arr, hdr,
445
+ make_masks=bool(make_masks_in_worker),
446
+ make_varmaps=bool(make_varmaps),
447
+ star_mask_cfg=star_mask_cfg,
448
+ varmap_cfg=varmap_cfg,
449
+ )
450
+
451
+
452
+ def _build_psf_and_assets(
453
+ paths, # list[str]
454
+ make_masks=False,
455
+ make_varmaps=False,
456
+ status_cb=lambda s: None,
457
+ save_dir: str | None = None,
458
+ star_mask_cfg: dict | None = None,
459
+ varmap_cfg: dict | None = None,
460
+ max_workers: int | None = None,
461
+ star_mask_ref_path: str | None = None, # build one mask from this frame if provided
462
+ # NEW (passed from multiframe_deconv so we don’t re-probe/convert):
463
+ Ht: int | None = None,
464
+ Wt: int | None = None,
465
+ color_mode: str = "luma",
466
+ ):
467
+ """
468
+ Parallel PSF + (optional) star mask + variance map per frame.
469
+
470
+ Changes from the original:
471
+ • Reuses the decoded frame cache (_FRAME_LRU) for FITS/XISF so we never re-decode.
472
+ • Automatically switches to threads for XISF (so memmaps are shared across workers).
473
+ • Builds a single reference star mask (if requested) from the cached frame and
474
+ center-pads/crops it for all frames (no extra I/O).
475
+ • Preserves return order and streams worker logs back to the UI.
476
+ """
477
+ if save_dir:
478
+ os.makedirs(save_dir, exist_ok=True)
479
+
480
+ n = len(paths)
481
+
482
+ # Resolve target intersection size if caller didn't pass it
483
+ if Ht is None or Wt is None:
484
+ Ht, Wt = _common_hw_from_paths(paths)
485
+
486
+ # Sensible default worker count (cap at 8)
487
+ if max_workers is None:
488
+ try:
489
+ hw = os.cpu_count() or 4
490
+ except Exception:
491
+ hw = 4
492
+ max_workers = max(1, min(8, hw))
493
+
494
+ # Decide executor: for any XISF, prefer threads so the memmap/cache is shared
495
+ any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
496
+ use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
497
+ Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
498
+ pool_kind = "process" if use_proc_pool else "thread"
499
+ status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
500
+
501
+ # ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
502
+ def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
503
+ a2d = np.asarray(a2d, dtype=np.float32)
504
+ H, W = int(a2d.shape[0]), int(a2d.shape[1])
505
+ # crop first if bigger
506
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
507
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
508
+ cropped = a2d[y0:y1, x0:x1]
509
+ ch, cw = cropped.shape
510
+ if ch == Ht and cw == Wt:
511
+ return np.ascontiguousarray(cropped, dtype=np.float32)
512
+ # pad if smaller
513
+ out = np.full((Ht, Wt), float(fill), dtype=np.float32)
514
+ oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
515
+ out[oy:oy+ch, ox:ox+cw] = cropped
516
+ return out
517
+
518
+ # ---- optional: build one mask from the reference frame and reuse ----
519
+ base_ref_mask = None
520
+ if make_masks and star_mask_ref_path:
521
+ try:
522
+ status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
523
+ # Pull from the shared frame cache as luma on (Ht,Wt)
524
+ ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
525
+ L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
526
+
527
+ vmc = (varmap_cfg or {})
528
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
529
+ L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
530
+ )
531
+ smc = (star_mask_cfg or {})
532
+ base_ref_mask = _star_mask_from_precomputed(
533
+ L, sky_map, err_scalar,
534
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
535
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
536
+ grow_px = smc.get("grow_px", GROW_PX),
537
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
538
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
539
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
540
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
541
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
542
+ status_cb = status_cb,
543
+ )
544
+ except Exception as e:
545
+ status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
546
+ base_ref_mask = None
547
+
548
+ # for GUI safety, queue logs from workers and flush in the main thread
549
+ log_queue: SimpleQueue = SimpleQueue()
550
+
551
+ def enqueue_logs(lines):
552
+ for s in lines:
553
+ log_queue.put(s)
554
+
555
+ psfs = [None] * n
556
+ masks = ([None] * n) if make_masks else None
557
+ vars_ = ([None] * n) if make_varmaps else None
558
+ make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
559
+
560
+ # --- thread worker: get frame from cache and compute assets ---
561
+ def _compute_one(i: int, path: str):
562
+ # avoid heavy BLAS oversubscription inside each worker
563
+ with threadpool_limits(limits=1):
564
+ # Pull frame from cache honoring color_mode & target (Ht,Wt)
565
+ img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
566
+ # For PSF/mask/varmap we operate on a 2D plane (luma/mono)
567
+ arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
568
+
569
+ # Header: synthesize a safe FITS-like header (works for XISF too)
570
+ try:
571
+ hdr = _safe_primary_header(path)
572
+ except Exception:
573
+ hdr = fits.Header()
574
+
575
+ return _compute_frame_assets(
576
+ i, arr2d, hdr,
577
+ make_masks=bool(make_masks_in_worker),
578
+ make_varmaps=bool(make_varmaps),
579
+ star_mask_cfg=star_mask_cfg,
580
+ varmap_cfg=varmap_cfg,
581
+ )
582
+
583
+ # --- submit jobs ---
584
+ with Executor(max_workers=max_workers) as ex:
585
+ futs = []
586
+ for i, p in enumerate(paths, start=1):
587
+ status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
588
+ if use_proc_pool:
589
+ # Process-safe path: worker re-loads inside the subprocess
590
+ futs.append(ex.submit(
591
+ _compute_one_worker,
592
+ (i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg)
593
+ ))
594
+ else:
595
+ # Thread path: hits the shared cache (fast path for XISF/FITS)
596
+ futs.append(ex.submit(_compute_one, i, p))
597
+
598
+ done_cnt = 0
599
+ for fut in as_completed(futs):
600
+ i, psf, m, v, logs = fut.result()
601
+ idx = i - 1
602
+ psfs[idx] = psf
603
+ if masks is not None:
604
+ masks[idx] = m
605
+ if vars_ is not None:
606
+ vars_[idx] = v
607
+ enqueue_logs(logs)
608
+
609
+ done_cnt += 1
610
+ if (done_cnt % 4) == 0 or done_cnt == n:
611
+ while not log_queue.empty():
612
+ try:
613
+ status_cb(log_queue.get_nowait())
614
+ except Exception:
615
+ break
616
+
617
+ # If we built a single reference mask, apply it to every frame (center pad/crop)
618
+ if base_ref_mask is not None and masks is not None:
619
+ for idx in range(n):
620
+ masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
621
+
622
+ # final flush of any remaining logs
623
+ while not log_queue.empty():
624
+ try:
625
+ status_cb(log_queue.get_nowait())
626
+ except Exception:
627
+ break
628
+
629
+ # save PSFs if requested
630
+ if save_dir:
631
+ for i, k in enumerate(psfs, start=1):
632
+ if k is not None:
633
+ fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
634
+ os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
635
+ )
636
+
637
+ return psfs, masks, vars_
638
+
639
+
640
+ _ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
641
+
642
+ # known FITS-style multi-extensions (rightmost-first match)
643
+ _KNOWN_EXTS = [
644
+ ".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
645
+ ".fz", ".gz",
646
+ ".fits", ".fit"
647
+ ]
648
+
649
+ def _sanitize_token(s: str) -> str:
650
+ s = _ALLOWED.sub("_", s)
651
+ s = re.sub(r"_+", "_", s).strip("_")
652
+ return s
653
+
654
+ def _split_known_exts(p: Path) -> tuple[str, str]:
655
+ """
656
+ Return (name_body, full_ext) where full_ext is a REAL extension block
657
+ (e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
658
+ """
659
+ name = p.name
660
+ for ext in _KNOWN_EXTS:
661
+ if name.lower().endswith(ext):
662
+ body = name[:-len(ext)]
663
+ return body, ext
664
+ # fallback: single suffix
665
+ return p.stem, "".join(p.suffixes)
666
+
667
+ _SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
668
+ _EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
669
+ _RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
670
+
671
+ def _extract_size(body: str) -> str | None:
672
+ m = _SIZE_RE.search(body)
673
+ return f"{m.group(1)}x{m.group(2)}" if m else None
674
+
675
+ def _extract_exposure_secs(body: str) -> str | None:
676
+ m = _EXP_RE.search(body)
677
+ if not m:
678
+ return None
679
+ secs = int(round(float(m.group(1))))
680
+ return f"{secs}s"
681
+
682
+ def _strip_metadata_from_base(body: str) -> str:
683
+ s = body
684
+
685
+ # normalize common separators first
686
+ s = s.replace(" - ", "_")
687
+
688
+ # remove known trailing marker '_MFDeconv'
689
+ s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
690
+
691
+ # remove parenthetical copy counters e.g. '(1)'
692
+ s = re.sub(r"\(\s*\d+\s*\)$", "", s)
693
+
694
+ # remove size (with or without parens) anywhere
695
+ s = _SIZE_RE.sub("", s)
696
+
697
+ # remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
698
+ s = _EXP_RE.sub("", s)
699
+
700
+ # remove any _#x tokens
701
+ s = _RX_RE.sub("", s)
702
+
703
+ # collapse whitespace/underscores and sanitize
704
+ s = re.sub(r"[\s]+", "_", s)
705
+ s = _sanitize_token(s)
706
+ return s or "output"
707
+
708
+ def _canonical_out_name_prefix(base: str, r: int, size: str | None,
709
+ exposure_secs: str | None, tag: str = "MFDeconv") -> str:
710
+ parts = [_sanitize_token(tag), _sanitize_token(base)]
711
+ if size:
712
+ parts.append(_sanitize_token(size))
713
+ if exposure_secs:
714
+ parts.append(_sanitize_token(exposure_secs))
715
+ if int(max(1, r)) > 1:
716
+ parts.append(f"{int(r)}x")
717
+ return "_".join(parts)
718
+
719
+ def _sr_out_path(out_path: str, r: int) -> Path:
720
+ """
721
+ Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
722
+ """
723
+ p = Path(out_path)
724
+ body, real_ext = _split_known_exts(p)
725
+
726
+ # harvest metadata from the whole body (not Path.stem)
727
+ size = _extract_size(body)
728
+ ex_sec = _extract_exposure_secs(body)
729
+
730
+ # clean base
731
+ base = _strip_metadata_from_base(body)
732
+
733
+ new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
734
+ return p.with_name(f"{new_stem}{real_ext}")
735
+
736
+ def _nonclobber_path(path: str) -> str:
737
+ """
738
+ Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
739
+ """
740
+ p = Path(path)
741
+ if not p.exists():
742
+ return str(p)
743
+
744
+ # keep the true extension(s)
745
+ body, real_ext = _split_known_exts(p)
746
+
747
+ # if already has _vN, bump it
748
+ m = re.search(r"(.*)_v(\d+)$", body)
749
+ if m:
750
+ base = m.group(1); n = int(m.group(2)) + 1
751
+ else:
752
+ base = body; n = 2
753
+
754
+ while True:
755
+ candidate = p.with_name(f"{base}_v{n}{real_ext}")
756
+ if not candidate.exists():
757
+ return str(candidate)
758
+ n += 1
759
+
760
+ def _iter_folder(basefile: str) -> str:
761
+ d, fname = os.path.split(basefile)
762
+ root, ext = os.path.splitext(fname)
763
+ tgt = os.path.join(d, f"{root}.iters")
764
+ if not os.path.exists(tgt):
765
+ try:
766
+ os.makedirs(tgt, exist_ok=True)
767
+ except Exception:
768
+ # last resort: suffix (n)
769
+ n = 1
770
+ while True:
771
+ cand = os.path.join(d, f"{root}.iters ({n})")
772
+ try:
773
+ os.makedirs(cand, exist_ok=True)
774
+ return cand
775
+ except Exception:
776
+ n += 1
777
+ return tgt
778
+
779
+ def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
780
+ """
781
+ arr: numpy array (H,W) or (C,H,W) float32
782
+ tag: 'seed' or 'iter_###'
783
+ """
784
+ if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
785
+ arr = np.moveaxis(arr, -1, 0)
786
+ if arr.ndim == 3 and arr.shape[0] == 1:
787
+ arr = arr[0]
788
+
789
+ hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
790
+ hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
791
+ hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
792
+ path = os.path.join(folder, f"{tag}.fit")
793
+ # overwrite allowed inside the dedicated folder
794
+ fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
795
+ return path
796
+
797
+
798
+ def _process_gui_events_safely():
799
+ app = QApplication.instance()
800
+ if app and QThread.currentThread() is app.thread():
801
+ app.processEvents()
802
+
803
+ EPS = 1e-6
804
+
805
+ # -----------------------------
806
+ # Helpers: image prep / shapes
807
+ # -----------------------------
808
+
809
+ # new: lightweight loader that yields one frame at a time
810
+ def _iter_fits(paths):
811
+ for p in paths:
812
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
813
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy
814
+ if arr.ndim == 3 and arr.shape[-1] == 1:
815
+ arr = np.squeeze(arr, axis=-1)
816
+ hdr = hdul[0].header.copy()
817
+ yield arr, hdr
818
+
819
+ def _to_luma_local(a: np.ndarray) -> np.ndarray:
820
+ a = np.asarray(a, dtype=np.float32)
821
+ if a.ndim == 2:
822
+ return a
823
+ # (H,W,3) or (3,H,W)
824
+ if a.ndim == 3 and a.shape[-1] == 3:
825
+ r, g, b = a[..., 0], a[..., 1], a[..., 2]
826
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
827
+ if a.ndim == 3 and a.shape[0] == 3:
828
+ r, g, b = a[0], a[1], a[2]
829
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
830
+ return a.mean(axis=-1).astype(np.float32, copy=False)
831
+
832
+ def _stack_loader(paths):
833
+ ys, hdrs = [], []
834
+ for p in paths:
835
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
836
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy inside with
837
+ hdr = hdul[0].header.copy()
838
+ if arr.ndim == 3 and arr.shape[-1] == 1:
839
+ arr = np.squeeze(arr, axis=-1)
840
+ ys.append(arr)
841
+ hdrs.append(hdr)
842
+ return ys, hdrs
843
+
844
+ def _normalize_layout_single(a, color_mode):
845
+ """
846
+ Coerce to:
847
+ - 'luma' -> (H, W)
848
+ - 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
849
+ Accepts (H,W), (H,W,3), or (3,H,W).
850
+ """
851
+ a = np.asarray(a, dtype=np.float32)
852
+
853
+ if color_mode == "luma":
854
+ return _to_luma_local(a) # returns (H,W)
855
+
856
+ # perchannel
857
+ if a.ndim == 2:
858
+ return a[None, ...] # (1,H,W) ← keep mono as 1 channel
859
+ if a.ndim == 3 and a.shape[-1] == 3:
860
+ return np.moveaxis(a, -1, 0) # (3,H,W)
861
+ if a.ndim == 3 and a.shape[0] in (1, 3):
862
+ return a # already (1,H,W) or (3,H,W)
863
+ # fallback: average any weird shape into luma 1×H×W
864
+ l = _to_luma_local(a)
865
+ return l[None, ...]
866
+
867
+
868
+ def _normalize_layout_batch(arrs, color_mode):
869
+ return [_normalize_layout_single(a, color_mode) for a in arrs]
870
+
871
+ def _common_hw(data_list):
872
+ """Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
873
+ Hs, Ws = [], []
874
+ for a in data_list:
875
+ if a.ndim == 2:
876
+ H, W = a.shape
877
+ else:
878
+ _, H, W = a.shape
879
+ Hs.append(H); Ws.append(W)
880
+ return int(min(Hs)), int(min(Ws))
881
+
882
+ def _center_crop(arr, Ht, Wt):
883
+ """Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
884
+ if arr.ndim == 2:
885
+ H, W = arr.shape
886
+ if H == Ht and W == Wt:
887
+ return arr
888
+ y0 = max(0, (H - Ht) // 2)
889
+ x0 = max(0, (W - Wt) // 2)
890
+ return arr[y0:y0+Ht, x0:x0+Wt]
891
+ else:
892
+ C, H, W = arr.shape
893
+ if H == Ht and W == Wt:
894
+ return arr
895
+ y0 = max(0, (H - Ht) // 2)
896
+ x0 = max(0, (W - Wt) // 2)
897
+ return arr[:, y0:y0+Ht, x0:x0+Wt]
898
+
899
+ def _sanitize_numeric(a):
900
+ """Replace NaN/Inf, clip negatives, make contiguous float32."""
901
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
902
+ a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
903
+ return np.ascontiguousarray(a)
904
+
905
+ # -----------------------------
906
+ # PSF utilities
907
+ # -----------------------------
908
+
909
+ def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
910
+ sigma = max(fwhm_px, 1.0) / 2.3548
911
+ r = (ksize - 1) / 2
912
+ y, x = np.mgrid[-r:r+1, -r:r+1]
913
+ g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
914
+ g /= (np.sum(g) + EPS)
915
+ return g.astype(np.float32, copy=False)
916
+
917
+ def _estimate_fwhm_from_header(hdr) -> float:
918
+ for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
919
+ if key in hdr:
920
+ try:
921
+ val = float(hdr[key])
922
+ if np.isfinite(val) and val > 0:
923
+ return val
924
+ except Exception:
925
+ pass
926
+ return float("nan")
927
+
928
+ def _estimate_fwhm_from_image(arr) -> float:
929
+ """Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
930
+ if sep is None:
931
+ return float("nan")
932
+ try:
933
+ img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
934
+ bkg = sep.Background(img)
935
+ data = _contig(img - bkg.back()) # ← ensure data is C-contig
936
+ try:
937
+ err = bkg.globalrms
938
+ except Exception:
939
+ err = float(np.median(bkg.rms()))
940
+ sources = sep.extract(data, 6.0, err=err)
941
+ if sources is None or len(sources) == 0:
942
+ return float("nan")
943
+ a = np.asarray(sources["a"], dtype=np.float32)
944
+ b = np.asarray(sources["b"], dtype=np.float32)
945
+ ab = (a + b) * 0.5
946
+ sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
947
+ if not np.isfinite(sigma) or sigma <= 0:
948
+ return float("nan")
949
+ return 2.3548 * sigma
950
+ except Exception:
951
+ return float("nan")
952
+
953
+ def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
954
+ """
955
+ Choose odd kernel size to cover about ±4σ.
956
+ """
957
+ sigma = max(fwhm_px, 1.0) / 2.3548
958
+ r = int(math.ceil(4.0 * sigma))
959
+ k = 2 * r + 1
960
+ k = max(kmin, min(k, kmax))
961
+ if (k % 2) == 0:
962
+ k += 1
963
+ return k
964
+
965
+ def _flip_kernel(psf):
966
+ # PyTorch dislikes negative strides; make it contiguous.
967
+ return np.flip(np.flip(psf, -1), -2).copy()
968
+
969
+ def _conv_same_np(img, psf):
970
+ """
971
+ NumPy FFT-based SAME convolution for (H,W) or (C,H,W).
972
+ IMPORTANT: ifftshift the PSF so its peak is at [0,0] before FFT.
973
+ """
974
+ import numpy as _np
975
+ import numpy.fft as _fft
976
+
977
+ kh, kw = psf.shape
978
+
979
+ def fftconv2(a, k):
980
+ # a is (1,H,W); k is (kh,kw)
981
+ H, W = a.shape[-2:]
982
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
983
+ A = _fft.rfftn(a, s=(fftH, fftW), axes=(-2, -1))
984
+ K = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW), axes=(-2, -1))
985
+ y = _fft.irfftn(A * K, s=(fftH, fftW), axes=(-2, -1))
986
+ sh, sw = (kh - 1)//2, (kw - 1)//2
987
+ return y[..., sh:sh+H, sw:sw+W]
988
+
989
+ if img.ndim == 2:
990
+ return fftconv2(img[None], psf)[0]
991
+ else:
992
+ # per-channel
993
+ return _np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
994
+
995
+ def _normalize_psf(psf):
996
+ psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
997
+ s = float(psf.sum())
998
+ if not np.isfinite(s) or s <= 1e-6:
999
+ return psf
1000
+ return (psf / s).astype(np.float32, copy=False)
1001
+
1002
+ def _soften_psf(psf, sigma_px=0.25):
1003
+ if sigma_px <= 0:
1004
+ return psf
1005
+ r = int(max(1, round(3 * sigma_px)))
1006
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1007
+ g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
1008
+ g /= g.sum() + 1e-6
1009
+ return _conv_same_np(psf[None], g)[0]
1010
+
1011
+ def _psf_fwhm_px(psf: np.ndarray) -> float:
1012
+ """Approximate FWHM (pixels) from second moments of a normalized kernel."""
1013
+ psf = np.maximum(psf, 0).astype(np.float32, copy=False)
1014
+ s = float(psf.sum())
1015
+ if s <= EPS:
1016
+ return float("nan")
1017
+ k = psf.shape[0]
1018
+ y, x = np.mgrid[:k, :k].astype(np.float32)
1019
+ cy = float((psf * y).sum() / s)
1020
+ cx = float((psf * x).sum() / s)
1021
+ var_y = float((psf * (y - cy) ** 2).sum() / s)
1022
+ var_x = float((psf * (x - cx) ** 2).sum() / s)
1023
+ sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
1024
+ return 2.3548 * sigma # FWHM≈2.355σ
1025
+
1026
+ STAR_MASK_MAXSIDE = 2048
1027
+ STAR_MASK_MAXOBJS = 2000 # cap number of objects
1028
+ VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
1029
+ THRESHOLD_SIGMA = 2.0
1030
+ KEEP_FLOOR = 0.20
1031
+ GROW_PX = 8
1032
+ MAX_STAR_RADIUS = 16
1033
+ SOFT_SIGMA = 2.0
1034
+ ELLIPSE_SCALE = 1.2
1035
+
1036
+ def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
1037
+ """One-time SEP background build; returns (sky_map, rms_map, err_scalar)."""
1038
+ if sep is None:
1039
+ # robust fallback
1040
+ med = float(np.median(img_2d))
1041
+ mad = float(np.median(np.abs(img_2d - med))) + 1e-6
1042
+ sky = np.full_like(img_2d, med, dtype=np.float32)
1043
+ rmsm = np.full_like(img_2d, 1.4826 * mad, dtype=np.float32)
1044
+ return sky, rmsm, float(np.median(rmsm))
1045
+
1046
+ a = np.ascontiguousarray(img_2d.astype(np.float32))
1047
+ b = sep.Background(a, bw=int(bw), bh=int(bh), fw=3, fh=3)
1048
+ sky = np.asarray(b.back(), dtype=np.float32)
1049
+ try:
1050
+ rmsm = np.asarray(b.rms(), dtype=np.float32)
1051
+ err = float(b.globalrms)
1052
+ except Exception:
1053
+ rmsm = np.full_like(a, float(np.median(b.rms())), dtype=np.float32)
1054
+ err = float(np.median(rmsm))
1055
+ return sky, rmsm, err
1056
+
1057
+
1058
+ def _auto_star_mask_sep(
1059
+ img_2d: np.ndarray,
1060
+ thresh_sigma: float = THRESHOLD_SIGMA,
1061
+ grow_px: int = GROW_PX,
1062
+ max_objs: int = STAR_MASK_MAXOBJS,
1063
+ max_side: int = STAR_MASK_MAXSIDE,
1064
+ ellipse_scale: float = ELLIPSE_SCALE,
1065
+ soft_sigma: float = SOFT_SIGMA,
1066
+ max_semiaxis_px: float | None = None, # kept for API compat; unused
1067
+ max_area_px2: float | None = None, # kept for API compat; unused
1068
+ max_radius_px: int = MAX_STAR_RADIUS,
1069
+ keep_floor: float = KEEP_FLOOR,
1070
+ status_cb=lambda s: None
1071
+ ) -> np.ndarray:
1072
+ """
1073
+ Build a KEEP weight map (float32 in [0,1]) using SEP detections only.
1074
+ **Never writes to img_2d** and draws only into a fresh mask buffer.
1075
+ """
1076
+ if sep is None:
1077
+ return np.ones_like(img_2d, dtype=np.float32, order="C")
1078
+
1079
+ # Optional OpenCV path for fast draw/blur
1080
+ try:
1081
+ import cv2 as _cv2
1082
+ _HAS_CV2 = True
1083
+ except Exception:
1084
+ _HAS_CV2 = False
1085
+ _cv2 = None # type: ignore
1086
+
1087
+ h, w = map(int, img_2d.shape)
1088
+
1089
+ # Background / residual on our own contiguous buffer
1090
+ data = np.ascontiguousarray(img_2d.astype(np.float32))
1091
+ bkg = sep.Background(data)
1092
+ data_sub = np.ascontiguousarray(data - bkg.back(), dtype=np.float32)
1093
+ try:
1094
+ err_scalar = float(bkg.globalrms)
1095
+ except Exception:
1096
+ err_scalar = float(np.median(np.asarray(bkg.rms(), dtype=np.float32)))
1097
+
1098
+ # Downscale for detection only
1099
+ det = data_sub
1100
+ scale = 1.0
1101
+ if max_side and max(h, w) > int(max_side):
1102
+ scale = float(max(h, w)) / float(max_side)
1103
+ if _HAS_CV2:
1104
+ det = _cv2.resize(
1105
+ det,
1106
+ (max(1, int(round(w / scale))), max(1, int(round(h / scale)))),
1107
+ interpolation=_cv2.INTER_AREA
1108
+ )
1109
+ else:
1110
+ s = int(max(1, round(scale)))
1111
+ det = det[:(h // s) * s, :(w // s) * s].reshape(h // s, s, w // s, s).mean(axis=(1, 3))
1112
+ scale = float(s)
1113
+
1114
+ # When averaging down by 'scale', per-pixel noise scales ~ 1/scale
1115
+ err_det = float(err_scalar) / float(max(1.0, scale))
1116
+
1117
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1118
+ thresh_sigma*8, thresh_sigma*16]
1119
+ objs = None; used = float("nan"); raw = 0
1120
+ for t in thresholds:
1121
+ try:
1122
+ cand = sep.extract(det, thresh=float(t), err=err_det)
1123
+ except Exception:
1124
+ cand = None
1125
+ n = 0 if cand is None else len(cand)
1126
+ if n == 0: continue
1127
+ if n > max_objs*12: continue
1128
+ objs, raw, used = cand, n, float(t)
1129
+ break
1130
+
1131
+ if objs is None or len(objs) == 0:
1132
+ try:
1133
+ cand = sep.extract(det, thresh=float(thresholds[-1]), err=err_det, minarea=9)
1134
+ except Exception:
1135
+ cand = None
1136
+ if cand is None or len(cand) == 0:
1137
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1138
+ return np.ones((h, w), dtype=np.float32, order="C")
1139
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1140
+
1141
+ # Keep brightest max_objs
1142
+ if "flux" in objs.dtype.names:
1143
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1144
+ objs = objs[idx]
1145
+ else:
1146
+ objs = objs[:int(max_objs)]
1147
+ kept_after_cap = int(len(objs))
1148
+
1149
+ # Draw on full-res fresh buffer
1150
+ mask_u8 = np.zeros((h, w), dtype=np.uint8, order="C")
1151
+ s_back = float(scale)
1152
+ MR = int(max(1, max_radius_px))
1153
+ G = int(max(0, grow_px))
1154
+ ES = float(max(0.1, ellipse_scale))
1155
+
1156
+ drawn = 0
1157
+ if _HAS_CV2:
1158
+ for o in objs:
1159
+ x = int(round(float(o["x"]) * s_back))
1160
+ y = int(round(float(o["y"]) * s_back))
1161
+ if not (0 <= x < w and 0 <= y < h):
1162
+ continue
1163
+ a = float(o["a"]) * s_back
1164
+ b = float(o["b"]) * s_back
1165
+ r = int(math.ceil(ES * max(a, b)))
1166
+ r = min(max(r, 0) + G, MR)
1167
+ if r <= 0:
1168
+ continue
1169
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1170
+ drawn += 1
1171
+ else:
1172
+ yy, xx = np.ogrid[:h, :w]
1173
+ for o in objs:
1174
+ x = int(round(float(o["x"]) * s_back))
1175
+ y = int(round(float(o["y"]) * s_back))
1176
+ if not (0 <= x < w and 0 <= y < h):
1177
+ continue
1178
+ a = float(o["a"]) * s_back
1179
+ b = float(o["b"]) * s_back
1180
+ r = int(math.ceil(ES * max(a, b)))
1181
+ r = min(max(r, 0) + G, MR)
1182
+ if r <= 0:
1183
+ continue
1184
+ y0 = max(0, y - r); y1 = min(h, y + r + 1)
1185
+ x0 = max(0, x - r); x1 = min(w, x + r + 1)
1186
+ ys = yy[y0:y1] - y
1187
+ xs = xx[x0:x1] - x
1188
+ disk = (ys * ys + xs * xs) <= (r * r)
1189
+ mask_u8[y0:y1, x0:x1][disk] = 1
1190
+ drawn += 1
1191
+
1192
+ masked_px_hard = int(mask_u8.sum())
1193
+
1194
+ # Feather and convert to KEEP weights in [0,1]
1195
+ m = mask_u8.astype(np.float32, copy=False)
1196
+ if soft_sigma and soft_sigma > 0.0:
1197
+ try:
1198
+ if _HAS_CV2:
1199
+ k = int(max(1, math.ceil(soft_sigma * 3)) * 2 + 1)
1200
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1201
+ borderType=_cv2.BORDER_REFLECT)
1202
+ else:
1203
+ from scipy.ndimage import gaussian_filter
1204
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1205
+ except Exception:
1206
+ pass
1207
+ np.clip(m, 0.0, 1.0, out=m)
1208
+
1209
+ keep = 1.0 - m
1210
+ kf = float(max(0.0, min(0.99, keep_floor)))
1211
+ keep = kf + (1.0 - kf) * keep
1212
+ np.clip(keep, 0.0, 1.0, out=keep)
1213
+
1214
+ status_cb(
1215
+ f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept_after_cap} | "
1216
+ f"drawn={drawn} | masked_px={masked_px_hard} | grow_px={G} | soft_sigma={soft_sigma} | keep_floor={keep_floor}"
1217
+ )
1218
+ return np.ascontiguousarray(keep, dtype=np.float32)
1219
+
1220
+
1221
+
1222
+ def _auto_variance_map(
1223
+ img_2d: np.ndarray,
1224
+ hdr,
1225
+ status_cb=lambda s: None,
1226
+ sample_stride: int = VARMAP_SAMPLE_STRIDE, # kept for signature compat; not used
1227
+ bw: int = 64, # SEP background box width (pixels)
1228
+ bh: int = 64, # SEP background box height (pixels)
1229
+ smooth_sigma: float = 1.0, # Gaussian sigma (px) to smooth the variance map
1230
+ floor: float = 1e-8, # hard floor to prevent blow-up in 1/var
1231
+ ) -> np.ndarray:
1232
+ """
1233
+ Build a per-pixel variance map in DN^2:
1234
+
1235
+ var_DN ≈ (object_only_DN)/gain + var_bg_DN^2
1236
+
1237
+ where:
1238
+ - object_only_DN = max(img - sky_DN, 0)
1239
+ - var_bg_DN^2 comes from SEP's local background rms (Poisson(sky)+readnoise)
1240
+ - if GAIN is missing, estimate 1/gain ≈ median(var_bg)/median(sky)
1241
+
1242
+ Returns float32 array, clipped below by `floor`, optionally smoothed with a
1243
+ small Gaussian to stabilize weights. Emits a summary status line.
1244
+ """
1245
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1246
+
1247
+ # --- Parse header for camera params (optional) ---
1248
+ gain = None
1249
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1250
+ if k in hdr:
1251
+ try:
1252
+ g = float(hdr[k])
1253
+ if np.isfinite(g) and g > 0:
1254
+ gain = g
1255
+ break
1256
+ except Exception:
1257
+ pass
1258
+
1259
+ readnoise = None
1260
+ for k in ("RDNOISE", "READNOISE", "RN"):
1261
+ if k in hdr:
1262
+ try:
1263
+ rn = float(hdr[k])
1264
+ if np.isfinite(rn) and rn >= 0:
1265
+ readnoise = rn
1266
+ break
1267
+ except Exception:
1268
+ pass
1269
+
1270
+ # --- Local background (full-res) ---
1271
+ if sep is not None:
1272
+ try:
1273
+ b = sep.Background(img, bw=int(bw), bh=int(bh), fw=3, fh=3)
1274
+ sky_dn_map = np.asarray(b.back(), dtype=np.float32)
1275
+ try:
1276
+ rms_dn_map = np.asarray(b.rms(), dtype=np.float32)
1277
+ except Exception:
1278
+ rms_dn_map = np.full_like(img, float(np.median(b.rms())), dtype=np.float32)
1279
+ except Exception:
1280
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1281
+ med = float(np.median(img))
1282
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1283
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1284
+ else:
1285
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1286
+ med = float(np.median(img))
1287
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1288
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1289
+
1290
+ # Background variance in DN^2
1291
+ var_bg_dn2 = np.maximum(rms_dn_map, 1e-6) ** 2
1292
+
1293
+ # Object-only DN
1294
+ obj_dn = np.clip(img - sky_dn_map, 0.0, None)
1295
+
1296
+ # Shot-noise coefficient
1297
+ if gain is not None and np.isfinite(gain) and gain > 0:
1298
+ a_shot = 1.0 / gain
1299
+ else:
1300
+ sky_med = float(np.median(sky_dn_map))
1301
+ varbg_med = float(np.median(var_bg_dn2))
1302
+ if sky_med > 1e-6:
1303
+ a_shot = np.clip(varbg_med / sky_med, 0.0, 10.0) # ~ 1/gain estimate
1304
+ else:
1305
+ a_shot = 0.0
1306
+
1307
+ # Total variance: background + shot noise from object-only flux
1308
+ v = var_bg_dn2 + a_shot * obj_dn
1309
+ v_raw = v.copy()
1310
+
1311
+ # Optional mild smoothing
1312
+ if smooth_sigma and smooth_sigma > 0:
1313
+ try:
1314
+ import cv2 as _cv2
1315
+ k = int(max(1, int(round(3 * float(smooth_sigma)))) * 2 + 1)
1316
+ v = _cv2.GaussianBlur(v, (k, k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1317
+ except Exception:
1318
+ try:
1319
+ from scipy.ndimage import gaussian_filter
1320
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1321
+ except Exception:
1322
+ r = int(max(1, round(3 * float(smooth_sigma))))
1323
+ yy, xx = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1324
+ gk = np.exp(-(xx*xx + yy*yy) / (2.0 * float(smooth_sigma) * float(smooth_sigma))).astype(np.float32)
1325
+ gk /= (gk.sum() + EPS)
1326
+ v = _conv_same_np(v, gk)
1327
+
1328
+ # Clip to avoid zero/negative variances
1329
+ v = np.clip(v, float(floor), None).astype(np.float32, copy=False)
1330
+
1331
+ # Emit telemetry
1332
+ try:
1333
+ sky_med = float(np.median(sky_dn_map))
1334
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1335
+ floor_pct = float((v <= floor).mean() * 100.0)
1336
+ status_cb(
1337
+ "Variance map: "
1338
+ f"sky_med={sky_med:.3g} DN | rms_med={rms_med:.3g} DN | "
1339
+ f"gain={(gain if gain is not None else 'NA')} | rn={(readnoise if readnoise is not None else 'NA')} | "
1340
+ f"smooth_sigma={smooth_sigma} | floor={floor} ({floor_pct:.2f}% at floor)"
1341
+ )
1342
+ except Exception:
1343
+ pass
1344
+
1345
+ return v
1346
+
1347
+
1348
+ def _star_mask_from_precomputed(
1349
+ img_2d: np.ndarray,
1350
+ sky_map: np.ndarray,
1351
+ err_scalar: float,
1352
+ *,
1353
+ thresh_sigma: float,
1354
+ max_objs: int,
1355
+ grow_px: int,
1356
+ ellipse_scale: float,
1357
+ soft_sigma: float,
1358
+ max_radius_px: int,
1359
+ keep_floor: float,
1360
+ max_side: int,
1361
+ status_cb=lambda s: None
1362
+ ) -> np.ndarray:
1363
+ """
1364
+ Build a KEEP weight map using a *downscaled detection / full-res draw* path.
1365
+ **Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
1366
+ """
1367
+ # Optional OpenCV fast path
1368
+ try:
1369
+ import cv2 as _cv2
1370
+ _HAS_CV2 = True
1371
+ except Exception:
1372
+ _HAS_CV2 = False
1373
+ _cv2 = None # type: ignore
1374
+
1375
+ H, W = map(int, img_2d.shape)
1376
+
1377
+ # Residual for detection (contiguous, separate buffer)
1378
+ data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
1379
+
1380
+ # Downscale *detection only* to speed up, never the draw step
1381
+ det = data_sub
1382
+ scale = 1.0
1383
+ if max_side and max(H, W) > int(max_side):
1384
+ scale = float(max(H, W)) / float(max_side)
1385
+ if _HAS_CV2:
1386
+ det = _cv2.resize(
1387
+ det,
1388
+ (max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
1389
+ interpolation=_cv2.INTER_AREA
1390
+ )
1391
+ else:
1392
+ s = int(max(1, round(scale)))
1393
+ det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
1394
+ scale = float(s)
1395
+
1396
+ # Threshold ladder
1397
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1398
+ thresh_sigma*8, thresh_sigma*16]
1399
+ objs = None; used = float("nan"); raw = 0
1400
+ for t in thresholds:
1401
+ cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
1402
+ n = 0 if cand is None else len(cand)
1403
+ if n == 0: continue
1404
+ if n > max_objs*12: continue
1405
+ objs, raw, used = cand, n, float(t)
1406
+ break
1407
+
1408
+ if objs is None or len(objs) == 0:
1409
+ try:
1410
+ cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
1411
+ except Exception:
1412
+ cand = None
1413
+ if cand is None or len(cand) == 0:
1414
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1415
+ return np.ones((H, W), dtype=np.float32, order="C")
1416
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1417
+
1418
+ # Brightest max_objs
1419
+ if "flux" in objs.dtype.names:
1420
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1421
+ objs = objs[idx]
1422
+ else:
1423
+ objs = objs[:int(max_objs)]
1424
+ kept = len(objs)
1425
+
1426
+ # ---- draw back on full-res into a brand-new buffer ----
1427
+ mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
1428
+ s_back = float(scale)
1429
+ MR = int(max(1, max_radius_px))
1430
+ G = int(max(0, grow_px))
1431
+ ES = float(max(0.1, ellipse_scale))
1432
+
1433
+ drawn = 0
1434
+ if _HAS_CV2:
1435
+ for o in objs:
1436
+ x = int(round(float(o["x"]) * s_back))
1437
+ y = int(round(float(o["y"]) * s_back))
1438
+ if not (0 <= x < W and 0 <= y < H):
1439
+ continue
1440
+ a = float(o["a"]) * s_back
1441
+ b = float(o["b"]) * s_back
1442
+ r = int(math.ceil(ES * max(a, b)))
1443
+ r = min(max(r, 0) + G, MR)
1444
+ if r <= 0:
1445
+ continue
1446
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1447
+ drawn += 1
1448
+ else:
1449
+ for o in objs:
1450
+ x = int(round(float(o["x"]) * s_back))
1451
+ y = int(round(float(o["y"]) * s_back))
1452
+ if not (0 <= x < W and 0 <= y < H):
1453
+ continue
1454
+ a = float(o["a"]) * s_back
1455
+ b = float(o["b"]) * s_back
1456
+ r = int(math.ceil(ES * max(a, b)))
1457
+ r = min(max(r, 0) + G, MR)
1458
+ if r <= 0:
1459
+ continue
1460
+ y0 = max(0, y - r); y1 = min(H, y + r + 1)
1461
+ x0 = max(0, x - r); x1 = min(W, x + r + 1)
1462
+ yy, xx = np.ogrid[y0:y1, x0:x1]
1463
+ disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
1464
+ mask_u8[y0:y1, x0:x1][disk] = 1
1465
+ drawn += 1
1466
+
1467
+ # Feather + convert to keep weights
1468
+ m = mask_u8.astype(np.float32, copy=False)
1469
+ if soft_sigma > 0:
1470
+ try:
1471
+ if _HAS_CV2:
1472
+ k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
1473
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1474
+ borderType=_cv2.BORDER_REFLECT)
1475
+ else:
1476
+ from scipy.ndimage import gaussian_filter
1477
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1478
+ except Exception:
1479
+ pass
1480
+ np.clip(m, 0.0, 1.0, out=m)
1481
+
1482
+ keep = 1.0 - m
1483
+ kf = float(max(0.0, min(0.99, keep_floor)))
1484
+ keep = kf + (1.0 - kf) * keep
1485
+ np.clip(keep, 0.0, 1.0, out=keep)
1486
+
1487
+ status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
1488
+ return np.ascontiguousarray(keep, dtype=np.float32)
1489
+
1490
+
1491
+ def _variance_map_from_precomputed(
1492
+ img_2d: np.ndarray,
1493
+ sky_map: np.ndarray,
1494
+ rms_map: np.ndarray,
1495
+ hdr,
1496
+ *,
1497
+ smooth_sigma: float,
1498
+ floor: float,
1499
+ status_cb=lambda s: None
1500
+ ) -> np.ndarray:
1501
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1502
+ var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
1503
+ obj_dn = np.clip(img - sky_map, 0.0, None)
1504
+
1505
+ gain = None
1506
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1507
+ if k in hdr:
1508
+ try:
1509
+ g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
1510
+ if gain is not None: break
1511
+ except Exception as e:
1512
+ import logging
1513
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1514
+
1515
+ if gain is not None:
1516
+ a_shot = 1.0 / gain
1517
+ else:
1518
+ sky_med = float(np.median(sky_map))
1519
+ varbg_med= float(np.median(var_bg_dn2))
1520
+ a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
1521
+ a_shot = float(np.clip(a_shot, 0.0, 10.0))
1522
+
1523
+ v = var_bg_dn2 + a_shot * obj_dn
1524
+ if smooth_sigma > 0:
1525
+ try:
1526
+ import cv2 as _cv2
1527
+ k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
1528
+ v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1529
+ except Exception:
1530
+ try:
1531
+ from scipy.ndimage import gaussian_filter
1532
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1533
+ except Exception:
1534
+ pass
1535
+
1536
+ np.clip(v, float(floor), None, out=v)
1537
+ try:
1538
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1539
+ status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
1540
+ except Exception:
1541
+ pass
1542
+ return v.astype(np.float32, copy=False)
1543
+
1544
+
1545
+
1546
+ # -----------------------------
1547
+ # Robust weighting (Huber)
1548
+ # -----------------------------
1549
+
1550
+ EPS = 1e-6
1551
+
1552
+ def _estimate_scalar_variance(a):
1553
+ med = np.median(a)
1554
+ mad = np.median(np.abs(a - med)) + 1e-6
1555
+ return float((1.4826 * mad) ** 2)
1556
+
1557
+ def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
1558
+ """
1559
+ W = [psi(r)/r] * 1/(var + eps) * mask, psi=Huber
1560
+ If huber_delta<0, delta = (-huber_delta) * RMS(residual) via MAD.
1561
+ y,pred: (H,W) or (C,H,W). var_map/mask are 2D; broadcast if needed.
1562
+ """
1563
+ r = y - pred
1564
+ # auto delta?
1565
+ if huber_delta < 0:
1566
+ med = np.median(r)
1567
+ mad = np.median(np.abs(r - med)) + 1e-6
1568
+ rms = 1.4826 * mad
1569
+ delta = (-huber_delta) * max(rms, 1e-6)
1570
+ else:
1571
+ delta = huber_delta
1572
+
1573
+ absr = np.abs(r)
1574
+ if float(delta) > 0:
1575
+ psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + EPS)).astype(np.float32)
1576
+ else:
1577
+ psi_over_r = np.ones_like(r, dtype=np.float32)
1578
+
1579
+ if var_map is None:
1580
+ v = _estimate_scalar_variance(r)
1581
+ else:
1582
+ v = var_map
1583
+ if v.ndim == 2 and r.ndim == 3:
1584
+ v = v[None, ...]
1585
+ w = psi_over_r / (v + EPS)
1586
+ if mask is not None:
1587
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1588
+ w = w * m
1589
+ return w
1590
+
1591
+
1592
+ # -----------------------------
1593
+ # Torch / conv
1594
+ # -----------------------------
1595
+
1596
+ def _fftshape_same(H, W, kh, kw):
1597
+ return H + kh - 1, W + kw - 1
1598
+
1599
+ # ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
1600
+ def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
1601
+ """
1602
+ Pack (Kf,padH,padW,kh,kw) so the conv can crop correctly to SAME.
1603
+ Kernel is ifftshifted before padding.
1604
+ """
1605
+ tfft = torch.fft
1606
+ psf_fft, psfT_fft = [], []
1607
+ for k, kT in zip(psfs, flip_psf):
1608
+ kh, kw = k.shape
1609
+ padH, padW = _fftshape_same(H, W, kh, kw)
1610
+ k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
1611
+ kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
1612
+ Kf = tfft.rfftn(k_small, s=(padH, padW))
1613
+ KTf = tfft.rfftn(kT_small, s=(padH, padW))
1614
+ psf_fft.append((Kf, padH, padW, kh, kw))
1615
+ psfT_fft.append((KTf, padH, padW, kh, kw))
1616
+ return psf_fft, psfT_fft
1617
+
1618
+
1619
+ def _fft_conv_same_torch(x, Kf_pack, out_spatial):
1620
+ tfft = torch.fft
1621
+ Kf, padH, padW, kh, kw = Kf_pack
1622
+ H, W = x.shape[-2], x.shape[-1]
1623
+ if x.ndim == 2:
1624
+ X = tfft.rfftn(x, s=(padH, padW))
1625
+ y = tfft.irfftn(X * Kf, s=(padH, padW))
1626
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1627
+ out_spatial.copy_(y[sh:sh+H, sw:sw+W])
1628
+ return out_spatial
1629
+ else:
1630
+ X = tfft.rfftn(x, s=(padH, padW), dim=(-2,-1))
1631
+ y = tfft.irfftn(X * Kf, s=(padH, padW), dim=(-2,-1))
1632
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1633
+ out_spatial.copy_(y[..., sh:sh+H, sw:sw+W])
1634
+ return out_spatial
1635
+
1636
+ # ---------- NumPy FFT helpers ----------
1637
+ def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
1638
+ import numpy.fft as fft
1639
+ meta, Kfs, KTfs = [], [], []
1640
+ for k, kT in zip(psfs, flip_psf):
1641
+ kh, kw = k.shape
1642
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1643
+ Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
1644
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
1645
+ meta.append((kh, kw, fftH, fftW))
1646
+ return Kfs, KTfs, meta
1647
+
1648
+ def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
1649
+ import numpy.fft as fft
1650
+ if a.ndim == 2:
1651
+ A = fft.rfftn(a, s=(fftH, fftW))
1652
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1653
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1654
+ out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
1655
+ return out
1656
+ else:
1657
+ C, H, W = a.shape
1658
+ acc = []
1659
+ for c in range(C):
1660
+ A = fft.rfftn(a[c], s=(fftH, fftW))
1661
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1662
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1663
+ acc.append(y[sh:sh+H, sw:sw+W])
1664
+ out[...] = np.stack(acc, 0)
1665
+ return out
1666
+
1667
+
1668
+
1669
+ def _torch_device():
1670
+ if TORCH_OK and (torch is not None):
1671
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1672
+ return torch.device("cuda")
1673
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1674
+ return torch.device("mps")
1675
+ # DirectML: we passed dml_device from outer scope; keep a module-global
1676
+ if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
1677
+ return globals()["dml_device"]
1678
+ return torch.device("cpu")
1679
+
1680
+ def _to_t(x: np.ndarray):
1681
+ if not (TORCH_OK and (torch is not None)):
1682
+ raise RuntimeError("Torch path requested but torch is unavailable")
1683
+ device = _torch_device()
1684
+ t = torch.from_numpy(x)
1685
+ # DirectML wants explicit .to(device)
1686
+ return t.to(device, non_blocking=True) if str(device) != "cpu" else t
1687
+
1688
+ def _contig(x):
1689
+ return np.ascontiguousarray(x, dtype=np.float32)
1690
+
1691
+ def _conv_same_torch(img_t, psf_t):
1692
+ """
1693
+ img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
1694
+ psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
1695
+ Pads with 'reflect' to avoid zero-padding ringing.
1696
+ """
1697
+ kh, kw = psf_t.shape[-2:]
1698
+ pad = (kw // 2, kw - kw // 2 - 1, # left, right
1699
+ kh // 2, kh - kh // 2 - 1) # top, bottom
1700
+
1701
+ if img_t.ndim == 2:
1702
+ x = img_t[None, None]
1703
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1704
+ y = torch.nn.functional.conv2d(x, psf_t, padding=0)
1705
+ return y[0, 0]
1706
+ else:
1707
+ C = img_t.shape[0]
1708
+ x = img_t[None]
1709
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1710
+ w = psf_t.repeat(C, 1, 1, 1)
1711
+ y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
1712
+ return y[0]
1713
+
1714
+ def _safe_inference_context():
1715
+ """
1716
+ Return a valid, working no-grad context:
1717
+ - prefer torch.inference_mode() if it exists *and* can be entered,
1718
+ - otherwise fall back to torch.no_grad(),
1719
+ - if torch is unavailable, return NO_GRAD.
1720
+ """
1721
+ if not (TORCH_OK and (torch is not None)):
1722
+ return NO_GRAD
1723
+
1724
+ cm = getattr(torch, "inference_mode", None)
1725
+ if cm is None:
1726
+ return torch.no_grad
1727
+
1728
+ # Probe inference_mode once; if it explodes on this build, fall back.
1729
+ try:
1730
+ with cm():
1731
+ pass
1732
+ return cm
1733
+ except Exception:
1734
+ return torch.no_grad
1735
+
1736
+ def _ensure_mask_list(masks, data):
1737
+ # 1s where valid, 0s where invalid (soft edges allowed)
1738
+ if masks is None:
1739
+ return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
1740
+ out = []
1741
+ for a, m in zip(data, masks):
1742
+ base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
1743
+ if m is None:
1744
+ out.append(np.ones_like(base, dtype=np.float32))
1745
+ else:
1746
+ mm = np.asarray(m, dtype=np.float32)
1747
+ if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
1748
+ mm = mm[0]
1749
+ if mm.shape != base.shape:
1750
+ # center crop to match (common intersection already applied)
1751
+ Ht, Wt = base.shape
1752
+ mm = _center_crop(mm, Ht, Wt)
1753
+ # keep as float weights in [0,1] (do not threshold!)
1754
+ out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
1755
+ return out
1756
+
1757
+ def _ensure_var_list(variances, data):
1758
+ # If None, we’ll estimate a robust scalar per frame on-the-fly.
1759
+ if variances is None:
1760
+ return [None]*len(data)
1761
+ out = []
1762
+ for a, v in zip(data, variances):
1763
+ if v is None:
1764
+ out.append(None)
1765
+ else:
1766
+ vv = np.asarray(v, dtype=np.float32)
1767
+ if vv.ndim == 3:
1768
+ vv = vv[0]
1769
+ base = a if a.ndim==2 else a[0]
1770
+ if vv.shape != base.shape:
1771
+ Ht, Wt = base.shape
1772
+ vv = _center_crop(vv, Ht, Wt)
1773
+ # clip tiny/negatives
1774
+ vv = np.nan_to_num(vv, nan=1e-8, posinf=1e8, neginf=1e8, copy=False)
1775
+ vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
1776
+ out.append(vv)
1777
+ return out
1778
+
1779
+ # ---- SR operators (downsample / upsample-sum) ----
1780
+ def _downsample_avg(img, r: int):
1781
+ """Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
1782
+ if r <= 1:
1783
+ return img
1784
+ a = np.asarray(img, dtype=np.float32)
1785
+ if a.ndim == 2:
1786
+ H, W = a.shape
1787
+ Hs, Ws = (H // r) * r, (W // r) * r
1788
+ a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
1789
+ return a
1790
+ else:
1791
+ C, H, W = a.shape
1792
+ Hs, Ws = (H // r) * r, (W // r) * r
1793
+ a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
1794
+ return a
1795
+
1796
+ def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
1797
+ """Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
1798
+ For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
1799
+ """
1800
+ if r <= 1:
1801
+ return img
1802
+ a = np.asarray(img, dtype=np.float32)
1803
+ if a.ndim == 2:
1804
+ H, W = a.shape
1805
+ out = np.kron(a, np.ones((r, r), dtype=np.float32))
1806
+ else:
1807
+ C, H, W = a.shape
1808
+ out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
1809
+ if target_hw is not None:
1810
+ Ht, Wt = target_hw
1811
+ out = _center_crop(out, Ht, Wt)
1812
+ return out
1813
+
1814
+ def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
1815
+ r = (ksize - 1) // 2
1816
+ y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1817
+ g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
1818
+ g /= g.sum() + EPS
1819
+ return g
1820
+
1821
+ def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
1822
+ # lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
1823
+ return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
1824
+
1825
+ def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
1826
+ iters: int = 500, lr: float = 0.1) -> np.ndarray:
1827
+ """
1828
+ Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
1829
+ where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
1830
+ """
1831
+ f = np.asarray(f_native, dtype=np.float32)
1832
+
1833
+ # NEW: sanitize to 2D odd square before anything else
1834
+ if f.ndim != 2:
1835
+ f = np.squeeze(f)
1836
+ if f.ndim != 2:
1837
+ raise ValueError(f"PSF must be 2D, got shape {f.shape}")
1838
+
1839
+ H, W = int(f.shape[0]), int(f.shape[1])
1840
+ k_sq = min(H, W)
1841
+ # center-crop to square if needed
1842
+ if H != W:
1843
+ y0 = (H - k_sq) // 2
1844
+ x0 = (W - k_sq) // 2
1845
+ f = f[y0:y0 + k_sq, x0:x0 + k_sq]
1846
+ H = W = k_sq
1847
+
1848
+ # enforce odd size (required by SAME padding math)
1849
+ if (H % 2) == 0:
1850
+ # drop one pixel border to make it odd (centered)
1851
+ f = f[1:, 1:]
1852
+ H = W = f.shape[0]
1853
+
1854
+ k = int(H) # k is now odd and square
1855
+ kr = int(k * r)
1856
+
1857
+
1858
+ g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
1859
+
1860
+ h0 = np.zeros((kr, kr), dtype=np.float32)
1861
+ h0[::r, ::r] = f
1862
+ h0 = _normalize_psf(h0)
1863
+
1864
+ if TORCH_OK:
1865
+ import torch.nn.functional as F
1866
+ dev = _torch_device()
1867
+
1868
+ # (1) Make sure Gaussian kernel is odd-sized for SAME conv padding
1869
+ g_pad = g
1870
+ if (g.shape[-1] % 2) == 0:
1871
+ # ensure odd + renormalize
1872
+ gg = _pad_kernel_to(g, g.shape[-1] + 1)
1873
+ g_pad = gg.astype(np.float32, copy=False)
1874
+
1875
+ t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
1876
+ f_t = torch.tensor(f, device=dev, dtype=torch.float32)
1877
+ g_t = torch.tensor(g_pad, device=dev, dtype=torch.float32)
1878
+ opt = torch.optim.Adam([t], lr=lr)
1879
+
1880
+ # Helpful assertion avoids silent shape traps
1881
+ H, W = t.shape
1882
+ assert (H % r) == 0 and (W % r) == 0, f"h shape {t.shape} not divisible by r={r}"
1883
+ Hr, Wr = H // r, W // r
1884
+
1885
+ try:
1886
+ for _ in range(max(10, iters)):
1887
+ opt.zero_grad(set_to_none=True)
1888
+
1889
+ # (2) Downsample with avg_pool2d instead of reshape/mean
1890
+ blk = t.narrow(0, 0, Hr * r).narrow(1, 0, Wr * r).contiguous()
1891
+ th = F.avg_pool2d(blk[None, None], kernel_size=r, stride=r)[0, 0] # (k,k)
1892
+
1893
+ # (3) Native-space blur with guaranteed-odd g_t
1894
+ pad = g_t.shape[-1] // 2
1895
+ conv = F.conv2d(th[None, None], g_t[None, None], padding=pad)[0, 0]
1896
+
1897
+ loss = torch.mean((conv - f_t) ** 2)
1898
+ loss.backward()
1899
+ opt.step()
1900
+ with torch.no_grad():
1901
+ t.clamp_(min=0.0)
1902
+ t /= (t.sum() + 1e-8)
1903
+ h = t.detach().cpu().numpy().astype(np.float32)
1904
+ except Exception:
1905
+ # (4) Conservative safety net: if a backend balks (commonly at r=2),
1906
+ # fall back to the NumPy solver *just for this kernel*.
1907
+ h = None
1908
+
1909
+ if not TORCH_OK or h is None:
1910
+ # NumPy fallback (unchanged)
1911
+ h = h0.copy()
1912
+ eta = float(lr)
1913
+ for _ in range(max(50, iters)):
1914
+ Dh = _downsample_avg(h, r)
1915
+ conv = _conv2_same_np(Dh, g)
1916
+ resid = (conv - f)
1917
+ grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
1918
+ grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
1919
+ h = np.clip(h - eta * grad_h, 0.0, None)
1920
+ s = float(h.sum()); h /= (s + 1e-8)
1921
+ eta *= 0.995
1922
+
1923
+ return _normalize_psf(h)
1924
+
1925
+
1926
+ def _downsample_avg_t(x, r: int):
1927
+ if r <= 1:
1928
+ return x
1929
+ if x.ndim == 2:
1930
+ H, W = x.shape
1931
+ Hr, Wr = (H // r) * r, (W // r) * r
1932
+ if Hr == 0 or Wr == 0:
1933
+ return x
1934
+ x2 = x[:Hr, :Wr]
1935
+ # ❌ .view → ✅ .reshape
1936
+ return x2.reshape(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
1937
+ else:
1938
+ C, H, W = x.shape
1939
+ Hr, Wr = (H // r) * r, (W // r) * r
1940
+ if Hr == 0 or Wr == 0:
1941
+ return x
1942
+ x2 = x[:, :Hr, :Wr]
1943
+ # ❌ .view → ✅ .reshape
1944
+ return x2.reshape(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
1945
+
1946
+ def _upsample_sum_t(x, r: int):
1947
+ if r <= 1:
1948
+ return x
1949
+ if x.ndim == 2:
1950
+ return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
1951
+ else:
1952
+ return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
1953
+
1954
+ def _sep_bg_rms(frames):
1955
+ """Return a robust background RMS using SEP's background model on the first frame."""
1956
+ if sep is None or not frames:
1957
+ return None
1958
+ try:
1959
+ y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
1960
+ a = np.ascontiguousarray(y0, dtype=np.float32)
1961
+ b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
1962
+ try:
1963
+ rms_val = float(b.globalrms)
1964
+ except Exception:
1965
+ # some SEP builds don’t expose globalrms; fall back to the map’s median
1966
+ rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
1967
+ return rms_val
1968
+ except Exception:
1969
+ return None
1970
+
1971
+ # =========================
1972
+ # Memory/streaming helpers
1973
+ # =========================
1974
+
1975
+ def _approx_bytes(arr_like_shape, dtype=np.float32):
1976
+ """Rough byte estimator for a given shape/dtype."""
1977
+ return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
1978
+
1979
+ def _mem_model(
1980
+ grid_hw: tuple[int,int],
1981
+ r: int,
1982
+ ksize: int,
1983
+ channels: int,
1984
+ mem_target_mb: int,
1985
+ prefer_tiles: bool = False,
1986
+ min_tile: int = 256,
1987
+ max_tile: int = 2048,
1988
+ ) -> dict:
1989
+ """
1990
+ Pick a batch size (#frames) and optional tile size (HxW) given a memory budget.
1991
+ Very conservative — aims to bound peak working-set on CPU/GPU.
1992
+ """
1993
+ Hs, Ws = grid_hw
1994
+ halo = (ksize // 2) * max(1, r) # SR grid halo if r>1
1995
+ C = max(1, channels)
1996
+
1997
+ # working-set per *full-frame* conv scratch (num/den/tmp/etc.)
1998
+ per_frame_fft_like = 3 * _approx_bytes((C, Hs, Ws)) # tmp/pred + in/out buffers
1999
+ global_accum = 2 * _approx_bytes((C, Hs, Ws)) # num + den
2000
+
2001
+ budget = int(mem_target_mb * 1024 * 1024)
2002
+
2003
+ # Try to stay in full-frame mode first unless prefer_tiles
2004
+ B_full = max(1, (budget - global_accum) // max(per_frame_fft_like, 1))
2005
+ use_tiles = prefer_tiles or (B_full < 1)
2006
+
2007
+ if not use_tiles:
2008
+ return dict(batch_frames=int(B_full), tiles=None, halo=int(halo), ksize=int(ksize))
2009
+
2010
+ # Tile mode: pick a square tile side t that fits
2011
+ # scratch per tile ~ 3*C*(t+2h)^2 + accum(core) ~ small
2012
+ # try descending from max_tile
2013
+ t = int(min(max_tile, max(min_tile, 1 << int(np.floor(np.log2(min(Hs, Ws)))))))
2014
+ while t >= min_tile:
2015
+ th = t + 2 * halo
2016
+ per_tile = 3 * _approx_bytes((C, th, th))
2017
+ B_tile = max(1, (budget - global_accum) // max(per_tile, 1))
2018
+ if B_tile >= 1:
2019
+ return dict(batch_frames=int(B_tile), tiles=(t, t), halo=int(halo), ksize=int(ksize))
2020
+ t //= 2
2021
+
2022
+ # Worst case: 1 frame, minimal tile
2023
+ return dict(batch_frames=1, tiles=(min_tile, min_tile), halo=int(halo), ksize=int(ksize))
2024
+
2025
+ def _build_seed_running_mu_sigma_from_paths(
2026
+ paths, Ht, Wt, color_mode,
2027
+ *, bootstrap_frames=24, clip_sigma=3.5, # clip_sigma used for streaming updates
2028
+ status_cb=lambda s: None, progress_cb=None
2029
+ ):
2030
+ """
2031
+ Seed:
2032
+ 1) Load first B frames -> mean0
2033
+ 2) MAD around mean0 -> ±4·MAD mask -> masked-mean seed (one mini-iteration)
2034
+ 3) Stream remaining frames with σ-clipped Welford updates (unchanged behavior)
2035
+ Returns float32 image in (H,W) or (C,H,W) matching color_mode.
2036
+ """
2037
+ def p(frac, msg):
2038
+ if progress_cb:
2039
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2040
+
2041
+ n_total = len(paths)
2042
+ B = int(max(1, min(int(bootstrap_frames), n_total)))
2043
+ status_cb(f"MFDeconv: Seed bootstrap {B} frame(s) with ±4·MAD clip on the average…")
2044
+ p(0.00, f"bootstrap load 0/{B}")
2045
+
2046
+ # ---------- load first B frames ----------
2047
+ boot = []
2048
+ for i, pth in enumerate(paths[:B], start=1):
2049
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2050
+ boot.append(ys[0].astype(np.float32, copy=False))
2051
+ if (i == B) or (i % 4 == 0):
2052
+ p(0.25 * (i / float(B)), f"bootstrap load {i}/{B}")
2053
+
2054
+ stack = np.stack(boot, axis=0) # (B,H,W) or (B,C,H,W)
2055
+ del boot
2056
+
2057
+ # ---------- mean0 ----------
2058
+ mean0 = np.mean(stack, axis=0, dtype=np.float32)
2059
+ p(0.28, "bootstrap mean computed")
2060
+
2061
+ # ---------- ±4·MAD clip around mean0, then masked mean (one pass) ----------
2062
+ # MAD per-pixel: median(|x - mean0|)
2063
+ abs_dev = np.abs(stack - mean0[None, ...])
2064
+ mad = np.median(abs_dev, axis=0).astype(np.float32, copy=False)
2065
+
2066
+ thr = 4.0 * mad + EPS
2067
+ mask = (abs_dev <= thr)
2068
+
2069
+ # masked mean with fallback to mean0 where all rejected
2070
+ m = mask.astype(np.float32, copy=False)
2071
+ sum_acc = np.sum(stack * m, axis=0, dtype=np.float32)
2072
+ cnt_acc = np.sum(m, axis=0, dtype=np.float32)
2073
+ seed = mean0.copy()
2074
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2075
+ p(0.36, "±4·MAD masked mean computed")
2076
+
2077
+ # ---------- initialize Welford state from seed ----------
2078
+ # Start μ=seed, set an initial variance envelope from the bootstrap dispersion
2079
+ dif = stack - seed[None, ...]
2080
+ M2 = np.sum(dif * dif, axis=0, dtype=np.float32)
2081
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2082
+ mu = seed.astype(np.float32, copy=False)
2083
+ del stack, abs_dev, mad, m, sum_acc, cnt_acc, dif
2084
+
2085
+ p(0.40, "seed initialized; streaming refinements…")
2086
+
2087
+ # ---------- stream remaining frames with σ-clipped Welford updates ----------
2088
+ remain = n_total - B
2089
+ if remain > 0:
2090
+ status_cb(f"MFDeconv: Seed μ–σ clipping {remain} remaining frame(s) (k={clip_sigma:.2f})…")
2091
+
2092
+ k = float(clip_sigma)
2093
+ for j, pth in enumerate(paths[B:], start=1):
2094
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2095
+ x = ys[0].astype(np.float32, copy=False)
2096
+
2097
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2098
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2099
+
2100
+ accept = (np.abs(x - mu) <= (k * sigma))
2101
+ acc = accept.astype(np.float32, copy=False)
2102
+
2103
+ n_new = cnt + acc
2104
+ delta = x - mu
2105
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2106
+ M2 = M2 + acc * delta * (x - mu_n)
2107
+
2108
+ mu, cnt = mu_n, n_new
2109
+
2110
+ if (j == remain) or (j % 8 == 0):
2111
+ p(0.40 + 0.60 * (j / float(remain)), f"μ–σ refine {j}/{remain}")
2112
+
2113
+ p(1.0, "seed ready")
2114
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2115
+
2116
+
2117
+ def _chunk(seq, n):
2118
+ """Yield chunks of size n from seq."""
2119
+ for i in range(0, len(seq), n):
2120
+ yield seq[i:i+n]
2121
+
2122
+ def _read_shape_fast(path) -> tuple[int,int,int]:
2123
+ if _is_xisf(path):
2124
+ a, _ = _load_image_array(path)
2125
+ if a is None:
2126
+ raise ValueError(f"No data in {path}")
2127
+ a = np.asarray(a)
2128
+ else:
2129
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2130
+ a = hdul[0].data
2131
+ if a is None:
2132
+ raise ValueError(f"No data in {path}")
2133
+
2134
+ # common logic for both XISF and FITS
2135
+ if a.ndim == 2:
2136
+ H, W = a.shape
2137
+ return (1, int(H), int(W))
2138
+ if a.ndim == 3:
2139
+ if a.shape[-1] in (1, 3): # HWC
2140
+ C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
2141
+ return (1 if C == 1 else 3, H, W)
2142
+ if a.shape[0] in (1, 3): # CHW
2143
+ return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
2144
+ s = tuple(map(int, a.shape))
2145
+ H, W = s[-2], s[-1]
2146
+ return (1, H, W)
2147
+
2148
+
2149
+
2150
+
2151
+ def _tiles_of(hw: tuple[int,int], tile_hw: tuple[int,int], halo: int):
2152
+ """
2153
+ Yield tiles as dicts: {y0,y1,x0,x1,yc0,yc1,xc0,xc1}
2154
+ (outer region includes halo; core (yc0:yc1, xc0:xc1) excludes halo).
2155
+ """
2156
+ H, W = hw
2157
+ th, tw = tile_hw
2158
+ th = max(1, int(th)); tw = max(1, int(tw))
2159
+ for y in range(0, H, th):
2160
+ for x in range(0, W, tw):
2161
+ yc0 = y; yc1 = min(y + th, H)
2162
+ xc0 = x; xc1 = min(x + tw, W)
2163
+ y0 = max(0, yc0 - halo); y1 = min(H, yc1 + halo)
2164
+ x0 = max(0, xc0 - halo); x1 = min(W, xc1 + halo)
2165
+ yield dict(y0=y0, y1=y1, x0=x0, x1=x1, yc0=yc0, yc1=yc1, xc0=xc0, xc1=xc1)
2166
+
2167
+ def _extract_with_halo(a, tile):
2168
+ """
2169
+ Slice 'a' ((H,W) or (C,H,W)) to [y0:y1, x0:x1] with channel kept.
2170
+ """
2171
+ y0,y1,x0,x1 = tile["y0"], tile["y1"], tile["x0"], tile["x1"]
2172
+ if a.ndim == 2:
2173
+ return a[y0:y1, x0:x1]
2174
+ else:
2175
+ return a[:, y0:y1, x0:x1]
2176
+
2177
+ def _add_core(accum, tile_val, tile):
2178
+ """
2179
+ Add tile_val core into accum at (yc0:yc1, xc0:xc1).
2180
+ Shapes match (2D) or (C,H,W).
2181
+ """
2182
+ yc0,yc1,xc0,xc1 = tile["yc0"], tile["yc1"], tile["xc0"], tile["xc1"]
2183
+ if accum.ndim == 2:
2184
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2185
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2186
+ accum[yc0:yc1, xc0:xc1] += tile_val[h0:h1, w0:w1]
2187
+ else:
2188
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2189
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2190
+ accum[:, yc0:yc1, xc0:xc1] += tile_val[:, h0:h1, w0:w1]
2191
+
2192
+ def _prepare_np_fft_packs_batch(psfs, flip_psf, Hs, Ws):
2193
+ """Precompute rFFT packs on current grid for NumPy path; returns lists aligned to batch psfs."""
2194
+ Kfs, KTfs, meta = [], [], []
2195
+ import numpy.fft as fft
2196
+ for k, kT in zip(psfs, flip_psf):
2197
+ kh, kw = k.shape
2198
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2199
+ Kfs.append(fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)))
2200
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)))
2201
+ meta.append((kh, kw, fftH, fftW))
2202
+ return Kfs, KTfs, meta
2203
+
2204
+ def _prepare_torch_fft_packs_batch(psfs, flip_psf, Hs, Ws, device, dtype):
2205
+ """Torch FFT packs per PSF on current grid; mirrors your existing packer."""
2206
+ return _precompute_torch_psf_ffts(psfs, flip_psf, Hs, Ws, device, dtype)
2207
+
2208
+ def _as_chw(np_img: np.ndarray) -> np.ndarray:
2209
+ x = np.asarray(np_img, dtype=np.float32, order="C")
2210
+ if x.size == 0:
2211
+ raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
2212
+ if x.ndim == 2:
2213
+ return x[None, ...] # 1,H,W
2214
+ if x.ndim == 3 and x.shape[0] in (1, 3):
2215
+ if x.shape[0] == 0:
2216
+ raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
2217
+ return x
2218
+ if x.ndim == 3 and x.shape[-1] in (1, 3):
2219
+ if x.shape[-1] == 0:
2220
+ raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
2221
+ return np.moveaxis(x, -1, 0)
2222
+ # last resort: treat first dim as channels, but reject zero
2223
+ if x.shape[0] == 0:
2224
+ raise RuntimeError(f"Zero channels in array; shape={x.shape}")
2225
+ return x
2226
+
2227
+
2228
+
2229
+ def _conv_same_np_spatial(a: np.ndarray, k: np.ndarray, out: np.ndarray | None = None):
2230
+ try:
2231
+ import cv2
2232
+ except Exception:
2233
+ return None # no opencv -> caller falls back to FFT
2234
+
2235
+ # cv2 wants HxW single-channel float32
2236
+ kf = np.ascontiguousarray(k.astype(np.float32))
2237
+ kf = np.flip(np.flip(kf, 0), 1) # OpenCV uses correlation; flip to emulate conv
2238
+
2239
+ if a.ndim == 2:
2240
+ y = cv2.filter2D(a, -1, kf, borderType=cv2.BORDER_REFLECT)
2241
+ if out is None: return y
2242
+ out[...] = y; return out
2243
+ else:
2244
+ C, H, W = a.shape
2245
+ if out is None:
2246
+ out = np.empty_like(a)
2247
+ for c in range(C):
2248
+ out[c] = cv2.filter2D(a[c], -1, kf, borderType=cv2.BORDER_REFLECT)
2249
+ return out
2250
+
2251
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
2252
+ F = torch.nn.functional
2253
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
2254
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
2255
+
2256
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
2257
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
2258
+
2259
+ # unified path (CUDA/CPU/MPS): one grouped conv with G=B*C
2260
+ G = int(B * C)
2261
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
2262
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
2263
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
2264
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
2265
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
2266
+
2267
+
2268
+ # put near other small helpers
2269
+ def _robust_med_mad_t(x, max_elems_per_sample: int = 2_000_000):
2270
+ """
2271
+ x: (B, C, H, W) tensor on device.
2272
+ Returns (median[B,1,1,1], mad[B,1,1,1]) computed on a strided subsample
2273
+ to avoid 'quantile() input tensor is too large'.
2274
+ """
2275
+ import math
2276
+ import torch
2277
+ B = x.shape[0]
2278
+ flat = x.reshape(B, -1)
2279
+ N = flat.shape[1]
2280
+ if N > max_elems_per_sample:
2281
+ stride = int(math.ceil(N / float(max_elems_per_sample)))
2282
+ flat = flat[:, ::stride] # strided subsample
2283
+ med = torch.quantile(flat, 0.5, dim=1, keepdim=True)
2284
+ mad = torch.quantile((flat - med).abs(), 0.5, dim=1, keepdim=True) + 1e-6
2285
+ return med.view(B,1,1,1), mad.view(B,1,1,1)
2286
+
2287
+ def _torch_should_use_spatial(psf_ksize: int) -> bool:
2288
+ # Prefer spatial on non-CUDA backends and for modest kernels.
2289
+ try:
2290
+ dev = _torch_device()
2291
+ if dev.type in ("mps", "privateuseone"): # privateuseone = DirectML
2292
+ return True
2293
+ if dev.type == "cuda":
2294
+ return psf_ksize <= 51 # typical PSF sizes; spatial is fast & stable
2295
+ except Exception:
2296
+ pass
2297
+ # Allow override via env
2298
+ import os as _os
2299
+ if _os.environ.get("MF_SPATIAL", "") == "1":
2300
+ return True
2301
+ return False
2302
+
2303
+ def _read_tile_fits(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2304
+ """Return a (H,W) or (H,W,3|1) tile via FITS memmap, without loading whole image."""
2305
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2306
+ hdu = hdul[0]
2307
+ a = hdu.data
2308
+ if a is None:
2309
+ # find first image HDU if primary is header-only
2310
+ for h in hdul[1:]:
2311
+ if getattr(h, "data", None) is not None:
2312
+ a = h.data; break
2313
+ a = np.asarray(a) # still lazy until sliced
2314
+ # squeeze trailing singleton if present to keep your conventions
2315
+ if a.ndim == 3 and a.shape[-1] == 1:
2316
+ a = np.squeeze(a, axis=-1)
2317
+ tile = a[y0:y1, x0:x1, ...]
2318
+ # copy so we own the buffer (we will cast/normalize)
2319
+ return np.array(tile, copy=True)
2320
+
2321
+ def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2322
+ """FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
2323
+ ext = os.path.splitext(path)[1].lower()
2324
+
2325
+ if ext == ".xisf":
2326
+ a = _xisf_cached_array(path) # float32 memmap; cheap slicing
2327
+ # a is HW, HWC, or CHW (whatever _load_image_array returned)
2328
+ if a.ndim == 2:
2329
+ return np.array(a[y0:y1, x0:x1], copy=True)
2330
+ elif a.ndim == 3:
2331
+ if a.shape[-1] in (1, 3): # HWC
2332
+ out = a[y0:y1, x0:x1, :]
2333
+ if out.shape[-1] == 1: out = out[..., 0]
2334
+ return np.array(out, copy=True)
2335
+ elif a.shape[0] in (1, 3): # CHW
2336
+ out = a[:, y0:y1, x0:x1]
2337
+ if out.shape[0] == 1: out = out[0]
2338
+ return np.array(out, copy=True)
2339
+ else:
2340
+ raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
2341
+ else:
2342
+ raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
2343
+
2344
+ # FITS
2345
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2346
+ a = None
2347
+ for h in hdul:
2348
+ if getattr(h, "data", None) is not None:
2349
+ a = h.data
2350
+ break
2351
+ if a is None:
2352
+ raise ValueError(f"No image data in {path}")
2353
+
2354
+ a = np.asarray(a)
2355
+
2356
+ if a.ndim == 2: # HW
2357
+ return np.array(a[y0:y1, x0:x1], copy=True)
2358
+
2359
+ if a.ndim == 3:
2360
+ if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
2361
+ out = a[:, y0:y1, x0:x1]
2362
+ if out.shape[0] == 1:
2363
+ out = out[0]
2364
+ return np.array(out, copy=True)
2365
+ if a.shape[-1] in (1, 3): # HWC
2366
+ out = a[y0:y1, x0:x1, :]
2367
+ if out.shape[-1] == 1:
2368
+ out = out[..., 0]
2369
+ return np.array(out, copy=True)
2370
+
2371
+ # Fallback: assume last two axes are spatial (…, H, W)
2372
+ try:
2373
+ out = a[(..., slice(y0, y1), slice(x0, x1))]
2374
+ return np.array(out, copy=True)
2375
+ except Exception:
2376
+ raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
2377
+
2378
+ def _infer_channels_from_tile(p: str, Ht: int, Wt: int) -> int:
2379
+ """Look at a 1×1 tile to infer channel count; supports HW, HWC, CHW."""
2380
+ y1 = min(1, Ht); x1 = min(1, Wt)
2381
+ t = _read_tile_fits_any(p, 0, y1, 0, x1)
2382
+
2383
+ if t.ndim == 2:
2384
+ return 1
2385
+
2386
+ if t.ndim == 3:
2387
+ # Prefer the axis that actually carries the color planes
2388
+ ch_first = t.shape[0] in (1, 3)
2389
+ ch_last = t.shape[-1] in (1, 3)
2390
+
2391
+ if ch_first and not ch_last:
2392
+ return int(t.shape[0])
2393
+ if ch_last and not ch_first:
2394
+ return int(t.shape[-1])
2395
+
2396
+ # Ambiguous tiny tile (e.g. CHW 3×1×1 or HWC 1×1×3):
2397
+ if t.shape[0] == 3 or t.shape[-1] == 3:
2398
+ return 3
2399
+ return 1
2400
+
2401
+ return 1
2402
+
2403
+
2404
+
2405
+ def _seed_median_streaming(
2406
+ paths,
2407
+ Ht,
2408
+ Wt,
2409
+ *,
2410
+ color_mode="luma",
2411
+ tile_hw=(256, 256),
2412
+ status_cb=lambda s: None,
2413
+ progress_cb=lambda f, m="": None,
2414
+ use_torch: bool | None = None, # auto by default
2415
+ ):
2416
+ """
2417
+ Exact per-pixel median via tiling; RAM-bounded.
2418
+ Now shows per-tile progress and uses Torch on GPU if available.
2419
+ Parallelizes per-tile slab reads to hide I/O and luma work.
2420
+ """
2421
+
2422
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2423
+ # old: want_c = 1 if str(color_mode).lower() == "luma" else (3 if _read_shape_fast(paths[0])[0] == 3 else 1)
2424
+ if str(color_mode).lower() == "luma":
2425
+ want_c = 1
2426
+ else:
2427
+ want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
2428
+ seed = np.zeros((Ht, Wt), np.float32) if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32)
2429
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt)) for y in range(0, Ht, th) for x in range(0, Wt, tw)]
2430
+ total = len(tiles)
2431
+ n_frames = len(paths)
2432
+
2433
+ # Choose a sensible number of I/O workers (bounded by frames)
2434
+ try:
2435
+ _cpu = (os.cpu_count() or 4)
2436
+ except Exception:
2437
+ _cpu = 4
2438
+ io_workers = max(1, min(8, _cpu, n_frames))
2439
+
2440
+ # Torch autodetect (once)
2441
+ TORCH_OK = False
2442
+ device = None
2443
+ if use_torch is not False:
2444
+ try:
2445
+ from setiastro.saspro.runtime_torch import import_torch
2446
+ _t = import_torch(prefer_cuda=True, status_cb=status_cb)
2447
+ dev = None
2448
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
2449
+ dev = _t.device("cuda")
2450
+ elif hasattr(_t.backends, "mps") and _t.backends.mps.is_available():
2451
+ dev = _t.device("mps")
2452
+ else:
2453
+ dev = None # CPU tensors slower than NumPy for median; only use if forced
2454
+ if dev is not None:
2455
+ TORCH_OK = True
2456
+ device = dev
2457
+ status_cb(f"Median seed: using Torch device {device}")
2458
+ except Exception as e:
2459
+ status_cb(f"Median seed: Torch unavailable → NumPy fallback ({e})")
2460
+ TORCH_OK = False
2461
+ device = None
2462
+
2463
+ def _tile_msg(ti, tn):
2464
+ return f"median tiles {ti}/{tn}"
2465
+
2466
+ done = 0
2467
+
2468
+ for (y0, y1, x0, x1) in tiles:
2469
+ h, w = (y1 - y0), (x1 - x0)
2470
+
2471
+ # per-tile slab reader with incremental progress (parallel)
2472
+ def _read_slab_for_channel(csel=None):
2473
+ """
2474
+ Returns slab of shape (N, h, w) float32 in [0,1]-ish (normalized if input was integer).
2475
+ If csel is None and luma is requested, computes luma.
2476
+ """
2477
+ # parallel worker returns (i, tile2d)
2478
+ def _load_one(i):
2479
+ t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
2480
+
2481
+ # normalize dtype
2482
+ if t.dtype.kind in "ui":
2483
+ t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
2484
+ else:
2485
+ t = t.astype(np.float32, copy=False)
2486
+
2487
+ # luma / channel selection
2488
+ if want_c == 1:
2489
+ if t.ndim == 3:
2490
+ t = _to_luma_local(t)
2491
+ elif t.ndim != 2:
2492
+ t = _to_luma_local(t)
2493
+ else:
2494
+ if t.ndim == 2:
2495
+ pass
2496
+ elif t.ndim == 3 and t.shape[-1] == 3: # HWC
2497
+ t = t[..., csel]
2498
+ elif t.ndim == 3 and t.shape[0] == 3: # CHW
2499
+ t = t[csel]
2500
+ else:
2501
+ t = _to_luma_local(t)
2502
+ return i, np.ascontiguousarray(t, dtype=np.float32)
2503
+
2504
+ slab = np.empty((n_frames, h, w), np.float32)
2505
+ done_local = 0
2506
+ # cap workers by n_frames so we don't spawn useless threads
2507
+ with ThreadPoolExecutor(max_workers=min(io_workers, n_frames)) as ex:
2508
+ futures = [ex.submit(_load_one, i) for i in range(n_frames)]
2509
+ for fut in as_completed(futures):
2510
+ i, t2d = fut.result()
2511
+ # quick sanity (avoid silent mis-shapes)
2512
+ if t2d.shape != (h, w):
2513
+ raise RuntimeError(
2514
+ f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
2515
+ f"tile={(y0,y1,x0,x1)}"
2516
+ )
2517
+ slab[i] = t2d
2518
+ done_local += 1
2519
+ if (done_local & 7) == 0 or done_local == n_frames:
2520
+ tile_base = done / total
2521
+ tile_span = 1.0 / total
2522
+ inner = done_local / n_frames
2523
+ progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
2524
+ return slab
2525
+
2526
+ try:
2527
+ if want_c == 1:
2528
+ t0 = time.perf_counter()
2529
+ slab = _read_slab_for_channel()
2530
+ t1 = time.perf_counter()
2531
+ if TORCH_OK:
2532
+ import torch as _t
2533
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32) # one H2D
2534
+ med_t = slab_t.median(dim=0).values
2535
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2536
+ # no per-tile empty_cache() — avoid forced syncs
2537
+ else:
2538
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2539
+ t2 = time.perf_counter()
2540
+ seed[y0:y1, x0:x1] = med_np
2541
+ # lightweight telemetry to confirm bottleneck
2542
+ status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median={'GPU' if TORCH_OK else 'CPU'}={t2-t1:.3f}s")
2543
+ else:
2544
+ for c in range(want_c):
2545
+ slab = _read_slab_for_channel(csel=c)
2546
+ if TORCH_OK:
2547
+ import torch as _t
2548
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32)
2549
+ med_t = slab_t.median(dim=0).values
2550
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2551
+ else:
2552
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2553
+ seed[c, y0:y1, x0:x1] = med_np
2554
+
2555
+ except RuntimeError as e:
2556
+ # Per-tile GPU OOM or device issues → fallback to NumPy for this tile
2557
+ msg = str(e).lower()
2558
+ if TORCH_OK and ("out of memory" in msg or "resource" in msg or "alloc" in msg):
2559
+ status_cb(f"Median seed: GPU OOM on tile ({h}x{w}); falling back to NumPy for this tile.")
2560
+ if want_c == 1:
2561
+ slab = _read_slab_for_channel()
2562
+ seed[y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2563
+ else:
2564
+ for c in range(want_c):
2565
+ slab = _read_slab_for_channel(csel=c)
2566
+ seed[c, y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2567
+ else:
2568
+ raise
2569
+
2570
+ done += 1
2571
+ # report tile completion (remaining 20% of tile span reserved for median compute)
2572
+ progress_cb(done / total, _tile_msg(done, total))
2573
+
2574
+ if (done & 3) == 0:
2575
+ _process_gui_events_safely()
2576
+ status_cb(f"Median seed: want_c={want_c}, seed_shape={seed.shape}")
2577
+ return seed
2578
+
2579
+ def _seed_bootstrap_streaming(paths, Ht, Wt, color_mode,
2580
+ bootstrap_frames: int = 20,
2581
+ clip_sigma: float = 5.0,
2582
+ status_cb=lambda s: None,
2583
+ progress_cb=None):
2584
+ """
2585
+ Seed = average first B frames, estimate global MAD threshold, masked-mean on B,
2586
+ then stream the remaining frames with σ-clipped Welford μ–σ updates.
2587
+
2588
+ Returns float32, CHW for per-channel mode, or CHW(1,...) for luma (your caller later squeezes if needed).
2589
+ """
2590
+ def p(frac, msg):
2591
+ if progress_cb:
2592
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2593
+
2594
+ n = len(paths)
2595
+ B = int(max(1, min(int(bootstrap_frames), n)))
2596
+ status_cb(f"Seed: bootstrap={B}, clip_sigma={clip_sigma}")
2597
+
2598
+ # ---------- pass 1: running mean over the first B frames (Welford μ only) ----------
2599
+ cnt = None
2600
+ mu = None
2601
+ for i, pth in enumerate(paths[:B], 1):
2602
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2603
+ x = ys[0].astype(np.float32, copy=False)
2604
+ if mu is None:
2605
+ mu = x.copy()
2606
+ cnt = np.ones_like(x, dtype=np.float32)
2607
+ else:
2608
+ delta = x - mu
2609
+ cnt += 1.0
2610
+ mu += delta / cnt
2611
+ if (i == B) or (i % 4) == 0:
2612
+ p(0.20 * (i / float(B)), f"bootstrap mean {i}/{B}")
2613
+
2614
+ # ---------- pass 2: estimate a *global* MAD around μ using strided samples ----------
2615
+ stride = 4 if max(Ht, Wt) > 1500 else 2
2616
+ # CHW or HW → build a slice that keeps C and strides H,W
2617
+ samp_slices = (slice(None) if mu.ndim == 3 else slice(None),
2618
+ slice(None, None, stride),
2619
+ slice(None, None, stride)) if mu.ndim == 3 else \
2620
+ (slice(None, None, stride), slice(None, None, stride))
2621
+
2622
+ mad_samples = []
2623
+ for i, pth in enumerate(paths[:B], 1):
2624
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2625
+ x = ys[0].astype(np.float32, copy=False)
2626
+ d = np.abs(x - mu)
2627
+ mad_samples.append(d[samp_slices].ravel())
2628
+ if (i == B) or (i % 8) == 0:
2629
+ p(0.35 + 0.10 * (i / float(B)), f"bootstrap MAD {i}/{B}")
2630
+
2631
+ # robust MAD estimate (scalar) → 4·MAD clip band
2632
+ mad_est = float(np.median(np.concatenate(mad_samples).astype(np.float32)))
2633
+ thr = 4.0 * max(mad_est, 1e-6)
2634
+
2635
+ # ---------- pass 3: masked mean over first B frames using the global threshold ----------
2636
+ sum_acc = np.zeros_like(mu, dtype=np.float32)
2637
+ cnt_acc = np.zeros_like(mu, dtype=np.float32)
2638
+ for i, pth in enumerate(paths[:B], 1):
2639
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2640
+ x = ys[0].astype(np.float32, copy=False)
2641
+ m = (np.abs(x - mu) <= thr).astype(np.float32, copy=False)
2642
+ sum_acc += x * m
2643
+ cnt_acc += m
2644
+ if (i == B) or (i % 4) == 0:
2645
+ p(0.48 + 0.12 * (i / float(B)), f"masked mean {i}/{B}")
2646
+
2647
+ seed = mu.copy()
2648
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2649
+
2650
+ # ---------- pass 4: μ–σ streaming on the remaining frames (σ-clipped Welford) ----------
2651
+ M2 = np.zeros_like(seed, dtype=np.float32) # sum of squared diffs
2652
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2653
+ mu = seed.astype(np.float32, copy=False)
2654
+
2655
+ remain = n - B
2656
+ k = float(clip_sigma)
2657
+ for j, pth in enumerate(paths[B:], 1):
2658
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2659
+ x = ys[0].astype(np.float32, copy=False)
2660
+
2661
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2662
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2663
+
2664
+ acc = (np.abs(x - mu) <= (k * sigma)).astype(np.float32, copy=False) # {0,1} mask
2665
+ n_new = cnt + acc
2666
+ delta = x - mu
2667
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2668
+ M2 = M2 + acc * delta * (x - mu_n)
2669
+
2670
+ mu, cnt = mu_n, n_new
2671
+
2672
+ if (j == remain) or (j % 8) == 0:
2673
+ p(0.60 + 0.40 * (j / float(max(1, remain))), f"μ–σ refine {j}/{remain}")
2674
+
2675
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2676
+
2677
+
2678
+ def _coerce_sr_factor(srf, *, default_on_bad=2):
2679
+ """
2680
+ Parse super-res factor robustly:
2681
+ - accepts 2, '2', '2x', ' 2 X ', 2.0
2682
+ - clamps to integers >= 1
2683
+ - if invalid/missing → returns default_on_bad (we want 2 by your request)
2684
+ """
2685
+ if srf is None:
2686
+ return int(default_on_bad)
2687
+ if isinstance(srf, (float, int)):
2688
+ r = int(round(float(srf)))
2689
+ return int(r if r >= 1 else default_on_bad)
2690
+ s = str(srf).strip().lower()
2691
+ # common GUIs pass e.g. "2x", "3×", etc.
2692
+ s = s.replace("×", "x")
2693
+ if s.endswith("x"):
2694
+ s = s[:-1]
2695
+ try:
2696
+ r = int(round(float(s)))
2697
+ return int(r if r >= 1 else default_on_bad)
2698
+ except Exception:
2699
+ return int(default_on_bad)
2700
+
2701
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2702
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2703
+ k = np.asarray(k, dtype=np.float32)
2704
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2705
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2706
+ if kh == K and kw == K:
2707
+ return k
2708
+ out = np.zeros((K, K), dtype=np.float32)
2709
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2710
+ out[y0:y0+kh, x0:x0+kw] = k
2711
+ s = float(out.sum())
2712
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2713
+
2714
+ # -----------------------------
2715
+ # Core
2716
+ # -----------------------------
2717
+ def multiframe_deconv(
2718
+ paths,
2719
+ out_path,
2720
+ iters=20,
2721
+ kappa=2.0,
2722
+ color_mode="luma",
2723
+ seed_mode: str = "robust",
2724
+ huber_delta=0.0,
2725
+ masks=None,
2726
+ variances=None,
2727
+ rho="huber",
2728
+ status_cb=lambda s: None,
2729
+ min_iters: int = 3,
2730
+ use_star_masks: bool = False,
2731
+ use_variance_maps: bool = False,
2732
+ star_mask_cfg: dict | None = None,
2733
+ varmap_cfg: dict | None = None,
2734
+ save_intermediate: bool = False,
2735
+ save_every: int = 1,
2736
+ # SR options
2737
+ super_res_factor: int = 1,
2738
+ sr_sigma: float = 1.1,
2739
+ sr_psf_opt_iters: int = 250,
2740
+ sr_psf_opt_lr: float = 0.1,
2741
+ # NEW
2742
+ batch_frames: int | None = None,
2743
+ # GPU tuning (optional knobs)
2744
+ mixed_precision: bool | None = None, # default: auto (True on CUDA/MPS)
2745
+ fft_kernel_threshold: int = 1024, # switch to FFT if K >= this (or lower if SR)
2746
+ prefetch_batches: bool = True, # CPU→GPU double-buffer prefetch
2747
+ use_channels_last: bool | None = None, # default: auto (True on CUDA/MPS)
2748
+ force_cpu: bool = False,
2749
+ star_mask_ref_path: str | None = None,
2750
+ low_mem: bool = False,
2751
+ ):
2752
+ """
2753
+ Streaming multi-frame deconvolution with optional SR (r>1).
2754
+ Optimized GPU path: AMP for convs, channels-last, pinned-memory prefetch, optional FFT for large kernels.
2755
+ """
2756
+ mixed_precision = False
2757
+ DEBUG_FLAT_WEIGHTS = False
2758
+ # ---------- local helpers (kept self-contained) ----------
2759
+ def _emit_pct(pct: float, msg: str | None = None):
2760
+ pct = float(max(0.0, min(1.0, pct)))
2761
+ status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
2762
+
2763
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2764
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2765
+ k = np.asarray(k, dtype=np.float32)
2766
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2767
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2768
+ if kh == K and kw == K:
2769
+ return k
2770
+ out = np.zeros((K, K), dtype=np.float32)
2771
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2772
+ out[y0:y0+kh, x0:x0+kw] = k
2773
+ s = float(out.sum())
2774
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2775
+
2776
+ max_iters = max(1, int(iters))
2777
+ min_iters = max(1, int(min_iters))
2778
+ if min_iters > max_iters:
2779
+ min_iters = max_iters
2780
+
2781
+ n_frames = len(paths)
2782
+ status_cb(f"MFDeconv: scanning {n_frames} aligned frames (memmap)…")
2783
+ _emit_pct(0.02, "scanning")
2784
+
2785
+ # choose common intersection size without loading full pixels
2786
+ Ht, Wt = _common_hw_from_paths(paths)
2787
+ _emit_pct(0.05, "preparing")
2788
+
2789
+ # --- LOW-MEM PATCH (begin) ---
2790
+ if low_mem:
2791
+ # Cap decoded-frame LRU to keep peak RAM sane on 16 GB laptops
2792
+ try:
2793
+ _FRAME_LRU.cap = max(1, min(getattr(_FRAME_LRU, "cap", 8), 2))
2794
+ except Exception:
2795
+ pass
2796
+
2797
+ # Disable CPU→GPU prefetch to avoid double-buffering allocations
2798
+ prefetch_batches = False
2799
+
2800
+ # Relax SEP background grid & star detection canvas when requested
2801
+ if use_variance_maps:
2802
+ varmap_cfg = {**(varmap_cfg or {})}
2803
+ # fewer, larger tiles → fewer big temporaries
2804
+ varmap_cfg.setdefault("bw", 96)
2805
+ varmap_cfg.setdefault("bh", 96)
2806
+
2807
+ if use_star_masks:
2808
+ star_mask_cfg = {**(star_mask_cfg or {})}
2809
+ # shrink detection canvas to limit temp buffers inside SEP/mask draw
2810
+ star_mask_cfg["max_side"] = int(min(1024, int(star_mask_cfg.get("max_side", 2048))))
2811
+ # --- LOW-MEM PATCH (end) ---
2812
+
2813
+
2814
+ if any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths):
2815
+ status_cb("MFDeconv: priming XISF cache (one-time decode per frame)…")
2816
+ for i, p in enumerate(paths, 1):
2817
+ try:
2818
+ _ = _xisf_cached_array(p) # decode once, store memmap
2819
+ except Exception as e:
2820
+ status_cb(f"XISF cache failed for {p}: {e}")
2821
+ if (i & 7) == 0 or i == len(paths):
2822
+ _process_gui_events_safely()
2823
+
2824
+ # per-frame loader & sequence view (closures capture Ht/Wt/color_mode/paths)
2825
+ def _load_frame_chw(i: int):
2826
+ return _FRAME_LRU.get(paths[i], Ht, Wt, color_mode)
2827
+
2828
+ class _FrameSeq:
2829
+ def __len__(self): return len(paths)
2830
+ def __getitem__(self, i): return _load_frame_chw(i)
2831
+ data = _FrameSeq()
2832
+
2833
+ # ---- torch detection (optional) ----
2834
+ global torch, TORCH_OK
2835
+ torch = None
2836
+ TORCH_OK = False
2837
+ cuda_ok = mps_ok = dml_ok = False
2838
+ dml_device = None
2839
+
2840
+ try:
2841
+ from setiastro.saspro.runtime_torch import import_torch
2842
+ torch = import_torch(prefer_cuda=True, status_cb=status_cb)
2843
+ TORCH_OK = True
2844
+ try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
2845
+ except Exception as e:
2846
+ import logging
2847
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2848
+ try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
2849
+ except Exception as e:
2850
+ import logging
2851
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2852
+ try:
2853
+ import torch_directml
2854
+ dml_device = torch_directml.device()
2855
+ _ = (torch.ones(1, device=dml_device) + 1).item()
2856
+ dml_ok = True
2857
+ # NEW: expose to _torch_device()
2858
+ globals()["dml_ok"] = True
2859
+ globals()["dml_device"] = dml_device
2860
+ except Exception:
2861
+ dml_ok = False
2862
+ globals()["dml_ok"] = False
2863
+ globals()["dml_device"] = None
2864
+
2865
+ if cuda_ok:
2866
+ status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
2867
+ elif mps_ok:
2868
+ status_cb("PyTorch MPS (Apple) available: True")
2869
+ elif dml_ok:
2870
+ status_cb("PyTorch DirectML (Windows) available: True")
2871
+ else:
2872
+ status_cb("PyTorch present, using CPU backend.")
2873
+ status_cb(
2874
+ f"PyTorch {getattr(torch,'__version__','?')} backend: "
2875
+ + ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
2876
+ )
2877
+
2878
+ try:
2879
+ # keep cuDNN autotune on
2880
+ if getattr(torch.backends, "cudnn", None) is not None:
2881
+ torch.backends.cudnn.benchmark = True
2882
+ torch.backends.cudnn.allow_tf32 = False
2883
+ except Exception:
2884
+ pass
2885
+
2886
+ try:
2887
+ # disable TF32 matmul shortcuts if present (CUDA-only; safe no-op elsewhere)
2888
+ if getattr(getattr(torch.backends, "cuda", None), "matmul", None) is not None:
2889
+ torch.backends.cuda.matmul.allow_tf32 = False
2890
+ except Exception:
2891
+ pass
2892
+
2893
+ try:
2894
+ # prefer highest FP32 precision on PT 2.x
2895
+ if hasattr(torch, "set_float32_matmul_precision"):
2896
+ torch.set_float32_matmul_precision("highest")
2897
+ except Exception:
2898
+ pass
2899
+ except Exception as e:
2900
+ TORCH_OK = False
2901
+ status_cb(f"PyTorch not available → CPU path. ({e})")
2902
+
2903
+ if force_cpu:
2904
+ status_cb("⚠️ CPU-only debug mode: disabling PyTorch path.")
2905
+ TORCH_OK = False
2906
+
2907
+ _process_gui_events_safely()
2908
+
2909
+ # ---- PSFs + optional assets (computed in parallel, streaming I/O) ----
2910
+ psfs, masks_auto, vars_auto = _build_psf_and_assets(
2911
+ paths,
2912
+ make_masks=bool(use_star_masks),
2913
+ make_varmaps=bool(use_variance_maps),
2914
+ status_cb=status_cb,
2915
+ save_dir=None,
2916
+ star_mask_cfg=star_mask_cfg,
2917
+ varmap_cfg=varmap_cfg,
2918
+ star_mask_ref_path=star_mask_ref_path,
2919
+ # NEW:
2920
+ Ht=Ht, Wt=Wt, color_mode=color_mode,
2921
+ )
2922
+
2923
+ try:
2924
+ import gc as _gc
2925
+ _gc.collect()
2926
+ except Exception:
2927
+ pass
2928
+
2929
+ psfs_native = psfs # keep a reference to the original native PSFs for fallback
2930
+
2931
+ # ---- SR lift of PSFs if needed ----
2932
+ r_req = _coerce_sr_factor(super_res_factor, default_on_bad=2)
2933
+ status_cb(f"MFDeconv: SR factor requested={super_res_factor!r} → using r={r_req}")
2934
+ r = int(r_req)
2935
+
2936
+ if r > 1:
2937
+ status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
2938
+ _process_gui_events_safely()
2939
+
2940
+ def _naive_sr_from_native(f_nat: np.ndarray, r_: int) -> np.ndarray:
2941
+ f2 = np.asarray(f_nat, np.float32)
2942
+ H, W = f2.shape[:2]
2943
+ k_sq = min(H, W)
2944
+ if H != W:
2945
+ y0 = (H - k_sq) // 2; x0 = (W - k_sq) // 2
2946
+ f2 = f2[y0:y0+k_sq, x0:x0+k_sq]
2947
+ if (f2.shape[0] % 2) == 0:
2948
+ f2 = f2[1:, 1:] # make odd
2949
+ f2 = _normalize_psf(f2)
2950
+ h0 = np.zeros((f2.shape[0]*r_, f2.shape[1]*r_), np.float32)
2951
+ h0[::r_, ::r_] = f2
2952
+ return _normalize_psf(h0)
2953
+
2954
+ sr_psfs = []
2955
+ for i, k_native in enumerate(psfs, start=1):
2956
+ try:
2957
+ status_cb(f" SR-PSF{i}: native shape={np.asarray(k_native).shape}")
2958
+ h = _solve_super_psf_from_native(
2959
+ k_native, r=r, sigma=float(sr_sigma),
2960
+ iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr)
2961
+ )
2962
+ except Exception as e:
2963
+ status_cb(f" SR-PSF{i} failed: {e!r} → using naïve upsample")
2964
+ h = _naive_sr_from_native(k_native, r)
2965
+
2966
+ # guarantee odd size for downstream SAME-padding math
2967
+ if (h.shape[0] % 2) == 0:
2968
+ h = h[:-1, :-1]
2969
+ sr_psfs.append(h.astype(np.float32, copy=False))
2970
+ status_cb(f" SR-PSF{i}: native {np.asarray(k_native).shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
2971
+ psfs = sr_psfs
2972
+
2973
+
2974
+
2975
+ # ---- Seed (streaming) with robust bootstrap already in file helpers ----
2976
+ _emit_pct(0.25, "Calculating Seed Image...")
2977
+ def _seed_progress(frac, msg):
2978
+ _emit_pct(0.25 + 0.15 * float(frac), f"seed: {msg}")
2979
+
2980
+ seed_mode_s = str(seed_mode).lower().strip()
2981
+ if seed_mode_s not in ("robust", "median"):
2982
+ seed_mode_s = "robust"
2983
+ if seed_mode_s == "median":
2984
+ status_cb("MFDeconv: Building median seed (tiled, streaming)…")
2985
+ seed_native = _seed_median_streaming(
2986
+ paths, Ht, Wt,
2987
+ color_mode=color_mode,
2988
+ tile_hw=(256, 256),
2989
+ status_cb=status_cb,
2990
+ progress_cb=_seed_progress,
2991
+ use_torch=TORCH_OK, # ← auto: GPU if available, else NumPy
2992
+ )
2993
+ else:
2994
+ seed_native = _seed_bootstrap_streaming(
2995
+ paths, Ht, Wt, color_mode,
2996
+ bootstrap_frames=20, clip_sigma=5,
2997
+ status_cb=status_cb, progress_cb=_seed_progress
2998
+ )
2999
+
3000
+ # lift seed if SR
3001
+ if r > 1:
3002
+ target_hw = (Ht * r, Wt * r)
3003
+ if seed_native.ndim == 2:
3004
+ x = _upsample_sum(seed_native / (r*r), r, target_hw=target_hw)
3005
+ else:
3006
+ C, Hn, Wn = seed_native.shape
3007
+ x = np.stack(
3008
+ [_upsample_sum(seed_native[c] / (r*r), r, target_hw=target_hw) for c in range(C)],
3009
+ axis=0
3010
+ )
3011
+ else:
3012
+ x = seed_native
3013
+
3014
+ # FINAL SHAPE CHECKS (auto-correct if a GUI sent something odd)
3015
+ if x.ndim == 2: x = x[None, ...]
3016
+ Hs, Ws = x.shape[-2], x.shape[-1]
3017
+ if r > 1:
3018
+ expected_H, expected_W = Ht * r, Wt * r
3019
+ if (Hs, Ws) != (expected_H, expected_W):
3020
+ status_cb(f"SR seed grid mismatch: got {(Hs, Ws)}, expected {(expected_H, expected_W)} → correcting")
3021
+ # Rebuild from the native mean to ensure exact SR size
3022
+ x = _upsample_sum(x if x.ndim==2 else x[0], r, target_hw=(expected_H, expected_W))
3023
+ if x.ndim == 2: x = x[None, ...]
3024
+ try: del seed_native
3025
+ except Exception as e:
3026
+ import logging
3027
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3028
+
3029
+ try:
3030
+ import gc as _gc
3031
+ _gc.collect()
3032
+ except Exception:
3033
+ pass
3034
+
3035
+ flip_psf = [_flip_kernel(k) for k in psfs]
3036
+ _emit_pct(0.20, "PSF Ready")
3037
+
3038
+ # --- Harmonize seed channels with the actual frames ---
3039
+ # Probe the first frame's channel count (CHW from the cache)
3040
+ try:
3041
+ y0_probe = _load_frame_chw(0) # CHW float32
3042
+ C_ref = 1 if y0_probe.ndim == 2 else int(y0_probe.shape[0])
3043
+ except Exception:
3044
+ # Fallback: infer from seed if probing fails
3045
+ C_ref = 1 if x.ndim == 2 else int(x.shape[0])
3046
+
3047
+ # If median seed came back mono but frames are RGB, broadcast it
3048
+ if x.ndim == 2 and C_ref == 3:
3049
+ x = np.stack([x] * 3, axis=0)
3050
+ elif x.ndim == 3 and x.shape[0] == 1 and C_ref == 3:
3051
+ x = np.repeat(x, 3, axis=0)
3052
+ # If seed is RGB but frames are mono (rare), collapse to luma
3053
+ elif x.ndim == 3 and x.shape[0] == 3 and C_ref == 1:
3054
+ # ITU-R BT.709 luma (safe in float)
3055
+ x = (0.2126 * x[0] + 0.7152 * x[1] + 0.0722 * x[2]).astype(np.float32)
3056
+
3057
+ # Ensure CHW shape for the rest of the pipeline
3058
+ if x.ndim == 2:
3059
+ x = x[None, ...]
3060
+ elif x.ndim == 3 and x.shape[0] not in (1, 3) and x.shape[-1] in (1, 3):
3061
+ x = np.moveaxis(x, -1, 0)
3062
+
3063
+ # Set the expected channel count from the FRAMES, not from the seed
3064
+ C_EXPECTED = int(C_ref)
3065
+ _, Hs, Ws = x.shape
3066
+
3067
+
3068
+
3069
+ # ---- choose default batch size ----
3070
+ if batch_frames is None:
3071
+ px = Hs * Ws
3072
+ if px >= 16_000_000: auto_B = 2
3073
+ elif px >= 8_000_000: auto_B = 4
3074
+ else: auto_B = 8
3075
+ else:
3076
+ auto_B = int(max(1, batch_frames))
3077
+
3078
+ # --- LOW-MEM PATCH: clamp batch size hard ---
3079
+ if low_mem:
3080
+ auto_B = max(1, min(auto_B, 2))
3081
+
3082
+ # ---- background/MAD telemetry (first frame) ----
3083
+ status_cb("MFDeconv: Calculating Backgrounds and MADs…")
3084
+ _process_gui_events_safely()
3085
+ try:
3086
+ y0 = data[0]; y0l = y0 if y0.ndim == 2 else y0[0]
3087
+ med = float(np.median(y0l)); mad = float(np.median(np.abs(y0l - med))) + 1e-6
3088
+ bg_est = 1.4826 * mad
3089
+ except Exception:
3090
+ bg_est = 0.0
3091
+ status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
3092
+
3093
+ # ---- mask/variance accessors ----
3094
+ def _mask_for(i, like_img):
3095
+ src = (masks if masks is not None else masks_auto)
3096
+ if src is None: # no masks at all
3097
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3098
+ m = src[i]
3099
+ if m is None:
3100
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3101
+ m = np.asarray(m, dtype=np.float32)
3102
+ if m.ndim == 3: m = m[0]
3103
+ return _center_crop(m, like_img.shape[-2], like_img.shape[-1]).astype(np.float32, copy=False)
3104
+
3105
+ def _var_for(i, like_img):
3106
+ src = (variances if variances is not None else vars_auto)
3107
+ if src is None: return None
3108
+ v = src[i]
3109
+ if v is None: return None
3110
+ v = np.asarray(v, dtype=np.float32)
3111
+ if v.ndim == 3: v = v[0]
3112
+ v = _center_crop(v, like_img.shape[-2], like_img.shape[-1])
3113
+ return np.clip(v, 1e-8, None).astype(np.float32, copy=False)
3114
+
3115
+ # ---- NumPy path conv helper (keep as-is) ----
3116
+ def _conv_np_same(a, k, out=None):
3117
+ y = _conv_same_np_spatial(a, k, out)
3118
+ if y is not None:
3119
+ return y
3120
+ # No OpenCV → always use the ifftshifted FFT path
3121
+ import numpy as _np
3122
+ import numpy.fft as _fft
3123
+ H, W = a.shape[-2:]
3124
+ kh, kw = k.shape
3125
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
3126
+ Kf = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW))
3127
+ if out is None:
3128
+ out = _np.empty_like(a, dtype=_np.float32)
3129
+ return _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out)
3130
+
3131
+ # ---- allocate scratch & prepare PSF tensors if torch ----
3132
+ relax = 0.7
3133
+ use_torch = bool(TORCH_OK)
3134
+ cm = _safe_inference_context() if use_torch else NO_GRAD
3135
+ rho_is_l2 = (str(rho).lower() == "l2")
3136
+ local_delta = 0.0 if rho_is_l2 else huber_delta
3137
+
3138
+ if use_torch:
3139
+ F = torch.nn.functional
3140
+ device = _torch_device()
3141
+
3142
+ # Force FP32 tensors
3143
+ x_t = _to_t(_contig(x)).to(torch.float32)
3144
+ num = torch.zeros_like(x_t, dtype=torch.float32)
3145
+ den = torch.zeros_like(x_t, dtype=torch.float32)
3146
+
3147
+ # channels-last preference kept (does not change dtype)
3148
+ if use_channels_last is None:
3149
+ use_channels_last = bool(cuda_ok) # <- never enable on MPS
3150
+ if mps_ok:
3151
+ use_channels_last = False # <- force NCHW on MPS
3152
+
3153
+ # PSF tensors strictly FP32
3154
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
3155
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
3156
+ use_spatial = _torch_should_use_spatial(psf_t[0].shape[-1])
3157
+
3158
+ # No mixed precision, no autocast
3159
+ use_amp = False
3160
+ amp_cm = None
3161
+ amp_kwargs = {}
3162
+
3163
+ # FFT gate (worth it for large kernels / SR); still FP32
3164
+ # Decide FFT vs spatial FIRST, based on native kernel sizes (no global padding!)
3165
+ K_native_max = max(int(k.shape[0]) for k in psfs)
3166
+ use_fft = False
3167
+ if device.type == "cuda" and ((K_native_max >= int(fft_kernel_threshold)) or (r > 1 and K_native_max >= max(21, int(fft_kernel_threshold) - 4))):
3168
+ use_fft = True
3169
+ # Force NCHW for FFT branch to keep channel slices contiguous/sane
3170
+ use_channels_last = False
3171
+ # Precompute FFT packs…
3172
+ psf_fft, psfT_fft = _precompute_torch_psf_ffts(
3173
+ psfs, flip_psf, Hs, Ws, device=x_t.device, dtype=torch.float32
3174
+ )
3175
+ else:
3176
+ psf_fft = psfT_fft = None
3177
+ Kmax = K_native_max
3178
+ if (Kmax % 2) == 0:
3179
+ Kmax += 1
3180
+ if any(int(k.shape[0]) != Kmax for k in psfs):
3181
+ status_cb(f"MFDeconv: normalizing PSF sizes → {Kmax}×{Kmax}")
3182
+ psfs = [_pad_kernel_to(k, Kmax) for k in psfs]
3183
+ flip_psf = [_flip_kernel(k) for k in psfs]
3184
+
3185
+ # (Re)build spatial kernels strictly as contiguous FP32 tensors
3186
+ psf_t = [_to_t(_contig(k))[None, None].to(torch.float32).contiguous() for k in psfs]
3187
+ psfT_t = [_to_t(_contig(kT))[None, None].to(torch.float32).contiguous() for kT in flip_psf]
3188
+ else:
3189
+ x_t = _contig(x).astype(np.float32, copy=False)
3190
+ num = np.zeros_like(x_t, dtype=np.float32)
3191
+ den = np.zeros_like(x_t, dtype=np.float32)
3192
+ use_amp = False
3193
+ use_fft = False
3194
+
3195
+ # ---- batched torch helper (grouped depthwise per-sample) ----
3196
+ if use_torch:
3197
+ # inside `if use_torch:` block in multiframe_deconv — replace the whole inner helper
3198
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
3199
+ """
3200
+ x_bc_hw : (B,C,H,W), torch.float32 on device
3201
+ w_b1kk : (B,1,kh,kw), torch.float32 on device
3202
+ Returns (B,C,H,W) contiguous (NCHW).
3203
+ """
3204
+ F = torch.nn.functional
3205
+
3206
+ # Force standard NCHW contiguous tensors
3207
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
3208
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
3209
+
3210
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
3211
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
3212
+
3213
+ if x_bc_hw.device.type == "mps":
3214
+ # Safe, slower path: convolve each channel separately, no groups
3215
+ ys = []
3216
+ for j in range(B): # per sample
3217
+ xj = x_bc_hw[j:j+1] # (1,C,H,W)
3218
+ # reflect pad once per sample
3219
+ xj = F.pad(xj, pad, mode="reflect")
3220
+ cj_out = []
3221
+ # one shared kernel per sample j: (1,1,kh,kw)
3222
+ kj = w_b1kk[j:j+1] # keep shape (1,1,kh,kw)
3223
+ for c in range(C):
3224
+ # slice that channel as its own (1,1,H,W) tensor
3225
+ xjc = xj[:, c:c+1, ...]
3226
+ yjc = F.conv2d(xjc, kj, padding=0, groups=1) # no groups
3227
+ cj_out.append(yjc)
3228
+ ys.append(torch.cat(cj_out, dim=1)) # (1,C,H,W)
3229
+ return torch.stack([y[0] for y in ys], 0).contiguous()
3230
+
3231
+
3232
+ # ---- FAST PATH (CUDA/CPU): single grouped conv with G=B*C ----
3233
+ G = int(B * C)
3234
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
3235
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
3236
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
3237
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
3238
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
3239
+
3240
+
3241
+ def _downsample_avg_bt_t(x, r_):
3242
+ if r_ <= 1:
3243
+ return x
3244
+ B, C, H, W = x.shape
3245
+ Hr, Wr = (H // r_) * r_, (W // r_) * r_
3246
+ if Hr == 0 or Wr == 0:
3247
+ return x
3248
+ return x[:, :, :Hr, :Wr].reshape(B, C, Hr // r_, r_, Wr // r_, r_).mean(dim=(3, 5))
3249
+
3250
+ def _upsample_sum_bt_t(x, r_):
3251
+ if r_ <= 1:
3252
+ return x
3253
+ return x.repeat_interleave(r_, dim=-2).repeat_interleave(r_, dim=-1)
3254
+
3255
+ def _make_pinned_batch(idx, C_expected, to_device_dtype):
3256
+ y_list, m_list, v_list = [], [], []
3257
+ for fi in idx:
3258
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3259
+ if y_chw.ndim != 3:
3260
+ raise RuntimeError(f"Frame {fi}: expected CHW, got {tuple(y_chw.shape)}")
3261
+ C_here = int(y_chw.shape[0])
3262
+ if C_expected is not None and C_here != C_expected:
3263
+ raise RuntimeError(f"Mixed channel counts: expected C={C_expected}, got C={C_here} (frame {fi})")
3264
+
3265
+ m2d = _mask_for(fi, y_chw)
3266
+ v2d = _var_for(fi, y_chw)
3267
+ y_list.append(y_chw); m_list.append(m2d); v_list.append(v2d)
3268
+
3269
+ # CPU (NCHW) tensors
3270
+ y_cpu = torch.from_numpy(np.stack(y_list, 0)).to(torch.float32).contiguous()
3271
+ m_cpu = torch.from_numpy(np.stack(m_list, 0)).to(torch.float32).contiguous()
3272
+ have_v = all(v is not None for v in v_list)
3273
+ vb_cpu = None if not have_v else torch.from_numpy(np.stack(v_list, 0)).to(torch.float32).contiguous()
3274
+
3275
+ # optional pin for faster H2D on CUDA
3276
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
3277
+ y_cpu = y_cpu.pin_memory(); m_cpu = m_cpu.pin_memory()
3278
+ if vb_cpu is not None: vb_cpu = vb_cpu.pin_memory()
3279
+
3280
+ # move to device in FP32, keep NCHW contiguous format
3281
+ yb = y_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3282
+ mb = m_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3283
+ vb = None if vb_cpu is None else vb_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3284
+
3285
+ return yb, mb, vb
3286
+
3287
+
3288
+
3289
+ # ---- intermediates folder ----
3290
+ iter_dir = None
3291
+ hdr0_seed = None
3292
+ if save_intermediate:
3293
+ iter_dir = _iter_folder(out_path)
3294
+ status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
3295
+ try:
3296
+ hdr0 = _safe_primary_header(paths[0])
3297
+ except Exception:
3298
+ hdr0 = fits.Header()
3299
+ _save_iter_image(x, hdr0, iter_dir, "seed", color_mode)
3300
+
3301
+ # ---- iterative loop ----
3302
+
3303
+
3304
+
3305
+ auto_delta_cache = None
3306
+ if use_torch and (huber_delta < 0) and (not rho_is_l2):
3307
+ auto_delta_cache = [None] * n_frames
3308
+ early = EarlyStopper(
3309
+ tol_upd_floor=2e-4,
3310
+ tol_rel_floor=5e-4,
3311
+ early_frac=0.40,
3312
+ ema_alpha=0.5,
3313
+ patience=2,
3314
+ min_iters=min_iters,
3315
+ status_cb=status_cb
3316
+ )
3317
+
3318
+ used_iters = 0
3319
+ early_stopped = False
3320
+ with cm():
3321
+ for it in range(1, max_iters + 1):
3322
+ # reset accumulators
3323
+ if use_torch:
3324
+ num.zero_(); den.zero_()
3325
+ else:
3326
+ num.fill(0.0); den.fill(0.0)
3327
+
3328
+ if use_torch:
3329
+ # ---- batched GPU path (hardened with busy→CPU fallback) ----
3330
+ retry_cpu = False # if True, we’ll redo this iteration’s accumulators on CPU
3331
+
3332
+ try:
3333
+ frame_idx = list(range(n_frames))
3334
+ B_cur = int(max(1, auto_B))
3335
+ ci = 0
3336
+ Cn = None
3337
+
3338
+ # AMP context helper
3339
+ def _maybe_amp():
3340
+ if use_amp and (amp_cm is not None):
3341
+ return amp_cm(**amp_kwargs)
3342
+ from contextlib import nullcontext
3343
+ return nullcontext()
3344
+
3345
+ # prefetch first batch (unchanged)
3346
+ if prefetch_batches:
3347
+ idx0 = frame_idx[ci:ci+B_cur]
3348
+ if idx0:
3349
+ Cn = C_EXPECTED
3350
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx0, Cn, x_t.dtype)
3351
+ else:
3352
+ yb_next = mb_next = vb_next = None
3353
+
3354
+ while ci < n_frames:
3355
+ idx = frame_idx[ci:ci + B_cur]
3356
+ B = len(idx)
3357
+ try:
3358
+ # ---- existing gather/conv/weights/backproject code (UNCHANGED) ----
3359
+ if prefetch_batches and (yb_next is not None):
3360
+ yb, mb, vb = yb_next, mb_next, vb_next
3361
+ ci2 = ci + B_cur
3362
+ if ci2 < n_frames:
3363
+ idx2 = frame_idx[ci2:ci2+B_cur]
3364
+ Cn = Cn or (_as_chw(_sanitize_numeric(data[idx[0]])).shape[0])
3365
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx2, Cn, torch.float32)
3366
+ else:
3367
+ yb_next = mb_next = vb_next = None
3368
+ else:
3369
+ Cn = C_EXPECTED
3370
+ yb, mb, vb = _make_pinned_batch(idx, Cn, torch.float32)
3371
+ if use_channels_last:
3372
+ yb = yb.contiguous(memory_format=torch.channels_last)
3373
+
3374
+ wk = torch.cat([psf_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3375
+ wkT = torch.cat([psfT_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3376
+
3377
+ # --- predict on SR grid ---
3378
+ x_bc_hw = x_t.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
3379
+ if use_channels_last:
3380
+ x_bc_hw = x_bc_hw.contiguous(memory_format=torch.channels_last)
3381
+
3382
+ if use_fft:
3383
+ pred_list = []
3384
+ chan_tmp = torch.empty_like(x_t[0], dtype=torch.float32) # (H,W) single-channel
3385
+ for j, fi in enumerate(idx):
3386
+ C_here = x_bc_hw.shape[1]
3387
+ pred_j = torch.empty_like(x_t, dtype=torch.float32) # (C,H,W)
3388
+ Kf_pack = psf_fft[fi]
3389
+ for c in range(C_here):
3390
+ _fft_conv_same_torch(x_bc_hw[j, c], Kf_pack, out_spatial=chan_tmp) # write into chan_tmp
3391
+ pred_j[c].copy_(chan_tmp) # copy once
3392
+ pred_list.append(pred_j)
3393
+ pred_super = torch.stack(pred_list, 0).contiguous()
3394
+ else:
3395
+ pred_super = _grouped_conv_same_torch_per_sample(
3396
+ x_bc_hw, wk, B, Cn
3397
+ ).contiguous()
3398
+
3399
+ pred_low = _downsample_avg_bt_t(pred_super, r) if r > 1 else pred_super
3400
+
3401
+ # --- robust weights (UNCHANGED) ---
3402
+ rnat = yb - pred_low
3403
+
3404
+ # Build/estimate variance map on the native grid (per batch)
3405
+ if vb is None:
3406
+ # robust per-batch variance estimate
3407
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3408
+ vmap = (1.4826 * mad) ** 2
3409
+ # repeat across channels
3410
+ vmap = vmap.repeat(1, Cn, rnat.shape[-2], rnat.shape[-1]).contiguous()
3411
+ else:
3412
+ vmap = vb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3413
+
3414
+ if str(rho).lower() == "l2":
3415
+ # L2 ⇒ psi/r = 1 (no robust clipping). Weighted LS with mask/variance.
3416
+ psi_over_r = torch.ones_like(rnat, dtype=torch.float32, device=x_t.device)
3417
+ else:
3418
+ # Huber ⇒ auto-delta or fixed delta
3419
+ if huber_delta < 0:
3420
+ # ensure we have auto deltas
3421
+ if (auto_delta_cache is None) or any(auto_delta_cache[fi] is None for fi in idx) or (it % 5 == 1):
3422
+ Btmp, C_here, H0, W0 = rnat.shape
3423
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3424
+ rms = 1.4826 * mad
3425
+ # store per-frame deltas
3426
+ if auto_delta_cache is None:
3427
+ # should not happen because we gated creation earlier, but be safe
3428
+ auto_delta_cache = [None] * n_frames
3429
+ for j, fi in enumerate(idx):
3430
+ auto_delta_cache[fi] = float((-huber_delta) * torch.clamp(rms[j, 0, 0, 0], min=1e-6).item())
3431
+ deltas = torch.tensor([auto_delta_cache[fi] for fi in idx],
3432
+ device=x_t.device, dtype=torch.float32).view(B, 1, 1, 1)
3433
+ else:
3434
+ deltas = torch.tensor(float(huber_delta), device=x_t.device,
3435
+ dtype=torch.float32).view(1, 1, 1, 1)
3436
+
3437
+ absr = rnat.abs()
3438
+ psi_over_r = torch.where(absr <= deltas, torch.ones_like(absr, dtype=torch.float32),
3439
+ deltas / (absr + EPS))
3440
+
3441
+ # compose weights with mask and variance
3442
+ if DEBUG_FLAT_WEIGHTS:
3443
+ wmap_low = mb.unsqueeze(1) # debug: mask-only weighting
3444
+ else:
3445
+ m1 = mb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3446
+ wmap_low = (psi_over_r / (vmap + EPS)) * m1
3447
+ wmap_low = torch.nan_to_num(wmap_low, nan=0.0, posinf=0.0, neginf=0.0)
3448
+
3449
+ # --- adjoint + backproject (UNCHANGED) ---
3450
+ if r > 1:
3451
+ up_y = _upsample_sum_bt_t(wmap_low * yb, r)
3452
+ up_pred = _upsample_sum_bt_t(wmap_low * pred_low, r)
3453
+ else:
3454
+ up_y, up_pred = wmap_low * yb, wmap_low * pred_low
3455
+
3456
+ if use_fft:
3457
+ back_num_list, back_den_list = [], []
3458
+ for j, fi in enumerate(idx):
3459
+ C_here = up_y.shape[1]
3460
+ bn_j = torch.empty_like(x_t, dtype=torch.float32)
3461
+ bd_j = torch.empty_like(x_t, dtype=torch.float32)
3462
+ KTf_pack = psfT_fft[fi]
3463
+ for c in range(C_here):
3464
+ _fft_conv_same_torch(up_y[j, c], KTf_pack, out_spatial=bn_j[c])
3465
+ _fft_conv_same_torch(up_pred[j, c], KTf_pack, out_spatial=bd_j[c])
3466
+ back_num_list.append(bn_j)
3467
+ back_den_list.append(bd_j)
3468
+ back_num = torch.stack(back_num_list, 0).sum(dim=0)
3469
+ back_den = torch.stack(back_den_list, 0).sum(dim=0)
3470
+ else:
3471
+ back_num = _grouped_conv_same_torch_per_sample(up_y, wkT, B, Cn).sum(dim=0)
3472
+ back_den = _grouped_conv_same_torch_per_sample(up_pred, wkT, B, Cn).sum(dim=0)
3473
+
3474
+ num += back_num
3475
+ den += back_den
3476
+ ci += B
3477
+
3478
+ if os.environ.get("MFDECONV_DEBUG_SYNC", "0") == "1":
3479
+ try:
3480
+ torch.cuda.synchronize()
3481
+ except Exception:
3482
+ pass
3483
+
3484
+ except RuntimeError as e:
3485
+ emsg = str(e).lower()
3486
+ # existing OOM backoff stays the same
3487
+ if ("out of memory" in emsg or "resource" in emsg or "alloc" in emsg) and B_cur > 1:
3488
+ B_cur = max(1, B_cur // 2)
3489
+ status_cb(f"GPU OOM: reducing batch_frames → {B_cur} and retrying this chunk.")
3490
+ if prefetch_batches:
3491
+ yb_next = mb_next = vb_next = None
3492
+ continue
3493
+ # NEW: busy/unavailable → trigger CPU retry for this iteration
3494
+ if _is_cuda_busy_error(e):
3495
+ status_cb("CUDA became unavailable mid-run → retrying this iteration on CPU and switching to CPU for the rest.")
3496
+ retry_cpu = True
3497
+ break # break the inner while; we'll recompute on CPU
3498
+ raise
3499
+
3500
+ _process_gui_events_safely()
3501
+
3502
+ except RuntimeError as e:
3503
+ # Catch outer-level CUDA busy/unavailable too (e.g., from syncs)
3504
+ if _is_cuda_busy_error(e):
3505
+ status_cb("CUDA error indicates device busy/unavailable → retrying this iteration on CPU and switching to CPU for the rest.")
3506
+ retry_cpu = True
3507
+ else:
3508
+ raise
3509
+
3510
+ if retry_cpu:
3511
+ # flip to CPU for this and subsequent iterations
3512
+ globals()["cuda_usable"] = False
3513
+ use_torch = False
3514
+ try:
3515
+ torch.cuda.empty_cache()
3516
+ except Exception:
3517
+ pass
3518
+
3519
+ try:
3520
+ import gc as _gc
3521
+ _gc.collect()
3522
+ except Exception:
3523
+ pass
3524
+
3525
+ # Rebuild accumulators on CPU for this iteration:
3526
+ # 1) convert x_t (seed/current estimate) to NumPy
3527
+ x_t = x_t.detach().cpu().numpy()
3528
+ # 2) reset accumulators as NumPy arrays
3529
+ num = np.zeros_like(x_t, dtype=np.float32)
3530
+ den = np.zeros_like(x_t, dtype=np.float32)
3531
+
3532
+ # 3) RUN THE EXISTING NUMPY ACCUMULATION (same as your else: block below)
3533
+ # 3) recompute this iteration’s accumulators on CPU
3534
+ for fi in range(n_frames):
3535
+ # load one frame as CHW (float32, sanitized)
3536
+
3537
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3538
+
3539
+ # forward predict (super grid) with this frame's PSF
3540
+ pred_super = _conv_np_same(x_t, psfs[fi])
3541
+
3542
+ # downsample to native if SR was on
3543
+ pred_low = _downsample_avg(pred_super, r) if r > 1 else pred_super
3544
+
3545
+ # per-frame mask/variance
3546
+ m2d = _mask_for(fi, y_chw) # 2D, [0..1]
3547
+ v2d = _var_for(fi, y_chw) # 2D or None
3548
+
3549
+ # robust weights per pixel/channel
3550
+ w = _weight_map(
3551
+ y=y_chw, pred=pred_low,
3552
+ huber_delta=local_delta, # 0.0 for L2, else huber_delta
3553
+ var_map=v2d, mask=m2d
3554
+ ).astype(np.float32, copy=False)
3555
+
3556
+ # adjoint (backproject) on super grid
3557
+ if r > 1:
3558
+ up_y = _upsample_sum(w * y_chw, r, target_hw=x_t.shape[-2:])
3559
+ up_pred = _upsample_sum(w * pred_low, r, target_hw=x_t.shape[-2:])
3560
+ else:
3561
+ up_y, up_pred = (w * y_chw), (w * pred_low)
3562
+
3563
+ num += _conv_np_same(up_y, flip_psf[fi])
3564
+ den += _conv_np_same(up_pred, flip_psf[fi])
3565
+
3566
+ # ensure strictly positive denominator
3567
+ den = np.clip(den, 1e-8, None).astype(np.float32, copy=False)
3568
+
3569
+ # switch everything to NumPy for the remainder of the run
3570
+ psf_fft = psfT_fft = None # (Torch packs no longer used)
3571
+
3572
+ # ---- multiplicative RL/MM step with clamping ----
3573
+ if use_torch:
3574
+ ratio = num / (den + EPS)
3575
+ neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
3576
+ ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
3577
+ upd = torch.clamp(ratio, 1.0 / kappa, kappa)
3578
+ x_next = torch.clamp(x_t * upd, min=0.0)
3579
+ # Robust scalars
3580
+ upd_med = torch.median(torch.abs(upd - 1))
3581
+ rel_change = (torch.median(torch.abs(x_next - x_t)) /
3582
+ (torch.median(torch.abs(x_t)) + 1e-8))
3583
+
3584
+ um = float(upd_med.detach().item())
3585
+ rc = float(rel_change.detach().item())
3586
+
3587
+ if early.step(it, max_iters, um, rc):
3588
+ x_t = x_next
3589
+ used_iters = it
3590
+ early_stopped = True
3591
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3592
+ _process_gui_events_safely()
3593
+ break
3594
+
3595
+ x_t = (1.0 - relax) * x_t + relax * x_next
3596
+ else:
3597
+ ratio = num / (den + EPS)
3598
+ neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
3599
+ if np.any(neutral): ratio[neutral] = 1.0
3600
+ upd = np.clip(ratio, 1.0 / kappa, kappa)
3601
+ x_next = np.clip(x_t * upd, 0.0, None)
3602
+
3603
+ um = float(np.median(np.abs(upd - 1.0)))
3604
+ rc = float(np.median(np.abs(x_next - x_t)) / (np.median(np.abs(x_t)) + 1e-8))
3605
+
3606
+ if early.step(it, max_iters, um, rc):
3607
+ x_t = x_next
3608
+ used_iters = it
3609
+ early_stopped = True
3610
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3611
+ _process_gui_events_safely()
3612
+ break
3613
+
3614
+ x_t = (1.0 - relax) * x_t + relax * x_next
3615
+
3616
+ # --- LOW-MEM CLEANUP (per-iteration) ---
3617
+ if low_mem:
3618
+ # Torch temporaries we created in the iteration (best-effort deletes)
3619
+ to_kill = [
3620
+ "pred_super", "pred_low", "wmap_low", "yb", "mb", "vb", "wk", "wkT",
3621
+ "back_num", "back_den", "pred_list", "back_num_list", "back_den_list",
3622
+ "x_bc_hw", "up_y", "up_pred", "psi_over_r", "vmap", "rnat", "deltas", "chan_tmp"
3623
+ ]
3624
+ loc = locals()
3625
+ for _name in to_kill:
3626
+ if _name in loc:
3627
+ try:
3628
+ del loc[_name]
3629
+ except Exception:
3630
+ pass
3631
+
3632
+ # Proactively release CUDA cache every other iter
3633
+ if use_torch:
3634
+ try:
3635
+ dev = _torch_device()
3636
+ if dev.type == "cuda" and (it % 2) == 0:
3637
+ import torch as _t
3638
+ _t.cuda.empty_cache()
3639
+ except Exception:
3640
+ pass
3641
+
3642
+ # Encourage Python to return big NumPy buffers to the OS sooner
3643
+ import gc as _gc
3644
+ _gc.collect()
3645
+
3646
+
3647
+ # ---- save intermediates ----
3648
+ if save_intermediate and (it % int(max(1, save_every)) == 0):
3649
+ try:
3650
+ x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3651
+ _save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
3652
+ except Exception as _e:
3653
+ status_cb(f"Intermediate save failed at iter {it}: {_e}")
3654
+
3655
+ frac = 0.25 + 0.70 * (it / float(max_iters))
3656
+ _emit_pct(frac, f"Iteration {it}/{max_iters}")
3657
+ status_cb(f"Iter {it}/{max_iters}")
3658
+ _process_gui_events_safely()
3659
+
3660
+ if not early_stopped:
3661
+ used_iters = max_iters
3662
+
3663
+ # ---- save result ----
3664
+ _emit_pct(0.97, "saving")
3665
+ x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3666
+ if x_final.ndim == 3:
3667
+ if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
3668
+ x_final = np.moveaxis(x_final, -1, 0)
3669
+ if x_final.shape[0] == 1:
3670
+ x_final = x_final[0]
3671
+
3672
+ try:
3673
+ hdr0 = _safe_primary_header(paths[0])
3674
+ except Exception:
3675
+ hdr0 = fits.Header()
3676
+
3677
+ hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
3678
+ hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
3679
+ hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
3680
+ hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
3681
+ hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
3682
+ hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
3683
+ r = int(max(1, super_res_factor))
3684
+ hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
3685
+ if r > 1:
3686
+ hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (native px)')
3687
+ hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
3688
+ hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
3689
+ hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
3690
+ hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
3691
+
3692
+ if isinstance(x_final, np.ndarray):
3693
+ if x_final.ndim == 2:
3694
+ hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
3695
+ elif x_final.ndim == 3:
3696
+ C, H, W = x_final.shape
3697
+ hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
3698
+
3699
+ save_path = _sr_out_path(out_path, super_res_factor)
3700
+ safe_out_path = _nonclobber_path(str(save_path))
3701
+ if safe_out_path != str(save_path):
3702
+ status_cb(f"Output exists — saving as: {safe_out_path}")
3703
+ fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
3704
+
3705
+ status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
3706
+ _emit_pct(1.00, "done")
3707
+ _process_gui_events_safely()
3708
+
3709
+ try:
3710
+ if use_torch:
3711
+ try: del num, den
3712
+ except Exception as e:
3713
+ import logging
3714
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3715
+ try: del psf_t, psfT_t
3716
+ except Exception as e:
3717
+ import logging
3718
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3719
+ _free_torch_memory()
3720
+ except Exception:
3721
+ pass
3722
+
3723
+ try:
3724
+ _clear_all_caches()
3725
+ except Exception:
3726
+ pass
3727
+
3728
+ return safe_out_path
3729
+
3730
+ # -----------------------------
3731
+ # Worker
3732
+ # -----------------------------
3733
+
3734
+ class MultiFrameDeconvWorker(QObject):
3735
+ progress = pyqtSignal(str)
3736
+ finished = pyqtSignal(bool, str, str) # success, message, out_path
3737
+
3738
+ def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
3739
+ huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
3740
+ star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
3741
+ save_intermediate: bool = False,
3742
+ seed_mode: str = "robust",
3743
+ # NEW SR params
3744
+ super_res_factor: int = 1,
3745
+ sr_sigma: float = 1.1,
3746
+ sr_psf_opt_iters: int = 250,
3747
+ sr_psf_opt_lr: float = 0.1,
3748
+ star_mask_ref_path: str | None = None):
3749
+
3750
+ super().__init__(parent)
3751
+ self.aligned_paths = aligned_paths
3752
+ self.output_path = output_path
3753
+ self.iters = iters
3754
+ self.kappa = kappa
3755
+ self.color_mode = color_mode
3756
+ self.huber_delta = huber_delta
3757
+ self.min_iters = min_iters # NEW
3758
+ self.star_mask_cfg = star_mask_cfg or {}
3759
+ self.varmap_cfg = varmap_cfg or {}
3760
+ self.use_star_masks = use_star_masks
3761
+ self.use_variance_maps = use_variance_maps
3762
+ self.rho = rho
3763
+ self.save_intermediate = save_intermediate
3764
+ self.super_res_factor = int(super_res_factor)
3765
+ self.sr_sigma = float(sr_sigma)
3766
+ self.sr_psf_opt_iters = int(sr_psf_opt_iters)
3767
+ self.sr_psf_opt_lr = float(sr_psf_opt_lr)
3768
+ self.star_mask_ref_path = star_mask_ref_path
3769
+ self.seed_mode = seed_mode
3770
+
3771
+
3772
+ def _log(self, s): self.progress.emit(s)
3773
+
3774
+ def run(self):
3775
+ try:
3776
+ out = multiframe_deconv(
3777
+ self.aligned_paths,
3778
+ self.output_path,
3779
+ iters=self.iters,
3780
+ kappa=self.kappa,
3781
+ color_mode=self.color_mode,
3782
+ seed_mode=self.seed_mode,
3783
+ huber_delta=self.huber_delta,
3784
+ use_star_masks=self.use_star_masks,
3785
+ use_variance_maps=self.use_variance_maps,
3786
+ rho=self.rho,
3787
+ min_iters=self.min_iters,
3788
+ status_cb=self._log,
3789
+ star_mask_cfg=self.star_mask_cfg,
3790
+ varmap_cfg=self.varmap_cfg,
3791
+ save_intermediate=self.save_intermediate,
3792
+ super_res_factor=self.super_res_factor,
3793
+ sr_sigma=self.sr_sigma,
3794
+ sr_psf_opt_iters=self.sr_psf_opt_iters,
3795
+ sr_psf_opt_lr=self.sr_psf_opt_lr,
3796
+ star_mask_ref_path=self.star_mask_ref_path,
3797
+ )
3798
+ self.finished.emit(True, "MF deconvolution complete.", out)
3799
+ _process_gui_events_safely()
3800
+ except Exception as e:
3801
+ self.finished.emit(False, f"MF deconvolution failed: {e}", "")
3802
+ finally:
3803
+ # Hard cleanup: drop references + free GPU memory
3804
+ try:
3805
+ # Drop big Python refs that might keep tensors alive indirectly
3806
+ self.aligned_paths = []
3807
+ self.star_mask_cfg = {}
3808
+ self.varmap_cfg = {}
3809
+ except Exception:
3810
+ pass
3811
+ try:
3812
+ _free_torch_memory() # your helper: del tensors, gc.collect(), etc.
3813
+ except Exception:
3814
+ pass
3815
+ try:
3816
+ import torch as _t
3817
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
3818
+ _t.cuda.synchronize()
3819
+ _t.cuda.empty_cache()
3820
+ if hasattr(_t, "mps") and getattr(_t.backends, "mps", None) and _t.backends.mps.is_available():
3821
+ # PyTorch 2.x has this
3822
+ if hasattr(_t.mps, "empty_cache"):
3823
+ _t.mps.empty_cache()
3824
+ # DirectML usually frees on GC; nothing special to call.
3825
+ except Exception:
3826
+ pass