setiastrosuitepro 1.6.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (367) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/data/SASP_data.fits +0 -0
  3. setiastro/data/catalogs/List_of_Galaxies_with_Distances_Gly.csv +488 -0
  4. setiastro/data/catalogs/astrobin_filters.csv +890 -0
  5. setiastro/data/catalogs/astrobin_filters_page1_local.csv +51 -0
  6. setiastro/data/catalogs/cali2.csv +63 -0
  7. setiastro/data/catalogs/cali2color.csv +65 -0
  8. setiastro/data/catalogs/celestial_catalog - original.csv +16471 -0
  9. setiastro/data/catalogs/celestial_catalog.csv +24031 -0
  10. setiastro/data/catalogs/detected_stars.csv +24784 -0
  11. setiastro/data/catalogs/fits_header_data.csv +46 -0
  12. setiastro/data/catalogs/test.csv +8 -0
  13. setiastro/data/catalogs/updated_celestial_catalog.csv +16471 -0
  14. setiastro/images/Astro_Spikes.png +0 -0
  15. setiastro/images/Background_startup.jpg +0 -0
  16. setiastro/images/HRDiagram.png +0 -0
  17. setiastro/images/LExtract.png +0 -0
  18. setiastro/images/LInsert.png +0 -0
  19. setiastro/images/Oxygenation-atm-2.svg.png +0 -0
  20. setiastro/images/RGB080604.png +0 -0
  21. setiastro/images/abeicon.png +0 -0
  22. setiastro/images/aberration.png +0 -0
  23. setiastro/images/andromedatry.png +0 -0
  24. setiastro/images/andromedatry_satellited.png +0 -0
  25. setiastro/images/annotated.png +0 -0
  26. setiastro/images/aperture.png +0 -0
  27. setiastro/images/astrosuite.ico +0 -0
  28. setiastro/images/astrosuite.png +0 -0
  29. setiastro/images/astrosuitepro.icns +0 -0
  30. setiastro/images/astrosuitepro.ico +0 -0
  31. setiastro/images/astrosuitepro.png +0 -0
  32. setiastro/images/background.png +0 -0
  33. setiastro/images/background2.png +0 -0
  34. setiastro/images/benchmark.png +0 -0
  35. setiastro/images/big_moon_stabilizer_timeline.png +0 -0
  36. setiastro/images/big_moon_stabilizer_timeline_clean.png +0 -0
  37. setiastro/images/blaster.png +0 -0
  38. setiastro/images/blink.png +0 -0
  39. setiastro/images/clahe.png +0 -0
  40. setiastro/images/collage.png +0 -0
  41. setiastro/images/colorwheel.png +0 -0
  42. setiastro/images/contsub.png +0 -0
  43. setiastro/images/convo.png +0 -0
  44. setiastro/images/copyslot.png +0 -0
  45. setiastro/images/cosmic.png +0 -0
  46. setiastro/images/cosmicsat.png +0 -0
  47. setiastro/images/crop1.png +0 -0
  48. setiastro/images/cropicon.png +0 -0
  49. setiastro/images/curves.png +0 -0
  50. setiastro/images/cvs.png +0 -0
  51. setiastro/images/debayer.png +0 -0
  52. setiastro/images/denoise_cnn_custom.png +0 -0
  53. setiastro/images/denoise_cnn_graph.png +0 -0
  54. setiastro/images/disk.png +0 -0
  55. setiastro/images/dse.png +0 -0
  56. setiastro/images/exoicon.png +0 -0
  57. setiastro/images/eye.png +0 -0
  58. setiastro/images/fliphorizontal.png +0 -0
  59. setiastro/images/flipvertical.png +0 -0
  60. setiastro/images/font.png +0 -0
  61. setiastro/images/freqsep.png +0 -0
  62. setiastro/images/functionbundle.png +0 -0
  63. setiastro/images/graxpert.png +0 -0
  64. setiastro/images/green.png +0 -0
  65. setiastro/images/gridicon.png +0 -0
  66. setiastro/images/halo.png +0 -0
  67. setiastro/images/hdr.png +0 -0
  68. setiastro/images/histogram.png +0 -0
  69. setiastro/images/hubble.png +0 -0
  70. setiastro/images/imagecombine.png +0 -0
  71. setiastro/images/invert.png +0 -0
  72. setiastro/images/isophote.png +0 -0
  73. setiastro/images/isophote_demo_figure.png +0 -0
  74. setiastro/images/isophote_demo_image.png +0 -0
  75. setiastro/images/isophote_demo_model.png +0 -0
  76. setiastro/images/isophote_demo_residual.png +0 -0
  77. setiastro/images/jwstpupil.png +0 -0
  78. setiastro/images/linearfit.png +0 -0
  79. setiastro/images/livestacking.png +0 -0
  80. setiastro/images/mask.png +0 -0
  81. setiastro/images/maskapply.png +0 -0
  82. setiastro/images/maskcreate.png +0 -0
  83. setiastro/images/maskremove.png +0 -0
  84. setiastro/images/morpho.png +0 -0
  85. setiastro/images/mosaic.png +0 -0
  86. setiastro/images/multiscale_decomp.png +0 -0
  87. setiastro/images/nbtorgb.png +0 -0
  88. setiastro/images/neutral.png +0 -0
  89. setiastro/images/nuke.png +0 -0
  90. setiastro/images/openfile.png +0 -0
  91. setiastro/images/pedestal.png +0 -0
  92. setiastro/images/pen.png +0 -0
  93. setiastro/images/pixelmath.png +0 -0
  94. setiastro/images/platesolve.png +0 -0
  95. setiastro/images/ppp.png +0 -0
  96. setiastro/images/pro.png +0 -0
  97. setiastro/images/project.png +0 -0
  98. setiastro/images/psf.png +0 -0
  99. setiastro/images/redo.png +0 -0
  100. setiastro/images/redoicon.png +0 -0
  101. setiastro/images/rescale.png +0 -0
  102. setiastro/images/rgbalign.png +0 -0
  103. setiastro/images/rgbcombo.png +0 -0
  104. setiastro/images/rgbextract.png +0 -0
  105. setiastro/images/rotate180.png +0 -0
  106. setiastro/images/rotateclockwise.png +0 -0
  107. setiastro/images/rotatecounterclockwise.png +0 -0
  108. setiastro/images/satellite.png +0 -0
  109. setiastro/images/script.png +0 -0
  110. setiastro/images/selectivecolor.png +0 -0
  111. setiastro/images/simbad.png +0 -0
  112. setiastro/images/slot0.png +0 -0
  113. setiastro/images/slot1.png +0 -0
  114. setiastro/images/slot2.png +0 -0
  115. setiastro/images/slot3.png +0 -0
  116. setiastro/images/slot4.png +0 -0
  117. setiastro/images/slot5.png +0 -0
  118. setiastro/images/slot6.png +0 -0
  119. setiastro/images/slot7.png +0 -0
  120. setiastro/images/slot8.png +0 -0
  121. setiastro/images/slot9.png +0 -0
  122. setiastro/images/spcc.png +0 -0
  123. setiastro/images/spin_precession_vs_lunar_distance.png +0 -0
  124. setiastro/images/spinner.gif +0 -0
  125. setiastro/images/stacking.png +0 -0
  126. setiastro/images/staradd.png +0 -0
  127. setiastro/images/staralign.png +0 -0
  128. setiastro/images/starnet.png +0 -0
  129. setiastro/images/starregistration.png +0 -0
  130. setiastro/images/starspike.png +0 -0
  131. setiastro/images/starstretch.png +0 -0
  132. setiastro/images/statstretch.png +0 -0
  133. setiastro/images/supernova.png +0 -0
  134. setiastro/images/uhs.png +0 -0
  135. setiastro/images/undoicon.png +0 -0
  136. setiastro/images/upscale.png +0 -0
  137. setiastro/images/viewbundle.png +0 -0
  138. setiastro/images/whitebalance.png +0 -0
  139. setiastro/images/wimi_icon_256x256.png +0 -0
  140. setiastro/images/wimilogo.png +0 -0
  141. setiastro/images/wims.png +0 -0
  142. setiastro/images/wrench_icon.png +0 -0
  143. setiastro/images/xisfliberator.png +0 -0
  144. setiastro/qml/ResourceMonitor.qml +126 -0
  145. setiastro/saspro/__init__.py +20 -0
  146. setiastro/saspro/__main__.py +945 -0
  147. setiastro/saspro/_generated/__init__.py +7 -0
  148. setiastro/saspro/_generated/build_info.py +3 -0
  149. setiastro/saspro/abe.py +1346 -0
  150. setiastro/saspro/abe_preset.py +196 -0
  151. setiastro/saspro/aberration_ai.py +694 -0
  152. setiastro/saspro/aberration_ai_preset.py +224 -0
  153. setiastro/saspro/accel_installer.py +218 -0
  154. setiastro/saspro/accel_workers.py +30 -0
  155. setiastro/saspro/add_stars.py +624 -0
  156. setiastro/saspro/astrobin_exporter.py +1010 -0
  157. setiastro/saspro/astrospike.py +153 -0
  158. setiastro/saspro/astrospike_python.py +1841 -0
  159. setiastro/saspro/autostretch.py +198 -0
  160. setiastro/saspro/backgroundneutral.py +602 -0
  161. setiastro/saspro/batch_convert.py +328 -0
  162. setiastro/saspro/batch_renamer.py +522 -0
  163. setiastro/saspro/blemish_blaster.py +491 -0
  164. setiastro/saspro/blink_comparator_pro.py +2926 -0
  165. setiastro/saspro/bundles.py +61 -0
  166. setiastro/saspro/bundles_dock.py +114 -0
  167. setiastro/saspro/cheat_sheet.py +213 -0
  168. setiastro/saspro/clahe.py +368 -0
  169. setiastro/saspro/comet_stacking.py +1442 -0
  170. setiastro/saspro/common_tr.py +107 -0
  171. setiastro/saspro/config.py +38 -0
  172. setiastro/saspro/config_bootstrap.py +40 -0
  173. setiastro/saspro/config_manager.py +316 -0
  174. setiastro/saspro/continuum_subtract.py +1617 -0
  175. setiastro/saspro/convo.py +1400 -0
  176. setiastro/saspro/convo_preset.py +414 -0
  177. setiastro/saspro/copyastro.py +190 -0
  178. setiastro/saspro/cosmicclarity.py +1589 -0
  179. setiastro/saspro/cosmicclarity_preset.py +407 -0
  180. setiastro/saspro/crop_dialog_pro.py +973 -0
  181. setiastro/saspro/crop_preset.py +189 -0
  182. setiastro/saspro/curve_editor_pro.py +2562 -0
  183. setiastro/saspro/curves_preset.py +375 -0
  184. setiastro/saspro/debayer.py +673 -0
  185. setiastro/saspro/debug_utils.py +29 -0
  186. setiastro/saspro/dnd_mime.py +35 -0
  187. setiastro/saspro/doc_manager.py +2664 -0
  188. setiastro/saspro/exoplanet_detector.py +2166 -0
  189. setiastro/saspro/file_utils.py +284 -0
  190. setiastro/saspro/fitsmodifier.py +748 -0
  191. setiastro/saspro/fix_bom.py +32 -0
  192. setiastro/saspro/free_torch_memory.py +48 -0
  193. setiastro/saspro/frequency_separation.py +1349 -0
  194. setiastro/saspro/function_bundle.py +1596 -0
  195. setiastro/saspro/generate_translations.py +3092 -0
  196. setiastro/saspro/ghs_dialog_pro.py +663 -0
  197. setiastro/saspro/ghs_preset.py +284 -0
  198. setiastro/saspro/graxpert.py +637 -0
  199. setiastro/saspro/graxpert_preset.py +287 -0
  200. setiastro/saspro/gui/__init__.py +0 -0
  201. setiastro/saspro/gui/main_window.py +8810 -0
  202. setiastro/saspro/gui/mixins/__init__.py +33 -0
  203. setiastro/saspro/gui/mixins/dock_mixin.py +362 -0
  204. setiastro/saspro/gui/mixins/file_mixin.py +450 -0
  205. setiastro/saspro/gui/mixins/geometry_mixin.py +403 -0
  206. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  207. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  208. setiastro/saspro/gui/mixins/menu_mixin.py +389 -0
  209. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  210. setiastro/saspro/gui/mixins/toolbar_mixin.py +1457 -0
  211. setiastro/saspro/gui/mixins/update_mixin.py +309 -0
  212. setiastro/saspro/gui/mixins/view_mixin.py +435 -0
  213. setiastro/saspro/gui/statistics_dialog.py +47 -0
  214. setiastro/saspro/halobgon.py +488 -0
  215. setiastro/saspro/header_viewer.py +448 -0
  216. setiastro/saspro/headless_utils.py +88 -0
  217. setiastro/saspro/histogram.py +756 -0
  218. setiastro/saspro/history_explorer.py +941 -0
  219. setiastro/saspro/i18n.py +168 -0
  220. setiastro/saspro/image_combine.py +417 -0
  221. setiastro/saspro/image_peeker_pro.py +1604 -0
  222. setiastro/saspro/imageops/__init__.py +37 -0
  223. setiastro/saspro/imageops/mdi_snap.py +292 -0
  224. setiastro/saspro/imageops/scnr.py +36 -0
  225. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  226. setiastro/saspro/imageops/stretch.py +236 -0
  227. setiastro/saspro/isophote.py +1182 -0
  228. setiastro/saspro/layers.py +208 -0
  229. setiastro/saspro/layers_dock.py +714 -0
  230. setiastro/saspro/lazy_imports.py +193 -0
  231. setiastro/saspro/legacy/__init__.py +2 -0
  232. setiastro/saspro/legacy/image_manager.py +2226 -0
  233. setiastro/saspro/legacy/numba_utils.py +3676 -0
  234. setiastro/saspro/legacy/xisf.py +1071 -0
  235. setiastro/saspro/linear_fit.py +537 -0
  236. setiastro/saspro/live_stacking.py +1841 -0
  237. setiastro/saspro/log_bus.py +5 -0
  238. setiastro/saspro/logging_config.py +460 -0
  239. setiastro/saspro/luminancerecombine.py +309 -0
  240. setiastro/saspro/main_helpers.py +201 -0
  241. setiastro/saspro/mask_creation.py +931 -0
  242. setiastro/saspro/masks_core.py +56 -0
  243. setiastro/saspro/mdi_widgets.py +353 -0
  244. setiastro/saspro/memory_utils.py +666 -0
  245. setiastro/saspro/metadata_patcher.py +75 -0
  246. setiastro/saspro/mfdeconv.py +3831 -0
  247. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  248. setiastro/saspro/mfdeconvcudnn.py +3263 -0
  249. setiastro/saspro/mfdeconvsport.py +2382 -0
  250. setiastro/saspro/minorbodycatalog.py +567 -0
  251. setiastro/saspro/morphology.py +407 -0
  252. setiastro/saspro/multiscale_decomp.py +1293 -0
  253. setiastro/saspro/nbtorgb_stars.py +541 -0
  254. setiastro/saspro/numba_utils.py +3145 -0
  255. setiastro/saspro/numba_warmup.py +141 -0
  256. setiastro/saspro/ops/__init__.py +9 -0
  257. setiastro/saspro/ops/command_help_dialog.py +623 -0
  258. setiastro/saspro/ops/command_runner.py +217 -0
  259. setiastro/saspro/ops/commands.py +1594 -0
  260. setiastro/saspro/ops/script_editor.py +1102 -0
  261. setiastro/saspro/ops/scripts.py +1473 -0
  262. setiastro/saspro/ops/settings.py +637 -0
  263. setiastro/saspro/parallel_utils.py +554 -0
  264. setiastro/saspro/pedestal.py +121 -0
  265. setiastro/saspro/perfect_palette_picker.py +1071 -0
  266. setiastro/saspro/pipeline.py +110 -0
  267. setiastro/saspro/pixelmath.py +1604 -0
  268. setiastro/saspro/plate_solver.py +2445 -0
  269. setiastro/saspro/project_io.py +797 -0
  270. setiastro/saspro/psf_utils.py +136 -0
  271. setiastro/saspro/psf_viewer.py +549 -0
  272. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  273. setiastro/saspro/remove_green.py +331 -0
  274. setiastro/saspro/remove_stars.py +1599 -0
  275. setiastro/saspro/remove_stars_preset.py +404 -0
  276. setiastro/saspro/resources.py +501 -0
  277. setiastro/saspro/rgb_combination.py +208 -0
  278. setiastro/saspro/rgb_extract.py +19 -0
  279. setiastro/saspro/rgbalign.py +723 -0
  280. setiastro/saspro/runtime_imports.py +7 -0
  281. setiastro/saspro/runtime_torch.py +754 -0
  282. setiastro/saspro/save_options.py +73 -0
  283. setiastro/saspro/selective_color.py +1552 -0
  284. setiastro/saspro/sfcc.py +1472 -0
  285. setiastro/saspro/shortcuts.py +3043 -0
  286. setiastro/saspro/signature_insert.py +1102 -0
  287. setiastro/saspro/stacking_suite.py +18470 -0
  288. setiastro/saspro/star_alignment.py +7435 -0
  289. setiastro/saspro/star_alignment_preset.py +329 -0
  290. setiastro/saspro/star_metrics.py +49 -0
  291. setiastro/saspro/star_spikes.py +765 -0
  292. setiastro/saspro/star_stretch.py +507 -0
  293. setiastro/saspro/stat_stretch.py +538 -0
  294. setiastro/saspro/status_log_dock.py +78 -0
  295. setiastro/saspro/subwindow.py +3328 -0
  296. setiastro/saspro/supernovaasteroidhunter.py +1719 -0
  297. setiastro/saspro/swap_manager.py +99 -0
  298. setiastro/saspro/torch_backend.py +89 -0
  299. setiastro/saspro/torch_rejection.py +434 -0
  300. setiastro/saspro/translations/all_source_strings.json +3654 -0
  301. setiastro/saspro/translations/ar_translations.py +3865 -0
  302. setiastro/saspro/translations/de_translations.py +3749 -0
  303. setiastro/saspro/translations/es_translations.py +3939 -0
  304. setiastro/saspro/translations/fr_translations.py +3858 -0
  305. setiastro/saspro/translations/hi_translations.py +3571 -0
  306. setiastro/saspro/translations/integrate_translations.py +270 -0
  307. setiastro/saspro/translations/it_translations.py +3678 -0
  308. setiastro/saspro/translations/ja_translations.py +3601 -0
  309. setiastro/saspro/translations/pt_translations.py +3869 -0
  310. setiastro/saspro/translations/ru_translations.py +2848 -0
  311. setiastro/saspro/translations/saspro_ar.qm +0 -0
  312. setiastro/saspro/translations/saspro_ar.ts +255 -0
  313. setiastro/saspro/translations/saspro_de.qm +0 -0
  314. setiastro/saspro/translations/saspro_de.ts +253 -0
  315. setiastro/saspro/translations/saspro_es.qm +0 -0
  316. setiastro/saspro/translations/saspro_es.ts +12520 -0
  317. setiastro/saspro/translations/saspro_fr.qm +0 -0
  318. setiastro/saspro/translations/saspro_fr.ts +12514 -0
  319. setiastro/saspro/translations/saspro_hi.qm +0 -0
  320. setiastro/saspro/translations/saspro_hi.ts +257 -0
  321. setiastro/saspro/translations/saspro_it.qm +0 -0
  322. setiastro/saspro/translations/saspro_it.ts +12520 -0
  323. setiastro/saspro/translations/saspro_ja.qm +0 -0
  324. setiastro/saspro/translations/saspro_ja.ts +257 -0
  325. setiastro/saspro/translations/saspro_pt.qm +0 -0
  326. setiastro/saspro/translations/saspro_pt.ts +257 -0
  327. setiastro/saspro/translations/saspro_ru.qm +0 -0
  328. setiastro/saspro/translations/saspro_ru.ts +237 -0
  329. setiastro/saspro/translations/saspro_sw.qm +0 -0
  330. setiastro/saspro/translations/saspro_sw.ts +257 -0
  331. setiastro/saspro/translations/saspro_uk.qm +0 -0
  332. setiastro/saspro/translations/saspro_uk.ts +10771 -0
  333. setiastro/saspro/translations/saspro_zh.qm +0 -0
  334. setiastro/saspro/translations/saspro_zh.ts +12520 -0
  335. setiastro/saspro/translations/sw_translations.py +3671 -0
  336. setiastro/saspro/translations/uk_translations.py +3700 -0
  337. setiastro/saspro/translations/zh_translations.py +3675 -0
  338. setiastro/saspro/versioning.py +77 -0
  339. setiastro/saspro/view_bundle.py +1558 -0
  340. setiastro/saspro/wavescale_hdr.py +645 -0
  341. setiastro/saspro/wavescale_hdr_preset.py +101 -0
  342. setiastro/saspro/wavescalede.py +680 -0
  343. setiastro/saspro/wavescalede_preset.py +230 -0
  344. setiastro/saspro/wcs_update.py +374 -0
  345. setiastro/saspro/whitebalance.py +492 -0
  346. setiastro/saspro/widgets/__init__.py +48 -0
  347. setiastro/saspro/widgets/common_utilities.py +306 -0
  348. setiastro/saspro/widgets/graphics_views.py +122 -0
  349. setiastro/saspro/widgets/image_utils.py +518 -0
  350. setiastro/saspro/widgets/minigame/game.js +986 -0
  351. setiastro/saspro/widgets/minigame/index.html +53 -0
  352. setiastro/saspro/widgets/minigame/style.css +241 -0
  353. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  354. setiastro/saspro/widgets/resource_monitor.py +237 -0
  355. setiastro/saspro/widgets/spinboxes.py +275 -0
  356. setiastro/saspro/widgets/themed_buttons.py +13 -0
  357. setiastro/saspro/widgets/wavelet_utils.py +331 -0
  358. setiastro/saspro/wimi.py +7996 -0
  359. setiastro/saspro/wims.py +578 -0
  360. setiastro/saspro/window_shelf.py +185 -0
  361. setiastro/saspro/xisf.py +1123 -0
  362. setiastrosuitepro-1.6.2.post1.dist-info/METADATA +278 -0
  363. setiastrosuitepro-1.6.2.post1.dist-info/RECORD +367 -0
  364. setiastrosuitepro-1.6.2.post1.dist-info/WHEEL +4 -0
  365. setiastrosuitepro-1.6.2.post1.dist-info/entry_points.txt +6 -0
  366. setiastrosuitepro-1.6.2.post1.dist-info/licenses/LICENSE +674 -0
  367. setiastrosuitepro-1.6.2.post1.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3831 @@
1
+ # pro/mfdeconv.py non-sport normal version
2
+ from __future__ import annotations
3
+ import os, sys
4
+ import math
5
+ import re
6
+ import numpy as np
7
+ import time
8
+ from astropy.io import fits
9
+ from PyQt6.QtCore import QObject, pyqtSignal
10
+ from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
11
+ from PyQt6.QtWidgets import QApplication
12
+ from PyQt6.QtCore import QThread
13
+ import contextlib
14
+ from threadpoolctl import threadpool_limits
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
16
+ _USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
17
+ import gc
18
+ try:
19
+ import sep
20
+ except Exception:
21
+ sep = None
22
+ from setiastro.saspro.free_torch_memory import _free_torch_memory
23
+ from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
24
+ torch = None # filled by runtime loader if available
25
+ TORCH_OK = False
26
+ NO_GRAD = contextlib.nullcontext # fallback
27
+ _XISF_READERS = []
28
+ try:
29
+ # e.g. your legacy module
30
+ from setiastro.saspro.legacy import xisf as _legacy_xisf
31
+ if hasattr(_legacy_xisf, "read"):
32
+ _XISF_READERS.append(lambda p: _legacy_xisf.read(p))
33
+ elif hasattr(_legacy_xisf, "open"):
34
+ _XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
35
+ except Exception:
36
+ pass
37
+ try:
38
+ # sometimes projects expose a generic load_image
39
+ from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
40
+ _XISF_READERS.append(lambda p: _generic_load_image(p)[0])
41
+ except Exception:
42
+ pass
43
+
44
+ from pathlib import Path
45
+
46
+ # at top of file with the other imports
47
+
48
+ from queue import SimpleQueue
49
+ from setiastro.saspro.memory_utils import LRUDict
50
+
51
+ # ── XISF decode cache → memmap on disk ─────────────────────────────────
52
+ import tempfile
53
+ import threading
54
+ import uuid
55
+ import atexit
56
+ _XISF_CACHE = LRUDict(50)
57
+ _XISF_LOCK = threading.Lock()
58
+ _XISF_TMPFILES = []
59
+
60
+ from collections import OrderedDict
61
+
62
+ # ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
63
+ class _FrameCHWLRU:
64
+ def __init__(self, capacity=8):
65
+ self.cap = int(max(1, capacity))
66
+ self.od = OrderedDict()
67
+
68
+ def clear(self):
69
+ self.od.clear()
70
+
71
+ def get(self, path, Ht, Wt, color_mode):
72
+ key = (path, Ht, Wt, str(color_mode).lower())
73
+ hit = self.od.get(key)
74
+ if hit is not None:
75
+ self.od.move_to_end(key)
76
+ return hit
77
+
78
+ # Load backing array cheaply (memmap for FITS, cached memmap for XISF)
79
+ ext = os.path.splitext(path)[1].lower()
80
+ if ext == ".xisf":
81
+ a = _xisf_cached_array(path) # float32, HW/HWC/CHW
82
+ else:
83
+ # FITS path: use astropy memmap (no data copy)
84
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
85
+ arr = None
86
+ for h in hdul:
87
+ if getattr(h, "data", None) is not None:
88
+ arr = h.data
89
+ break
90
+ if arr is None:
91
+ raise ValueError(f"No image data in {path}")
92
+ a = np.asarray(arr)
93
+ # dtype normalize once; keep float32
94
+ if a.dtype.kind in "ui":
95
+ a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
96
+ else:
97
+ a = a.astype(np.float32, copy=False)
98
+
99
+ # Center-crop to (Ht, Wt) and convert to CHW
100
+ a = np.asarray(a) # float32
101
+ a = _center_crop(a, Ht, Wt)
102
+
103
+ # Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
104
+ cm = str(color_mode).lower()
105
+ if cm == "luma":
106
+ a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
107
+ else:
108
+ a_chw = _as_chw(a).astype(np.float32, copy=False)
109
+ if a_chw.shape[0] == 1 and cm != "luma":
110
+ # still OK (mono data)
111
+ pass
112
+
113
+ # LRU insert
114
+ self.od[key] = a_chw
115
+ if len(self.od) > self.cap:
116
+ self.od.popitem(last=False)
117
+ return a_chw
118
+
119
+ _FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
120
+
121
+ def _clear_all_caches():
122
+ try: _clear_xisf_cache()
123
+ except Exception as e:
124
+ import logging
125
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
126
+ try: _FRAME_LRU.clear()
127
+ except Exception as e:
128
+ import logging
129
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
130
+
131
+
132
+ def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
133
+ if a.dtype.kind in "ui":
134
+ return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
135
+ if a.dtype == np.float32:
136
+ return a
137
+ return a.astype(np.float32, copy=False)
138
+
139
+ def _xisf_cached_array(path: str) -> np.memmap:
140
+ """
141
+ Decode an XISF image exactly once and back it by a read-only float32 memmap.
142
+ Returns a memmap that can be sliced cheaply for tiles.
143
+ """
144
+ with _XISF_LOCK:
145
+ hit = _XISF_CACHE.get(path)
146
+ if hit is not None:
147
+ fn, shape = hit
148
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
149
+
150
+ # Decode once
151
+ arr, _ = _load_image_array(path) # your existing loader
152
+ if arr is None:
153
+ raise ValueError(f"XISF loader returned None for {path}")
154
+ arr = np.asarray(arr)
155
+ arrf = _normalize_to_float32(arr)
156
+
157
+ # Create a temp file-backed memmap
158
+ tmpdir = tempfile.gettempdir()
159
+ fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
160
+ mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
161
+ mm[...] = arrf[...]
162
+ mm.flush()
163
+ del mm # close writer handle; re-open below as read-only
164
+
165
+ _XISF_CACHE[path] = (fn, arrf.shape)
166
+ _XISF_TMPFILES.append(fn)
167
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
168
+
169
+ def _clear_xisf_cache():
170
+ with _XISF_LOCK:
171
+ for fn in _XISF_TMPFILES:
172
+ try: os.remove(fn)
173
+ except Exception as e:
174
+ import logging
175
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
176
+ _XISF_CACHE.clear()
177
+ _XISF_TMPFILES.clear()
178
+
179
+ atexit.register(_clear_xisf_cache)
180
+
181
+
182
+ def _is_xisf(path: str) -> bool:
183
+ return os.path.splitext(path)[1].lower() == ".xisf"
184
+
185
+ def _read_xisf_numpy(path: str) -> np.ndarray:
186
+ if not _XISF_READERS:
187
+ raise RuntimeError(
188
+ "No XISF readers registered. Ensure one of "
189
+ "legacy.xisf.read/open or *.image_io.load_image is importable."
190
+ )
191
+ last_err = None
192
+ for fn in _XISF_READERS:
193
+ try:
194
+ arr = fn(path)
195
+ if isinstance(arr, tuple):
196
+ arr = arr[0]
197
+ return np.asarray(arr)
198
+ except Exception as e:
199
+ last_err = e
200
+ raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
201
+
202
+ def _fits_open_data(path: str):
203
+ # ignore_missing_simple=True lets us open headers missing SIMPLE
204
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
205
+ hdu = hdul[0]
206
+ if hdu.data is None:
207
+ # find first image HDU if primary is header-only
208
+ for h in hdul[1:]:
209
+ if getattr(h, "data", None) is not None:
210
+ hdu = h
211
+ break
212
+ data = np.asanyarray(hdu.data)
213
+ hdr = hdu.header
214
+ return data, hdr
215
+
216
+ def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
217
+ """
218
+ Return (numpy array, fits.Header or None). Color-last if 3D.
219
+ dtype left as-is; callers cast to float32. Array is C-contig & writeable.
220
+ """
221
+ if _is_xisf(path):
222
+ arr = _read_xisf_numpy(path)
223
+ hdr = None
224
+ else:
225
+ arr, hdr = _fits_open_data(path)
226
+
227
+ a = np.asarray(arr)
228
+ # Move color axis to last if 3D with a leading channel axis
229
+ if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
230
+ a = np.moveaxis(a, 0, -1)
231
+ # Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
232
+ if (not a.flags.c_contiguous) or (not a.flags.writeable):
233
+ a = np.array(a, copy=True)
234
+ return a, hdr
235
+
236
+ def _probe_hw(path: str) -> tuple[int, int, int | None]:
237
+ """
238
+ Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
239
+ """
240
+ a, _ = _load_image_array(path)
241
+ if a.ndim == 2:
242
+ return a.shape[0], a.shape[1], None
243
+ if a.ndim == 3:
244
+ h, w, c = a.shape
245
+ # treat mono-3D as (H,W,1)
246
+ if c not in (1, 3) and a.shape[0] in (1, 3):
247
+ a = np.moveaxis(a, 0, -1)
248
+ h, w, c = a.shape
249
+ return h, w, c if c in (1, 3) else None
250
+ raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
251
+
252
+ def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
253
+ """
254
+ Replacement for the old FITS-only version: min(H), min(W) across files.
255
+ """
256
+ Hs, Ws = [], []
257
+ for p in paths:
258
+ h, w, _ = _probe_hw(p)
259
+ Hs.append(int(h)); Ws.append(int(w))
260
+ return int(min(Hs)), int(min(Ws))
261
+
262
+ def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
263
+ """
264
+ Convert to CHW float32:
265
+ - mono → (1,H,W)
266
+ - RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
267
+ """
268
+ x = np.asarray(img)
269
+ if x.ndim == 2:
270
+ y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
271
+ return y
272
+ if x.ndim == 3:
273
+ # color-last (H,W,C) expected
274
+ if x.shape[-1] == 1:
275
+ return x[..., 0].astype(np.float32, copy=False)[None, ...]
276
+ if x.shape[-1] == 3:
277
+ if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
278
+ r, g, b = x[..., 0], x[..., 1], x[..., 2]
279
+ return np.stack([r.astype(np.float32, copy=False),
280
+ g.astype(np.float32, copy=False),
281
+ b.astype(np.float32, copy=False)], axis=0)
282
+ # luma
283
+ r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
284
+ L = 0.2126*r + 0.7152*g + 0.0722*b
285
+ return L[None, ...]
286
+ # rare mono-3D
287
+ if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
288
+ x = np.moveaxis(x, 0, -1)
289
+ return _to_chw_float32(x, color_mode)
290
+ raise ValueError(f"Unsupported image shape {x.shape}")
291
+
292
+ def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
293
+ h, w = img.shape[:2]
294
+ y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
295
+ return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
296
+
297
+ def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
298
+ """
299
+ Drop-in replacement of the old FITS-only helper.
300
+ Returns (ys, hdrs):
301
+ ys : list of CHW float32 arrays cropped to (Ht,Wt)
302
+ hdrs : list of fits.Header or None (XISF)
303
+ """
304
+ ys, hdrs = [], []
305
+ for p in paths:
306
+ arr, hdr = _load_image_array(p)
307
+ arr = _center_crop_hw(arr, Ht, Wt)
308
+ # normalize integer data to [0,1] like the rest of your code
309
+ if arr.dtype.kind in "ui":
310
+ mx = np.float32(np.iinfo(arr.dtype).max)
311
+ arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
312
+ elif arr.dtype.kind == "f":
313
+ arr = arr.astype(np.float32, copy=False)
314
+ else:
315
+ arr = arr.astype(np.float32, copy=False)
316
+
317
+ y = _to_chw_float32(arr, color_mode)
318
+ if (not y.flags.c_contiguous) or (not y.flags.writeable):
319
+ y = np.ascontiguousarray(y.astype(np.float32, copy=True))
320
+ ys.append(y)
321
+ hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
322
+ return ys, hdrs
323
+
324
+ def _safe_primary_header(path: str) -> fits.Header:
325
+ if _is_xisf(path):
326
+ # best-effort synthetic header
327
+ h = fits.Header()
328
+ h["SIMPLE"] = (True, "created by MFDeconv")
329
+ h["BITPIX"] = -32
330
+ h["NAXIS"] = 2
331
+ return h
332
+ try:
333
+ return fits.getheader(path, ext=0, ignore_missing_simple=True)
334
+ except Exception:
335
+ return fits.Header()
336
+
337
+ # --- CUDA busy/unavailable detector (runtime fallback helper) ---
338
+ def _is_cuda_busy_error(e: Exception) -> bool:
339
+ """
340
+ Return True for 'device busy/unavailable' style CUDA errors that can pop up
341
+ mid-run on shared Linux systems. We *exclude* OOM here (handled elsewhere).
342
+ """
343
+ em = str(e).lower()
344
+ if "out of memory" in em:
345
+ return False # OOM handled by your batch size backoff
346
+ return (
347
+ ("cuda" in em and ("busy" in em or "unavailable" in em))
348
+ or ("device-side" in em and "assert" in em)
349
+ or ("driver shutting down" in em)
350
+ )
351
+
352
+
353
+ def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
354
+ star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
355
+ """
356
+ Worker function: compute PSF and optional star mask / varmap for one frame.
357
+ Returns (index, psf, mask_or_None, var_or_None, log_lines)
358
+ """
359
+ logs = []
360
+ def log(s): logs.append(s)
361
+
362
+ # --- PSF sizing by FWHM ---
363
+ f_hdr = _estimate_fwhm_from_header(hdr)
364
+ f_img = _estimate_fwhm_from_image(arr)
365
+ f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
366
+ if not np.isfinite(f_whm) or f_whm <= 0:
367
+ f_whm = 2.5
368
+ k_auto = _auto_ksize_from_fwhm(f_whm)
369
+
370
+ # --- Star-derived PSF with retries ---
371
+ tried, psf = [], None
372
+ for k_try in [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]:
373
+ if k_try in tried: continue
374
+ tried.append(k_try)
375
+ try:
376
+ out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=6.0, max_stars=80)
377
+ psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
378
+ if psf_try is not None:
379
+ psf = psf_try
380
+ break
381
+ except Exception:
382
+ psf = None
383
+ if psf is None:
384
+ psf = _gaussian_psf(f_whm, ksize=k_auto)
385
+ psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.0)
386
+
387
+ mask = None
388
+ var = None
389
+
390
+ if make_masks or make_varmaps:
391
+ # one background per frame (reused by both)
392
+ luma = _to_luma_local(arr)
393
+ vmc = (varmap_cfg or {})
394
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
395
+ luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
396
+ )
397
+
398
+ if make_masks:
399
+ smc = star_mask_cfg or {}
400
+ mask = _star_mask_from_precomputed(
401
+ luma, sky_map, err_scalar,
402
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
403
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
404
+ grow_px = smc.get("grow_px", GROW_PX),
405
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
406
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
407
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
408
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
409
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
410
+ status_cb = log,
411
+ )
412
+
413
+ if make_varmaps:
414
+ vmc = varmap_cfg or {}
415
+ var = _variance_map_from_precomputed(
416
+ luma, sky_map, rms_map, hdr,
417
+ smooth_sigma = vmc.get("smooth_sigma", 1.0),
418
+ floor = vmc.get("floor", 1e-8),
419
+ status_cb = log,
420
+ )
421
+
422
+ # small per-frame summary
423
+ fwhm_est = _psf_fwhm_px(psf)
424
+ logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
425
+
426
+ return i, psf, mask, var, logs
427
+
428
+ def _compute_one_worker(args):
429
+ """
430
+ Top-level picklable worker for ProcessPoolExecutor.
431
+ args: (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg)
432
+ Returns (i, psf, mask, var, logs)
433
+ """
434
+ (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg) = args
435
+ # avoid BLAS/OMP storm inside each process
436
+ with threadpool_limits(limits=1):
437
+ arr, hdr = _load_image_array(path) # FITS or XISF
438
+ arr = np.asarray(arr, dtype=np.float32, order="C")
439
+ if arr.ndim == 3 and arr.shape[-1] == 1:
440
+ arr = np.squeeze(arr, axis=-1)
441
+ if not isinstance(hdr, fits.Header): # synthesize FITS-like header for XISF
442
+ hdr = _safe_primary_header(path)
443
+ return _compute_frame_assets(
444
+ i, arr, hdr,
445
+ make_masks=bool(make_masks_in_worker),
446
+ make_varmaps=bool(make_varmaps),
447
+ star_mask_cfg=star_mask_cfg,
448
+ varmap_cfg=varmap_cfg,
449
+ )
450
+
451
+
452
+ def _build_psf_and_assets(
453
+ paths, # list[str]
454
+ make_masks=False,
455
+ make_varmaps=False,
456
+ status_cb=lambda s: None,
457
+ save_dir: str | None = None,
458
+ star_mask_cfg: dict | None = None,
459
+ varmap_cfg: dict | None = None,
460
+ max_workers: int | None = None,
461
+ star_mask_ref_path: str | None = None, # build one mask from this frame if provided
462
+ # NEW (passed from multiframe_deconv so we don’t re-probe/convert):
463
+ Ht: int | None = None,
464
+ Wt: int | None = None,
465
+ color_mode: str = "luma",
466
+ ):
467
+ """
468
+ Parallel PSF + (optional) star mask + variance map per frame.
469
+
470
+ Changes from the original:
471
+ • Reuses the decoded frame cache (_FRAME_LRU) for FITS/XISF so we never re-decode.
472
+ • Automatically switches to threads for XISF (so memmaps are shared across workers).
473
+ • Builds a single reference star mask (if requested) from the cached frame and
474
+ center-pads/crops it for all frames (no extra I/O).
475
+ • Preserves return order and streams worker logs back to the UI.
476
+ """
477
+ if save_dir:
478
+ os.makedirs(save_dir, exist_ok=True)
479
+
480
+ n = len(paths)
481
+
482
+ # Resolve target intersection size if caller didn't pass it
483
+ if Ht is None or Wt is None:
484
+ Ht, Wt = _common_hw_from_paths(paths)
485
+
486
+ # Sensible default worker count (cap at 8)
487
+ if max_workers is None:
488
+ try:
489
+ hw = os.cpu_count() or 4
490
+ except Exception:
491
+ hw = 4
492
+ max_workers = max(1, min(8, hw))
493
+
494
+ # Decide executor: for any XISF, prefer threads so the memmap/cache is shared
495
+ any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
496
+ use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
497
+ Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
498
+ pool_kind = "process" if use_proc_pool else "thread"
499
+ status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
500
+
501
+ # ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
502
+ def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
503
+ a2d = np.asarray(a2d, dtype=np.float32)
504
+ H, W = int(a2d.shape[0]), int(a2d.shape[1])
505
+ # crop first if bigger
506
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
507
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
508
+ cropped = a2d[y0:y1, x0:x1]
509
+ ch, cw = cropped.shape
510
+ if ch == Ht and cw == Wt:
511
+ return np.ascontiguousarray(cropped, dtype=np.float32)
512
+ # pad if smaller
513
+ out = np.full((Ht, Wt), float(fill), dtype=np.float32)
514
+ oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
515
+ out[oy:oy+ch, ox:ox+cw] = cropped
516
+ return out
517
+
518
+ # ---- optional: build one mask from the reference frame and reuse ----
519
+ base_ref_mask = None
520
+ if make_masks and star_mask_ref_path:
521
+ try:
522
+ status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
523
+ # Pull from the shared frame cache as luma on (Ht,Wt)
524
+ ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
525
+ L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
526
+
527
+ vmc = (varmap_cfg or {})
528
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
529
+ L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
530
+ )
531
+ smc = (star_mask_cfg or {})
532
+ base_ref_mask = _star_mask_from_precomputed(
533
+ L, sky_map, err_scalar,
534
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
535
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
536
+ grow_px = smc.get("grow_px", GROW_PX),
537
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
538
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
539
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
540
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
541
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
542
+ status_cb = status_cb,
543
+ )
544
+ except Exception as e:
545
+ status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
546
+ base_ref_mask = None
547
+
548
+ # for GUI safety, queue logs from workers and flush in the main thread
549
+ log_queue: SimpleQueue = SimpleQueue()
550
+
551
+ def enqueue_logs(lines):
552
+ for s in lines:
553
+ log_queue.put(s)
554
+
555
+ psfs = [None] * n
556
+ masks = ([None] * n) if make_masks else None
557
+ vars_ = ([None] * n) if make_varmaps else None
558
+ make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
559
+
560
+ # --- thread worker: get frame from cache and compute assets ---
561
+ def _compute_one(i: int, path: str):
562
+ # avoid heavy BLAS oversubscription inside each worker
563
+ with threadpool_limits(limits=1):
564
+ # Pull frame from cache honoring color_mode & target (Ht,Wt)
565
+ img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
566
+ # For PSF/mask/varmap we operate on a 2D plane (luma/mono)
567
+ arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
568
+
569
+ # Header: synthesize a safe FITS-like header (works for XISF too)
570
+ try:
571
+ hdr = _safe_primary_header(path)
572
+ except Exception:
573
+ hdr = fits.Header()
574
+
575
+ return _compute_frame_assets(
576
+ i, arr2d, hdr,
577
+ make_masks=bool(make_masks_in_worker),
578
+ make_varmaps=bool(make_varmaps),
579
+ star_mask_cfg=star_mask_cfg,
580
+ varmap_cfg=varmap_cfg,
581
+ )
582
+
583
+ # --- submit jobs ---
584
+ with Executor(max_workers=max_workers) as ex:
585
+ futs = []
586
+ for i, p in enumerate(paths, start=1):
587
+ status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
588
+ if use_proc_pool:
589
+ # Process-safe path: worker re-loads inside the subprocess
590
+ futs.append(ex.submit(
591
+ _compute_one_worker,
592
+ (i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg)
593
+ ))
594
+ else:
595
+ # Thread path: hits the shared cache (fast path for XISF/FITS)
596
+ futs.append(ex.submit(_compute_one, i, p))
597
+
598
+ done_cnt = 0
599
+ for fut in as_completed(futs):
600
+ i, psf, m, v, logs = fut.result()
601
+ idx = i - 1
602
+ psfs[idx] = psf
603
+ if masks is not None:
604
+ masks[idx] = m
605
+ if vars_ is not None:
606
+ vars_[idx] = v
607
+ enqueue_logs(logs)
608
+
609
+ done_cnt += 1
610
+ if (done_cnt % 4) == 0 or done_cnt == n:
611
+ while not log_queue.empty():
612
+ try:
613
+ status_cb(log_queue.get_nowait())
614
+ except Exception:
615
+ break
616
+
617
+ # If we built a single reference mask, apply it to every frame (center pad/crop)
618
+ if base_ref_mask is not None and masks is not None:
619
+ for idx in range(n):
620
+ masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
621
+
622
+ # final flush of any remaining logs
623
+ while not log_queue.empty():
624
+ try:
625
+ status_cb(log_queue.get_nowait())
626
+ except Exception:
627
+ break
628
+
629
+ # save PSFs if requested
630
+ if save_dir:
631
+ for i, k in enumerate(psfs, start=1):
632
+ if k is not None:
633
+ fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
634
+ os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
635
+ )
636
+
637
+ return psfs, masks, vars_
638
+
639
+
640
+ _ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
641
+
642
+ # known FITS-style multi-extensions (rightmost-first match)
643
+ _KNOWN_EXTS = [
644
+ ".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
645
+ ".fz", ".gz",
646
+ ".fits", ".fit"
647
+ ]
648
+
649
+ def _sanitize_token(s: str) -> str:
650
+ s = _ALLOWED.sub("_", s)
651
+ s = re.sub(r"_+", "_", s).strip("_")
652
+ return s
653
+
654
+ def _split_known_exts(p: Path) -> tuple[str, str]:
655
+ """
656
+ Return (name_body, full_ext) where full_ext is a REAL extension block
657
+ (e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
658
+ """
659
+ name = p.name
660
+ for ext in _KNOWN_EXTS:
661
+ if name.lower().endswith(ext):
662
+ body = name[:-len(ext)]
663
+ return body, ext
664
+ # fallback: single suffix
665
+ return p.stem, "".join(p.suffixes)
666
+
667
+ _SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
668
+ _EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
669
+ _RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
670
+
671
+ def _extract_size(body: str) -> str | None:
672
+ m = _SIZE_RE.search(body)
673
+ return f"{m.group(1)}x{m.group(2)}" if m else None
674
+
675
+ def _extract_exposure_secs(body: str) -> str | None:
676
+ m = _EXP_RE.search(body)
677
+ if not m:
678
+ return None
679
+ secs = int(round(float(m.group(1))))
680
+ return f"{secs}s"
681
+
682
+ def _strip_metadata_from_base(body: str) -> str:
683
+ s = body
684
+
685
+ # normalize common separators first
686
+ s = s.replace(" - ", "_")
687
+
688
+ # remove known trailing marker '_MFDeconv'
689
+ s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
690
+
691
+ # remove parenthetical copy counters e.g. '(1)'
692
+ s = re.sub(r"\(\s*\d+\s*\)$", "", s)
693
+
694
+ # remove size (with or without parens) anywhere
695
+ s = _SIZE_RE.sub("", s)
696
+
697
+ # remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
698
+ s = _EXP_RE.sub("", s)
699
+
700
+ # remove any _#x tokens
701
+ s = _RX_RE.sub("", s)
702
+
703
+ # collapse whitespace/underscores and sanitize
704
+ s = re.sub(r"[\s]+", "_", s)
705
+ s = _sanitize_token(s)
706
+ return s or "output"
707
+
708
+ def _canonical_out_name_prefix(base: str, r: int, size: str | None,
709
+ exposure_secs: str | None, tag: str = "MFDeconv") -> str:
710
+ parts = [_sanitize_token(tag), _sanitize_token(base)]
711
+ if size:
712
+ parts.append(_sanitize_token(size))
713
+ if exposure_secs:
714
+ parts.append(_sanitize_token(exposure_secs))
715
+ if int(max(1, r)) > 1:
716
+ parts.append(f"{int(r)}x")
717
+ return "_".join(parts)
718
+
719
+ def _sr_out_path(out_path: str, r: int) -> Path:
720
+ """
721
+ Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
722
+ """
723
+ p = Path(out_path)
724
+ body, real_ext = _split_known_exts(p)
725
+
726
+ # harvest metadata from the whole body (not Path.stem)
727
+ size = _extract_size(body)
728
+ ex_sec = _extract_exposure_secs(body)
729
+
730
+ # clean base
731
+ base = _strip_metadata_from_base(body)
732
+
733
+ new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
734
+ return p.with_name(f"{new_stem}{real_ext}")
735
+
736
+ def _nonclobber_path(path: str) -> str:
737
+ """
738
+ Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
739
+ """
740
+ p = Path(path)
741
+ if not p.exists():
742
+ return str(p)
743
+
744
+ # keep the true extension(s)
745
+ body, real_ext = _split_known_exts(p)
746
+
747
+ # if already has _vN, bump it
748
+ m = re.search(r"(.*)_v(\d+)$", body)
749
+ if m:
750
+ base = m.group(1); n = int(m.group(2)) + 1
751
+ else:
752
+ base = body; n = 2
753
+
754
+ while True:
755
+ candidate = p.with_name(f"{base}_v{n}{real_ext}")
756
+ if not candidate.exists():
757
+ return str(candidate)
758
+ n += 1
759
+
760
+ def _iter_folder(basefile: str) -> str:
761
+ d, fname = os.path.split(basefile)
762
+ root, ext = os.path.splitext(fname)
763
+ tgt = os.path.join(d, f"{root}.iters")
764
+ if not os.path.exists(tgt):
765
+ try:
766
+ os.makedirs(tgt, exist_ok=True)
767
+ except Exception:
768
+ # last resort: suffix (n)
769
+ n = 1
770
+ while True:
771
+ cand = os.path.join(d, f"{root}.iters ({n})")
772
+ try:
773
+ os.makedirs(cand, exist_ok=True)
774
+ return cand
775
+ except Exception:
776
+ n += 1
777
+ return tgt
778
+
779
+ def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
780
+ """
781
+ arr: numpy array (H,W) or (C,H,W) float32
782
+ tag: 'seed' or 'iter_###'
783
+ """
784
+ if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
785
+ arr = np.moveaxis(arr, -1, 0)
786
+ if arr.ndim == 3 and arr.shape[0] == 1:
787
+ arr = arr[0]
788
+
789
+ hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
790
+ hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
791
+ hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
792
+ path = os.path.join(folder, f"{tag}.fit")
793
+ # overwrite allowed inside the dedicated folder
794
+ fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
795
+ return path
796
+
797
+
798
+ def _process_gui_events_safely():
799
+ app = QApplication.instance()
800
+ if app and QThread.currentThread() is app.thread():
801
+ app.processEvents()
802
+
803
+ EPS = 1e-6
804
+
805
+ # -----------------------------
806
+ # Helpers: image prep / shapes
807
+ # -----------------------------
808
+
809
+ # new: lightweight loader that yields one frame at a time
810
+ def _iter_fits(paths):
811
+ for p in paths:
812
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
813
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy
814
+ if arr.ndim == 3 and arr.shape[-1] == 1:
815
+ arr = np.squeeze(arr, axis=-1)
816
+ hdr = hdul[0].header.copy()
817
+ yield arr, hdr
818
+
819
+ def _to_luma_local(a: np.ndarray) -> np.ndarray:
820
+ a = np.asarray(a, dtype=np.float32)
821
+ if a.ndim == 2:
822
+ return a
823
+ # (H,W,3) or (3,H,W)
824
+ if a.ndim == 3 and a.shape[-1] == 3:
825
+ try:
826
+ import cv2
827
+ return cv2.cvtColor(a, cv2.COLOR_RGB2GRAY).astype(np.float32, copy=False)
828
+ except Exception:
829
+ pass
830
+ r, g, b = a[..., 0], a[..., 1], a[..., 2]
831
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
832
+ if a.ndim == 3 and a.shape[0] == 3:
833
+ r, g, b = a[0], a[1], a[2]
834
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
835
+ return a.mean(axis=-1).astype(np.float32, copy=False)
836
+
837
+ def _stack_loader(paths):
838
+ ys, hdrs = [], []
839
+ for p in paths:
840
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
841
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy inside with
842
+ hdr = hdul[0].header.copy()
843
+ if arr.ndim == 3 and arr.shape[-1] == 1:
844
+ arr = np.squeeze(arr, axis=-1)
845
+ ys.append(arr)
846
+ hdrs.append(hdr)
847
+ return ys, hdrs
848
+
849
+ def _normalize_layout_single(a, color_mode):
850
+ """
851
+ Coerce to:
852
+ - 'luma' -> (H, W)
853
+ - 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
854
+ Accepts (H,W), (H,W,3), or (3,H,W).
855
+ """
856
+ a = np.asarray(a, dtype=np.float32)
857
+
858
+ if color_mode == "luma":
859
+ return _to_luma_local(a) # returns (H,W)
860
+
861
+ # perchannel
862
+ if a.ndim == 2:
863
+ return a[None, ...] # (1,H,W) ← keep mono as 1 channel
864
+ if a.ndim == 3 and a.shape[-1] == 3:
865
+ return np.moveaxis(a, -1, 0) # (3,H,W)
866
+ if a.ndim == 3 and a.shape[0] in (1, 3):
867
+ return a # already (1,H,W) or (3,H,W)
868
+ # fallback: average any weird shape into luma 1×H×W
869
+ l = _to_luma_local(a)
870
+ return l[None, ...]
871
+
872
+
873
+ def _normalize_layout_batch(arrs, color_mode):
874
+ return [_normalize_layout_single(a, color_mode) for a in arrs]
875
+
876
+ def _common_hw(data_list):
877
+ """Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
878
+ Hs, Ws = [], []
879
+ for a in data_list:
880
+ if a.ndim == 2:
881
+ H, W = a.shape
882
+ else:
883
+ _, H, W = a.shape
884
+ Hs.append(H); Ws.append(W)
885
+ return int(min(Hs)), int(min(Ws))
886
+
887
+ def _center_crop(arr, Ht, Wt):
888
+ """Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
889
+ if arr.ndim == 2:
890
+ H, W = arr.shape
891
+ if H == Ht and W == Wt:
892
+ return arr
893
+ y0 = max(0, (H - Ht) // 2)
894
+ x0 = max(0, (W - Wt) // 2)
895
+ return arr[y0:y0+Ht, x0:x0+Wt]
896
+ else:
897
+ C, H, W = arr.shape
898
+ if H == Ht and W == Wt:
899
+ return arr
900
+ y0 = max(0, (H - Ht) // 2)
901
+ x0 = max(0, (W - Wt) // 2)
902
+ return arr[:, y0:y0+Ht, x0:x0+Wt]
903
+
904
+ def _sanitize_numeric(a):
905
+ """Replace NaN/Inf, clip negatives, make contiguous float32."""
906
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
907
+ a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
908
+ return np.ascontiguousarray(a)
909
+
910
+ # -----------------------------
911
+ # PSF utilities
912
+ # -----------------------------
913
+
914
+ def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
915
+ sigma = max(fwhm_px, 1.0) / 2.3548
916
+ r = (ksize - 1) / 2
917
+ y, x = np.mgrid[-r:r+1, -r:r+1]
918
+ g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
919
+ g /= (np.sum(g) + EPS)
920
+ return g.astype(np.float32, copy=False)
921
+
922
+ def _estimate_fwhm_from_header(hdr) -> float:
923
+ for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
924
+ if key in hdr:
925
+ try:
926
+ val = float(hdr[key])
927
+ if np.isfinite(val) and val > 0:
928
+ return val
929
+ except Exception:
930
+ pass
931
+ return float("nan")
932
+
933
+ def _estimate_fwhm_from_image(arr) -> float:
934
+ """Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
935
+ if sep is None:
936
+ return float("nan")
937
+ try:
938
+ img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
939
+ bkg = sep.Background(img)
940
+ data = _contig(img - bkg.back()) # ← ensure data is C-contig
941
+ try:
942
+ err = bkg.globalrms
943
+ except Exception:
944
+ err = float(np.median(bkg.rms()))
945
+ sources = sep.extract(data, 6.0, err=err)
946
+ if sources is None or len(sources) == 0:
947
+ return float("nan")
948
+ a = np.asarray(sources["a"], dtype=np.float32)
949
+ b = np.asarray(sources["b"], dtype=np.float32)
950
+ ab = (a + b) * 0.5
951
+ sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
952
+ if not np.isfinite(sigma) or sigma <= 0:
953
+ return float("nan")
954
+ return 2.3548 * sigma
955
+ except Exception:
956
+ return float("nan")
957
+
958
+ def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
959
+ """
960
+ Choose odd kernel size to cover about ±4σ.
961
+ """
962
+ sigma = max(fwhm_px, 1.0) / 2.3548
963
+ r = int(math.ceil(4.0 * sigma))
964
+ k = 2 * r + 1
965
+ k = max(kmin, min(k, kmax))
966
+ if (k % 2) == 0:
967
+ k += 1
968
+ return k
969
+
970
+ def _flip_kernel(psf):
971
+ # PyTorch dislikes negative strides; make it contiguous.
972
+ return np.flip(np.flip(psf, -1), -2).copy()
973
+
974
+ def _conv_same_np(img, psf):
975
+ """
976
+ NumPy FFT-based SAME convolution for (H,W) or (C,H,W).
977
+ IMPORTANT: ifftshift the PSF so its peak is at [0,0] before FFT.
978
+ """
979
+ import numpy as _np
980
+ import numpy.fft as _fft
981
+
982
+ kh, kw = psf.shape
983
+
984
+ def fftconv2(a, k):
985
+ # a is (1,H,W); k is (kh,kw)
986
+ H, W = a.shape[-2:]
987
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
988
+ A = _fft.rfftn(a, s=(fftH, fftW), axes=(-2, -1))
989
+ K = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW), axes=(-2, -1))
990
+ y = _fft.irfftn(A * K, s=(fftH, fftW), axes=(-2, -1))
991
+ sh, sw = (kh - 1)//2, (kw - 1)//2
992
+ return y[..., sh:sh+H, sw:sw+W]
993
+
994
+ if img.ndim == 2:
995
+ return fftconv2(img[None], psf)[0]
996
+ else:
997
+ # per-channel
998
+ return _np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
999
+
1000
+ def _normalize_psf(psf):
1001
+ psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
1002
+ s = float(psf.sum())
1003
+ if not np.isfinite(s) or s <= 1e-6:
1004
+ return psf
1005
+ return (psf / s).astype(np.float32, copy=False)
1006
+
1007
+ def _soften_psf(psf, sigma_px=0.25):
1008
+ if sigma_px <= 0:
1009
+ return psf
1010
+ r = int(max(1, round(3 * sigma_px)))
1011
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1012
+ g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
1013
+ g /= g.sum() + 1e-6
1014
+ return _conv_same_np(psf[None], g)[0]
1015
+
1016
+ def _psf_fwhm_px(psf: np.ndarray) -> float:
1017
+ """Approximate FWHM (pixels) from second moments of a normalized kernel."""
1018
+ psf = np.maximum(psf, 0).astype(np.float32, copy=False)
1019
+ s = float(psf.sum())
1020
+ if s <= EPS:
1021
+ return float("nan")
1022
+ k = psf.shape[0]
1023
+ y, x = np.mgrid[:k, :k].astype(np.float32)
1024
+ cy = float((psf * y).sum() / s)
1025
+ cx = float((psf * x).sum() / s)
1026
+ var_y = float((psf * (y - cy) ** 2).sum() / s)
1027
+ var_x = float((psf * (x - cx) ** 2).sum() / s)
1028
+ sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
1029
+ return 2.3548 * sigma # FWHM≈2.355σ
1030
+
1031
+ STAR_MASK_MAXSIDE = 2048
1032
+ STAR_MASK_MAXOBJS = 2000 # cap number of objects
1033
+ VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
1034
+ THRESHOLD_SIGMA = 2.0
1035
+ KEEP_FLOOR = 0.20
1036
+ GROW_PX = 8
1037
+ MAX_STAR_RADIUS = 16
1038
+ SOFT_SIGMA = 2.0
1039
+ ELLIPSE_SCALE = 1.2
1040
+
1041
+ def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
1042
+ """One-time SEP background build; returns (sky_map, rms_map, err_scalar)."""
1043
+ if sep is None:
1044
+ # robust fallback
1045
+ med = float(np.median(img_2d))
1046
+ mad = float(np.median(np.abs(img_2d - med))) + 1e-6
1047
+ sky = np.full_like(img_2d, med, dtype=np.float32)
1048
+ rmsm = np.full_like(img_2d, 1.4826 * mad, dtype=np.float32)
1049
+ return sky, rmsm, float(np.median(rmsm))
1050
+
1051
+ a = np.ascontiguousarray(img_2d.astype(np.float32))
1052
+ b = sep.Background(a, bw=int(bw), bh=int(bh), fw=3, fh=3)
1053
+ sky = np.asarray(b.back(), dtype=np.float32)
1054
+ try:
1055
+ rmsm = np.asarray(b.rms(), dtype=np.float32)
1056
+ err = float(b.globalrms)
1057
+ except Exception:
1058
+ rmsm = np.full_like(a, float(np.median(b.rms())), dtype=np.float32)
1059
+ err = float(np.median(rmsm))
1060
+ return sky, rmsm, err
1061
+
1062
+
1063
+ def _auto_star_mask_sep(
1064
+ img_2d: np.ndarray,
1065
+ thresh_sigma: float = THRESHOLD_SIGMA,
1066
+ grow_px: int = GROW_PX,
1067
+ max_objs: int = STAR_MASK_MAXOBJS,
1068
+ max_side: int = STAR_MASK_MAXSIDE,
1069
+ ellipse_scale: float = ELLIPSE_SCALE,
1070
+ soft_sigma: float = SOFT_SIGMA,
1071
+ max_semiaxis_px: float | None = None, # kept for API compat; unused
1072
+ max_area_px2: float | None = None, # kept for API compat; unused
1073
+ max_radius_px: int = MAX_STAR_RADIUS,
1074
+ keep_floor: float = KEEP_FLOOR,
1075
+ status_cb=lambda s: None
1076
+ ) -> np.ndarray:
1077
+ """
1078
+ Build a KEEP weight map (float32 in [0,1]) using SEP detections only.
1079
+ **Never writes to img_2d** and draws only into a fresh mask buffer.
1080
+ """
1081
+ if sep is None:
1082
+ return np.ones_like(img_2d, dtype=np.float32, order="C")
1083
+
1084
+ # Optional OpenCV path for fast draw/blur
1085
+ try:
1086
+ import cv2 as _cv2
1087
+ _HAS_CV2 = True
1088
+ except Exception:
1089
+ _HAS_CV2 = False
1090
+ _cv2 = None # type: ignore
1091
+
1092
+ h, w = map(int, img_2d.shape)
1093
+
1094
+ # Background / residual on our own contiguous buffer
1095
+ data = np.ascontiguousarray(img_2d.astype(np.float32))
1096
+ bkg = sep.Background(data)
1097
+ data_sub = np.ascontiguousarray(data - bkg.back(), dtype=np.float32)
1098
+ try:
1099
+ err_scalar = float(bkg.globalrms)
1100
+ except Exception:
1101
+ err_scalar = float(np.median(np.asarray(bkg.rms(), dtype=np.float32)))
1102
+
1103
+ # Downscale for detection only
1104
+ det = data_sub
1105
+ scale = 1.0
1106
+ if max_side and max(h, w) > int(max_side):
1107
+ scale = float(max(h, w)) / float(max_side)
1108
+ if _HAS_CV2:
1109
+ det = _cv2.resize(
1110
+ det,
1111
+ (max(1, int(round(w / scale))), max(1, int(round(h / scale)))),
1112
+ interpolation=_cv2.INTER_AREA
1113
+ )
1114
+ else:
1115
+ s = int(max(1, round(scale)))
1116
+ det = det[:(h // s) * s, :(w // s) * s].reshape(h // s, s, w // s, s).mean(axis=(1, 3))
1117
+ scale = float(s)
1118
+
1119
+ # When averaging down by 'scale', per-pixel noise scales ~ 1/scale
1120
+ err_det = float(err_scalar) / float(max(1.0, scale))
1121
+
1122
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1123
+ thresh_sigma*8, thresh_sigma*16]
1124
+ objs = None; used = float("nan"); raw = 0
1125
+ for t in thresholds:
1126
+ try:
1127
+ cand = sep.extract(det, thresh=float(t), err=err_det)
1128
+ except Exception:
1129
+ cand = None
1130
+ n = 0 if cand is None else len(cand)
1131
+ if n == 0: continue
1132
+ if n > max_objs*12: continue
1133
+ objs, raw, used = cand, n, float(t)
1134
+ break
1135
+
1136
+ if objs is None or len(objs) == 0:
1137
+ try:
1138
+ cand = sep.extract(det, thresh=float(thresholds[-1]), err=err_det, minarea=9)
1139
+ except Exception:
1140
+ cand = None
1141
+ if cand is None or len(cand) == 0:
1142
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1143
+ return np.ones((h, w), dtype=np.float32, order="C")
1144
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1145
+
1146
+ # Keep brightest max_objs
1147
+ if "flux" in objs.dtype.names:
1148
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1149
+ objs = objs[idx]
1150
+ else:
1151
+ objs = objs[:int(max_objs)]
1152
+ kept_after_cap = int(len(objs))
1153
+
1154
+ # Draw on full-res fresh buffer
1155
+ mask_u8 = np.zeros((h, w), dtype=np.uint8, order="C")
1156
+ s_back = float(scale)
1157
+ MR = int(max(1, max_radius_px))
1158
+ G = int(max(0, grow_px))
1159
+ ES = float(max(0.1, ellipse_scale))
1160
+
1161
+ drawn = 0
1162
+ if _HAS_CV2:
1163
+ for o in objs:
1164
+ x = int(round(float(o["x"]) * s_back))
1165
+ y = int(round(float(o["y"]) * s_back))
1166
+ if not (0 <= x < w and 0 <= y < h):
1167
+ continue
1168
+ a = float(o["a"]) * s_back
1169
+ b = float(o["b"]) * s_back
1170
+ r = int(math.ceil(ES * max(a, b)))
1171
+ r = min(max(r, 0) + G, MR)
1172
+ if r <= 0:
1173
+ continue
1174
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1175
+ drawn += 1
1176
+ else:
1177
+ yy, xx = np.ogrid[:h, :w]
1178
+ for o in objs:
1179
+ x = int(round(float(o["x"]) * s_back))
1180
+ y = int(round(float(o["y"]) * s_back))
1181
+ if not (0 <= x < w and 0 <= y < h):
1182
+ continue
1183
+ a = float(o["a"]) * s_back
1184
+ b = float(o["b"]) * s_back
1185
+ r = int(math.ceil(ES * max(a, b)))
1186
+ r = min(max(r, 0) + G, MR)
1187
+ if r <= 0:
1188
+ continue
1189
+ y0 = max(0, y - r); y1 = min(h, y + r + 1)
1190
+ x0 = max(0, x - r); x1 = min(w, x + r + 1)
1191
+ ys = yy[y0:y1] - y
1192
+ xs = xx[x0:x1] - x
1193
+ disk = (ys * ys + xs * xs) <= (r * r)
1194
+ mask_u8[y0:y1, x0:x1][disk] = 1
1195
+ drawn += 1
1196
+
1197
+ masked_px_hard = int(mask_u8.sum())
1198
+
1199
+ # Feather and convert to KEEP weights in [0,1]
1200
+ m = mask_u8.astype(np.float32, copy=False)
1201
+ if soft_sigma and soft_sigma > 0.0:
1202
+ try:
1203
+ if _HAS_CV2:
1204
+ k = int(max(1, math.ceil(soft_sigma * 3)) * 2 + 1)
1205
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1206
+ borderType=_cv2.BORDER_REFLECT)
1207
+ else:
1208
+ from scipy.ndimage import gaussian_filter
1209
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1210
+ except Exception:
1211
+ pass
1212
+ np.clip(m, 0.0, 1.0, out=m)
1213
+
1214
+ keep = 1.0 - m
1215
+ kf = float(max(0.0, min(0.99, keep_floor)))
1216
+ keep = kf + (1.0 - kf) * keep
1217
+ np.clip(keep, 0.0, 1.0, out=keep)
1218
+
1219
+ status_cb(
1220
+ f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept_after_cap} | "
1221
+ f"drawn={drawn} | masked_px={masked_px_hard} | grow_px={G} | soft_sigma={soft_sigma} | keep_floor={keep_floor}"
1222
+ )
1223
+ return np.ascontiguousarray(keep, dtype=np.float32)
1224
+
1225
+
1226
+
1227
+ def _auto_variance_map(
1228
+ img_2d: np.ndarray,
1229
+ hdr,
1230
+ status_cb=lambda s: None,
1231
+ sample_stride: int = VARMAP_SAMPLE_STRIDE, # kept for signature compat; not used
1232
+ bw: int = 64, # SEP background box width (pixels)
1233
+ bh: int = 64, # SEP background box height (pixels)
1234
+ smooth_sigma: float = 1.0, # Gaussian sigma (px) to smooth the variance map
1235
+ floor: float = 1e-8, # hard floor to prevent blow-up in 1/var
1236
+ ) -> np.ndarray:
1237
+ """
1238
+ Build a per-pixel variance map in DN^2:
1239
+
1240
+ var_DN ≈ (object_only_DN)/gain + var_bg_DN^2
1241
+
1242
+ where:
1243
+ - object_only_DN = max(img - sky_DN, 0)
1244
+ - var_bg_DN^2 comes from SEP's local background rms (Poisson(sky)+readnoise)
1245
+ - if GAIN is missing, estimate 1/gain ≈ median(var_bg)/median(sky)
1246
+
1247
+ Returns float32 array, clipped below by `floor`, optionally smoothed with a
1248
+ small Gaussian to stabilize weights. Emits a summary status line.
1249
+ """
1250
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1251
+
1252
+ # --- Parse header for camera params (optional) ---
1253
+ gain = None
1254
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1255
+ if k in hdr:
1256
+ try:
1257
+ g = float(hdr[k])
1258
+ if np.isfinite(g) and g > 0:
1259
+ gain = g
1260
+ break
1261
+ except Exception:
1262
+ pass
1263
+
1264
+ readnoise = None
1265
+ for k in ("RDNOISE", "READNOISE", "RN"):
1266
+ if k in hdr:
1267
+ try:
1268
+ rn = float(hdr[k])
1269
+ if np.isfinite(rn) and rn >= 0:
1270
+ readnoise = rn
1271
+ break
1272
+ except Exception:
1273
+ pass
1274
+
1275
+ # --- Local background (full-res) ---
1276
+ if sep is not None:
1277
+ try:
1278
+ b = sep.Background(img, bw=int(bw), bh=int(bh), fw=3, fh=3)
1279
+ sky_dn_map = np.asarray(b.back(), dtype=np.float32)
1280
+ try:
1281
+ rms_dn_map = np.asarray(b.rms(), dtype=np.float32)
1282
+ except Exception:
1283
+ rms_dn_map = np.full_like(img, float(np.median(b.rms())), dtype=np.float32)
1284
+ except Exception:
1285
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1286
+ med = float(np.median(img))
1287
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1288
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1289
+ else:
1290
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1291
+ med = float(np.median(img))
1292
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1293
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1294
+
1295
+ # Background variance in DN^2
1296
+ var_bg_dn2 = np.maximum(rms_dn_map, 1e-6) ** 2
1297
+
1298
+ # Object-only DN
1299
+ obj_dn = np.clip(img - sky_dn_map, 0.0, None)
1300
+
1301
+ # Shot-noise coefficient
1302
+ if gain is not None and np.isfinite(gain) and gain > 0:
1303
+ a_shot = 1.0 / gain
1304
+ else:
1305
+ sky_med = float(np.median(sky_dn_map))
1306
+ varbg_med = float(np.median(var_bg_dn2))
1307
+ if sky_med > 1e-6:
1308
+ a_shot = np.clip(varbg_med / sky_med, 0.0, 10.0) # ~ 1/gain estimate
1309
+ else:
1310
+ a_shot = 0.0
1311
+
1312
+ # Total variance: background + shot noise from object-only flux
1313
+ v = var_bg_dn2 + a_shot * obj_dn
1314
+ v_raw = v.copy()
1315
+
1316
+ # Optional mild smoothing
1317
+ if smooth_sigma and smooth_sigma > 0:
1318
+ try:
1319
+ import cv2 as _cv2
1320
+ k = int(max(1, int(round(3 * float(smooth_sigma)))) * 2 + 1)
1321
+ v = _cv2.GaussianBlur(v, (k, k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1322
+ except Exception:
1323
+ try:
1324
+ from scipy.ndimage import gaussian_filter
1325
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1326
+ except Exception:
1327
+ r = int(max(1, round(3 * float(smooth_sigma))))
1328
+ yy, xx = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1329
+ gk = np.exp(-(xx*xx + yy*yy) / (2.0 * float(smooth_sigma) * float(smooth_sigma))).astype(np.float32)
1330
+ gk /= (gk.sum() + EPS)
1331
+ v = _conv_same_np(v, gk)
1332
+
1333
+ # Clip to avoid zero/negative variances
1334
+ v = np.clip(v, float(floor), None).astype(np.float32, copy=False)
1335
+
1336
+ # Emit telemetry
1337
+ try:
1338
+ sky_med = float(np.median(sky_dn_map))
1339
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1340
+ floor_pct = float((v <= floor).mean() * 100.0)
1341
+ status_cb(
1342
+ "Variance map: "
1343
+ f"sky_med={sky_med:.3g} DN | rms_med={rms_med:.3g} DN | "
1344
+ f"gain={(gain if gain is not None else 'NA')} | rn={(readnoise if readnoise is not None else 'NA')} | "
1345
+ f"smooth_sigma={smooth_sigma} | floor={floor} ({floor_pct:.2f}% at floor)"
1346
+ )
1347
+ except Exception:
1348
+ pass
1349
+
1350
+ return v
1351
+
1352
+
1353
+ def _star_mask_from_precomputed(
1354
+ img_2d: np.ndarray,
1355
+ sky_map: np.ndarray,
1356
+ err_scalar: float,
1357
+ *,
1358
+ thresh_sigma: float,
1359
+ max_objs: int,
1360
+ grow_px: int,
1361
+ ellipse_scale: float,
1362
+ soft_sigma: float,
1363
+ max_radius_px: int,
1364
+ keep_floor: float,
1365
+ max_side: int,
1366
+ status_cb=lambda s: None
1367
+ ) -> np.ndarray:
1368
+ """
1369
+ Build a KEEP weight map using a *downscaled detection / full-res draw* path.
1370
+ **Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
1371
+ """
1372
+ # Optional OpenCV fast path
1373
+ try:
1374
+ import cv2 as _cv2
1375
+ _HAS_CV2 = True
1376
+ except Exception:
1377
+ _HAS_CV2 = False
1378
+ _cv2 = None # type: ignore
1379
+
1380
+ H, W = map(int, img_2d.shape)
1381
+
1382
+ # Residual for detection (contiguous, separate buffer)
1383
+ data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
1384
+
1385
+ # Downscale *detection only* to speed up, never the draw step
1386
+ det = data_sub
1387
+ scale = 1.0
1388
+ if max_side and max(H, W) > int(max_side):
1389
+ scale = float(max(H, W)) / float(max_side)
1390
+ if _HAS_CV2:
1391
+ det = _cv2.resize(
1392
+ det,
1393
+ (max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
1394
+ interpolation=_cv2.INTER_AREA
1395
+ )
1396
+ else:
1397
+ s = int(max(1, round(scale)))
1398
+ det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
1399
+ scale = float(s)
1400
+
1401
+ # Threshold ladder
1402
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1403
+ thresh_sigma*8, thresh_sigma*16]
1404
+ objs = None; used = float("nan"); raw = 0
1405
+ for t in thresholds:
1406
+ cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
1407
+ n = 0 if cand is None else len(cand)
1408
+ if n == 0: continue
1409
+ if n > max_objs*12: continue
1410
+ objs, raw, used = cand, n, float(t)
1411
+ break
1412
+
1413
+ if objs is None or len(objs) == 0:
1414
+ try:
1415
+ cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
1416
+ except Exception:
1417
+ cand = None
1418
+ if cand is None or len(cand) == 0:
1419
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1420
+ return np.ones((H, W), dtype=np.float32, order="C")
1421
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1422
+
1423
+ # Brightest max_objs
1424
+ if "flux" in objs.dtype.names:
1425
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1426
+ objs = objs[idx]
1427
+ else:
1428
+ objs = objs[:int(max_objs)]
1429
+ kept = len(objs)
1430
+
1431
+ # ---- draw back on full-res into a brand-new buffer ----
1432
+ mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
1433
+ s_back = float(scale)
1434
+ MR = int(max(1, max_radius_px))
1435
+ G = int(max(0, grow_px))
1436
+ ES = float(max(0.1, ellipse_scale))
1437
+
1438
+ drawn = 0
1439
+ if _HAS_CV2:
1440
+ for o in objs:
1441
+ x = int(round(float(o["x"]) * s_back))
1442
+ y = int(round(float(o["y"]) * s_back))
1443
+ if not (0 <= x < W and 0 <= y < H):
1444
+ continue
1445
+ a = float(o["a"]) * s_back
1446
+ b = float(o["b"]) * s_back
1447
+ r = int(math.ceil(ES * max(a, b)))
1448
+ r = min(max(r, 0) + G, MR)
1449
+ if r <= 0:
1450
+ continue
1451
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1452
+ drawn += 1
1453
+ else:
1454
+ for o in objs:
1455
+ x = int(round(float(o["x"]) * s_back))
1456
+ y = int(round(float(o["y"]) * s_back))
1457
+ if not (0 <= x < W and 0 <= y < H):
1458
+ continue
1459
+ a = float(o["a"]) * s_back
1460
+ b = float(o["b"]) * s_back
1461
+ r = int(math.ceil(ES * max(a, b)))
1462
+ r = min(max(r, 0) + G, MR)
1463
+ if r <= 0:
1464
+ continue
1465
+ y0 = max(0, y - r); y1 = min(H, y + r + 1)
1466
+ x0 = max(0, x - r); x1 = min(W, x + r + 1)
1467
+ yy, xx = np.ogrid[y0:y1, x0:x1]
1468
+ disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
1469
+ mask_u8[y0:y1, x0:x1][disk] = 1
1470
+ drawn += 1
1471
+
1472
+ # Feather + convert to keep weights
1473
+ m = mask_u8.astype(np.float32, copy=False)
1474
+ if soft_sigma > 0:
1475
+ try:
1476
+ if _HAS_CV2:
1477
+ k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
1478
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1479
+ borderType=_cv2.BORDER_REFLECT)
1480
+ else:
1481
+ from scipy.ndimage import gaussian_filter
1482
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1483
+ except Exception:
1484
+ pass
1485
+ np.clip(m, 0.0, 1.0, out=m)
1486
+
1487
+ keep = 1.0 - m
1488
+ kf = float(max(0.0, min(0.99, keep_floor)))
1489
+ keep = kf + (1.0 - kf) * keep
1490
+ np.clip(keep, 0.0, 1.0, out=keep)
1491
+
1492
+ status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
1493
+ return np.ascontiguousarray(keep, dtype=np.float32)
1494
+
1495
+
1496
+ def _variance_map_from_precomputed(
1497
+ img_2d: np.ndarray,
1498
+ sky_map: np.ndarray,
1499
+ rms_map: np.ndarray,
1500
+ hdr,
1501
+ *,
1502
+ smooth_sigma: float,
1503
+ floor: float,
1504
+ status_cb=lambda s: None
1505
+ ) -> np.ndarray:
1506
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1507
+ var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
1508
+ obj_dn = np.clip(img - sky_map, 0.0, None)
1509
+
1510
+ gain = None
1511
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1512
+ if k in hdr:
1513
+ try:
1514
+ g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
1515
+ if gain is not None: break
1516
+ except Exception as e:
1517
+ import logging
1518
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1519
+
1520
+ if gain is not None:
1521
+ a_shot = 1.0 / gain
1522
+ else:
1523
+ sky_med = float(np.median(sky_map))
1524
+ varbg_med= float(np.median(var_bg_dn2))
1525
+ a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
1526
+ a_shot = float(np.clip(a_shot, 0.0, 10.0))
1527
+
1528
+ v = var_bg_dn2 + a_shot * obj_dn
1529
+ if smooth_sigma > 0:
1530
+ try:
1531
+ import cv2 as _cv2
1532
+ k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
1533
+ v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1534
+ except Exception:
1535
+ try:
1536
+ from scipy.ndimage import gaussian_filter
1537
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1538
+ except Exception:
1539
+ pass
1540
+
1541
+ np.clip(v, float(floor), None, out=v)
1542
+ try:
1543
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1544
+ status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
1545
+ except Exception:
1546
+ pass
1547
+ return v.astype(np.float32, copy=False)
1548
+
1549
+
1550
+
1551
+ # -----------------------------
1552
+ # Robust weighting (Huber)
1553
+ # -----------------------------
1554
+
1555
+ EPS = 1e-6
1556
+
1557
+ def _estimate_scalar_variance(a):
1558
+ med = np.median(a)
1559
+ mad = np.median(np.abs(a - med)) + 1e-6
1560
+ return float((1.4826 * mad) ** 2)
1561
+
1562
+ def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
1563
+ """
1564
+ W = [psi(r)/r] * 1/(var + eps) * mask, psi=Huber
1565
+ If huber_delta<0, delta = (-huber_delta) * RMS(residual) via MAD.
1566
+ y,pred: (H,W) or (C,H,W). var_map/mask are 2D; broadcast if needed.
1567
+ """
1568
+ r = y - pred
1569
+ # auto delta?
1570
+ if huber_delta < 0:
1571
+ med = np.median(r)
1572
+ mad = np.median(np.abs(r - med)) + 1e-6
1573
+ rms = 1.4826 * mad
1574
+ delta = (-huber_delta) * max(rms, 1e-6)
1575
+ else:
1576
+ delta = huber_delta
1577
+
1578
+ absr = np.abs(r)
1579
+ if float(delta) > 0:
1580
+ psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + EPS)).astype(np.float32)
1581
+ else:
1582
+ psi_over_r = np.ones_like(r, dtype=np.float32)
1583
+
1584
+ if var_map is None:
1585
+ v = _estimate_scalar_variance(r)
1586
+ else:
1587
+ v = var_map
1588
+ if v.ndim == 2 and r.ndim == 3:
1589
+ v = v[None, ...]
1590
+ w = psi_over_r / (v + EPS)
1591
+ if mask is not None:
1592
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1593
+ w = w * m
1594
+ return w
1595
+
1596
+
1597
+ # -----------------------------
1598
+ # Torch / conv
1599
+ # -----------------------------
1600
+
1601
+ def _fftshape_same(H, W, kh, kw):
1602
+ return H + kh - 1, W + kw - 1
1603
+
1604
+ # ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
1605
+ def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
1606
+ """
1607
+ Pack (Kf,padH,padW,kh,kw) so the conv can crop correctly to SAME.
1608
+ Kernel is ifftshifted before padding.
1609
+ """
1610
+ tfft = torch.fft
1611
+ psf_fft, psfT_fft = [], []
1612
+ for k, kT in zip(psfs, flip_psf):
1613
+ kh, kw = k.shape
1614
+ padH, padW = _fftshape_same(H, W, kh, kw)
1615
+ k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
1616
+ kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
1617
+ Kf = tfft.rfftn(k_small, s=(padH, padW))
1618
+ KTf = tfft.rfftn(kT_small, s=(padH, padW))
1619
+ psf_fft.append((Kf, padH, padW, kh, kw))
1620
+ psfT_fft.append((KTf, padH, padW, kh, kw))
1621
+ return psf_fft, psfT_fft
1622
+
1623
+
1624
+ def _fft_conv_same_torch(x, Kf_pack, out_spatial):
1625
+ tfft = torch.fft
1626
+ Kf, padH, padW, kh, kw = Kf_pack
1627
+ H, W = x.shape[-2], x.shape[-1]
1628
+ if x.ndim == 2:
1629
+ X = tfft.rfftn(x, s=(padH, padW))
1630
+ y = tfft.irfftn(X * Kf, s=(padH, padW))
1631
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1632
+ out_spatial.copy_(y[sh:sh+H, sw:sw+W])
1633
+ return out_spatial
1634
+ else:
1635
+ X = tfft.rfftn(x, s=(padH, padW), dim=(-2,-1))
1636
+ y = tfft.irfftn(X * Kf, s=(padH, padW), dim=(-2,-1))
1637
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1638
+ out_spatial.copy_(y[..., sh:sh+H, sw:sw+W])
1639
+ return out_spatial
1640
+
1641
+ # ---------- NumPy FFT helpers ----------
1642
+ def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
1643
+ import numpy.fft as fft
1644
+ meta, Kfs, KTfs = [], [], []
1645
+ for k, kT in zip(psfs, flip_psf):
1646
+ kh, kw = k.shape
1647
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1648
+ Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
1649
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
1650
+ meta.append((kh, kw, fftH, fftW))
1651
+ return Kfs, KTfs, meta
1652
+
1653
+ def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
1654
+ import numpy.fft as fft
1655
+ if a.ndim == 2:
1656
+ A = fft.rfftn(a, s=(fftH, fftW))
1657
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1658
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1659
+ out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
1660
+ return out
1661
+ else:
1662
+ C, H, W = a.shape
1663
+ acc = []
1664
+ for c in range(C):
1665
+ A = fft.rfftn(a[c], s=(fftH, fftW))
1666
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1667
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1668
+ acc.append(y[sh:sh+H, sw:sw+W])
1669
+ out[...] = np.stack(acc, 0)
1670
+ return out
1671
+
1672
+
1673
+
1674
+ def _torch_device():
1675
+ if TORCH_OK and (torch is not None):
1676
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1677
+ return torch.device("cuda")
1678
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1679
+ return torch.device("mps")
1680
+ # DirectML: we passed dml_device from outer scope; keep a module-global
1681
+ if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
1682
+ return globals()["dml_device"]
1683
+ return torch.device("cpu")
1684
+
1685
+ def _to_t(x: np.ndarray):
1686
+ if not (TORCH_OK and (torch is not None)):
1687
+ raise RuntimeError("Torch path requested but torch is unavailable")
1688
+ device = _torch_device()
1689
+ t = torch.from_numpy(x)
1690
+ # DirectML wants explicit .to(device)
1691
+ return t.to(device, non_blocking=True) if str(device) != "cpu" else t
1692
+
1693
+ def _contig(x):
1694
+ return np.ascontiguousarray(x, dtype=np.float32)
1695
+
1696
+ def _conv_same_torch(img_t, psf_t):
1697
+ """
1698
+ img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
1699
+ psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
1700
+ Pads with 'reflect' to avoid zero-padding ringing.
1701
+ """
1702
+ kh, kw = psf_t.shape[-2:]
1703
+ pad = (kw // 2, kw - kw // 2 - 1, # left, right
1704
+ kh // 2, kh - kh // 2 - 1) # top, bottom
1705
+
1706
+ if img_t.ndim == 2:
1707
+ x = img_t[None, None]
1708
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1709
+ y = torch.nn.functional.conv2d(x, psf_t, padding=0)
1710
+ return y[0, 0]
1711
+ else:
1712
+ C = img_t.shape[0]
1713
+ x = img_t[None]
1714
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1715
+ w = psf_t.repeat(C, 1, 1, 1)
1716
+ y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
1717
+ return y[0]
1718
+
1719
+ def _safe_inference_context():
1720
+ """
1721
+ Return a valid, working no-grad context:
1722
+ - prefer torch.inference_mode() if it exists *and* can be entered,
1723
+ - otherwise fall back to torch.no_grad(),
1724
+ - if torch is unavailable, return NO_GRAD.
1725
+ """
1726
+ if not (TORCH_OK and (torch is not None)):
1727
+ return NO_GRAD
1728
+
1729
+ cm = getattr(torch, "inference_mode", None)
1730
+ if cm is None:
1731
+ return torch.no_grad
1732
+
1733
+ # Probe inference_mode once; if it explodes on this build, fall back.
1734
+ try:
1735
+ with cm():
1736
+ pass
1737
+ return cm
1738
+ except Exception:
1739
+ return torch.no_grad
1740
+
1741
+ def _ensure_mask_list(masks, data):
1742
+ # 1s where valid, 0s where invalid (soft edges allowed)
1743
+ if masks is None:
1744
+ return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
1745
+ out = []
1746
+ for a, m in zip(data, masks):
1747
+ base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
1748
+ if m is None:
1749
+ out.append(np.ones_like(base, dtype=np.float32))
1750
+ else:
1751
+ mm = np.asarray(m, dtype=np.float32)
1752
+ if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
1753
+ mm = mm[0]
1754
+ if mm.shape != base.shape:
1755
+ # center crop to match (common intersection already applied)
1756
+ Ht, Wt = base.shape
1757
+ mm = _center_crop(mm, Ht, Wt)
1758
+ # keep as float weights in [0,1] (do not threshold!)
1759
+ out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
1760
+ return out
1761
+
1762
+ def _ensure_var_list(variances, data):
1763
+ # If None, we’ll estimate a robust scalar per frame on-the-fly.
1764
+ if variances is None:
1765
+ return [None]*len(data)
1766
+ out = []
1767
+ for a, v in zip(data, variances):
1768
+ if v is None:
1769
+ out.append(None)
1770
+ else:
1771
+ vv = np.asarray(v, dtype=np.float32)
1772
+ if vv.ndim == 3:
1773
+ vv = vv[0]
1774
+ base = a if a.ndim==2 else a[0]
1775
+ if vv.shape != base.shape:
1776
+ Ht, Wt = base.shape
1777
+ vv = _center_crop(vv, Ht, Wt)
1778
+ # clip tiny/negatives
1779
+ vv = np.nan_to_num(vv, nan=1e-8, posinf=1e8, neginf=1e8, copy=False)
1780
+ vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
1781
+ out.append(vv)
1782
+ return out
1783
+
1784
+ # ---- SR operators (downsample / upsample-sum) ----
1785
+ def _downsample_avg(img, r: int):
1786
+ """Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
1787
+ if r <= 1:
1788
+ return img
1789
+ a = np.asarray(img, dtype=np.float32)
1790
+ if a.ndim == 2:
1791
+ H, W = a.shape
1792
+ Hs, Ws = (H // r) * r, (W // r) * r
1793
+ a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
1794
+ return a
1795
+ else:
1796
+ C, H, W = a.shape
1797
+ Hs, Ws = (H // r) * r, (W // r) * r
1798
+ a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
1799
+ return a
1800
+
1801
+ def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
1802
+ """Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
1803
+ For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
1804
+ """
1805
+ if r <= 1:
1806
+ return img
1807
+ a = np.asarray(img, dtype=np.float32)
1808
+ if a.ndim == 2:
1809
+ H, W = a.shape
1810
+ out = np.kron(a, np.ones((r, r), dtype=np.float32))
1811
+ else:
1812
+ C, H, W = a.shape
1813
+ out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
1814
+ if target_hw is not None:
1815
+ Ht, Wt = target_hw
1816
+ out = _center_crop(out, Ht, Wt)
1817
+ return out
1818
+
1819
+ def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
1820
+ r = (ksize - 1) // 2
1821
+ y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1822
+ g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
1823
+ g /= g.sum() + EPS
1824
+ return g
1825
+
1826
+ def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
1827
+ # lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
1828
+ return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
1829
+
1830
+ def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
1831
+ iters: int = 500, lr: float = 0.1) -> np.ndarray:
1832
+ """
1833
+ Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
1834
+ where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
1835
+ """
1836
+ f = np.asarray(f_native, dtype=np.float32)
1837
+
1838
+ # NEW: sanitize to 2D odd square before anything else
1839
+ if f.ndim != 2:
1840
+ f = np.squeeze(f)
1841
+ if f.ndim != 2:
1842
+ raise ValueError(f"PSF must be 2D, got shape {f.shape}")
1843
+
1844
+ H, W = int(f.shape[0]), int(f.shape[1])
1845
+ k_sq = min(H, W)
1846
+ # center-crop to square if needed
1847
+ if H != W:
1848
+ y0 = (H - k_sq) // 2
1849
+ x0 = (W - k_sq) // 2
1850
+ f = f[y0:y0 + k_sq, x0:x0 + k_sq]
1851
+ H = W = k_sq
1852
+
1853
+ # enforce odd size (required by SAME padding math)
1854
+ if (H % 2) == 0:
1855
+ # drop one pixel border to make it odd (centered)
1856
+ f = f[1:, 1:]
1857
+ H = W = f.shape[0]
1858
+
1859
+ k = int(H) # k is now odd and square
1860
+ kr = int(k * r)
1861
+
1862
+
1863
+ g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
1864
+
1865
+ h0 = np.zeros((kr, kr), dtype=np.float32)
1866
+ h0[::r, ::r] = f
1867
+ h0 = _normalize_psf(h0)
1868
+
1869
+ if TORCH_OK:
1870
+ import torch.nn.functional as F
1871
+ dev = _torch_device()
1872
+
1873
+ # (1) Make sure Gaussian kernel is odd-sized for SAME conv padding
1874
+ g_pad = g
1875
+ if (g.shape[-1] % 2) == 0:
1876
+ # ensure odd + renormalize
1877
+ gg = _pad_kernel_to(g, g.shape[-1] + 1)
1878
+ g_pad = gg.astype(np.float32, copy=False)
1879
+
1880
+ t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
1881
+ f_t = torch.tensor(f, device=dev, dtype=torch.float32)
1882
+ g_t = torch.tensor(g_pad, device=dev, dtype=torch.float32)
1883
+ opt = torch.optim.Adam([t], lr=lr)
1884
+
1885
+ # Helpful assertion avoids silent shape traps
1886
+ H, W = t.shape
1887
+ assert (H % r) == 0 and (W % r) == 0, f"h shape {t.shape} not divisible by r={r}"
1888
+ Hr, Wr = H // r, W // r
1889
+
1890
+ try:
1891
+ for _ in range(max(10, iters)):
1892
+ opt.zero_grad(set_to_none=True)
1893
+
1894
+ # (2) Downsample with avg_pool2d instead of reshape/mean
1895
+ blk = t.narrow(0, 0, Hr * r).narrow(1, 0, Wr * r).contiguous()
1896
+ th = F.avg_pool2d(blk[None, None], kernel_size=r, stride=r)[0, 0] # (k,k)
1897
+
1898
+ # (3) Native-space blur with guaranteed-odd g_t
1899
+ pad = g_t.shape[-1] // 2
1900
+ conv = F.conv2d(th[None, None], g_t[None, None], padding=pad)[0, 0]
1901
+
1902
+ loss = torch.mean((conv - f_t) ** 2)
1903
+ loss.backward()
1904
+ opt.step()
1905
+ with torch.no_grad():
1906
+ t.clamp_(min=0.0)
1907
+ t /= (t.sum() + 1e-8)
1908
+ h = t.detach().cpu().numpy().astype(np.float32)
1909
+ except Exception:
1910
+ # (4) Conservative safety net: if a backend balks (commonly at r=2),
1911
+ # fall back to the NumPy solver *just for this kernel*.
1912
+ h = None
1913
+
1914
+ if not TORCH_OK or h is None:
1915
+ # NumPy fallback (unchanged)
1916
+ h = h0.copy()
1917
+ eta = float(lr)
1918
+ for _ in range(max(50, iters)):
1919
+ Dh = _downsample_avg(h, r)
1920
+ conv = _conv2_same_np(Dh, g)
1921
+ resid = (conv - f)
1922
+ grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
1923
+ grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
1924
+ h = np.clip(h - eta * grad_h, 0.0, None)
1925
+ s = float(h.sum()); h /= (s + 1e-8)
1926
+ eta *= 0.995
1927
+
1928
+ return _normalize_psf(h)
1929
+
1930
+
1931
+ def _downsample_avg_t(x, r: int):
1932
+ if r <= 1:
1933
+ return x
1934
+ if x.ndim == 2:
1935
+ H, W = x.shape
1936
+ Hr, Wr = (H // r) * r, (W // r) * r
1937
+ if Hr == 0 or Wr == 0:
1938
+ return x
1939
+ x2 = x[:Hr, :Wr]
1940
+ # ❌ .view → ✅ .reshape
1941
+ return x2.reshape(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
1942
+ else:
1943
+ C, H, W = x.shape
1944
+ Hr, Wr = (H // r) * r, (W // r) * r
1945
+ if Hr == 0 or Wr == 0:
1946
+ return x
1947
+ x2 = x[:, :Hr, :Wr]
1948
+ # ❌ .view → ✅ .reshape
1949
+ return x2.reshape(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
1950
+
1951
+ def _upsample_sum_t(x, r: int):
1952
+ if r <= 1:
1953
+ return x
1954
+ if x.ndim == 2:
1955
+ return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
1956
+ else:
1957
+ return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
1958
+
1959
+ def _sep_bg_rms(frames):
1960
+ """Return a robust background RMS using SEP's background model on the first frame."""
1961
+ if sep is None or not frames:
1962
+ return None
1963
+ try:
1964
+ y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
1965
+ a = np.ascontiguousarray(y0, dtype=np.float32)
1966
+ b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
1967
+ try:
1968
+ rms_val = float(b.globalrms)
1969
+ except Exception:
1970
+ # some SEP builds don’t expose globalrms; fall back to the map’s median
1971
+ rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
1972
+ return rms_val
1973
+ except Exception:
1974
+ return None
1975
+
1976
+ # =========================
1977
+ # Memory/streaming helpers
1978
+ # =========================
1979
+
1980
+ def _approx_bytes(arr_like_shape, dtype=np.float32):
1981
+ """Rough byte estimator for a given shape/dtype."""
1982
+ return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
1983
+
1984
+ def _mem_model(
1985
+ grid_hw: tuple[int,int],
1986
+ r: int,
1987
+ ksize: int,
1988
+ channels: int,
1989
+ mem_target_mb: int,
1990
+ prefer_tiles: bool = False,
1991
+ min_tile: int = 256,
1992
+ max_tile: int = 2048,
1993
+ ) -> dict:
1994
+ """
1995
+ Pick a batch size (#frames) and optional tile size (HxW) given a memory budget.
1996
+ Very conservative — aims to bound peak working-set on CPU/GPU.
1997
+ """
1998
+ Hs, Ws = grid_hw
1999
+ halo = (ksize // 2) * max(1, r) # SR grid halo if r>1
2000
+ C = max(1, channels)
2001
+
2002
+ # working-set per *full-frame* conv scratch (num/den/tmp/etc.)
2003
+ per_frame_fft_like = 3 * _approx_bytes((C, Hs, Ws)) # tmp/pred + in/out buffers
2004
+ global_accum = 2 * _approx_bytes((C, Hs, Ws)) # num + den
2005
+
2006
+ budget = int(mem_target_mb * 1024 * 1024)
2007
+
2008
+ # Try to stay in full-frame mode first unless prefer_tiles
2009
+ B_full = max(1, (budget - global_accum) // max(per_frame_fft_like, 1))
2010
+ use_tiles = prefer_tiles or (B_full < 1)
2011
+
2012
+ if not use_tiles:
2013
+ return dict(batch_frames=int(B_full), tiles=None, halo=int(halo), ksize=int(ksize))
2014
+
2015
+ # Tile mode: pick a square tile side t that fits
2016
+ # scratch per tile ~ 3*C*(t+2h)^2 + accum(core) ~ small
2017
+ # try descending from max_tile
2018
+ t = int(min(max_tile, max(min_tile, 1 << int(np.floor(np.log2(min(Hs, Ws)))))))
2019
+ while t >= min_tile:
2020
+ th = t + 2 * halo
2021
+ per_tile = 3 * _approx_bytes((C, th, th))
2022
+ B_tile = max(1, (budget - global_accum) // max(per_tile, 1))
2023
+ if B_tile >= 1:
2024
+ return dict(batch_frames=int(B_tile), tiles=(t, t), halo=int(halo), ksize=int(ksize))
2025
+ t //= 2
2026
+
2027
+ # Worst case: 1 frame, minimal tile
2028
+ return dict(batch_frames=1, tiles=(min_tile, min_tile), halo=int(halo), ksize=int(ksize))
2029
+
2030
+ def _build_seed_running_mu_sigma_from_paths(
2031
+ paths, Ht, Wt, color_mode,
2032
+ *, bootstrap_frames=24, clip_sigma=3.5, # clip_sigma used for streaming updates
2033
+ status_cb=lambda s: None, progress_cb=None
2034
+ ):
2035
+ """
2036
+ Seed:
2037
+ 1) Load first B frames -> mean0
2038
+ 2) MAD around mean0 -> ±4·MAD mask -> masked-mean seed (one mini-iteration)
2039
+ 3) Stream remaining frames with σ-clipped Welford updates (unchanged behavior)
2040
+ Returns float32 image in (H,W) or (C,H,W) matching color_mode.
2041
+ """
2042
+ def p(frac, msg):
2043
+ if progress_cb:
2044
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2045
+
2046
+ n_total = len(paths)
2047
+ B = int(max(1, min(int(bootstrap_frames), n_total)))
2048
+ status_cb(f"MFDeconv: Seed bootstrap {B} frame(s) with ±4·MAD clip on the average…")
2049
+ p(0.00, f"bootstrap load 0/{B}")
2050
+
2051
+ # ---------- load first B frames ----------
2052
+ boot = []
2053
+ for i, pth in enumerate(paths[:B], start=1):
2054
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2055
+ boot.append(ys[0].astype(np.float32, copy=False))
2056
+ if (i == B) or (i % 4 == 0):
2057
+ p(0.25 * (i / float(B)), f"bootstrap load {i}/{B}")
2058
+
2059
+ stack = np.stack(boot, axis=0) # (B,H,W) or (B,C,H,W)
2060
+ del boot
2061
+
2062
+ # ---------- mean0 ----------
2063
+ mean0 = np.mean(stack, axis=0, dtype=np.float32)
2064
+ p(0.28, "bootstrap mean computed")
2065
+
2066
+ # ---------- ±4·MAD clip around mean0, then masked mean (one pass) ----------
2067
+ # MAD per-pixel: median(|x - mean0|)
2068
+ abs_dev = np.abs(stack - mean0[None, ...])
2069
+ mad = np.median(abs_dev, axis=0).astype(np.float32, copy=False)
2070
+
2071
+ thr = 4.0 * mad + EPS
2072
+ mask = (abs_dev <= thr)
2073
+
2074
+ # masked mean with fallback to mean0 where all rejected
2075
+ m = mask.astype(np.float32, copy=False)
2076
+ sum_acc = np.sum(stack * m, axis=0, dtype=np.float32)
2077
+ cnt_acc = np.sum(m, axis=0, dtype=np.float32)
2078
+ seed = mean0.copy()
2079
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2080
+ p(0.36, "±4·MAD masked mean computed")
2081
+
2082
+ # ---------- initialize Welford state from seed ----------
2083
+ # Start μ=seed, set an initial variance envelope from the bootstrap dispersion
2084
+ dif = stack - seed[None, ...]
2085
+ M2 = np.sum(dif * dif, axis=0, dtype=np.float32)
2086
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2087
+ mu = seed.astype(np.float32, copy=False)
2088
+ del stack, abs_dev, mad, m, sum_acc, cnt_acc, dif
2089
+
2090
+ p(0.40, "seed initialized; streaming refinements…")
2091
+
2092
+ # ---------- stream remaining frames with σ-clipped Welford updates ----------
2093
+ remain = n_total - B
2094
+ if remain > 0:
2095
+ status_cb(f"MFDeconv: Seed μ–σ clipping {remain} remaining frame(s) (k={clip_sigma:.2f})…")
2096
+
2097
+ k = float(clip_sigma)
2098
+ for j, pth in enumerate(paths[B:], start=1):
2099
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2100
+ x = ys[0].astype(np.float32, copy=False)
2101
+
2102
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2103
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2104
+
2105
+ accept = (np.abs(x - mu) <= (k * sigma))
2106
+ acc = accept.astype(np.float32, copy=False)
2107
+
2108
+ n_new = cnt + acc
2109
+ delta = x - mu
2110
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2111
+ M2 = M2 + acc * delta * (x - mu_n)
2112
+
2113
+ mu, cnt = mu_n, n_new
2114
+
2115
+ if (j == remain) or (j % 8 == 0):
2116
+ p(0.40 + 0.60 * (j / float(remain)), f"μ–σ refine {j}/{remain}")
2117
+
2118
+ p(1.0, "seed ready")
2119
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2120
+
2121
+
2122
+ def _chunk(seq, n):
2123
+ """Yield chunks of size n from seq."""
2124
+ for i in range(0, len(seq), n):
2125
+ yield seq[i:i+n]
2126
+
2127
+ def _read_shape_fast(path) -> tuple[int,int,int]:
2128
+ if _is_xisf(path):
2129
+ a, _ = _load_image_array(path)
2130
+ if a is None:
2131
+ raise ValueError(f"No data in {path}")
2132
+ a = np.asarray(a)
2133
+ else:
2134
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2135
+ a = hdul[0].data
2136
+ if a is None:
2137
+ raise ValueError(f"No data in {path}")
2138
+
2139
+ # common logic for both XISF and FITS
2140
+ if a.ndim == 2:
2141
+ H, W = a.shape
2142
+ return (1, int(H), int(W))
2143
+ if a.ndim == 3:
2144
+ if a.shape[-1] in (1, 3): # HWC
2145
+ C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
2146
+ return (1 if C == 1 else 3, H, W)
2147
+ if a.shape[0] in (1, 3): # CHW
2148
+ return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
2149
+ s = tuple(map(int, a.shape))
2150
+ H, W = s[-2], s[-1]
2151
+ return (1, H, W)
2152
+
2153
+
2154
+
2155
+
2156
+ def _tiles_of(hw: tuple[int,int], tile_hw: tuple[int,int], halo: int):
2157
+ """
2158
+ Yield tiles as dicts: {y0,y1,x0,x1,yc0,yc1,xc0,xc1}
2159
+ (outer region includes halo; core (yc0:yc1, xc0:xc1) excludes halo).
2160
+ """
2161
+ H, W = hw
2162
+ th, tw = tile_hw
2163
+ th = max(1, int(th)); tw = max(1, int(tw))
2164
+ for y in range(0, H, th):
2165
+ for x in range(0, W, tw):
2166
+ yc0 = y; yc1 = min(y + th, H)
2167
+ xc0 = x; xc1 = min(x + tw, W)
2168
+ y0 = max(0, yc0 - halo); y1 = min(H, yc1 + halo)
2169
+ x0 = max(0, xc0 - halo); x1 = min(W, xc1 + halo)
2170
+ yield dict(y0=y0, y1=y1, x0=x0, x1=x1, yc0=yc0, yc1=yc1, xc0=xc0, xc1=xc1)
2171
+
2172
+ def _extract_with_halo(a, tile):
2173
+ """
2174
+ Slice 'a' ((H,W) or (C,H,W)) to [y0:y1, x0:x1] with channel kept.
2175
+ """
2176
+ y0,y1,x0,x1 = tile["y0"], tile["y1"], tile["x0"], tile["x1"]
2177
+ if a.ndim == 2:
2178
+ return a[y0:y1, x0:x1]
2179
+ else:
2180
+ return a[:, y0:y1, x0:x1]
2181
+
2182
+ def _add_core(accum, tile_val, tile):
2183
+ """
2184
+ Add tile_val core into accum at (yc0:yc1, xc0:xc1).
2185
+ Shapes match (2D) or (C,H,W).
2186
+ """
2187
+ yc0,yc1,xc0,xc1 = tile["yc0"], tile["yc1"], tile["xc0"], tile["xc1"]
2188
+ if accum.ndim == 2:
2189
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2190
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2191
+ accum[yc0:yc1, xc0:xc1] += tile_val[h0:h1, w0:w1]
2192
+ else:
2193
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2194
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2195
+ accum[:, yc0:yc1, xc0:xc1] += tile_val[:, h0:h1, w0:w1]
2196
+
2197
+ def _prepare_np_fft_packs_batch(psfs, flip_psf, Hs, Ws):
2198
+ """Precompute rFFT packs on current grid for NumPy path; returns lists aligned to batch psfs."""
2199
+ Kfs, KTfs, meta = [], [], []
2200
+ import numpy.fft as fft
2201
+ for k, kT in zip(psfs, flip_psf):
2202
+ kh, kw = k.shape
2203
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2204
+ Kfs.append(fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)))
2205
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)))
2206
+ meta.append((kh, kw, fftH, fftW))
2207
+ return Kfs, KTfs, meta
2208
+
2209
+ def _prepare_torch_fft_packs_batch(psfs, flip_psf, Hs, Ws, device, dtype):
2210
+ """Torch FFT packs per PSF on current grid; mirrors your existing packer."""
2211
+ return _precompute_torch_psf_ffts(psfs, flip_psf, Hs, Ws, device, dtype)
2212
+
2213
+ def _as_chw(np_img: np.ndarray) -> np.ndarray:
2214
+ x = np.asarray(np_img, dtype=np.float32, order="C")
2215
+ if x.size == 0:
2216
+ raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
2217
+ if x.ndim == 2:
2218
+ return x[None, ...] # 1,H,W
2219
+ if x.ndim == 3 and x.shape[0] in (1, 3):
2220
+ if x.shape[0] == 0:
2221
+ raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
2222
+ return x
2223
+ if x.ndim == 3 and x.shape[-1] in (1, 3):
2224
+ if x.shape[-1] == 0:
2225
+ raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
2226
+ return np.moveaxis(x, -1, 0)
2227
+ # last resort: treat first dim as channels, but reject zero
2228
+ if x.shape[0] == 0:
2229
+ raise RuntimeError(f"Zero channels in array; shape={x.shape}")
2230
+ return x
2231
+
2232
+
2233
+
2234
+ def _conv_same_np_spatial(a: np.ndarray, k: np.ndarray, out: np.ndarray | None = None):
2235
+ try:
2236
+ import cv2
2237
+ except Exception:
2238
+ return None # no opencv -> caller falls back to FFT
2239
+
2240
+ # cv2 wants HxW single-channel float32
2241
+ kf = np.ascontiguousarray(k.astype(np.float32))
2242
+ kf = np.flip(np.flip(kf, 0), 1) # OpenCV uses correlation; flip to emulate conv
2243
+
2244
+ if a.ndim == 2:
2245
+ y = cv2.filter2D(a, -1, kf, borderType=cv2.BORDER_REFLECT)
2246
+ if out is None: return y
2247
+ out[...] = y; return out
2248
+ else:
2249
+ C, H, W = a.shape
2250
+ if out is None:
2251
+ out = np.empty_like(a)
2252
+ for c in range(C):
2253
+ out[c] = cv2.filter2D(a[c], -1, kf, borderType=cv2.BORDER_REFLECT)
2254
+ return out
2255
+
2256
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
2257
+ F = torch.nn.functional
2258
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
2259
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
2260
+
2261
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
2262
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
2263
+
2264
+ # unified path (CUDA/CPU/MPS): one grouped conv with G=B*C
2265
+ G = int(B * C)
2266
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
2267
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
2268
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
2269
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
2270
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
2271
+
2272
+
2273
+ # put near other small helpers
2274
+ def _robust_med_mad_t(x, max_elems_per_sample: int = 2_000_000):
2275
+ """
2276
+ x: (B, C, H, W) tensor on device.
2277
+ Returns (median[B,1,1,1], mad[B,1,1,1]) computed on a strided subsample
2278
+ to avoid 'quantile() input tensor is too large'.
2279
+ """
2280
+ import math
2281
+ import torch
2282
+ B = x.shape[0]
2283
+ flat = x.reshape(B, -1)
2284
+ N = flat.shape[1]
2285
+ if N > max_elems_per_sample:
2286
+ stride = int(math.ceil(N / float(max_elems_per_sample)))
2287
+ flat = flat[:, ::stride] # strided subsample
2288
+ med = torch.quantile(flat, 0.5, dim=1, keepdim=True)
2289
+ mad = torch.quantile((flat - med).abs(), 0.5, dim=1, keepdim=True) + 1e-6
2290
+ return med.view(B,1,1,1), mad.view(B,1,1,1)
2291
+
2292
+ def _torch_should_use_spatial(psf_ksize: int) -> bool:
2293
+ # Prefer spatial on non-CUDA backends and for modest kernels.
2294
+ try:
2295
+ dev = _torch_device()
2296
+ if dev.type in ("mps", "privateuseone"): # privateuseone = DirectML
2297
+ return True
2298
+ if dev.type == "cuda":
2299
+ return psf_ksize <= 51 # typical PSF sizes; spatial is fast & stable
2300
+ except Exception:
2301
+ pass
2302
+ # Allow override via env
2303
+ import os as _os
2304
+ if _os.environ.get("MF_SPATIAL", "") == "1":
2305
+ return True
2306
+ return False
2307
+
2308
+ def _read_tile_fits(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2309
+ """Return a (H,W) or (H,W,3|1) tile via FITS memmap, without loading whole image."""
2310
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2311
+ hdu = hdul[0]
2312
+ a = hdu.data
2313
+ if a is None:
2314
+ # find first image HDU if primary is header-only
2315
+ for h in hdul[1:]:
2316
+ if getattr(h, "data", None) is not None:
2317
+ a = h.data; break
2318
+ a = np.asarray(a) # still lazy until sliced
2319
+ # squeeze trailing singleton if present to keep your conventions
2320
+ if a.ndim == 3 and a.shape[-1] == 1:
2321
+ a = np.squeeze(a, axis=-1)
2322
+ tile = a[y0:y1, x0:x1, ...]
2323
+ # copy so we own the buffer (we will cast/normalize)
2324
+ return np.array(tile, copy=True)
2325
+
2326
+ def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2327
+ """FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
2328
+ ext = os.path.splitext(path)[1].lower()
2329
+
2330
+ if ext == ".xisf":
2331
+ a = _xisf_cached_array(path) # float32 memmap; cheap slicing
2332
+ # a is HW, HWC, or CHW (whatever _load_image_array returned)
2333
+ if a.ndim == 2:
2334
+ return np.array(a[y0:y1, x0:x1], copy=True)
2335
+ elif a.ndim == 3:
2336
+ if a.shape[-1] in (1, 3): # HWC
2337
+ out = a[y0:y1, x0:x1, :]
2338
+ if out.shape[-1] == 1: out = out[..., 0]
2339
+ return np.array(out, copy=True)
2340
+ elif a.shape[0] in (1, 3): # CHW
2341
+ out = a[:, y0:y1, x0:x1]
2342
+ if out.shape[0] == 1: out = out[0]
2343
+ return np.array(out, copy=True)
2344
+ else:
2345
+ raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
2346
+ else:
2347
+ raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
2348
+
2349
+ # FITS
2350
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2351
+ a = None
2352
+ for h in hdul:
2353
+ if getattr(h, "data", None) is not None:
2354
+ a = h.data
2355
+ break
2356
+ if a is None:
2357
+ raise ValueError(f"No image data in {path}")
2358
+
2359
+ a = np.asarray(a)
2360
+
2361
+ if a.ndim == 2: # HW
2362
+ return np.array(a[y0:y1, x0:x1], copy=True)
2363
+
2364
+ if a.ndim == 3:
2365
+ if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
2366
+ out = a[:, y0:y1, x0:x1]
2367
+ if out.shape[0] == 1:
2368
+ out = out[0]
2369
+ return np.array(out, copy=True)
2370
+ if a.shape[-1] in (1, 3): # HWC
2371
+ out = a[y0:y1, x0:x1, :]
2372
+ if out.shape[-1] == 1:
2373
+ out = out[..., 0]
2374
+ return np.array(out, copy=True)
2375
+
2376
+ # Fallback: assume last two axes are spatial (…, H, W)
2377
+ try:
2378
+ out = a[(..., slice(y0, y1), slice(x0, x1))]
2379
+ return np.array(out, copy=True)
2380
+ except Exception:
2381
+ raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
2382
+
2383
+ def _infer_channels_from_tile(p: str, Ht: int, Wt: int) -> int:
2384
+ """Look at a 1×1 tile to infer channel count; supports HW, HWC, CHW."""
2385
+ y1 = min(1, Ht); x1 = min(1, Wt)
2386
+ t = _read_tile_fits_any(p, 0, y1, 0, x1)
2387
+
2388
+ if t.ndim == 2:
2389
+ return 1
2390
+
2391
+ if t.ndim == 3:
2392
+ # Prefer the axis that actually carries the color planes
2393
+ ch_first = t.shape[0] in (1, 3)
2394
+ ch_last = t.shape[-1] in (1, 3)
2395
+
2396
+ if ch_first and not ch_last:
2397
+ return int(t.shape[0])
2398
+ if ch_last and not ch_first:
2399
+ return int(t.shape[-1])
2400
+
2401
+ # Ambiguous tiny tile (e.g. CHW 3×1×1 or HWC 1×1×3):
2402
+ if t.shape[0] == 3 or t.shape[-1] == 3:
2403
+ return 3
2404
+ return 1
2405
+
2406
+ return 1
2407
+
2408
+
2409
+
2410
+ def _seed_median_streaming(
2411
+ paths,
2412
+ Ht,
2413
+ Wt,
2414
+ *,
2415
+ color_mode="luma",
2416
+ tile_hw=(256, 256),
2417
+ status_cb=lambda s: None,
2418
+ progress_cb=lambda f, m="": None,
2419
+ use_torch: bool | None = None, # auto by default
2420
+ ):
2421
+ """
2422
+ Exact per-pixel median via tiling; RAM-bounded.
2423
+ Now shows per-tile progress and uses Torch on GPU if available.
2424
+ Parallelizes per-tile slab reads to hide I/O and luma work.
2425
+ """
2426
+
2427
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2428
+ # old: want_c = 1 if str(color_mode).lower() == "luma" else (3 if _read_shape_fast(paths[0])[0] == 3 else 1)
2429
+ if str(color_mode).lower() == "luma":
2430
+ want_c = 1
2431
+ else:
2432
+ want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
2433
+ seed = np.zeros((Ht, Wt), np.float32) if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32)
2434
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt)) for y in range(0, Ht, th) for x in range(0, Wt, tw)]
2435
+ total = len(tiles)
2436
+ n_frames = len(paths)
2437
+
2438
+ # Choose a sensible number of I/O workers (bounded by frames)
2439
+ try:
2440
+ _cpu = (os.cpu_count() or 4)
2441
+ except Exception:
2442
+ _cpu = 4
2443
+ io_workers = max(1, min(8, _cpu, n_frames))
2444
+
2445
+ # Torch autodetect (once)
2446
+ TORCH_OK = False
2447
+ device = None
2448
+ if use_torch is not False:
2449
+ try:
2450
+ from setiastro.saspro.runtime_torch import import_torch
2451
+ _t = import_torch(prefer_cuda=True, status_cb=status_cb)
2452
+ dev = None
2453
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
2454
+ dev = _t.device("cuda")
2455
+ elif hasattr(_t.backends, "mps") and _t.backends.mps.is_available():
2456
+ dev = _t.device("mps")
2457
+ else:
2458
+ dev = None # CPU tensors slower than NumPy for median; only use if forced
2459
+ if dev is not None:
2460
+ TORCH_OK = True
2461
+ device = dev
2462
+ status_cb(f"Median seed: using Torch device {device}")
2463
+ except Exception as e:
2464
+ status_cb(f"Median seed: Torch unavailable → NumPy fallback ({e})")
2465
+ TORCH_OK = False
2466
+ device = None
2467
+
2468
+ def _tile_msg(ti, tn):
2469
+ return f"median tiles {ti}/{tn}"
2470
+
2471
+ done = 0
2472
+
2473
+ for (y0, y1, x0, x1) in tiles:
2474
+ h, w = (y1 - y0), (x1 - x0)
2475
+
2476
+ # per-tile slab reader with incremental progress (parallel)
2477
+ def _read_slab_for_channel(csel=None):
2478
+ """
2479
+ Returns slab of shape (N, h, w) float32 in [0,1]-ish (normalized if input was integer).
2480
+ If csel is None and luma is requested, computes luma.
2481
+ """
2482
+ # parallel worker returns (i, tile2d)
2483
+ def _load_one(i):
2484
+ t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
2485
+
2486
+ # normalize dtype
2487
+ if t.dtype.kind in "ui":
2488
+ t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
2489
+ else:
2490
+ t = t.astype(np.float32, copy=False)
2491
+
2492
+ # luma / channel selection
2493
+ if want_c == 1:
2494
+ if t.ndim == 3:
2495
+ t = _to_luma_local(t)
2496
+ elif t.ndim != 2:
2497
+ t = _to_luma_local(t)
2498
+ else:
2499
+ if t.ndim == 2:
2500
+ pass
2501
+ elif t.ndim == 3 and t.shape[-1] == 3: # HWC
2502
+ t = t[..., csel]
2503
+ elif t.ndim == 3 and t.shape[0] == 3: # CHW
2504
+ t = t[csel]
2505
+ else:
2506
+ t = _to_luma_local(t)
2507
+ return i, np.ascontiguousarray(t, dtype=np.float32)
2508
+
2509
+ slab = np.empty((n_frames, h, w), np.float32)
2510
+ done_local = 0
2511
+ # cap workers by n_frames so we don't spawn useless threads
2512
+ with ThreadPoolExecutor(max_workers=min(io_workers, n_frames)) as ex:
2513
+ futures = [ex.submit(_load_one, i) for i in range(n_frames)]
2514
+ for fut in as_completed(futures):
2515
+ i, t2d = fut.result()
2516
+ # quick sanity (avoid silent mis-shapes)
2517
+ if t2d.shape != (h, w):
2518
+ raise RuntimeError(
2519
+ f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
2520
+ f"tile={(y0,y1,x0,x1)}"
2521
+ )
2522
+ slab[i] = t2d
2523
+ done_local += 1
2524
+ if (done_local & 7) == 0 or done_local == n_frames:
2525
+ tile_base = done / total
2526
+ tile_span = 1.0 / total
2527
+ inner = done_local / n_frames
2528
+ progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
2529
+ return slab
2530
+
2531
+ try:
2532
+ if want_c == 1:
2533
+ t0 = time.perf_counter()
2534
+ slab = _read_slab_for_channel()
2535
+ t1 = time.perf_counter()
2536
+ if TORCH_OK:
2537
+ import torch as _t
2538
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32) # one H2D
2539
+ med_t = slab_t.median(dim=0).values
2540
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2541
+ # no per-tile empty_cache() — avoid forced syncs
2542
+ else:
2543
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2544
+ t2 = time.perf_counter()
2545
+ seed[y0:y1, x0:x1] = med_np
2546
+ # lightweight telemetry to confirm bottleneck
2547
+ status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median={'GPU' if TORCH_OK else 'CPU'}={t2-t1:.3f}s")
2548
+ else:
2549
+ for c in range(want_c):
2550
+ slab = _read_slab_for_channel(csel=c)
2551
+ if TORCH_OK:
2552
+ import torch as _t
2553
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32)
2554
+ med_t = slab_t.median(dim=0).values
2555
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2556
+ else:
2557
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2558
+ seed[c, y0:y1, x0:x1] = med_np
2559
+
2560
+ except RuntimeError as e:
2561
+ # Per-tile GPU OOM or device issues → fallback to NumPy for this tile
2562
+ msg = str(e).lower()
2563
+ if TORCH_OK and ("out of memory" in msg or "resource" in msg or "alloc" in msg):
2564
+ status_cb(f"Median seed: GPU OOM on tile ({h}x{w}); falling back to NumPy for this tile.")
2565
+ if want_c == 1:
2566
+ slab = _read_slab_for_channel()
2567
+ seed[y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2568
+ else:
2569
+ for c in range(want_c):
2570
+ slab = _read_slab_for_channel(csel=c)
2571
+ seed[c, y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2572
+ else:
2573
+ raise
2574
+
2575
+ done += 1
2576
+ # report tile completion (remaining 20% of tile span reserved for median compute)
2577
+ progress_cb(done / total, _tile_msg(done, total))
2578
+
2579
+ if (done & 3) == 0:
2580
+ _process_gui_events_safely()
2581
+ status_cb(f"Median seed: want_c={want_c}, seed_shape={seed.shape}")
2582
+ return seed
2583
+
2584
+ def _seed_bootstrap_streaming(paths, Ht, Wt, color_mode,
2585
+ bootstrap_frames: int = 20,
2586
+ clip_sigma: float = 5.0,
2587
+ status_cb=lambda s: None,
2588
+ progress_cb=None):
2589
+ """
2590
+ Seed = average first B frames, estimate global MAD threshold, masked-mean on B,
2591
+ then stream the remaining frames with σ-clipped Welford μ–σ updates.
2592
+
2593
+ Returns float32, CHW for per-channel mode, or CHW(1,...) for luma (your caller later squeezes if needed).
2594
+ """
2595
+ def p(frac, msg):
2596
+ if progress_cb:
2597
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2598
+
2599
+ n = len(paths)
2600
+ B = int(max(1, min(int(bootstrap_frames), n)))
2601
+ status_cb(f"Seed: bootstrap={B}, clip_sigma={clip_sigma}")
2602
+
2603
+ # ---------- pass 1: running mean over the first B frames (Welford μ only) ----------
2604
+ cnt = None
2605
+ mu = None
2606
+ for i, pth in enumerate(paths[:B], 1):
2607
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2608
+ x = ys[0].astype(np.float32, copy=False)
2609
+ if mu is None:
2610
+ mu = x.copy()
2611
+ cnt = np.ones_like(x, dtype=np.float32)
2612
+ else:
2613
+ delta = x - mu
2614
+ cnt += 1.0
2615
+ mu += delta / cnt
2616
+ if (i == B) or (i % 4) == 0:
2617
+ p(0.20 * (i / float(B)), f"bootstrap mean {i}/{B}")
2618
+
2619
+ # ---------- pass 2: estimate a *global* MAD around μ using strided samples ----------
2620
+ stride = 4 if max(Ht, Wt) > 1500 else 2
2621
+ # CHW or HW → build a slice that keeps C and strides H,W
2622
+ samp_slices = (slice(None) if mu.ndim == 3 else slice(None),
2623
+ slice(None, None, stride),
2624
+ slice(None, None, stride)) if mu.ndim == 3 else \
2625
+ (slice(None, None, stride), slice(None, None, stride))
2626
+
2627
+ mad_samples = []
2628
+ for i, pth in enumerate(paths[:B], 1):
2629
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2630
+ x = ys[0].astype(np.float32, copy=False)
2631
+ d = np.abs(x - mu)
2632
+ mad_samples.append(d[samp_slices].ravel())
2633
+ if (i == B) or (i % 8) == 0:
2634
+ p(0.35 + 0.10 * (i / float(B)), f"bootstrap MAD {i}/{B}")
2635
+
2636
+ # robust MAD estimate (scalar) → 4·MAD clip band
2637
+ mad_est = float(np.median(np.concatenate(mad_samples).astype(np.float32)))
2638
+ thr = 4.0 * max(mad_est, 1e-6)
2639
+
2640
+ # ---------- pass 3: masked mean over first B frames using the global threshold ----------
2641
+ sum_acc = np.zeros_like(mu, dtype=np.float32)
2642
+ cnt_acc = np.zeros_like(mu, dtype=np.float32)
2643
+ for i, pth in enumerate(paths[:B], 1):
2644
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2645
+ x = ys[0].astype(np.float32, copy=False)
2646
+ m = (np.abs(x - mu) <= thr).astype(np.float32, copy=False)
2647
+ sum_acc += x * m
2648
+ cnt_acc += m
2649
+ if (i == B) or (i % 4) == 0:
2650
+ p(0.48 + 0.12 * (i / float(B)), f"masked mean {i}/{B}")
2651
+
2652
+ seed = mu.copy()
2653
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2654
+
2655
+ # ---------- pass 4: μ–σ streaming on the remaining frames (σ-clipped Welford) ----------
2656
+ M2 = np.zeros_like(seed, dtype=np.float32) # sum of squared diffs
2657
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2658
+ mu = seed.astype(np.float32, copy=False)
2659
+
2660
+ remain = n - B
2661
+ k = float(clip_sigma)
2662
+ for j, pth in enumerate(paths[B:], 1):
2663
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2664
+ x = ys[0].astype(np.float32, copy=False)
2665
+
2666
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2667
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2668
+
2669
+ acc = (np.abs(x - mu) <= (k * sigma)).astype(np.float32, copy=False) # {0,1} mask
2670
+ n_new = cnt + acc
2671
+ delta = x - mu
2672
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2673
+ M2 = M2 + acc * delta * (x - mu_n)
2674
+
2675
+ mu, cnt = mu_n, n_new
2676
+
2677
+ if (j == remain) or (j % 8) == 0:
2678
+ p(0.60 + 0.40 * (j / float(max(1, remain))), f"μ–σ refine {j}/{remain}")
2679
+
2680
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2681
+
2682
+
2683
+ def _coerce_sr_factor(srf, *, default_on_bad=2):
2684
+ """
2685
+ Parse super-res factor robustly:
2686
+ - accepts 2, '2', '2x', ' 2 X ', 2.0
2687
+ - clamps to integers >= 1
2688
+ - if invalid/missing → returns default_on_bad (we want 2 by your request)
2689
+ """
2690
+ if srf is None:
2691
+ return int(default_on_bad)
2692
+ if isinstance(srf, (float, int)):
2693
+ r = int(round(float(srf)))
2694
+ return int(r if r >= 1 else default_on_bad)
2695
+ s = str(srf).strip().lower()
2696
+ # common GUIs pass e.g. "2x", "3×", etc.
2697
+ s = s.replace("×", "x")
2698
+ if s.endswith("x"):
2699
+ s = s[:-1]
2700
+ try:
2701
+ r = int(round(float(s)))
2702
+ return int(r if r >= 1 else default_on_bad)
2703
+ except Exception:
2704
+ return int(default_on_bad)
2705
+
2706
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2707
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2708
+ k = np.asarray(k, dtype=np.float32)
2709
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2710
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2711
+ if kh == K and kw == K:
2712
+ return k
2713
+ out = np.zeros((K, K), dtype=np.float32)
2714
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2715
+ out[y0:y0+kh, x0:x0+kw] = k
2716
+ s = float(out.sum())
2717
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2718
+
2719
+ # -----------------------------
2720
+ # Core
2721
+ # -----------------------------
2722
+ def multiframe_deconv(
2723
+ paths,
2724
+ out_path,
2725
+ iters=20,
2726
+ kappa=2.0,
2727
+ color_mode="luma",
2728
+ seed_mode: str = "robust",
2729
+ huber_delta=0.0,
2730
+ masks=None,
2731
+ variances=None,
2732
+ rho="huber",
2733
+ status_cb=lambda s: None,
2734
+ min_iters: int = 3,
2735
+ use_star_masks: bool = False,
2736
+ use_variance_maps: bool = False,
2737
+ star_mask_cfg: dict | None = None,
2738
+ varmap_cfg: dict | None = None,
2739
+ save_intermediate: bool = False,
2740
+ save_every: int = 1,
2741
+ # SR options
2742
+ super_res_factor: int = 1,
2743
+ sr_sigma: float = 1.1,
2744
+ sr_psf_opt_iters: int = 250,
2745
+ sr_psf_opt_lr: float = 0.1,
2746
+ # NEW
2747
+ batch_frames: int | None = None,
2748
+ # GPU tuning (optional knobs)
2749
+ mixed_precision: bool | None = None, # default: auto (True on CUDA/MPS)
2750
+ fft_kernel_threshold: int = 1024, # switch to FFT if K >= this (or lower if SR)
2751
+ prefetch_batches: bool = True, # CPU→GPU double-buffer prefetch
2752
+ use_channels_last: bool | None = None, # default: auto (True on CUDA/MPS)
2753
+ force_cpu: bool = False,
2754
+ star_mask_ref_path: str | None = None,
2755
+ low_mem: bool = False,
2756
+ ):
2757
+ """
2758
+ Streaming multi-frame deconvolution with optional SR (r>1).
2759
+ Optimized GPU path: AMP for convs, channels-last, pinned-memory prefetch, optional FFT for large kernels.
2760
+ """
2761
+ mixed_precision = False
2762
+ DEBUG_FLAT_WEIGHTS = False
2763
+ # ---------- local helpers (kept self-contained) ----------
2764
+ def _emit_pct(pct: float, msg: str | None = None):
2765
+ pct = float(max(0.0, min(1.0, pct)))
2766
+ status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
2767
+
2768
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2769
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2770
+ k = np.asarray(k, dtype=np.float32)
2771
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2772
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2773
+ if kh == K and kw == K:
2774
+ return k
2775
+ out = np.zeros((K, K), dtype=np.float32)
2776
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2777
+ out[y0:y0+kh, x0:x0+kw] = k
2778
+ s = float(out.sum())
2779
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2780
+
2781
+ max_iters = max(1, int(iters))
2782
+ min_iters = max(1, int(min_iters))
2783
+ if min_iters > max_iters:
2784
+ min_iters = max_iters
2785
+
2786
+ n_frames = len(paths)
2787
+ status_cb(f"MFDeconv: scanning {n_frames} aligned frames (memmap)…")
2788
+ _emit_pct(0.02, "scanning")
2789
+
2790
+ # choose common intersection size without loading full pixels
2791
+ Ht, Wt = _common_hw_from_paths(paths)
2792
+ _emit_pct(0.05, "preparing")
2793
+
2794
+ # --- LOW-MEM PATCH (begin) ---
2795
+ if low_mem:
2796
+ # Cap decoded-frame LRU to keep peak RAM sane on 16 GB laptops
2797
+ try:
2798
+ _FRAME_LRU.cap = max(1, min(getattr(_FRAME_LRU, "cap", 8), 2))
2799
+ except Exception:
2800
+ pass
2801
+
2802
+ # Disable CPU→GPU prefetch to avoid double-buffering allocations
2803
+ prefetch_batches = False
2804
+
2805
+ # Relax SEP background grid & star detection canvas when requested
2806
+ if use_variance_maps:
2807
+ varmap_cfg = {**(varmap_cfg or {})}
2808
+ # fewer, larger tiles → fewer big temporaries
2809
+ varmap_cfg.setdefault("bw", 96)
2810
+ varmap_cfg.setdefault("bh", 96)
2811
+
2812
+ if use_star_masks:
2813
+ star_mask_cfg = {**(star_mask_cfg or {})}
2814
+ # shrink detection canvas to limit temp buffers inside SEP/mask draw
2815
+ star_mask_cfg["max_side"] = int(min(1024, int(star_mask_cfg.get("max_side", 2048))))
2816
+ # --- LOW-MEM PATCH (end) ---
2817
+
2818
+
2819
+ if any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths):
2820
+ status_cb("MFDeconv: priming XISF cache (one-time decode per frame)…")
2821
+ for i, p in enumerate(paths, 1):
2822
+ try:
2823
+ _ = _xisf_cached_array(p) # decode once, store memmap
2824
+ except Exception as e:
2825
+ status_cb(f"XISF cache failed for {p}: {e}")
2826
+ if (i & 7) == 0 or i == len(paths):
2827
+ _process_gui_events_safely()
2828
+
2829
+ # per-frame loader & sequence view (closures capture Ht/Wt/color_mode/paths)
2830
+ def _load_frame_chw(i: int):
2831
+ return _FRAME_LRU.get(paths[i], Ht, Wt, color_mode)
2832
+
2833
+ class _FrameSeq:
2834
+ def __len__(self): return len(paths)
2835
+ def __getitem__(self, i): return _load_frame_chw(i)
2836
+ data = _FrameSeq()
2837
+
2838
+ # ---- torch detection (optional) ----
2839
+ global torch, TORCH_OK
2840
+ torch = None
2841
+ TORCH_OK = False
2842
+ cuda_ok = mps_ok = dml_ok = False
2843
+ dml_device = None
2844
+
2845
+ try:
2846
+ from setiastro.saspro.runtime_torch import import_torch
2847
+ torch = import_torch(prefer_cuda=True, status_cb=status_cb)
2848
+ TORCH_OK = True
2849
+ try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
2850
+ except Exception as e:
2851
+ import logging
2852
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2853
+ try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
2854
+ except Exception as e:
2855
+ import logging
2856
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2857
+ try:
2858
+ import torch_directml
2859
+ dml_device = torch_directml.device()
2860
+ _ = (torch.ones(1, device=dml_device) + 1).item()
2861
+ dml_ok = True
2862
+ # NEW: expose to _torch_device()
2863
+ globals()["dml_ok"] = True
2864
+ globals()["dml_device"] = dml_device
2865
+ except Exception:
2866
+ dml_ok = False
2867
+ globals()["dml_ok"] = False
2868
+ globals()["dml_device"] = None
2869
+
2870
+ if cuda_ok:
2871
+ status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
2872
+ elif mps_ok:
2873
+ status_cb("PyTorch MPS (Apple) available: True")
2874
+ elif dml_ok:
2875
+ status_cb("PyTorch DirectML (Windows) available: True")
2876
+ else:
2877
+ status_cb("PyTorch present, using CPU backend.")
2878
+ status_cb(
2879
+ f"PyTorch {getattr(torch,'__version__','?')} backend: "
2880
+ + ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
2881
+ )
2882
+
2883
+ try:
2884
+ # keep cuDNN autotune on
2885
+ if getattr(torch.backends, "cudnn", None) is not None:
2886
+ torch.backends.cudnn.benchmark = True
2887
+ torch.backends.cudnn.allow_tf32 = False
2888
+ except Exception:
2889
+ pass
2890
+
2891
+ try:
2892
+ # disable TF32 matmul shortcuts if present (CUDA-only; safe no-op elsewhere)
2893
+ if getattr(getattr(torch.backends, "cuda", None), "matmul", None) is not None:
2894
+ torch.backends.cuda.matmul.allow_tf32 = False
2895
+ except Exception:
2896
+ pass
2897
+
2898
+ try:
2899
+ # prefer highest FP32 precision on PT 2.x
2900
+ if hasattr(torch, "set_float32_matmul_precision"):
2901
+ torch.set_float32_matmul_precision("highest")
2902
+ except Exception:
2903
+ pass
2904
+ except Exception as e:
2905
+ TORCH_OK = False
2906
+ status_cb(f"PyTorch not available → CPU path. ({e})")
2907
+
2908
+ if force_cpu:
2909
+ status_cb("⚠️ CPU-only debug mode: disabling PyTorch path.")
2910
+ TORCH_OK = False
2911
+
2912
+ _process_gui_events_safely()
2913
+
2914
+ # ---- PSFs + optional assets (computed in parallel, streaming I/O) ----
2915
+ psfs, masks_auto, vars_auto = _build_psf_and_assets(
2916
+ paths,
2917
+ make_masks=bool(use_star_masks),
2918
+ make_varmaps=bool(use_variance_maps),
2919
+ status_cb=status_cb,
2920
+ save_dir=None,
2921
+ star_mask_cfg=star_mask_cfg,
2922
+ varmap_cfg=varmap_cfg,
2923
+ star_mask_ref_path=star_mask_ref_path,
2924
+ # NEW:
2925
+ Ht=Ht, Wt=Wt, color_mode=color_mode,
2926
+ )
2927
+
2928
+ try:
2929
+ import gc as _gc
2930
+ _gc.collect()
2931
+ except Exception:
2932
+ pass
2933
+
2934
+ psfs_native = psfs # keep a reference to the original native PSFs for fallback
2935
+
2936
+ # ---- SR lift of PSFs if needed ----
2937
+ r_req = _coerce_sr_factor(super_res_factor, default_on_bad=2)
2938
+ status_cb(f"MFDeconv: SR factor requested={super_res_factor!r} → using r={r_req}")
2939
+ r = int(r_req)
2940
+
2941
+ if r > 1:
2942
+ status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
2943
+ _process_gui_events_safely()
2944
+
2945
+ def _naive_sr_from_native(f_nat: np.ndarray, r_: int) -> np.ndarray:
2946
+ f2 = np.asarray(f_nat, np.float32)
2947
+ H, W = f2.shape[:2]
2948
+ k_sq = min(H, W)
2949
+ if H != W:
2950
+ y0 = (H - k_sq) // 2; x0 = (W - k_sq) // 2
2951
+ f2 = f2[y0:y0+k_sq, x0:x0+k_sq]
2952
+ if (f2.shape[0] % 2) == 0:
2953
+ f2 = f2[1:, 1:] # make odd
2954
+ f2 = _normalize_psf(f2)
2955
+ h0 = np.zeros((f2.shape[0]*r_, f2.shape[1]*r_), np.float32)
2956
+ h0[::r_, ::r_] = f2
2957
+ return _normalize_psf(h0)
2958
+
2959
+ sr_psfs = []
2960
+ for i, k_native in enumerate(psfs, start=1):
2961
+ try:
2962
+ status_cb(f" SR-PSF{i}: native shape={np.asarray(k_native).shape}")
2963
+ h = _solve_super_psf_from_native(
2964
+ k_native, r=r, sigma=float(sr_sigma),
2965
+ iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr)
2966
+ )
2967
+ except Exception as e:
2968
+ status_cb(f" SR-PSF{i} failed: {e!r} → using naïve upsample")
2969
+ h = _naive_sr_from_native(k_native, r)
2970
+
2971
+ # guarantee odd size for downstream SAME-padding math
2972
+ if (h.shape[0] % 2) == 0:
2973
+ h = h[:-1, :-1]
2974
+ sr_psfs.append(h.astype(np.float32, copy=False))
2975
+ status_cb(f" SR-PSF{i}: native {np.asarray(k_native).shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
2976
+ psfs = sr_psfs
2977
+
2978
+
2979
+
2980
+ # ---- Seed (streaming) with robust bootstrap already in file helpers ----
2981
+ _emit_pct(0.25, "Calculating Seed Image...")
2982
+ def _seed_progress(frac, msg):
2983
+ _emit_pct(0.25 + 0.15 * float(frac), f"seed: {msg}")
2984
+
2985
+ seed_mode_s = str(seed_mode).lower().strip()
2986
+ if seed_mode_s not in ("robust", "median"):
2987
+ seed_mode_s = "robust"
2988
+ if seed_mode_s == "median":
2989
+ status_cb("MFDeconv: Building median seed (tiled, streaming)…")
2990
+ seed_native = _seed_median_streaming(
2991
+ paths, Ht, Wt,
2992
+ color_mode=color_mode,
2993
+ tile_hw=(256, 256),
2994
+ status_cb=status_cb,
2995
+ progress_cb=_seed_progress,
2996
+ use_torch=TORCH_OK, # ← auto: GPU if available, else NumPy
2997
+ )
2998
+ else:
2999
+ seed_native = _seed_bootstrap_streaming(
3000
+ paths, Ht, Wt, color_mode,
3001
+ bootstrap_frames=20, clip_sigma=5,
3002
+ status_cb=status_cb, progress_cb=_seed_progress
3003
+ )
3004
+
3005
+ # lift seed if SR
3006
+ if r > 1:
3007
+ target_hw = (Ht * r, Wt * r)
3008
+ if seed_native.ndim == 2:
3009
+ x = _upsample_sum(seed_native / (r*r), r, target_hw=target_hw)
3010
+ else:
3011
+ C, Hn, Wn = seed_native.shape
3012
+ x = np.stack(
3013
+ [_upsample_sum(seed_native[c] / (r*r), r, target_hw=target_hw) for c in range(C)],
3014
+ axis=0
3015
+ )
3016
+ else:
3017
+ x = seed_native
3018
+
3019
+ # FINAL SHAPE CHECKS (auto-correct if a GUI sent something odd)
3020
+ if x.ndim == 2: x = x[None, ...]
3021
+ Hs, Ws = x.shape[-2], x.shape[-1]
3022
+ if r > 1:
3023
+ expected_H, expected_W = Ht * r, Wt * r
3024
+ if (Hs, Ws) != (expected_H, expected_W):
3025
+ status_cb(f"SR seed grid mismatch: got {(Hs, Ws)}, expected {(expected_H, expected_W)} → correcting")
3026
+ # Rebuild from the native mean to ensure exact SR size
3027
+ x = _upsample_sum(x if x.ndim==2 else x[0], r, target_hw=(expected_H, expected_W))
3028
+ if x.ndim == 2: x = x[None, ...]
3029
+ try: del seed_native
3030
+ except Exception as e:
3031
+ import logging
3032
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3033
+
3034
+ try:
3035
+ import gc as _gc
3036
+ _gc.collect()
3037
+ except Exception:
3038
+ pass
3039
+
3040
+ flip_psf = [_flip_kernel(k) for k in psfs]
3041
+ _emit_pct(0.20, "PSF Ready")
3042
+
3043
+ # --- Harmonize seed channels with the actual frames ---
3044
+ # Probe the first frame's channel count (CHW from the cache)
3045
+ try:
3046
+ y0_probe = _load_frame_chw(0) # CHW float32
3047
+ C_ref = 1 if y0_probe.ndim == 2 else int(y0_probe.shape[0])
3048
+ except Exception:
3049
+ # Fallback: infer from seed if probing fails
3050
+ C_ref = 1 if x.ndim == 2 else int(x.shape[0])
3051
+
3052
+ # If median seed came back mono but frames are RGB, broadcast it
3053
+ if x.ndim == 2 and C_ref == 3:
3054
+ x = np.stack([x] * 3, axis=0)
3055
+ elif x.ndim == 3 and x.shape[0] == 1 and C_ref == 3:
3056
+ x = np.repeat(x, 3, axis=0)
3057
+ # If seed is RGB but frames are mono (rare), collapse to luma
3058
+ elif x.ndim == 3 and x.shape[0] == 3 and C_ref == 1:
3059
+ # ITU-R BT.709 luma (safe in float)
3060
+ x = (0.2126 * x[0] + 0.7152 * x[1] + 0.0722 * x[2]).astype(np.float32)
3061
+
3062
+ # Ensure CHW shape for the rest of the pipeline
3063
+ if x.ndim == 2:
3064
+ x = x[None, ...]
3065
+ elif x.ndim == 3 and x.shape[0] not in (1, 3) and x.shape[-1] in (1, 3):
3066
+ x = np.moveaxis(x, -1, 0)
3067
+
3068
+ # Set the expected channel count from the FRAMES, not from the seed
3069
+ C_EXPECTED = int(C_ref)
3070
+ _, Hs, Ws = x.shape
3071
+
3072
+
3073
+
3074
+ # ---- choose default batch size ----
3075
+ if batch_frames is None:
3076
+ px = Hs * Ws
3077
+ if px >= 16_000_000: auto_B = 2
3078
+ elif px >= 8_000_000: auto_B = 4
3079
+ else: auto_B = 8
3080
+ else:
3081
+ auto_B = int(max(1, batch_frames))
3082
+
3083
+ # --- LOW-MEM PATCH: clamp batch size hard ---
3084
+ if low_mem:
3085
+ auto_B = max(1, min(auto_B, 2))
3086
+
3087
+ # ---- background/MAD telemetry (first frame) ----
3088
+ status_cb("MFDeconv: Calculating Backgrounds and MADs…")
3089
+ _process_gui_events_safely()
3090
+ try:
3091
+ y0 = data[0]; y0l = y0 if y0.ndim == 2 else y0[0]
3092
+ med = float(np.median(y0l)); mad = float(np.median(np.abs(y0l - med))) + 1e-6
3093
+ bg_est = 1.4826 * mad
3094
+ except Exception:
3095
+ bg_est = 0.0
3096
+ status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
3097
+
3098
+ # ---- mask/variance accessors ----
3099
+ def _mask_for(i, like_img):
3100
+ src = (masks if masks is not None else masks_auto)
3101
+ if src is None: # no masks at all
3102
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3103
+ m = src[i]
3104
+ if m is None:
3105
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3106
+ m = np.asarray(m, dtype=np.float32)
3107
+ if m.ndim == 3: m = m[0]
3108
+ return _center_crop(m, like_img.shape[-2], like_img.shape[-1]).astype(np.float32, copy=False)
3109
+
3110
+ def _var_for(i, like_img):
3111
+ src = (variances if variances is not None else vars_auto)
3112
+ if src is None: return None
3113
+ v = src[i]
3114
+ if v is None: return None
3115
+ v = np.asarray(v, dtype=np.float32)
3116
+ if v.ndim == 3: v = v[0]
3117
+ v = _center_crop(v, like_img.shape[-2], like_img.shape[-1])
3118
+ return np.clip(v, 1e-8, None).astype(np.float32, copy=False)
3119
+
3120
+ # ---- NumPy path conv helper (keep as-is) ----
3121
+ def _conv_np_same(a, k, out=None):
3122
+ y = _conv_same_np_spatial(a, k, out)
3123
+ if y is not None:
3124
+ return y
3125
+ # No OpenCV → always use the ifftshifted FFT path
3126
+ import numpy as _np
3127
+ import numpy.fft as _fft
3128
+ H, W = a.shape[-2:]
3129
+ kh, kw = k.shape
3130
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
3131
+ Kf = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW))
3132
+ if out is None:
3133
+ out = _np.empty_like(a, dtype=_np.float32)
3134
+ return _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out)
3135
+
3136
+ # ---- allocate scratch & prepare PSF tensors if torch ----
3137
+ relax = 0.7
3138
+ use_torch = bool(TORCH_OK)
3139
+ cm = _safe_inference_context() if use_torch else NO_GRAD
3140
+ rho_is_l2 = (str(rho).lower() == "l2")
3141
+ local_delta = 0.0 if rho_is_l2 else huber_delta
3142
+
3143
+ if use_torch:
3144
+ F = torch.nn.functional
3145
+ device = _torch_device()
3146
+
3147
+ # Force FP32 tensors
3148
+ x_t = _to_t(_contig(x)).to(torch.float32)
3149
+ num = torch.zeros_like(x_t, dtype=torch.float32)
3150
+ den = torch.zeros_like(x_t, dtype=torch.float32)
3151
+
3152
+ # channels-last preference kept (does not change dtype)
3153
+ if use_channels_last is None:
3154
+ use_channels_last = bool(cuda_ok) # <- never enable on MPS
3155
+ if mps_ok:
3156
+ use_channels_last = False # <- force NCHW on MPS
3157
+
3158
+ # PSF tensors strictly FP32
3159
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
3160
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
3161
+ use_spatial = _torch_should_use_spatial(psf_t[0].shape[-1])
3162
+
3163
+ # No mixed precision, no autocast
3164
+ use_amp = False
3165
+ amp_cm = None
3166
+ amp_kwargs = {}
3167
+
3168
+ # FFT gate (worth it for large kernels / SR); still FP32
3169
+ # Decide FFT vs spatial FIRST, based on native kernel sizes (no global padding!)
3170
+ K_native_max = max(int(k.shape[0]) for k in psfs)
3171
+ use_fft = False
3172
+ if device.type == "cuda" and ((K_native_max >= int(fft_kernel_threshold)) or (r > 1 and K_native_max >= max(21, int(fft_kernel_threshold) - 4))):
3173
+ use_fft = True
3174
+ # Force NCHW for FFT branch to keep channel slices contiguous/sane
3175
+ use_channels_last = False
3176
+ # Precompute FFT packs…
3177
+ psf_fft, psfT_fft = _precompute_torch_psf_ffts(
3178
+ psfs, flip_psf, Hs, Ws, device=x_t.device, dtype=torch.float32
3179
+ )
3180
+ else:
3181
+ psf_fft = psfT_fft = None
3182
+ Kmax = K_native_max
3183
+ if (Kmax % 2) == 0:
3184
+ Kmax += 1
3185
+ if any(int(k.shape[0]) != Kmax for k in psfs):
3186
+ status_cb(f"MFDeconv: normalizing PSF sizes → {Kmax}×{Kmax}")
3187
+ psfs = [_pad_kernel_to(k, Kmax) for k in psfs]
3188
+ flip_psf = [_flip_kernel(k) for k in psfs]
3189
+
3190
+ # (Re)build spatial kernels strictly as contiguous FP32 tensors
3191
+ psf_t = [_to_t(_contig(k))[None, None].to(torch.float32).contiguous() for k in psfs]
3192
+ psfT_t = [_to_t(_contig(kT))[None, None].to(torch.float32).contiguous() for kT in flip_psf]
3193
+ else:
3194
+ x_t = _contig(x).astype(np.float32, copy=False)
3195
+ num = np.zeros_like(x_t, dtype=np.float32)
3196
+ den = np.zeros_like(x_t, dtype=np.float32)
3197
+ use_amp = False
3198
+ use_fft = False
3199
+
3200
+ # ---- batched torch helper (grouped depthwise per-sample) ----
3201
+ if use_torch:
3202
+ # inside `if use_torch:` block in multiframe_deconv — replace the whole inner helper
3203
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
3204
+ """
3205
+ x_bc_hw : (B,C,H,W), torch.float32 on device
3206
+ w_b1kk : (B,1,kh,kw), torch.float32 on device
3207
+ Returns (B,C,H,W) contiguous (NCHW).
3208
+ """
3209
+ F = torch.nn.functional
3210
+
3211
+ # Force standard NCHW contiguous tensors
3212
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
3213
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
3214
+
3215
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
3216
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
3217
+
3218
+ if x_bc_hw.device.type == "mps":
3219
+ # Safe, slower path: convolve each channel separately, no groups
3220
+ ys = []
3221
+ for j in range(B): # per sample
3222
+ xj = x_bc_hw[j:j+1] # (1,C,H,W)
3223
+ # reflect pad once per sample
3224
+ xj = F.pad(xj, pad, mode="reflect")
3225
+ cj_out = []
3226
+ # one shared kernel per sample j: (1,1,kh,kw)
3227
+ kj = w_b1kk[j:j+1] # keep shape (1,1,kh,kw)
3228
+ for c in range(C):
3229
+ # slice that channel as its own (1,1,H,W) tensor
3230
+ xjc = xj[:, c:c+1, ...]
3231
+ yjc = F.conv2d(xjc, kj, padding=0, groups=1) # no groups
3232
+ cj_out.append(yjc)
3233
+ ys.append(torch.cat(cj_out, dim=1)) # (1,C,H,W)
3234
+ return torch.stack([y[0] for y in ys], 0).contiguous()
3235
+
3236
+
3237
+ # ---- FAST PATH (CUDA/CPU): single grouped conv with G=B*C ----
3238
+ G = int(B * C)
3239
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
3240
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
3241
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
3242
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
3243
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
3244
+
3245
+
3246
+ def _downsample_avg_bt_t(x, r_):
3247
+ if r_ <= 1:
3248
+ return x
3249
+ B, C, H, W = x.shape
3250
+ Hr, Wr = (H // r_) * r_, (W // r_) * r_
3251
+ if Hr == 0 or Wr == 0:
3252
+ return x
3253
+ return x[:, :, :Hr, :Wr].reshape(B, C, Hr // r_, r_, Wr // r_, r_).mean(dim=(3, 5))
3254
+
3255
+ def _upsample_sum_bt_t(x, r_):
3256
+ if r_ <= 1:
3257
+ return x
3258
+ return x.repeat_interleave(r_, dim=-2).repeat_interleave(r_, dim=-1)
3259
+
3260
+ def _make_pinned_batch(idx, C_expected, to_device_dtype):
3261
+ y_list, m_list, v_list = [], [], []
3262
+ for fi in idx:
3263
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3264
+ if y_chw.ndim != 3:
3265
+ raise RuntimeError(f"Frame {fi}: expected CHW, got {tuple(y_chw.shape)}")
3266
+ C_here = int(y_chw.shape[0])
3267
+ if C_expected is not None and C_here != C_expected:
3268
+ raise RuntimeError(f"Mixed channel counts: expected C={C_expected}, got C={C_here} (frame {fi})")
3269
+
3270
+ m2d = _mask_for(fi, y_chw)
3271
+ v2d = _var_for(fi, y_chw)
3272
+ y_list.append(y_chw); m_list.append(m2d); v_list.append(v2d)
3273
+
3274
+ # CPU (NCHW) tensors
3275
+ y_cpu = torch.from_numpy(np.stack(y_list, 0)).to(torch.float32).contiguous()
3276
+ m_cpu = torch.from_numpy(np.stack(m_list, 0)).to(torch.float32).contiguous()
3277
+ have_v = all(v is not None for v in v_list)
3278
+ vb_cpu = None if not have_v else torch.from_numpy(np.stack(v_list, 0)).to(torch.float32).contiguous()
3279
+
3280
+ # optional pin for faster H2D on CUDA
3281
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
3282
+ y_cpu = y_cpu.pin_memory(); m_cpu = m_cpu.pin_memory()
3283
+ if vb_cpu is not None: vb_cpu = vb_cpu.pin_memory()
3284
+
3285
+ # move to device in FP32, keep NCHW contiguous format
3286
+ yb = y_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3287
+ mb = m_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3288
+ vb = None if vb_cpu is None else vb_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3289
+
3290
+ return yb, mb, vb
3291
+
3292
+
3293
+
3294
+ # ---- intermediates folder ----
3295
+ iter_dir = None
3296
+ hdr0_seed = None
3297
+ if save_intermediate:
3298
+ iter_dir = _iter_folder(out_path)
3299
+ status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
3300
+ try:
3301
+ hdr0 = _safe_primary_header(paths[0])
3302
+ except Exception:
3303
+ hdr0 = fits.Header()
3304
+ _save_iter_image(x, hdr0, iter_dir, "seed", color_mode)
3305
+
3306
+ # ---- iterative loop ----
3307
+
3308
+
3309
+
3310
+ auto_delta_cache = None
3311
+ if use_torch and (huber_delta < 0) and (not rho_is_l2):
3312
+ auto_delta_cache = [None] * n_frames
3313
+ early = EarlyStopper(
3314
+ tol_upd_floor=2e-4,
3315
+ tol_rel_floor=5e-4,
3316
+ early_frac=0.40,
3317
+ ema_alpha=0.5,
3318
+ patience=2,
3319
+ min_iters=min_iters,
3320
+ status_cb=status_cb
3321
+ )
3322
+
3323
+ used_iters = 0
3324
+ early_stopped = False
3325
+ with cm():
3326
+ for it in range(1, max_iters + 1):
3327
+ # reset accumulators
3328
+ if use_torch:
3329
+ num.zero_(); den.zero_()
3330
+ else:
3331
+ num.fill(0.0); den.fill(0.0)
3332
+
3333
+ if use_torch:
3334
+ # ---- batched GPU path (hardened with busy→CPU fallback) ----
3335
+ retry_cpu = False # if True, we’ll redo this iteration’s accumulators on CPU
3336
+
3337
+ try:
3338
+ frame_idx = list(range(n_frames))
3339
+ B_cur = int(max(1, auto_B))
3340
+ ci = 0
3341
+ Cn = None
3342
+
3343
+ # AMP context helper
3344
+ def _maybe_amp():
3345
+ if use_amp and (amp_cm is not None):
3346
+ return amp_cm(**amp_kwargs)
3347
+ from contextlib import nullcontext
3348
+ return nullcontext()
3349
+
3350
+ # prefetch first batch (unchanged)
3351
+ if prefetch_batches:
3352
+ idx0 = frame_idx[ci:ci+B_cur]
3353
+ if idx0:
3354
+ Cn = C_EXPECTED
3355
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx0, Cn, x_t.dtype)
3356
+ else:
3357
+ yb_next = mb_next = vb_next = None
3358
+
3359
+ while ci < n_frames:
3360
+ idx = frame_idx[ci:ci + B_cur]
3361
+ B = len(idx)
3362
+ try:
3363
+ # ---- existing gather/conv/weights/backproject code (UNCHANGED) ----
3364
+ if prefetch_batches and (yb_next is not None):
3365
+ yb, mb, vb = yb_next, mb_next, vb_next
3366
+ ci2 = ci + B_cur
3367
+ if ci2 < n_frames:
3368
+ idx2 = frame_idx[ci2:ci2+B_cur]
3369
+ Cn = Cn or (_as_chw(_sanitize_numeric(data[idx[0]])).shape[0])
3370
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx2, Cn, torch.float32)
3371
+ else:
3372
+ yb_next = mb_next = vb_next = None
3373
+ else:
3374
+ Cn = C_EXPECTED
3375
+ yb, mb, vb = _make_pinned_batch(idx, Cn, torch.float32)
3376
+ if use_channels_last:
3377
+ yb = yb.contiguous(memory_format=torch.channels_last)
3378
+
3379
+ wk = torch.cat([psf_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3380
+ wkT = torch.cat([psfT_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3381
+
3382
+ # --- predict on SR grid ---
3383
+ x_bc_hw = x_t.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
3384
+ if use_channels_last:
3385
+ x_bc_hw = x_bc_hw.contiguous(memory_format=torch.channels_last)
3386
+
3387
+ if use_fft:
3388
+ pred_list = []
3389
+ chan_tmp = torch.empty_like(x_t[0], dtype=torch.float32) # (H,W) single-channel
3390
+ for j, fi in enumerate(idx):
3391
+ C_here = x_bc_hw.shape[1]
3392
+ pred_j = torch.empty_like(x_t, dtype=torch.float32) # (C,H,W)
3393
+ Kf_pack = psf_fft[fi]
3394
+ for c in range(C_here):
3395
+ _fft_conv_same_torch(x_bc_hw[j, c], Kf_pack, out_spatial=chan_tmp) # write into chan_tmp
3396
+ pred_j[c].copy_(chan_tmp) # copy once
3397
+ pred_list.append(pred_j)
3398
+ pred_super = torch.stack(pred_list, 0).contiguous()
3399
+ else:
3400
+ pred_super = _grouped_conv_same_torch_per_sample(
3401
+ x_bc_hw, wk, B, Cn
3402
+ ).contiguous()
3403
+
3404
+ pred_low = _downsample_avg_bt_t(pred_super, r) if r > 1 else pred_super
3405
+
3406
+ # --- robust weights (UNCHANGED) ---
3407
+ rnat = yb - pred_low
3408
+
3409
+ # Build/estimate variance map on the native grid (per batch)
3410
+ if vb is None:
3411
+ # robust per-batch variance estimate
3412
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3413
+ vmap = (1.4826 * mad) ** 2
3414
+ # repeat across channels
3415
+ vmap = vmap.repeat(1, Cn, rnat.shape[-2], rnat.shape[-1]).contiguous()
3416
+ else:
3417
+ vmap = vb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3418
+
3419
+ if str(rho).lower() == "l2":
3420
+ # L2 ⇒ psi/r = 1 (no robust clipping). Weighted LS with mask/variance.
3421
+ psi_over_r = torch.ones_like(rnat, dtype=torch.float32, device=x_t.device)
3422
+ else:
3423
+ # Huber ⇒ auto-delta or fixed delta
3424
+ if huber_delta < 0:
3425
+ # ensure we have auto deltas
3426
+ if (auto_delta_cache is None) or any(auto_delta_cache[fi] is None for fi in idx) or (it % 5 == 1):
3427
+ Btmp, C_here, H0, W0 = rnat.shape
3428
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3429
+ rms = 1.4826 * mad
3430
+ # store per-frame deltas
3431
+ if auto_delta_cache is None:
3432
+ # should not happen because we gated creation earlier, but be safe
3433
+ auto_delta_cache = [None] * n_frames
3434
+ for j, fi in enumerate(idx):
3435
+ auto_delta_cache[fi] = float((-huber_delta) * torch.clamp(rms[j, 0, 0, 0], min=1e-6).item())
3436
+ deltas = torch.tensor([auto_delta_cache[fi] for fi in idx],
3437
+ device=x_t.device, dtype=torch.float32).view(B, 1, 1, 1)
3438
+ else:
3439
+ deltas = torch.tensor(float(huber_delta), device=x_t.device,
3440
+ dtype=torch.float32).view(1, 1, 1, 1)
3441
+
3442
+ absr = rnat.abs()
3443
+ psi_over_r = torch.where(absr <= deltas, torch.ones_like(absr, dtype=torch.float32),
3444
+ deltas / (absr + EPS))
3445
+
3446
+ # compose weights with mask and variance
3447
+ if DEBUG_FLAT_WEIGHTS:
3448
+ wmap_low = mb.unsqueeze(1) # debug: mask-only weighting
3449
+ else:
3450
+ m1 = mb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3451
+ wmap_low = (psi_over_r / (vmap + EPS)) * m1
3452
+ wmap_low = torch.nan_to_num(wmap_low, nan=0.0, posinf=0.0, neginf=0.0)
3453
+
3454
+ # --- adjoint + backproject (UNCHANGED) ---
3455
+ if r > 1:
3456
+ up_y = _upsample_sum_bt_t(wmap_low * yb, r)
3457
+ up_pred = _upsample_sum_bt_t(wmap_low * pred_low, r)
3458
+ else:
3459
+ up_y, up_pred = wmap_low * yb, wmap_low * pred_low
3460
+
3461
+ if use_fft:
3462
+ back_num_list, back_den_list = [], []
3463
+ for j, fi in enumerate(idx):
3464
+ C_here = up_y.shape[1]
3465
+ bn_j = torch.empty_like(x_t, dtype=torch.float32)
3466
+ bd_j = torch.empty_like(x_t, dtype=torch.float32)
3467
+ KTf_pack = psfT_fft[fi]
3468
+ for c in range(C_here):
3469
+ _fft_conv_same_torch(up_y[j, c], KTf_pack, out_spatial=bn_j[c])
3470
+ _fft_conv_same_torch(up_pred[j, c], KTf_pack, out_spatial=bd_j[c])
3471
+ back_num_list.append(bn_j)
3472
+ back_den_list.append(bd_j)
3473
+ back_num = torch.stack(back_num_list, 0).sum(dim=0)
3474
+ back_den = torch.stack(back_den_list, 0).sum(dim=0)
3475
+ else:
3476
+ back_num = _grouped_conv_same_torch_per_sample(up_y, wkT, B, Cn).sum(dim=0)
3477
+ back_den = _grouped_conv_same_torch_per_sample(up_pred, wkT, B, Cn).sum(dim=0)
3478
+
3479
+ num += back_num
3480
+ den += back_den
3481
+ ci += B
3482
+
3483
+ if os.environ.get("MFDECONV_DEBUG_SYNC", "0") == "1":
3484
+ try:
3485
+ torch.cuda.synchronize()
3486
+ except Exception:
3487
+ pass
3488
+
3489
+ except RuntimeError as e:
3490
+ emsg = str(e).lower()
3491
+ # existing OOM backoff stays the same
3492
+ if ("out of memory" in emsg or "resource" in emsg or "alloc" in emsg) and B_cur > 1:
3493
+ B_cur = max(1, B_cur // 2)
3494
+ status_cb(f"GPU OOM: reducing batch_frames → {B_cur} and retrying this chunk.")
3495
+ if prefetch_batches:
3496
+ yb_next = mb_next = vb_next = None
3497
+ continue
3498
+ # NEW: busy/unavailable → trigger CPU retry for this iteration
3499
+ if _is_cuda_busy_error(e):
3500
+ status_cb("CUDA became unavailable mid-run → retrying this iteration on CPU and switching to CPU for the rest.")
3501
+ retry_cpu = True
3502
+ break # break the inner while; we'll recompute on CPU
3503
+ raise
3504
+
3505
+ _process_gui_events_safely()
3506
+
3507
+ except RuntimeError as e:
3508
+ # Catch outer-level CUDA busy/unavailable too (e.g., from syncs)
3509
+ if _is_cuda_busy_error(e):
3510
+ status_cb("CUDA error indicates device busy/unavailable → retrying this iteration on CPU and switching to CPU for the rest.")
3511
+ retry_cpu = True
3512
+ else:
3513
+ raise
3514
+
3515
+ if retry_cpu:
3516
+ # flip to CPU for this and subsequent iterations
3517
+ globals()["cuda_usable"] = False
3518
+ use_torch = False
3519
+ try:
3520
+ torch.cuda.empty_cache()
3521
+ except Exception:
3522
+ pass
3523
+
3524
+ try:
3525
+ import gc as _gc
3526
+ _gc.collect()
3527
+ except Exception:
3528
+ pass
3529
+
3530
+ # Rebuild accumulators on CPU for this iteration:
3531
+ # 1) convert x_t (seed/current estimate) to NumPy
3532
+ x_t = x_t.detach().cpu().numpy()
3533
+ # 2) reset accumulators as NumPy arrays
3534
+ num = np.zeros_like(x_t, dtype=np.float32)
3535
+ den = np.zeros_like(x_t, dtype=np.float32)
3536
+
3537
+ # 3) RUN THE EXISTING NUMPY ACCUMULATION (same as your else: block below)
3538
+ # 3) recompute this iteration’s accumulators on CPU
3539
+ for fi in range(n_frames):
3540
+ # load one frame as CHW (float32, sanitized)
3541
+
3542
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3543
+
3544
+ # forward predict (super grid) with this frame's PSF
3545
+ pred_super = _conv_np_same(x_t, psfs[fi])
3546
+
3547
+ # downsample to native if SR was on
3548
+ pred_low = _downsample_avg(pred_super, r) if r > 1 else pred_super
3549
+
3550
+ # per-frame mask/variance
3551
+ m2d = _mask_for(fi, y_chw) # 2D, [0..1]
3552
+ v2d = _var_for(fi, y_chw) # 2D or None
3553
+
3554
+ # robust weights per pixel/channel
3555
+ w = _weight_map(
3556
+ y=y_chw, pred=pred_low,
3557
+ huber_delta=local_delta, # 0.0 for L2, else huber_delta
3558
+ var_map=v2d, mask=m2d
3559
+ ).astype(np.float32, copy=False)
3560
+
3561
+ # adjoint (backproject) on super grid
3562
+ if r > 1:
3563
+ up_y = _upsample_sum(w * y_chw, r, target_hw=x_t.shape[-2:])
3564
+ up_pred = _upsample_sum(w * pred_low, r, target_hw=x_t.shape[-2:])
3565
+ else:
3566
+ up_y, up_pred = (w * y_chw), (w * pred_low)
3567
+
3568
+ num += _conv_np_same(up_y, flip_psf[fi])
3569
+ den += _conv_np_same(up_pred, flip_psf[fi])
3570
+
3571
+ # ensure strictly positive denominator
3572
+ den = np.clip(den, 1e-8, None).astype(np.float32, copy=False)
3573
+
3574
+ # switch everything to NumPy for the remainder of the run
3575
+ psf_fft = psfT_fft = None # (Torch packs no longer used)
3576
+
3577
+ # ---- multiplicative RL/MM step with clamping ----
3578
+ if use_torch:
3579
+ ratio = num / (den + EPS)
3580
+ neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
3581
+ ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
3582
+ upd = torch.clamp(ratio, 1.0 / kappa, kappa)
3583
+ x_next = torch.clamp(x_t * upd, min=0.0)
3584
+ # Robust scalars
3585
+ upd_med = torch.median(torch.abs(upd - 1))
3586
+ rel_change = (torch.median(torch.abs(x_next - x_t)) /
3587
+ (torch.median(torch.abs(x_t)) + 1e-8))
3588
+
3589
+ um = float(upd_med.detach().item())
3590
+ rc = float(rel_change.detach().item())
3591
+
3592
+ if early.step(it, max_iters, um, rc):
3593
+ x_t = x_next
3594
+ used_iters = it
3595
+ early_stopped = True
3596
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3597
+ _process_gui_events_safely()
3598
+ break
3599
+
3600
+ x_t = (1.0 - relax) * x_t + relax * x_next
3601
+ else:
3602
+ ratio = num / (den + EPS)
3603
+ neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
3604
+ if np.any(neutral): ratio[neutral] = 1.0
3605
+ upd = np.clip(ratio, 1.0 / kappa, kappa)
3606
+ x_next = np.clip(x_t * upd, 0.0, None)
3607
+
3608
+ um = float(np.median(np.abs(upd - 1.0)))
3609
+ rc = float(np.median(np.abs(x_next - x_t)) / (np.median(np.abs(x_t)) + 1e-8))
3610
+
3611
+ if early.step(it, max_iters, um, rc):
3612
+ x_t = x_next
3613
+ used_iters = it
3614
+ early_stopped = True
3615
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3616
+ _process_gui_events_safely()
3617
+ break
3618
+
3619
+ x_t = (1.0 - relax) * x_t + relax * x_next
3620
+
3621
+ # --- LOW-MEM CLEANUP (per-iteration) ---
3622
+ if low_mem:
3623
+ # Torch temporaries we created in the iteration (best-effort deletes)
3624
+ to_kill = [
3625
+ "pred_super", "pred_low", "wmap_low", "yb", "mb", "vb", "wk", "wkT",
3626
+ "back_num", "back_den", "pred_list", "back_num_list", "back_den_list",
3627
+ "x_bc_hw", "up_y", "up_pred", "psi_over_r", "vmap", "rnat", "deltas", "chan_tmp"
3628
+ ]
3629
+ loc = locals()
3630
+ for _name in to_kill:
3631
+ if _name in loc:
3632
+ try:
3633
+ del loc[_name]
3634
+ except Exception:
3635
+ pass
3636
+
3637
+ # Proactively release CUDA cache every other iter
3638
+ if use_torch:
3639
+ try:
3640
+ dev = _torch_device()
3641
+ if dev.type == "cuda" and (it % 2) == 0:
3642
+ import torch as _t
3643
+ _t.cuda.empty_cache()
3644
+ except Exception:
3645
+ pass
3646
+
3647
+ # Encourage Python to return big NumPy buffers to the OS sooner
3648
+ import gc as _gc
3649
+ _gc.collect()
3650
+
3651
+
3652
+ # ---- save intermediates ----
3653
+ if save_intermediate and (it % int(max(1, save_every)) == 0):
3654
+ try:
3655
+ x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3656
+ _save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
3657
+ except Exception as _e:
3658
+ status_cb(f"Intermediate save failed at iter {it}: {_e}")
3659
+
3660
+ frac = 0.25 + 0.70 * (it / float(max_iters))
3661
+ _emit_pct(frac, f"Iteration {it}/{max_iters}")
3662
+ status_cb(f"Iter {it}/{max_iters}")
3663
+ _process_gui_events_safely()
3664
+
3665
+ if not early_stopped:
3666
+ used_iters = max_iters
3667
+
3668
+ # ---- save result ----
3669
+ _emit_pct(0.97, "saving")
3670
+ x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3671
+ if x_final.ndim == 3:
3672
+ if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
3673
+ x_final = np.moveaxis(x_final, -1, 0)
3674
+ if x_final.shape[0] == 1:
3675
+ x_final = x_final[0]
3676
+
3677
+ try:
3678
+ hdr0 = _safe_primary_header(paths[0])
3679
+ except Exception:
3680
+ hdr0 = fits.Header()
3681
+
3682
+ hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
3683
+ hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
3684
+ hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
3685
+ hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
3686
+ hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
3687
+ hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
3688
+ r = int(max(1, super_res_factor))
3689
+ hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
3690
+ if r > 1:
3691
+ hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (native px)')
3692
+ hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
3693
+ hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
3694
+ hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
3695
+ hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
3696
+
3697
+ if isinstance(x_final, np.ndarray):
3698
+ if x_final.ndim == 2:
3699
+ hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
3700
+ elif x_final.ndim == 3:
3701
+ C, H, W = x_final.shape
3702
+ hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
3703
+
3704
+ save_path = _sr_out_path(out_path, super_res_factor)
3705
+ safe_out_path = _nonclobber_path(str(save_path))
3706
+ if safe_out_path != str(save_path):
3707
+ status_cb(f"Output exists — saving as: {safe_out_path}")
3708
+ fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
3709
+
3710
+ status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
3711
+ _emit_pct(1.00, "done")
3712
+ _process_gui_events_safely()
3713
+
3714
+ try:
3715
+ if use_torch:
3716
+ try: del num, den
3717
+ except Exception as e:
3718
+ import logging
3719
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3720
+ try: del psf_t, psfT_t
3721
+ except Exception as e:
3722
+ import logging
3723
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3724
+ _free_torch_memory()
3725
+ except Exception:
3726
+ pass
3727
+
3728
+ try:
3729
+ _clear_all_caches()
3730
+ except Exception:
3731
+ pass
3732
+
3733
+ return safe_out_path
3734
+
3735
+ # -----------------------------
3736
+ # Worker
3737
+ # -----------------------------
3738
+
3739
+ class MultiFrameDeconvWorker(QObject):
3740
+ progress = pyqtSignal(str)
3741
+ finished = pyqtSignal(bool, str, str) # success, message, out_path
3742
+
3743
+ def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
3744
+ huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
3745
+ star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
3746
+ save_intermediate: bool = False,
3747
+ seed_mode: str = "robust",
3748
+ # NEW SR params
3749
+ super_res_factor: int = 1,
3750
+ sr_sigma: float = 1.1,
3751
+ sr_psf_opt_iters: int = 250,
3752
+ sr_psf_opt_lr: float = 0.1,
3753
+ star_mask_ref_path: str | None = None):
3754
+
3755
+ super().__init__(parent)
3756
+ self.aligned_paths = aligned_paths
3757
+ self.output_path = output_path
3758
+ self.iters = iters
3759
+ self.kappa = kappa
3760
+ self.color_mode = color_mode
3761
+ self.huber_delta = huber_delta
3762
+ self.min_iters = min_iters # NEW
3763
+ self.star_mask_cfg = star_mask_cfg or {}
3764
+ self.varmap_cfg = varmap_cfg or {}
3765
+ self.use_star_masks = use_star_masks
3766
+ self.use_variance_maps = use_variance_maps
3767
+ self.rho = rho
3768
+ self.save_intermediate = save_intermediate
3769
+ self.super_res_factor = int(super_res_factor)
3770
+ self.sr_sigma = float(sr_sigma)
3771
+ self.sr_psf_opt_iters = int(sr_psf_opt_iters)
3772
+ self.sr_psf_opt_lr = float(sr_psf_opt_lr)
3773
+ self.star_mask_ref_path = star_mask_ref_path
3774
+ self.seed_mode = seed_mode
3775
+
3776
+
3777
+ def _log(self, s): self.progress.emit(s)
3778
+
3779
+ def run(self):
3780
+ try:
3781
+ out = multiframe_deconv(
3782
+ self.aligned_paths,
3783
+ self.output_path,
3784
+ iters=self.iters,
3785
+ kappa=self.kappa,
3786
+ color_mode=self.color_mode,
3787
+ seed_mode=self.seed_mode,
3788
+ huber_delta=self.huber_delta,
3789
+ use_star_masks=self.use_star_masks,
3790
+ use_variance_maps=self.use_variance_maps,
3791
+ rho=self.rho,
3792
+ min_iters=self.min_iters,
3793
+ status_cb=self._log,
3794
+ star_mask_cfg=self.star_mask_cfg,
3795
+ varmap_cfg=self.varmap_cfg,
3796
+ save_intermediate=self.save_intermediate,
3797
+ super_res_factor=self.super_res_factor,
3798
+ sr_sigma=self.sr_sigma,
3799
+ sr_psf_opt_iters=self.sr_psf_opt_iters,
3800
+ sr_psf_opt_lr=self.sr_psf_opt_lr,
3801
+ star_mask_ref_path=self.star_mask_ref_path,
3802
+ )
3803
+ self.finished.emit(True, "MF deconvolution complete.", out)
3804
+ _process_gui_events_safely()
3805
+ except Exception as e:
3806
+ self.finished.emit(False, f"MF deconvolution failed: {e}", "")
3807
+ finally:
3808
+ # Hard cleanup: drop references + free GPU memory
3809
+ try:
3810
+ # Drop big Python refs that might keep tensors alive indirectly
3811
+ self.aligned_paths = []
3812
+ self.star_mask_cfg = {}
3813
+ self.varmap_cfg = {}
3814
+ except Exception:
3815
+ pass
3816
+ try:
3817
+ _free_torch_memory() # your helper: del tensors, gc.collect(), etc.
3818
+ except Exception:
3819
+ pass
3820
+ try:
3821
+ import torch as _t
3822
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
3823
+ _t.cuda.synchronize()
3824
+ _t.cuda.empty_cache()
3825
+ if hasattr(_t, "mps") and getattr(_t.backends, "mps", None) and _t.backends.mps.is_available():
3826
+ # PyTorch 2.x has this
3827
+ if hasattr(_t.mps, "empty_cache"):
3828
+ _t.mps.empty_cache()
3829
+ # DirectML usually frees on GC; nothing special to call.
3830
+ except Exception:
3831
+ pass