setiastrosuitepro 1.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (394) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/data/SASP_data.fits +0 -0
  3. setiastro/data/catalogs/List_of_Galaxies_with_Distances_Gly.csv +488 -0
  4. setiastro/data/catalogs/astrobin_filters.csv +890 -0
  5. setiastro/data/catalogs/astrobin_filters_page1_local.csv +51 -0
  6. setiastro/data/catalogs/cali2.csv +63 -0
  7. setiastro/data/catalogs/cali2color.csv +65 -0
  8. setiastro/data/catalogs/celestial_catalog - original.csv +16471 -0
  9. setiastro/data/catalogs/celestial_catalog.csv +24031 -0
  10. setiastro/data/catalogs/detected_stars.csv +24784 -0
  11. setiastro/data/catalogs/fits_header_data.csv +46 -0
  12. setiastro/data/catalogs/test.csv +8 -0
  13. setiastro/data/catalogs/updated_celestial_catalog.csv +16471 -0
  14. setiastro/images/Astro_Spikes.png +0 -0
  15. setiastro/images/Background_startup.jpg +0 -0
  16. setiastro/images/HRDiagram.png +0 -0
  17. setiastro/images/LExtract.png +0 -0
  18. setiastro/images/LInsert.png +0 -0
  19. setiastro/images/Oxygenation-atm-2.svg.png +0 -0
  20. setiastro/images/RGB080604.png +0 -0
  21. setiastro/images/abeicon.png +0 -0
  22. setiastro/images/aberration.png +0 -0
  23. setiastro/images/acv_icon.png +0 -0
  24. setiastro/images/andromedatry.png +0 -0
  25. setiastro/images/andromedatry_satellited.png +0 -0
  26. setiastro/images/annotated.png +0 -0
  27. setiastro/images/aperture.png +0 -0
  28. setiastro/images/astrosuite.ico +0 -0
  29. setiastro/images/astrosuite.png +0 -0
  30. setiastro/images/astrosuitepro.icns +0 -0
  31. setiastro/images/astrosuitepro.ico +0 -0
  32. setiastro/images/astrosuitepro.png +0 -0
  33. setiastro/images/background.png +0 -0
  34. setiastro/images/background2.png +0 -0
  35. setiastro/images/benchmark.png +0 -0
  36. setiastro/images/big_moon_stabilizer_timeline.png +0 -0
  37. setiastro/images/big_moon_stabilizer_timeline_clean.png +0 -0
  38. setiastro/images/blaster.png +0 -0
  39. setiastro/images/blink.png +0 -0
  40. setiastro/images/clahe.png +0 -0
  41. setiastro/images/collage.png +0 -0
  42. setiastro/images/colorwheel.png +0 -0
  43. setiastro/images/contsub.png +0 -0
  44. setiastro/images/convo.png +0 -0
  45. setiastro/images/copyslot.png +0 -0
  46. setiastro/images/cosmic.png +0 -0
  47. setiastro/images/cosmicsat.png +0 -0
  48. setiastro/images/crop1.png +0 -0
  49. setiastro/images/cropicon.png +0 -0
  50. setiastro/images/curves.png +0 -0
  51. setiastro/images/cvs.png +0 -0
  52. setiastro/images/debayer.png +0 -0
  53. setiastro/images/denoise_cnn_custom.png +0 -0
  54. setiastro/images/denoise_cnn_graph.png +0 -0
  55. setiastro/images/disk.png +0 -0
  56. setiastro/images/dse.png +0 -0
  57. setiastro/images/exoicon.png +0 -0
  58. setiastro/images/eye.png +0 -0
  59. setiastro/images/first_quarter.png +0 -0
  60. setiastro/images/fliphorizontal.png +0 -0
  61. setiastro/images/flipvertical.png +0 -0
  62. setiastro/images/font.png +0 -0
  63. setiastro/images/freqsep.png +0 -0
  64. setiastro/images/full_moon.png +0 -0
  65. setiastro/images/functionbundle.png +0 -0
  66. setiastro/images/graxpert.png +0 -0
  67. setiastro/images/green.png +0 -0
  68. setiastro/images/gridicon.png +0 -0
  69. setiastro/images/halo.png +0 -0
  70. setiastro/images/hdr.png +0 -0
  71. setiastro/images/histogram.png +0 -0
  72. setiastro/images/hubble.png +0 -0
  73. setiastro/images/imagecombine.png +0 -0
  74. setiastro/images/invert.png +0 -0
  75. setiastro/images/isophote.png +0 -0
  76. setiastro/images/isophote_demo_figure.png +0 -0
  77. setiastro/images/isophote_demo_image.png +0 -0
  78. setiastro/images/isophote_demo_model.png +0 -0
  79. setiastro/images/isophote_demo_residual.png +0 -0
  80. setiastro/images/jwstpupil.png +0 -0
  81. setiastro/images/last_quarter.png +0 -0
  82. setiastro/images/linearfit.png +0 -0
  83. setiastro/images/livestacking.png +0 -0
  84. setiastro/images/mask.png +0 -0
  85. setiastro/images/maskapply.png +0 -0
  86. setiastro/images/maskcreate.png +0 -0
  87. setiastro/images/maskremove.png +0 -0
  88. setiastro/images/morpho.png +0 -0
  89. setiastro/images/mosaic.png +0 -0
  90. setiastro/images/multiscale_decomp.png +0 -0
  91. setiastro/images/nbtorgb.png +0 -0
  92. setiastro/images/neutral.png +0 -0
  93. setiastro/images/new_moon.png +0 -0
  94. setiastro/images/nuke.png +0 -0
  95. setiastro/images/openfile.png +0 -0
  96. setiastro/images/pedestal.png +0 -0
  97. setiastro/images/pen.png +0 -0
  98. setiastro/images/pixelmath.png +0 -0
  99. setiastro/images/platesolve.png +0 -0
  100. setiastro/images/ppp.png +0 -0
  101. setiastro/images/pro.png +0 -0
  102. setiastro/images/project.png +0 -0
  103. setiastro/images/psf.png +0 -0
  104. setiastro/images/redo.png +0 -0
  105. setiastro/images/redoicon.png +0 -0
  106. setiastro/images/rescale.png +0 -0
  107. setiastro/images/rgbalign.png +0 -0
  108. setiastro/images/rgbcombo.png +0 -0
  109. setiastro/images/rgbextract.png +0 -0
  110. setiastro/images/rotate180.png +0 -0
  111. setiastro/images/rotatearbitrary.png +0 -0
  112. setiastro/images/rotateclockwise.png +0 -0
  113. setiastro/images/rotatecounterclockwise.png +0 -0
  114. setiastro/images/satellite.png +0 -0
  115. setiastro/images/script.png +0 -0
  116. setiastro/images/selectivecolor.png +0 -0
  117. setiastro/images/simbad.png +0 -0
  118. setiastro/images/slot0.png +0 -0
  119. setiastro/images/slot1.png +0 -0
  120. setiastro/images/slot2.png +0 -0
  121. setiastro/images/slot3.png +0 -0
  122. setiastro/images/slot4.png +0 -0
  123. setiastro/images/slot5.png +0 -0
  124. setiastro/images/slot6.png +0 -0
  125. setiastro/images/slot7.png +0 -0
  126. setiastro/images/slot8.png +0 -0
  127. setiastro/images/slot9.png +0 -0
  128. setiastro/images/spcc.png +0 -0
  129. setiastro/images/spin_precession_vs_lunar_distance.png +0 -0
  130. setiastro/images/spinner.gif +0 -0
  131. setiastro/images/stacking.png +0 -0
  132. setiastro/images/staradd.png +0 -0
  133. setiastro/images/staralign.png +0 -0
  134. setiastro/images/starnet.png +0 -0
  135. setiastro/images/starregistration.png +0 -0
  136. setiastro/images/starspike.png +0 -0
  137. setiastro/images/starstretch.png +0 -0
  138. setiastro/images/statstretch.png +0 -0
  139. setiastro/images/supernova.png +0 -0
  140. setiastro/images/uhs.png +0 -0
  141. setiastro/images/undoicon.png +0 -0
  142. setiastro/images/upscale.png +0 -0
  143. setiastro/images/viewbundle.png +0 -0
  144. setiastro/images/waning_crescent_1.png +0 -0
  145. setiastro/images/waning_crescent_2.png +0 -0
  146. setiastro/images/waning_crescent_3.png +0 -0
  147. setiastro/images/waning_crescent_4.png +0 -0
  148. setiastro/images/waning_crescent_5.png +0 -0
  149. setiastro/images/waning_gibbous_1.png +0 -0
  150. setiastro/images/waning_gibbous_2.png +0 -0
  151. setiastro/images/waning_gibbous_3.png +0 -0
  152. setiastro/images/waning_gibbous_4.png +0 -0
  153. setiastro/images/waning_gibbous_5.png +0 -0
  154. setiastro/images/waxing_crescent_1.png +0 -0
  155. setiastro/images/waxing_crescent_2.png +0 -0
  156. setiastro/images/waxing_crescent_3.png +0 -0
  157. setiastro/images/waxing_crescent_4.png +0 -0
  158. setiastro/images/waxing_crescent_5.png +0 -0
  159. setiastro/images/waxing_gibbous_1.png +0 -0
  160. setiastro/images/waxing_gibbous_2.png +0 -0
  161. setiastro/images/waxing_gibbous_3.png +0 -0
  162. setiastro/images/waxing_gibbous_4.png +0 -0
  163. setiastro/images/waxing_gibbous_5.png +0 -0
  164. setiastro/images/whitebalance.png +0 -0
  165. setiastro/images/wimi_icon_256x256.png +0 -0
  166. setiastro/images/wimilogo.png +0 -0
  167. setiastro/images/wims.png +0 -0
  168. setiastro/images/wrench_icon.png +0 -0
  169. setiastro/images/xisfliberator.png +0 -0
  170. setiastro/qml/ResourceMonitor.qml +128 -0
  171. setiastro/saspro/__init__.py +20 -0
  172. setiastro/saspro/__main__.py +964 -0
  173. setiastro/saspro/_generated/__init__.py +7 -0
  174. setiastro/saspro/_generated/build_info.py +3 -0
  175. setiastro/saspro/abe.py +1379 -0
  176. setiastro/saspro/abe_preset.py +196 -0
  177. setiastro/saspro/aberration_ai.py +910 -0
  178. setiastro/saspro/aberration_ai_preset.py +224 -0
  179. setiastro/saspro/accel_installer.py +218 -0
  180. setiastro/saspro/accel_workers.py +30 -0
  181. setiastro/saspro/acv_exporter.py +379 -0
  182. setiastro/saspro/add_stars.py +627 -0
  183. setiastro/saspro/astrobin_exporter.py +1010 -0
  184. setiastro/saspro/astrospike.py +153 -0
  185. setiastro/saspro/astrospike_python.py +1841 -0
  186. setiastro/saspro/autostretch.py +198 -0
  187. setiastro/saspro/backgroundneutral.py +639 -0
  188. setiastro/saspro/batch_convert.py +328 -0
  189. setiastro/saspro/batch_renamer.py +522 -0
  190. setiastro/saspro/blemish_blaster.py +494 -0
  191. setiastro/saspro/blink_comparator_pro.py +3149 -0
  192. setiastro/saspro/bundles.py +61 -0
  193. setiastro/saspro/bundles_dock.py +114 -0
  194. setiastro/saspro/cheat_sheet.py +213 -0
  195. setiastro/saspro/clahe.py +371 -0
  196. setiastro/saspro/comet_stacking.py +1442 -0
  197. setiastro/saspro/common_tr.py +107 -0
  198. setiastro/saspro/config.py +38 -0
  199. setiastro/saspro/config_bootstrap.py +40 -0
  200. setiastro/saspro/config_manager.py +316 -0
  201. setiastro/saspro/continuum_subtract.py +1620 -0
  202. setiastro/saspro/convo.py +1403 -0
  203. setiastro/saspro/convo_preset.py +414 -0
  204. setiastro/saspro/copyastro.py +190 -0
  205. setiastro/saspro/cosmicclarity.py +1593 -0
  206. setiastro/saspro/cosmicclarity_preset.py +407 -0
  207. setiastro/saspro/crop_dialog_pro.py +1005 -0
  208. setiastro/saspro/crop_preset.py +189 -0
  209. setiastro/saspro/curve_editor_pro.py +2608 -0
  210. setiastro/saspro/curves_preset.py +375 -0
  211. setiastro/saspro/debayer.py +673 -0
  212. setiastro/saspro/debug_utils.py +29 -0
  213. setiastro/saspro/dnd_mime.py +35 -0
  214. setiastro/saspro/doc_manager.py +2727 -0
  215. setiastro/saspro/exoplanet_detector.py +2258 -0
  216. setiastro/saspro/file_utils.py +284 -0
  217. setiastro/saspro/fitsmodifier.py +748 -0
  218. setiastro/saspro/fix_bom.py +32 -0
  219. setiastro/saspro/free_torch_memory.py +48 -0
  220. setiastro/saspro/frequency_separation.py +1352 -0
  221. setiastro/saspro/function_bundle.py +1596 -0
  222. setiastro/saspro/generate_translations.py +3092 -0
  223. setiastro/saspro/ghs_dialog_pro.py +728 -0
  224. setiastro/saspro/ghs_preset.py +284 -0
  225. setiastro/saspro/graxpert.py +638 -0
  226. setiastro/saspro/graxpert_preset.py +287 -0
  227. setiastro/saspro/gui/__init__.py +0 -0
  228. setiastro/saspro/gui/main_window.py +8928 -0
  229. setiastro/saspro/gui/mixins/__init__.py +33 -0
  230. setiastro/saspro/gui/mixins/dock_mixin.py +375 -0
  231. setiastro/saspro/gui/mixins/file_mixin.py +450 -0
  232. setiastro/saspro/gui/mixins/geometry_mixin.py +503 -0
  233. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  234. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  235. setiastro/saspro/gui/mixins/menu_mixin.py +391 -0
  236. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  237. setiastro/saspro/gui/mixins/toolbar_mixin.py +1824 -0
  238. setiastro/saspro/gui/mixins/update_mixin.py +323 -0
  239. setiastro/saspro/gui/mixins/view_mixin.py +477 -0
  240. setiastro/saspro/gui/statistics_dialog.py +47 -0
  241. setiastro/saspro/halobgon.py +492 -0
  242. setiastro/saspro/header_viewer.py +448 -0
  243. setiastro/saspro/headless_utils.py +88 -0
  244. setiastro/saspro/histogram.py +760 -0
  245. setiastro/saspro/history_explorer.py +941 -0
  246. setiastro/saspro/i18n.py +168 -0
  247. setiastro/saspro/image_combine.py +421 -0
  248. setiastro/saspro/image_peeker_pro.py +1608 -0
  249. setiastro/saspro/imageops/__init__.py +37 -0
  250. setiastro/saspro/imageops/mdi_snap.py +292 -0
  251. setiastro/saspro/imageops/scnr.py +36 -0
  252. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  253. setiastro/saspro/imageops/stretch.py +236 -0
  254. setiastro/saspro/isophote.py +1186 -0
  255. setiastro/saspro/layers.py +208 -0
  256. setiastro/saspro/layers_dock.py +714 -0
  257. setiastro/saspro/lazy_imports.py +193 -0
  258. setiastro/saspro/legacy/__init__.py +2 -0
  259. setiastro/saspro/legacy/image_manager.py +2360 -0
  260. setiastro/saspro/legacy/numba_utils.py +3676 -0
  261. setiastro/saspro/legacy/xisf.py +1213 -0
  262. setiastro/saspro/linear_fit.py +537 -0
  263. setiastro/saspro/live_stacking.py +1854 -0
  264. setiastro/saspro/log_bus.py +5 -0
  265. setiastro/saspro/logging_config.py +460 -0
  266. setiastro/saspro/luminancerecombine.py +510 -0
  267. setiastro/saspro/main_helpers.py +201 -0
  268. setiastro/saspro/mask_creation.py +1090 -0
  269. setiastro/saspro/masks_core.py +56 -0
  270. setiastro/saspro/mdi_widgets.py +353 -0
  271. setiastro/saspro/memory_utils.py +666 -0
  272. setiastro/saspro/metadata_patcher.py +75 -0
  273. setiastro/saspro/mfdeconv.py +3909 -0
  274. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  275. setiastro/saspro/mfdeconvcudnn.py +3312 -0
  276. setiastro/saspro/mfdeconvsport.py +2459 -0
  277. setiastro/saspro/minorbodycatalog.py +567 -0
  278. setiastro/saspro/morphology.py +411 -0
  279. setiastro/saspro/multiscale_decomp.py +1751 -0
  280. setiastro/saspro/nbtorgb_stars.py +541 -0
  281. setiastro/saspro/numba_utils.py +3145 -0
  282. setiastro/saspro/numba_warmup.py +141 -0
  283. setiastro/saspro/ops/__init__.py +9 -0
  284. setiastro/saspro/ops/command_help_dialog.py +623 -0
  285. setiastro/saspro/ops/command_runner.py +217 -0
  286. setiastro/saspro/ops/commands.py +1594 -0
  287. setiastro/saspro/ops/script_editor.py +1105 -0
  288. setiastro/saspro/ops/scripts.py +1476 -0
  289. setiastro/saspro/ops/settings.py +637 -0
  290. setiastro/saspro/parallel_utils.py +554 -0
  291. setiastro/saspro/pedestal.py +121 -0
  292. setiastro/saspro/perfect_palette_picker.py +1105 -0
  293. setiastro/saspro/pipeline.py +110 -0
  294. setiastro/saspro/pixelmath.py +1604 -0
  295. setiastro/saspro/plate_solver.py +2480 -0
  296. setiastro/saspro/project_io.py +797 -0
  297. setiastro/saspro/psf_utils.py +136 -0
  298. setiastro/saspro/psf_viewer.py +631 -0
  299. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  300. setiastro/saspro/remove_green.py +331 -0
  301. setiastro/saspro/remove_stars.py +1599 -0
  302. setiastro/saspro/remove_stars_preset.py +446 -0
  303. setiastro/saspro/resources.py +570 -0
  304. setiastro/saspro/rgb_combination.py +208 -0
  305. setiastro/saspro/rgb_extract.py +19 -0
  306. setiastro/saspro/rgbalign.py +727 -0
  307. setiastro/saspro/runtime_imports.py +7 -0
  308. setiastro/saspro/runtime_torch.py +754 -0
  309. setiastro/saspro/save_options.py +73 -0
  310. setiastro/saspro/selective_color.py +1614 -0
  311. setiastro/saspro/sfcc.py +1530 -0
  312. setiastro/saspro/shortcuts.py +3125 -0
  313. setiastro/saspro/signature_insert.py +1106 -0
  314. setiastro/saspro/stacking_suite.py +19069 -0
  315. setiastro/saspro/star_alignment.py +7383 -0
  316. setiastro/saspro/star_alignment_preset.py +329 -0
  317. setiastro/saspro/star_metrics.py +49 -0
  318. setiastro/saspro/star_spikes.py +769 -0
  319. setiastro/saspro/star_stretch.py +542 -0
  320. setiastro/saspro/stat_stretch.py +554 -0
  321. setiastro/saspro/status_log_dock.py +78 -0
  322. setiastro/saspro/subwindow.py +3523 -0
  323. setiastro/saspro/supernovaasteroidhunter.py +1719 -0
  324. setiastro/saspro/swap_manager.py +134 -0
  325. setiastro/saspro/torch_backend.py +89 -0
  326. setiastro/saspro/torch_rejection.py +434 -0
  327. setiastro/saspro/translations/all_source_strings.json +4726 -0
  328. setiastro/saspro/translations/ar_translations.py +4096 -0
  329. setiastro/saspro/translations/de_translations.py +3728 -0
  330. setiastro/saspro/translations/es_translations.py +4169 -0
  331. setiastro/saspro/translations/fr_translations.py +4090 -0
  332. setiastro/saspro/translations/hi_translations.py +3803 -0
  333. setiastro/saspro/translations/integrate_translations.py +271 -0
  334. setiastro/saspro/translations/it_translations.py +4728 -0
  335. setiastro/saspro/translations/ja_translations.py +3834 -0
  336. setiastro/saspro/translations/pt_translations.py +3847 -0
  337. setiastro/saspro/translations/ru_translations.py +3082 -0
  338. setiastro/saspro/translations/saspro_ar.qm +0 -0
  339. setiastro/saspro/translations/saspro_ar.ts +16019 -0
  340. setiastro/saspro/translations/saspro_de.qm +0 -0
  341. setiastro/saspro/translations/saspro_de.ts +14548 -0
  342. setiastro/saspro/translations/saspro_es.qm +0 -0
  343. setiastro/saspro/translations/saspro_es.ts +16202 -0
  344. setiastro/saspro/translations/saspro_fr.qm +0 -0
  345. setiastro/saspro/translations/saspro_fr.ts +15870 -0
  346. setiastro/saspro/translations/saspro_hi.qm +0 -0
  347. setiastro/saspro/translations/saspro_hi.ts +14855 -0
  348. setiastro/saspro/translations/saspro_it.qm +0 -0
  349. setiastro/saspro/translations/saspro_it.ts +19046 -0
  350. setiastro/saspro/translations/saspro_ja.qm +0 -0
  351. setiastro/saspro/translations/saspro_ja.ts +14980 -0
  352. setiastro/saspro/translations/saspro_pt.qm +0 -0
  353. setiastro/saspro/translations/saspro_pt.ts +15024 -0
  354. setiastro/saspro/translations/saspro_ru.qm +0 -0
  355. setiastro/saspro/translations/saspro_ru.ts +11835 -0
  356. setiastro/saspro/translations/saspro_sw.qm +0 -0
  357. setiastro/saspro/translations/saspro_sw.ts +15237 -0
  358. setiastro/saspro/translations/saspro_uk.qm +0 -0
  359. setiastro/saspro/translations/saspro_uk.ts +15248 -0
  360. setiastro/saspro/translations/saspro_zh.qm +0 -0
  361. setiastro/saspro/translations/saspro_zh.ts +15289 -0
  362. setiastro/saspro/translations/sw_translations.py +3897 -0
  363. setiastro/saspro/translations/uk_translations.py +3929 -0
  364. setiastro/saspro/translations/zh_translations.py +3910 -0
  365. setiastro/saspro/versioning.py +77 -0
  366. setiastro/saspro/view_bundle.py +1558 -0
  367. setiastro/saspro/wavescale_hdr.py +648 -0
  368. setiastro/saspro/wavescale_hdr_preset.py +101 -0
  369. setiastro/saspro/wavescalede.py +683 -0
  370. setiastro/saspro/wavescalede_preset.py +230 -0
  371. setiastro/saspro/wcs_update.py +374 -0
  372. setiastro/saspro/whitebalance.py +540 -0
  373. setiastro/saspro/widgets/__init__.py +48 -0
  374. setiastro/saspro/widgets/common_utilities.py +306 -0
  375. setiastro/saspro/widgets/graphics_views.py +122 -0
  376. setiastro/saspro/widgets/image_utils.py +518 -0
  377. setiastro/saspro/widgets/minigame/game.js +991 -0
  378. setiastro/saspro/widgets/minigame/index.html +53 -0
  379. setiastro/saspro/widgets/minigame/style.css +241 -0
  380. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  381. setiastro/saspro/widgets/resource_monitor.py +313 -0
  382. setiastro/saspro/widgets/spinboxes.py +290 -0
  383. setiastro/saspro/widgets/themed_buttons.py +13 -0
  384. setiastro/saspro/widgets/wavelet_utils.py +331 -0
  385. setiastro/saspro/wimi.py +7367 -0
  386. setiastro/saspro/wims.py +588 -0
  387. setiastro/saspro/window_shelf.py +185 -0
  388. setiastro/saspro/xisf.py +1213 -0
  389. setiastrosuitepro-1.6.7.dist-info/METADATA +279 -0
  390. setiastrosuitepro-1.6.7.dist-info/RECORD +394 -0
  391. setiastrosuitepro-1.6.7.dist-info/WHEEL +4 -0
  392. setiastrosuitepro-1.6.7.dist-info/entry_points.txt +6 -0
  393. setiastrosuitepro-1.6.7.dist-info/licenses/LICENSE +674 -0
  394. setiastrosuitepro-1.6.7.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3909 @@
1
+ # pro/mfdeconv.py non-sport normal version
2
+ from __future__ import annotations
3
+ import os, sys
4
+ import math
5
+ import re
6
+ import numpy as np
7
+ import time
8
+ from astropy.io import fits
9
+ from PyQt6.QtCore import QObject, pyqtSignal
10
+ from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
11
+ from PyQt6.QtWidgets import QApplication
12
+ from PyQt6.QtCore import QThread
13
+ import contextlib
14
+ from threadpoolctl import threadpool_limits
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
16
+ _USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
17
+ import gc
18
+ try:
19
+ import sep
20
+ except Exception:
21
+ sep = None
22
+ from setiastro.saspro.free_torch_memory import _free_torch_memory
23
+ from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
24
+ torch = None # filled by runtime loader if available
25
+ TORCH_OK = False
26
+ NO_GRAD = contextlib.nullcontext # fallback
27
+ _XISF_READERS = []
28
+ try:
29
+ # e.g. your legacy module
30
+ from setiastro.saspro.legacy import xisf as _legacy_xisf
31
+ if hasattr(_legacy_xisf, "read"):
32
+ _XISF_READERS.append(lambda p: _legacy_xisf.read(p))
33
+ elif hasattr(_legacy_xisf, "open"):
34
+ _XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
35
+ except Exception:
36
+ pass
37
+ try:
38
+ # sometimes projects expose a generic load_image
39
+ from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
40
+ _XISF_READERS.append(lambda p: _generic_load_image(p)[0])
41
+ except Exception:
42
+ pass
43
+
44
+ from pathlib import Path
45
+
46
+ # at top of file with the other imports
47
+
48
+ from queue import SimpleQueue
49
+ from setiastro.saspro.memory_utils import LRUDict
50
+
51
+ # ── XISF decode cache → memmap on disk ─────────────────────────────────
52
+ import tempfile
53
+ import threading
54
+ import uuid
55
+ import atexit
56
+ _XISF_CACHE = LRUDict(50)
57
+ _XISF_LOCK = threading.Lock()
58
+ _XISF_TMPFILES = []
59
+
60
+ from collections import OrderedDict
61
+
62
+ # ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
63
+ class _FrameCHWLRU:
64
+ def __init__(self, capacity=8):
65
+ self.cap = int(max(1, capacity))
66
+ self.od = OrderedDict()
67
+
68
+ def clear(self):
69
+ self.od.clear()
70
+
71
+ def get(self, path, Ht, Wt, color_mode):
72
+ key = (path, Ht, Wt, str(color_mode).lower())
73
+ hit = self.od.get(key)
74
+ if hit is not None:
75
+ self.od.move_to_end(key)
76
+ return hit
77
+
78
+ # Load backing array cheaply (memmap for FITS, cached memmap for XISF)
79
+ ext = os.path.splitext(path)[1].lower()
80
+ if ext == ".xisf":
81
+ a = _xisf_cached_array(path) # float32, HW/HWC/CHW
82
+ else:
83
+ # FITS path: use astropy memmap (no data copy)
84
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
85
+ arr = None
86
+ for h in hdul:
87
+ if getattr(h, "data", None) is not None:
88
+ arr = h.data
89
+ break
90
+ if arr is None:
91
+ raise ValueError(f"No image data in {path}")
92
+ a = np.asarray(arr)
93
+ # dtype normalize once; keep float32
94
+ if a.dtype.kind in "ui":
95
+ a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
96
+ else:
97
+ a = a.astype(np.float32, copy=False)
98
+
99
+ # Center-crop to (Ht, Wt) and convert to CHW
100
+ a = np.asarray(a) # float32
101
+ a = _center_crop(a, Ht, Wt)
102
+
103
+ # Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
104
+ cm = str(color_mode).lower()
105
+ if cm == "luma":
106
+ a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
107
+ else:
108
+ a_chw = _as_chw(a).astype(np.float32, copy=False)
109
+ if a_chw.shape[0] == 1 and cm != "luma":
110
+ # still OK (mono data)
111
+ pass
112
+
113
+ # LRU insert
114
+ self.od[key] = a_chw
115
+ if len(self.od) > self.cap:
116
+ self.od.popitem(last=False)
117
+ return a_chw
118
+
119
+ _FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
120
+
121
+ def _clear_all_caches():
122
+ try: _clear_xisf_cache()
123
+ except Exception as e:
124
+ import logging
125
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
126
+ try: _FRAME_LRU.clear()
127
+ except Exception as e:
128
+ import logging
129
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
130
+
131
+
132
+ def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
133
+ if a.dtype.kind in "ui":
134
+ return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
135
+ if a.dtype == np.float32:
136
+ return a
137
+ return a.astype(np.float32, copy=False)
138
+
139
+ def _xisf_cached_array(path: str) -> np.memmap:
140
+ """
141
+ Decode an XISF image exactly once and back it by a read-only float32 memmap.
142
+ Returns a memmap that can be sliced cheaply for tiles.
143
+ """
144
+ with _XISF_LOCK:
145
+ hit = _XISF_CACHE.get(path)
146
+ if hit is not None:
147
+ fn, shape = hit
148
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
149
+
150
+ # Decode once
151
+ arr, _ = _load_image_array(path) # your existing loader
152
+ if arr is None:
153
+ raise ValueError(f"XISF loader returned None for {path}")
154
+ arr = np.asarray(arr)
155
+ arrf = _normalize_to_float32(arr)
156
+
157
+ # Create a temp file-backed memmap
158
+ tmpdir = tempfile.gettempdir()
159
+ fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
160
+ mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
161
+ mm[...] = arrf[...]
162
+ mm.flush()
163
+ del mm # close writer handle; re-open below as read-only
164
+
165
+ _XISF_CACHE[path] = (fn, arrf.shape)
166
+ _XISF_TMPFILES.append(fn)
167
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
168
+
169
+ def _clear_xisf_cache():
170
+ with _XISF_LOCK:
171
+ for fn in _XISF_TMPFILES:
172
+ try: os.remove(fn)
173
+ except Exception as e:
174
+ import logging
175
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
176
+ _XISF_CACHE.clear()
177
+ _XISF_TMPFILES.clear()
178
+
179
+ atexit.register(_clear_xisf_cache)
180
+
181
+
182
+ def _is_xisf(path: str) -> bool:
183
+ return os.path.splitext(path)[1].lower() == ".xisf"
184
+
185
+ def _read_xisf_numpy(path: str) -> np.ndarray:
186
+ if not _XISF_READERS:
187
+ raise RuntimeError(
188
+ "No XISF readers registered. Ensure one of "
189
+ "legacy.xisf.read/open or *.image_io.load_image is importable."
190
+ )
191
+ last_err = None
192
+ for fn in _XISF_READERS:
193
+ try:
194
+ arr = fn(path)
195
+ if isinstance(arr, tuple):
196
+ arr = arr[0]
197
+ return np.asarray(arr)
198
+ except Exception as e:
199
+ last_err = e
200
+ raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
201
+
202
+ def _fits_open_data(path: str):
203
+ # ignore_missing_simple=True lets us open headers missing SIMPLE
204
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
205
+ hdu = hdul[0]
206
+ if hdu.data is None:
207
+ # find first image HDU if primary is header-only
208
+ for h in hdul[1:]:
209
+ if getattr(h, "data", None) is not None:
210
+ hdu = h
211
+ break
212
+ data = np.asanyarray(hdu.data)
213
+ hdr = hdu.header
214
+ return data, hdr
215
+
216
+ def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
217
+ """
218
+ Return (numpy array, fits.Header or None). Color-last if 3D.
219
+ dtype left as-is; callers cast to float32. Array is C-contig & writeable.
220
+ """
221
+ if _is_xisf(path):
222
+ arr = _read_xisf_numpy(path)
223
+ hdr = None
224
+ else:
225
+ arr, hdr = _fits_open_data(path)
226
+
227
+ a = np.asarray(arr)
228
+ # Move color axis to last if 3D with a leading channel axis
229
+ if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
230
+ a = np.moveaxis(a, 0, -1)
231
+ # Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
232
+ if (not a.flags.c_contiguous) or (not a.flags.writeable):
233
+ a = np.array(a, copy=True)
234
+ return a, hdr
235
+
236
+ def _probe_hw(path: str) -> tuple[int, int, int | None]:
237
+ """
238
+ Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
239
+ """
240
+ a, _ = _load_image_array(path)
241
+ if a.ndim == 2:
242
+ return a.shape[0], a.shape[1], None
243
+ if a.ndim == 3:
244
+ h, w, c = a.shape
245
+ # treat mono-3D as (H,W,1)
246
+ if c not in (1, 3) and a.shape[0] in (1, 3):
247
+ a = np.moveaxis(a, 0, -1)
248
+ h, w, c = a.shape
249
+ return h, w, c if c in (1, 3) else None
250
+ raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
251
+
252
+ def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
253
+ Hs, Ws = [], []
254
+ for p in paths:
255
+ h, w, _ = _probe_hw(p)
256
+ h = int(h); w = int(w)
257
+ if h > 0 and w > 0:
258
+ Hs.append(h); Ws.append(w)
259
+
260
+ if not Hs:
261
+ raise ValueError("Could not determine any valid frame sizes.")
262
+ Ht = min(Hs); Wt = min(Ws)
263
+ if Ht < 8 or Wt < 8:
264
+ raise ValueError(f"Intersection too small: {Ht}x{Wt}")
265
+ return Ht, Wt
266
+
267
+
268
+ def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
269
+ """
270
+ Convert to CHW float32:
271
+ - mono → (1,H,W)
272
+ - RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
273
+ """
274
+ x = np.asarray(img)
275
+ if x.ndim == 2:
276
+ y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
277
+ return y
278
+ if x.ndim == 3:
279
+ # color-last (H,W,C) expected
280
+ if x.shape[-1] == 1:
281
+ return x[..., 0].astype(np.float32, copy=False)[None, ...]
282
+ if x.shape[-1] == 3:
283
+ if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
284
+ r, g, b = x[..., 0], x[..., 1], x[..., 2]
285
+ return np.stack([r.astype(np.float32, copy=False),
286
+ g.astype(np.float32, copy=False),
287
+ b.astype(np.float32, copy=False)], axis=0)
288
+ # luma
289
+ r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
290
+ L = 0.2126*r + 0.7152*g + 0.0722*b
291
+ return L[None, ...]
292
+ # rare mono-3D
293
+ if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
294
+ x = np.moveaxis(x, 0, -1)
295
+ return _to_chw_float32(x, color_mode)
296
+ raise ValueError(f"Unsupported image shape {x.shape}")
297
+
298
+ def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
299
+ h, w = img.shape[:2]
300
+ y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
301
+ return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
302
+
303
+ def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
304
+ """
305
+ Drop-in replacement of the old FITS-only helper.
306
+ Returns (ys, hdrs):
307
+ ys : list of CHW float32 arrays cropped to (Ht,Wt)
308
+ hdrs : list of fits.Header or None (XISF)
309
+ """
310
+ ys, hdrs = [], []
311
+ for p in paths:
312
+ arr, hdr = _load_image_array(p)
313
+ arr = _center_crop_hw(arr, Ht, Wt)
314
+ # normalize integer data to [0,1] like the rest of your code
315
+ if arr.dtype.kind in "ui":
316
+ mx = np.float32(np.iinfo(arr.dtype).max)
317
+ arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
318
+ elif arr.dtype.kind == "f":
319
+ arr = arr.astype(np.float32, copy=False)
320
+ else:
321
+ arr = arr.astype(np.float32, copy=False)
322
+
323
+ y = _to_chw_float32(arr, color_mode)
324
+ if (not y.flags.c_contiguous) or (not y.flags.writeable):
325
+ y = np.ascontiguousarray(y.astype(np.float32, copy=True))
326
+ ys.append(y)
327
+ hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
328
+ return ys, hdrs
329
+
330
+ def _safe_primary_header(path: str) -> fits.Header:
331
+ if _is_xisf(path):
332
+ # best-effort synthetic header
333
+ h = fits.Header()
334
+ h["SIMPLE"] = (True, "created by MFDeconv")
335
+ h["BITPIX"] = -32
336
+ h["NAXIS"] = 2
337
+ return h
338
+ try:
339
+ return fits.getheader(path, ext=0, ignore_missing_simple=True)
340
+ except Exception:
341
+ return fits.Header()
342
+
343
+ # --- CUDA busy/unavailable detector (runtime fallback helper) ---
344
+ def _is_cuda_busy_error(e: Exception) -> bool:
345
+ """
346
+ Return True for 'device busy/unavailable' style CUDA errors that can pop up
347
+ mid-run on shared Linux systems. We *exclude* OOM here (handled elsewhere).
348
+ """
349
+ em = str(e).lower()
350
+ if "out of memory" in em:
351
+ return False # OOM handled by your batch size backoff
352
+ return (
353
+ ("cuda" in em and ("busy" in em or "unavailable" in em))
354
+ or ("device-side" in em and "assert" in em)
355
+ or ("driver shutting down" in em)
356
+ )
357
+
358
+
359
+ def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
360
+ star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
361
+ """
362
+ Worker function: compute PSF and optional star mask / varmap for one frame.
363
+ Returns (index, psf, mask_or_None, var_or_None, log_lines)
364
+ """
365
+ logs = []
366
+ def log(s): logs.append(s)
367
+
368
+ # --- PSF sizing by FWHM ---
369
+ f_hdr = _estimate_fwhm_from_header(hdr)
370
+ f_img = _estimate_fwhm_from_image(arr)
371
+ f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
372
+ if not np.isfinite(f_whm) or f_whm <= 0:
373
+ f_whm = 2.5
374
+ k_auto = _auto_ksize_from_fwhm(f_whm)
375
+
376
+ # --- Star-derived PSF with retries (dynamic det_sigma ladder) ---
377
+ psf = None
378
+
379
+ # Your existing ksize ladder
380
+ k_ladder = [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]
381
+
382
+ # New: start high to avoid detecting 10k stars; step down only if needed
383
+ sigma_ladder = [50.0, 25.0, 12.0, 6.0]
384
+
385
+ tried = set()
386
+ for det_sigma in sigma_ladder:
387
+ for k_try in k_ladder:
388
+ if (det_sigma, k_try) in tried:
389
+ continue
390
+ tried.add((det_sigma, k_try))
391
+ try:
392
+ out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=det_sigma, max_stars=80)
393
+ psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
394
+ if psf_try is not None:
395
+ psf = psf_try
396
+ break
397
+ except Exception:
398
+ psf = None
399
+ if psf is not None:
400
+ break
401
+
402
+ if psf is None:
403
+ psf = _gaussian_psf(f_whm, ksize=k_auto)
404
+
405
+ psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.25)
406
+
407
+ mask = None
408
+ var = None
409
+
410
+ if make_masks or make_varmaps:
411
+ # one background per frame (reused by both)
412
+ luma = _to_luma_local(arr)
413
+ vmc = (varmap_cfg or {})
414
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
415
+ luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
416
+ )
417
+
418
+ if make_masks:
419
+ smc = star_mask_cfg or {}
420
+ mask = _star_mask_from_precomputed(
421
+ luma, sky_map, err_scalar,
422
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
423
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
424
+ grow_px = smc.get("grow_px", GROW_PX),
425
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
426
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
427
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
428
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
429
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
430
+ status_cb = log,
431
+ )
432
+
433
+ if make_varmaps:
434
+ vmc = varmap_cfg or {}
435
+ var = _variance_map_from_precomputed(
436
+ luma, sky_map, rms_map, hdr,
437
+ smooth_sigma = vmc.get("smooth_sigma", 1.0),
438
+ floor = vmc.get("floor", 1e-8),
439
+ status_cb = log,
440
+ )
441
+
442
+ # small per-frame summary
443
+ fwhm_est = _psf_fwhm_px(psf)
444
+ logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
445
+
446
+ return i, psf, mask, var, logs
447
+
448
+
449
+ def _compute_one_worker(args):
450
+ """
451
+ Top-level picklable worker for ProcessPoolExecutor.
452
+ args: (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg)
453
+ Returns (i, psf, mask, var, logs)
454
+ """
455
+ (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg) = args
456
+ # avoid BLAS/OMP storm inside each process
457
+ with threadpool_limits(limits=1):
458
+ arr, hdr = _load_image_array(path) # FITS or XISF
459
+ arr = np.asarray(arr, dtype=np.float32, order="C")
460
+ if arr.ndim == 3 and arr.shape[-1] == 1:
461
+ arr = np.squeeze(arr, axis=-1)
462
+ if not isinstance(hdr, fits.Header): # synthesize FITS-like header for XISF
463
+ hdr = _safe_primary_header(path)
464
+ return _compute_frame_assets(
465
+ i, arr, hdr,
466
+ make_masks=bool(make_masks_in_worker),
467
+ make_varmaps=bool(make_varmaps),
468
+ star_mask_cfg=star_mask_cfg,
469
+ varmap_cfg=varmap_cfg,
470
+ )
471
+
472
+
473
+ def _build_psf_and_assets(
474
+ paths, # list[str]
475
+ make_masks=False,
476
+ make_varmaps=False,
477
+ status_cb=lambda s: None,
478
+ save_dir: str | None = None,
479
+ star_mask_cfg: dict | None = None,
480
+ varmap_cfg: dict | None = None,
481
+ max_workers: int | None = None,
482
+ star_mask_ref_path: str | None = None, # build one mask from this frame if provided
483
+ # NEW (passed from multiframe_deconv so we don’t re-probe/convert):
484
+ Ht: int | None = None,
485
+ Wt: int | None = None,
486
+ color_mode: str = "luma",
487
+ ):
488
+ """
489
+ Parallel PSF + (optional) star mask + variance map per frame.
490
+
491
+ Changes from the original:
492
+ • Reuses the decoded frame cache (_FRAME_LRU) for FITS/XISF so we never re-decode.
493
+ • Automatically switches to threads for XISF (so memmaps are shared across workers).
494
+ • Builds a single reference star mask (if requested) from the cached frame and
495
+ center-pads/crops it for all frames (no extra I/O).
496
+ • Preserves return order and streams worker logs back to the UI.
497
+ """
498
+ if save_dir:
499
+ os.makedirs(save_dir, exist_ok=True)
500
+
501
+ n = len(paths)
502
+
503
+ # Resolve target intersection size if caller didn't pass it
504
+ if Ht is None or Wt is None:
505
+ Ht, Wt = _common_hw_from_paths(paths)
506
+
507
+ # Sensible default worker count (cap at 8)
508
+ if max_workers is None:
509
+ try:
510
+ hw = os.cpu_count() or 4
511
+ except Exception:
512
+ hw = 4
513
+ max_workers = max(1, min(8, hw))
514
+
515
+ # Decide executor: for any XISF, prefer threads so the memmap/cache is shared
516
+ any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
517
+ use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
518
+ Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
519
+ pool_kind = "process" if use_proc_pool else "thread"
520
+ status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
521
+
522
+ # ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
523
+ def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
524
+ a2d = np.asarray(a2d, dtype=np.float32)
525
+ H, W = int(a2d.shape[0]), int(a2d.shape[1])
526
+ # crop first if bigger
527
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
528
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
529
+ cropped = a2d[y0:y1, x0:x1]
530
+ ch, cw = cropped.shape
531
+ if ch == Ht and cw == Wt:
532
+ return np.ascontiguousarray(cropped, dtype=np.float32)
533
+ # pad if smaller
534
+ out = np.full((Ht, Wt), float(fill), dtype=np.float32)
535
+ oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
536
+ out[oy:oy+ch, ox:ox+cw] = cropped
537
+ return out
538
+
539
+ # ---- optional: build one mask from the reference frame and reuse ----
540
+ base_ref_mask = None
541
+ if make_masks and star_mask_ref_path:
542
+ try:
543
+ status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
544
+ # Pull from the shared frame cache as luma on (Ht,Wt)
545
+ ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
546
+ L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
547
+
548
+ vmc = (varmap_cfg or {})
549
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
550
+ L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
551
+ )
552
+ smc = (star_mask_cfg or {})
553
+ base_ref_mask = _star_mask_from_precomputed(
554
+ L, sky_map, err_scalar,
555
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
556
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
557
+ grow_px = smc.get("grow_px", GROW_PX),
558
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
559
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
560
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
561
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
562
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
563
+ status_cb = status_cb,
564
+ )
565
+ except Exception as e:
566
+ status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
567
+ base_ref_mask = None
568
+
569
+ # for GUI safety, queue logs from workers and flush in the main thread
570
+ log_queue: SimpleQueue = SimpleQueue()
571
+
572
+ def enqueue_logs(lines):
573
+ for s in lines:
574
+ log_queue.put(s)
575
+
576
+ psfs = [None] * n
577
+ masks = ([None] * n) if make_masks else None
578
+ vars_ = ([None] * n) if make_varmaps else None
579
+ make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
580
+
581
+ # --- thread worker: get frame from cache and compute assets ---
582
+ def _compute_one(i: int, path: str):
583
+ # avoid heavy BLAS oversubscription inside each worker
584
+ with threadpool_limits(limits=1):
585
+ # Pull frame from cache honoring color_mode & target (Ht,Wt)
586
+ img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
587
+ # For PSF/mask/varmap we operate on a 2D plane (luma/mono)
588
+ arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
589
+
590
+ # Header: synthesize a safe FITS-like header (works for XISF too)
591
+ try:
592
+ hdr = _safe_primary_header(path)
593
+ except Exception:
594
+ hdr = fits.Header()
595
+
596
+ return _compute_frame_assets(
597
+ i, arr2d, hdr,
598
+ make_masks=bool(make_masks_in_worker),
599
+ make_varmaps=bool(make_varmaps),
600
+ star_mask_cfg=star_mask_cfg,
601
+ varmap_cfg=varmap_cfg,
602
+ )
603
+
604
+ # --- submit jobs ---
605
+ with Executor(max_workers=max_workers) as ex:
606
+ futs = []
607
+ for i, p in enumerate(paths, start=1):
608
+ status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
609
+ if use_proc_pool:
610
+ # Process-safe path: worker re-loads inside the subprocess
611
+ futs.append(ex.submit(
612
+ _compute_one_worker,
613
+ (i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg)
614
+ ))
615
+ else:
616
+ # Thread path: hits the shared cache (fast path for XISF/FITS)
617
+ futs.append(ex.submit(_compute_one, i, p))
618
+
619
+ done_cnt = 0
620
+ for fut in as_completed(futs):
621
+ i, psf, m, v, logs = fut.result()
622
+ idx = i - 1
623
+ psfs[idx] = psf
624
+ if masks is not None:
625
+ masks[idx] = m
626
+ if vars_ is not None:
627
+ vars_[idx] = v
628
+ enqueue_logs(logs)
629
+
630
+ done_cnt += 1
631
+ if (done_cnt % 4) == 0 or done_cnt == n:
632
+ while not log_queue.empty():
633
+ try:
634
+ status_cb(log_queue.get_nowait())
635
+ except Exception:
636
+ break
637
+
638
+ # If we built a single reference mask, apply it to every frame (center pad/crop)
639
+ if base_ref_mask is not None and masks is not None:
640
+ for idx in range(n):
641
+ masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
642
+
643
+ # final flush of any remaining logs
644
+ while not log_queue.empty():
645
+ try:
646
+ status_cb(log_queue.get_nowait())
647
+ except Exception:
648
+ break
649
+
650
+ # save PSFs if requested
651
+ if save_dir:
652
+ for i, k in enumerate(psfs, start=1):
653
+ if k is not None:
654
+ fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
655
+ os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
656
+ )
657
+
658
+ return psfs, masks, vars_
659
+
660
+
661
+ _ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
662
+
663
+ # known FITS-style multi-extensions (rightmost-first match)
664
+ _KNOWN_EXTS = [
665
+ ".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
666
+ ".fz", ".gz",
667
+ ".fits", ".fit"
668
+ ]
669
+
670
+ def _sanitize_token(s: str) -> str:
671
+ s = _ALLOWED.sub("_", s)
672
+ s = re.sub(r"_+", "_", s).strip("_")
673
+ return s
674
+
675
+ def _split_known_exts(p: Path) -> tuple[str, str]:
676
+ """
677
+ Return (name_body, full_ext) where full_ext is a REAL extension block
678
+ (e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
679
+ """
680
+ name = p.name
681
+ for ext in _KNOWN_EXTS:
682
+ if name.lower().endswith(ext):
683
+ body = name[:-len(ext)]
684
+ return body, ext
685
+ # fallback: single suffix
686
+ return p.stem, "".join(p.suffixes)
687
+
688
+ _SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
689
+ _EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
690
+ _RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
691
+
692
+ def _extract_size(body: str) -> str | None:
693
+ m = _SIZE_RE.search(body)
694
+ return f"{m.group(1)}x{m.group(2)}" if m else None
695
+
696
+ def _extract_exposure_secs(body: str) -> str | None:
697
+ m = _EXP_RE.search(body)
698
+ if not m:
699
+ return None
700
+ secs = int(round(float(m.group(1))))
701
+ return f"{secs}s"
702
+
703
+ def _strip_metadata_from_base(body: str) -> str:
704
+ s = body
705
+
706
+ # normalize common separators first
707
+ s = s.replace(" - ", "_")
708
+
709
+ # remove known trailing marker '_MFDeconv'
710
+ s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
711
+
712
+ # remove parenthetical copy counters e.g. '(1)'
713
+ s = re.sub(r"\(\s*\d+\s*\)$", "", s)
714
+
715
+ # remove size (with or without parens) anywhere
716
+ s = _SIZE_RE.sub("", s)
717
+
718
+ # remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
719
+ s = _EXP_RE.sub("", s)
720
+
721
+ # remove any _#x tokens
722
+ s = _RX_RE.sub("", s)
723
+
724
+ # collapse whitespace/underscores and sanitize
725
+ s = re.sub(r"[\s]+", "_", s)
726
+ s = _sanitize_token(s)
727
+ return s or "output"
728
+
729
+ def _canonical_out_name_prefix(base: str, r: int, size: str | None,
730
+ exposure_secs: str | None, tag: str = "MFDeconv") -> str:
731
+ parts = [_sanitize_token(tag), _sanitize_token(base)]
732
+ if size:
733
+ parts.append(_sanitize_token(size))
734
+ if exposure_secs:
735
+ parts.append(_sanitize_token(exposure_secs))
736
+ if int(max(1, r)) > 1:
737
+ parts.append(f"{int(r)}x")
738
+ return "_".join(parts)
739
+
740
+ def _sr_out_path(out_path: str, r: int) -> Path:
741
+ """
742
+ Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
743
+ """
744
+ p = Path(out_path)
745
+ body, real_ext = _split_known_exts(p)
746
+
747
+ # harvest metadata from the whole body (not Path.stem)
748
+ size = _extract_size(body)
749
+ ex_sec = _extract_exposure_secs(body)
750
+
751
+ # clean base
752
+ base = _strip_metadata_from_base(body)
753
+
754
+ new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
755
+ return p.with_name(f"{new_stem}{real_ext}")
756
+
757
+ def _nonclobber_path(path: str) -> str:
758
+ """
759
+ Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
760
+ """
761
+ p = Path(path)
762
+ if not p.exists():
763
+ return str(p)
764
+
765
+ # keep the true extension(s)
766
+ body, real_ext = _split_known_exts(p)
767
+
768
+ # if already has _vN, bump it
769
+ m = re.search(r"(.*)_v(\d+)$", body)
770
+ if m:
771
+ base = m.group(1); n = int(m.group(2)) + 1
772
+ else:
773
+ base = body; n = 2
774
+
775
+ while True:
776
+ candidate = p.with_name(f"{base}_v{n}{real_ext}")
777
+ if not candidate.exists():
778
+ return str(candidate)
779
+ n += 1
780
+
781
+ def _iter_folder(basefile: str) -> str:
782
+ d, fname = os.path.split(basefile)
783
+ root, ext = os.path.splitext(fname)
784
+ tgt = os.path.join(d, f"{root}.iters")
785
+ if not os.path.exists(tgt):
786
+ try:
787
+ os.makedirs(tgt, exist_ok=True)
788
+ except Exception:
789
+ # last resort: suffix (n)
790
+ n = 1
791
+ while True:
792
+ cand = os.path.join(d, f"{root}.iters ({n})")
793
+ try:
794
+ os.makedirs(cand, exist_ok=True)
795
+ return cand
796
+ except Exception:
797
+ n += 1
798
+ return tgt
799
+
800
+ def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
801
+ """
802
+ arr: numpy array (H,W) or (C,H,W) float32
803
+ tag: 'seed' or 'iter_###'
804
+ """
805
+ if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
806
+ arr = np.moveaxis(arr, -1, 0)
807
+ if arr.ndim == 3 and arr.shape[0] == 1:
808
+ arr = arr[0]
809
+
810
+ hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
811
+ hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
812
+ hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
813
+ path = os.path.join(folder, f"{tag}.fit")
814
+ # overwrite allowed inside the dedicated folder
815
+ fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
816
+ return path
817
+
818
+
819
+ def _process_gui_events_safely():
820
+ app = QApplication.instance()
821
+ if app and QThread.currentThread() is app.thread():
822
+ app.processEvents()
823
+
824
+ EPS = 1e-6
825
+
826
+ # -----------------------------
827
+ # Helpers: image prep / shapes
828
+ # -----------------------------
829
+
830
+ # new: lightweight loader that yields one frame at a time
831
+ def _iter_fits(paths):
832
+ for p in paths:
833
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
834
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy
835
+ if arr.ndim == 3 and arr.shape[-1] == 1:
836
+ arr = np.squeeze(arr, axis=-1)
837
+ hdr = hdul[0].header.copy()
838
+ yield arr, hdr
839
+
840
+ def _to_luma_local(a: np.ndarray) -> np.ndarray:
841
+ a = np.asarray(a, dtype=np.float32)
842
+ if a.ndim == 2:
843
+ return a
844
+ # (H,W,3) or (3,H,W)
845
+ if a.ndim == 3 and a.shape[-1] == 3:
846
+ try:
847
+ import cv2
848
+ return cv2.cvtColor(a, cv2.COLOR_RGB2GRAY).astype(np.float32, copy=False)
849
+ except Exception:
850
+ pass
851
+ r, g, b = a[..., 0], a[..., 1], a[..., 2]
852
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
853
+ if a.ndim == 3 and a.shape[0] == 3:
854
+ r, g, b = a[0], a[1], a[2]
855
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
856
+ return a.mean(axis=-1).astype(np.float32, copy=False)
857
+
858
+ def _stack_loader(paths):
859
+ ys, hdrs = [], []
860
+ for p in paths:
861
+ with fits.open(p, memmap=False) as hdul: # ⬅ False
862
+ arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy inside with
863
+ hdr = hdul[0].header.copy()
864
+ if arr.ndim == 3 and arr.shape[-1] == 1:
865
+ arr = np.squeeze(arr, axis=-1)
866
+ ys.append(arr)
867
+ hdrs.append(hdr)
868
+ return ys, hdrs
869
+
870
+ def _normalize_layout_single(a, color_mode):
871
+ """
872
+ Coerce to:
873
+ - 'luma' -> (H, W)
874
+ - 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
875
+ Accepts (H,W), (H,W,3), or (3,H,W).
876
+ """
877
+ a = np.asarray(a, dtype=np.float32)
878
+
879
+ if color_mode == "luma":
880
+ return _to_luma_local(a) # returns (H,W)
881
+
882
+ # perchannel
883
+ if a.ndim == 2:
884
+ return a[None, ...] # (1,H,W) ← keep mono as 1 channel
885
+ if a.ndim == 3 and a.shape[-1] == 3:
886
+ return np.moveaxis(a, -1, 0) # (3,H,W)
887
+ if a.ndim == 3 and a.shape[0] in (1, 3):
888
+ return a # already (1,H,W) or (3,H,W)
889
+ # fallback: average any weird shape into luma 1×H×W
890
+ l = _to_luma_local(a)
891
+ return l[None, ...]
892
+
893
+
894
+ def _normalize_layout_batch(arrs, color_mode):
895
+ return [_normalize_layout_single(a, color_mode) for a in arrs]
896
+
897
+ def _common_hw(data_list):
898
+ """Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
899
+ Hs, Ws = [], []
900
+ for a in data_list:
901
+ if a.ndim == 2:
902
+ H, W = a.shape
903
+ else:
904
+ _, H, W = a.shape
905
+ Hs.append(H); Ws.append(W)
906
+ return int(min(Hs)), int(min(Ws))
907
+
908
+ def _center_crop(arr, Ht, Wt):
909
+ """Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
910
+ if arr.ndim == 2:
911
+ H, W = arr.shape
912
+ if H == Ht and W == Wt:
913
+ return arr
914
+ y0 = max(0, (H - Ht) // 2)
915
+ x0 = max(0, (W - Wt) // 2)
916
+ return arr[y0:y0+Ht, x0:x0+Wt]
917
+ else:
918
+ C, H, W = arr.shape
919
+ if H == Ht and W == Wt:
920
+ return arr
921
+ y0 = max(0, (H - Ht) // 2)
922
+ x0 = max(0, (W - Wt) // 2)
923
+ return arr[:, y0:y0+Ht, x0:x0+Wt]
924
+
925
+ def _sanitize_numeric(a):
926
+ """Replace NaN/Inf, clip negatives, make contiguous float32."""
927
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
928
+ a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
929
+ return np.ascontiguousarray(a)
930
+
931
+ # -----------------------------
932
+ # PSF utilities
933
+ # -----------------------------
934
+
935
+ def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
936
+ sigma = max(fwhm_px, 1.0) / 2.3548
937
+ r = (ksize - 1) / 2
938
+ y, x = np.mgrid[-r:r+1, -r:r+1]
939
+ g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
940
+ g /= (np.sum(g) + EPS)
941
+ return g.astype(np.float32, copy=False)
942
+
943
+ def _estimate_fwhm_from_header(hdr) -> float:
944
+ for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
945
+ if key in hdr:
946
+ try:
947
+ val = float(hdr[key])
948
+ if np.isfinite(val) and val > 0:
949
+ return val
950
+ except Exception:
951
+ pass
952
+ return float("nan")
953
+
954
+ def _estimate_fwhm_from_image(arr) -> float:
955
+ """Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
956
+ if sep is None:
957
+ return float("nan")
958
+ try:
959
+ img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
960
+ bkg = sep.Background(img)
961
+ data = _contig(img - bkg.back()) # ← ensure data is C-contig
962
+ try:
963
+ err = bkg.globalrms
964
+ except Exception:
965
+ err = float(np.median(bkg.rms()))
966
+ sources = sep.extract(data, 6.0, err=err)
967
+ if sources is None or len(sources) == 0:
968
+ return float("nan")
969
+ a = np.asarray(sources["a"], dtype=np.float32)
970
+ b = np.asarray(sources["b"], dtype=np.float32)
971
+ ab = (a + b) * 0.5
972
+ sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
973
+ if not np.isfinite(sigma) or sigma <= 0:
974
+ return float("nan")
975
+ return 2.3548 * sigma
976
+ except Exception:
977
+ return float("nan")
978
+
979
+ def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
980
+ """
981
+ Choose odd kernel size to cover about ±4σ.
982
+ """
983
+ sigma = max(fwhm_px, 1.0) / 2.3548
984
+ r = int(math.ceil(4.0 * sigma))
985
+ k = 2 * r + 1
986
+ k = max(kmin, min(k, kmax))
987
+ if (k % 2) == 0:
988
+ k += 1
989
+ return k
990
+
991
+ def _flip_kernel(psf):
992
+ # PyTorch dislikes negative strides; make it contiguous.
993
+ return np.flip(np.flip(psf, -1), -2).copy()
994
+
995
+ def _conv_same_np(img, psf):
996
+ """
997
+ NumPy FFT-based SAME convolution for (H,W) or (C,H,W).
998
+ IMPORTANT: ifftshift the PSF so its peak is at [0,0] before FFT.
999
+ """
1000
+ import numpy as _np
1001
+ import numpy.fft as _fft
1002
+
1003
+ kh, kw = psf.shape
1004
+
1005
+ def fftconv2(a, k):
1006
+ # a is (1,H,W); k is (kh,kw)
1007
+ H, W = a.shape[-2:]
1008
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1009
+ A = _fft.rfftn(a, s=(fftH, fftW), axes=(-2, -1))
1010
+ K = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW), axes=(-2, -1))
1011
+ y = _fft.irfftn(A * K, s=(fftH, fftW), axes=(-2, -1))
1012
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1013
+ return y[..., sh:sh+H, sw:sw+W]
1014
+
1015
+ if img.ndim == 2:
1016
+ return fftconv2(img[None], psf)[0]
1017
+ else:
1018
+ # per-channel
1019
+ return _np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
1020
+
1021
+ def _normalize_psf(psf):
1022
+ psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
1023
+ s = float(psf.sum())
1024
+ if not np.isfinite(s) or s <= 1e-6:
1025
+ return psf
1026
+ return (psf / s).astype(np.float32, copy=False)
1027
+
1028
+ def _soften_psf(psf, sigma_px=0.25):
1029
+ if sigma_px <= 0:
1030
+ return psf
1031
+ r = int(max(1, round(3 * sigma_px)))
1032
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1033
+ g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
1034
+ g /= g.sum() + 1e-6
1035
+ return _conv_same_np(psf[None], g)[0]
1036
+
1037
+ def _psf_fwhm_px(psf: np.ndarray) -> float:
1038
+ """Approximate FWHM (pixels) from second moments of a normalized kernel."""
1039
+ psf = np.maximum(psf, 0).astype(np.float32, copy=False)
1040
+ s = float(psf.sum())
1041
+ if s <= EPS:
1042
+ return float("nan")
1043
+ k = psf.shape[0]
1044
+ y, x = np.mgrid[:k, :k].astype(np.float32)
1045
+ cy = float((psf * y).sum() / s)
1046
+ cx = float((psf * x).sum() / s)
1047
+ var_y = float((psf * (y - cy) ** 2).sum() / s)
1048
+ var_x = float((psf * (x - cx) ** 2).sum() / s)
1049
+ sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
1050
+ return 2.3548 * sigma # FWHM≈2.355σ
1051
+
1052
+ STAR_MASK_MAXSIDE = 2048
1053
+ STAR_MASK_MAXOBJS = 2000 # cap number of objects
1054
+ VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
1055
+ THRESHOLD_SIGMA = 2.0
1056
+ KEEP_FLOOR = 0.20
1057
+ GROW_PX = 8
1058
+ MAX_STAR_RADIUS = 16
1059
+ SOFT_SIGMA = 2.0
1060
+ ELLIPSE_SCALE = 1.2
1061
+
1062
+ def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
1063
+ """
1064
+ One-time SEP background build; returns (sky_map, rms_map, err_scalar).
1065
+
1066
+ Guarantees:
1067
+ - Always returns a 3-tuple (sky, rms, err)
1068
+ - sky/rms are float32 and same shape as img_2d
1069
+ - Robust to sep missing, sep errors, NaNs/Infs, and tiny frames
1070
+ """
1071
+ a = np.asarray(img_2d, dtype=np.float32)
1072
+ if a.ndim != 2:
1073
+ # be strict; callers expect 2D
1074
+ raise ValueError(f"_sep_background_precompute expects 2D, got shape={a.shape}")
1075
+
1076
+ H, W = int(a.shape[0]), int(a.shape[1])
1077
+ if H == 0 or W == 0:
1078
+ # should never happen, but don't return empty tuple
1079
+ sky = np.zeros((H, W), dtype=np.float32)
1080
+ rms = np.ones((H, W), dtype=np.float32)
1081
+ return sky, rms, 1.0
1082
+
1083
+ # --- robust fallback builder (works for any input) ---
1084
+ def _fallback():
1085
+ # Use finite-only stats if possible
1086
+ finite = np.isfinite(a)
1087
+ if finite.any():
1088
+ vals = a[finite]
1089
+ med = float(np.median(vals))
1090
+ mad = float(np.median(np.abs(vals - med))) + 1e-6
1091
+ else:
1092
+ med = 0.0
1093
+ mad = 1.0
1094
+ sky = np.full((H, W), med, dtype=np.float32)
1095
+ rms = np.full((H, W), 1.4826 * mad, dtype=np.float32)
1096
+ err = float(np.median(rms))
1097
+ return sky, rms, err
1098
+
1099
+ # If sep isn't available, always fallback
1100
+ if sep is None:
1101
+ return _fallback()
1102
+
1103
+ # SEP is present: sanitize input and clamp tile sizes
1104
+ # sep can choke on NaNs/Infs
1105
+ if not np.isfinite(a).all():
1106
+ # replace non-finite with median of finite values (or 0)
1107
+ finite = np.isfinite(a)
1108
+ fill = float(np.median(a[finite])) if finite.any() else 0.0
1109
+ a = np.where(finite, a, fill).astype(np.float32, copy=False)
1110
+
1111
+ a = np.ascontiguousarray(a, dtype=np.float32)
1112
+
1113
+ # Clamp bw/bh to image size; SEP doesn't like bw/bh > dims
1114
+ bw = int(max(8, min(int(bw), W)))
1115
+ bh = int(max(8, min(int(bh), H)))
1116
+
1117
+ try:
1118
+ b = sep.Background(a, bw=bw, bh=bh, fw=3, fh=3)
1119
+
1120
+ sky = np.asarray(b.back(), dtype=np.float32)
1121
+ rms = np.asarray(b.rms(), dtype=np.float32)
1122
+
1123
+ # Ensure shape sanity (SEP should match, but be paranoid)
1124
+ if sky.shape != a.shape or rms.shape != a.shape:
1125
+ return _fallback()
1126
+
1127
+ # globalrms sometimes isn't available depending on SEP build
1128
+ err = float(getattr(b, "globalrms", np.nan))
1129
+ if not np.isfinite(err) or err <= 0:
1130
+ # robust scalar: median rms
1131
+ err = float(np.median(rms)) if rms.size else 1.0
1132
+
1133
+ return sky, rms, err
1134
+
1135
+ except Exception:
1136
+ # If SEP blows up for any reason, degrade gracefully
1137
+ return _fallback()
1138
+
1139
+
1140
+
1141
+ def _auto_star_mask_sep(
1142
+ img_2d: np.ndarray,
1143
+ thresh_sigma: float = THRESHOLD_SIGMA,
1144
+ grow_px: int = GROW_PX,
1145
+ max_objs: int = STAR_MASK_MAXOBJS,
1146
+ max_side: int = STAR_MASK_MAXSIDE,
1147
+ ellipse_scale: float = ELLIPSE_SCALE,
1148
+ soft_sigma: float = SOFT_SIGMA,
1149
+ max_semiaxis_px: float | None = None, # kept for API compat; unused
1150
+ max_area_px2: float | None = None, # kept for API compat; unused
1151
+ max_radius_px: int = MAX_STAR_RADIUS,
1152
+ keep_floor: float = KEEP_FLOOR,
1153
+ status_cb=lambda s: None
1154
+ ) -> np.ndarray:
1155
+ """
1156
+ Build a KEEP weight map (float32 in [0,1]) using SEP detections only.
1157
+ **Never writes to img_2d** and draws only into a fresh mask buffer.
1158
+ """
1159
+ if sep is None:
1160
+ return np.ones_like(img_2d, dtype=np.float32, order="C")
1161
+
1162
+ # Optional OpenCV path for fast draw/blur
1163
+ try:
1164
+ import cv2 as _cv2
1165
+ _HAS_CV2 = True
1166
+ except Exception:
1167
+ _HAS_CV2 = False
1168
+ _cv2 = None # type: ignore
1169
+
1170
+ h, w = map(int, img_2d.shape)
1171
+
1172
+ # Background / residual on our own contiguous buffer
1173
+ data = np.ascontiguousarray(img_2d.astype(np.float32))
1174
+ bkg = sep.Background(data)
1175
+ data_sub = np.ascontiguousarray(data - bkg.back(), dtype=np.float32)
1176
+ try:
1177
+ err_scalar = float(bkg.globalrms)
1178
+ except Exception:
1179
+ err_scalar = float(np.median(np.asarray(bkg.rms(), dtype=np.float32)))
1180
+
1181
+ # Downscale for detection only
1182
+ det = data_sub
1183
+ scale = 1.0
1184
+ if max_side and max(h, w) > int(max_side):
1185
+ scale = float(max(h, w)) / float(max_side)
1186
+ if _HAS_CV2:
1187
+ det = _cv2.resize(
1188
+ det,
1189
+ (max(1, int(round(w / scale))), max(1, int(round(h / scale)))),
1190
+ interpolation=_cv2.INTER_AREA
1191
+ )
1192
+ else:
1193
+ s = int(max(1, round(scale)))
1194
+ det = det[:(h // s) * s, :(w // s) * s].reshape(h // s, s, w // s, s).mean(axis=(1, 3))
1195
+ scale = float(s)
1196
+
1197
+ # When averaging down by 'scale', per-pixel noise scales ~ 1/scale
1198
+ err_det = float(err_scalar) / float(max(1.0, scale))
1199
+
1200
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1201
+ thresh_sigma*8, thresh_sigma*16]
1202
+ objs = None; used = float("nan"); raw = 0
1203
+ for t in thresholds:
1204
+ try:
1205
+ cand = sep.extract(det, thresh=float(t), err=err_det)
1206
+ except Exception:
1207
+ cand = None
1208
+ n = 0 if cand is None else len(cand)
1209
+ if n == 0: continue
1210
+ if n > max_objs*12: continue
1211
+ objs, raw, used = cand, n, float(t)
1212
+ break
1213
+
1214
+ if objs is None or len(objs) == 0:
1215
+ try:
1216
+ cand = sep.extract(det, thresh=float(thresholds[-1]), err=err_det, minarea=9)
1217
+ except Exception:
1218
+ cand = None
1219
+ if cand is None or len(cand) == 0:
1220
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1221
+ return np.ones((h, w), dtype=np.float32, order="C")
1222
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1223
+
1224
+ # Keep brightest max_objs
1225
+ if "flux" in objs.dtype.names:
1226
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1227
+ objs = objs[idx]
1228
+ else:
1229
+ objs = objs[:int(max_objs)]
1230
+ kept_after_cap = int(len(objs))
1231
+
1232
+ # Draw on full-res fresh buffer
1233
+ mask_u8 = np.zeros((h, w), dtype=np.uint8, order="C")
1234
+ s_back = float(scale)
1235
+ MR = int(max(1, max_radius_px))
1236
+ G = int(max(0, grow_px))
1237
+ ES = float(max(0.1, ellipse_scale))
1238
+
1239
+ drawn = 0
1240
+ if _HAS_CV2:
1241
+ for o in objs:
1242
+ x = int(round(float(o["x"]) * s_back))
1243
+ y = int(round(float(o["y"]) * s_back))
1244
+ if not (0 <= x < w and 0 <= y < h):
1245
+ continue
1246
+ a = float(o["a"]) * s_back
1247
+ b = float(o["b"]) * s_back
1248
+ r = int(math.ceil(ES * max(a, b)))
1249
+ r = min(max(r, 0) + G, MR)
1250
+ if r <= 0:
1251
+ continue
1252
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1253
+ drawn += 1
1254
+ else:
1255
+ yy, xx = np.ogrid[:h, :w]
1256
+ for o in objs:
1257
+ x = int(round(float(o["x"]) * s_back))
1258
+ y = int(round(float(o["y"]) * s_back))
1259
+ if not (0 <= x < w and 0 <= y < h):
1260
+ continue
1261
+ a = float(o["a"]) * s_back
1262
+ b = float(o["b"]) * s_back
1263
+ r = int(math.ceil(ES * max(a, b)))
1264
+ r = min(max(r, 0) + G, MR)
1265
+ if r <= 0:
1266
+ continue
1267
+ y0 = max(0, y - r); y1 = min(h, y + r + 1)
1268
+ x0 = max(0, x - r); x1 = min(w, x + r + 1)
1269
+ ys = yy[y0:y1] - y
1270
+ xs = xx[x0:x1] - x
1271
+ disk = (ys * ys + xs * xs) <= (r * r)
1272
+ mask_u8[y0:y1, x0:x1][disk] = 1
1273
+ drawn += 1
1274
+
1275
+ masked_px_hard = int(mask_u8.sum())
1276
+
1277
+ # Feather and convert to KEEP weights in [0,1]
1278
+ m = mask_u8.astype(np.float32, copy=False)
1279
+ if soft_sigma and soft_sigma > 0.0:
1280
+ try:
1281
+ if _HAS_CV2:
1282
+ k = int(max(1, math.ceil(soft_sigma * 3)) * 2 + 1)
1283
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1284
+ borderType=_cv2.BORDER_REFLECT)
1285
+ else:
1286
+ from scipy.ndimage import gaussian_filter
1287
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1288
+ except Exception:
1289
+ pass
1290
+ np.clip(m, 0.0, 1.0, out=m)
1291
+
1292
+ keep = 1.0 - m
1293
+ kf = float(max(0.0, min(0.99, keep_floor)))
1294
+ keep = kf + (1.0 - kf) * keep
1295
+ np.clip(keep, 0.0, 1.0, out=keep)
1296
+
1297
+ status_cb(
1298
+ f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept_after_cap} | "
1299
+ f"drawn={drawn} | masked_px={masked_px_hard} | grow_px={G} | soft_sigma={soft_sigma} | keep_floor={keep_floor}"
1300
+ )
1301
+ return np.ascontiguousarray(keep, dtype=np.float32)
1302
+
1303
+
1304
+
1305
+ def _auto_variance_map(
1306
+ img_2d: np.ndarray,
1307
+ hdr,
1308
+ status_cb=lambda s: None,
1309
+ sample_stride: int = VARMAP_SAMPLE_STRIDE, # kept for signature compat; not used
1310
+ bw: int = 64, # SEP background box width (pixels)
1311
+ bh: int = 64, # SEP background box height (pixels)
1312
+ smooth_sigma: float = 1.0, # Gaussian sigma (px) to smooth the variance map
1313
+ floor: float = 1e-8, # hard floor to prevent blow-up in 1/var
1314
+ ) -> np.ndarray:
1315
+ """
1316
+ Build a per-pixel variance map in DN^2:
1317
+
1318
+ var_DN ≈ (object_only_DN)/gain + var_bg_DN^2
1319
+
1320
+ where:
1321
+ - object_only_DN = max(img - sky_DN, 0)
1322
+ - var_bg_DN^2 comes from SEP's local background rms (Poisson(sky)+readnoise)
1323
+ - if GAIN is missing, estimate 1/gain ≈ median(var_bg)/median(sky)
1324
+
1325
+ Returns float32 array, clipped below by `floor`, optionally smoothed with a
1326
+ small Gaussian to stabilize weights. Emits a summary status line.
1327
+ """
1328
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1329
+
1330
+ # --- Parse header for camera params (optional) ---
1331
+ gain = None
1332
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1333
+ if k in hdr:
1334
+ try:
1335
+ g = float(hdr[k])
1336
+ if np.isfinite(g) and g > 0:
1337
+ gain = g
1338
+ break
1339
+ except Exception:
1340
+ pass
1341
+
1342
+ readnoise = None
1343
+ for k in ("RDNOISE", "READNOISE", "RN"):
1344
+ if k in hdr:
1345
+ try:
1346
+ rn = float(hdr[k])
1347
+ if np.isfinite(rn) and rn >= 0:
1348
+ readnoise = rn
1349
+ break
1350
+ except Exception:
1351
+ pass
1352
+
1353
+ # --- Local background (full-res) ---
1354
+ if sep is not None:
1355
+ try:
1356
+ b = sep.Background(img, bw=int(bw), bh=int(bh), fw=3, fh=3)
1357
+ sky_dn_map = np.asarray(b.back(), dtype=np.float32)
1358
+ try:
1359
+ rms_dn_map = np.asarray(b.rms(), dtype=np.float32)
1360
+ except Exception:
1361
+ rms_dn_map = np.full_like(img, float(np.median(b.rms())), dtype=np.float32)
1362
+ except Exception:
1363
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1364
+ med = float(np.median(img))
1365
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1366
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1367
+ else:
1368
+ sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
1369
+ med = float(np.median(img))
1370
+ mad = float(np.median(np.abs(img - med))) + 1e-6
1371
+ rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
1372
+
1373
+ # Background variance in DN^2
1374
+ var_bg_dn2 = np.maximum(rms_dn_map, 1e-6) ** 2
1375
+
1376
+ # Object-only DN
1377
+ obj_dn = np.clip(img - sky_dn_map, 0.0, None)
1378
+
1379
+ # Shot-noise coefficient
1380
+ if gain is not None and np.isfinite(gain) and gain > 0:
1381
+ a_shot = 1.0 / gain
1382
+ else:
1383
+ sky_med = float(np.median(sky_dn_map))
1384
+ varbg_med = float(np.median(var_bg_dn2))
1385
+ if sky_med > 1e-6:
1386
+ a_shot = np.clip(varbg_med / sky_med, 0.0, 10.0) # ~ 1/gain estimate
1387
+ else:
1388
+ a_shot = 0.0
1389
+
1390
+ # Total variance: background + shot noise from object-only flux
1391
+ v = var_bg_dn2 + a_shot * obj_dn
1392
+ v_raw = v.copy()
1393
+
1394
+ # Optional mild smoothing
1395
+ if smooth_sigma and smooth_sigma > 0:
1396
+ try:
1397
+ import cv2 as _cv2
1398
+ k = int(max(1, int(round(3 * float(smooth_sigma)))) * 2 + 1)
1399
+ v = _cv2.GaussianBlur(v, (k, k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1400
+ except Exception:
1401
+ try:
1402
+ from scipy.ndimage import gaussian_filter
1403
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1404
+ except Exception:
1405
+ r = int(max(1, round(3 * float(smooth_sigma))))
1406
+ yy, xx = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1407
+ gk = np.exp(-(xx*xx + yy*yy) / (2.0 * float(smooth_sigma) * float(smooth_sigma))).astype(np.float32)
1408
+ gk /= (gk.sum() + EPS)
1409
+ v = _conv_same_np(v, gk)
1410
+
1411
+ # Clip to avoid zero/negative variances
1412
+ v = np.clip(v, float(floor), None).astype(np.float32, copy=False)
1413
+
1414
+ # Emit telemetry
1415
+ try:
1416
+ sky_med = float(np.median(sky_dn_map))
1417
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1418
+ floor_pct = float((v <= floor).mean() * 100.0)
1419
+ status_cb(
1420
+ "Variance map: "
1421
+ f"sky_med={sky_med:.3g} DN | rms_med={rms_med:.3g} DN | "
1422
+ f"gain={(gain if gain is not None else 'NA')} | rn={(readnoise if readnoise is not None else 'NA')} | "
1423
+ f"smooth_sigma={smooth_sigma} | floor={floor} ({floor_pct:.2f}% at floor)"
1424
+ )
1425
+ except Exception:
1426
+ pass
1427
+
1428
+ return v
1429
+
1430
+
1431
+ def _star_mask_from_precomputed(
1432
+ img_2d: np.ndarray,
1433
+ sky_map: np.ndarray,
1434
+ err_scalar: float,
1435
+ *,
1436
+ thresh_sigma: float,
1437
+ max_objs: int,
1438
+ grow_px: int,
1439
+ ellipse_scale: float,
1440
+ soft_sigma: float,
1441
+ max_radius_px: int,
1442
+ keep_floor: float,
1443
+ max_side: int,
1444
+ status_cb=lambda s: None
1445
+ ) -> np.ndarray:
1446
+ """
1447
+ Build a KEEP weight map using a *downscaled detection / full-res draw* path.
1448
+ **Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
1449
+ """
1450
+ # Optional OpenCV fast path
1451
+ try:
1452
+ import cv2 as _cv2
1453
+ _HAS_CV2 = True
1454
+ except Exception:
1455
+ _HAS_CV2 = False
1456
+ _cv2 = None # type: ignore
1457
+
1458
+ H, W = map(int, img_2d.shape)
1459
+
1460
+ # Residual for detection (contiguous, separate buffer)
1461
+ data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
1462
+
1463
+ # Downscale *detection only* to speed up, never the draw step
1464
+ det = data_sub
1465
+ scale = 1.0
1466
+ if max_side and max(H, W) > int(max_side):
1467
+ scale = float(max(H, W)) / float(max_side)
1468
+ if _HAS_CV2:
1469
+ det = _cv2.resize(
1470
+ det,
1471
+ (max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
1472
+ interpolation=_cv2.INTER_AREA
1473
+ )
1474
+ else:
1475
+ s = int(max(1, round(scale)))
1476
+ det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
1477
+ scale = float(s)
1478
+
1479
+ # Threshold ladder
1480
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1481
+ thresh_sigma*8, thresh_sigma*16]
1482
+ objs = None; used = float("nan"); raw = 0
1483
+ for t in thresholds:
1484
+ cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
1485
+ n = 0 if cand is None else len(cand)
1486
+ if n == 0: continue
1487
+ if n > max_objs*12: continue
1488
+ objs, raw, used = cand, n, float(t)
1489
+ break
1490
+
1491
+ if objs is None or len(objs) == 0:
1492
+ try:
1493
+ cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
1494
+ except Exception:
1495
+ cand = None
1496
+ if cand is None or len(cand) == 0:
1497
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1498
+ return np.ones((H, W), dtype=np.float32, order="C")
1499
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1500
+
1501
+ # Brightest max_objs
1502
+ if "flux" in objs.dtype.names:
1503
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1504
+ objs = objs[idx]
1505
+ else:
1506
+ objs = objs[:int(max_objs)]
1507
+ kept = len(objs)
1508
+
1509
+ # ---- draw back on full-res into a brand-new buffer ----
1510
+ mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
1511
+ s_back = float(scale)
1512
+ MR = int(max(1, max_radius_px))
1513
+ G = int(max(0, grow_px))
1514
+ ES = float(max(0.1, ellipse_scale))
1515
+
1516
+ drawn = 0
1517
+ if _HAS_CV2:
1518
+ for o in objs:
1519
+ x = int(round(float(o["x"]) * s_back))
1520
+ y = int(round(float(o["y"]) * s_back))
1521
+ if not (0 <= x < W and 0 <= y < H):
1522
+ continue
1523
+ a = float(o["a"]) * s_back
1524
+ b = float(o["b"]) * s_back
1525
+ r = int(math.ceil(ES * max(a, b)))
1526
+ r = min(max(r, 0) + G, MR)
1527
+ if r <= 0:
1528
+ continue
1529
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1530
+ drawn += 1
1531
+ else:
1532
+ for o in objs:
1533
+ x = int(round(float(o["x"]) * s_back))
1534
+ y = int(round(float(o["y"]) * s_back))
1535
+ if not (0 <= x < W and 0 <= y < H):
1536
+ continue
1537
+ a = float(o["a"]) * s_back
1538
+ b = float(o["b"]) * s_back
1539
+ r = int(math.ceil(ES * max(a, b)))
1540
+ r = min(max(r, 0) + G, MR)
1541
+ if r <= 0:
1542
+ continue
1543
+ y0 = max(0, y - r); y1 = min(H, y + r + 1)
1544
+ x0 = max(0, x - r); x1 = min(W, x + r + 1)
1545
+ yy, xx = np.ogrid[y0:y1, x0:x1]
1546
+ disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
1547
+ mask_u8[y0:y1, x0:x1][disk] = 1
1548
+ drawn += 1
1549
+
1550
+ # Feather + convert to keep weights
1551
+ m = mask_u8.astype(np.float32, copy=False)
1552
+ if soft_sigma > 0:
1553
+ try:
1554
+ if _HAS_CV2:
1555
+ k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
1556
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1557
+ borderType=_cv2.BORDER_REFLECT)
1558
+ else:
1559
+ from scipy.ndimage import gaussian_filter
1560
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1561
+ except Exception:
1562
+ pass
1563
+ np.clip(m, 0.0, 1.0, out=m)
1564
+
1565
+ keep = 1.0 - m
1566
+ kf = float(max(0.0, min(0.99, keep_floor)))
1567
+ keep = kf + (1.0 - kf) * keep
1568
+ np.clip(keep, 0.0, 1.0, out=keep)
1569
+
1570
+ status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
1571
+ return np.ascontiguousarray(keep, dtype=np.float32)
1572
+
1573
+
1574
+ def _variance_map_from_precomputed(
1575
+ img_2d: np.ndarray,
1576
+ sky_map: np.ndarray,
1577
+ rms_map: np.ndarray,
1578
+ hdr,
1579
+ *,
1580
+ smooth_sigma: float,
1581
+ floor: float,
1582
+ status_cb=lambda s: None
1583
+ ) -> np.ndarray:
1584
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1585
+ var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
1586
+ obj_dn = np.clip(img - sky_map, 0.0, None)
1587
+
1588
+ gain = None
1589
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1590
+ if k in hdr:
1591
+ try:
1592
+ g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
1593
+ if gain is not None: break
1594
+ except Exception as e:
1595
+ import logging
1596
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1597
+
1598
+ if gain is not None:
1599
+ a_shot = 1.0 / gain
1600
+ else:
1601
+ sky_med = float(np.median(sky_map))
1602
+ varbg_med= float(np.median(var_bg_dn2))
1603
+ a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
1604
+ a_shot = float(np.clip(a_shot, 0.0, 10.0))
1605
+
1606
+ v = var_bg_dn2 + a_shot * obj_dn
1607
+ if smooth_sigma > 0:
1608
+ try:
1609
+ import cv2 as _cv2
1610
+ k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
1611
+ v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1612
+ except Exception:
1613
+ try:
1614
+ from scipy.ndimage import gaussian_filter
1615
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1616
+ except Exception:
1617
+ pass
1618
+
1619
+ np.clip(v, float(floor), None, out=v)
1620
+ try:
1621
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1622
+ status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
1623
+ except Exception:
1624
+ pass
1625
+ return v.astype(np.float32, copy=False)
1626
+
1627
+
1628
+
1629
+ # -----------------------------
1630
+ # Robust weighting (Huber)
1631
+ # -----------------------------
1632
+
1633
+ EPS = 1e-6
1634
+
1635
+ def _estimate_scalar_variance(a):
1636
+ med = np.median(a)
1637
+ mad = np.median(np.abs(a - med)) + 1e-6
1638
+ return float((1.4826 * mad) ** 2)
1639
+
1640
+ def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
1641
+ """
1642
+ W = [psi(r)/r] * 1/(var + eps) * mask, psi=Huber
1643
+ If huber_delta<0, delta = (-huber_delta) * RMS(residual) via MAD.
1644
+ y,pred: (H,W) or (C,H,W). var_map/mask are 2D; broadcast if needed.
1645
+ """
1646
+ r = y - pred
1647
+ # auto delta?
1648
+ if huber_delta < 0:
1649
+ med = np.median(r)
1650
+ mad = np.median(np.abs(r - med)) + 1e-6
1651
+ rms = 1.4826 * mad
1652
+ delta = (-huber_delta) * max(rms, 1e-6)
1653
+ else:
1654
+ delta = huber_delta
1655
+
1656
+ absr = np.abs(r)
1657
+ if float(delta) > 0:
1658
+ psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + EPS)).astype(np.float32)
1659
+ else:
1660
+ psi_over_r = np.ones_like(r, dtype=np.float32)
1661
+
1662
+ if var_map is None:
1663
+ v = _estimate_scalar_variance(r)
1664
+ else:
1665
+ v = var_map
1666
+ if v.ndim == 2 and r.ndim == 3:
1667
+ v = v[None, ...]
1668
+ w = psi_over_r / (v + EPS)
1669
+ if mask is not None:
1670
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1671
+ w = w * m
1672
+ return w
1673
+
1674
+
1675
+ # -----------------------------
1676
+ # Torch / conv
1677
+ # -----------------------------
1678
+
1679
+ def _fftshape_same(H, W, kh, kw):
1680
+ return H + kh - 1, W + kw - 1
1681
+
1682
+ # ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
1683
+ def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
1684
+ """
1685
+ Pack (Kf,padH,padW,kh,kw) so the conv can crop correctly to SAME.
1686
+ Kernel is ifftshifted before padding.
1687
+ """
1688
+ tfft = torch.fft
1689
+ psf_fft, psfT_fft = [], []
1690
+ for k, kT in zip(psfs, flip_psf):
1691
+ kh, kw = k.shape
1692
+ padH, padW = _fftshape_same(H, W, kh, kw)
1693
+ k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
1694
+ kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
1695
+ Kf = tfft.rfftn(k_small, s=(padH, padW))
1696
+ KTf = tfft.rfftn(kT_small, s=(padH, padW))
1697
+ psf_fft.append((Kf, padH, padW, kh, kw))
1698
+ psfT_fft.append((KTf, padH, padW, kh, kw))
1699
+ return psf_fft, psfT_fft
1700
+
1701
+
1702
+ def _fft_conv_same_torch(x, Kf_pack, out_spatial):
1703
+ tfft = torch.fft
1704
+ Kf, padH, padW, kh, kw = Kf_pack
1705
+ H, W = x.shape[-2], x.shape[-1]
1706
+ if x.ndim == 2:
1707
+ X = tfft.rfftn(x, s=(padH, padW))
1708
+ y = tfft.irfftn(X * Kf, s=(padH, padW))
1709
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1710
+ out_spatial.copy_(y[sh:sh+H, sw:sw+W])
1711
+ return out_spatial
1712
+ else:
1713
+ X = tfft.rfftn(x, s=(padH, padW), dim=(-2,-1))
1714
+ y = tfft.irfftn(X * Kf, s=(padH, padW), dim=(-2,-1))
1715
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1716
+ out_spatial.copy_(y[..., sh:sh+H, sw:sw+W])
1717
+ return out_spatial
1718
+
1719
+ # ---------- NumPy FFT helpers ----------
1720
+ def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
1721
+ import numpy.fft as fft
1722
+ meta, Kfs, KTfs = [], [], []
1723
+ for k, kT in zip(psfs, flip_psf):
1724
+ kh, kw = k.shape
1725
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1726
+ Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
1727
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
1728
+ meta.append((kh, kw, fftH, fftW))
1729
+ return Kfs, KTfs, meta
1730
+
1731
+ def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
1732
+ import numpy.fft as fft
1733
+ if a.ndim == 2:
1734
+ A = fft.rfftn(a, s=(fftH, fftW))
1735
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1736
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1737
+ out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
1738
+ return out
1739
+ else:
1740
+ C, H, W = a.shape
1741
+ acc = []
1742
+ for c in range(C):
1743
+ A = fft.rfftn(a[c], s=(fftH, fftW))
1744
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1745
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1746
+ acc.append(y[sh:sh+H, sw:sw+W])
1747
+ out[...] = np.stack(acc, 0)
1748
+ return out
1749
+
1750
+
1751
+
1752
+ def _torch_device():
1753
+ if TORCH_OK and (torch is not None):
1754
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1755
+ return torch.device("cuda")
1756
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1757
+ return torch.device("mps")
1758
+ # DirectML: we passed dml_device from outer scope; keep a module-global
1759
+ if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
1760
+ return globals()["dml_device"]
1761
+ return torch.device("cpu")
1762
+
1763
+ def _to_t(x: np.ndarray):
1764
+ if not (TORCH_OK and (torch is not None)):
1765
+ raise RuntimeError("Torch path requested but torch is unavailable")
1766
+ device = _torch_device()
1767
+ t = torch.from_numpy(x)
1768
+ # DirectML wants explicit .to(device)
1769
+ return t.to(device, non_blocking=True) if str(device) != "cpu" else t
1770
+
1771
+ def _contig(x):
1772
+ return np.ascontiguousarray(x, dtype=np.float32)
1773
+
1774
+ def _conv_same_torch(img_t, psf_t):
1775
+ """
1776
+ img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
1777
+ psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
1778
+ Pads with 'reflect' to avoid zero-padding ringing.
1779
+ """
1780
+ kh, kw = psf_t.shape[-2:]
1781
+ pad = (kw // 2, kw - kw // 2 - 1, # left, right
1782
+ kh // 2, kh - kh // 2 - 1) # top, bottom
1783
+
1784
+ if img_t.ndim == 2:
1785
+ x = img_t[None, None]
1786
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1787
+ y = torch.nn.functional.conv2d(x, psf_t, padding=0)
1788
+ return y[0, 0]
1789
+ else:
1790
+ C = img_t.shape[0]
1791
+ x = img_t[None]
1792
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1793
+ w = psf_t.repeat(C, 1, 1, 1)
1794
+ y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
1795
+ return y[0]
1796
+
1797
+ def _safe_inference_context():
1798
+ """
1799
+ Return a valid, working no-grad context:
1800
+ - prefer torch.inference_mode() if it exists *and* can be entered,
1801
+ - otherwise fall back to torch.no_grad(),
1802
+ - if torch is unavailable, return NO_GRAD.
1803
+ """
1804
+ if not (TORCH_OK and (torch is not None)):
1805
+ return NO_GRAD
1806
+
1807
+ cm = getattr(torch, "inference_mode", None)
1808
+ if cm is None:
1809
+ return torch.no_grad
1810
+
1811
+ # Probe inference_mode once; if it explodes on this build, fall back.
1812
+ try:
1813
+ with cm():
1814
+ pass
1815
+ return cm
1816
+ except Exception:
1817
+ return torch.no_grad
1818
+
1819
+ def _ensure_mask_list(masks, data):
1820
+ # 1s where valid, 0s where invalid (soft edges allowed)
1821
+ if masks is None:
1822
+ return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
1823
+ out = []
1824
+ for a, m in zip(data, masks):
1825
+ base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
1826
+ if m is None:
1827
+ out.append(np.ones_like(base, dtype=np.float32))
1828
+ else:
1829
+ mm = np.asarray(m, dtype=np.float32)
1830
+ if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
1831
+ mm = mm[0]
1832
+ if mm.shape != base.shape:
1833
+ # center crop to match (common intersection already applied)
1834
+ Ht, Wt = base.shape
1835
+ mm = _center_crop(mm, Ht, Wt)
1836
+ # keep as float weights in [0,1] (do not threshold!)
1837
+ out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
1838
+ return out
1839
+
1840
+ def _ensure_var_list(variances, data):
1841
+ # If None, we’ll estimate a robust scalar per frame on-the-fly.
1842
+ if variances is None:
1843
+ return [None]*len(data)
1844
+ out = []
1845
+ for a, v in zip(data, variances):
1846
+ if v is None:
1847
+ out.append(None)
1848
+ else:
1849
+ vv = np.asarray(v, dtype=np.float32)
1850
+ if vv.ndim == 3:
1851
+ vv = vv[0]
1852
+ base = a if a.ndim==2 else a[0]
1853
+ if vv.shape != base.shape:
1854
+ Ht, Wt = base.shape
1855
+ vv = _center_crop(vv, Ht, Wt)
1856
+ # clip tiny/negatives
1857
+ vv = np.nan_to_num(vv, nan=1e-8, posinf=1e8, neginf=1e8, copy=False)
1858
+ vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
1859
+ out.append(vv)
1860
+ return out
1861
+
1862
+ # ---- SR operators (downsample / upsample-sum) ----
1863
+ def _downsample_avg(img, r: int):
1864
+ """Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
1865
+ if r <= 1:
1866
+ return img
1867
+ a = np.asarray(img, dtype=np.float32)
1868
+ if a.ndim == 2:
1869
+ H, W = a.shape
1870
+ Hs, Ws = (H // r) * r, (W // r) * r
1871
+ a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
1872
+ return a
1873
+ else:
1874
+ C, H, W = a.shape
1875
+ Hs, Ws = (H // r) * r, (W // r) * r
1876
+ a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
1877
+ return a
1878
+
1879
+ def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
1880
+ """Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
1881
+ For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
1882
+ """
1883
+ if r <= 1:
1884
+ return img
1885
+ a = np.asarray(img, dtype=np.float32)
1886
+ if a.ndim == 2:
1887
+ H, W = a.shape
1888
+ out = np.kron(a, np.ones((r, r), dtype=np.float32))
1889
+ else:
1890
+ C, H, W = a.shape
1891
+ out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
1892
+ if target_hw is not None:
1893
+ Ht, Wt = target_hw
1894
+ out = _center_crop(out, Ht, Wt)
1895
+ return out
1896
+
1897
+ def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
1898
+ r = (ksize - 1) // 2
1899
+ y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1900
+ g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
1901
+ g /= g.sum() + EPS
1902
+ return g
1903
+
1904
+ def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
1905
+ # lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
1906
+ return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
1907
+
1908
+ def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
1909
+ iters: int = 500, lr: float = 0.1) -> np.ndarray:
1910
+ """
1911
+ Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
1912
+ where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
1913
+ """
1914
+ f = np.asarray(f_native, dtype=np.float32)
1915
+
1916
+ # NEW: sanitize to 2D odd square before anything else
1917
+ if f.ndim != 2:
1918
+ f = np.squeeze(f)
1919
+ if f.ndim != 2:
1920
+ raise ValueError(f"PSF must be 2D, got shape {f.shape}")
1921
+
1922
+ H, W = int(f.shape[0]), int(f.shape[1])
1923
+ k_sq = min(H, W)
1924
+ # center-crop to square if needed
1925
+ if H != W:
1926
+ y0 = (H - k_sq) // 2
1927
+ x0 = (W - k_sq) // 2
1928
+ f = f[y0:y0 + k_sq, x0:x0 + k_sq]
1929
+ H = W = k_sq
1930
+
1931
+ # enforce odd size (required by SAME padding math)
1932
+ if (H % 2) == 0:
1933
+ # drop one pixel border to make it odd (centered)
1934
+ f = f[1:, 1:]
1935
+ H = W = f.shape[0]
1936
+
1937
+ k = int(H) # k is now odd and square
1938
+ kr = int(k * r)
1939
+
1940
+
1941
+ g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
1942
+
1943
+ h0 = np.zeros((kr, kr), dtype=np.float32)
1944
+ h0[::r, ::r] = f
1945
+ h0 = _normalize_psf(h0)
1946
+
1947
+ if TORCH_OK:
1948
+ import torch.nn.functional as F
1949
+ dev = _torch_device()
1950
+
1951
+ # (1) Make sure Gaussian kernel is odd-sized for SAME conv padding
1952
+ g_pad = g
1953
+ if (g.shape[-1] % 2) == 0:
1954
+ # ensure odd + renormalize
1955
+ gg = _pad_kernel_to(g, g.shape[-1] + 1)
1956
+ g_pad = gg.astype(np.float32, copy=False)
1957
+
1958
+ t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
1959
+ f_t = torch.tensor(f, device=dev, dtype=torch.float32)
1960
+ g_t = torch.tensor(g_pad, device=dev, dtype=torch.float32)
1961
+ opt = torch.optim.Adam([t], lr=lr)
1962
+
1963
+ # Helpful assertion avoids silent shape traps
1964
+ H, W = t.shape
1965
+ assert (H % r) == 0 and (W % r) == 0, f"h shape {t.shape} not divisible by r={r}"
1966
+ Hr, Wr = H // r, W // r
1967
+
1968
+ try:
1969
+ for _ in range(max(10, iters)):
1970
+ opt.zero_grad(set_to_none=True)
1971
+
1972
+ # (2) Downsample with avg_pool2d instead of reshape/mean
1973
+ blk = t.narrow(0, 0, Hr * r).narrow(1, 0, Wr * r).contiguous()
1974
+ th = F.avg_pool2d(blk[None, None], kernel_size=r, stride=r)[0, 0] # (k,k)
1975
+
1976
+ # (3) Native-space blur with guaranteed-odd g_t
1977
+ pad = g_t.shape[-1] // 2
1978
+ conv = F.conv2d(th[None, None], g_t[None, None], padding=pad)[0, 0]
1979
+
1980
+ loss = torch.mean((conv - f_t) ** 2)
1981
+ loss.backward()
1982
+ opt.step()
1983
+ with torch.no_grad():
1984
+ t.clamp_(min=0.0)
1985
+ t /= (t.sum() + 1e-8)
1986
+ h = t.detach().cpu().numpy().astype(np.float32)
1987
+ except Exception:
1988
+ # (4) Conservative safety net: if a backend balks (commonly at r=2),
1989
+ # fall back to the NumPy solver *just for this kernel*.
1990
+ h = None
1991
+
1992
+ if not TORCH_OK or h is None:
1993
+ # NumPy fallback (unchanged)
1994
+ h = h0.copy()
1995
+ eta = float(lr)
1996
+ for _ in range(max(50, iters)):
1997
+ Dh = _downsample_avg(h, r)
1998
+ conv = _conv2_same_np(Dh, g)
1999
+ resid = (conv - f)
2000
+ grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
2001
+ grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
2002
+ h = np.clip(h - eta * grad_h, 0.0, None)
2003
+ s = float(h.sum()); h /= (s + 1e-8)
2004
+ eta *= 0.995
2005
+
2006
+ return _normalize_psf(h)
2007
+
2008
+
2009
+ def _downsample_avg_t(x, r: int):
2010
+ if r <= 1:
2011
+ return x
2012
+ if x.ndim == 2:
2013
+ H, W = x.shape
2014
+ Hr, Wr = (H // r) * r, (W // r) * r
2015
+ if Hr == 0 or Wr == 0:
2016
+ return x
2017
+ x2 = x[:Hr, :Wr]
2018
+ # ❌ .view → ✅ .reshape
2019
+ return x2.reshape(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
2020
+ else:
2021
+ C, H, W = x.shape
2022
+ Hr, Wr = (H // r) * r, (W // r) * r
2023
+ if Hr == 0 or Wr == 0:
2024
+ return x
2025
+ x2 = x[:, :Hr, :Wr]
2026
+ # ❌ .view → ✅ .reshape
2027
+ return x2.reshape(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
2028
+
2029
+ def _upsample_sum_t(x, r: int):
2030
+ if r <= 1:
2031
+ return x
2032
+ if x.ndim == 2:
2033
+ return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
2034
+ else:
2035
+ return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
2036
+
2037
+ def _sep_bg_rms(frames):
2038
+ """Return a robust background RMS using SEP's background model on the first frame."""
2039
+ if sep is None or not frames:
2040
+ return None
2041
+ try:
2042
+ y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
2043
+ a = np.ascontiguousarray(y0, dtype=np.float32)
2044
+ b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
2045
+ try:
2046
+ rms_val = float(b.globalrms)
2047
+ except Exception:
2048
+ # some SEP builds don’t expose globalrms; fall back to the map’s median
2049
+ rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
2050
+ return rms_val
2051
+ except Exception:
2052
+ return None
2053
+
2054
+ # =========================
2055
+ # Memory/streaming helpers
2056
+ # =========================
2057
+
2058
+ def _approx_bytes(arr_like_shape, dtype=np.float32):
2059
+ """Rough byte estimator for a given shape/dtype."""
2060
+ return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
2061
+
2062
+ def _mem_model(
2063
+ grid_hw: tuple[int,int],
2064
+ r: int,
2065
+ ksize: int,
2066
+ channels: int,
2067
+ mem_target_mb: int,
2068
+ prefer_tiles: bool = False,
2069
+ min_tile: int = 256,
2070
+ max_tile: int = 2048,
2071
+ ) -> dict:
2072
+ """
2073
+ Pick a batch size (#frames) and optional tile size (HxW) given a memory budget.
2074
+ Very conservative — aims to bound peak working-set on CPU/GPU.
2075
+ """
2076
+ Hs, Ws = grid_hw
2077
+ halo = (ksize // 2) * max(1, r) # SR grid halo if r>1
2078
+ C = max(1, channels)
2079
+
2080
+ # working-set per *full-frame* conv scratch (num/den/tmp/etc.)
2081
+ per_frame_fft_like = 3 * _approx_bytes((C, Hs, Ws)) # tmp/pred + in/out buffers
2082
+ global_accum = 2 * _approx_bytes((C, Hs, Ws)) # num + den
2083
+
2084
+ budget = int(mem_target_mb * 1024 * 1024)
2085
+
2086
+ # Try to stay in full-frame mode first unless prefer_tiles
2087
+ B_full = max(1, (budget - global_accum) // max(per_frame_fft_like, 1))
2088
+ use_tiles = prefer_tiles or (B_full < 1)
2089
+
2090
+ if not use_tiles:
2091
+ return dict(batch_frames=int(B_full), tiles=None, halo=int(halo), ksize=int(ksize))
2092
+
2093
+ # Tile mode: pick a square tile side t that fits
2094
+ # scratch per tile ~ 3*C*(t+2h)^2 + accum(core) ~ small
2095
+ # try descending from max_tile
2096
+ t = int(min(max_tile, max(min_tile, 1 << int(np.floor(np.log2(min(Hs, Ws)))))))
2097
+ while t >= min_tile:
2098
+ th = t + 2 * halo
2099
+ per_tile = 3 * _approx_bytes((C, th, th))
2100
+ B_tile = max(1, (budget - global_accum) // max(per_tile, 1))
2101
+ if B_tile >= 1:
2102
+ return dict(batch_frames=int(B_tile), tiles=(t, t), halo=int(halo), ksize=int(ksize))
2103
+ t //= 2
2104
+
2105
+ # Worst case: 1 frame, minimal tile
2106
+ return dict(batch_frames=1, tiles=(min_tile, min_tile), halo=int(halo), ksize=int(ksize))
2107
+
2108
+ def _build_seed_running_mu_sigma_from_paths(
2109
+ paths, Ht, Wt, color_mode,
2110
+ *, bootstrap_frames=24, clip_sigma=3.5, # clip_sigma used for streaming updates
2111
+ status_cb=lambda s: None, progress_cb=None
2112
+ ):
2113
+ """
2114
+ Seed:
2115
+ 1) Load first B frames -> mean0
2116
+ 2) MAD around mean0 -> ±4·MAD mask -> masked-mean seed (one mini-iteration)
2117
+ 3) Stream remaining frames with σ-clipped Welford updates (unchanged behavior)
2118
+ Returns float32 image in (H,W) or (C,H,W) matching color_mode.
2119
+ """
2120
+ def p(frac, msg):
2121
+ if progress_cb:
2122
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2123
+
2124
+ n_total = len(paths)
2125
+ B = int(max(1, min(int(bootstrap_frames), n_total)))
2126
+ status_cb(f"MFDeconv: Seed bootstrap {B} frame(s) with ±4·MAD clip on the average…")
2127
+ p(0.00, f"bootstrap load 0/{B}")
2128
+
2129
+ # ---------- load first B frames ----------
2130
+ boot = []
2131
+ for i, pth in enumerate(paths[:B], start=1):
2132
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2133
+ boot.append(ys[0].astype(np.float32, copy=False))
2134
+ if (i == B) or (i % 4 == 0):
2135
+ p(0.25 * (i / float(B)), f"bootstrap load {i}/{B}")
2136
+
2137
+ stack = np.stack(boot, axis=0) # (B,H,W) or (B,C,H,W)
2138
+ del boot
2139
+
2140
+ # ---------- mean0 ----------
2141
+ mean0 = np.mean(stack, axis=0, dtype=np.float32)
2142
+ p(0.28, "bootstrap mean computed")
2143
+
2144
+ # ---------- ±4·MAD clip around mean0, then masked mean (one pass) ----------
2145
+ # MAD per-pixel: median(|x - mean0|)
2146
+ abs_dev = np.abs(stack - mean0[None, ...])
2147
+ mad = np.median(abs_dev, axis=0).astype(np.float32, copy=False)
2148
+
2149
+ thr = 4.0 * mad + EPS
2150
+ mask = (abs_dev <= thr)
2151
+
2152
+ # masked mean with fallback to mean0 where all rejected
2153
+ m = mask.astype(np.float32, copy=False)
2154
+ sum_acc = np.sum(stack * m, axis=0, dtype=np.float32)
2155
+ cnt_acc = np.sum(m, axis=0, dtype=np.float32)
2156
+ seed = mean0.copy()
2157
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2158
+ p(0.36, "±4·MAD masked mean computed")
2159
+
2160
+ # ---------- initialize Welford state from seed ----------
2161
+ # Start μ=seed, set an initial variance envelope from the bootstrap dispersion
2162
+ dif = stack - seed[None, ...]
2163
+ M2 = np.sum(dif * dif, axis=0, dtype=np.float32)
2164
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2165
+ mu = seed.astype(np.float32, copy=False)
2166
+ del stack, abs_dev, mad, m, sum_acc, cnt_acc, dif
2167
+
2168
+ p(0.40, "seed initialized; streaming refinements…")
2169
+
2170
+ # ---------- stream remaining frames with σ-clipped Welford updates ----------
2171
+ remain = n_total - B
2172
+ if remain > 0:
2173
+ status_cb(f"MFDeconv: Seed μ–σ clipping {remain} remaining frame(s) (k={clip_sigma:.2f})…")
2174
+
2175
+ k = float(clip_sigma)
2176
+ for j, pth in enumerate(paths[B:], start=1):
2177
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2178
+ x = ys[0].astype(np.float32, copy=False)
2179
+
2180
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2181
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2182
+
2183
+ accept = (np.abs(x - mu) <= (k * sigma))
2184
+ acc = accept.astype(np.float32, copy=False)
2185
+
2186
+ n_new = cnt + acc
2187
+ delta = x - mu
2188
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2189
+ M2 = M2 + acc * delta * (x - mu_n)
2190
+
2191
+ mu, cnt = mu_n, n_new
2192
+
2193
+ if (j == remain) or (j % 8 == 0):
2194
+ p(0.40 + 0.60 * (j / float(remain)), f"μ–σ refine {j}/{remain}")
2195
+
2196
+ p(1.0, "seed ready")
2197
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2198
+
2199
+
2200
+ def _chunk(seq, n):
2201
+ """Yield chunks of size n from seq."""
2202
+ for i in range(0, len(seq), n):
2203
+ yield seq[i:i+n]
2204
+
2205
+ def _read_shape_fast(path) -> tuple[int,int,int]:
2206
+ if _is_xisf(path):
2207
+ a, _ = _load_image_array(path)
2208
+ if a is None:
2209
+ raise ValueError(f"No data in {path}")
2210
+ a = np.asarray(a)
2211
+ else:
2212
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2213
+ a = hdul[0].data
2214
+ if a is None:
2215
+ raise ValueError(f"No data in {path}")
2216
+
2217
+ # common logic for both XISF and FITS
2218
+ if a.ndim == 2:
2219
+ H, W = a.shape
2220
+ return (1, int(H), int(W))
2221
+ if a.ndim == 3:
2222
+ if a.shape[-1] in (1, 3): # HWC
2223
+ C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
2224
+ return (1 if C == 1 else 3, H, W)
2225
+ if a.shape[0] in (1, 3): # CHW
2226
+ return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
2227
+ s = tuple(map(int, a.shape))
2228
+ H, W = s[-2], s[-1]
2229
+ return (1, H, W)
2230
+
2231
+
2232
+
2233
+
2234
+ def _tiles_of(hw: tuple[int,int], tile_hw: tuple[int,int], halo: int):
2235
+ """
2236
+ Yield tiles as dicts: {y0,y1,x0,x1,yc0,yc1,xc0,xc1}
2237
+ (outer region includes halo; core (yc0:yc1, xc0:xc1) excludes halo).
2238
+ """
2239
+ H, W = hw
2240
+ th, tw = tile_hw
2241
+ th = max(1, int(th)); tw = max(1, int(tw))
2242
+ for y in range(0, H, th):
2243
+ for x in range(0, W, tw):
2244
+ yc0 = y; yc1 = min(y + th, H)
2245
+ xc0 = x; xc1 = min(x + tw, W)
2246
+ y0 = max(0, yc0 - halo); y1 = min(H, yc1 + halo)
2247
+ x0 = max(0, xc0 - halo); x1 = min(W, xc1 + halo)
2248
+ yield dict(y0=y0, y1=y1, x0=x0, x1=x1, yc0=yc0, yc1=yc1, xc0=xc0, xc1=xc1)
2249
+
2250
+ def _extract_with_halo(a, tile):
2251
+ """
2252
+ Slice 'a' ((H,W) or (C,H,W)) to [y0:y1, x0:x1] with channel kept.
2253
+ """
2254
+ y0,y1,x0,x1 = tile["y0"], tile["y1"], tile["x0"], tile["x1"]
2255
+ if a.ndim == 2:
2256
+ return a[y0:y1, x0:x1]
2257
+ else:
2258
+ return a[:, y0:y1, x0:x1]
2259
+
2260
+ def _add_core(accum, tile_val, tile):
2261
+ """
2262
+ Add tile_val core into accum at (yc0:yc1, xc0:xc1).
2263
+ Shapes match (2D) or (C,H,W).
2264
+ """
2265
+ yc0,yc1,xc0,xc1 = tile["yc0"], tile["yc1"], tile["xc0"], tile["xc1"]
2266
+ if accum.ndim == 2:
2267
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2268
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2269
+ accum[yc0:yc1, xc0:xc1] += tile_val[h0:h1, w0:w1]
2270
+ else:
2271
+ h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
2272
+ w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
2273
+ accum[:, yc0:yc1, xc0:xc1] += tile_val[:, h0:h1, w0:w1]
2274
+
2275
+ def _prepare_np_fft_packs_batch(psfs, flip_psf, Hs, Ws):
2276
+ """Precompute rFFT packs on current grid for NumPy path; returns lists aligned to batch psfs."""
2277
+ Kfs, KTfs, meta = [], [], []
2278
+ import numpy.fft as fft
2279
+ for k, kT in zip(psfs, flip_psf):
2280
+ kh, kw = k.shape
2281
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2282
+ Kfs.append(fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)))
2283
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)))
2284
+ meta.append((kh, kw, fftH, fftW))
2285
+ return Kfs, KTfs, meta
2286
+
2287
+ def _prepare_torch_fft_packs_batch(psfs, flip_psf, Hs, Ws, device, dtype):
2288
+ """Torch FFT packs per PSF on current grid; mirrors your existing packer."""
2289
+ return _precompute_torch_psf_ffts(psfs, flip_psf, Hs, Ws, device, dtype)
2290
+
2291
+ def _as_chw(np_img: np.ndarray) -> np.ndarray:
2292
+ x = np.asarray(np_img, dtype=np.float32, order="C")
2293
+ if x.size == 0:
2294
+ raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
2295
+ if x.ndim == 2:
2296
+ return x[None, ...] # 1,H,W
2297
+ if x.ndim == 3 and x.shape[0] in (1, 3):
2298
+ if x.shape[0] == 0:
2299
+ raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
2300
+ return x
2301
+ if x.ndim == 3 and x.shape[-1] in (1, 3):
2302
+ if x.shape[-1] == 0:
2303
+ raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
2304
+ return np.moveaxis(x, -1, 0)
2305
+ # last resort: treat first dim as channels, but reject zero
2306
+ if x.shape[0] == 0:
2307
+ raise RuntimeError(f"Zero channels in array; shape={x.shape}")
2308
+ return x
2309
+
2310
+
2311
+
2312
+ def _conv_same_np_spatial(a: np.ndarray, k: np.ndarray, out: np.ndarray | None = None):
2313
+ try:
2314
+ import cv2
2315
+ except Exception:
2316
+ return None # no opencv -> caller falls back to FFT
2317
+
2318
+ # cv2 wants HxW single-channel float32
2319
+ kf = np.ascontiguousarray(k.astype(np.float32))
2320
+ kf = np.flip(np.flip(kf, 0), 1) # OpenCV uses correlation; flip to emulate conv
2321
+
2322
+ if a.ndim == 2:
2323
+ y = cv2.filter2D(a, -1, kf, borderType=cv2.BORDER_REFLECT)
2324
+ if out is None: return y
2325
+ out[...] = y; return out
2326
+ else:
2327
+ C, H, W = a.shape
2328
+ if out is None:
2329
+ out = np.empty_like(a)
2330
+ for c in range(C):
2331
+ out[c] = cv2.filter2D(a[c], -1, kf, borderType=cv2.BORDER_REFLECT)
2332
+ return out
2333
+
2334
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
2335
+ F = torch.nn.functional
2336
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
2337
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
2338
+
2339
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
2340
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
2341
+
2342
+ # unified path (CUDA/CPU/MPS): one grouped conv with G=B*C
2343
+ G = int(B * C)
2344
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
2345
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
2346
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
2347
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
2348
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
2349
+
2350
+
2351
+ # put near other small helpers
2352
+ def _robust_med_mad_t(x, max_elems_per_sample: int = 2_000_000):
2353
+ """
2354
+ x: (B, C, H, W) tensor on device.
2355
+ Returns (median[B,1,1,1], mad[B,1,1,1]) computed on a strided subsample
2356
+ to avoid 'quantile() input tensor is too large'.
2357
+ """
2358
+ import math
2359
+ import torch
2360
+ B = x.shape[0]
2361
+ flat = x.reshape(B, -1)
2362
+ N = flat.shape[1]
2363
+ if N > max_elems_per_sample:
2364
+ stride = int(math.ceil(N / float(max_elems_per_sample)))
2365
+ flat = flat[:, ::stride] # strided subsample
2366
+ med = torch.quantile(flat, 0.5, dim=1, keepdim=True)
2367
+ mad = torch.quantile((flat - med).abs(), 0.5, dim=1, keepdim=True) + 1e-6
2368
+ return med.view(B,1,1,1), mad.view(B,1,1,1)
2369
+
2370
+ def _torch_should_use_spatial(psf_ksize: int) -> bool:
2371
+ # Prefer spatial on non-CUDA backends and for modest kernels.
2372
+ try:
2373
+ dev = _torch_device()
2374
+ if dev.type in ("mps", "privateuseone"): # privateuseone = DirectML
2375
+ return True
2376
+ if dev.type == "cuda":
2377
+ return psf_ksize <= 51 # typical PSF sizes; spatial is fast & stable
2378
+ except Exception:
2379
+ pass
2380
+ # Allow override via env
2381
+ import os as _os
2382
+ if _os.environ.get("MF_SPATIAL", "") == "1":
2383
+ return True
2384
+ return False
2385
+
2386
+ def _read_tile_fits(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2387
+ """Return a (H,W) or (H,W,3|1) tile via FITS memmap, without loading whole image."""
2388
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2389
+ hdu = hdul[0]
2390
+ a = hdu.data
2391
+ if a is None:
2392
+ # find first image HDU if primary is header-only
2393
+ for h in hdul[1:]:
2394
+ if getattr(h, "data", None) is not None:
2395
+ a = h.data; break
2396
+ a = np.asarray(a) # still lazy until sliced
2397
+ # squeeze trailing singleton if present to keep your conventions
2398
+ if a.ndim == 3 and a.shape[-1] == 1:
2399
+ a = np.squeeze(a, axis=-1)
2400
+ tile = a[y0:y1, x0:x1, ...]
2401
+ # copy so we own the buffer (we will cast/normalize)
2402
+ return np.array(tile, copy=True)
2403
+
2404
+ def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
2405
+ """FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
2406
+ ext = os.path.splitext(path)[1].lower()
2407
+
2408
+ if ext == ".xisf":
2409
+ a = _xisf_cached_array(path) # float32 memmap; cheap slicing
2410
+ # a is HW, HWC, or CHW (whatever _load_image_array returned)
2411
+ if a.ndim == 2:
2412
+ return np.array(a[y0:y1, x0:x1], copy=True)
2413
+ elif a.ndim == 3:
2414
+ if a.shape[-1] in (1, 3): # HWC
2415
+ out = a[y0:y1, x0:x1, :]
2416
+ if out.shape[-1] == 1: out = out[..., 0]
2417
+ return np.array(out, copy=True)
2418
+ elif a.shape[0] in (1, 3): # CHW
2419
+ out = a[:, y0:y1, x0:x1]
2420
+ if out.shape[0] == 1: out = out[0]
2421
+ return np.array(out, copy=True)
2422
+ else:
2423
+ raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
2424
+ else:
2425
+ raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
2426
+
2427
+ # FITS
2428
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2429
+ a = None
2430
+ for h in hdul:
2431
+ if getattr(h, "data", None) is not None:
2432
+ a = h.data
2433
+ break
2434
+ if a is None:
2435
+ raise ValueError(f"No image data in {path}")
2436
+
2437
+ a = np.asarray(a)
2438
+
2439
+ if a.ndim == 2: # HW
2440
+ return np.array(a[y0:y1, x0:x1], copy=True)
2441
+
2442
+ if a.ndim == 3:
2443
+ if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
2444
+ out = a[:, y0:y1, x0:x1]
2445
+ if out.shape[0] == 1:
2446
+ out = out[0]
2447
+ return np.array(out, copy=True)
2448
+ if a.shape[-1] in (1, 3): # HWC
2449
+ out = a[y0:y1, x0:x1, :]
2450
+ if out.shape[-1] == 1:
2451
+ out = out[..., 0]
2452
+ return np.array(out, copy=True)
2453
+
2454
+ # Fallback: assume last two axes are spatial (…, H, W)
2455
+ try:
2456
+ out = a[(..., slice(y0, y1), slice(x0, x1))]
2457
+ return np.array(out, copy=True)
2458
+ except Exception:
2459
+ raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
2460
+
2461
+ def _infer_channels_from_tile(p: str, Ht: int, Wt: int) -> int:
2462
+ """Look at a 1×1 tile to infer channel count; supports HW, HWC, CHW."""
2463
+ y1 = min(1, Ht); x1 = min(1, Wt)
2464
+ t = _read_tile_fits_any(p, 0, y1, 0, x1)
2465
+
2466
+ if t.ndim == 2:
2467
+ return 1
2468
+
2469
+ if t.ndim == 3:
2470
+ # Prefer the axis that actually carries the color planes
2471
+ ch_first = t.shape[0] in (1, 3)
2472
+ ch_last = t.shape[-1] in (1, 3)
2473
+
2474
+ if ch_first and not ch_last:
2475
+ return int(t.shape[0])
2476
+ if ch_last and not ch_first:
2477
+ return int(t.shape[-1])
2478
+
2479
+ # Ambiguous tiny tile (e.g. CHW 3×1×1 or HWC 1×1×3):
2480
+ if t.shape[0] == 3 or t.shape[-1] == 3:
2481
+ return 3
2482
+ return 1
2483
+
2484
+ return 1
2485
+
2486
+
2487
+
2488
+ def _seed_median_streaming(
2489
+ paths,
2490
+ Ht,
2491
+ Wt,
2492
+ *,
2493
+ color_mode="luma",
2494
+ tile_hw=(256, 256),
2495
+ status_cb=lambda s: None,
2496
+ progress_cb=lambda f, m="": None,
2497
+ use_torch: bool | None = None, # auto by default
2498
+ ):
2499
+ """
2500
+ Exact per-pixel median via tiling; RAM-bounded.
2501
+ Now shows per-tile progress and uses Torch on GPU if available.
2502
+ Parallelizes per-tile slab reads to hide I/O and luma work.
2503
+ """
2504
+
2505
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2506
+ # old: want_c = 1 if str(color_mode).lower() == "luma" else (3 if _read_shape_fast(paths[0])[0] == 3 else 1)
2507
+ if str(color_mode).lower() == "luma":
2508
+ want_c = 1
2509
+ else:
2510
+ want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
2511
+ seed = np.zeros((Ht, Wt), np.float32) if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32)
2512
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt)) for y in range(0, Ht, th) for x in range(0, Wt, tw)]
2513
+ total = len(tiles)
2514
+ n_frames = len(paths)
2515
+
2516
+ # Choose a sensible number of I/O workers (bounded by frames)
2517
+ try:
2518
+ _cpu = (os.cpu_count() or 4)
2519
+ except Exception:
2520
+ _cpu = 4
2521
+ io_workers = max(1, min(8, _cpu, n_frames))
2522
+
2523
+ # Torch autodetect (once)
2524
+ TORCH_OK = False
2525
+ device = None
2526
+ if use_torch is not False:
2527
+ try:
2528
+ from setiastro.saspro.runtime_torch import import_torch
2529
+ _t = import_torch(prefer_cuda=True, status_cb=status_cb)
2530
+ dev = None
2531
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
2532
+ dev = _t.device("cuda")
2533
+ elif hasattr(_t.backends, "mps") and _t.backends.mps.is_available():
2534
+ dev = _t.device("mps")
2535
+ else:
2536
+ dev = None # CPU tensors slower than NumPy for median; only use if forced
2537
+ if dev is not None:
2538
+ TORCH_OK = True
2539
+ device = dev
2540
+ status_cb(f"Median seed: using Torch device {device}")
2541
+ except Exception as e:
2542
+ status_cb(f"Median seed: Torch unavailable → NumPy fallback ({e})")
2543
+ TORCH_OK = False
2544
+ device = None
2545
+
2546
+ def _tile_msg(ti, tn):
2547
+ return f"median tiles {ti}/{tn}"
2548
+
2549
+ done = 0
2550
+
2551
+ for (y0, y1, x0, x1) in tiles:
2552
+ h, w = (y1 - y0), (x1 - x0)
2553
+
2554
+ # per-tile slab reader with incremental progress (parallel)
2555
+ def _read_slab_for_channel(csel=None):
2556
+ """
2557
+ Returns slab of shape (N, h, w) float32 in [0,1]-ish (normalized if input was integer).
2558
+ If csel is None and luma is requested, computes luma.
2559
+ """
2560
+ # parallel worker returns (i, tile2d)
2561
+ def _load_one(i):
2562
+ t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
2563
+
2564
+ # normalize dtype
2565
+ if t.dtype.kind in "ui":
2566
+ t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
2567
+ else:
2568
+ t = t.astype(np.float32, copy=False)
2569
+
2570
+ # luma / channel selection
2571
+ if want_c == 1:
2572
+ if t.ndim == 3:
2573
+ t = _to_luma_local(t)
2574
+ elif t.ndim != 2:
2575
+ t = _to_luma_local(t)
2576
+ else:
2577
+ if t.ndim == 2:
2578
+ pass
2579
+ elif t.ndim == 3 and t.shape[-1] == 3: # HWC
2580
+ t = t[..., csel]
2581
+ elif t.ndim == 3 and t.shape[0] == 3: # CHW
2582
+ t = t[csel]
2583
+ else:
2584
+ t = _to_luma_local(t)
2585
+ return i, np.ascontiguousarray(t, dtype=np.float32)
2586
+
2587
+ slab = np.empty((n_frames, h, w), np.float32)
2588
+ done_local = 0
2589
+ # cap workers by n_frames so we don't spawn useless threads
2590
+ with ThreadPoolExecutor(max_workers=min(io_workers, n_frames)) as ex:
2591
+ futures = [ex.submit(_load_one, i) for i in range(n_frames)]
2592
+ for fut in as_completed(futures):
2593
+ i, t2d = fut.result()
2594
+ # quick sanity (avoid silent mis-shapes)
2595
+ if t2d.shape != (h, w):
2596
+ raise RuntimeError(
2597
+ f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
2598
+ f"tile={(y0,y1,x0,x1)}"
2599
+ )
2600
+ slab[i] = t2d
2601
+ done_local += 1
2602
+ if (done_local & 7) == 0 or done_local == n_frames:
2603
+ tile_base = done / total
2604
+ tile_span = 1.0 / total
2605
+ inner = done_local / n_frames
2606
+ progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
2607
+ return slab
2608
+
2609
+ try:
2610
+ if want_c == 1:
2611
+ t0 = time.perf_counter()
2612
+ slab = _read_slab_for_channel()
2613
+ t1 = time.perf_counter()
2614
+ if TORCH_OK:
2615
+ import torch as _t
2616
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32) # one H2D
2617
+ med_t = slab_t.median(dim=0).values
2618
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2619
+ # no per-tile empty_cache() — avoid forced syncs
2620
+ else:
2621
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2622
+ t2 = time.perf_counter()
2623
+ seed[y0:y1, x0:x1] = med_np
2624
+ # lightweight telemetry to confirm bottleneck
2625
+ status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median={'GPU' if TORCH_OK else 'CPU'}={t2-t1:.3f}s")
2626
+ else:
2627
+ for c in range(want_c):
2628
+ slab = _read_slab_for_channel(csel=c)
2629
+ if TORCH_OK:
2630
+ import torch as _t
2631
+ slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32)
2632
+ med_t = slab_t.median(dim=0).values
2633
+ med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
2634
+ else:
2635
+ med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
2636
+ seed[c, y0:y1, x0:x1] = med_np
2637
+
2638
+ except RuntimeError as e:
2639
+ # Per-tile GPU OOM or device issues → fallback to NumPy for this tile
2640
+ msg = str(e).lower()
2641
+ if TORCH_OK and ("out of memory" in msg or "resource" in msg or "alloc" in msg):
2642
+ status_cb(f"Median seed: GPU OOM on tile ({h}x{w}); falling back to NumPy for this tile.")
2643
+ if want_c == 1:
2644
+ slab = _read_slab_for_channel()
2645
+ seed[y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2646
+ else:
2647
+ for c in range(want_c):
2648
+ slab = _read_slab_for_channel(csel=c)
2649
+ seed[c, y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
2650
+ else:
2651
+ raise
2652
+
2653
+ done += 1
2654
+ # report tile completion (remaining 20% of tile span reserved for median compute)
2655
+ progress_cb(done / total, _tile_msg(done, total))
2656
+
2657
+ if (done & 3) == 0:
2658
+ _process_gui_events_safely()
2659
+ status_cb(f"Median seed: want_c={want_c}, seed_shape={seed.shape}")
2660
+ return seed
2661
+
2662
+ def _seed_bootstrap_streaming(paths, Ht, Wt, color_mode,
2663
+ bootstrap_frames: int = 20,
2664
+ clip_sigma: float = 5.0,
2665
+ status_cb=lambda s: None,
2666
+ progress_cb=None):
2667
+ """
2668
+ Seed = average first B frames, estimate global MAD threshold, masked-mean on B,
2669
+ then stream the remaining frames with σ-clipped Welford μ–σ updates.
2670
+
2671
+ Returns float32, CHW for per-channel mode, or CHW(1,...) for luma (your caller later squeezes if needed).
2672
+ """
2673
+ def p(frac, msg):
2674
+ if progress_cb:
2675
+ progress_cb(float(max(0.0, min(1.0, frac))), msg)
2676
+
2677
+ n = len(paths)
2678
+ B = int(max(1, min(int(bootstrap_frames), n)))
2679
+ status_cb(f"Seed: bootstrap={B}, clip_sigma={clip_sigma}")
2680
+
2681
+ # ---------- pass 1: running mean over the first B frames (Welford μ only) ----------
2682
+ cnt = None
2683
+ mu = None
2684
+ for i, pth in enumerate(paths[:B], 1):
2685
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2686
+ x = ys[0].astype(np.float32, copy=False)
2687
+ if mu is None:
2688
+ mu = x.copy()
2689
+ cnt = np.ones_like(x, dtype=np.float32)
2690
+ else:
2691
+ delta = x - mu
2692
+ cnt += 1.0
2693
+ mu += delta / cnt
2694
+ if (i == B) or (i % 4) == 0:
2695
+ p(0.20 * (i / float(B)), f"bootstrap mean {i}/{B}")
2696
+
2697
+ # ---------- pass 2: estimate a *global* MAD around μ using strided samples ----------
2698
+ stride = 4 if max(Ht, Wt) > 1500 else 2
2699
+ # CHW or HW → build a slice that keeps C and strides H,W
2700
+ samp_slices = (slice(None) if mu.ndim == 3 else slice(None),
2701
+ slice(None, None, stride),
2702
+ slice(None, None, stride)) if mu.ndim == 3 else \
2703
+ (slice(None, None, stride), slice(None, None, stride))
2704
+
2705
+ mad_samples = []
2706
+ for i, pth in enumerate(paths[:B], 1):
2707
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2708
+ x = ys[0].astype(np.float32, copy=False)
2709
+ d = np.abs(x - mu)
2710
+ mad_samples.append(d[samp_slices].ravel())
2711
+ if (i == B) or (i % 8) == 0:
2712
+ p(0.35 + 0.10 * (i / float(B)), f"bootstrap MAD {i}/{B}")
2713
+
2714
+ # robust MAD estimate (scalar) → 4·MAD clip band
2715
+ mad_est = float(np.median(np.concatenate(mad_samples).astype(np.float32)))
2716
+ thr = 4.0 * max(mad_est, 1e-6)
2717
+
2718
+ # ---------- pass 3: masked mean over first B frames using the global threshold ----------
2719
+ sum_acc = np.zeros_like(mu, dtype=np.float32)
2720
+ cnt_acc = np.zeros_like(mu, dtype=np.float32)
2721
+ for i, pth in enumerate(paths[:B], 1):
2722
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2723
+ x = ys[0].astype(np.float32, copy=False)
2724
+ m = (np.abs(x - mu) <= thr).astype(np.float32, copy=False)
2725
+ sum_acc += x * m
2726
+ cnt_acc += m
2727
+ if (i == B) or (i % 4) == 0:
2728
+ p(0.48 + 0.12 * (i / float(B)), f"masked mean {i}/{B}")
2729
+
2730
+ seed = mu.copy()
2731
+ np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
2732
+
2733
+ # ---------- pass 4: μ–σ streaming on the remaining frames (σ-clipped Welford) ----------
2734
+ M2 = np.zeros_like(seed, dtype=np.float32) # sum of squared diffs
2735
+ cnt = np.full_like(seed, float(B), dtype=np.float32)
2736
+ mu = seed.astype(np.float32, copy=False)
2737
+
2738
+ remain = n - B
2739
+ k = float(clip_sigma)
2740
+ for j, pth in enumerate(paths[B:], 1):
2741
+ ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
2742
+ x = ys[0].astype(np.float32, copy=False)
2743
+
2744
+ var = M2 / np.maximum(cnt - 1.0, 1.0)
2745
+ sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
2746
+
2747
+ acc = (np.abs(x - mu) <= (k * sigma)).astype(np.float32, copy=False) # {0,1} mask
2748
+ n_new = cnt + acc
2749
+ delta = x - mu
2750
+ mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
2751
+ M2 = M2 + acc * delta * (x - mu_n)
2752
+
2753
+ mu, cnt = mu_n, n_new
2754
+
2755
+ if (j == remain) or (j % 8) == 0:
2756
+ p(0.60 + 0.40 * (j / float(max(1, remain))), f"μ–σ refine {j}/{remain}")
2757
+
2758
+ return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
2759
+
2760
+
2761
+ def _coerce_sr_factor(srf, *, default_on_bad=2):
2762
+ """
2763
+ Parse super-res factor robustly:
2764
+ - accepts 2, '2', '2x', ' 2 X ', 2.0
2765
+ - clamps to integers >= 1
2766
+ - if invalid/missing → returns default_on_bad (we want 2 by your request)
2767
+ """
2768
+ if srf is None:
2769
+ return int(default_on_bad)
2770
+ if isinstance(srf, (float, int)):
2771
+ r = int(round(float(srf)))
2772
+ return int(r if r >= 1 else default_on_bad)
2773
+ s = str(srf).strip().lower()
2774
+ # common GUIs pass e.g. "2x", "3×", etc.
2775
+ s = s.replace("×", "x")
2776
+ if s.endswith("x"):
2777
+ s = s[:-1]
2778
+ try:
2779
+ r = int(round(float(s)))
2780
+ return int(r if r >= 1 else default_on_bad)
2781
+ except Exception:
2782
+ return int(default_on_bad)
2783
+
2784
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2785
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2786
+ k = np.asarray(k, dtype=np.float32)
2787
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2788
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2789
+ if kh == K and kw == K:
2790
+ return k
2791
+ out = np.zeros((K, K), dtype=np.float32)
2792
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2793
+ out[y0:y0+kh, x0:x0+kw] = k
2794
+ s = float(out.sum())
2795
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2796
+
2797
+ # -----------------------------
2798
+ # Core
2799
+ # -----------------------------
2800
+ def multiframe_deconv(
2801
+ paths,
2802
+ out_path,
2803
+ iters=20,
2804
+ kappa=2.0,
2805
+ color_mode="luma",
2806
+ seed_mode: str = "robust",
2807
+ huber_delta=0.0,
2808
+ masks=None,
2809
+ variances=None,
2810
+ rho="huber",
2811
+ status_cb=lambda s: None,
2812
+ min_iters: int = 3,
2813
+ use_star_masks: bool = False,
2814
+ use_variance_maps: bool = False,
2815
+ star_mask_cfg: dict | None = None,
2816
+ varmap_cfg: dict | None = None,
2817
+ save_intermediate: bool = False,
2818
+ save_every: int = 1,
2819
+ # SR options
2820
+ super_res_factor: int = 1,
2821
+ sr_sigma: float = 1.1,
2822
+ sr_psf_opt_iters: int = 250,
2823
+ sr_psf_opt_lr: float = 0.1,
2824
+ # NEW
2825
+ batch_frames: int | None = None,
2826
+ # GPU tuning (optional knobs)
2827
+ mixed_precision: bool | None = None, # default: auto (True on CUDA/MPS)
2828
+ fft_kernel_threshold: int = 1024, # switch to FFT if K >= this (or lower if SR)
2829
+ prefetch_batches: bool = True, # CPU→GPU double-buffer prefetch
2830
+ use_channels_last: bool | None = None, # default: auto (True on CUDA/MPS)
2831
+ force_cpu: bool = False,
2832
+ star_mask_ref_path: str | None = None,
2833
+ low_mem: bool = False,
2834
+ ):
2835
+ """
2836
+ Streaming multi-frame deconvolution with optional SR (r>1).
2837
+ Optimized GPU path: AMP for convs, channels-last, pinned-memory prefetch, optional FFT for large kernels.
2838
+ """
2839
+ mixed_precision = False
2840
+ DEBUG_FLAT_WEIGHTS = False
2841
+ # ---------- local helpers (kept self-contained) ----------
2842
+ def _emit_pct(pct: float, msg: str | None = None):
2843
+ pct = float(max(0.0, min(1.0, pct)))
2844
+ status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
2845
+
2846
+ def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
2847
+ """Pad/center an odd-sized kernel to K×K (K odd)."""
2848
+ k = np.asarray(k, dtype=np.float32)
2849
+ kh, kw = int(k.shape[0]), int(k.shape[1])
2850
+ assert (kh % 2 == 1) and (kw % 2 == 1)
2851
+ if kh == K and kw == K:
2852
+ return k
2853
+ out = np.zeros((K, K), dtype=np.float32)
2854
+ y0 = (K - kh)//2; x0 = (K - kw)//2
2855
+ out[y0:y0+kh, x0:x0+kw] = k
2856
+ s = float(out.sum())
2857
+ return out if s <= 0 else (out / s).astype(np.float32, copy=False)
2858
+
2859
+ max_iters = max(1, int(iters))
2860
+ min_iters = max(1, int(min_iters))
2861
+ if min_iters > max_iters:
2862
+ min_iters = max_iters
2863
+
2864
+ n_frames = len(paths)
2865
+ status_cb(f"MFDeconv: scanning {n_frames} aligned frames (memmap)…")
2866
+ _emit_pct(0.02, "scanning")
2867
+
2868
+ # choose common intersection size without loading full pixels
2869
+ Ht, Wt = _common_hw_from_paths(paths)
2870
+ _emit_pct(0.05, "preparing")
2871
+
2872
+ # --- LOW-MEM PATCH (begin) ---
2873
+ if low_mem:
2874
+ # Cap decoded-frame LRU to keep peak RAM sane on 16 GB laptops
2875
+ try:
2876
+ _FRAME_LRU.cap = max(1, min(getattr(_FRAME_LRU, "cap", 8), 2))
2877
+ except Exception:
2878
+ pass
2879
+
2880
+ # Disable CPU→GPU prefetch to avoid double-buffering allocations
2881
+ prefetch_batches = False
2882
+
2883
+ # Relax SEP background grid & star detection canvas when requested
2884
+ if use_variance_maps:
2885
+ varmap_cfg = {**(varmap_cfg or {})}
2886
+ # fewer, larger tiles → fewer big temporaries
2887
+ varmap_cfg.setdefault("bw", 96)
2888
+ varmap_cfg.setdefault("bh", 96)
2889
+
2890
+ if use_star_masks:
2891
+ star_mask_cfg = {**(star_mask_cfg or {})}
2892
+ # shrink detection canvas to limit temp buffers inside SEP/mask draw
2893
+ star_mask_cfg["max_side"] = int(min(1024, int(star_mask_cfg.get("max_side", 2048))))
2894
+ # --- LOW-MEM PATCH (end) ---
2895
+
2896
+
2897
+ if any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths):
2898
+ status_cb("MFDeconv: priming XISF cache (one-time decode per frame)…")
2899
+ for i, p in enumerate(paths, 1):
2900
+ try:
2901
+ _ = _xisf_cached_array(p) # decode once, store memmap
2902
+ except Exception as e:
2903
+ status_cb(f"XISF cache failed for {p}: {e}")
2904
+ if (i & 7) == 0 or i == len(paths):
2905
+ _process_gui_events_safely()
2906
+
2907
+ # per-frame loader & sequence view (closures capture Ht/Wt/color_mode/paths)
2908
+ def _load_frame_chw(i: int):
2909
+ return _FRAME_LRU.get(paths[i], Ht, Wt, color_mode)
2910
+
2911
+ class _FrameSeq:
2912
+ def __len__(self): return len(paths)
2913
+ def __getitem__(self, i): return _load_frame_chw(i)
2914
+ data = _FrameSeq()
2915
+
2916
+ # ---- torch detection (optional) ----
2917
+ global torch, TORCH_OK
2918
+ torch = None
2919
+ TORCH_OK = False
2920
+ cuda_ok = mps_ok = dml_ok = False
2921
+ dml_device = None
2922
+
2923
+ try:
2924
+ from setiastro.saspro.runtime_torch import import_torch
2925
+ torch = import_torch(prefer_cuda=True, status_cb=status_cb)
2926
+ TORCH_OK = True
2927
+ try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
2928
+ except Exception as e:
2929
+ import logging
2930
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2931
+ try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
2932
+ except Exception as e:
2933
+ import logging
2934
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2935
+ try:
2936
+ import torch_directml
2937
+ dml_device = torch_directml.device()
2938
+ _ = (torch.ones(1, device=dml_device) + 1).item()
2939
+ dml_ok = True
2940
+ # NEW: expose to _torch_device()
2941
+ globals()["dml_ok"] = True
2942
+ globals()["dml_device"] = dml_device
2943
+ except Exception:
2944
+ dml_ok = False
2945
+ globals()["dml_ok"] = False
2946
+ globals()["dml_device"] = None
2947
+
2948
+ if cuda_ok:
2949
+ status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
2950
+ elif mps_ok:
2951
+ status_cb("PyTorch MPS (Apple) available: True")
2952
+ elif dml_ok:
2953
+ status_cb("PyTorch DirectML (Windows) available: True")
2954
+ else:
2955
+ status_cb("PyTorch present, using CPU backend.")
2956
+ status_cb(
2957
+ f"PyTorch {getattr(torch,'__version__','?')} backend: "
2958
+ + ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
2959
+ )
2960
+
2961
+ try:
2962
+ # keep cuDNN autotune on
2963
+ if getattr(torch.backends, "cudnn", None) is not None:
2964
+ torch.backends.cudnn.benchmark = True
2965
+ torch.backends.cudnn.allow_tf32 = False
2966
+ except Exception:
2967
+ pass
2968
+
2969
+ try:
2970
+ # disable TF32 matmul shortcuts if present (CUDA-only; safe no-op elsewhere)
2971
+ if getattr(getattr(torch.backends, "cuda", None), "matmul", None) is not None:
2972
+ torch.backends.cuda.matmul.allow_tf32 = False
2973
+ except Exception:
2974
+ pass
2975
+
2976
+ try:
2977
+ # prefer highest FP32 precision on PT 2.x
2978
+ if hasattr(torch, "set_float32_matmul_precision"):
2979
+ torch.set_float32_matmul_precision("highest")
2980
+ except Exception:
2981
+ pass
2982
+ except Exception as e:
2983
+ TORCH_OK = False
2984
+ status_cb(f"PyTorch not available → CPU path. ({e})")
2985
+
2986
+ if force_cpu:
2987
+ status_cb("⚠️ CPU-only debug mode: disabling PyTorch path.")
2988
+ TORCH_OK = False
2989
+
2990
+ _process_gui_events_safely()
2991
+
2992
+ # ---- PSFs + optional assets (computed in parallel, streaming I/O) ----
2993
+ psfs, masks_auto, vars_auto = _build_psf_and_assets(
2994
+ paths,
2995
+ make_masks=bool(use_star_masks),
2996
+ make_varmaps=bool(use_variance_maps),
2997
+ status_cb=status_cb,
2998
+ save_dir=None,
2999
+ star_mask_cfg=star_mask_cfg,
3000
+ varmap_cfg=varmap_cfg,
3001
+ star_mask_ref_path=star_mask_ref_path,
3002
+ # NEW:
3003
+ Ht=Ht, Wt=Wt, color_mode=color_mode,
3004
+ )
3005
+
3006
+ try:
3007
+ import gc as _gc
3008
+ _gc.collect()
3009
+ except Exception:
3010
+ pass
3011
+
3012
+ psfs_native = psfs # keep a reference to the original native PSFs for fallback
3013
+
3014
+ # ---- SR lift of PSFs if needed ----
3015
+ r_req = _coerce_sr_factor(super_res_factor, default_on_bad=2)
3016
+ status_cb(f"MFDeconv: SR factor requested={super_res_factor!r} → using r={r_req}")
3017
+ r = int(r_req)
3018
+
3019
+ if r > 1:
3020
+ status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
3021
+ _process_gui_events_safely()
3022
+
3023
+ def _naive_sr_from_native(f_nat: np.ndarray, r_: int) -> np.ndarray:
3024
+ f2 = np.asarray(f_nat, np.float32)
3025
+ H, W = f2.shape[:2]
3026
+ k_sq = min(H, W)
3027
+ if H != W:
3028
+ y0 = (H - k_sq) // 2; x0 = (W - k_sq) // 2
3029
+ f2 = f2[y0:y0+k_sq, x0:x0+k_sq]
3030
+ if (f2.shape[0] % 2) == 0:
3031
+ f2 = f2[1:, 1:] # make odd
3032
+ f2 = _normalize_psf(f2)
3033
+ h0 = np.zeros((f2.shape[0]*r_, f2.shape[1]*r_), np.float32)
3034
+ h0[::r_, ::r_] = f2
3035
+ return _normalize_psf(h0)
3036
+
3037
+ sr_psfs = []
3038
+ for i, k_native in enumerate(psfs, start=1):
3039
+ try:
3040
+ status_cb(f" SR-PSF{i}: native shape={np.asarray(k_native).shape}")
3041
+ h = _solve_super_psf_from_native(
3042
+ k_native, r=r, sigma=float(sr_sigma),
3043
+ iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr)
3044
+ )
3045
+ except Exception as e:
3046
+ status_cb(f" SR-PSF{i} failed: {e!r} → using naïve upsample")
3047
+ h = _naive_sr_from_native(k_native, r)
3048
+
3049
+ # guarantee odd size for downstream SAME-padding math
3050
+ if (h.shape[0] % 2) == 0:
3051
+ h = h[:-1, :-1]
3052
+ sr_psfs.append(h.astype(np.float32, copy=False))
3053
+ status_cb(f" SR-PSF{i}: native {np.asarray(k_native).shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
3054
+ psfs = sr_psfs
3055
+
3056
+
3057
+
3058
+ # ---- Seed (streaming) with robust bootstrap already in file helpers ----
3059
+ _emit_pct(0.25, "Calculating Seed Image...")
3060
+ def _seed_progress(frac, msg):
3061
+ _emit_pct(0.25 + 0.15 * float(frac), f"seed: {msg}")
3062
+
3063
+ seed_mode_s = str(seed_mode).lower().strip()
3064
+ if seed_mode_s not in ("robust", "median"):
3065
+ seed_mode_s = "robust"
3066
+ if seed_mode_s == "median":
3067
+ status_cb("MFDeconv: Building median seed (tiled, streaming)…")
3068
+ seed_native = _seed_median_streaming(
3069
+ paths, Ht, Wt,
3070
+ color_mode=color_mode,
3071
+ tile_hw=(256, 256),
3072
+ status_cb=status_cb,
3073
+ progress_cb=_seed_progress,
3074
+ use_torch=TORCH_OK, # ← auto: GPU if available, else NumPy
3075
+ )
3076
+ else:
3077
+ seed_native = _seed_bootstrap_streaming(
3078
+ paths, Ht, Wt, color_mode,
3079
+ bootstrap_frames=20, clip_sigma=5,
3080
+ status_cb=status_cb, progress_cb=_seed_progress
3081
+ )
3082
+
3083
+ # lift seed if SR
3084
+ if r > 1:
3085
+ target_hw = (Ht * r, Wt * r)
3086
+ if seed_native.ndim == 2:
3087
+ x = _upsample_sum(seed_native / (r*r), r, target_hw=target_hw)
3088
+ else:
3089
+ C, Hn, Wn = seed_native.shape
3090
+ x = np.stack(
3091
+ [_upsample_sum(seed_native[c] / (r*r), r, target_hw=target_hw) for c in range(C)],
3092
+ axis=0
3093
+ )
3094
+ else:
3095
+ x = seed_native
3096
+
3097
+ # FINAL SHAPE CHECKS (auto-correct if a GUI sent something odd)
3098
+ if x.ndim == 2: x = x[None, ...]
3099
+ Hs, Ws = x.shape[-2], x.shape[-1]
3100
+ if r > 1:
3101
+ expected_H, expected_W = Ht * r, Wt * r
3102
+ if (Hs, Ws) != (expected_H, expected_W):
3103
+ status_cb(f"SR seed grid mismatch: got {(Hs, Ws)}, expected {(expected_H, expected_W)} → correcting")
3104
+ # Rebuild from the native mean to ensure exact SR size
3105
+ x = _upsample_sum(x if x.ndim==2 else x[0], r, target_hw=(expected_H, expected_W))
3106
+ if x.ndim == 2: x = x[None, ...]
3107
+ try: del seed_native
3108
+ except Exception as e:
3109
+ import logging
3110
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3111
+
3112
+ try:
3113
+ import gc as _gc
3114
+ _gc.collect()
3115
+ except Exception:
3116
+ pass
3117
+
3118
+ flip_psf = [_flip_kernel(k) for k in psfs]
3119
+ _emit_pct(0.20, "PSF Ready")
3120
+
3121
+ # --- Harmonize seed channels with the actual frames ---
3122
+ # Probe the first frame's channel count (CHW from the cache)
3123
+ try:
3124
+ y0_probe = _load_frame_chw(0) # CHW float32
3125
+ C_ref = 1 if y0_probe.ndim == 2 else int(y0_probe.shape[0])
3126
+ except Exception:
3127
+ # Fallback: infer from seed if probing fails
3128
+ C_ref = 1 if x.ndim == 2 else int(x.shape[0])
3129
+
3130
+ # If median seed came back mono but frames are RGB, broadcast it
3131
+ if x.ndim == 2 and C_ref == 3:
3132
+ x = np.stack([x] * 3, axis=0)
3133
+ elif x.ndim == 3 and x.shape[0] == 1 and C_ref == 3:
3134
+ x = np.repeat(x, 3, axis=0)
3135
+ # If seed is RGB but frames are mono (rare), collapse to luma
3136
+ elif x.ndim == 3 and x.shape[0] == 3 and C_ref == 1:
3137
+ # ITU-R BT.709 luma (safe in float)
3138
+ x = (0.2126 * x[0] + 0.7152 * x[1] + 0.0722 * x[2]).astype(np.float32)
3139
+
3140
+ # Ensure CHW shape for the rest of the pipeline
3141
+ if x.ndim == 2:
3142
+ x = x[None, ...]
3143
+ elif x.ndim == 3 and x.shape[0] not in (1, 3) and x.shape[-1] in (1, 3):
3144
+ x = np.moveaxis(x, -1, 0)
3145
+
3146
+ # Set the expected channel count from the FRAMES, not from the seed
3147
+ C_EXPECTED = int(C_ref)
3148
+ _, Hs, Ws = x.shape
3149
+
3150
+
3151
+
3152
+ # ---- choose default batch size ----
3153
+ if batch_frames is None:
3154
+ px = Hs * Ws
3155
+ if px >= 16_000_000: auto_B = 2
3156
+ elif px >= 8_000_000: auto_B = 4
3157
+ else: auto_B = 8
3158
+ else:
3159
+ auto_B = int(max(1, batch_frames))
3160
+
3161
+ # --- LOW-MEM PATCH: clamp batch size hard ---
3162
+ if low_mem:
3163
+ auto_B = max(1, min(auto_B, 2))
3164
+
3165
+ # ---- background/MAD telemetry (first frame) ----
3166
+ status_cb("MFDeconv: Calculating Backgrounds and MADs…")
3167
+ _process_gui_events_safely()
3168
+ try:
3169
+ y0 = data[0]; y0l = y0 if y0.ndim == 2 else y0[0]
3170
+ med = float(np.median(y0l)); mad = float(np.median(np.abs(y0l - med))) + 1e-6
3171
+ bg_est = 1.4826 * mad
3172
+ except Exception:
3173
+ bg_est = 0.0
3174
+ status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
3175
+
3176
+ # ---- mask/variance accessors ----
3177
+ def _mask_for(i, like_img):
3178
+ src = (masks if masks is not None else masks_auto)
3179
+ if src is None: # no masks at all
3180
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3181
+ m = src[i]
3182
+ if m is None:
3183
+ return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
3184
+ m = np.asarray(m, dtype=np.float32)
3185
+ if m.ndim == 3: m = m[0]
3186
+ return _center_crop(m, like_img.shape[-2], like_img.shape[-1]).astype(np.float32, copy=False)
3187
+
3188
+ def _var_for(i, like_img):
3189
+ src = (variances if variances is not None else vars_auto)
3190
+ if src is None: return None
3191
+ v = src[i]
3192
+ if v is None: return None
3193
+ v = np.asarray(v, dtype=np.float32)
3194
+ if v.ndim == 3: v = v[0]
3195
+ v = _center_crop(v, like_img.shape[-2], like_img.shape[-1])
3196
+ return np.clip(v, 1e-8, None).astype(np.float32, copy=False)
3197
+
3198
+ # ---- NumPy path conv helper (keep as-is) ----
3199
+ def _conv_np_same(a, k, out=None):
3200
+ y = _conv_same_np_spatial(a, k, out)
3201
+ if y is not None:
3202
+ return y
3203
+ # No OpenCV → always use the ifftshifted FFT path
3204
+ import numpy as _np
3205
+ import numpy.fft as _fft
3206
+ H, W = a.shape[-2:]
3207
+ kh, kw = k.shape
3208
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
3209
+ Kf = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW))
3210
+ if out is None:
3211
+ out = _np.empty_like(a, dtype=_np.float32)
3212
+ return _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out)
3213
+
3214
+ # ---- allocate scratch & prepare PSF tensors if torch ----
3215
+ relax = 0.7
3216
+ use_torch = bool(TORCH_OK)
3217
+ cm = _safe_inference_context() if use_torch else NO_GRAD
3218
+ rho_is_l2 = (str(rho).lower() == "l2")
3219
+ local_delta = 0.0 if rho_is_l2 else huber_delta
3220
+
3221
+ if use_torch:
3222
+ F = torch.nn.functional
3223
+ device = _torch_device()
3224
+
3225
+ # Force FP32 tensors
3226
+ x_t = _to_t(_contig(x)).to(torch.float32)
3227
+ num = torch.zeros_like(x_t, dtype=torch.float32)
3228
+ den = torch.zeros_like(x_t, dtype=torch.float32)
3229
+
3230
+ # channels-last preference kept (does not change dtype)
3231
+ if use_channels_last is None:
3232
+ use_channels_last = bool(cuda_ok) # <- never enable on MPS
3233
+ if mps_ok:
3234
+ use_channels_last = False # <- force NCHW on MPS
3235
+
3236
+ # PSF tensors strictly FP32
3237
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
3238
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
3239
+ use_spatial = _torch_should_use_spatial(psf_t[0].shape[-1])
3240
+
3241
+ # No mixed precision, no autocast
3242
+ use_amp = False
3243
+ amp_cm = None
3244
+ amp_kwargs = {}
3245
+
3246
+ # FFT gate (worth it for large kernels / SR); still FP32
3247
+ # Decide FFT vs spatial FIRST, based on native kernel sizes (no global padding!)
3248
+ K_native_max = max(int(k.shape[0]) for k in psfs)
3249
+ use_fft = False
3250
+ if device.type == "cuda" and ((K_native_max >= int(fft_kernel_threshold)) or (r > 1 and K_native_max >= max(21, int(fft_kernel_threshold) - 4))):
3251
+ use_fft = True
3252
+ # Force NCHW for FFT branch to keep channel slices contiguous/sane
3253
+ use_channels_last = False
3254
+ # Precompute FFT packs…
3255
+ psf_fft, psfT_fft = _precompute_torch_psf_ffts(
3256
+ psfs, flip_psf, Hs, Ws, device=x_t.device, dtype=torch.float32
3257
+ )
3258
+ else:
3259
+ psf_fft = psfT_fft = None
3260
+ Kmax = K_native_max
3261
+ if (Kmax % 2) == 0:
3262
+ Kmax += 1
3263
+ if any(int(k.shape[0]) != Kmax for k in psfs):
3264
+ status_cb(f"MFDeconv: normalizing PSF sizes → {Kmax}×{Kmax}")
3265
+ psfs = [_pad_kernel_to(k, Kmax) for k in psfs]
3266
+ flip_psf = [_flip_kernel(k) for k in psfs]
3267
+
3268
+ # (Re)build spatial kernels strictly as contiguous FP32 tensors
3269
+ psf_t = [_to_t(_contig(k))[None, None].to(torch.float32).contiguous() for k in psfs]
3270
+ psfT_t = [_to_t(_contig(kT))[None, None].to(torch.float32).contiguous() for kT in flip_psf]
3271
+ else:
3272
+ x_t = _contig(x).astype(np.float32, copy=False)
3273
+ num = np.zeros_like(x_t, dtype=np.float32)
3274
+ den = np.zeros_like(x_t, dtype=np.float32)
3275
+ use_amp = False
3276
+ use_fft = False
3277
+
3278
+ # ---- batched torch helper (grouped depthwise per-sample) ----
3279
+ if use_torch:
3280
+ # inside `if use_torch:` block in multiframe_deconv — replace the whole inner helper
3281
+ def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
3282
+ """
3283
+ x_bc_hw : (B,C,H,W), torch.float32 on device
3284
+ w_b1kk : (B,1,kh,kw), torch.float32 on device
3285
+ Returns (B,C,H,W) contiguous (NCHW).
3286
+ """
3287
+ F = torch.nn.functional
3288
+
3289
+ # Force standard NCHW contiguous tensors
3290
+ x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
3291
+ w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
3292
+
3293
+ kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
3294
+ pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
3295
+
3296
+ if x_bc_hw.device.type == "mps":
3297
+ # Safe, slower path: convolve each channel separately, no groups
3298
+ ys = []
3299
+ for j in range(B): # per sample
3300
+ xj = x_bc_hw[j:j+1] # (1,C,H,W)
3301
+ # reflect pad once per sample
3302
+ xj = F.pad(xj, pad, mode="reflect")
3303
+ cj_out = []
3304
+ # one shared kernel per sample j: (1,1,kh,kw)
3305
+ kj = w_b1kk[j:j+1] # keep shape (1,1,kh,kw)
3306
+ for c in range(C):
3307
+ # slice that channel as its own (1,1,H,W) tensor
3308
+ xjc = xj[:, c:c+1, ...]
3309
+ yjc = F.conv2d(xjc, kj, padding=0, groups=1) # no groups
3310
+ cj_out.append(yjc)
3311
+ ys.append(torch.cat(cj_out, dim=1)) # (1,C,H,W)
3312
+ return torch.stack([y[0] for y in ys], 0).contiguous()
3313
+
3314
+
3315
+ # ---- FAST PATH (CUDA/CPU): single grouped conv with G=B*C ----
3316
+ G = int(B * C)
3317
+ x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
3318
+ x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
3319
+ w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
3320
+ y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
3321
+ return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
3322
+
3323
+
3324
+ def _downsample_avg_bt_t(x, r_):
3325
+ if r_ <= 1:
3326
+ return x
3327
+ B, C, H, W = x.shape
3328
+ Hr, Wr = (H // r_) * r_, (W // r_) * r_
3329
+ if Hr == 0 or Wr == 0:
3330
+ return x
3331
+ return x[:, :, :Hr, :Wr].reshape(B, C, Hr // r_, r_, Wr // r_, r_).mean(dim=(3, 5))
3332
+
3333
+ def _upsample_sum_bt_t(x, r_):
3334
+ if r_ <= 1:
3335
+ return x
3336
+ return x.repeat_interleave(r_, dim=-2).repeat_interleave(r_, dim=-1)
3337
+
3338
+ def _make_pinned_batch(idx, C_expected, to_device_dtype):
3339
+ y_list, m_list, v_list = [], [], []
3340
+ for fi in idx:
3341
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3342
+ if y_chw.ndim != 3:
3343
+ raise RuntimeError(f"Frame {fi}: expected CHW, got {tuple(y_chw.shape)}")
3344
+ C_here = int(y_chw.shape[0])
3345
+ if C_expected is not None and C_here != C_expected:
3346
+ raise RuntimeError(f"Mixed channel counts: expected C={C_expected}, got C={C_here} (frame {fi})")
3347
+
3348
+ m2d = _mask_for(fi, y_chw)
3349
+ v2d = _var_for(fi, y_chw)
3350
+ y_list.append(y_chw); m_list.append(m2d); v_list.append(v2d)
3351
+
3352
+ # CPU (NCHW) tensors
3353
+ y_cpu = torch.from_numpy(np.stack(y_list, 0)).to(torch.float32).contiguous()
3354
+ m_cpu = torch.from_numpy(np.stack(m_list, 0)).to(torch.float32).contiguous()
3355
+ have_v = all(v is not None for v in v_list)
3356
+ vb_cpu = None if not have_v else torch.from_numpy(np.stack(v_list, 0)).to(torch.float32).contiguous()
3357
+
3358
+ # optional pin for faster H2D on CUDA
3359
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
3360
+ y_cpu = y_cpu.pin_memory(); m_cpu = m_cpu.pin_memory()
3361
+ if vb_cpu is not None: vb_cpu = vb_cpu.pin_memory()
3362
+
3363
+ # move to device in FP32, keep NCHW contiguous format
3364
+ yb = y_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3365
+ mb = m_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3366
+ vb = None if vb_cpu is None else vb_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
3367
+
3368
+ return yb, mb, vb
3369
+
3370
+
3371
+
3372
+ # ---- intermediates folder ----
3373
+ iter_dir = None
3374
+ hdr0_seed = None
3375
+ if save_intermediate:
3376
+ iter_dir = _iter_folder(out_path)
3377
+ status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
3378
+ try:
3379
+ hdr0 = _safe_primary_header(paths[0])
3380
+ except Exception:
3381
+ hdr0 = fits.Header()
3382
+ _save_iter_image(x, hdr0, iter_dir, "seed", color_mode)
3383
+
3384
+ # ---- iterative loop ----
3385
+
3386
+
3387
+
3388
+ auto_delta_cache = None
3389
+ if use_torch and (huber_delta < 0) and (not rho_is_l2):
3390
+ auto_delta_cache = [None] * n_frames
3391
+ early = EarlyStopper(
3392
+ tol_upd_floor=2e-4,
3393
+ tol_rel_floor=5e-4,
3394
+ early_frac=0.40,
3395
+ ema_alpha=0.5,
3396
+ patience=2,
3397
+ min_iters=min_iters,
3398
+ status_cb=status_cb
3399
+ )
3400
+
3401
+ used_iters = 0
3402
+ early_stopped = False
3403
+ with cm():
3404
+ for it in range(1, max_iters + 1):
3405
+ # reset accumulators
3406
+ if use_torch:
3407
+ num.zero_(); den.zero_()
3408
+ else:
3409
+ num.fill(0.0); den.fill(0.0)
3410
+
3411
+ if use_torch:
3412
+ # ---- batched GPU path (hardened with busy→CPU fallback) ----
3413
+ retry_cpu = False # if True, we’ll redo this iteration’s accumulators on CPU
3414
+
3415
+ try:
3416
+ frame_idx = list(range(n_frames))
3417
+ B_cur = int(max(1, auto_B))
3418
+ ci = 0
3419
+ Cn = None
3420
+
3421
+ # AMP context helper
3422
+ def _maybe_amp():
3423
+ if use_amp and (amp_cm is not None):
3424
+ return amp_cm(**amp_kwargs)
3425
+ from contextlib import nullcontext
3426
+ return nullcontext()
3427
+
3428
+ # prefetch first batch (unchanged)
3429
+ if prefetch_batches:
3430
+ idx0 = frame_idx[ci:ci+B_cur]
3431
+ if idx0:
3432
+ Cn = C_EXPECTED
3433
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx0, Cn, x_t.dtype)
3434
+ else:
3435
+ yb_next = mb_next = vb_next = None
3436
+
3437
+ while ci < n_frames:
3438
+ idx = frame_idx[ci:ci + B_cur]
3439
+ B = len(idx)
3440
+ try:
3441
+ # ---- existing gather/conv/weights/backproject code (UNCHANGED) ----
3442
+ if prefetch_batches and (yb_next is not None):
3443
+ yb, mb, vb = yb_next, mb_next, vb_next
3444
+ ci2 = ci + B_cur
3445
+ if ci2 < n_frames:
3446
+ idx2 = frame_idx[ci2:ci2+B_cur]
3447
+ Cn = Cn or (_as_chw(_sanitize_numeric(data[idx[0]])).shape[0])
3448
+ yb_next, mb_next, vb_next = _make_pinned_batch(idx2, Cn, torch.float32)
3449
+ else:
3450
+ yb_next = mb_next = vb_next = None
3451
+ else:
3452
+ Cn = C_EXPECTED
3453
+ yb, mb, vb = _make_pinned_batch(idx, Cn, torch.float32)
3454
+ if use_channels_last:
3455
+ yb = yb.contiguous(memory_format=torch.channels_last)
3456
+
3457
+ wk = torch.cat([psf_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3458
+ wkT = torch.cat([psfT_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
3459
+
3460
+ # --- predict on SR grid ---
3461
+ x_bc_hw = x_t.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
3462
+ if use_channels_last:
3463
+ x_bc_hw = x_bc_hw.contiguous(memory_format=torch.channels_last)
3464
+
3465
+ if use_fft:
3466
+ pred_list = []
3467
+ chan_tmp = torch.empty_like(x_t[0], dtype=torch.float32) # (H,W) single-channel
3468
+ for j, fi in enumerate(idx):
3469
+ C_here = x_bc_hw.shape[1]
3470
+ pred_j = torch.empty_like(x_t, dtype=torch.float32) # (C,H,W)
3471
+ Kf_pack = psf_fft[fi]
3472
+ for c in range(C_here):
3473
+ _fft_conv_same_torch(x_bc_hw[j, c], Kf_pack, out_spatial=chan_tmp) # write into chan_tmp
3474
+ pred_j[c].copy_(chan_tmp) # copy once
3475
+ pred_list.append(pred_j)
3476
+ pred_super = torch.stack(pred_list, 0).contiguous()
3477
+ else:
3478
+ pred_super = _grouped_conv_same_torch_per_sample(
3479
+ x_bc_hw, wk, B, Cn
3480
+ ).contiguous()
3481
+
3482
+ pred_low = _downsample_avg_bt_t(pred_super, r) if r > 1 else pred_super
3483
+
3484
+ # --- robust weights (UNCHANGED) ---
3485
+ rnat = yb - pred_low
3486
+
3487
+ # Build/estimate variance map on the native grid (per batch)
3488
+ if vb is None:
3489
+ # robust per-batch variance estimate
3490
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3491
+ vmap = (1.4826 * mad) ** 2
3492
+ # repeat across channels
3493
+ vmap = vmap.repeat(1, Cn, rnat.shape[-2], rnat.shape[-1]).contiguous()
3494
+ else:
3495
+ vmap = vb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3496
+
3497
+ if str(rho).lower() == "l2":
3498
+ # L2 ⇒ psi/r = 1 (no robust clipping). Weighted LS with mask/variance.
3499
+ psi_over_r = torch.ones_like(rnat, dtype=torch.float32, device=x_t.device)
3500
+ else:
3501
+ # Huber ⇒ auto-delta or fixed delta
3502
+ if huber_delta < 0:
3503
+ # ensure we have auto deltas
3504
+ if (auto_delta_cache is None) or any(auto_delta_cache[fi] is None for fi in idx) or (it % 5 == 1):
3505
+ Btmp, C_here, H0, W0 = rnat.shape
3506
+ med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
3507
+ rms = 1.4826 * mad
3508
+ # store per-frame deltas
3509
+ if auto_delta_cache is None:
3510
+ # should not happen because we gated creation earlier, but be safe
3511
+ auto_delta_cache = [None] * n_frames
3512
+ for j, fi in enumerate(idx):
3513
+ auto_delta_cache[fi] = float((-huber_delta) * torch.clamp(rms[j, 0, 0, 0], min=1e-6).item())
3514
+ deltas = torch.tensor([auto_delta_cache[fi] for fi in idx],
3515
+ device=x_t.device, dtype=torch.float32).view(B, 1, 1, 1)
3516
+ else:
3517
+ deltas = torch.tensor(float(huber_delta), device=x_t.device,
3518
+ dtype=torch.float32).view(1, 1, 1, 1)
3519
+
3520
+ absr = rnat.abs()
3521
+ psi_over_r = torch.where(absr <= deltas, torch.ones_like(absr, dtype=torch.float32),
3522
+ deltas / (absr + EPS))
3523
+
3524
+ # compose weights with mask and variance
3525
+ if DEBUG_FLAT_WEIGHTS:
3526
+ wmap_low = mb.unsqueeze(1) # debug: mask-only weighting
3527
+ else:
3528
+ m1 = mb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
3529
+ wmap_low = (psi_over_r / (vmap + EPS)) * m1
3530
+ wmap_low = torch.nan_to_num(wmap_low, nan=0.0, posinf=0.0, neginf=0.0)
3531
+
3532
+ # --- adjoint + backproject (UNCHANGED) ---
3533
+ if r > 1:
3534
+ up_y = _upsample_sum_bt_t(wmap_low * yb, r)
3535
+ up_pred = _upsample_sum_bt_t(wmap_low * pred_low, r)
3536
+ else:
3537
+ up_y, up_pred = wmap_low * yb, wmap_low * pred_low
3538
+
3539
+ if use_fft:
3540
+ back_num_list, back_den_list = [], []
3541
+ for j, fi in enumerate(idx):
3542
+ C_here = up_y.shape[1]
3543
+ bn_j = torch.empty_like(x_t, dtype=torch.float32)
3544
+ bd_j = torch.empty_like(x_t, dtype=torch.float32)
3545
+ KTf_pack = psfT_fft[fi]
3546
+ for c in range(C_here):
3547
+ _fft_conv_same_torch(up_y[j, c], KTf_pack, out_spatial=bn_j[c])
3548
+ _fft_conv_same_torch(up_pred[j, c], KTf_pack, out_spatial=bd_j[c])
3549
+ back_num_list.append(bn_j)
3550
+ back_den_list.append(bd_j)
3551
+ back_num = torch.stack(back_num_list, 0).sum(dim=0)
3552
+ back_den = torch.stack(back_den_list, 0).sum(dim=0)
3553
+ else:
3554
+ back_num = _grouped_conv_same_torch_per_sample(up_y, wkT, B, Cn).sum(dim=0)
3555
+ back_den = _grouped_conv_same_torch_per_sample(up_pred, wkT, B, Cn).sum(dim=0)
3556
+
3557
+ num += back_num
3558
+ den += back_den
3559
+ ci += B
3560
+
3561
+ if os.environ.get("MFDECONV_DEBUG_SYNC", "0") == "1":
3562
+ try:
3563
+ torch.cuda.synchronize()
3564
+ except Exception:
3565
+ pass
3566
+
3567
+ except RuntimeError as e:
3568
+ emsg = str(e).lower()
3569
+ # existing OOM backoff stays the same
3570
+ if ("out of memory" in emsg or "resource" in emsg or "alloc" in emsg) and B_cur > 1:
3571
+ B_cur = max(1, B_cur // 2)
3572
+ status_cb(f"GPU OOM: reducing batch_frames → {B_cur} and retrying this chunk.")
3573
+ if prefetch_batches:
3574
+ yb_next = mb_next = vb_next = None
3575
+ continue
3576
+ # NEW: busy/unavailable → trigger CPU retry for this iteration
3577
+ if _is_cuda_busy_error(e):
3578
+ status_cb("CUDA became unavailable mid-run → retrying this iteration on CPU and switching to CPU for the rest.")
3579
+ retry_cpu = True
3580
+ break # break the inner while; we'll recompute on CPU
3581
+ raise
3582
+
3583
+ _process_gui_events_safely()
3584
+
3585
+ except RuntimeError as e:
3586
+ # Catch outer-level CUDA busy/unavailable too (e.g., from syncs)
3587
+ if _is_cuda_busy_error(e):
3588
+ status_cb("CUDA error indicates device busy/unavailable → retrying this iteration on CPU and switching to CPU for the rest.")
3589
+ retry_cpu = True
3590
+ else:
3591
+ raise
3592
+
3593
+ if retry_cpu:
3594
+ # flip to CPU for this and subsequent iterations
3595
+ globals()["cuda_usable"] = False
3596
+ use_torch = False
3597
+ try:
3598
+ torch.cuda.empty_cache()
3599
+ except Exception:
3600
+ pass
3601
+
3602
+ try:
3603
+ import gc as _gc
3604
+ _gc.collect()
3605
+ except Exception:
3606
+ pass
3607
+
3608
+ # Rebuild accumulators on CPU for this iteration:
3609
+ # 1) convert x_t (seed/current estimate) to NumPy
3610
+ x_t = x_t.detach().cpu().numpy()
3611
+ # 2) reset accumulators as NumPy arrays
3612
+ num = np.zeros_like(x_t, dtype=np.float32)
3613
+ den = np.zeros_like(x_t, dtype=np.float32)
3614
+
3615
+ # 3) RUN THE EXISTING NUMPY ACCUMULATION (same as your else: block below)
3616
+ # 3) recompute this iteration’s accumulators on CPU
3617
+ for fi in range(n_frames):
3618
+ # load one frame as CHW (float32, sanitized)
3619
+
3620
+ y_chw = _load_frame_chw(fi) # CHW float32 from cache
3621
+
3622
+ # forward predict (super grid) with this frame's PSF
3623
+ pred_super = _conv_np_same(x_t, psfs[fi])
3624
+
3625
+ # downsample to native if SR was on
3626
+ pred_low = _downsample_avg(pred_super, r) if r > 1 else pred_super
3627
+
3628
+ # per-frame mask/variance
3629
+ m2d = _mask_for(fi, y_chw) # 2D, [0..1]
3630
+ v2d = _var_for(fi, y_chw) # 2D or None
3631
+
3632
+ # robust weights per pixel/channel
3633
+ w = _weight_map(
3634
+ y=y_chw, pred=pred_low,
3635
+ huber_delta=local_delta, # 0.0 for L2, else huber_delta
3636
+ var_map=v2d, mask=m2d
3637
+ ).astype(np.float32, copy=False)
3638
+
3639
+ # adjoint (backproject) on super grid
3640
+ if r > 1:
3641
+ up_y = _upsample_sum(w * y_chw, r, target_hw=x_t.shape[-2:])
3642
+ up_pred = _upsample_sum(w * pred_low, r, target_hw=x_t.shape[-2:])
3643
+ else:
3644
+ up_y, up_pred = (w * y_chw), (w * pred_low)
3645
+
3646
+ num += _conv_np_same(up_y, flip_psf[fi])
3647
+ den += _conv_np_same(up_pred, flip_psf[fi])
3648
+
3649
+ # ensure strictly positive denominator
3650
+ den = np.clip(den, 1e-8, None).astype(np.float32, copy=False)
3651
+
3652
+ # switch everything to NumPy for the remainder of the run
3653
+ psf_fft = psfT_fft = None # (Torch packs no longer used)
3654
+
3655
+ # ---- multiplicative RL/MM step with clamping ----
3656
+ if use_torch:
3657
+ ratio = num / (den + EPS)
3658
+ neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
3659
+ ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
3660
+ upd = torch.clamp(ratio, 1.0 / kappa, kappa)
3661
+ x_next = torch.clamp(x_t * upd, min=0.0)
3662
+ # Robust scalars
3663
+ upd_med = torch.median(torch.abs(upd - 1))
3664
+ rel_change = (torch.median(torch.abs(x_next - x_t)) /
3665
+ (torch.median(torch.abs(x_t)) + 1e-8))
3666
+
3667
+ um = float(upd_med.detach().item())
3668
+ rc = float(rel_change.detach().item())
3669
+
3670
+ if early.step(it, max_iters, um, rc):
3671
+ x_t = x_next
3672
+ used_iters = it
3673
+ early_stopped = True
3674
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3675
+ _process_gui_events_safely()
3676
+ break
3677
+
3678
+ x_t = (1.0 - relax) * x_t + relax * x_next
3679
+ else:
3680
+ ratio = num / (den + EPS)
3681
+ neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
3682
+ if np.any(neutral): ratio[neutral] = 1.0
3683
+ upd = np.clip(ratio, 1.0 / kappa, kappa)
3684
+ x_next = np.clip(x_t * upd, 0.0, None)
3685
+
3686
+ um = float(np.median(np.abs(upd - 1.0)))
3687
+ rc = float(np.median(np.abs(x_next - x_t)) / (np.median(np.abs(x_t)) + 1e-8))
3688
+
3689
+ if early.step(it, max_iters, um, rc):
3690
+ x_t = x_next
3691
+ used_iters = it
3692
+ early_stopped = True
3693
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3694
+ _process_gui_events_safely()
3695
+ break
3696
+
3697
+ x_t = (1.0 - relax) * x_t + relax * x_next
3698
+
3699
+ # --- LOW-MEM CLEANUP (per-iteration) ---
3700
+ if low_mem:
3701
+ # Torch temporaries we created in the iteration (best-effort deletes)
3702
+ to_kill = [
3703
+ "pred_super", "pred_low", "wmap_low", "yb", "mb", "vb", "wk", "wkT",
3704
+ "back_num", "back_den", "pred_list", "back_num_list", "back_den_list",
3705
+ "x_bc_hw", "up_y", "up_pred", "psi_over_r", "vmap", "rnat", "deltas", "chan_tmp"
3706
+ ]
3707
+ loc = locals()
3708
+ for _name in to_kill:
3709
+ if _name in loc:
3710
+ try:
3711
+ del loc[_name]
3712
+ except Exception:
3713
+ pass
3714
+
3715
+ # Proactively release CUDA cache every other iter
3716
+ if use_torch:
3717
+ try:
3718
+ dev = _torch_device()
3719
+ if dev.type == "cuda" and (it % 2) == 0:
3720
+ import torch as _t
3721
+ _t.cuda.empty_cache()
3722
+ except Exception:
3723
+ pass
3724
+
3725
+ # Encourage Python to return big NumPy buffers to the OS sooner
3726
+ import gc as _gc
3727
+ _gc.collect()
3728
+
3729
+
3730
+ # ---- save intermediates ----
3731
+ if save_intermediate and (it % int(max(1, save_every)) == 0):
3732
+ try:
3733
+ x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3734
+ _save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
3735
+ except Exception as _e:
3736
+ status_cb(f"Intermediate save failed at iter {it}: {_e}")
3737
+
3738
+ frac = 0.25 + 0.70 * (it / float(max_iters))
3739
+ _emit_pct(frac, f"Iteration {it}/{max_iters}")
3740
+ status_cb(f"Iter {it}/{max_iters}")
3741
+ _process_gui_events_safely()
3742
+
3743
+ if not early_stopped:
3744
+ used_iters = max_iters
3745
+
3746
+ # ---- save result ----
3747
+ _emit_pct(0.97, "saving")
3748
+ x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3749
+ if x_final.ndim == 3:
3750
+ if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
3751
+ x_final = np.moveaxis(x_final, -1, 0)
3752
+ if x_final.shape[0] == 1:
3753
+ x_final = x_final[0]
3754
+
3755
+ try:
3756
+ hdr0 = _safe_primary_header(paths[0])
3757
+ except Exception:
3758
+ hdr0 = fits.Header()
3759
+
3760
+ hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
3761
+ hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
3762
+ hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
3763
+ hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
3764
+ hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
3765
+ hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
3766
+ r = int(max(1, super_res_factor))
3767
+ hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
3768
+ if r > 1:
3769
+ hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (native px)')
3770
+ hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
3771
+ hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
3772
+ hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
3773
+ hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
3774
+
3775
+ if isinstance(x_final, np.ndarray):
3776
+ if x_final.ndim == 2:
3777
+ hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
3778
+ elif x_final.ndim == 3:
3779
+ C, H, W = x_final.shape
3780
+ hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
3781
+
3782
+ save_path = _sr_out_path(out_path, super_res_factor)
3783
+ safe_out_path = _nonclobber_path(str(save_path))
3784
+ if safe_out_path != str(save_path):
3785
+ status_cb(f"Output exists — saving as: {safe_out_path}")
3786
+ fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
3787
+
3788
+ status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
3789
+ _emit_pct(1.00, "done")
3790
+ _process_gui_events_safely()
3791
+
3792
+ try:
3793
+ if use_torch:
3794
+ try: del num, den
3795
+ except Exception as e:
3796
+ import logging
3797
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3798
+ try: del psf_t, psfT_t
3799
+ except Exception as e:
3800
+ import logging
3801
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3802
+ _free_torch_memory()
3803
+ except Exception:
3804
+ pass
3805
+
3806
+ try:
3807
+ _clear_all_caches()
3808
+ except Exception:
3809
+ pass
3810
+
3811
+ return safe_out_path
3812
+
3813
+ # -----------------------------
3814
+ # Worker
3815
+ # -----------------------------
3816
+
3817
+ class MultiFrameDeconvWorker(QObject):
3818
+ progress = pyqtSignal(str)
3819
+ finished = pyqtSignal(bool, str, str) # success, message, out_path
3820
+
3821
+ def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
3822
+ huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
3823
+ star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
3824
+ save_intermediate: bool = False,
3825
+ seed_mode: str = "robust",
3826
+ # NEW SR params
3827
+ super_res_factor: int = 1,
3828
+ sr_sigma: float = 1.1,
3829
+ sr_psf_opt_iters: int = 250,
3830
+ sr_psf_opt_lr: float = 0.1,
3831
+ star_mask_ref_path: str | None = None):
3832
+
3833
+ super().__init__(parent)
3834
+ self.aligned_paths = aligned_paths
3835
+ self.output_path = output_path
3836
+ self.iters = iters
3837
+ self.kappa = kappa
3838
+ self.color_mode = color_mode
3839
+ self.huber_delta = huber_delta
3840
+ self.min_iters = min_iters # NEW
3841
+ self.star_mask_cfg = star_mask_cfg or {}
3842
+ self.varmap_cfg = varmap_cfg or {}
3843
+ self.use_star_masks = use_star_masks
3844
+ self.use_variance_maps = use_variance_maps
3845
+ self.rho = rho
3846
+ self.save_intermediate = save_intermediate
3847
+ self.super_res_factor = int(super_res_factor)
3848
+ self.sr_sigma = float(sr_sigma)
3849
+ self.sr_psf_opt_iters = int(sr_psf_opt_iters)
3850
+ self.sr_psf_opt_lr = float(sr_psf_opt_lr)
3851
+ self.star_mask_ref_path = star_mask_ref_path
3852
+ self.seed_mode = seed_mode
3853
+
3854
+
3855
+ def _log(self, s): self.progress.emit(s)
3856
+
3857
+ def run(self):
3858
+ try:
3859
+ out = multiframe_deconv(
3860
+ self.aligned_paths,
3861
+ self.output_path,
3862
+ iters=self.iters,
3863
+ kappa=self.kappa,
3864
+ color_mode=self.color_mode,
3865
+ seed_mode=self.seed_mode,
3866
+ huber_delta=self.huber_delta,
3867
+ use_star_masks=self.use_star_masks,
3868
+ use_variance_maps=self.use_variance_maps,
3869
+ rho=self.rho,
3870
+ min_iters=self.min_iters,
3871
+ status_cb=self._log,
3872
+ star_mask_cfg=self.star_mask_cfg,
3873
+ varmap_cfg=self.varmap_cfg,
3874
+ save_intermediate=self.save_intermediate,
3875
+ super_res_factor=self.super_res_factor,
3876
+ sr_sigma=self.sr_sigma,
3877
+ sr_psf_opt_iters=self.sr_psf_opt_iters,
3878
+ sr_psf_opt_lr=self.sr_psf_opt_lr,
3879
+ star_mask_ref_path=self.star_mask_ref_path,
3880
+ )
3881
+ self.finished.emit(True, "MF deconvolution complete.", out)
3882
+ _process_gui_events_safely()
3883
+ except Exception as e:
3884
+ self.finished.emit(False, f"MF deconvolution failed: {e}", "")
3885
+ finally:
3886
+ # Hard cleanup: drop references + free GPU memory
3887
+ try:
3888
+ # Drop big Python refs that might keep tensors alive indirectly
3889
+ self.aligned_paths = []
3890
+ self.star_mask_cfg = {}
3891
+ self.varmap_cfg = {}
3892
+ except Exception:
3893
+ pass
3894
+ try:
3895
+ _free_torch_memory() # your helper: del tensors, gc.collect(), etc.
3896
+ except Exception:
3897
+ pass
3898
+ try:
3899
+ import torch as _t
3900
+ if hasattr(_t, "cuda") and _t.cuda.is_available():
3901
+ _t.cuda.synchronize()
3902
+ _t.cuda.empty_cache()
3903
+ if hasattr(_t, "mps") and getattr(_t.backends, "mps", None) and _t.backends.mps.is_available():
3904
+ # PyTorch 2.x has this
3905
+ if hasattr(_t.mps, "empty_cache"):
3906
+ _t.mps.empty_cache()
3907
+ # DirectML usually frees on GC; nothing special to call.
3908
+ except Exception:
3909
+ pass