setiastrosuitepro 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (174) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/saspro/__init__.py +20 -0
  3. setiastro/saspro/__main__.py +784 -0
  4. setiastro/saspro/_generated/__init__.py +7 -0
  5. setiastro/saspro/_generated/build_info.py +2 -0
  6. setiastro/saspro/abe.py +1295 -0
  7. setiastro/saspro/abe_preset.py +196 -0
  8. setiastro/saspro/aberration_ai.py +694 -0
  9. setiastro/saspro/aberration_ai_preset.py +224 -0
  10. setiastro/saspro/accel_installer.py +218 -0
  11. setiastro/saspro/accel_workers.py +30 -0
  12. setiastro/saspro/add_stars.py +621 -0
  13. setiastro/saspro/astrobin_exporter.py +1007 -0
  14. setiastro/saspro/astrospike.py +153 -0
  15. setiastro/saspro/astrospike_python.py +1839 -0
  16. setiastro/saspro/autostretch.py +196 -0
  17. setiastro/saspro/backgroundneutral.py +560 -0
  18. setiastro/saspro/batch_convert.py +325 -0
  19. setiastro/saspro/batch_renamer.py +519 -0
  20. setiastro/saspro/blemish_blaster.py +488 -0
  21. setiastro/saspro/blink_comparator_pro.py +2923 -0
  22. setiastro/saspro/bundles.py +61 -0
  23. setiastro/saspro/bundles_dock.py +114 -0
  24. setiastro/saspro/cheat_sheet.py +168 -0
  25. setiastro/saspro/clahe.py +342 -0
  26. setiastro/saspro/comet_stacking.py +1377 -0
  27. setiastro/saspro/config.py +38 -0
  28. setiastro/saspro/config_bootstrap.py +40 -0
  29. setiastro/saspro/config_manager.py +316 -0
  30. setiastro/saspro/continuum_subtract.py +1617 -0
  31. setiastro/saspro/convo.py +1397 -0
  32. setiastro/saspro/convo_preset.py +414 -0
  33. setiastro/saspro/copyastro.py +187 -0
  34. setiastro/saspro/cosmicclarity.py +1564 -0
  35. setiastro/saspro/cosmicclarity_preset.py +407 -0
  36. setiastro/saspro/crop_dialog_pro.py +948 -0
  37. setiastro/saspro/crop_preset.py +189 -0
  38. setiastro/saspro/curve_editor_pro.py +2544 -0
  39. setiastro/saspro/curves_preset.py +375 -0
  40. setiastro/saspro/debayer.py +670 -0
  41. setiastro/saspro/debug_utils.py +29 -0
  42. setiastro/saspro/dnd_mime.py +35 -0
  43. setiastro/saspro/doc_manager.py +2634 -0
  44. setiastro/saspro/exoplanet_detector.py +2166 -0
  45. setiastro/saspro/file_utils.py +284 -0
  46. setiastro/saspro/fitsmodifier.py +744 -0
  47. setiastro/saspro/free_torch_memory.py +48 -0
  48. setiastro/saspro/frequency_separation.py +1343 -0
  49. setiastro/saspro/function_bundle.py +1594 -0
  50. setiastro/saspro/ghs_dialog_pro.py +660 -0
  51. setiastro/saspro/ghs_preset.py +284 -0
  52. setiastro/saspro/graxpert.py +634 -0
  53. setiastro/saspro/graxpert_preset.py +287 -0
  54. setiastro/saspro/gui/__init__.py +0 -0
  55. setiastro/saspro/gui/main_window.py +8494 -0
  56. setiastro/saspro/gui/mixins/__init__.py +33 -0
  57. setiastro/saspro/gui/mixins/dock_mixin.py +263 -0
  58. setiastro/saspro/gui/mixins/file_mixin.py +445 -0
  59. setiastro/saspro/gui/mixins/geometry_mixin.py +403 -0
  60. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  61. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  62. setiastro/saspro/gui/mixins/menu_mixin.py +361 -0
  63. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  64. setiastro/saspro/gui/mixins/toolbar_mixin.py +1324 -0
  65. setiastro/saspro/gui/mixins/update_mixin.py +309 -0
  66. setiastro/saspro/gui/mixins/view_mixin.py +435 -0
  67. setiastro/saspro/halobgon.py +462 -0
  68. setiastro/saspro/header_viewer.py +445 -0
  69. setiastro/saspro/headless_utils.py +88 -0
  70. setiastro/saspro/histogram.py +753 -0
  71. setiastro/saspro/history_explorer.py +939 -0
  72. setiastro/saspro/image_combine.py +414 -0
  73. setiastro/saspro/image_peeker_pro.py +1596 -0
  74. setiastro/saspro/imageops/__init__.py +37 -0
  75. setiastro/saspro/imageops/mdi_snap.py +292 -0
  76. setiastro/saspro/imageops/scnr.py +36 -0
  77. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  78. setiastro/saspro/imageops/stretch.py +244 -0
  79. setiastro/saspro/isophote.py +1179 -0
  80. setiastro/saspro/layers.py +208 -0
  81. setiastro/saspro/layers_dock.py +714 -0
  82. setiastro/saspro/lazy_imports.py +193 -0
  83. setiastro/saspro/legacy/__init__.py +2 -0
  84. setiastro/saspro/legacy/image_manager.py +2226 -0
  85. setiastro/saspro/legacy/numba_utils.py +3659 -0
  86. setiastro/saspro/legacy/xisf.py +1071 -0
  87. setiastro/saspro/linear_fit.py +534 -0
  88. setiastro/saspro/live_stacking.py +1830 -0
  89. setiastro/saspro/log_bus.py +5 -0
  90. setiastro/saspro/logging_config.py +460 -0
  91. setiastro/saspro/luminancerecombine.py +309 -0
  92. setiastro/saspro/main_helpers.py +201 -0
  93. setiastro/saspro/mask_creation.py +928 -0
  94. setiastro/saspro/masks_core.py +56 -0
  95. setiastro/saspro/mdi_widgets.py +353 -0
  96. setiastro/saspro/memory_utils.py +666 -0
  97. setiastro/saspro/metadata_patcher.py +75 -0
  98. setiastro/saspro/mfdeconv.py +3826 -0
  99. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  100. setiastro/saspro/mfdeconvcudnn.py +3263 -0
  101. setiastro/saspro/mfdeconvsport.py +2382 -0
  102. setiastro/saspro/minorbodycatalog.py +567 -0
  103. setiastro/saspro/morphology.py +382 -0
  104. setiastro/saspro/multiscale_decomp.py +1290 -0
  105. setiastro/saspro/nbtorgb_stars.py +531 -0
  106. setiastro/saspro/numba_utils.py +3044 -0
  107. setiastro/saspro/numba_warmup.py +141 -0
  108. setiastro/saspro/ops/__init__.py +9 -0
  109. setiastro/saspro/ops/command_help_dialog.py +623 -0
  110. setiastro/saspro/ops/command_runner.py +217 -0
  111. setiastro/saspro/ops/commands.py +1594 -0
  112. setiastro/saspro/ops/script_editor.py +1102 -0
  113. setiastro/saspro/ops/scripts.py +1413 -0
  114. setiastro/saspro/ops/settings.py +560 -0
  115. setiastro/saspro/parallel_utils.py +554 -0
  116. setiastro/saspro/pedestal.py +121 -0
  117. setiastro/saspro/perfect_palette_picker.py +1053 -0
  118. setiastro/saspro/pipeline.py +110 -0
  119. setiastro/saspro/pixelmath.py +1600 -0
  120. setiastro/saspro/plate_solver.py +2435 -0
  121. setiastro/saspro/project_io.py +797 -0
  122. setiastro/saspro/psf_utils.py +136 -0
  123. setiastro/saspro/psf_viewer.py +549 -0
  124. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  125. setiastro/saspro/remove_green.py +314 -0
  126. setiastro/saspro/remove_stars.py +1625 -0
  127. setiastro/saspro/remove_stars_preset.py +404 -0
  128. setiastro/saspro/resources.py +472 -0
  129. setiastro/saspro/rgb_combination.py +207 -0
  130. setiastro/saspro/rgb_extract.py +19 -0
  131. setiastro/saspro/rgbalign.py +723 -0
  132. setiastro/saspro/runtime_imports.py +7 -0
  133. setiastro/saspro/runtime_torch.py +754 -0
  134. setiastro/saspro/save_options.py +72 -0
  135. setiastro/saspro/selective_color.py +1552 -0
  136. setiastro/saspro/sfcc.py +1425 -0
  137. setiastro/saspro/shortcuts.py +2807 -0
  138. setiastro/saspro/signature_insert.py +1099 -0
  139. setiastro/saspro/stacking_suite.py +17712 -0
  140. setiastro/saspro/star_alignment.py +7420 -0
  141. setiastro/saspro/star_alignment_preset.py +329 -0
  142. setiastro/saspro/star_metrics.py +49 -0
  143. setiastro/saspro/star_spikes.py +681 -0
  144. setiastro/saspro/star_stretch.py +470 -0
  145. setiastro/saspro/stat_stretch.py +502 -0
  146. setiastro/saspro/status_log_dock.py +78 -0
  147. setiastro/saspro/subwindow.py +3267 -0
  148. setiastro/saspro/supernovaasteroidhunter.py +1712 -0
  149. setiastro/saspro/swap_manager.py +99 -0
  150. setiastro/saspro/torch_backend.py +89 -0
  151. setiastro/saspro/torch_rejection.py +434 -0
  152. setiastro/saspro/view_bundle.py +1555 -0
  153. setiastro/saspro/wavescale_hdr.py +624 -0
  154. setiastro/saspro/wavescale_hdr_preset.py +100 -0
  155. setiastro/saspro/wavescalede.py +657 -0
  156. setiastro/saspro/wavescalede_preset.py +228 -0
  157. setiastro/saspro/wcs_update.py +374 -0
  158. setiastro/saspro/whitebalance.py +456 -0
  159. setiastro/saspro/widgets/__init__.py +48 -0
  160. setiastro/saspro/widgets/common_utilities.py +305 -0
  161. setiastro/saspro/widgets/graphics_views.py +122 -0
  162. setiastro/saspro/widgets/image_utils.py +518 -0
  163. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  164. setiastro/saspro/widgets/spinboxes.py +275 -0
  165. setiastro/saspro/widgets/themed_buttons.py +13 -0
  166. setiastro/saspro/widgets/wavelet_utils.py +299 -0
  167. setiastro/saspro/window_shelf.py +185 -0
  168. setiastro/saspro/xisf.py +1123 -0
  169. setiastrosuitepro-1.6.0.dist-info/METADATA +266 -0
  170. setiastrosuitepro-1.6.0.dist-info/RECORD +174 -0
  171. setiastrosuitepro-1.6.0.dist-info/WHEEL +4 -0
  172. setiastrosuitepro-1.6.0.dist-info/entry_points.txt +6 -0
  173. setiastrosuitepro-1.6.0.dist-info/licenses/LICENSE +674 -0
  174. setiastrosuitepro-1.6.0.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3263 @@
1
+ # pro/mfdeconvsport.py
2
+ from __future__ import annotations
3
+ import os, sys
4
+ import math
5
+ import re
6
+ import numpy as np
7
+ import tempfile
8
+ import uuid
9
+ import atexit
10
+ from astropy.io import fits
11
+ from PyQt6.QtCore import QObject, pyqtSignal
12
+ from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
13
+ from PyQt6.QtWidgets import QApplication
14
+ from PyQt6.QtCore import QThread
15
+ from threadpoolctl import threadpool_limits
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
17
+ _USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
18
+ import numpy.fft as _fft
19
+ import contextlib
20
+ from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
21
+
22
+ import gc
23
+ try:
24
+ import sep
25
+ except Exception:
26
+ sep = None
27
+ from setiastro.saspro.free_torch_memory import _free_torch_memory
28
+ torch = None # filled by runtime loader if available
29
+ TORCH_OK = False
30
+ NO_GRAD = contextlib.nullcontext # fallback
31
+
32
+ _SCRATCH_MMAPS = []
33
+ _XISF_READERS = []
34
+ try:
35
+ # e.g. your legacy module
36
+ from setiastro.saspro.legacy import xisf as _legacy_xisf
37
+ if hasattr(_legacy_xisf, "read"):
38
+ _XISF_READERS.append(lambda p: _legacy_xisf.read(p))
39
+ elif hasattr(_legacy_xisf, "open"):
40
+ _XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
41
+ except Exception:
42
+ pass
43
+ try:
44
+ # sometimes projects expose a generic load_image
45
+ from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
46
+ _XISF_READERS.append(lambda p: _generic_load_image(p)[0])
47
+ except Exception:
48
+ pass
49
+
50
+ # at top of file with the other imports
51
+ from concurrent.futures import ThreadPoolExecutor, as_completed
52
+ from queue import SimpleQueue
53
+ from setiastro.saspro.memory_utils import LRUDict
54
+
55
+ # ── XISF decode cache → memmap on disk ─────────────────────────────────
56
+ import tempfile
57
+ import threading
58
+ import uuid
59
+ import atexit
60
+ _XISF_CACHE = LRUDict(50)
61
+ _XISF_LOCK = threading.Lock()
62
+ _XISF_TMPFILES = []
63
+
64
+ from collections import OrderedDict
65
+
66
+ # ─────────────────────────────────────────────────────────────────────────────
67
+ # Unified image I/O for MFDeconv (FITS + XISF)
68
+ # ─────────────────────────────────────────────────────────────────────────────
69
+
70
+ from pathlib import Path
71
+
72
+
73
+ from collections import OrderedDict
74
+
75
+ _VAR_DTYPE = np.float32
76
+
77
+ def _mm_create(shape, dtype, scratch_dir=None, tag="scratch"):
78
+ """Create a disk-backed memmap array, zero-initialized."""
79
+ scratch_dir = scratch_dir or tempfile.gettempdir()
80
+ fn = os.path.join(scratch_dir, f"mfdeconv_{tag}_{uuid.uuid4().hex}.mmap")
81
+ mm = np.memmap(fn, mode="w+", dtype=dtype, shape=tuple(map(int, shape)))
82
+ mm[...] = 0
83
+ mm.flush()
84
+ _SCRATCH_MMAPS.append(fn)
85
+ return mm
86
+
87
+ def _maybe_memmap(shape, dtype=np.float32, *,
88
+ force_mm=False, threshold_mb=512, scratch_dir=None, tag="scratch"):
89
+ """Return either np.zeros(...) or a zeroed memmap based on size/flags."""
90
+ nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
91
+ if force_mm or (nbytes >= threshold_mb * 1024 * 1024):
92
+ return _mm_create(shape, dtype, scratch_dir, tag)
93
+ return np.zeros(shape, dtype=dtype)
94
+
95
+ def _cleanup_scratch_mm():
96
+ for fn in _SCRATCH_MMAPS[:]:
97
+ try: os.remove(fn)
98
+ except Exception as e:
99
+ import logging
100
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
101
+ _SCRATCH_MMAPS.clear()
102
+
103
+ atexit.register(_cleanup_scratch_mm)
104
+
105
+ # ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
106
+ class _FrameCHWLRU:
107
+ def __init__(self, capacity=8):
108
+ self.cap = int(max(1, capacity))
109
+ self.od = OrderedDict()
110
+
111
+ def clear(self):
112
+ self.od.clear()
113
+
114
+ def get(self, path, Ht, Wt, color_mode):
115
+ key = (path, Ht, Wt, str(color_mode).lower())
116
+ hit = self.od.get(key)
117
+ if hit is not None:
118
+ self.od.move_to_end(key)
119
+ return hit
120
+
121
+ # Load backing array cheaply (memmap for FITS, cached memmap for XISF)
122
+ ext = os.path.splitext(path)[1].lower()
123
+ if ext == ".xisf":
124
+ a = _xisf_cached_array(path) # float32, HW/HWC/CHW
125
+ else:
126
+ # FITS path: use astropy memmap (no data copy)
127
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
128
+ arr = None
129
+ for h in hdul:
130
+ if getattr(h, "data", None) is not None:
131
+ arr = h.data
132
+ break
133
+ if arr is None:
134
+ raise ValueError(f"No image data in {path}")
135
+ a = np.asarray(arr)
136
+ # dtype normalize once; keep float32
137
+ if a.dtype.kind in "ui":
138
+ a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
139
+ else:
140
+ a = a.astype(np.float32, copy=False)
141
+
142
+ # Center-crop to (Ht, Wt) and convert to CHW
143
+ a = np.asarray(a) # float32
144
+ a = _center_crop(a, Ht, Wt)
145
+
146
+ # Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
147
+ cm = str(color_mode).lower()
148
+ if cm == "luma":
149
+ a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
150
+ else:
151
+ a_chw = _as_chw(a).astype(np.float32, copy=False)
152
+ if a_chw.shape[0] == 1 and cm != "luma":
153
+ # still OK (mono data)
154
+ pass
155
+
156
+ # LRU insert
157
+ self.od[key] = a_chw
158
+ if len(self.od) > self.cap:
159
+ self.od.popitem(last=False)
160
+ return a_chw
161
+
162
+ _FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
163
+
164
+ def _clear_all_caches():
165
+ try: _clear_xisf_cache()
166
+ except Exception as e:
167
+ import logging
168
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
169
+ try: _FRAME_LRU.clear()
170
+ except Exception as e:
171
+ import logging
172
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
173
+
174
+ def _as_chw(np_img: np.ndarray) -> np.ndarray:
175
+ x = np.asarray(np_img, dtype=np.float32, order="C")
176
+ if x.size == 0:
177
+ raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
178
+ if x.ndim == 2:
179
+ return x[None, ...] # 1,H,W
180
+ if x.ndim == 3 and x.shape[0] in (1, 3):
181
+ if x.shape[0] == 0:
182
+ raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
183
+ return x
184
+ if x.ndim == 3 and x.shape[-1] in (1, 3):
185
+ if x.shape[-1] == 0:
186
+ raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
187
+ return np.moveaxis(x, -1, 0)
188
+ # last resort: treat first dim as channels, but reject zero
189
+ if x.shape[0] == 0:
190
+ raise RuntimeError(f"Zero channels in array; shape={x.shape}")
191
+ return x
192
+
193
+ def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
194
+ if a.dtype.kind in "ui":
195
+ return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
196
+ if a.dtype == np.float32:
197
+ return a
198
+ return a.astype(np.float32, copy=False)
199
+
200
+ def _xisf_cached_array(path: str) -> np.memmap:
201
+ """
202
+ Decode an XISF image exactly once and back it by a read-only float32 memmap.
203
+ Returns a memmap that can be sliced cheaply for tiles.
204
+ """
205
+ with _XISF_LOCK:
206
+ hit = _XISF_CACHE.get(path)
207
+ if hit is not None:
208
+ fn, shape = hit
209
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
210
+
211
+ # Decode once
212
+ arr, _ = _load_image_array(path) # your existing loader
213
+ if arr is None:
214
+ raise ValueError(f"XISF loader returned None for {path}")
215
+ arr = np.asarray(arr)
216
+ arrf = _normalize_to_float32(arr)
217
+
218
+ # Create a temp file-backed memmap
219
+ tmpdir = tempfile.gettempdir()
220
+ fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
221
+ mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
222
+ mm[...] = arrf[...]
223
+ mm.flush()
224
+ del mm # close writer handle; re-open below as read-only
225
+
226
+ _XISF_CACHE[path] = (fn, arrf.shape)
227
+ _XISF_TMPFILES.append(fn)
228
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
229
+
230
+ def _clear_xisf_cache():
231
+ with _XISF_LOCK:
232
+ for fn in _XISF_TMPFILES:
233
+ try: os.remove(fn)
234
+ except Exception as e:
235
+ import logging
236
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
237
+ _XISF_CACHE.clear()
238
+ _XISF_TMPFILES.clear()
239
+
240
+ atexit.register(_clear_xisf_cache)
241
+
242
+
243
+ def _is_xisf(path: str) -> bool:
244
+ return os.path.splitext(path)[1].lower() == ".xisf"
245
+
246
+ def _read_xisf_numpy(path: str) -> np.ndarray:
247
+ if not _XISF_READERS:
248
+ raise RuntimeError(
249
+ "No XISF readers registered. Ensure one of "
250
+ "legacy.xisf.read/open or *.image_io.load_image is importable."
251
+ )
252
+ last_err = None
253
+ for fn in _XISF_READERS:
254
+ try:
255
+ arr = fn(path)
256
+ if isinstance(arr, tuple):
257
+ arr = arr[0]
258
+ return np.asarray(arr)
259
+ except Exception as e:
260
+ last_err = e
261
+ raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
262
+
263
+ def _fits_open_data(path: str):
264
+ # ignore_missing_simple=True lets us open headers missing SIMPLE
265
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
266
+ hdu = hdul[0]
267
+ if hdu.data is None:
268
+ # find first image HDU if primary is header-only
269
+ for h in hdul[1:]:
270
+ if getattr(h, "data", None) is not None:
271
+ hdu = h
272
+ break
273
+ data = np.asanyarray(hdu.data)
274
+ hdr = hdu.header
275
+ return data, hdr
276
+
277
+ def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
278
+ """
279
+ Return (numpy array, fits.Header or None). Color-last if 3D.
280
+ dtype left as-is; callers cast to float32. Array is C-contig & writeable.
281
+ """
282
+ if _is_xisf(path):
283
+ arr = _read_xisf_numpy(path)
284
+ hdr = None
285
+ else:
286
+ arr, hdr = _fits_open_data(path)
287
+
288
+ a = np.asarray(arr)
289
+ # Move color axis to last if 3D with a leading channel axis
290
+ if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
291
+ a = np.moveaxis(a, 0, -1)
292
+ # Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
293
+ if (not a.flags.c_contiguous) or (not a.flags.writeable):
294
+ a = np.array(a, copy=True)
295
+ return a, hdr
296
+
297
+ def _probe_hw(path: str) -> tuple[int, int, int | None]:
298
+ """
299
+ Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
300
+ """
301
+ a, _ = _load_image_array(path)
302
+ if a.ndim == 2:
303
+ return a.shape[0], a.shape[1], None
304
+ if a.ndim == 3:
305
+ h, w, c = a.shape
306
+ # treat mono-3D as (H,W,1)
307
+ if c not in (1, 3) and a.shape[0] in (1, 3):
308
+ a = np.moveaxis(a, 0, -1)
309
+ h, w, c = a.shape
310
+ return h, w, c if c in (1, 3) else None
311
+ raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
312
+
313
+ def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
314
+ """
315
+ Replacement for the old FITS-only version: min(H), min(W) across files.
316
+ """
317
+ Hs, Ws = [], []
318
+ for p in paths:
319
+ h, w, _ = _probe_hw(p)
320
+ Hs.append(int(h)); Ws.append(int(w))
321
+ return int(min(Hs)), int(min(Ws))
322
+
323
+ def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
324
+ """
325
+ Convert to CHW float32:
326
+ - mono → (1,H,W)
327
+ - RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
328
+ """
329
+ x = np.asarray(img)
330
+ if x.ndim == 2:
331
+ y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
332
+ return y
333
+ if x.ndim == 3:
334
+ # color-last (H,W,C) expected
335
+ if x.shape[-1] == 1:
336
+ return x[..., 0].astype(np.float32, copy=False)[None, ...]
337
+ if x.shape[-1] == 3:
338
+ if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
339
+ r, g, b = x[..., 0], x[..., 1], x[..., 2]
340
+ return np.stack([r.astype(np.float32, copy=False),
341
+ g.astype(np.float32, copy=False),
342
+ b.astype(np.float32, copy=False)], axis=0)
343
+ # luma
344
+ r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
345
+ L = 0.2126*r + 0.7152*g + 0.0722*b
346
+ return L[None, ...]
347
+ # rare mono-3D
348
+ if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
349
+ x = np.moveaxis(x, 0, -1)
350
+ return _to_chw_float32(x, color_mode)
351
+ raise ValueError(f"Unsupported image shape {x.shape}")
352
+
353
+ def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
354
+ h, w = img.shape[:2]
355
+ y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
356
+ return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
357
+
358
+ def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
359
+ """
360
+ Drop-in replacement of the old FITS-only helper.
361
+ Returns (ys, hdrs):
362
+ ys : list of CHW float32 arrays cropped to (Ht,Wt)
363
+ hdrs : list of fits.Header or None (XISF)
364
+ """
365
+ ys, hdrs = [], []
366
+ for p in paths:
367
+ arr, hdr = _load_image_array(p)
368
+ arr = _center_crop_hw(arr, Ht, Wt)
369
+ # normalize integer data to [0,1] like the rest of your code
370
+ if arr.dtype.kind in "ui":
371
+ mx = np.float32(np.iinfo(arr.dtype).max)
372
+ arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
373
+ elif arr.dtype.kind == "f":
374
+ arr = arr.astype(np.float32, copy=False)
375
+ else:
376
+ arr = arr.astype(np.float32, copy=False)
377
+
378
+ y = _to_chw_float32(arr, color_mode)
379
+ if (not y.flags.c_contiguous) or (not y.flags.writeable):
380
+ y = np.ascontiguousarray(y.astype(np.float32, copy=True))
381
+ ys.append(y)
382
+ hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
383
+ return ys, hdrs
384
+
385
+ def _safe_primary_header(path: str) -> fits.Header:
386
+ if _is_xisf(path):
387
+ # best-effort synthetic header
388
+ h = fits.Header()
389
+ h["SIMPLE"] = (True, "created by MFDeconv")
390
+ h["BITPIX"] = -32
391
+ h["NAXIS"] = 2
392
+ return h
393
+ try:
394
+ return fits.getheader(path, ext=0, ignore_missing_simple=True)
395
+ except Exception:
396
+ return fits.Header()
397
+
398
+
399
+ def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
400
+ star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
401
+ """
402
+ Worker function: compute PSF and optional star mask / varmap for one frame.
403
+ Returns (index, psf, mask_or_None, var_or_None, var_path_or_None, log_lines)
404
+ """
405
+
406
+
407
+ logs = []
408
+ def log(s): logs.append(s)
409
+
410
+ # --- PSF sizing by FWHM ---
411
+ f_hdr = _estimate_fwhm_from_header(hdr)
412
+ f_img = _estimate_fwhm_from_image(arr)
413
+ f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
414
+ if not np.isfinite(f_whm) or f_whm <= 0:
415
+ f_whm = 2.5
416
+ k_auto = _auto_ksize_from_fwhm(f_whm)
417
+
418
+ # --- Star-derived PSF with retries ---
419
+ tried, psf = [], None
420
+ for k_try in [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]:
421
+ if k_try in tried:
422
+ continue
423
+ tried.append(k_try)
424
+ try:
425
+ out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=6.0, max_stars=80)
426
+ psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
427
+ if psf_try is not None:
428
+ psf = psf_try
429
+ break
430
+ except Exception:
431
+ psf = None
432
+ if psf is None:
433
+ psf = _gaussian_psf(f_whm, ksize=k_auto)
434
+ psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.25)
435
+
436
+ mask = None
437
+ var = None
438
+ var_path = None
439
+
440
+ if make_masks or make_varmaps:
441
+ luma = _to_luma_local(arr)
442
+ vmc = (varmap_cfg or {})
443
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
444
+ luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
445
+ )
446
+
447
+ # ---------- Star mask ----------
448
+ if make_masks:
449
+ smc = star_mask_cfg or {}
450
+ mask = _star_mask_from_precomputed(
451
+ luma, sky_map, err_scalar,
452
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
453
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
454
+ grow_px = smc.get("grow_px", GROW_PX),
455
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
456
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
457
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
458
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
459
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
460
+ status_cb = log,
461
+ )
462
+ # keep masks compact
463
+ if mask is not None and mask.dtype != np.uint8:
464
+ mask = (mask > 0.5).astype(np.uint8, copy=False)
465
+
466
+ # ---------- Variance map (memmap path; Option B) ----------
467
+ if make_varmaps:
468
+ vmc = varmap_cfg or {}
469
+ def _vprog(frac: float, msg: str = ""):
470
+ try: log(f"__PROGRESS__ {0.16 + 0.02*float(frac):.4f} {msg}")
471
+ except Exception as e:
472
+ import logging
473
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
474
+
475
+ var_path = _variance_map_from_precomputed_memmap(
476
+ luma, sky_map, rms_map, hdr,
477
+ smooth_sigma = float(vmc.get("smooth_sigma", 1.0)),
478
+ floor = float(vmc.get("floor", 1e-8)),
479
+ tile_hw = tuple(vmc.get("tile_hw", (512, 512))),
480
+ scratch_dir = vmc.get("scratch_dir", None),
481
+ tag = f"varmap_{i:04d}",
482
+ status_cb = log,
483
+ progress_cb = _vprog,
484
+ )
485
+ var = None # Option B: don't keep an open memmap handle
486
+
487
+ # 🔻 free heavy temporaries immediately
488
+ try:
489
+ del luma
490
+ del sky_map
491
+ del rms_map
492
+ except Exception:
493
+ pass
494
+ gc.collect()
495
+
496
+ # per-frame summary
497
+ fwhm_est = _psf_fwhm_px(psf)
498
+ logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
499
+ return i, psf, mask, var, var_path, logs
500
+
501
+ def _compute_one_worker(args):
502
+ """
503
+ Process-safe worker wrapper.
504
+ Args tuple: (i, path, make_masks, make_varmaps, star_mask_cfg, varmap_cfg, Ht, Wt, color_mode)
505
+ Returns: (i, psf, mask, var_or_None, var_path_or_None, logs)
506
+ """
507
+ (i, path, make_masks, make_varmaps, star_mask_cfg, varmap_cfg, Ht, Wt, color_mode) = (
508
+ args if len(args) == 9 else (*args, None, None, None) # allow old callers
509
+ )[:9]
510
+
511
+ # lightweight load (center-crop to Ht,Wt, get header)
512
+ try:
513
+ hdr = _safe_primary_header(path)
514
+ except Exception:
515
+ hdr = fits.Header()
516
+
517
+ # read full image then crop center; keep float32 luma/mono 2D
518
+ ext = os.path.splitext(path)[1].lower()
519
+ if ext == ".xisf":
520
+ arr_all, _ = _load_image_array(path)
521
+ arr_all = np.asarray(arr_all)
522
+ else:
523
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
524
+ arr_all = np.asarray(hdul[0].data)
525
+
526
+ # to luma/2D
527
+ if arr_all.ndim == 3:
528
+ if arr_all.shape[0] in (1, 3): # CHW → take first/luma
529
+ arr2d = arr_all[0].astype(np.float32, copy=False)
530
+ elif arr_all.shape[-1] in (1, 3): # HWC → to luma then 2D
531
+ arr2d = _to_luma_local(arr_all).astype(np.float32, copy=False)
532
+ else:
533
+ arr2d = _to_luma_local(arr_all).astype(np.float32, copy=False)
534
+ else:
535
+ arr2d = np.asarray(arr_all, dtype=np.float32)
536
+
537
+ # center-crop/pad to (Ht,Wt) if needed
538
+ H, W = arr2d.shape
539
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
540
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
541
+ arr2d = np.ascontiguousarray(arr2d[y0:y1, x0:x1], dtype=np.float32)
542
+ if arr2d.shape != (Ht, Wt):
543
+ out = np.zeros((Ht, Wt), dtype=np.float32)
544
+ oy = (Ht - arr2d.shape[0]) // 2; ox = (Wt - arr2d.shape[1]) // 2
545
+ out[oy:oy+arr2d.shape[0], ox:ox+arr2d.shape[1]] = arr2d
546
+ arr2d = out
547
+
548
+ # compute assets
549
+ i2, psf, mask, var, var_path, logs = _compute_frame_assets(
550
+ i, arr2d, hdr,
551
+ make_masks=bool(make_masks),
552
+ make_varmaps=bool(make_varmaps),
553
+ star_mask_cfg=star_mask_cfg,
554
+ varmap_cfg=varmap_cfg,
555
+ )
556
+
557
+ # Force Option B behavior: close any memmap and only pass a path
558
+ if isinstance(var, np.memmap):
559
+ try: var.flush()
560
+ except Exception as e:
561
+ import logging
562
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
563
+ var = None
564
+
565
+ return i2, psf, mask, var, var_path, logs
566
+
567
+
568
+ def _normalize_assets_result(res):
569
+ """
570
+ Accept worker results in legacy and new shapes and normalize to:
571
+ (i, psf, mask, var_or_None, var_path_or_None, logs_list)
572
+ Supported inputs:
573
+ - (i, psf, mask, var, logs)
574
+ - (i, psf, mask, var, var_path, logs)
575
+ - (i, psf, mask, var, var_mm, var_path, logs) # legacy where both returned
576
+ """
577
+ if not isinstance(res, (tuple, list)) or len(res) < 5:
578
+ raise ValueError(f"Unexpected assets result: {type(res)} len={len(res) if hasattr(res,'__len__') else 'na'}")
579
+
580
+ i = res[0]
581
+ psf = res[1]
582
+ mask = res[2]
583
+ logs = res[-1]
584
+
585
+ middle = res[3:-1] # everything between mask and logs
586
+ var = None
587
+ var_path = None
588
+
589
+ # Try to recover var (np.ndarray/np.memmap) and a path (str)
590
+ for x in middle:
591
+ if var is None and hasattr(x, "shape"): # ndarray/memmap/torch?
592
+ var = x
593
+ if var_path is None and isinstance(x, str):
594
+ var_path = x
595
+
596
+ # Back-compat for 5-tuple
597
+ if len(res) == 5:
598
+ var = middle[0] if middle else None
599
+ var_path = None
600
+
601
+ return i, psf, mask, var, var_path, logs
602
+
603
+
604
+ def _build_psf_and_assets(
605
+ paths, # list[str]
606
+ make_masks=False,
607
+ make_varmaps=False,
608
+ status_cb=lambda s: None,
609
+ save_dir: str | None = None,
610
+ star_mask_cfg: dict | None = None,
611
+ varmap_cfg: dict | None = None,
612
+ max_workers: int | None = None,
613
+ star_mask_ref_path: str | None = None, # build one mask from this frame if provided
614
+ # NEW (passed from multiframe_deconv so we don’t re-probe/convert):
615
+ Ht: int | None = None,
616
+ Wt: int | None = None,
617
+ color_mode: str = "luma",
618
+ ):
619
+ """
620
+ Parallel PSF + (optional) star mask + variance map per frame.
621
+
622
+ Changes:
623
+ • Variance maps are written to disk as memmaps (Option B) and **paths** are returned.
624
+ • If a single reference star mask is requested, it is built once and reused.
625
+ • Returns: (psfs, masks, vars_, var_paths) — vars_ contains None for varmaps.
626
+ • RAM bounded: frees per-frame temporaries, drains logs, trims frame cache.
627
+ """
628
+
629
+ # Local helpers expected from your module scope:
630
+ # _FRAME_LRU, _common_hw_from_paths, _safe_primary_header, fits, _gaussian_psf, etc.
631
+
632
+ if save_dir:
633
+ os.makedirs(save_dir, exist_ok=True)
634
+
635
+ n = len(paths)
636
+
637
+ # Resolve target intersection size if caller didn't pass it
638
+ if Ht is None or Wt is None:
639
+ Ht, Wt = _common_hw_from_paths(paths)
640
+
641
+ # Conservative default worker count to cap concurrent RAM
642
+ if max_workers is None:
643
+ try:
644
+ hw = os.cpu_count() or 4
645
+ except Exception:
646
+ hw = 4
647
+ # half the cores, max 4 (keeps sky/rms/luma concurrency modest)
648
+ max_workers = max(1, min(4, hw // 2))
649
+
650
+ # Decide executor: for any XISF, prefer threads so the memmap/cache is shared
651
+ any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
652
+ use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
653
+ Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
654
+ pool_kind = "process" if use_proc_pool else "thread"
655
+ status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
656
+
657
+ # ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
658
+ def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
659
+ a2d = np.asarray(a2d, dtype=np.float32)
660
+ H, W = int(a2d.shape[0]), int(a2d.shape[1])
661
+ # crop first if bigger
662
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
663
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
664
+ cropped = a2d[y0:y1, x0:x1]
665
+ ch, cw = cropped.shape
666
+ if ch == Ht and cw == Wt:
667
+ return np.ascontiguousarray(cropped, dtype=np.float32)
668
+ # pad if smaller
669
+ out = np.full((Ht, Wt), float(fill), dtype=np.float32)
670
+ oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
671
+ out[oy:oy+ch, ox:ox+cw] = cropped
672
+ return out
673
+
674
+ # ---- optional: build one mask from the reference frame and reuse ----
675
+ base_ref_mask = None
676
+ if make_masks and star_mask_ref_path:
677
+ try:
678
+ status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
679
+ ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
680
+ L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
681
+
682
+ vmc = (varmap_cfg or {})
683
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
684
+ L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
685
+ )
686
+ smc = (star_mask_cfg or {})
687
+ base_ref_mask = _star_mask_from_precomputed(
688
+ L, sky_map, err_scalar,
689
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
690
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
691
+ grow_px = smc.get("grow_px", GROW_PX),
692
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
693
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
694
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
695
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
696
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
697
+ status_cb = status_cb,
698
+ )
699
+ # keep mask compact
700
+ if base_ref_mask is not None and base_ref_mask.dtype != np.uint8:
701
+ base_ref_mask = (base_ref_mask > 0.5).astype(np.uint8, copy=False)
702
+ # free temps
703
+ del L, sky_map, rms_map
704
+ gc.collect()
705
+ except Exception as e:
706
+ status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
707
+ base_ref_mask = None
708
+
709
+ # for GUI safety, queue logs from workers and flush in the main thread
710
+ from queue import SimpleQueue
711
+ log_queue: SimpleQueue = SimpleQueue()
712
+
713
+ def enqueue_logs(lines):
714
+ for s in lines:
715
+ log_queue.put(s)
716
+
717
+ psfs = [None] * n
718
+ masks = ([None] * n) if make_masks else None
719
+ vars_ = ([None] * n) if make_varmaps else None # Option B: keep None placeholders
720
+ var_paths = ([None] * n) if make_varmaps else None # on-disk paths for varmaps
721
+ make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
722
+
723
+ # --- worker thunk (thread mode hits shared cache; process mode has its own worker fn) ---
724
+ def _compute_one(i: int, path: str):
725
+ with threadpool_limits(limits=1):
726
+ img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
727
+ arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
728
+ try:
729
+ hdr = _safe_primary_header(path)
730
+ except Exception:
731
+ hdr = fits.Header()
732
+ return _compute_frame_assets(
733
+ i, arr2d, hdr,
734
+ make_masks=bool(make_masks_in_worker),
735
+ make_varmaps=bool(make_varmaps),
736
+ star_mask_cfg=star_mask_cfg,
737
+ varmap_cfg=varmap_cfg,
738
+ )
739
+
740
+ # Optional: tighten the frame cache during this phase if your LRU supports it
741
+ try:
742
+ _FRAME_LRU.set_limits(max_items=max(4, max_workers + 2))
743
+ except Exception:
744
+ pass
745
+
746
+ # --- submit jobs ---
747
+ with Executor(max_workers=max_workers) as ex:
748
+ futs = []
749
+ for i, p in enumerate(paths, start=1):
750
+ status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
751
+ if use_proc_pool:
752
+ futs.append(ex.submit(
753
+ _compute_one_worker,
754
+ (i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg, Ht, Wt, color_mode)
755
+ ))
756
+ else:
757
+ futs.append(ex.submit(_compute_one, i, p))
758
+
759
+ done_cnt = 0
760
+ for fut in as_completed(futs):
761
+ res = fut.result()
762
+ i, psf, m, v, vpath, logs = _normalize_assets_result(res)
763
+
764
+ idx = i - 1
765
+ psfs[idx] = psf
766
+ if masks is not None:
767
+ masks[idx] = m
768
+ if vars_ is not None:
769
+ # Option B: don't hold open memmaps in RAM
770
+ vars_[idx] = None
771
+ if var_paths is not None:
772
+ var_paths[idx] = vpath
773
+
774
+ enqueue_logs(logs)
775
+
776
+ try:
777
+ _FRAME_LRU.drop(paths[idx])
778
+ except Exception:
779
+ pass
780
+
781
+ done_cnt += 1
782
+ while not log_queue.empty():
783
+ try:
784
+ status_cb(log_queue.get_nowait())
785
+ except Exception:
786
+ break
787
+
788
+ if (done_cnt % 8) == 0:
789
+ gc.collect()
790
+
791
+ # If we built a single reference mask, apply it to every frame (center pad/crop)
792
+ if base_ref_mask is not None and masks is not None:
793
+ for idx in range(n):
794
+ masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
795
+
796
+ # final flush of any remaining logs
797
+ while not log_queue.empty():
798
+ try:
799
+ status_cb(log_queue.get_nowait())
800
+ except Exception:
801
+ break
802
+
803
+ # save PSFs if requested
804
+ if save_dir:
805
+ for i, k in enumerate(psfs, start=1):
806
+ if k is not None:
807
+ fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
808
+ os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
809
+ )
810
+
811
+ return psfs, masks, vars_, var_paths
812
+ _ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
813
+
814
+ # known FITS-style multi-extensions (rightmost-first match)
815
+ _KNOWN_EXTS = [
816
+ ".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
817
+ ".fz", ".gz",
818
+ ".fits", ".fit"
819
+ ]
820
+
821
+ def _sanitize_token(s: str) -> str:
822
+ s = _ALLOWED.sub("_", s)
823
+ s = re.sub(r"_+", "_", s).strip("_")
824
+ return s
825
+
826
+ def _split_known_exts(p: Path) -> tuple[str, str]:
827
+ """
828
+ Return (name_body, full_ext) where full_ext is a REAL extension block
829
+ (e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
830
+ """
831
+ name = p.name
832
+ for ext in _KNOWN_EXTS:
833
+ if name.lower().endswith(ext):
834
+ body = name[:-len(ext)]
835
+ return body, ext
836
+ # fallback: single suffix
837
+ return p.stem, "".join(p.suffixes)
838
+
839
+ _SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
840
+ _EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
841
+ _RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
842
+
843
+ def _extract_size(body: str) -> str | None:
844
+ m = _SIZE_RE.search(body)
845
+ return f"{m.group(1)}x{m.group(2)}" if m else None
846
+
847
+ def _extract_exposure_secs(body: str) -> str | None:
848
+ m = _EXP_RE.search(body)
849
+ if not m:
850
+ return None
851
+ secs = int(round(float(m.group(1))))
852
+ return f"{secs}s"
853
+
854
+ def _strip_metadata_from_base(body: str) -> str:
855
+ s = body
856
+
857
+ # normalize common separators first
858
+ s = s.replace(" - ", "_")
859
+
860
+ # remove known trailing marker '_MFDeconv'
861
+ s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
862
+
863
+ # remove parenthetical copy counters e.g. '(1)'
864
+ s = re.sub(r"\(\s*\d+\s*\)$", "", s)
865
+
866
+ # remove size (with or without parens) anywhere
867
+ s = _SIZE_RE.sub("", s)
868
+
869
+ # remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
870
+ s = _EXP_RE.sub("", s)
871
+
872
+ # remove any _#x tokens
873
+ s = _RX_RE.sub("", s)
874
+
875
+ # collapse whitespace/underscores and sanitize
876
+ s = re.sub(r"[\s]+", "_", s)
877
+ s = _sanitize_token(s)
878
+ return s or "output"
879
+
880
+ def _canonical_out_name_prefix(base: str, r: int, size: str | None,
881
+ exposure_secs: str | None, tag: str = "MFDeconv") -> str:
882
+ parts = [_sanitize_token(tag), _sanitize_token(base)]
883
+ if size:
884
+ parts.append(_sanitize_token(size))
885
+ if exposure_secs:
886
+ parts.append(_sanitize_token(exposure_secs))
887
+ if int(max(1, r)) > 1:
888
+ parts.append(f"{int(r)}x")
889
+ return "_".join(parts)
890
+
891
+ def _sr_out_path(out_path: str, r: int) -> Path:
892
+ """
893
+ Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
894
+ """
895
+ p = Path(out_path)
896
+ body, real_ext = _split_known_exts(p)
897
+
898
+ # harvest metadata from the whole body (not Path.stem)
899
+ size = _extract_size(body)
900
+ ex_sec = _extract_exposure_secs(body)
901
+
902
+ # clean base
903
+ base = _strip_metadata_from_base(body)
904
+
905
+ new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
906
+ return p.with_name(f"{new_stem}{real_ext}")
907
+
908
+ def _nonclobber_path(path: str) -> str:
909
+ """
910
+ Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
911
+ """
912
+ p = Path(path)
913
+ if not p.exists():
914
+ return str(p)
915
+
916
+ # keep the true extension(s)
917
+ body, real_ext = _split_known_exts(p)
918
+
919
+ # if already has _vN, bump it
920
+ m = re.search(r"(.*)_v(\d+)$", body)
921
+ if m:
922
+ base = m.group(1); n = int(m.group(2)) + 1
923
+ else:
924
+ base = body; n = 2
925
+
926
+ while True:
927
+ candidate = p.with_name(f"{base}_v{n}{real_ext}")
928
+ if not candidate.exists():
929
+ return str(candidate)
930
+ n += 1
931
+
932
+ def _iter_folder(basefile: str) -> str:
933
+ d, fname = os.path.split(basefile)
934
+ root, ext = os.path.splitext(fname)
935
+ tgt = os.path.join(d, f"{root}.iters")
936
+ if not os.path.exists(tgt):
937
+ try:
938
+ os.makedirs(tgt, exist_ok=True)
939
+ except Exception:
940
+ # last resort: suffix (n)
941
+ n = 1
942
+ while True:
943
+ cand = os.path.join(d, f"{root}.iters ({n})")
944
+ try:
945
+ os.makedirs(cand, exist_ok=True)
946
+ return cand
947
+ except Exception:
948
+ n += 1
949
+ return tgt
950
+
951
+ def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
952
+ """
953
+ arr: numpy array (H,W) or (C,H,W) float32
954
+ tag: 'seed' or 'iter_###'
955
+ """
956
+ if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
957
+ arr = np.moveaxis(arr, -1, 0)
958
+ if arr.ndim == 3 and arr.shape[0] == 1:
959
+ arr = arr[0]
960
+
961
+ hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
962
+ hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
963
+ hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
964
+ path = os.path.join(folder, f"{tag}.fit")
965
+ # overwrite allowed inside the dedicated folder
966
+ fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
967
+ return path
968
+
969
+
970
+ def _process_gui_events_safely():
971
+ app = QApplication.instance()
972
+ if app and QThread.currentThread() is app.thread():
973
+ app.processEvents()
974
+
975
+ EPS = 1e-6
976
+
977
+ # -----------------------------
978
+ # Helpers: image prep / shapes
979
+ # -----------------------------
980
+
981
+ # new: lightweight loader that yields one frame at a time
982
+
983
+ def _to_luma_local(a: np.ndarray) -> np.ndarray:
984
+ a = np.asarray(a, dtype=np.float32)
985
+ if a.ndim == 2:
986
+ return a
987
+ if a.ndim == 3:
988
+ # mono fast paths
989
+ if a.shape[-1] == 1: # HWC mono
990
+ return a[..., 0].astype(np.float32, copy=False)
991
+ if a.shape[0] == 1: # CHW mono
992
+ return a[0].astype(np.float32, copy=False)
993
+ # RGB
994
+ if a.shape[-1] == 3: # HWC RGB
995
+ r, g, b = a[..., 0], a[..., 1], a[..., 2]
996
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
997
+ if a.shape[0] == 3: # CHW RGB
998
+ r, g, b = a[0], a[1], a[2]
999
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
1000
+ # fallback: average last axis
1001
+ return a.mean(axis=-1).astype(np.float32, copy=False)
1002
+
1003
+ def _normalize_layout_single(a, color_mode):
1004
+ """
1005
+ Coerce to:
1006
+ - 'luma' -> (H, W)
1007
+ - 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
1008
+ Accepts (H,W), (H,W,3), or (3,H,W).
1009
+ """
1010
+ a = np.asarray(a, dtype=np.float32)
1011
+
1012
+ if color_mode == "luma":
1013
+ return _to_luma_local(a) # returns (H,W)
1014
+
1015
+ # perchannel
1016
+ if a.ndim == 2:
1017
+ return a[None, ...] # (1,H,W) ← keep mono as 1 channel
1018
+ if a.ndim == 3 and a.shape[-1] == 3:
1019
+ return np.moveaxis(a, -1, 0) # (3,H,W)
1020
+ if a.ndim == 3 and a.shape[0] in (1, 3):
1021
+ return a # already (1,H,W) or (3,H,W)
1022
+ # fallback: average any weird shape into luma 1×H×W
1023
+ l = _to_luma_local(a)
1024
+ return l[None, ...]
1025
+
1026
+
1027
+ def _normalize_layout_batch(arrs, color_mode):
1028
+ return [_normalize_layout_single(a, color_mode) for a in arrs]
1029
+
1030
+ def _common_hw(data_list):
1031
+ """Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
1032
+ Hs, Ws = [], []
1033
+ for a in data_list:
1034
+ if a.ndim == 2:
1035
+ H, W = a.shape
1036
+ else:
1037
+ _, H, W = a.shape
1038
+ Hs.append(H); Ws.append(W)
1039
+ return int(min(Hs)), int(min(Ws))
1040
+
1041
+ def _center_crop(arr, Ht, Wt):
1042
+ """Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
1043
+ if arr.ndim == 2:
1044
+ H, W = arr.shape
1045
+ if H == Ht and W == Wt:
1046
+ return arr
1047
+ y0 = max(0, (H - Ht) // 2)
1048
+ x0 = max(0, (W - Wt) // 2)
1049
+ return arr[y0:y0+Ht, x0:x0+Wt]
1050
+ else:
1051
+ C, H, W = arr.shape
1052
+ if H == Ht and W == Wt:
1053
+ return arr
1054
+ y0 = max(0, (H - Ht) // 2)
1055
+ x0 = max(0, (W - Wt) // 2)
1056
+ return arr[:, y0:y0+Ht, x0:x0+Wt]
1057
+
1058
+ def _sanitize_numeric(a):
1059
+ """Replace NaN/Inf, clip negatives, make contiguous float32."""
1060
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
1061
+ a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
1062
+ return np.ascontiguousarray(a)
1063
+
1064
+ # -----------------------------
1065
+ # PSF utilities
1066
+ # -----------------------------
1067
+
1068
+ def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
1069
+ sigma = max(fwhm_px, 1.0) / 2.3548
1070
+ r = (ksize - 1) / 2
1071
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1072
+ g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
1073
+ g /= (np.sum(g) + EPS)
1074
+ return g.astype(np.float32, copy=False)
1075
+
1076
+ def _estimate_fwhm_from_header(hdr) -> float:
1077
+ for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
1078
+ if key in hdr:
1079
+ try:
1080
+ val = float(hdr[key])
1081
+ if np.isfinite(val) and val > 0:
1082
+ return val
1083
+ except Exception:
1084
+ pass
1085
+ return float("nan")
1086
+
1087
+ def _estimate_fwhm_from_image(arr) -> float:
1088
+ """Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
1089
+ if sep is None:
1090
+ return float("nan")
1091
+ try:
1092
+ img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
1093
+ bkg = sep.Background(img)
1094
+ data = _contig(img - bkg.back()) # ← ensure data is C-contig
1095
+ try:
1096
+ err = bkg.globalrms
1097
+ except Exception:
1098
+ err = float(np.median(bkg.rms()))
1099
+ sources = sep.extract(data, 6.0, err=err)
1100
+ if sources is None or len(sources) == 0:
1101
+ return float("nan")
1102
+ a = np.asarray(sources["a"], dtype=np.float32)
1103
+ b = np.asarray(sources["b"], dtype=np.float32)
1104
+ ab = (a + b) * 0.5
1105
+ sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
1106
+ if not np.isfinite(sigma) or sigma <= 0:
1107
+ return float("nan")
1108
+ return 2.3548 * sigma
1109
+ except Exception:
1110
+ return float("nan")
1111
+
1112
+ def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
1113
+ """
1114
+ Choose odd kernel size to cover about ±4σ.
1115
+ """
1116
+ sigma = max(fwhm_px, 1.0) / 2.3548
1117
+ r = int(math.ceil(4.0 * sigma))
1118
+ k = 2 * r + 1
1119
+ k = max(kmin, min(k, kmax))
1120
+ if (k % 2) == 0:
1121
+ k += 1
1122
+ return k
1123
+
1124
+ def _flip_kernel(psf):
1125
+ # PyTorch dislikes negative strides; make it contiguous.
1126
+ return np.flip(np.flip(psf, -1), -2).copy()
1127
+
1128
+ def _conv_same_np(img, psf):
1129
+ # img: (H,W) or (C,H,W) numpy
1130
+ import numpy.fft as fft
1131
+ def fftconv2(a, k):
1132
+ H, W = a.shape[-2:]
1133
+ kh, kw = k.shape
1134
+ pad_h, pad_w = H + kh - 1, W + kw - 1
1135
+ A = fft.rfftn(a, s=(pad_h, pad_w), axes=(-2, -1))
1136
+ K = fft.rfftn(k, s=(pad_h, pad_w), axes=(-2, -1))
1137
+ Y = A * K
1138
+ y = fft.irfftn(Y, s=(pad_h, pad_w), axes=(-2, -1))
1139
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1140
+ return y[..., sh:sh+H, sw:sw+W]
1141
+ if img.ndim == 2:
1142
+ return fftconv2(img[None], psf)[0]
1143
+ else:
1144
+ return np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
1145
+
1146
+ def _normalize_psf(psf):
1147
+ psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
1148
+ s = float(psf.sum())
1149
+ if not np.isfinite(s) or s <= EPS:
1150
+ return psf
1151
+ return (psf / s).astype(np.float32, copy=False)
1152
+
1153
+ def _soften_psf(psf, sigma_px=0.25):
1154
+ # optional tiny Gaussian soften to reduce ringing; sigma<=0 disables
1155
+ if sigma_px <= 0:
1156
+ return psf
1157
+ r = int(max(1, round(3 * sigma_px)))
1158
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1159
+ g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
1160
+ g /= g.sum() + EPS
1161
+ return _conv_same_np(psf[None], g)[0]
1162
+
1163
+ def _psf_fwhm_px(psf: np.ndarray) -> float:
1164
+ """Approximate FWHM (pixels) from second moments of a normalized kernel."""
1165
+ psf = np.maximum(psf, 0).astype(np.float32, copy=False)
1166
+ s = float(psf.sum())
1167
+ if s <= EPS:
1168
+ return float("nan")
1169
+ k = psf.shape[0]
1170
+ y, x = np.mgrid[:k, :k].astype(np.float32)
1171
+ cy = float((psf * y).sum() / s)
1172
+ cx = float((psf * x).sum() / s)
1173
+ var_y = float((psf * (y - cy) ** 2).sum() / s)
1174
+ var_x = float((psf * (x - cx) ** 2).sum() / s)
1175
+ sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
1176
+ return 2.3548 * sigma # FWHM≈2.355σ
1177
+
1178
+ STAR_MASK_MAXSIDE = 2048
1179
+ STAR_MASK_MAXOBJS = 2000 # cap number of objects
1180
+ VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
1181
+ THRESHOLD_SIGMA = 2.0
1182
+ KEEP_FLOOR = 0.20
1183
+ GROW_PX = 8
1184
+ MAX_STAR_RADIUS = 16
1185
+ SOFT_SIGMA = 2.0
1186
+ ELLIPSE_SCALE = 1.2
1187
+
1188
+ def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
1189
+ """One-time SEP background build; returns (sky_map, rms_map, err_scalar)."""
1190
+ if sep is None:
1191
+ # robust fallback
1192
+ med = float(np.median(img_2d))
1193
+ mad = float(np.median(np.abs(img_2d - med))) + 1e-6
1194
+ sky = np.full_like(img_2d, med, dtype=np.float32)
1195
+ rmsm = np.full_like(img_2d, 1.4826 * mad, dtype=np.float32)
1196
+ return sky, rmsm, float(np.median(rmsm))
1197
+
1198
+ a = np.ascontiguousarray(img_2d.astype(np.float32))
1199
+ b = sep.Background(a, bw=int(bw), bh=int(bh), fw=3, fh=3)
1200
+ sky = np.asarray(b.back(), dtype=np.float32)
1201
+ try:
1202
+ rmsm = np.asarray(b.rms(), dtype=np.float32)
1203
+ err = float(b.globalrms)
1204
+ except Exception:
1205
+ rmsm = np.full_like(a, float(np.median(b.rms())), dtype=np.float32)
1206
+ err = float(np.median(rmsm))
1207
+ return sky, rmsm, err
1208
+
1209
+
1210
+ def _star_mask_from_precomputed(
1211
+ img_2d: np.ndarray,
1212
+ sky_map: np.ndarray,
1213
+ err_scalar: float,
1214
+ *,
1215
+ thresh_sigma: float,
1216
+ max_objs: int,
1217
+ grow_px: int,
1218
+ ellipse_scale: float,
1219
+ soft_sigma: float,
1220
+ max_radius_px: int,
1221
+ keep_floor: float,
1222
+ max_side: int,
1223
+ status_cb=lambda s: None
1224
+ ) -> np.ndarray:
1225
+ """
1226
+ Build a KEEP weight map using a *downscaled detection / full-res draw* path.
1227
+ **Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
1228
+ """
1229
+ # Optional OpenCV fast path
1230
+ try:
1231
+ import cv2 as _cv2
1232
+ _HAS_CV2 = True
1233
+ except Exception:
1234
+ _HAS_CV2 = False
1235
+ _cv2 = None # type: ignore
1236
+
1237
+ H, W = map(int, img_2d.shape)
1238
+
1239
+ # Residual for detection (contiguous, separate buffer)
1240
+ data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
1241
+
1242
+ # Downscale *detection only* to speed up, never the draw step
1243
+ det = data_sub
1244
+ scale = 1.0
1245
+ if max_side and max(H, W) > int(max_side):
1246
+ scale = float(max(H, W)) / float(max_side)
1247
+ if _HAS_CV2:
1248
+ det = _cv2.resize(
1249
+ det,
1250
+ (max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
1251
+ interpolation=_cv2.INTER_AREA
1252
+ )
1253
+ else:
1254
+ s = int(max(1, round(scale)))
1255
+ det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
1256
+ scale = float(s)
1257
+
1258
+ # Threshold ladder
1259
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1260
+ thresh_sigma*8, thresh_sigma*16]
1261
+ objs = None; used = float("nan"); raw = 0
1262
+ for t in thresholds:
1263
+ cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
1264
+ n = 0 if cand is None else len(cand)
1265
+ if n == 0: continue
1266
+ if n > max_objs*12: continue
1267
+ objs, raw, used = cand, n, float(t)
1268
+ break
1269
+
1270
+ if objs is None or len(objs) == 0:
1271
+ try:
1272
+ cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
1273
+ except Exception:
1274
+ cand = None
1275
+ if cand is None or len(cand) == 0:
1276
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1277
+ return np.ones((H, W), dtype=np.float32, order="C")
1278
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1279
+
1280
+ # Brightest max_objs
1281
+ if "flux" in objs.dtype.names:
1282
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1283
+ objs = objs[idx]
1284
+ else:
1285
+ objs = objs[:int(max_objs)]
1286
+ kept = len(objs)
1287
+
1288
+ # ---- draw back on full-res into a brand-new buffer ----
1289
+ mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
1290
+ s_back = float(scale)
1291
+ MR = int(max(1, max_radius_px))
1292
+ G = int(max(0, grow_px))
1293
+ ES = float(max(0.1, ellipse_scale))
1294
+
1295
+ drawn = 0
1296
+ if _HAS_CV2:
1297
+ for o in objs:
1298
+ x = int(round(float(o["x"]) * s_back))
1299
+ y = int(round(float(o["y"]) * s_back))
1300
+ if not (0 <= x < W and 0 <= y < H):
1301
+ continue
1302
+ a = float(o["a"]) * s_back
1303
+ b = float(o["b"]) * s_back
1304
+ r = int(math.ceil(ES * max(a, b)))
1305
+ r = min(max(r, 0) + G, MR)
1306
+ if r <= 0:
1307
+ continue
1308
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1309
+ drawn += 1
1310
+ else:
1311
+ for o in objs:
1312
+ x = int(round(float(o["x"]) * s_back))
1313
+ y = int(round(float(o["y"]) * s_back))
1314
+ if not (0 <= x < W and 0 <= y < H):
1315
+ continue
1316
+ a = float(o["a"]) * s_back
1317
+ b = float(o["b"]) * s_back
1318
+ r = int(math.ceil(ES * max(a, b)))
1319
+ r = min(max(r, 0) + G, MR)
1320
+ if r <= 0:
1321
+ continue
1322
+ y0 = max(0, y - r); y1 = min(H, y + r + 1)
1323
+ x0 = max(0, x - r); x1 = min(W, x + r + 1)
1324
+ yy, xx = np.ogrid[y0:y1, x0:x1]
1325
+ disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
1326
+ mask_u8[y0:y1, x0:x1][disk] = 1
1327
+ drawn += 1
1328
+
1329
+ # Feather + convert to keep weights
1330
+ m = mask_u8.astype(np.float32, copy=False)
1331
+ if soft_sigma > 0:
1332
+ try:
1333
+ if _HAS_CV2:
1334
+ k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
1335
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1336
+ borderType=_cv2.BORDER_REFLECT)
1337
+ else:
1338
+ from scipy.ndimage import gaussian_filter
1339
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1340
+ except Exception:
1341
+ pass
1342
+ np.clip(m, 0.0, 1.0, out=m)
1343
+
1344
+ keep = 1.0 - m
1345
+ kf = float(max(0.0, min(0.99, keep_floor)))
1346
+ keep = kf + (1.0 - kf) * keep
1347
+ np.clip(keep, 0.0, 1.0, out=keep)
1348
+
1349
+ status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
1350
+ return np.ascontiguousarray(keep, dtype=np.float32)
1351
+
1352
+
1353
+ def _ensure_scratch_dir(scratch_dir: str | None) -> str:
1354
+ """Ensure a writable scratch directory exists; default to system temp."""
1355
+ if scratch_dir is None or not isinstance(scratch_dir, str) or not scratch_dir.strip():
1356
+ scratch_dir = tempfile.gettempdir()
1357
+ os.makedirs(scratch_dir, exist_ok=True)
1358
+ return scratch_dir
1359
+
1360
+ def _mm_unique_path(scratch_dir: str, tag: str, ext: str = ".mm") -> str:
1361
+ """Return a unique file path (closed fd) for a memmap file."""
1362
+ fd, path = tempfile.mkstemp(prefix=f"sas_{tag}_", suffix=ext, dir=scratch_dir)
1363
+ try:
1364
+ os.close(fd)
1365
+ except Exception:
1366
+ pass
1367
+ return path
1368
+
1369
+ def _variance_map_from_precomputed(
1370
+ img_2d: np.ndarray,
1371
+ sky_map: np.ndarray,
1372
+ rms_map: np.ndarray,
1373
+ hdr,
1374
+ *,
1375
+ smooth_sigma: float,
1376
+ floor: float,
1377
+ status_cb=lambda s: None
1378
+ ) -> np.ndarray:
1379
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1380
+ var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
1381
+ obj_dn = np.clip(img - sky_map, 0.0, None)
1382
+
1383
+ gain = None
1384
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1385
+ if k in hdr:
1386
+ try:
1387
+ g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
1388
+ if gain is not None: break
1389
+ except Exception as e:
1390
+ import logging
1391
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1392
+
1393
+ if gain is not None:
1394
+ a_shot = 1.0 / gain
1395
+ else:
1396
+ sky_med = float(np.median(sky_map))
1397
+ varbg_med= float(np.median(var_bg_dn2))
1398
+ a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
1399
+ a_shot = float(np.clip(a_shot, 0.0, 10.0))
1400
+
1401
+ v = var_bg_dn2 + a_shot * obj_dn
1402
+ if smooth_sigma > 0:
1403
+ try:
1404
+ import cv2 as _cv2
1405
+ k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
1406
+ v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1407
+ except Exception:
1408
+ try:
1409
+ from scipy.ndimage import gaussian_filter
1410
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1411
+ except Exception:
1412
+ pass
1413
+
1414
+ np.clip(v, float(floor), None, out=v)
1415
+ try:
1416
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1417
+ status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
1418
+ except Exception:
1419
+ pass
1420
+ return v.astype(np.float32, copy=False)
1421
+
1422
+
1423
+ def _variance_map_from_precomputed_memmap(
1424
+ luma: np.ndarray,
1425
+ sky_map: np.ndarray,
1426
+ rms_map: np.ndarray,
1427
+ hdr,
1428
+ *,
1429
+ smooth_sigma: float = 1.0,
1430
+ floor: float = 1e-8,
1431
+ tile_hw: tuple[int, int] = (512, 512),
1432
+ scratch_dir: str | None = None,
1433
+ tag: str = "varmap",
1434
+ status_cb=lambda s: None,
1435
+ progress_cb=lambda f, m="": None,
1436
+ ) -> str:
1437
+ """
1438
+ Compute a per-pixel variance map exactly like _variance_map_from_precomputed,
1439
+ but stream to a disk-backed memmap file (RAM bounded). Returns the file path.
1440
+ """
1441
+ luma = np.asarray(luma, dtype=np.float32)
1442
+ sky_map = np.asarray(sky_map, dtype=np.float32)
1443
+ rms_map = np.asarray(rms_map, dtype=np.float32)
1444
+
1445
+ H, W = int(luma.shape[0]), int(luma.shape[1])
1446
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
1447
+
1448
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
1449
+ var_path = _mm_unique_path(scratch_dir, tag, ext=".mm")
1450
+
1451
+ # create writeable memmap
1452
+ var_mm = np.memmap(var_path, mode="w+", dtype=np.float32, shape=(H, W))
1453
+
1454
+ tiles = [(y, min(y + th, H), x, min(x + tw, W)) for y in range(0, H, th) for x in range(0, W, tw)]
1455
+ total = len(tiles)
1456
+
1457
+ for ti, (y0, y1, x0, x1) in enumerate(tiles, start=1):
1458
+ # compute tile using the existing exact routine
1459
+ v_tile = _variance_map_from_precomputed(
1460
+ luma[y0:y1, x0:x1],
1461
+ sky_map[y0:y1, x0:x1],
1462
+ rms_map[y0:y1, x0:x1],
1463
+ hdr,
1464
+ smooth_sigma=smooth_sigma,
1465
+ floor=floor,
1466
+ status_cb=lambda _s: None
1467
+ )
1468
+ var_mm[y0:y1, x0:x1] = v_tile
1469
+
1470
+ # free tile buffer promptly to avoid resident growth
1471
+ del v_tile
1472
+
1473
+ # periodic flush + progress
1474
+ if (ti & 7) == 0 or ti == total:
1475
+ try: var_mm.flush()
1476
+ except Exception as e:
1477
+ import logging
1478
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1479
+ try:
1480
+ progress_cb(ti / float(total), f"varmap tiles {ti}/{total}")
1481
+ except Exception:
1482
+ pass
1483
+
1484
+ try: var_mm.flush()
1485
+ except Exception as e:
1486
+ import logging
1487
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1488
+ # drop the handle (Option B)
1489
+ del var_mm
1490
+
1491
+ status_cb(f"Variance map written (memmap): {var_path} ({H}x{W})")
1492
+ return var_path
1493
+
1494
+
1495
+
1496
+ # -----------------------------
1497
+ # Robust weighting (Huber)
1498
+ # -----------------------------
1499
+
1500
+ def _estimate_scalar_variance_t(r):
1501
+ # r: tensor on device
1502
+ med = torch.median(r)
1503
+ mad = torch.median(torch.abs(r - med)) + 1e-6
1504
+ return (1.4826 * mad) ** 2
1505
+
1506
+ def _estimate_scalar_variance(a):
1507
+ med = np.median(a)
1508
+ mad = np.median(np.abs(a - med)) + 1e-6
1509
+ return float((1.4826 * mad) ** 2)
1510
+
1511
+ def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
1512
+ """
1513
+ Robust per-pixel weights for the MM update.
1514
+ W = [psi(r)/r] * 1/(var + eps) * mask
1515
+ If huber_delta < 0, delta = (-huber_delta) * RMS(residual) (auto).
1516
+ var_map: per-pixel variance (2D); if None, fall back to robust scalar via MAD.
1517
+ mask: 2D {0,1} validity; if None, treat as ones.
1518
+ """
1519
+ r = y - pred
1520
+ eps = EPS
1521
+
1522
+ # resolve Huber delta
1523
+ if huber_delta < 0:
1524
+ if TORCH_OK and isinstance(r, torch.Tensor):
1525
+ med = torch.median(r)
1526
+ mad = torch.median(torch.abs(r - med)) + 1e-6
1527
+ rms = 1.4826 * mad
1528
+ delta = (-huber_delta) * torch.clamp(rms, min=1e-6)
1529
+ else:
1530
+ med = np.median(r)
1531
+ mad = np.median(np.abs(r - med)) + 1e-6
1532
+ rms = 1.4826 * mad
1533
+ delta = (-huber_delta) * max(rms, 1e-6)
1534
+ else:
1535
+ delta = huber_delta
1536
+
1537
+ # psi(r)/r
1538
+ if TORCH_OK and isinstance(r, torch.Tensor):
1539
+ absr = torch.abs(r)
1540
+ if float(delta) > 0:
1541
+ psi_over_r = torch.where(absr <= delta, torch.ones_like(r), delta / (absr + eps))
1542
+ else:
1543
+ psi_over_r = torch.ones_like(r)
1544
+ if var_map is None:
1545
+ v = _estimate_scalar_variance_t(r)
1546
+ else:
1547
+ v = var_map
1548
+ if v.ndim == 2 and r.ndim == 3:
1549
+ v = v[None, ...] # broadcast over channels
1550
+ w = psi_over_r / (v + eps)
1551
+ if mask is not None:
1552
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1553
+ w = w * m
1554
+ return w
1555
+ else:
1556
+ absr = np.abs(r)
1557
+ if float(delta) > 0:
1558
+ psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + eps)).astype(np.float32)
1559
+ else:
1560
+ psi_over_r = np.ones_like(r, dtype=np.float32)
1561
+ if var_map is None:
1562
+ v = _estimate_scalar_variance(r)
1563
+ else:
1564
+ v = var_map
1565
+ if v.ndim == 2 and r.ndim == 3:
1566
+ v = v[None, ...]
1567
+ w = psi_over_r / (v + eps)
1568
+ if mask is not None:
1569
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1570
+ w = w * m
1571
+ return w
1572
+
1573
+
1574
+ # -----------------------------
1575
+ # Torch / conv
1576
+ # -----------------------------
1577
+
1578
+ def _fftshape_same(H, W, kh, kw):
1579
+ return H + kh - 1, W + kw - 1
1580
+
1581
+ # ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
1582
+ def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
1583
+ tfft = torch.fft
1584
+ psf_fft, psfT_fft = [], []
1585
+ for k, kT in zip(psfs, flip_psf):
1586
+ kh, kw = k.shape
1587
+ padH, padW = _fftshape_same(H, W, kh, kw)
1588
+
1589
+ # shift the small kernels to the origin, then FFT into padded size
1590
+ k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
1591
+ kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
1592
+
1593
+ Kf = tfft.rfftn(k_small, s=(padH, padW))
1594
+ KTf = tfft.rfftn(kT_small, s=(padH, padW))
1595
+
1596
+ psf_fft.append((Kf, padH, padW, kh, kw))
1597
+ psfT_fft.append((KTf, padH, padW, kh, kw))
1598
+ return psf_fft, psfT_fft
1599
+
1600
+
1601
+
1602
+ # ---------- NumPy FFT helpers ----------
1603
+ def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
1604
+ import numpy.fft as fft
1605
+ meta, Kfs, KTfs = [], [], []
1606
+ for k, kT in zip(psfs, flip_psf):
1607
+ kh, kw = k.shape
1608
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1609
+ Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
1610
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
1611
+ meta.append((kh, kw, fftH, fftW))
1612
+ return Kfs, KTfs, meta
1613
+
1614
+ def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
1615
+ import numpy.fft as fft
1616
+ if a.ndim == 2:
1617
+ A = fft.rfftn(a, s=(fftH, fftW))
1618
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1619
+ sh, sw = kh // 2, kw // 2
1620
+ out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
1621
+ return out
1622
+ else:
1623
+ C, H, W = a.shape
1624
+ acc = []
1625
+ for c in range(C):
1626
+ A = fft.rfftn(a[c], s=(fftH, fftW))
1627
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1628
+ sh, sw = kh // 2, kw // 2
1629
+ acc.append(y[sh:sh+H, sw:sw+W])
1630
+ out[...] = np.stack(acc, 0)
1631
+ return out
1632
+
1633
+
1634
+
1635
+ def _torch_device():
1636
+ if TORCH_OK and (torch is not None):
1637
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1638
+ return torch.device("cuda")
1639
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1640
+ return torch.device("mps")
1641
+ # DirectML: we passed dml_device from outer scope; keep a module-global
1642
+ if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
1643
+ return globals()["dml_device"]
1644
+ return torch.device("cpu")
1645
+
1646
+ def _to_t(x: np.ndarray):
1647
+ if not (TORCH_OK and (torch is not None)):
1648
+ raise RuntimeError("Torch path requested but torch is unavailable")
1649
+ device = _torch_device()
1650
+ t = torch.from_numpy(x)
1651
+ # DirectML wants explicit .to(device)
1652
+ return t.to(device, non_blocking=True) if str(device) != "cpu" else t
1653
+
1654
+ def _contig(x):
1655
+ return np.ascontiguousarray(x, dtype=np.float32)
1656
+
1657
+ def _conv_same_torch(img_t, psf_t):
1658
+ """
1659
+ img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
1660
+ psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
1661
+ Pads with 'reflect' to avoid zero-padding ringing.
1662
+ """
1663
+ kh, kw = psf_t.shape[-2:]
1664
+ pad = (kw // 2, kw - kw // 2 - 1, # left, right
1665
+ kh // 2, kh - kh // 2 - 1) # top, bottom
1666
+
1667
+ if img_t.ndim == 2:
1668
+ x = img_t[None, None]
1669
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1670
+ y = torch.nn.functional.conv2d(x, psf_t, padding=0)
1671
+ return y[0, 0]
1672
+ else:
1673
+ C = img_t.shape[0]
1674
+ x = img_t[None]
1675
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1676
+ w = psf_t.repeat(C, 1, 1, 1)
1677
+ y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
1678
+ return y[0]
1679
+
1680
+ def _safe_inference_context():
1681
+ """
1682
+ Return a valid, working no-grad context:
1683
+ - prefer torch.inference_mode() if it exists *and* can be entered,
1684
+ - otherwise fall back to torch.no_grad(),
1685
+ - if torch is unavailable, return NO_GRAD.
1686
+ """
1687
+ if not (TORCH_OK and (torch is not None)):
1688
+ return NO_GRAD
1689
+
1690
+ cm = getattr(torch, "inference_mode", None)
1691
+ if cm is None:
1692
+ return torch.no_grad
1693
+
1694
+ # Probe inference_mode once; if it explodes on this build, fall back.
1695
+ try:
1696
+ with cm():
1697
+ pass
1698
+ return cm
1699
+ except Exception:
1700
+ return torch.no_grad
1701
+
1702
+ def _ensure_mask_list(masks, data):
1703
+ # 1s where valid, 0s where invalid (soft edges allowed)
1704
+ if masks is None:
1705
+ return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
1706
+ out = []
1707
+ for a, m in zip(data, masks):
1708
+ base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
1709
+ if m is None:
1710
+ out.append(np.ones_like(base, dtype=np.float32))
1711
+ else:
1712
+ mm = np.asarray(m, dtype=np.float32)
1713
+ if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
1714
+ mm = mm[0]
1715
+ if mm.shape != base.shape:
1716
+ # center crop to match (common intersection already applied)
1717
+ Ht, Wt = base.shape
1718
+ mm = _center_crop(mm, Ht, Wt)
1719
+ # keep as float weights in [0,1] (do not threshold!)
1720
+ out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
1721
+ return out
1722
+
1723
+ def _ensure_var_list(variances, data):
1724
+ # If None, we’ll estimate a robust scalar per frame on-the-fly.
1725
+ if variances is None:
1726
+ return [None]*len(data)
1727
+ out = []
1728
+ for a, v in zip(data, variances):
1729
+ if v is None:
1730
+ out.append(None)
1731
+ else:
1732
+ vv = np.asarray(v, dtype=np.float32)
1733
+ if vv.ndim == 3:
1734
+ vv = vv[0]
1735
+ base = a if a.ndim==2 else a[0]
1736
+ if vv.shape != base.shape:
1737
+ Ht, Wt = base.shape
1738
+ vv = _center_crop(vv, Ht, Wt)
1739
+ # clip tiny/negatives
1740
+ vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
1741
+ out.append(vv)
1742
+ return out
1743
+
1744
+ # ---- SR operators (downsample / upsample-sum) ----
1745
+ def _downsample_avg(img, r: int):
1746
+ """Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
1747
+ if r <= 1:
1748
+ return img
1749
+ a = np.asarray(img, dtype=np.float32)
1750
+ if a.ndim == 2:
1751
+ H, W = a.shape
1752
+ Hs, Ws = (H // r) * r, (W // r) * r
1753
+ a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
1754
+ return a
1755
+ else:
1756
+ C, H, W = a.shape
1757
+ Hs, Ws = (H // r) * r, (W // r) * r
1758
+ a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
1759
+ return a
1760
+
1761
+ def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
1762
+ """Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
1763
+ For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
1764
+ """
1765
+ if r <= 1:
1766
+ return img
1767
+ a = np.asarray(img, dtype=np.float32)
1768
+ if a.ndim == 2:
1769
+ H, W = a.shape
1770
+ out = np.kron(a, np.ones((r, r), dtype=np.float32))
1771
+ else:
1772
+ C, H, W = a.shape
1773
+ out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
1774
+ if target_hw is not None:
1775
+ Ht, Wt = target_hw
1776
+ out = _center_crop(out, Ht, Wt)
1777
+ return out
1778
+
1779
+ def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
1780
+ r = (ksize - 1) // 2
1781
+ y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1782
+ g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
1783
+ g /= g.sum() + EPS
1784
+ return g
1785
+
1786
+ def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
1787
+ # lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
1788
+ return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
1789
+
1790
+ def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
1791
+ iters: int = 500, lr: float = 0.1) -> np.ndarray:
1792
+ """
1793
+ Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
1794
+ where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
1795
+ """
1796
+ f = np.asarray(f_native, dtype=np.float32)
1797
+ k = int(f.shape[0]); assert f.shape[0] == f.shape[1]
1798
+ kr = int(k * r)
1799
+
1800
+ # build Gaussian pre-blur at native scale (match paper §4.2)
1801
+ g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
1802
+
1803
+ # init h by zero-insertion (nearest upsample of f) then deconvolving g very mildly
1804
+ h0 = np.zeros((kr, kr), dtype=np.float32)
1805
+ h0[::r, ::r] = f
1806
+ h0 = _normalize_psf(h0)
1807
+
1808
+ if TORCH_OK:
1809
+ dev = _torch_device()
1810
+ t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
1811
+ f_t = torch.tensor(f, device=dev, dtype=torch.float32)
1812
+ g_t = torch.tensor(g, device=dev, dtype=torch.float32)
1813
+ opt = torch.optim.Adam([t], lr=lr)
1814
+ for _ in range(max(10, iters)):
1815
+ opt.zero_grad(set_to_none=True)
1816
+ H, W = t.shape
1817
+ Hr, Wr = H//r, W//r
1818
+ th = t[:Hr*r, :Wr*r].reshape(Hr, r, Wr, r).mean(dim=(1,3))
1819
+ # conv native: (Dh) * g
1820
+ conv = torch.nn.functional.conv2d(th[None,None], g_t[None,None], padding=g_t.shape[-1]//2)[0,0]
1821
+ loss = torch.mean((conv - f_t)**2)
1822
+ loss.backward()
1823
+ opt.step()
1824
+ with torch.no_grad():
1825
+ t.clamp_(min=0.0)
1826
+ t /= (t.sum() + 1e-8)
1827
+ h = t.detach().cpu().numpy().astype(np.float32)
1828
+ else:
1829
+ # Tiny gradient-descent fallback on numpy
1830
+ h = h0.copy()
1831
+ eta = float(lr)
1832
+ for _ in range(max(50, iters)):
1833
+ Dh = _downsample_avg(h, r)
1834
+ conv = _conv2_same_np(Dh, g)
1835
+ resid = (conv - f)
1836
+ # backprop through conv and D: grad wrt Dh is resid * g^T conv; adjoint of D is upsample-sum
1837
+ grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
1838
+ grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
1839
+ h = np.clip(h - eta * grad_h, 0.0, None)
1840
+ s = float(h.sum()); h /= (s + 1e-8)
1841
+ eta *= 0.995
1842
+ return _normalize_psf(h)
1843
+
1844
+ def _downsample_avg_t(x, r: int):
1845
+ """
1846
+ Average-pool over non-overlapping r×r blocks.
1847
+ Works for (H,W) or (C,H,W). Crops to multiples of r.
1848
+ """
1849
+ if r <= 1:
1850
+ return x
1851
+ if x.ndim == 2:
1852
+ H, W = x.shape
1853
+ Hr, Wr = (H // r) * r, (W // r) * r
1854
+ if Hr == 0 or Wr == 0:
1855
+ return x # nothing to pool
1856
+ x2 = x[:Hr, :Wr]
1857
+ return x2.view(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
1858
+ else:
1859
+ C, H, W = x.shape
1860
+ Hr, Wr = (H // r) * r, (W // r) * r
1861
+ if Hr == 0 or Wr == 0:
1862
+ return x
1863
+ x2 = x[:, :Hr, :Wr]
1864
+ return x2.view(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
1865
+
1866
+ def _upsample_sum_t(x, r: int):
1867
+ if r <= 1:
1868
+ return x
1869
+ if x.ndim == 2:
1870
+ return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
1871
+ else:
1872
+ return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
1873
+
1874
+ def _sep_bg_rms(frames):
1875
+ """Return a robust background RMS using SEP's background model on the first frame."""
1876
+ if sep is None or not frames:
1877
+ return None
1878
+ try:
1879
+ y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
1880
+ a = np.ascontiguousarray(y0, dtype=np.float32)
1881
+ b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
1882
+ try:
1883
+ rms_val = float(b.globalrms)
1884
+ except Exception:
1885
+ # some SEP builds don’t expose globalrms; fall back to the map’s median
1886
+ rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
1887
+ return rms_val
1888
+ except Exception:
1889
+ return None
1890
+
1891
+ # =========================
1892
+ # Memory/streaming helpers
1893
+ # =========================
1894
+
1895
+ def _approx_bytes(arr_like_shape, dtype=np.float32):
1896
+ """Rough byte estimator for a given shape/dtype."""
1897
+ return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
1898
+
1899
+
1900
+
1901
+ def _read_shape_fast(path) -> tuple[int,int,int]:
1902
+ if _is_xisf(path):
1903
+ a, _ = _load_image_array(path)
1904
+ if a is None:
1905
+ raise ValueError(f"No data in {path}")
1906
+ a = np.asarray(a)
1907
+ else:
1908
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
1909
+ a = hdul[0].data
1910
+ if a is None:
1911
+ raise ValueError(f"No data in {path}")
1912
+
1913
+ # common logic for both XISF and FITS
1914
+ if a.ndim == 2:
1915
+ H, W = a.shape
1916
+ return (1, int(H), int(W))
1917
+ if a.ndim == 3:
1918
+ if a.shape[-1] in (1, 3): # HWC
1919
+ C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
1920
+ return (1 if C == 1 else 3, H, W)
1921
+ if a.shape[0] in (1, 3): # CHW
1922
+ return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
1923
+ s = tuple(map(int, a.shape))
1924
+ H, W = s[-2], s[-1]
1925
+ return (1, H, W)
1926
+
1927
+ def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
1928
+ """FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
1929
+ ext = os.path.splitext(path)[1].lower()
1930
+
1931
+ if ext == ".xisf":
1932
+ a, _ = _load_image_array(path) # helper returns array-like + hdr/metadata
1933
+ if a is None:
1934
+ raise ValueError(f"XISF loader returned None for {path}")
1935
+ a = np.asarray(a)
1936
+ if a.ndim == 2: # HW
1937
+ return np.array(a[y0:y1, x0:x1], copy=True)
1938
+ elif a.ndim == 3:
1939
+ if a.shape[-1] in (1, 3): # HWC
1940
+ out = a[y0:y1, x0:x1, :]
1941
+ if out.shape[-1] == 1:
1942
+ out = out[..., 0]
1943
+ return np.array(out, copy=True)
1944
+ elif a.shape[0] in (1, 3): # CHW
1945
+ out = a[:, y0:y1, x0:x1]
1946
+ if out.shape[0] == 1:
1947
+ out = out[0]
1948
+ return np.array(out, copy=True)
1949
+ else:
1950
+ raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
1951
+ else:
1952
+ raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
1953
+
1954
+ # FITS
1955
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
1956
+ a = None
1957
+ for h in hdul:
1958
+ if getattr(h, "data", None) is not None:
1959
+ a = h.data
1960
+ break
1961
+ if a is None:
1962
+ raise ValueError(f"No image data in {path}")
1963
+
1964
+ a = np.asarray(a)
1965
+
1966
+ if a.ndim == 2: # HW
1967
+ return np.array(a[y0:y1, x0:x1], copy=True)
1968
+
1969
+ if a.ndim == 3:
1970
+ if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
1971
+ out = a[:, y0:y1, x0:x1]
1972
+ if out.shape[0] == 1: out = out[0]
1973
+ return np.array(out, copy=True)
1974
+ if a.shape[-1] in (1, 3): # HWC
1975
+ out = a[y0:y1, x0:x1, :]
1976
+ if out.shape[-1] == 1: out = out[..., 0]
1977
+ return np.array(out, copy=True)
1978
+
1979
+ # Fallback: assume last two axes are spatial (…, H, W)
1980
+ try:
1981
+ out = a[(..., slice(y0, y1), slice(x0, x1))]
1982
+ return np.array(out, copy=True)
1983
+ except Exception:
1984
+ raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
1985
+
1986
+ def _select_io_workers(n_frames: int,
1987
+ user_io_workers: int | None,
1988
+ tile_hw: tuple[int,int] = (256,256),
1989
+ disk_bound: bool = True) -> int:
1990
+ """Heuristic picker for I/O threadpool size."""
1991
+ try:
1992
+ cpu = os.cpu_count() or 4
1993
+ except Exception:
1994
+ cpu = 4
1995
+
1996
+ if user_io_workers is not None:
1997
+ # Respect caller, but clamp to sane bounds
1998
+ return max(1, min(cpu, int(user_io_workers)))
1999
+
2000
+ # Default: don’t oversubscribe CPU; don’t exceed frame count
2001
+ base = min(cpu, 8, int(n_frames))
2002
+
2003
+ # If we’re disk-bound (memmaps/tiling) or tiles are small, cap lower to reduce thrash
2004
+ th, tw = tile_hw
2005
+ if disk_bound or (th * tw <= 256 * 256):
2006
+ base = min(base, 4)
2007
+
2008
+ return max(1, base)
2009
+
2010
+
2011
+ def _seed_median_streaming(
2012
+ paths,
2013
+ Ht,
2014
+ Wt,
2015
+ *,
2016
+ color_mode="luma",
2017
+ tile_hw=(256, 256),
2018
+ status_cb=lambda s: None,
2019
+ progress_cb=lambda f, m="": None,
2020
+ io_workers: int = 4,
2021
+ scratch_dir: str | None = None,
2022
+ tile_loader=None, # NEW
2023
+ ):
2024
+ import time
2025
+ import gc
2026
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2027
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
2028
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2029
+
2030
+ # Prefer channel hint from tile_loader (so we don't reopen original files)
2031
+ if str(color_mode).lower() == "luma":
2032
+ want_c = 1
2033
+ else:
2034
+ want_c = getattr(tile_loader, "want_c", None)
2035
+ if not want_c:
2036
+ want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
2037
+
2038
+ seed = (np.zeros((Ht, Wt), np.float32)
2039
+ if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32))
2040
+
2041
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt))
2042
+ for y in range(0, Ht, th)
2043
+ for x in range(0, Wt, tw)]
2044
+ total = len(tiles)
2045
+ n_frames = len(paths)
2046
+
2047
+ io_workers_eff = _select_io_workers(n_frames, io_workers, tile_hw=tile_hw, disk_bound=True)
2048
+
2049
+ def _tile_msg(ti, tn): return f"median tiles {ti}/{tn}"
2050
+ done = 0
2051
+
2052
+ def _read_slab_for_channel(y0, y1, x0, x1, csel=None):
2053
+ h, w = (y1 - y0), (x1 - x0)
2054
+ # use fp16 slabs to halve peak RAM; we cast to f32 only for the median op
2055
+ slab = np.empty((n_frames, h, w), np.float16)
2056
+
2057
+ def _load_one(i):
2058
+ if tile_loader is not None:
2059
+ t = tile_loader(i, y0, y1, x0, x1, csel=csel)
2060
+ else:
2061
+ t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
2062
+ # luma/channel selection only for file reads
2063
+ if want_c == 1:
2064
+ if t.ndim != 2:
2065
+ t = _to_luma_local(t)
2066
+ else:
2067
+ if t.ndim == 2:
2068
+ pass
2069
+ elif t.ndim == 3 and t.shape[-1] == 3:
2070
+ t = t[..., int(csel)]
2071
+ elif t.ndim == 3 and t.shape[0] == 3:
2072
+ t = t[int(csel)]
2073
+ else:
2074
+ t = _to_luma_local(t)
2075
+
2076
+ # normalize to [0,1] if integer; store as fp16
2077
+ if t.dtype.kind in "ui":
2078
+ t = (t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)).astype(np.float16)
2079
+ else:
2080
+ t = t.astype(np.float16, copy=False)
2081
+ return i, np.ascontiguousarray(t, dtype=np.float16)
2082
+
2083
+ done_local = 0
2084
+ with ThreadPoolExecutor(max_workers=min(io_workers_eff, n_frames)) as ex:
2085
+ futures = [ex.submit(_load_one, i) for i in range(n_frames)]
2086
+ for fut in as_completed(futures):
2087
+ i, t2d = fut.result()
2088
+ if t2d.shape != (h, w):
2089
+ raise RuntimeError(
2090
+ f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
2091
+ f"tile={(y0,y1,x0,x1)}"
2092
+ )
2093
+ slab[i] = t2d
2094
+ done_local += 1
2095
+ if (done_local & 7) == 0 or done_local == n_frames:
2096
+ tile_base = done / total
2097
+ tile_span = 1.0 / total
2098
+ inner = done_local / n_frames
2099
+ progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
2100
+ return slab
2101
+
2102
+ for (y0, y1, x0, x1) in tiles:
2103
+ h, w = (y1 - y0), (x1 - x0)
2104
+ t0 = time.perf_counter()
2105
+
2106
+ if want_c == 1:
2107
+ slab = _read_slab_for_channel(y0, y1, x0, x1)
2108
+ t1 = time.perf_counter()
2109
+ med_np = np.median(slab.astype(np.float32, copy=False), axis=0).astype(np.float32, copy=False)
2110
+ t2 = time.perf_counter()
2111
+ seed[y0:y1, x0:x1] = med_np
2112
+ status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median=CPU={t2-t1:.3f}s")
2113
+ else:
2114
+ for c in range(int(want_c)):
2115
+ slab = _read_slab_for_channel(y0, y1, x0, x1, csel=c)
2116
+ med_np = np.median(slab.astype(np.float32, copy=False), axis=0).astype(np.float32, copy=False)
2117
+ seed[c, y0:y1, x0:x1] = med_np
2118
+
2119
+ # free tile buffers aggressively
2120
+ del slab
2121
+ try: del med_np
2122
+ except Exception as e:
2123
+ import logging
2124
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2125
+
2126
+ done += 1
2127
+ if (done & 1) == 0:
2128
+ gc.collect() # encourage RSS to drop
2129
+ progress_cb(done / total, _tile_msg(done, total))
2130
+ if (done & 3) == 0:
2131
+ _process_gui_events_safely()
2132
+
2133
+ status_cb(f"Median seed (CPU): want_c={want_c}, seed_shape={seed.shape}")
2134
+ return seed
2135
+
2136
+
2137
+
2138
+ def _seed_median_full_from_data(data_list):
2139
+ """
2140
+ data_list: list of np.ndarray each shaped either (H,W) or (C,H,W),
2141
+ already cropped/sanitized to the same size by the caller.
2142
+ Returns: (H,W) or (C,H,W) median image in float32.
2143
+ """
2144
+ if not data_list:
2145
+ raise ValueError("Empty stack for median seed")
2146
+
2147
+ a0 = data_list[0]
2148
+ if a0.ndim == 2:
2149
+ # (N, H, W) -> (H, W)
2150
+ cube = np.stack([np.asarray(a, dtype=np.float32, order="C") for a in data_list], axis=0)
2151
+ med = np.median(cube, axis=0).astype(np.float32, copy=False)
2152
+ return np.ascontiguousarray(med)
2153
+ else:
2154
+ # (N, C, H, W) -> (C, H, W)
2155
+ cube = np.stack([np.asarray(a, dtype=np.float32, order="C") for a in data_list], axis=0)
2156
+ med = np.median(cube, axis=0).astype(np.float32, copy=False)
2157
+ return np.ascontiguousarray(med)
2158
+
2159
+ def _infer_channels_from_tile(path: str, Ht: int, Wt: int) -> int:
2160
+ """
2161
+ Decide output channel count for median seeding in PerChannel mode.
2162
+ Returns 3 if the source is RGB, else 1.
2163
+ """
2164
+ C, _, _ = _read_shape_fast(path) # (C,H,W) with C in {1,3}
2165
+ return 3 if C == 3 else 1
2166
+
2167
+
2168
+ def _build_seed_running_mu_sigma_from_paths(paths, Ht, Wt, color_mode,
2169
+ *, bootstrap_frames=20, clip_sigma=5.0,
2170
+ status_cb=lambda s: None, progress_cb=lambda f,m='': None):
2171
+ K = max(1, min(int(bootstrap_frames), len(paths)))
2172
+ def _load_chw(i):
2173
+ ys, _ = _stack_loader_memmap([paths[i]], Ht, Wt, color_mode)
2174
+ return _as_chw(ys[0]).astype(np.float32, copy=False)
2175
+ x0 = _load_chw(0).copy()
2176
+ mean = x0; m2 = np.zeros_like(mean); count = 1
2177
+ for i in range(1, K):
2178
+ x = _load_chw(i); count += 1
2179
+ d = x - mean; mean += d/count; m2 += d*(x-mean)
2180
+ progress_cb(i/K*0.5, "μ-σ bootstrap")
2181
+ var = m2 / max(1, count-1); sigma = np.sqrt(np.clip(var, 1e-12, None)).astype(np.float32)
2182
+ lo = mean - float(clip_sigma)*sigma; hi = mean + float(clip_sigma)*sigma
2183
+ acc = np.zeros_like(mean); n=0
2184
+ for i in range(len(paths)):
2185
+ x = _load_chw(i); x = np.clip(x, lo, hi, out=x)
2186
+ acc += x; n += 1; progress_cb(0.5 + 0.5*(i+1)/len(paths), "clipped mean")
2187
+ seed = (acc/max(1,n)).astype(np.float32)
2188
+ return seed[0] if (seed.ndim==3 and seed.shape[0]==1) else seed
2189
+
2190
+ def _flush_memmaps(*arrs):
2191
+ for a in arrs:
2192
+ try:
2193
+ mm = a[0] if (isinstance(a, tuple) and isinstance(a[0], np.memmap)) else a
2194
+ if isinstance(mm, np.memmap):
2195
+ mm.flush()
2196
+ except Exception:
2197
+ pass
2198
+
2199
+ def _mm_create(shape, dtype=np.float32, scratch_dir=None, tag="mm"):
2200
+ root = scratch_dir or tempfile.gettempdir()
2201
+ os.makedirs(root, exist_ok=True)
2202
+ fd, path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".dat", dir=root)
2203
+ os.close(fd)
2204
+ mm = np.memmap(path, mode="w+", dtype=dtype, shape=tuple(shape))
2205
+ return mm, path
2206
+
2207
+ def _mm_flush(*arrs):
2208
+ for a in arrs:
2209
+ try:
2210
+ if isinstance(a, np.memmap):
2211
+ a.flush()
2212
+ except Exception:
2213
+ pass
2214
+
2215
+ def _gauss_tile(a, sigma):
2216
+ """
2217
+ Small helper: Gaussian blur a tile (2D float32), reflect borders.
2218
+ Prefers OpenCV; falls back to SciPy; otherwise no-op.
2219
+ """
2220
+ if sigma <= 0:
2221
+ return a
2222
+ k = int(max(1, int(round(3*sigma))) * 2 + 1)
2223
+ try:
2224
+ import cv2
2225
+ return cv2.GaussianBlur(a, (k, k), float(sigma), borderType=cv2.BORDER_REFLECT)
2226
+ except Exception:
2227
+ try:
2228
+ from scipy.ndimage import gaussian_filter
2229
+ return gaussian_filter(a, sigma=float(sigma), mode="reflect")
2230
+ except Exception:
2231
+ return a
2232
+
2233
+ # 1) add a tiny helper near _prepare_frame_stack_memmap
2234
+ def _make_memmap_tile_loader(frame_infos, max_open=32):
2235
+ """
2236
+ Returns tile_loader(i, y0,y1,x0,x1, csel=None) that slices from each frame's memmap.
2237
+ Keeps a tiny LRU cache of opened memmaps (handles only; not image-sized arrays).
2238
+ """
2239
+ opened = OrderedDict()
2240
+
2241
+ def _open_mm(i):
2242
+ fi = frame_infos[i]
2243
+ mm = np.memmap(fi["path"], mode="r", dtype=fi["dtype"], shape=fi["shape"])
2244
+ opened[i] = mm
2245
+ # evict least-recently used beyond max_open
2246
+ while len(opened) > int(max_open):
2247
+ _, old = opened.popitem(last=False)
2248
+ try: del old
2249
+ except Exception as e:
2250
+ import logging
2251
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2252
+ return mm
2253
+
2254
+ def tile_loader(i, y0, y1, x0, x1, csel=None):
2255
+ # reuse or open on demand
2256
+ mm = opened.get(i)
2257
+ if mm is None:
2258
+ mm = _open_mm(i)
2259
+ else:
2260
+ # bump LRU
2261
+ opened.move_to_end(i, last=True)
2262
+
2263
+ a = mm # (H,W) or (C,H,W)
2264
+ if a.ndim == 2:
2265
+ t = a[y0:y1, x0:x1]
2266
+ else:
2267
+ # (C,H,W); pick channel or default to first (luma-equivalent)
2268
+ cc = 0 if csel is None else int(csel)
2269
+ t = a[cc, y0:y1, x0:x1]
2270
+ # return a copy so median slab is independent/contiguous
2271
+ return np.array(t, copy=True)
2272
+
2273
+ # advertise channel count so the seeder doesn't reopen original files
2274
+ shp = frame_infos[0]["shape"]
2275
+ tile_loader.want_c = (shp[0] if (len(shp) == 3) else 1)
2276
+
2277
+ def _close():
2278
+ # drop handles
2279
+ while opened:
2280
+ _, mm = opened.popitem(last=False)
2281
+ try: del mm
2282
+ except Exception as e:
2283
+ import logging
2284
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2285
+ tile_loader.close = _close
2286
+
2287
+ return tile_loader
2288
+
2289
+
2290
+ def _prepare_frame_stack_memmap(
2291
+ paths: list[str],
2292
+ Ht: int,
2293
+ Wt: int,
2294
+ color_mode: str = "luma",
2295
+ *,
2296
+ scratch_dir: str | None = None,
2297
+ dtype: np.dtype | str = np.float32,
2298
+ tile_hw: tuple[int,int] = (512, 512),
2299
+ status_cb=lambda s: None,
2300
+ ):
2301
+ """
2302
+ Create one disk-backed memmap per input frame, already cropped to (Ht,Wt)
2303
+ and normalized to float32 (or requested dtype). Returns:
2304
+ frame_infos: list[dict(path, shape, dtype)]
2305
+ hdrs: list[fits.Header]
2306
+ Each memmap stores (H,W) or (C,H,W) in row-major order.
2307
+ """
2308
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
2309
+
2310
+ # normalize dtype
2311
+ if isinstance(dtype, str):
2312
+ _d = dtype.lower().strip()
2313
+ out_dtype = np.float16 if _d in ("float16","fp16","half") else np.float32
2314
+ else:
2315
+ out_dtype = np.dtype(dtype)
2316
+
2317
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2318
+ infos, hdrs = [], []
2319
+
2320
+ status_cb(f"Preparing {len(paths)} frame memmaps → {scratch_dir}")
2321
+ for idx, p in enumerate(paths, start=1):
2322
+ try:
2323
+ hdr = _safe_primary_header(p)
2324
+ except Exception:
2325
+ hdr = fits.Header()
2326
+ hdrs.append(hdr)
2327
+
2328
+ mode = str(color_mode).lower().strip()
2329
+ if mode == "luma":
2330
+ shape = (Ht, Wt) # 2D
2331
+ C_out = 1
2332
+ else:
2333
+ # Per-channel path: keep CHW even if mono → (1,H,W) or (3,H,W)
2334
+ C0, _, _ = _read_shape_fast(p)
2335
+ C_out = 3 if C0 == 3 else 1
2336
+ shape = (C_out, Ht, Wt) # 3D
2337
+
2338
+ mm_path = _mm_unique_path(scratch_dir, tag=f"frame_{idx:04d}", ext=".mm")
2339
+ mm = np.memmap(mm_path, mode="w+", dtype=out_dtype, shape=shape)
2340
+ mm_is_3d = (mm.ndim == 3)
2341
+
2342
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt))
2343
+ for y in range(0, Ht, th) for x in range(0, Wt, tw)]
2344
+
2345
+ for (y0, y1, x0, x1) in tiles:
2346
+ # 1) read source tile
2347
+ t = _read_tile_fits_any(p, y0, y1, x0, x1)
2348
+
2349
+ # 2) normalize to float32 in [0,1] if integer input
2350
+ if t.dtype.kind in "ui":
2351
+ t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
2352
+ else:
2353
+ t = t.astype(np.float32, copy=False)
2354
+
2355
+ # 3) layout to match memmap
2356
+ if not mm_is_3d:
2357
+ # target is 2D (Ht,Wt) — luma tile must be 2D
2358
+ if t.ndim == 3:
2359
+ t = _to_luma_local(t)
2360
+ elif t.ndim != 2:
2361
+ t = _to_luma_local(t)
2362
+ if out_dtype != np.float32:
2363
+ t = t.astype(out_dtype, copy=False)
2364
+ mm[y0:y1, x0:x1] = t
2365
+ else:
2366
+ # target is 3D (C,H,W)
2367
+ if C_out == 3:
2368
+ # ensure CHW
2369
+ if t.ndim == 2:
2370
+ # replicate luma across 3 channels
2371
+ t = np.stack([t, t, t], axis=0) # CHW
2372
+ elif t.ndim == 3 and t.shape[-1] == 3: # HWC → CHW
2373
+ t = np.moveaxis(t, -1, 0)
2374
+ elif t.ndim == 3 and t.shape[0] == 3: # already CHW
2375
+ pass
2376
+ else:
2377
+ t = _to_luma_local(t)
2378
+ t = np.stack([t, t, t], axis=0)
2379
+ if out_dtype != np.float32:
2380
+ t = t.astype(out_dtype, copy=False)
2381
+ mm[:, y0:y1, x0:x1] = t
2382
+ else:
2383
+ # C_out == 1: store single channel at mm[0, ...]
2384
+ if t.ndim == 3:
2385
+ t = _to_luma_local(t)
2386
+ # t must be 2D here
2387
+ if out_dtype != np.float32:
2388
+ t = t.astype(out_dtype, copy=False)
2389
+ mm[0, y0:y1, x0:x1] = t
2390
+
2391
+ try: mm.flush()
2392
+ except Exception as e:
2393
+ import logging
2394
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2395
+ del mm
2396
+
2397
+ infos.append({"path": mm_path, "shape": tuple(shape), "dtype": out_dtype})
2398
+
2399
+ if (idx % 8) == 0 or idx == len(paths):
2400
+ status_cb(f"Frame memmaps: {idx}/{len(paths)} ready")
2401
+ gc.collect()
2402
+
2403
+ return infos, hdrs
2404
+
2405
+
2406
+ # -----------------------------
2407
+ # Core
2408
+ # -----------------------------
2409
+ def multiframe_deconv(
2410
+ paths,
2411
+ out_path,
2412
+ iters=20,
2413
+ kappa=2.0,
2414
+ color_mode="luma",
2415
+ seed_mode: str = "robust",
2416
+ huber_delta=0.0,
2417
+ masks=None,
2418
+ variances=None,
2419
+ rho="huber",
2420
+ status_cb=lambda s: None,
2421
+ min_iters: int = 3,
2422
+ use_star_masks: bool = False,
2423
+ use_variance_maps: bool = False,
2424
+ star_mask_cfg: dict | None = None,
2425
+ varmap_cfg: dict | None = None,
2426
+ save_intermediate: bool = False,
2427
+ save_every: int = 1,
2428
+ # >>> SR options
2429
+ super_res_factor: int = 1,
2430
+ sr_sigma: float = 1.1,
2431
+ sr_psf_opt_iters: int = 250,
2432
+ sr_psf_opt_lr: float = 0.1,
2433
+ star_mask_ref_path: str | None = None,
2434
+ scratch_to_disk: bool = True, # spill large scratch arrays to memmap
2435
+ scratch_dir: str | None = None, # where to put memmaps (default: tempdir)
2436
+ memmap_threshold_mb: int = 512, # always memmap if buffer > this
2437
+ force_cpu: bool = False, # disable torch entirely unless caller opts in
2438
+ cache_psf_ffts: str = "disk", # 'disk' | 'ram' | 'none' (see §3)
2439
+ fft_reuse_across_iters: bool = True, # keep PSF FFTs across iters (same math)
2440
+ io_workers: int | None = None, # cap I/O threadpool (seed/tiles)
2441
+ blas_threads: int = 1, # limit BLAS threads to avoid oversub
2442
+ ):
2443
+ # sanitize and clamp
2444
+ max_iters = max(1, int(iters))
2445
+ min_iters = max(1, int(min_iters))
2446
+ if min_iters > max_iters:
2447
+ min_iters = max_iters
2448
+
2449
+ def _emit_pct(pct: float, msg: str | None = None):
2450
+ pct = float(max(0.0, min(1.0, pct)))
2451
+ status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
2452
+
2453
+ status_cb(f"MFDeconv: loading {len(paths)} aligned frames…")
2454
+ _emit_pct(0.02, "loading")
2455
+
2456
+ # Use unified probe to pick a common crop without loading full images
2457
+ Ht, Wt = _common_hw_from_paths(paths)
2458
+ _emit_pct(0.05, "preparing")
2459
+
2460
+ # Stream actual pixels cropped to (Ht,Wt), float32 CHW/2D + headers
2461
+ frame_infos, hdrs = _prepare_frame_stack_memmap(
2462
+ paths, Ht, Wt, color_mode,
2463
+ scratch_dir=scratch_dir,
2464
+ dtype=np.float32, # or pull from a cfg
2465
+ tile_hw=(512, 512),
2466
+ status_cb=status_cb,
2467
+ )
2468
+
2469
+ tile_loader = _make_memmap_tile_loader(frame_infos, max_open=32)
2470
+
2471
+ def _open_frame_numpy(i: int) -> np.ndarray:
2472
+ fi = frame_infos[i]
2473
+ a = np.memmap(fi["path"], mode="r", dtype=fi["dtype"], shape=fi["shape"])
2474
+ # Solver expects float32 math; cast on read (no copy if already f32)
2475
+ return np.asarray(a, dtype=np.float32)
2476
+
2477
+ # For functions that only need luma/HW:
2478
+ def _open_frame_hw(i: int) -> np.ndarray:
2479
+ arr = _open_frame_numpy(i)
2480
+ if arr.ndim == 3:
2481
+ return arr[0] # use first/luma channel consistently
2482
+ return arr
2483
+
2484
+
2485
+ relax = 0.7
2486
+ use_torch = False
2487
+ global torch, TORCH_OK
2488
+
2489
+ # -------- try to import torch from per-user runtime venv --------
2490
+ # -------- try to import torch from per-user runtime venv --------
2491
+ torch = None
2492
+ TORCH_OK = False
2493
+ cuda_ok = mps_ok = dml_ok = False
2494
+ dml_device = None
2495
+ try:
2496
+ from setiastro.saspro.runtime_torch import import_torch
2497
+ torch = import_torch(prefer_cuda=True, status_cb=status_cb)
2498
+ TORCH_OK = True
2499
+
2500
+ try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
2501
+ except Exception: cuda_ok = False
2502
+ try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
2503
+ except Exception: mps_ok = False
2504
+ try:
2505
+ import torch_directml
2506
+ dml_device = torch_directml.device()
2507
+ _ = (torch.ones(1, device=dml_device) + 1).item()
2508
+ dml_ok = True
2509
+ except Exception:
2510
+ dml_ok = False
2511
+
2512
+ if cuda_ok:
2513
+ status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
2514
+ elif mps_ok:
2515
+ status_cb("PyTorch MPS (Apple) available: True")
2516
+ elif dml_ok:
2517
+ status_cb("PyTorch DirectML (Windows) available: True")
2518
+ else:
2519
+ status_cb("PyTorch present, using CPU backend.")
2520
+
2521
+ status_cb(
2522
+ f"PyTorch {getattr(torch, '__version__', '?')} backend: "
2523
+ + ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
2524
+ )
2525
+ except Exception as e:
2526
+ TORCH_OK = False
2527
+ status_cb(f"PyTorch not available → CPU path. ({e})")
2528
+
2529
+ # ----------------------------
2530
+ # Torch usage policy gate (STEP 5)
2531
+ # ----------------------------
2532
+ # 1) Hard off-switch from caller
2533
+ if force_cpu:
2534
+ TORCH_OK = False
2535
+ torch = None
2536
+ status_cb("Torch disabled by policy: force_cpu=True → using NumPy everywhere.")
2537
+
2538
+ # 2) (Optional) clamp BLAS threads globally to avoid oversubscription
2539
+ try:
2540
+ from threadpoolctl import threadpool_limits
2541
+ _blas_ctx = threadpool_limits(limits=int(max(1, blas_threads)))
2542
+ except Exception:
2543
+ _blas_ctx = contextlib.nullcontext()
2544
+
2545
+ use_torch = bool(TORCH_OK)
2546
+
2547
+ # Only configure Torch backends if policy allowed Torch
2548
+ if use_torch:
2549
+ # ----- Precision policy (strict FP32) -----
2550
+ try:
2551
+ torch.backends.cudnn.benchmark = True
2552
+ if hasattr(torch.backends, "cudnn"):
2553
+ torch.backends.cudnn.allow_tf32 = False
2554
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
2555
+ torch.backends.cuda.matmul.allow_tf32 = False
2556
+ if hasattr(torch, "set_float32_matmul_precision"):
2557
+ torch.set_float32_matmul_precision("highest")
2558
+ except Exception:
2559
+ pass
2560
+
2561
+ try:
2562
+ c_tf32 = getattr(torch.backends.cudnn, "allow_tf32", None)
2563
+ m_tf32 = getattr(getattr(torch.backends.cuda, "matmul", object()), "allow_tf32", None)
2564
+ status_cb(
2565
+ f"Precision: cudnn.allow_tf32={c_tf32} | "
2566
+ f"matmul.allow_tf32={m_tf32} | "
2567
+ f"benchmark={getattr(torch.backends.cudnn, 'benchmark', None)}"
2568
+ )
2569
+ except Exception:
2570
+ pass
2571
+
2572
+
2573
+ _process_gui_events_safely()
2574
+
2575
+ # PSFs (auto-size per frame) + flipped copies
2576
+ psf_out_dir = None
2577
+ psfs, masks_auto, vars_auto, var_paths = _build_psf_and_assets(
2578
+ paths,
2579
+ make_masks=bool(use_star_masks),
2580
+ make_varmaps=bool(use_variance_maps),
2581
+ status_cb=status_cb,
2582
+ save_dir=None,
2583
+ star_mask_cfg=star_mask_cfg,
2584
+ varmap_cfg=varmap_cfg,
2585
+ star_mask_ref_path=star_mask_ref_path,
2586
+ Ht=Ht, Wt=Wt, color_mode=color_mode,
2587
+ )
2588
+
2589
+ # >>> SR: lift PSFs to super-res if requested
2590
+ r = int(max(1, super_res_factor))
2591
+ if r > 1:
2592
+ status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
2593
+ _process_gui_events_safely()
2594
+ sr_psfs = []
2595
+ for i, k_native in enumerate(psfs, start=1):
2596
+ h = _solve_super_psf_from_native(k_native, r=r, sigma=float(sr_sigma),
2597
+ iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr))
2598
+ sr_psfs.append(h)
2599
+ status_cb(f" SR-PSF{i}: native {k_native.shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
2600
+ psfs = sr_psfs
2601
+
2602
+ flip_psf = [_flip_kernel(k) for k in psfs]
2603
+ _emit_pct(0.20, "PSF Ready")
2604
+
2605
+
2606
+ # --- SR/native seed ---
2607
+ seed_mode_s = str(seed_mode).lower().strip()
2608
+ if seed_mode_s not in ("robust", "median"):
2609
+ seed_mode_s = "robust"
2610
+
2611
+ if seed_mode_s == "median":
2612
+ status_cb("MFDeconv: Building median seed (streaming, CPU)…")
2613
+ try:
2614
+ seed_native = _seed_median_streaming(
2615
+ paths,
2616
+ Ht,
2617
+ Wt,
2618
+ color_mode=color_mode,
2619
+ tile_hw=(256, 256),
2620
+ status_cb=status_cb,
2621
+ progress_cb=lambda f, m="": _emit_pct(0.10 + 0.10 * f, f"median seed: {m}"),
2622
+ io_workers=io_workers,
2623
+ scratch_dir=scratch_dir,
2624
+ tile_loader=tile_loader, # <<< use the memmap-backed tiles
2625
+ )
2626
+ finally:
2627
+ # drop any open memmap handles held by the loader
2628
+ try: tile_loader.close()
2629
+ except Exception as e:
2630
+ import logging
2631
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2632
+ import gc as _gc; _gc.collect()
2633
+ else:
2634
+ status_cb("MFDeconv: Building robust seed (live μ-σ stacking)…")
2635
+ seed_native = _build_seed_running_mu_sigma_from_paths(
2636
+ paths, Ht, Wt, color_mode,
2637
+ bootstrap_frames=20, clip_sigma=5.0,
2638
+ status_cb=status_cb, progress_cb=lambda f,m='': None
2639
+ )
2640
+ if r > 1:
2641
+ if seed_native.ndim == 2:
2642
+ x = _upsample_sum(seed_native / (r*r), r, target_hw=(Ht*r, Wt*r))
2643
+ else:
2644
+ C, Hn, Wn = seed_native.shape
2645
+ x = np.stack(
2646
+ [_upsample_sum(seed_native[c] / (r*r), r, target_hw=(Hn*r, Wn*r)) for c in range(C)],
2647
+ axis=0
2648
+ )
2649
+ else:
2650
+ x = seed_native
2651
+ # Ensure CHW shape under PerChannel, even for mono (C=1)
2652
+ if str(color_mode).lower() != "luma" and x.ndim == 2:
2653
+ x = x[None, ...] # (1,H,W) to match frame & Torch ops
2654
+
2655
+ # Robust H,W extraction
2656
+ Hs, Ws = (x.shape[-2], x.shape[-1]) if x.ndim >= 2 else (Ht, Wt)
2657
+
2658
+ # masks/vars aligned to native grid (2D each)
2659
+ auto_masks = masks_auto if use_star_masks else None
2660
+ auto_vars = vars_auto if use_variance_maps else None
2661
+ base_template = np.empty((Ht, Wt), dtype=np.float32)
2662
+ data_like = [base_template] * len(paths)
2663
+
2664
+ # replace the bad call with:
2665
+ mask_list = _ensure_mask_list(
2666
+ masks if masks is not None else masks_auto,
2667
+ data_like
2668
+ )
2669
+ # ...
2670
+ if use_variance_maps and var_paths is not None:
2671
+ def _open_var_mm(p):
2672
+ return None if p is None else np.memmap(p, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
2673
+ var_list = [_open_var_mm(p) for p in var_paths]
2674
+ else:
2675
+ var_list = [None] * len(paths)
2676
+
2677
+ iter_dir = None
2678
+ hdr0_seed = None
2679
+ if save_intermediate:
2680
+ iter_dir = _iter_folder(out_path)
2681
+ status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
2682
+ try:
2683
+ hdr0_seed = _safe_primary_header(paths[0])
2684
+ except Exception:
2685
+ hdr0_seed = fits.Header()
2686
+ _save_iter_image(x, hdr0_seed, iter_dir, "seed", color_mode)
2687
+
2688
+ status_cb("MFDeconv: Calculating Backgrounds and MADs…")
2689
+ _process_gui_events_safely()
2690
+ y0 = _open_frame_hw(0)
2691
+ bg_est = _sep_bg_rms([y0]) or (np.median(np.abs(y0 - np.median(y0))) * 1.4826)
2692
+ del y0
2693
+ status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
2694
+ _process_gui_events_safely()
2695
+
2696
+ status_cb("Computing FFTs and Allocating Scratch…")
2697
+ _process_gui_events_safely()
2698
+
2699
+ # -------- precompute and allocate scratch --------
2700
+ pred_super = None # CPU-only temp; avoid UnboundLocal on Torch path
2701
+ tmp_out = None # CPU-only temp; avoid UnboundLocal on Torch path
2702
+ def _arr_only(x):
2703
+ """Accept either ndarray/memmap or (memmap, path) and return the array."""
2704
+ if isinstance(x, tuple) and len(x) == 2 and hasattr(x[0], "dtype"):
2705
+ return x[0]
2706
+ return x
2707
+ per_frame_logging = (r > 1)
2708
+ if use_torch:
2709
+ x_t = _to_t(_contig(x))
2710
+ num = torch.zeros_like(x_t)
2711
+ den = torch.zeros_like(x_t)
2712
+
2713
+ if r > 1:
2714
+ # >>> SR path now uses SPATIAL CONV (cuDNN) to avoid huge FFT buffers
2715
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
2716
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
2717
+ else:
2718
+ # Native spatial (as before)
2719
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs]
2720
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
2721
+
2722
+ else:
2723
+ # ---------- CPU path (NumPy) ----------
2724
+ x_t = x
2725
+
2726
+ # Determine working size
2727
+ if x_t.ndim == 2:
2728
+ Hs, Ws = x_t.shape
2729
+ else:
2730
+ _, Hs, Ws = x_t.shape
2731
+
2732
+ # Choose PSF FFT caching policy
2733
+ # Expect this to be threaded in from function args; fallback to "ram"
2734
+ cache_psf_ffts = locals().get("cache_psf_ffts", "ram")
2735
+ scratch_dir = locals().get("scratch_dir", None)
2736
+
2737
+ import numpy.fft as _fft
2738
+
2739
+ Kfs = KTfs = meta = None
2740
+
2741
+ if cache_psf_ffts == "ram":
2742
+ # Original behavior: keep FFTs in RAM
2743
+ Kfs, KTfs, meta = _precompute_np_psf_ffts(psfs, flip_psf, Hs, Ws)
2744
+
2745
+ elif cache_psf_ffts == "disk":
2746
+ # New behavior: keep FFTs in disk-backed memmaps to save RAM
2747
+ # Requires _mm_create() helper (Step 2). If not added yet, set cache_psf_ffts="ram".
2748
+ Kfs, KTfs, meta = [], [], []
2749
+ for idx, (k, kT) in enumerate(zip(psfs, flip_psf), start=1):
2750
+ kh, kw = k.shape
2751
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2752
+
2753
+ # Create complex64 memmaps for rfftn grids (H, W//2+1)
2754
+ Kf_mm, Kf_path = _mm_create((fftH, fftW//2 + 1), np.complex64, scratch_dir, tag=f"Kf_{idx}")
2755
+ KTf_mm, KTf_path = _mm_create((fftH, fftW//2 + 1), np.complex64, scratch_dir, tag=f"KTf_{idx}")
2756
+
2757
+ # Compute once into the memmaps (same math)
2758
+ Kf_mm[...] = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW)).astype(np.complex64, copy=False)
2759
+ KTf_mm[...] = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW)).astype(np.complex64, copy=False)
2760
+ Kf_mm.flush(); KTf_mm.flush()
2761
+
2762
+ Kfs.append(Kf_mm)
2763
+ KTfs.append(KTf_mm)
2764
+ meta.append((kh, kw, fftH, fftW))
2765
+
2766
+ elif cache_psf_ffts == "none":
2767
+ # Don’t precompute; compute per-frame inside the iter loop (same math, less RAM, more CPU).
2768
+ Kfs = KTfs = meta = None
2769
+ else:
2770
+ # Fallback to RAM behavior
2771
+ cache_psf_ffts = "ram"
2772
+ Kfs, KTfs, meta = _precompute_np_psf_ffts(psfs, flip_psf, Hs, Ws)
2773
+
2774
+ # Allocate CPU scratch (keep as-is for Step 3)
2775
+ def _shape_of(a): return a.shape, a.dtype
2776
+
2777
+ # Always keep x_t in RAM for speed (it’s the only array updated iteratively)
2778
+ # But allow opting into memmap if strictly necessary:
2779
+ if scratch_to_disk and (_approx_bytes(x.shape, x.dtype) / 1e6 > memmap_threshold_mb):
2780
+ x_mm, x_path = _mm_create(x.shape, x.dtype, scratch_dir, tag="x")
2781
+ x_mm[...] = x # copy the seed into the memmap
2782
+ x_t = x_mm
2783
+ else:
2784
+ x_t = x
2785
+
2786
+ num = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2787
+ force_mm=scratch_to_disk,
2788
+ threshold_mb=memmap_threshold_mb,
2789
+ scratch_dir=scratch_dir, tag="num"))
2790
+ den = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2791
+ force_mm=scratch_to_disk,
2792
+ threshold_mb=memmap_threshold_mb,
2793
+ scratch_dir=scratch_dir, tag="den"))
2794
+
2795
+ pred_super = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2796
+ force_mm=scratch_to_disk,
2797
+ threshold_mb=memmap_threshold_mb,
2798
+ scratch_dir=scratch_dir, tag="pred"))
2799
+ tmp_out = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2800
+ force_mm=scratch_to_disk,
2801
+ threshold_mb=memmap_threshold_mb,
2802
+ scratch_dir=scratch_dir, tag="tmp"))
2803
+
2804
+
2805
+ _to_check = [('x_t', x_t), ('num', num), ('den', den)]
2806
+ if not use_torch:
2807
+ _to_check += [('pred_super', pred_super), ('tmp_out', tmp_out)]
2808
+ for _name, _arr in _to_check:
2809
+ assert hasattr(_arr, 'shape'), f"{_name} must be array-like with .shape, got {type(_arr)}"
2810
+ _flush_memmaps(num, den)
2811
+
2812
+ # CPU-only scratch; may not exist on Torch path
2813
+ if isinstance(pred_super, np.memmap):
2814
+ _flush_memmaps(pred_super)
2815
+ if isinstance(tmp_out, np.memmap):
2816
+ _flush_memmaps(tmp_out)
2817
+
2818
+
2819
+ status_cb("Starting First Multiplicative Iteration…")
2820
+ _process_gui_events_safely()
2821
+
2822
+ cm = _safe_inference_context() if use_torch else NO_GRAD
2823
+ rho_is_l2 = (str(rho).lower() == "l2")
2824
+ local_delta = 0.0 if rho_is_l2 else huber_delta
2825
+
2826
+
2827
+ auto_delta_cache = None
2828
+ if use_torch and (huber_delta < 0) and (not rho_is_l2):
2829
+ auto_delta_cache = [None] * len(paths)
2830
+ # ---- unified EarlyStopper ----
2831
+ early = EarlyStopper(
2832
+ tol_upd_floor=1e-3,
2833
+ tol_rel_floor=5e-4,
2834
+ early_frac=0.40,
2835
+ ema_alpha=0.5,
2836
+ patience=2,
2837
+ min_iters=min_iters,
2838
+ status_cb=status_cb
2839
+ )
2840
+
2841
+ used_iters = 0
2842
+ early_stopped = False
2843
+
2844
+ with cm():
2845
+ for it in range(1, max_iters + 1):
2846
+ if use_torch:
2847
+ num.zero_(); den.zero_()
2848
+
2849
+ if r > 1:
2850
+ # -------- SR path (SPATIAL conv + stream) --------
2851
+ for fidx, (wk, wkT) in enumerate(zip(psf_t, psfT_t)):
2852
+ yt_np = _open_frame_numpy(fidx) # CHW or HW (CPU)
2853
+ mt_np = mask_list[fidx]
2854
+ vt_np = var_list[fidx]
2855
+
2856
+ yt = torch.as_tensor(yt_np, dtype=x_t.dtype, device=x_t.device)
2857
+ mt = None if mt_np is None else torch.as_tensor(mt_np, dtype=x_t.dtype, device=x_t.device)
2858
+ vt = None if vt_np is None else torch.as_tensor(vt_np, dtype=x_t.dtype, device=x_t.device)
2859
+
2860
+ # SR conv on grid of x_t
2861
+ pred_sr = _conv_same_torch(x_t, wk) # SR grid
2862
+ pred_low = _downsample_avg_t(pred_sr, r) # native grid
2863
+
2864
+ if auto_delta_cache is not None:
2865
+ if (auto_delta_cache[fidx] is None) or (it % 5 == 1):
2866
+ rnat = yt - pred_low
2867
+ med = torch.median(rnat)
2868
+ mad = torch.median(torch.abs(rnat - med)) + 1e-6
2869
+ rms = 1.4826 * mad
2870
+ auto_delta_cache[fidx] = float((-huber_delta) * torch.clamp(rms, min=1e-6).item())
2871
+ wmap_low = _weight_map(yt, pred_low, auto_delta_cache[fidx], var_map=vt, mask=mt)
2872
+ else:
2873
+ wmap_low = _weight_map(yt, pred_low, local_delta, var_map=vt, mask=mt)
2874
+
2875
+ # lift back to SR via sum-replicate
2876
+ up_y = _upsample_sum_t(wmap_low * yt, r)
2877
+ up_pred = _upsample_sum_t(wmap_low * pred_low, r)
2878
+
2879
+ # accumulate via adjoint kernel on SR grid
2880
+ num += _conv_same_torch(up_y, wkT)
2881
+ den += _conv_same_torch(up_pred, wkT)
2882
+
2883
+ # free temps as aggressively as possible
2884
+ del yt, mt, vt, pred_sr, pred_low, wmap_low, up_y, up_pred
2885
+ if cuda_ok:
2886
+ try: torch.cuda.empty_cache()
2887
+ except Exception as e:
2888
+ import logging
2889
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2890
+
2891
+ if per_frame_logging and ((fidx & 7) == 0):
2892
+ status_cb(f"Iter {it}/{max_iters} — frame {fidx+1}/{len(paths)} (SR spatial)")
2893
+
2894
+ del yt_np
2895
+
2896
+ else:
2897
+ # -------- Native path (spatial conv, stream) --------
2898
+ for fidx, (wk, wkT) in enumerate(zip(psf_t, psfT_t)):
2899
+ yt_np = _open_frame_numpy(fidx) # CHW or HW (CPU → to Torch tensor)
2900
+ mt_np = mask_list[fidx]
2901
+ vt_np = var_list[fidx]
2902
+
2903
+ yt = torch.as_tensor(yt_np, dtype=x_t.dtype, device=x_t.device)
2904
+ mt = None if mt_np is None else torch.as_tensor(mt_np, dtype=x_t.dtype, device=x_t.device)
2905
+ vt = None if vt_np is None else torch.as_tensor(vt_np, dtype=x_t.dtype, device=x_t.device)
2906
+
2907
+ pred = _conv_same_torch(x_t, wk)
2908
+ wmap = _weight_map(yt, pred, local_delta, var_map=vt, mask=mt)
2909
+ up_y = wmap * yt
2910
+ up_pred = wmap * pred
2911
+ num += _conv_same_torch(up_y, wkT)
2912
+ den += _conv_same_torch(up_pred, wkT)
2913
+
2914
+ del yt, mt, vt, pred, wmap, up_y, up_pred
2915
+ if cuda_ok:
2916
+ try: torch.cuda.empty_cache()
2917
+ except Exception as e:
2918
+ import logging
2919
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2920
+
2921
+ ratio = num / (den + EPS)
2922
+ neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
2923
+ ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
2924
+ upd = torch.clamp(ratio, 1.0 / kappa, kappa)
2925
+ x_next = torch.clamp(x_t * upd, min=0.0)
2926
+
2927
+ upd_med = torch.median(torch.abs(upd - 1))
2928
+ rel_change = (torch.median(torch.abs(x_next - x_t)) /
2929
+ (torch.median(torch.abs(x_t)) + 1e-8))
2930
+
2931
+ um = float(upd_med.detach().cpu().item())
2932
+ rc = float(rel_change.detach().cpu().item())
2933
+
2934
+ if early.step(it, max_iters, um, rc):
2935
+ x_t = x_next
2936
+ used_iters = it
2937
+ early_stopped = True
2938
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
2939
+ _process_gui_events_safely()
2940
+ break
2941
+
2942
+
2943
+ x_t = (1.0 - relax) * x_t + relax * x_next
2944
+
2945
+ else:
2946
+ # -------- NumPy path (fixed, no 'data') --------
2947
+ num.fill(0.0); den.fill(0.0)
2948
+
2949
+ if r > 1:
2950
+ # -------- Super-resolution (NumPy) --------
2951
+ if cache_psf_ffts == "none":
2952
+ # No precomputed PSF FFTs → compute per-frame per-iter
2953
+ for fidx, (m2d, v2d) in enumerate(zip(mask_list, var_list)):
2954
+ # Load native frame on demand (CHW or HW)
2955
+ y_nat = _open_frame_numpy(fidx)
2956
+
2957
+ # PSF for this frame
2958
+ k, kT = psfs[fidx], flip_psf[fidx]
2959
+ kh, kw = k.shape
2960
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2961
+
2962
+ # Per-frame FFTs (same math as precomputed branch)
2963
+ Kf = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW))
2964
+ KTf = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW))
2965
+
2966
+ # Convolve current estimate x_t → SR prediction, then downsample
2967
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
2968
+ pred_low = _downsample_avg(pred_super, r)
2969
+
2970
+ # Weight map in native grid
2971
+ wmap_low = _weight_map(y_nat, pred_low, local_delta, var_map=v2d, mask=m2d)
2972
+
2973
+ # Lift back to SR via sum-replicate
2974
+ up_y = _upsample_sum(wmap_low * y_nat, r, target_hw=pred_super.shape[-2:])
2975
+ up_pred = _upsample_sum(wmap_low * pred_low, r, target_hw=pred_super.shape[-2:])
2976
+
2977
+ # Accumulate adjoint contributions
2978
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
2979
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
2980
+
2981
+ del y_nat, up_y, up_pred, wmap_low, pred_low, Kf, KTf
2982
+
2983
+ else:
2984
+ # Precomputed PSF FFTs (RAM or disk memmap)
2985
+ for (Kf, KTf, (kh, kw, fftH, fftW)), m2d, pvar, fidx in zip(
2986
+ zip(Kfs, KTfs, meta),
2987
+ mask_list,
2988
+ (var_paths or [None] * len(frame_infos)),
2989
+ range(len(frame_infos)),
2990
+ ):
2991
+ y_nat = _open_frame_numpy(fidx) # CHW or HW
2992
+
2993
+ vt_np = None
2994
+ if use_variance_maps and pvar is not None:
2995
+ vt_np = np.memmap(pvar, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
2996
+
2997
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
2998
+ pred = pred_super
2999
+
3000
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=vt_np, mask=m2d)
3001
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3002
+
3003
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3004
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3005
+
3006
+ if vt_np is not None:
3007
+ try:
3008
+ del vt_np
3009
+ except Exception as e:
3010
+ import logging
3011
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3012
+ del y_nat, up_y, up_pred, wmap, pred
3013
+
3014
+ else:
3015
+ # -------- Native (NumPy) --------
3016
+ if cache_psf_ffts == "none":
3017
+ # No precomputed PSF FFTs → compute per-frame per-iter
3018
+ for fidx, (m2d, v2d) in enumerate(zip(mask_list, var_list)):
3019
+ y_nat = _open_frame_numpy(fidx)
3020
+
3021
+ k, kT = psfs[fidx], flip_psf[fidx]
3022
+ kh, kw = k.shape
3023
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
3024
+
3025
+ Kf = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW))
3026
+ KTf = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW))
3027
+
3028
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3029
+ pred = pred_super
3030
+
3031
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=v2d, mask=m2d)
3032
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3033
+
3034
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3035
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3036
+
3037
+ del y_nat, up_y, up_pred, wmap, pred, Kf, KTf
3038
+
3039
+ else:
3040
+ # Precomputed PSF FFTs (RAM or disk memmap)
3041
+ for (Kf, KTf, (kh, kw, fftH, fftW)), m2d, pvar, fidx in zip(
3042
+ zip(Kfs, KTfs, meta),
3043
+ mask_list,
3044
+ (var_paths or [None] * len(frame_infos)),
3045
+ range(len(frame_infos)),
3046
+ ):
3047
+ y_nat = _open_frame_numpy(fidx)
3048
+
3049
+ vt_np = None
3050
+ if use_variance_maps and pvar is not None:
3051
+ vt_np = np.memmap(pvar, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
3052
+
3053
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3054
+ pred = pred_super
3055
+
3056
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=vt_np, mask=m2d)
3057
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3058
+
3059
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3060
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3061
+
3062
+ if vt_np is not None:
3063
+ try:
3064
+ del vt_np
3065
+ except Exception as e:
3066
+ import logging
3067
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3068
+ del y_nat, up_y, up_pred, wmap, pred
3069
+
3070
+ # --- multiplicative update (NumPy) ---
3071
+ ratio = num / (den + EPS)
3072
+ neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
3073
+ ratio[neutral] = 1.0
3074
+
3075
+ upd = np.clip(ratio, 1.0 / kappa, kappa)
3076
+ x_next = np.clip(x_t * upd, 0.0, None)
3077
+
3078
+ upd_med = np.median(np.abs(upd - 1.0))
3079
+ rel_change = (
3080
+ np.median(np.abs(x_next - x_t)) /
3081
+ (np.median(np.abs(x_t)) + 1e-8)
3082
+ )
3083
+
3084
+ um = float(upd_med)
3085
+ rc = float(rel_change)
3086
+
3087
+ if early.step(it, max_iters, um, rc):
3088
+ x_t = x_next
3089
+ used_iters = it
3090
+ early_stopped = True
3091
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3092
+ _process_gui_events_safely()
3093
+ break
3094
+
3095
+ x_t = (1.0 - relax) * x_t + relax * x_next
3096
+
3097
+
3098
+ # save intermediate
3099
+ if save_intermediate and (it % int(max(1, save_every)) == 0):
3100
+ try:
3101
+ x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3102
+ _save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
3103
+ except Exception as _e:
3104
+ status_cb(f"Intermediate save failed at iter {it}: {_e}")
3105
+
3106
+ frac = 0.25 + 0.70 * (it / float(max_iters))
3107
+ _emit_pct(frac, f"Iteration {it}/{max_iters}")
3108
+ status_cb(f"Iter {it}/{max_iters}")
3109
+ _process_gui_events_safely()
3110
+ _flush_memmaps(num, den)
3111
+
3112
+ # If present in your CPU path / SR path:
3113
+ if isinstance(pred_super, np.memmap):
3114
+ _flush_memmaps(pred_super)
3115
+ if isinstance(tmp_out, np.memmap):
3116
+ _flush_memmaps(tmp_out)
3117
+
3118
+ if not early_stopped:
3119
+ used_iters = max_iters
3120
+
3121
+ # ----------------------------
3122
+ # Save result (keep FITS-friendly order: (C,H,W))
3123
+ # ----------------------------
3124
+ _emit_pct(0.97, "saving")
3125
+ x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3126
+
3127
+ if x_final.ndim == 3:
3128
+ if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
3129
+ x_final = np.moveaxis(x_final, -1, 0)
3130
+ if x_final.shape[0] == 1:
3131
+ x_final = x_final[0]
3132
+
3133
+ try:
3134
+ hdr0 = _safe_primary_header(paths[0])
3135
+ except Exception:
3136
+ hdr0 = fits.Header()
3137
+ hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
3138
+ hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
3139
+ hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
3140
+ hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
3141
+ hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
3142
+ hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
3143
+
3144
+ hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
3145
+ if r > 1:
3146
+ hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (pixels at native)')
3147
+ hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
3148
+
3149
+ hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
3150
+ hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
3151
+ hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
3152
+
3153
+ if isinstance(x_final, np.ndarray):
3154
+ if x_final.ndim == 2:
3155
+ hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
3156
+ elif x_final.ndim == 3:
3157
+ C, H, W = x_final.shape
3158
+ hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
3159
+ _flush_memmaps(x_t, num, den)
3160
+ if isinstance(pred_super, np.memmap):
3161
+ _flush_memmaps(pred_super)
3162
+ if isinstance(tmp_out, np.memmap):
3163
+ _flush_memmaps(tmp_out)
3164
+ save_path = _sr_out_path(out_path, super_res_factor)
3165
+ safe_out_path = _nonclobber_path(str(save_path))
3166
+ if safe_out_path != str(save_path):
3167
+ status_cb(f"Output exists — saving as: {safe_out_path}")
3168
+ fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
3169
+
3170
+ status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
3171
+ _emit_pct(1.00, "done")
3172
+ _process_gui_events_safely()
3173
+
3174
+ try:
3175
+ if use_torch:
3176
+ try: del num, den
3177
+ except Exception as e:
3178
+ import logging
3179
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3180
+ try: del psf_t, psfT_t
3181
+ except Exception as e:
3182
+ import logging
3183
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3184
+ _free_torch_memory()
3185
+ except Exception:
3186
+ pass
3187
+
3188
+ return safe_out_path
3189
+
3190
+
3191
+
3192
+ # -----------------------------
3193
+ # Worker
3194
+ # -----------------------------
3195
+
3196
+ class MultiFrameDeconvWorkercuDNN(QObject):
3197
+ progress = pyqtSignal(str)
3198
+ finished = pyqtSignal(bool, str, str) # success, message, out_path
3199
+
3200
+ def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
3201
+ huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
3202
+ star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
3203
+ save_intermediate: bool = False,
3204
+ seed_mode: str = "robust",
3205
+ # NEW SR params
3206
+ super_res_factor: int = 1,
3207
+ sr_sigma: float = 1.1,
3208
+ sr_psf_opt_iters: int = 250,
3209
+ sr_psf_opt_lr: float = 0.1,
3210
+ star_mask_ref_path: str | None = None):
3211
+ super().__init__(parent)
3212
+ self.aligned_paths = aligned_paths
3213
+ self.output_path = output_path
3214
+ self.iters = iters
3215
+ self.kappa = kappa
3216
+ self.color_mode = color_mode
3217
+ self.huber_delta = huber_delta
3218
+ self.min_iters = min_iters # NEW
3219
+ self.star_mask_cfg = star_mask_cfg or {}
3220
+ self.varmap_cfg = varmap_cfg or {}
3221
+ self.use_star_masks = use_star_masks
3222
+ self.use_variance_maps = use_variance_maps
3223
+ self.rho = rho
3224
+ self.save_intermediate = save_intermediate
3225
+ self.super_res_factor = int(super_res_factor)
3226
+ self.sr_sigma = float(sr_sigma)
3227
+ self.sr_psf_opt_iters = int(sr_psf_opt_iters)
3228
+ self.sr_psf_opt_lr = float(sr_psf_opt_lr)
3229
+ self.star_mask_ref_path = star_mask_ref_path
3230
+ self.seed_mode = seed_mode
3231
+
3232
+
3233
+ def _log(self, s): self.progress.emit(s)
3234
+
3235
+ def run(self):
3236
+ try:
3237
+ out = multiframe_deconv(
3238
+ self.aligned_paths,
3239
+ self.output_path,
3240
+ iters=self.iters,
3241
+ kappa=self.kappa,
3242
+ color_mode=self.color_mode,
3243
+ seed_mode=self.seed_mode,
3244
+ huber_delta=self.huber_delta,
3245
+ use_star_masks=self.use_star_masks,
3246
+ use_variance_maps=self.use_variance_maps,
3247
+ rho=self.rho,
3248
+ min_iters=self.min_iters,
3249
+ status_cb=self._log,
3250
+ star_mask_cfg=self.star_mask_cfg,
3251
+ varmap_cfg=self.varmap_cfg,
3252
+ save_intermediate=self.save_intermediate,
3253
+ # NEW SR forwards
3254
+ super_res_factor=self.super_res_factor,
3255
+ sr_sigma=self.sr_sigma,
3256
+ sr_psf_opt_iters=self.sr_psf_opt_iters,
3257
+ sr_psf_opt_lr=self.sr_psf_opt_lr,
3258
+ star_mask_ref_path=self.star_mask_ref_path,
3259
+ )
3260
+ self.finished.emit(True, "MF deconvolution complete.", out)
3261
+ _process_gui_events_safely()
3262
+ except Exception as e:
3263
+ self.finished.emit(False, f"MF deconvolution failed: {e}", "")