setiastrosuitepro 1.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (394) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/data/SASP_data.fits +0 -0
  3. setiastro/data/catalogs/List_of_Galaxies_with_Distances_Gly.csv +488 -0
  4. setiastro/data/catalogs/astrobin_filters.csv +890 -0
  5. setiastro/data/catalogs/astrobin_filters_page1_local.csv +51 -0
  6. setiastro/data/catalogs/cali2.csv +63 -0
  7. setiastro/data/catalogs/cali2color.csv +65 -0
  8. setiastro/data/catalogs/celestial_catalog - original.csv +16471 -0
  9. setiastro/data/catalogs/celestial_catalog.csv +24031 -0
  10. setiastro/data/catalogs/detected_stars.csv +24784 -0
  11. setiastro/data/catalogs/fits_header_data.csv +46 -0
  12. setiastro/data/catalogs/test.csv +8 -0
  13. setiastro/data/catalogs/updated_celestial_catalog.csv +16471 -0
  14. setiastro/images/Astro_Spikes.png +0 -0
  15. setiastro/images/Background_startup.jpg +0 -0
  16. setiastro/images/HRDiagram.png +0 -0
  17. setiastro/images/LExtract.png +0 -0
  18. setiastro/images/LInsert.png +0 -0
  19. setiastro/images/Oxygenation-atm-2.svg.png +0 -0
  20. setiastro/images/RGB080604.png +0 -0
  21. setiastro/images/abeicon.png +0 -0
  22. setiastro/images/aberration.png +0 -0
  23. setiastro/images/acv_icon.png +0 -0
  24. setiastro/images/andromedatry.png +0 -0
  25. setiastro/images/andromedatry_satellited.png +0 -0
  26. setiastro/images/annotated.png +0 -0
  27. setiastro/images/aperture.png +0 -0
  28. setiastro/images/astrosuite.ico +0 -0
  29. setiastro/images/astrosuite.png +0 -0
  30. setiastro/images/astrosuitepro.icns +0 -0
  31. setiastro/images/astrosuitepro.ico +0 -0
  32. setiastro/images/astrosuitepro.png +0 -0
  33. setiastro/images/background.png +0 -0
  34. setiastro/images/background2.png +0 -0
  35. setiastro/images/benchmark.png +0 -0
  36. setiastro/images/big_moon_stabilizer_timeline.png +0 -0
  37. setiastro/images/big_moon_stabilizer_timeline_clean.png +0 -0
  38. setiastro/images/blaster.png +0 -0
  39. setiastro/images/blink.png +0 -0
  40. setiastro/images/clahe.png +0 -0
  41. setiastro/images/collage.png +0 -0
  42. setiastro/images/colorwheel.png +0 -0
  43. setiastro/images/contsub.png +0 -0
  44. setiastro/images/convo.png +0 -0
  45. setiastro/images/copyslot.png +0 -0
  46. setiastro/images/cosmic.png +0 -0
  47. setiastro/images/cosmicsat.png +0 -0
  48. setiastro/images/crop1.png +0 -0
  49. setiastro/images/cropicon.png +0 -0
  50. setiastro/images/curves.png +0 -0
  51. setiastro/images/cvs.png +0 -0
  52. setiastro/images/debayer.png +0 -0
  53. setiastro/images/denoise_cnn_custom.png +0 -0
  54. setiastro/images/denoise_cnn_graph.png +0 -0
  55. setiastro/images/disk.png +0 -0
  56. setiastro/images/dse.png +0 -0
  57. setiastro/images/exoicon.png +0 -0
  58. setiastro/images/eye.png +0 -0
  59. setiastro/images/first_quarter.png +0 -0
  60. setiastro/images/fliphorizontal.png +0 -0
  61. setiastro/images/flipvertical.png +0 -0
  62. setiastro/images/font.png +0 -0
  63. setiastro/images/freqsep.png +0 -0
  64. setiastro/images/full_moon.png +0 -0
  65. setiastro/images/functionbundle.png +0 -0
  66. setiastro/images/graxpert.png +0 -0
  67. setiastro/images/green.png +0 -0
  68. setiastro/images/gridicon.png +0 -0
  69. setiastro/images/halo.png +0 -0
  70. setiastro/images/hdr.png +0 -0
  71. setiastro/images/histogram.png +0 -0
  72. setiastro/images/hubble.png +0 -0
  73. setiastro/images/imagecombine.png +0 -0
  74. setiastro/images/invert.png +0 -0
  75. setiastro/images/isophote.png +0 -0
  76. setiastro/images/isophote_demo_figure.png +0 -0
  77. setiastro/images/isophote_demo_image.png +0 -0
  78. setiastro/images/isophote_demo_model.png +0 -0
  79. setiastro/images/isophote_demo_residual.png +0 -0
  80. setiastro/images/jwstpupil.png +0 -0
  81. setiastro/images/last_quarter.png +0 -0
  82. setiastro/images/linearfit.png +0 -0
  83. setiastro/images/livestacking.png +0 -0
  84. setiastro/images/mask.png +0 -0
  85. setiastro/images/maskapply.png +0 -0
  86. setiastro/images/maskcreate.png +0 -0
  87. setiastro/images/maskremove.png +0 -0
  88. setiastro/images/morpho.png +0 -0
  89. setiastro/images/mosaic.png +0 -0
  90. setiastro/images/multiscale_decomp.png +0 -0
  91. setiastro/images/nbtorgb.png +0 -0
  92. setiastro/images/neutral.png +0 -0
  93. setiastro/images/new_moon.png +0 -0
  94. setiastro/images/nuke.png +0 -0
  95. setiastro/images/openfile.png +0 -0
  96. setiastro/images/pedestal.png +0 -0
  97. setiastro/images/pen.png +0 -0
  98. setiastro/images/pixelmath.png +0 -0
  99. setiastro/images/platesolve.png +0 -0
  100. setiastro/images/ppp.png +0 -0
  101. setiastro/images/pro.png +0 -0
  102. setiastro/images/project.png +0 -0
  103. setiastro/images/psf.png +0 -0
  104. setiastro/images/redo.png +0 -0
  105. setiastro/images/redoicon.png +0 -0
  106. setiastro/images/rescale.png +0 -0
  107. setiastro/images/rgbalign.png +0 -0
  108. setiastro/images/rgbcombo.png +0 -0
  109. setiastro/images/rgbextract.png +0 -0
  110. setiastro/images/rotate180.png +0 -0
  111. setiastro/images/rotatearbitrary.png +0 -0
  112. setiastro/images/rotateclockwise.png +0 -0
  113. setiastro/images/rotatecounterclockwise.png +0 -0
  114. setiastro/images/satellite.png +0 -0
  115. setiastro/images/script.png +0 -0
  116. setiastro/images/selectivecolor.png +0 -0
  117. setiastro/images/simbad.png +0 -0
  118. setiastro/images/slot0.png +0 -0
  119. setiastro/images/slot1.png +0 -0
  120. setiastro/images/slot2.png +0 -0
  121. setiastro/images/slot3.png +0 -0
  122. setiastro/images/slot4.png +0 -0
  123. setiastro/images/slot5.png +0 -0
  124. setiastro/images/slot6.png +0 -0
  125. setiastro/images/slot7.png +0 -0
  126. setiastro/images/slot8.png +0 -0
  127. setiastro/images/slot9.png +0 -0
  128. setiastro/images/spcc.png +0 -0
  129. setiastro/images/spin_precession_vs_lunar_distance.png +0 -0
  130. setiastro/images/spinner.gif +0 -0
  131. setiastro/images/stacking.png +0 -0
  132. setiastro/images/staradd.png +0 -0
  133. setiastro/images/staralign.png +0 -0
  134. setiastro/images/starnet.png +0 -0
  135. setiastro/images/starregistration.png +0 -0
  136. setiastro/images/starspike.png +0 -0
  137. setiastro/images/starstretch.png +0 -0
  138. setiastro/images/statstretch.png +0 -0
  139. setiastro/images/supernova.png +0 -0
  140. setiastro/images/uhs.png +0 -0
  141. setiastro/images/undoicon.png +0 -0
  142. setiastro/images/upscale.png +0 -0
  143. setiastro/images/viewbundle.png +0 -0
  144. setiastro/images/waning_crescent_1.png +0 -0
  145. setiastro/images/waning_crescent_2.png +0 -0
  146. setiastro/images/waning_crescent_3.png +0 -0
  147. setiastro/images/waning_crescent_4.png +0 -0
  148. setiastro/images/waning_crescent_5.png +0 -0
  149. setiastro/images/waning_gibbous_1.png +0 -0
  150. setiastro/images/waning_gibbous_2.png +0 -0
  151. setiastro/images/waning_gibbous_3.png +0 -0
  152. setiastro/images/waning_gibbous_4.png +0 -0
  153. setiastro/images/waning_gibbous_5.png +0 -0
  154. setiastro/images/waxing_crescent_1.png +0 -0
  155. setiastro/images/waxing_crescent_2.png +0 -0
  156. setiastro/images/waxing_crescent_3.png +0 -0
  157. setiastro/images/waxing_crescent_4.png +0 -0
  158. setiastro/images/waxing_crescent_5.png +0 -0
  159. setiastro/images/waxing_gibbous_1.png +0 -0
  160. setiastro/images/waxing_gibbous_2.png +0 -0
  161. setiastro/images/waxing_gibbous_3.png +0 -0
  162. setiastro/images/waxing_gibbous_4.png +0 -0
  163. setiastro/images/waxing_gibbous_5.png +0 -0
  164. setiastro/images/whitebalance.png +0 -0
  165. setiastro/images/wimi_icon_256x256.png +0 -0
  166. setiastro/images/wimilogo.png +0 -0
  167. setiastro/images/wims.png +0 -0
  168. setiastro/images/wrench_icon.png +0 -0
  169. setiastro/images/xisfliberator.png +0 -0
  170. setiastro/qml/ResourceMonitor.qml +128 -0
  171. setiastro/saspro/__init__.py +20 -0
  172. setiastro/saspro/__main__.py +964 -0
  173. setiastro/saspro/_generated/__init__.py +7 -0
  174. setiastro/saspro/_generated/build_info.py +3 -0
  175. setiastro/saspro/abe.py +1379 -0
  176. setiastro/saspro/abe_preset.py +196 -0
  177. setiastro/saspro/aberration_ai.py +910 -0
  178. setiastro/saspro/aberration_ai_preset.py +224 -0
  179. setiastro/saspro/accel_installer.py +218 -0
  180. setiastro/saspro/accel_workers.py +30 -0
  181. setiastro/saspro/acv_exporter.py +379 -0
  182. setiastro/saspro/add_stars.py +627 -0
  183. setiastro/saspro/astrobin_exporter.py +1010 -0
  184. setiastro/saspro/astrospike.py +153 -0
  185. setiastro/saspro/astrospike_python.py +1841 -0
  186. setiastro/saspro/autostretch.py +198 -0
  187. setiastro/saspro/backgroundneutral.py +639 -0
  188. setiastro/saspro/batch_convert.py +328 -0
  189. setiastro/saspro/batch_renamer.py +522 -0
  190. setiastro/saspro/blemish_blaster.py +494 -0
  191. setiastro/saspro/blink_comparator_pro.py +3149 -0
  192. setiastro/saspro/bundles.py +61 -0
  193. setiastro/saspro/bundles_dock.py +114 -0
  194. setiastro/saspro/cheat_sheet.py +213 -0
  195. setiastro/saspro/clahe.py +371 -0
  196. setiastro/saspro/comet_stacking.py +1442 -0
  197. setiastro/saspro/common_tr.py +107 -0
  198. setiastro/saspro/config.py +38 -0
  199. setiastro/saspro/config_bootstrap.py +40 -0
  200. setiastro/saspro/config_manager.py +316 -0
  201. setiastro/saspro/continuum_subtract.py +1620 -0
  202. setiastro/saspro/convo.py +1403 -0
  203. setiastro/saspro/convo_preset.py +414 -0
  204. setiastro/saspro/copyastro.py +190 -0
  205. setiastro/saspro/cosmicclarity.py +1593 -0
  206. setiastro/saspro/cosmicclarity_preset.py +407 -0
  207. setiastro/saspro/crop_dialog_pro.py +1005 -0
  208. setiastro/saspro/crop_preset.py +189 -0
  209. setiastro/saspro/curve_editor_pro.py +2608 -0
  210. setiastro/saspro/curves_preset.py +375 -0
  211. setiastro/saspro/debayer.py +673 -0
  212. setiastro/saspro/debug_utils.py +29 -0
  213. setiastro/saspro/dnd_mime.py +35 -0
  214. setiastro/saspro/doc_manager.py +2727 -0
  215. setiastro/saspro/exoplanet_detector.py +2258 -0
  216. setiastro/saspro/file_utils.py +284 -0
  217. setiastro/saspro/fitsmodifier.py +748 -0
  218. setiastro/saspro/fix_bom.py +32 -0
  219. setiastro/saspro/free_torch_memory.py +48 -0
  220. setiastro/saspro/frequency_separation.py +1352 -0
  221. setiastro/saspro/function_bundle.py +1596 -0
  222. setiastro/saspro/generate_translations.py +3092 -0
  223. setiastro/saspro/ghs_dialog_pro.py +728 -0
  224. setiastro/saspro/ghs_preset.py +284 -0
  225. setiastro/saspro/graxpert.py +638 -0
  226. setiastro/saspro/graxpert_preset.py +287 -0
  227. setiastro/saspro/gui/__init__.py +0 -0
  228. setiastro/saspro/gui/main_window.py +8928 -0
  229. setiastro/saspro/gui/mixins/__init__.py +33 -0
  230. setiastro/saspro/gui/mixins/dock_mixin.py +375 -0
  231. setiastro/saspro/gui/mixins/file_mixin.py +450 -0
  232. setiastro/saspro/gui/mixins/geometry_mixin.py +503 -0
  233. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  234. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  235. setiastro/saspro/gui/mixins/menu_mixin.py +391 -0
  236. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  237. setiastro/saspro/gui/mixins/toolbar_mixin.py +1824 -0
  238. setiastro/saspro/gui/mixins/update_mixin.py +323 -0
  239. setiastro/saspro/gui/mixins/view_mixin.py +477 -0
  240. setiastro/saspro/gui/statistics_dialog.py +47 -0
  241. setiastro/saspro/halobgon.py +492 -0
  242. setiastro/saspro/header_viewer.py +448 -0
  243. setiastro/saspro/headless_utils.py +88 -0
  244. setiastro/saspro/histogram.py +760 -0
  245. setiastro/saspro/history_explorer.py +941 -0
  246. setiastro/saspro/i18n.py +168 -0
  247. setiastro/saspro/image_combine.py +421 -0
  248. setiastro/saspro/image_peeker_pro.py +1608 -0
  249. setiastro/saspro/imageops/__init__.py +37 -0
  250. setiastro/saspro/imageops/mdi_snap.py +292 -0
  251. setiastro/saspro/imageops/scnr.py +36 -0
  252. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  253. setiastro/saspro/imageops/stretch.py +236 -0
  254. setiastro/saspro/isophote.py +1186 -0
  255. setiastro/saspro/layers.py +208 -0
  256. setiastro/saspro/layers_dock.py +714 -0
  257. setiastro/saspro/lazy_imports.py +193 -0
  258. setiastro/saspro/legacy/__init__.py +2 -0
  259. setiastro/saspro/legacy/image_manager.py +2360 -0
  260. setiastro/saspro/legacy/numba_utils.py +3676 -0
  261. setiastro/saspro/legacy/xisf.py +1213 -0
  262. setiastro/saspro/linear_fit.py +537 -0
  263. setiastro/saspro/live_stacking.py +1854 -0
  264. setiastro/saspro/log_bus.py +5 -0
  265. setiastro/saspro/logging_config.py +460 -0
  266. setiastro/saspro/luminancerecombine.py +510 -0
  267. setiastro/saspro/main_helpers.py +201 -0
  268. setiastro/saspro/mask_creation.py +1090 -0
  269. setiastro/saspro/masks_core.py +56 -0
  270. setiastro/saspro/mdi_widgets.py +353 -0
  271. setiastro/saspro/memory_utils.py +666 -0
  272. setiastro/saspro/metadata_patcher.py +75 -0
  273. setiastro/saspro/mfdeconv.py +3909 -0
  274. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  275. setiastro/saspro/mfdeconvcudnn.py +3312 -0
  276. setiastro/saspro/mfdeconvsport.py +2459 -0
  277. setiastro/saspro/minorbodycatalog.py +567 -0
  278. setiastro/saspro/morphology.py +411 -0
  279. setiastro/saspro/multiscale_decomp.py +1751 -0
  280. setiastro/saspro/nbtorgb_stars.py +541 -0
  281. setiastro/saspro/numba_utils.py +3145 -0
  282. setiastro/saspro/numba_warmup.py +141 -0
  283. setiastro/saspro/ops/__init__.py +9 -0
  284. setiastro/saspro/ops/command_help_dialog.py +623 -0
  285. setiastro/saspro/ops/command_runner.py +217 -0
  286. setiastro/saspro/ops/commands.py +1594 -0
  287. setiastro/saspro/ops/script_editor.py +1105 -0
  288. setiastro/saspro/ops/scripts.py +1476 -0
  289. setiastro/saspro/ops/settings.py +637 -0
  290. setiastro/saspro/parallel_utils.py +554 -0
  291. setiastro/saspro/pedestal.py +121 -0
  292. setiastro/saspro/perfect_palette_picker.py +1105 -0
  293. setiastro/saspro/pipeline.py +110 -0
  294. setiastro/saspro/pixelmath.py +1604 -0
  295. setiastro/saspro/plate_solver.py +2480 -0
  296. setiastro/saspro/project_io.py +797 -0
  297. setiastro/saspro/psf_utils.py +136 -0
  298. setiastro/saspro/psf_viewer.py +631 -0
  299. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  300. setiastro/saspro/remove_green.py +331 -0
  301. setiastro/saspro/remove_stars.py +1599 -0
  302. setiastro/saspro/remove_stars_preset.py +446 -0
  303. setiastro/saspro/resources.py +570 -0
  304. setiastro/saspro/rgb_combination.py +208 -0
  305. setiastro/saspro/rgb_extract.py +19 -0
  306. setiastro/saspro/rgbalign.py +727 -0
  307. setiastro/saspro/runtime_imports.py +7 -0
  308. setiastro/saspro/runtime_torch.py +754 -0
  309. setiastro/saspro/save_options.py +73 -0
  310. setiastro/saspro/selective_color.py +1614 -0
  311. setiastro/saspro/sfcc.py +1530 -0
  312. setiastro/saspro/shortcuts.py +3125 -0
  313. setiastro/saspro/signature_insert.py +1106 -0
  314. setiastro/saspro/stacking_suite.py +19069 -0
  315. setiastro/saspro/star_alignment.py +7383 -0
  316. setiastro/saspro/star_alignment_preset.py +329 -0
  317. setiastro/saspro/star_metrics.py +49 -0
  318. setiastro/saspro/star_spikes.py +769 -0
  319. setiastro/saspro/star_stretch.py +542 -0
  320. setiastro/saspro/stat_stretch.py +554 -0
  321. setiastro/saspro/status_log_dock.py +78 -0
  322. setiastro/saspro/subwindow.py +3523 -0
  323. setiastro/saspro/supernovaasteroidhunter.py +1719 -0
  324. setiastro/saspro/swap_manager.py +134 -0
  325. setiastro/saspro/torch_backend.py +89 -0
  326. setiastro/saspro/torch_rejection.py +434 -0
  327. setiastro/saspro/translations/all_source_strings.json +4726 -0
  328. setiastro/saspro/translations/ar_translations.py +4096 -0
  329. setiastro/saspro/translations/de_translations.py +3728 -0
  330. setiastro/saspro/translations/es_translations.py +4169 -0
  331. setiastro/saspro/translations/fr_translations.py +4090 -0
  332. setiastro/saspro/translations/hi_translations.py +3803 -0
  333. setiastro/saspro/translations/integrate_translations.py +271 -0
  334. setiastro/saspro/translations/it_translations.py +4728 -0
  335. setiastro/saspro/translations/ja_translations.py +3834 -0
  336. setiastro/saspro/translations/pt_translations.py +3847 -0
  337. setiastro/saspro/translations/ru_translations.py +3082 -0
  338. setiastro/saspro/translations/saspro_ar.qm +0 -0
  339. setiastro/saspro/translations/saspro_ar.ts +16019 -0
  340. setiastro/saspro/translations/saspro_de.qm +0 -0
  341. setiastro/saspro/translations/saspro_de.ts +14548 -0
  342. setiastro/saspro/translations/saspro_es.qm +0 -0
  343. setiastro/saspro/translations/saspro_es.ts +16202 -0
  344. setiastro/saspro/translations/saspro_fr.qm +0 -0
  345. setiastro/saspro/translations/saspro_fr.ts +15870 -0
  346. setiastro/saspro/translations/saspro_hi.qm +0 -0
  347. setiastro/saspro/translations/saspro_hi.ts +14855 -0
  348. setiastro/saspro/translations/saspro_it.qm +0 -0
  349. setiastro/saspro/translations/saspro_it.ts +19046 -0
  350. setiastro/saspro/translations/saspro_ja.qm +0 -0
  351. setiastro/saspro/translations/saspro_ja.ts +14980 -0
  352. setiastro/saspro/translations/saspro_pt.qm +0 -0
  353. setiastro/saspro/translations/saspro_pt.ts +15024 -0
  354. setiastro/saspro/translations/saspro_ru.qm +0 -0
  355. setiastro/saspro/translations/saspro_ru.ts +11835 -0
  356. setiastro/saspro/translations/saspro_sw.qm +0 -0
  357. setiastro/saspro/translations/saspro_sw.ts +15237 -0
  358. setiastro/saspro/translations/saspro_uk.qm +0 -0
  359. setiastro/saspro/translations/saspro_uk.ts +15248 -0
  360. setiastro/saspro/translations/saspro_zh.qm +0 -0
  361. setiastro/saspro/translations/saspro_zh.ts +15289 -0
  362. setiastro/saspro/translations/sw_translations.py +3897 -0
  363. setiastro/saspro/translations/uk_translations.py +3929 -0
  364. setiastro/saspro/translations/zh_translations.py +3910 -0
  365. setiastro/saspro/versioning.py +77 -0
  366. setiastro/saspro/view_bundle.py +1558 -0
  367. setiastro/saspro/wavescale_hdr.py +648 -0
  368. setiastro/saspro/wavescale_hdr_preset.py +101 -0
  369. setiastro/saspro/wavescalede.py +683 -0
  370. setiastro/saspro/wavescalede_preset.py +230 -0
  371. setiastro/saspro/wcs_update.py +374 -0
  372. setiastro/saspro/whitebalance.py +540 -0
  373. setiastro/saspro/widgets/__init__.py +48 -0
  374. setiastro/saspro/widgets/common_utilities.py +306 -0
  375. setiastro/saspro/widgets/graphics_views.py +122 -0
  376. setiastro/saspro/widgets/image_utils.py +518 -0
  377. setiastro/saspro/widgets/minigame/game.js +991 -0
  378. setiastro/saspro/widgets/minigame/index.html +53 -0
  379. setiastro/saspro/widgets/minigame/style.css +241 -0
  380. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  381. setiastro/saspro/widgets/resource_monitor.py +313 -0
  382. setiastro/saspro/widgets/spinboxes.py +290 -0
  383. setiastro/saspro/widgets/themed_buttons.py +13 -0
  384. setiastro/saspro/widgets/wavelet_utils.py +331 -0
  385. setiastro/saspro/wimi.py +7367 -0
  386. setiastro/saspro/wims.py +588 -0
  387. setiastro/saspro/window_shelf.py +185 -0
  388. setiastro/saspro/xisf.py +1213 -0
  389. setiastrosuitepro-1.6.7.dist-info/METADATA +279 -0
  390. setiastrosuitepro-1.6.7.dist-info/RECORD +394 -0
  391. setiastrosuitepro-1.6.7.dist-info/WHEEL +4 -0
  392. setiastrosuitepro-1.6.7.dist-info/entry_points.txt +6 -0
  393. setiastrosuitepro-1.6.7.dist-info/licenses/LICENSE +674 -0
  394. setiastrosuitepro-1.6.7.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3312 @@
1
+ # pro/mfdeconvsport.py
2
+ from __future__ import annotations
3
+ import os, sys
4
+ import math
5
+ import re
6
+ import numpy as np
7
+ import tempfile
8
+ import uuid
9
+ import atexit
10
+ from astropy.io import fits
11
+ from PyQt6.QtCore import QObject, pyqtSignal
12
+ from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
13
+ from PyQt6.QtWidgets import QApplication
14
+ from PyQt6.QtCore import QThread
15
+ from threadpoolctl import threadpool_limits
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
17
+ _USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
18
+ import numpy.fft as _fft
19
+ import contextlib
20
+ from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
21
+
22
+ import gc
23
+ try:
24
+ import sep
25
+ except Exception:
26
+ sep = None
27
+ from setiastro.saspro.free_torch_memory import _free_torch_memory
28
+ torch = None # filled by runtime loader if available
29
+ TORCH_OK = False
30
+ NO_GRAD = contextlib.nullcontext # fallback
31
+
32
+ _SCRATCH_MMAPS = []
33
+ _XISF_READERS = []
34
+ try:
35
+ # e.g. your legacy module
36
+ from setiastro.saspro.legacy import xisf as _legacy_xisf
37
+ if hasattr(_legacy_xisf, "read"):
38
+ _XISF_READERS.append(lambda p: _legacy_xisf.read(p))
39
+ elif hasattr(_legacy_xisf, "open"):
40
+ _XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
41
+ except Exception:
42
+ pass
43
+ try:
44
+ # sometimes projects expose a generic load_image
45
+ from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
46
+ _XISF_READERS.append(lambda p: _generic_load_image(p)[0])
47
+ except Exception:
48
+ pass
49
+
50
+ # at top of file with the other imports
51
+ from concurrent.futures import ThreadPoolExecutor, as_completed
52
+ from queue import SimpleQueue
53
+ from setiastro.saspro.memory_utils import LRUDict
54
+
55
+ # ── XISF decode cache → memmap on disk ─────────────────────────────────
56
+ import tempfile
57
+ import threading
58
+ import uuid
59
+ import atexit
60
+ _XISF_CACHE = LRUDict(50)
61
+ _XISF_LOCK = threading.Lock()
62
+ _XISF_TMPFILES = []
63
+
64
+ from collections import OrderedDict
65
+
66
+ # ─────────────────────────────────────────────────────────────────────────────
67
+ # Unified image I/O for MFDeconv (FITS + XISF)
68
+ # ─────────────────────────────────────────────────────────────────────────────
69
+
70
+ from pathlib import Path
71
+
72
+
73
+ from collections import OrderedDict
74
+
75
+ _VAR_DTYPE = np.float32
76
+
77
+ def _mm_create(shape, dtype, scratch_dir=None, tag="scratch"):
78
+ """Create a disk-backed memmap array, zero-initialized."""
79
+ scratch_dir = scratch_dir or tempfile.gettempdir()
80
+ fn = os.path.join(scratch_dir, f"mfdeconv_{tag}_{uuid.uuid4().hex}.mmap")
81
+ mm = np.memmap(fn, mode="w+", dtype=dtype, shape=tuple(map(int, shape)))
82
+ mm[...] = 0
83
+ mm.flush()
84
+ _SCRATCH_MMAPS.append(fn)
85
+ return mm
86
+
87
+ def _maybe_memmap(shape, dtype=np.float32, *,
88
+ force_mm=False, threshold_mb=512, scratch_dir=None, tag="scratch"):
89
+ """Return either np.zeros(...) or a zeroed memmap based on size/flags."""
90
+ nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
91
+ if force_mm or (nbytes >= threshold_mb * 1024 * 1024):
92
+ return _mm_create(shape, dtype, scratch_dir, tag)
93
+ return np.zeros(shape, dtype=dtype)
94
+
95
+ def _cleanup_scratch_mm():
96
+ for fn in _SCRATCH_MMAPS[:]:
97
+ try: os.remove(fn)
98
+ except Exception as e:
99
+ import logging
100
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
101
+ _SCRATCH_MMAPS.clear()
102
+
103
+ atexit.register(_cleanup_scratch_mm)
104
+
105
+ # ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
106
+ class _FrameCHWLRU:
107
+ def __init__(self, capacity=8):
108
+ self.cap = int(max(1, capacity))
109
+ self.od = OrderedDict()
110
+
111
+ def clear(self):
112
+ self.od.clear()
113
+
114
+ def get(self, path, Ht, Wt, color_mode):
115
+ key = (path, Ht, Wt, str(color_mode).lower())
116
+ hit = self.od.get(key)
117
+ if hit is not None:
118
+ self.od.move_to_end(key)
119
+ return hit
120
+
121
+ # Load backing array cheaply (memmap for FITS, cached memmap for XISF)
122
+ ext = os.path.splitext(path)[1].lower()
123
+ if ext == ".xisf":
124
+ a = _xisf_cached_array(path) # float32, HW/HWC/CHW
125
+ else:
126
+ # FITS path: use astropy memmap (no data copy)
127
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
128
+ arr = None
129
+ for h in hdul:
130
+ if getattr(h, "data", None) is not None:
131
+ arr = h.data
132
+ break
133
+ if arr is None:
134
+ raise ValueError(f"No image data in {path}")
135
+ a = np.asarray(arr)
136
+ # dtype normalize once; keep float32
137
+ if a.dtype.kind in "ui":
138
+ a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
139
+ else:
140
+ a = a.astype(np.float32, copy=False)
141
+
142
+ # Center-crop to (Ht, Wt) and convert to CHW
143
+ a = np.asarray(a) # float32
144
+ a = _center_crop(a, Ht, Wt)
145
+
146
+ # Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
147
+ cm = str(color_mode).lower()
148
+ if cm == "luma":
149
+ a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
150
+ else:
151
+ a_chw = _as_chw(a).astype(np.float32, copy=False)
152
+ if a_chw.shape[0] == 1 and cm != "luma":
153
+ # still OK (mono data)
154
+ pass
155
+
156
+ # LRU insert
157
+ self.od[key] = a_chw
158
+ if len(self.od) > self.cap:
159
+ self.od.popitem(last=False)
160
+ return a_chw
161
+
162
+ _FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
163
+
164
+ def _clear_all_caches():
165
+ try: _clear_xisf_cache()
166
+ except Exception as e:
167
+ import logging
168
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
169
+ try: _FRAME_LRU.clear()
170
+ except Exception as e:
171
+ import logging
172
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
173
+
174
+ def _as_chw(np_img: np.ndarray) -> np.ndarray:
175
+ x = np.asarray(np_img, dtype=np.float32, order="C")
176
+ if x.size == 0:
177
+ raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
178
+ if x.ndim == 2:
179
+ return x[None, ...] # 1,H,W
180
+ if x.ndim == 3 and x.shape[0] in (1, 3):
181
+ if x.shape[0] == 0:
182
+ raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
183
+ return x
184
+ if x.ndim == 3 and x.shape[-1] in (1, 3):
185
+ if x.shape[-1] == 0:
186
+ raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
187
+ return np.moveaxis(x, -1, 0)
188
+ # last resort: treat first dim as channels, but reject zero
189
+ if x.shape[0] == 0:
190
+ raise RuntimeError(f"Zero channels in array; shape={x.shape}")
191
+ return x
192
+
193
+ def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
194
+ if a.dtype.kind in "ui":
195
+ return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
196
+ if a.dtype == np.float32:
197
+ return a
198
+ return a.astype(np.float32, copy=False)
199
+
200
+ def _xisf_cached_array(path: str) -> np.memmap:
201
+ """
202
+ Decode an XISF image exactly once and back it by a read-only float32 memmap.
203
+ Returns a memmap that can be sliced cheaply for tiles.
204
+ """
205
+ with _XISF_LOCK:
206
+ hit = _XISF_CACHE.get(path)
207
+ if hit is not None:
208
+ fn, shape = hit
209
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
210
+
211
+ # Decode once
212
+ arr, _ = _load_image_array(path) # your existing loader
213
+ if arr is None:
214
+ raise ValueError(f"XISF loader returned None for {path}")
215
+ arr = np.asarray(arr)
216
+ arrf = _normalize_to_float32(arr)
217
+
218
+ # Create a temp file-backed memmap
219
+ tmpdir = tempfile.gettempdir()
220
+ fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
221
+ mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
222
+ mm[...] = arrf[...]
223
+ mm.flush()
224
+ del mm # close writer handle; re-open below as read-only
225
+
226
+ _XISF_CACHE[path] = (fn, arrf.shape)
227
+ _XISF_TMPFILES.append(fn)
228
+ return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
229
+
230
+ def _clear_xisf_cache():
231
+ with _XISF_LOCK:
232
+ for fn in _XISF_TMPFILES:
233
+ try: os.remove(fn)
234
+ except Exception as e:
235
+ import logging
236
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
237
+ _XISF_CACHE.clear()
238
+ _XISF_TMPFILES.clear()
239
+
240
+ atexit.register(_clear_xisf_cache)
241
+
242
+
243
+ def _is_xisf(path: str) -> bool:
244
+ return os.path.splitext(path)[1].lower() == ".xisf"
245
+
246
+ def _read_xisf_numpy(path: str) -> np.ndarray:
247
+ if not _XISF_READERS:
248
+ raise RuntimeError(
249
+ "No XISF readers registered. Ensure one of "
250
+ "legacy.xisf.read/open or *.image_io.load_image is importable."
251
+ )
252
+ last_err = None
253
+ for fn in _XISF_READERS:
254
+ try:
255
+ arr = fn(path)
256
+ if isinstance(arr, tuple):
257
+ arr = arr[0]
258
+ return np.asarray(arr)
259
+ except Exception as e:
260
+ last_err = e
261
+ raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
262
+
263
+ def _fits_open_data(path: str):
264
+ # ignore_missing_simple=True lets us open headers missing SIMPLE
265
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
266
+ hdu = hdul[0]
267
+ if hdu.data is None:
268
+ # find first image HDU if primary is header-only
269
+ for h in hdul[1:]:
270
+ if getattr(h, "data", None) is not None:
271
+ hdu = h
272
+ break
273
+ data = np.asanyarray(hdu.data)
274
+ hdr = hdu.header
275
+ return data, hdr
276
+
277
+ def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
278
+ """
279
+ Return (numpy array, fits.Header or None). Color-last if 3D.
280
+ dtype left as-is; callers cast to float32. Array is C-contig & writeable.
281
+ """
282
+ if _is_xisf(path):
283
+ arr = _read_xisf_numpy(path)
284
+ hdr = None
285
+ else:
286
+ arr, hdr = _fits_open_data(path)
287
+
288
+ a = np.asarray(arr)
289
+ # Move color axis to last if 3D with a leading channel axis
290
+ if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
291
+ a = np.moveaxis(a, 0, -1)
292
+ # Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
293
+ if (not a.flags.c_contiguous) or (not a.flags.writeable):
294
+ a = np.array(a, copy=True)
295
+ return a, hdr
296
+
297
+ def _probe_hw(path: str) -> tuple[int, int, int | None]:
298
+ """
299
+ Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
300
+ """
301
+ a, _ = _load_image_array(path)
302
+ if a.ndim == 2:
303
+ return a.shape[0], a.shape[1], None
304
+ if a.ndim == 3:
305
+ h, w, c = a.shape
306
+ # treat mono-3D as (H,W,1)
307
+ if c not in (1, 3) and a.shape[0] in (1, 3):
308
+ a = np.moveaxis(a, 0, -1)
309
+ h, w, c = a.shape
310
+ return h, w, c if c in (1, 3) else None
311
+ raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
312
+
313
+ def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
314
+ Hs, Ws = [], []
315
+ for p in paths:
316
+ h, w, _ = _probe_hw(p)
317
+ h = int(h); w = int(w)
318
+ if h > 0 and w > 0:
319
+ Hs.append(h); Ws.append(w)
320
+
321
+ if not Hs:
322
+ raise ValueError("Could not determine any valid frame sizes.")
323
+ Ht = min(Hs); Wt = min(Ws)
324
+ if Ht < 8 or Wt < 8:
325
+ raise ValueError(f"Intersection too small: {Ht}x{Wt}")
326
+ return Ht, Wt
327
+
328
+
329
+ def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
330
+ """
331
+ Convert to CHW float32:
332
+ - mono → (1,H,W)
333
+ - RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
334
+ """
335
+ x = np.asarray(img)
336
+ if x.ndim == 2:
337
+ y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
338
+ return y
339
+ if x.ndim == 3:
340
+ # color-last (H,W,C) expected
341
+ if x.shape[-1] == 1:
342
+ return x[..., 0].astype(np.float32, copy=False)[None, ...]
343
+ if x.shape[-1] == 3:
344
+ if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
345
+ r, g, b = x[..., 0], x[..., 1], x[..., 2]
346
+ return np.stack([r.astype(np.float32, copy=False),
347
+ g.astype(np.float32, copy=False),
348
+ b.astype(np.float32, copy=False)], axis=0)
349
+ # luma
350
+ r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
351
+ L = 0.2126*r + 0.7152*g + 0.0722*b
352
+ return L[None, ...]
353
+ # rare mono-3D
354
+ if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
355
+ x = np.moveaxis(x, 0, -1)
356
+ return _to_chw_float32(x, color_mode)
357
+ raise ValueError(f"Unsupported image shape {x.shape}")
358
+
359
+ def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
360
+ h, w = img.shape[:2]
361
+ y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
362
+ return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
363
+
364
+ def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
365
+ """
366
+ Drop-in replacement of the old FITS-only helper.
367
+ Returns (ys, hdrs):
368
+ ys : list of CHW float32 arrays cropped to (Ht,Wt)
369
+ hdrs : list of fits.Header or None (XISF)
370
+ """
371
+ ys, hdrs = [], []
372
+ for p in paths:
373
+ arr, hdr = _load_image_array(p)
374
+ arr = _center_crop_hw(arr, Ht, Wt)
375
+ # normalize integer data to [0,1] like the rest of your code
376
+ if arr.dtype.kind in "ui":
377
+ mx = np.float32(np.iinfo(arr.dtype).max)
378
+ arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
379
+ elif arr.dtype.kind == "f":
380
+ arr = arr.astype(np.float32, copy=False)
381
+ else:
382
+ arr = arr.astype(np.float32, copy=False)
383
+
384
+ y = _to_chw_float32(arr, color_mode)
385
+ if (not y.flags.c_contiguous) or (not y.flags.writeable):
386
+ y = np.ascontiguousarray(y.astype(np.float32, copy=True))
387
+ ys.append(y)
388
+ hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
389
+ return ys, hdrs
390
+
391
+ def _safe_primary_header(path: str) -> fits.Header:
392
+ if _is_xisf(path):
393
+ # best-effort synthetic header
394
+ h = fits.Header()
395
+ h["SIMPLE"] = (True, "created by MFDeconv")
396
+ h["BITPIX"] = -32
397
+ h["NAXIS"] = 2
398
+ return h
399
+ try:
400
+ return fits.getheader(path, ext=0, ignore_missing_simple=True)
401
+ except Exception:
402
+ return fits.Header()
403
+
404
+ def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
405
+ star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
406
+ """
407
+ Worker function: compute PSF and optional star mask / varmap for one frame.
408
+ Returns (index, psf, mask_or_None, var_or_None, log_lines)
409
+ """
410
+ logs = []
411
+ def log(s): logs.append(s)
412
+
413
+ # --- PSF sizing by FWHM ---
414
+ f_hdr = _estimate_fwhm_from_header(hdr)
415
+ f_img = _estimate_fwhm_from_image(arr)
416
+ f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
417
+ if not np.isfinite(f_whm) or f_whm <= 0:
418
+ f_whm = 2.5
419
+ k_auto = _auto_ksize_from_fwhm(f_whm)
420
+
421
+ # --- Star-derived PSF with retries (dynamic det_sigma ladder) ---
422
+ psf = None
423
+
424
+ # Your existing ksize ladder
425
+ k_ladder = [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]
426
+
427
+ # New: start high to avoid detecting 10k stars; step down only if needed
428
+ sigma_ladder = [50.0, 25.0, 12.0, 6.0]
429
+
430
+ tried = set()
431
+ for det_sigma in sigma_ladder:
432
+ for k_try in k_ladder:
433
+ if (det_sigma, k_try) in tried:
434
+ continue
435
+ tried.add((det_sigma, k_try))
436
+ try:
437
+ out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=det_sigma, max_stars=80)
438
+ psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
439
+ if psf_try is not None:
440
+ psf = psf_try
441
+ break
442
+ except Exception:
443
+ psf = None
444
+ if psf is not None:
445
+ break
446
+
447
+ if psf is None:
448
+ psf = _gaussian_psf(f_whm, ksize=k_auto)
449
+
450
+ psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.25)
451
+
452
+ mask = None
453
+ var = None
454
+
455
+ if make_masks or make_varmaps:
456
+ # one background per frame (reused by both)
457
+ luma = _to_luma_local(arr)
458
+ vmc = (varmap_cfg or {})
459
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
460
+ luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
461
+ )
462
+
463
+ if make_masks:
464
+ smc = star_mask_cfg or {}
465
+ mask = _star_mask_from_precomputed(
466
+ luma, sky_map, err_scalar,
467
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
468
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
469
+ grow_px = smc.get("grow_px", GROW_PX),
470
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
471
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
472
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
473
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
474
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
475
+ status_cb = log,
476
+ )
477
+
478
+ if make_varmaps:
479
+ vmc = varmap_cfg or {}
480
+ var = _variance_map_from_precomputed(
481
+ luma, sky_map, rms_map, hdr,
482
+ smooth_sigma = vmc.get("smooth_sigma", 1.0),
483
+ floor = vmc.get("floor", 1e-8),
484
+ status_cb = log,
485
+ )
486
+
487
+ # small per-frame summary
488
+ fwhm_est = _psf_fwhm_px(psf)
489
+ logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
490
+
491
+ return i, psf, mask, var, logs
492
+
493
+
494
+ def _compute_one_worker(args):
495
+ """
496
+ Process-safe worker wrapper.
497
+ Args tuple: (i, path, make_masks, make_varmaps, star_mask_cfg, varmap_cfg, Ht, Wt, color_mode)
498
+ Returns: (i, psf, mask, var_or_None, var_path_or_None, logs)
499
+ """
500
+ (i, path, make_masks, make_varmaps, star_mask_cfg, varmap_cfg, Ht, Wt, color_mode) = (
501
+ args if len(args) == 9 else (*args, None, None, None) # allow old callers
502
+ )[:9]
503
+
504
+ # lightweight load (center-crop to Ht,Wt, get header)
505
+ try:
506
+ hdr = _safe_primary_header(path)
507
+ except Exception:
508
+ hdr = fits.Header()
509
+
510
+ # read full image then crop center; keep float32 luma/mono 2D
511
+ ext = os.path.splitext(path)[1].lower()
512
+ if ext == ".xisf":
513
+ arr_all, _ = _load_image_array(path)
514
+ arr_all = np.asarray(arr_all)
515
+ else:
516
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
517
+ arr_all = np.asarray(hdul[0].data)
518
+
519
+ # to luma/2D
520
+ if arr_all.ndim == 3:
521
+ if arr_all.shape[0] in (1, 3): # CHW → take first/luma
522
+ arr2d = arr_all[0].astype(np.float32, copy=False)
523
+ elif arr_all.shape[-1] in (1, 3): # HWC → to luma then 2D
524
+ arr2d = _to_luma_local(arr_all).astype(np.float32, copy=False)
525
+ else:
526
+ arr2d = _to_luma_local(arr_all).astype(np.float32, copy=False)
527
+ else:
528
+ arr2d = np.asarray(arr_all, dtype=np.float32)
529
+
530
+ # center-crop/pad to (Ht,Wt) if needed
531
+ H, W = arr2d.shape
532
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
533
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
534
+ arr2d = np.ascontiguousarray(arr2d[y0:y1, x0:x1], dtype=np.float32)
535
+ if arr2d.shape != (Ht, Wt):
536
+ out = np.zeros((Ht, Wt), dtype=np.float32)
537
+ oy = (Ht - arr2d.shape[0]) // 2; ox = (Wt - arr2d.shape[1]) // 2
538
+ out[oy:oy+arr2d.shape[0], ox:ox+arr2d.shape[1]] = arr2d
539
+ arr2d = out
540
+
541
+ # compute assets
542
+ i2, psf, mask, var, var_path, logs = _compute_frame_assets(
543
+ i, arr2d, hdr,
544
+ make_masks=bool(make_masks),
545
+ make_varmaps=bool(make_varmaps),
546
+ star_mask_cfg=star_mask_cfg,
547
+ varmap_cfg=varmap_cfg,
548
+ )
549
+
550
+ # Force Option B behavior: close any memmap and only pass a path
551
+ if isinstance(var, np.memmap):
552
+ try: var.flush()
553
+ except Exception as e:
554
+ import logging
555
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
556
+ var = None
557
+
558
+ return i2, psf, mask, var, var_path, logs
559
+
560
+
561
+ def _normalize_assets_result(res):
562
+ """
563
+ Accept worker results in legacy and new shapes and normalize to:
564
+ (i, psf, mask, var_or_None, var_path_or_None, logs_list)
565
+ Supported inputs:
566
+ - (i, psf, mask, var, logs)
567
+ - (i, psf, mask, var, var_path, logs)
568
+ - (i, psf, mask, var, var_mm, var_path, logs) # legacy where both returned
569
+ """
570
+ if not isinstance(res, (tuple, list)) or len(res) < 5:
571
+ raise ValueError(f"Unexpected assets result: {type(res)} len={len(res) if hasattr(res,'__len__') else 'na'}")
572
+
573
+ i = res[0]
574
+ psf = res[1]
575
+ mask = res[2]
576
+ logs = res[-1]
577
+
578
+ middle = res[3:-1] # everything between mask and logs
579
+ var = None
580
+ var_path = None
581
+
582
+ # Try to recover var (np.ndarray/np.memmap) and a path (str)
583
+ for x in middle:
584
+ if var is None and hasattr(x, "shape"): # ndarray/memmap/torch?
585
+ var = x
586
+ if var_path is None and isinstance(x, str):
587
+ var_path = x
588
+
589
+ # Back-compat for 5-tuple
590
+ if len(res) == 5:
591
+ var = middle[0] if middle else None
592
+ var_path = None
593
+
594
+ return i, psf, mask, var, var_path, logs
595
+
596
+
597
+ def _build_psf_and_assets(
598
+ paths, # list[str]
599
+ make_masks=False,
600
+ make_varmaps=False,
601
+ status_cb=lambda s: None,
602
+ save_dir: str | None = None,
603
+ star_mask_cfg: dict | None = None,
604
+ varmap_cfg: dict | None = None,
605
+ max_workers: int | None = None,
606
+ star_mask_ref_path: str | None = None, # build one mask from this frame if provided
607
+ # NEW (passed from multiframe_deconv so we don’t re-probe/convert):
608
+ Ht: int | None = None,
609
+ Wt: int | None = None,
610
+ color_mode: str = "luma",
611
+ ):
612
+ """
613
+ Parallel PSF + (optional) star mask + variance map per frame.
614
+
615
+ Changes:
616
+ • Variance maps are written to disk as memmaps (Option B) and **paths** are returned.
617
+ • If a single reference star mask is requested, it is built once and reused.
618
+ • Returns: (psfs, masks, vars_, var_paths) — vars_ contains None for varmaps.
619
+ • RAM bounded: frees per-frame temporaries, drains logs, trims frame cache.
620
+ """
621
+
622
+ # Local helpers expected from your module scope:
623
+ # _FRAME_LRU, _common_hw_from_paths, _safe_primary_header, fits, _gaussian_psf, etc.
624
+
625
+ if save_dir:
626
+ os.makedirs(save_dir, exist_ok=True)
627
+
628
+ n = len(paths)
629
+
630
+ # Resolve target intersection size if caller didn't pass it
631
+ if Ht is None or Wt is None:
632
+ Ht, Wt = _common_hw_from_paths(paths)
633
+
634
+ # Conservative default worker count to cap concurrent RAM
635
+ if max_workers is None:
636
+ try:
637
+ hw = os.cpu_count() or 4
638
+ except Exception:
639
+ hw = 4
640
+ # half the cores, max 4 (keeps sky/rms/luma concurrency modest)
641
+ max_workers = max(1, min(4, hw // 2))
642
+
643
+ # Decide executor: for any XISF, prefer threads so the memmap/cache is shared
644
+ any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
645
+ use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
646
+ Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
647
+ pool_kind = "process" if use_proc_pool else "thread"
648
+ status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
649
+
650
+ # ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
651
+ def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
652
+ a2d = np.asarray(a2d, dtype=np.float32)
653
+ H, W = int(a2d.shape[0]), int(a2d.shape[1])
654
+ # crop first if bigger
655
+ y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
656
+ y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
657
+ cropped = a2d[y0:y1, x0:x1]
658
+ ch, cw = cropped.shape
659
+ if ch == Ht and cw == Wt:
660
+ return np.ascontiguousarray(cropped, dtype=np.float32)
661
+ # pad if smaller
662
+ out = np.full((Ht, Wt), float(fill), dtype=np.float32)
663
+ oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
664
+ out[oy:oy+ch, ox:ox+cw] = cropped
665
+ return out
666
+
667
+ # ---- optional: build one mask from the reference frame and reuse ----
668
+ base_ref_mask = None
669
+ if make_masks and star_mask_ref_path:
670
+ try:
671
+ status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
672
+ ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
673
+ L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
674
+
675
+ vmc = (varmap_cfg or {})
676
+ sky_map, rms_map, err_scalar = _sep_background_precompute(
677
+ L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
678
+ )
679
+ smc = (star_mask_cfg or {})
680
+ base_ref_mask = _star_mask_from_precomputed(
681
+ L, sky_map, err_scalar,
682
+ thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
683
+ max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
684
+ grow_px = smc.get("grow_px", GROW_PX),
685
+ ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
686
+ soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
687
+ max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
688
+ keep_floor = smc.get("keep_floor", KEEP_FLOOR),
689
+ max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
690
+ status_cb = status_cb,
691
+ )
692
+ # keep mask compact
693
+ if base_ref_mask is not None and base_ref_mask.dtype != np.uint8:
694
+ base_ref_mask = (base_ref_mask > 0.5).astype(np.uint8, copy=False)
695
+ # free temps
696
+ del L, sky_map, rms_map
697
+ gc.collect()
698
+ except Exception as e:
699
+ status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
700
+ base_ref_mask = None
701
+
702
+ # for GUI safety, queue logs from workers and flush in the main thread
703
+ from queue import SimpleQueue
704
+ log_queue: SimpleQueue = SimpleQueue()
705
+
706
+ def enqueue_logs(lines):
707
+ for s in lines:
708
+ log_queue.put(s)
709
+
710
+ psfs = [None] * n
711
+ masks = ([None] * n) if make_masks else None
712
+ vars_ = ([None] * n) if make_varmaps else None # Option B: keep None placeholders
713
+ var_paths = ([None] * n) if make_varmaps else None # on-disk paths for varmaps
714
+ make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
715
+
716
+ # --- worker thunk (thread mode hits shared cache; process mode has its own worker fn) ---
717
+ def _compute_one(i: int, path: str):
718
+ with threadpool_limits(limits=1):
719
+ img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
720
+ arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
721
+ try:
722
+ hdr = _safe_primary_header(path)
723
+ except Exception:
724
+ hdr = fits.Header()
725
+ return _compute_frame_assets(
726
+ i, arr2d, hdr,
727
+ make_masks=bool(make_masks_in_worker),
728
+ make_varmaps=bool(make_varmaps),
729
+ star_mask_cfg=star_mask_cfg,
730
+ varmap_cfg=varmap_cfg,
731
+ )
732
+
733
+ # Optional: tighten the frame cache during this phase if your LRU supports it
734
+ try:
735
+ _FRAME_LRU.set_limits(max_items=max(4, max_workers + 2))
736
+ except Exception:
737
+ pass
738
+
739
+ # --- submit jobs ---
740
+ with Executor(max_workers=max_workers) as ex:
741
+ futs = []
742
+ for i, p in enumerate(paths, start=1):
743
+ status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
744
+ if use_proc_pool:
745
+ futs.append(ex.submit(
746
+ _compute_one_worker,
747
+ (i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg, Ht, Wt, color_mode)
748
+ ))
749
+ else:
750
+ futs.append(ex.submit(_compute_one, i, p))
751
+
752
+ done_cnt = 0
753
+ for fut in as_completed(futs):
754
+ res = fut.result()
755
+ i, psf, m, v, vpath, logs = _normalize_assets_result(res)
756
+
757
+ idx = i - 1
758
+ psfs[idx] = psf
759
+ if masks is not None:
760
+ masks[idx] = m
761
+ if vars_ is not None:
762
+ # Option B: don't hold open memmaps in RAM
763
+ vars_[idx] = None
764
+ if var_paths is not None:
765
+ var_paths[idx] = vpath
766
+
767
+ enqueue_logs(logs)
768
+
769
+ try:
770
+ _FRAME_LRU.drop(paths[idx])
771
+ except Exception:
772
+ pass
773
+
774
+ done_cnt += 1
775
+ while not log_queue.empty():
776
+ try:
777
+ status_cb(log_queue.get_nowait())
778
+ except Exception:
779
+ break
780
+
781
+ if (done_cnt % 8) == 0:
782
+ gc.collect()
783
+
784
+ # If we built a single reference mask, apply it to every frame (center pad/crop)
785
+ if base_ref_mask is not None and masks is not None:
786
+ for idx in range(n):
787
+ masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
788
+
789
+ # final flush of any remaining logs
790
+ while not log_queue.empty():
791
+ try:
792
+ status_cb(log_queue.get_nowait())
793
+ except Exception:
794
+ break
795
+
796
+ # save PSFs if requested
797
+ if save_dir:
798
+ for i, k in enumerate(psfs, start=1):
799
+ if k is not None:
800
+ fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
801
+ os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
802
+ )
803
+
804
+ return psfs, masks, vars_, var_paths
805
+ _ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
806
+
807
+ # known FITS-style multi-extensions (rightmost-first match)
808
+ _KNOWN_EXTS = [
809
+ ".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
810
+ ".fz", ".gz",
811
+ ".fits", ".fit"
812
+ ]
813
+
814
+ def _sanitize_token(s: str) -> str:
815
+ s = _ALLOWED.sub("_", s)
816
+ s = re.sub(r"_+", "_", s).strip("_")
817
+ return s
818
+
819
+ def _split_known_exts(p: Path) -> tuple[str, str]:
820
+ """
821
+ Return (name_body, full_ext) where full_ext is a REAL extension block
822
+ (e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
823
+ """
824
+ name = p.name
825
+ for ext in _KNOWN_EXTS:
826
+ if name.lower().endswith(ext):
827
+ body = name[:-len(ext)]
828
+ return body, ext
829
+ # fallback: single suffix
830
+ return p.stem, "".join(p.suffixes)
831
+
832
+ _SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
833
+ _EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
834
+ _RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
835
+
836
+ def _extract_size(body: str) -> str | None:
837
+ m = _SIZE_RE.search(body)
838
+ return f"{m.group(1)}x{m.group(2)}" if m else None
839
+
840
+ def _extract_exposure_secs(body: str) -> str | None:
841
+ m = _EXP_RE.search(body)
842
+ if not m:
843
+ return None
844
+ secs = int(round(float(m.group(1))))
845
+ return f"{secs}s"
846
+
847
+ def _strip_metadata_from_base(body: str) -> str:
848
+ s = body
849
+
850
+ # normalize common separators first
851
+ s = s.replace(" - ", "_")
852
+
853
+ # remove known trailing marker '_MFDeconv'
854
+ s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
855
+
856
+ # remove parenthetical copy counters e.g. '(1)'
857
+ s = re.sub(r"\(\s*\d+\s*\)$", "", s)
858
+
859
+ # remove size (with or without parens) anywhere
860
+ s = _SIZE_RE.sub("", s)
861
+
862
+ # remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
863
+ s = _EXP_RE.sub("", s)
864
+
865
+ # remove any _#x tokens
866
+ s = _RX_RE.sub("", s)
867
+
868
+ # collapse whitespace/underscores and sanitize
869
+ s = re.sub(r"[\s]+", "_", s)
870
+ s = _sanitize_token(s)
871
+ return s or "output"
872
+
873
+ def _canonical_out_name_prefix(base: str, r: int, size: str | None,
874
+ exposure_secs: str | None, tag: str = "MFDeconv") -> str:
875
+ parts = [_sanitize_token(tag), _sanitize_token(base)]
876
+ if size:
877
+ parts.append(_sanitize_token(size))
878
+ if exposure_secs:
879
+ parts.append(_sanitize_token(exposure_secs))
880
+ if int(max(1, r)) > 1:
881
+ parts.append(f"{int(r)}x")
882
+ return "_".join(parts)
883
+
884
+ def _sr_out_path(out_path: str, r: int) -> Path:
885
+ """
886
+ Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
887
+ """
888
+ p = Path(out_path)
889
+ body, real_ext = _split_known_exts(p)
890
+
891
+ # harvest metadata from the whole body (not Path.stem)
892
+ size = _extract_size(body)
893
+ ex_sec = _extract_exposure_secs(body)
894
+
895
+ # clean base
896
+ base = _strip_metadata_from_base(body)
897
+
898
+ new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
899
+ return p.with_name(f"{new_stem}{real_ext}")
900
+
901
+ def _nonclobber_path(path: str) -> str:
902
+ """
903
+ Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
904
+ """
905
+ p = Path(path)
906
+ if not p.exists():
907
+ return str(p)
908
+
909
+ # keep the true extension(s)
910
+ body, real_ext = _split_known_exts(p)
911
+
912
+ # if already has _vN, bump it
913
+ m = re.search(r"(.*)_v(\d+)$", body)
914
+ if m:
915
+ base = m.group(1); n = int(m.group(2)) + 1
916
+ else:
917
+ base = body; n = 2
918
+
919
+ while True:
920
+ candidate = p.with_name(f"{base}_v{n}{real_ext}")
921
+ if not candidate.exists():
922
+ return str(candidate)
923
+ n += 1
924
+
925
+ def _iter_folder(basefile: str) -> str:
926
+ d, fname = os.path.split(basefile)
927
+ root, ext = os.path.splitext(fname)
928
+ tgt = os.path.join(d, f"{root}.iters")
929
+ if not os.path.exists(tgt):
930
+ try:
931
+ os.makedirs(tgt, exist_ok=True)
932
+ except Exception:
933
+ # last resort: suffix (n)
934
+ n = 1
935
+ while True:
936
+ cand = os.path.join(d, f"{root}.iters ({n})")
937
+ try:
938
+ os.makedirs(cand, exist_ok=True)
939
+ return cand
940
+ except Exception:
941
+ n += 1
942
+ return tgt
943
+
944
+ def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
945
+ """
946
+ arr: numpy array (H,W) or (C,H,W) float32
947
+ tag: 'seed' or 'iter_###'
948
+ """
949
+ if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
950
+ arr = np.moveaxis(arr, -1, 0)
951
+ if arr.ndim == 3 and arr.shape[0] == 1:
952
+ arr = arr[0]
953
+
954
+ hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
955
+ hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
956
+ hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
957
+ path = os.path.join(folder, f"{tag}.fit")
958
+ # overwrite allowed inside the dedicated folder
959
+ fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
960
+ return path
961
+
962
+
963
+ def _process_gui_events_safely():
964
+ app = QApplication.instance()
965
+ if app and QThread.currentThread() is app.thread():
966
+ app.processEvents()
967
+
968
+ EPS = 1e-6
969
+
970
+ # -----------------------------
971
+ # Helpers: image prep / shapes
972
+ # -----------------------------
973
+
974
+ # new: lightweight loader that yields one frame at a time
975
+
976
+ def _to_luma_local(a: np.ndarray) -> np.ndarray:
977
+ a = np.asarray(a, dtype=np.float32)
978
+ if a.ndim == 2:
979
+ return a
980
+ if a.ndim == 3:
981
+ # mono fast paths
982
+ if a.shape[-1] == 1: # HWC mono
983
+ return a[..., 0].astype(np.float32, copy=False)
984
+ if a.shape[0] == 1: # CHW mono
985
+ return a[0].astype(np.float32, copy=False)
986
+ # RGB
987
+ if a.shape[-1] == 3: # HWC RGB
988
+ r, g, b = a[..., 0], a[..., 1], a[..., 2]
989
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
990
+ if a.shape[0] == 3: # CHW RGB
991
+ r, g, b = a[0], a[1], a[2]
992
+ return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
993
+ # fallback: average last axis
994
+ return a.mean(axis=-1).astype(np.float32, copy=False)
995
+
996
+ def _normalize_layout_single(a, color_mode):
997
+ """
998
+ Coerce to:
999
+ - 'luma' -> (H, W)
1000
+ - 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
1001
+ Accepts (H,W), (H,W,3), or (3,H,W).
1002
+ """
1003
+ a = np.asarray(a, dtype=np.float32)
1004
+
1005
+ if color_mode == "luma":
1006
+ return _to_luma_local(a) # returns (H,W)
1007
+
1008
+ # perchannel
1009
+ if a.ndim == 2:
1010
+ return a[None, ...] # (1,H,W) ← keep mono as 1 channel
1011
+ if a.ndim == 3 and a.shape[-1] == 3:
1012
+ return np.moveaxis(a, -1, 0) # (3,H,W)
1013
+ if a.ndim == 3 and a.shape[0] in (1, 3):
1014
+ return a # already (1,H,W) or (3,H,W)
1015
+ # fallback: average any weird shape into luma 1×H×W
1016
+ l = _to_luma_local(a)
1017
+ return l[None, ...]
1018
+
1019
+
1020
+ def _normalize_layout_batch(arrs, color_mode):
1021
+ return [_normalize_layout_single(a, color_mode) for a in arrs]
1022
+
1023
+ def _common_hw(data_list):
1024
+ """Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
1025
+ Hs, Ws = [], []
1026
+ for a in data_list:
1027
+ if a.ndim == 2:
1028
+ H, W = a.shape
1029
+ else:
1030
+ _, H, W = a.shape
1031
+ Hs.append(H); Ws.append(W)
1032
+ return int(min(Hs)), int(min(Ws))
1033
+
1034
+ def _center_crop(arr, Ht, Wt):
1035
+ """Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
1036
+ if arr.ndim == 2:
1037
+ H, W = arr.shape
1038
+ if H == Ht and W == Wt:
1039
+ return arr
1040
+ y0 = max(0, (H - Ht) // 2)
1041
+ x0 = max(0, (W - Wt) // 2)
1042
+ return arr[y0:y0+Ht, x0:x0+Wt]
1043
+ else:
1044
+ C, H, W = arr.shape
1045
+ if H == Ht and W == Wt:
1046
+ return arr
1047
+ y0 = max(0, (H - Ht) // 2)
1048
+ x0 = max(0, (W - Wt) // 2)
1049
+ return arr[:, y0:y0+Ht, x0:x0+Wt]
1050
+
1051
+ def _sanitize_numeric(a):
1052
+ """Replace NaN/Inf, clip negatives, make contiguous float32."""
1053
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
1054
+ a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
1055
+ return np.ascontiguousarray(a)
1056
+
1057
+ # -----------------------------
1058
+ # PSF utilities
1059
+ # -----------------------------
1060
+
1061
+ def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
1062
+ sigma = max(fwhm_px, 1.0) / 2.3548
1063
+ r = (ksize - 1) / 2
1064
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1065
+ g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
1066
+ g /= (np.sum(g) + EPS)
1067
+ return g.astype(np.float32, copy=False)
1068
+
1069
+ def _estimate_fwhm_from_header(hdr) -> float:
1070
+ for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
1071
+ if key in hdr:
1072
+ try:
1073
+ val = float(hdr[key])
1074
+ if np.isfinite(val) and val > 0:
1075
+ return val
1076
+ except Exception:
1077
+ pass
1078
+ return float("nan")
1079
+
1080
+ def _estimate_fwhm_from_image(arr) -> float:
1081
+ """Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
1082
+ if sep is None:
1083
+ return float("nan")
1084
+ try:
1085
+ img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
1086
+ bkg = sep.Background(img)
1087
+ data = _contig(img - bkg.back()) # ← ensure data is C-contig
1088
+ try:
1089
+ err = bkg.globalrms
1090
+ except Exception:
1091
+ err = float(np.median(bkg.rms()))
1092
+ sources = sep.extract(data, 6.0, err=err)
1093
+ if sources is None or len(sources) == 0:
1094
+ return float("nan")
1095
+ a = np.asarray(sources["a"], dtype=np.float32)
1096
+ b = np.asarray(sources["b"], dtype=np.float32)
1097
+ ab = (a + b) * 0.5
1098
+ sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
1099
+ if not np.isfinite(sigma) or sigma <= 0:
1100
+ return float("nan")
1101
+ return 2.3548 * sigma
1102
+ except Exception:
1103
+ return float("nan")
1104
+
1105
+ def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
1106
+ """
1107
+ Choose odd kernel size to cover about ±4σ.
1108
+ """
1109
+ sigma = max(fwhm_px, 1.0) / 2.3548
1110
+ r = int(math.ceil(4.0 * sigma))
1111
+ k = 2 * r + 1
1112
+ k = max(kmin, min(k, kmax))
1113
+ if (k % 2) == 0:
1114
+ k += 1
1115
+ return k
1116
+
1117
+ def _flip_kernel(psf):
1118
+ # PyTorch dislikes negative strides; make it contiguous.
1119
+ return np.flip(np.flip(psf, -1), -2).copy()
1120
+
1121
+ def _conv_same_np(img, psf):
1122
+ # img: (H,W) or (C,H,W) numpy
1123
+ import numpy.fft as fft
1124
+ def fftconv2(a, k):
1125
+ H, W = a.shape[-2:]
1126
+ kh, kw = k.shape
1127
+ pad_h, pad_w = H + kh - 1, W + kw - 1
1128
+ A = fft.rfftn(a, s=(pad_h, pad_w), axes=(-2, -1))
1129
+ K = fft.rfftn(k, s=(pad_h, pad_w), axes=(-2, -1))
1130
+ Y = A * K
1131
+ y = fft.irfftn(Y, s=(pad_h, pad_w), axes=(-2, -1))
1132
+ sh, sw = (kh - 1)//2, (kw - 1)//2
1133
+ return y[..., sh:sh+H, sw:sw+W]
1134
+ if img.ndim == 2:
1135
+ return fftconv2(img[None], psf)[0]
1136
+ else:
1137
+ return np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
1138
+
1139
+ def _normalize_psf(psf):
1140
+ psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
1141
+ s = float(psf.sum())
1142
+ if not np.isfinite(s) or s <= EPS:
1143
+ return psf
1144
+ return (psf / s).astype(np.float32, copy=False)
1145
+
1146
+ def _soften_psf(psf, sigma_px=0.25):
1147
+ # optional tiny Gaussian soften to reduce ringing; sigma<=0 disables
1148
+ if sigma_px <= 0:
1149
+ return psf
1150
+ r = int(max(1, round(3 * sigma_px)))
1151
+ y, x = np.mgrid[-r:r+1, -r:r+1]
1152
+ g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
1153
+ g /= g.sum() + EPS
1154
+ return _conv_same_np(psf[None], g)[0]
1155
+
1156
+ def _psf_fwhm_px(psf: np.ndarray) -> float:
1157
+ """Approximate FWHM (pixels) from second moments of a normalized kernel."""
1158
+ psf = np.maximum(psf, 0).astype(np.float32, copy=False)
1159
+ s = float(psf.sum())
1160
+ if s <= EPS:
1161
+ return float("nan")
1162
+ k = psf.shape[0]
1163
+ y, x = np.mgrid[:k, :k].astype(np.float32)
1164
+ cy = float((psf * y).sum() / s)
1165
+ cx = float((psf * x).sum() / s)
1166
+ var_y = float((psf * (y - cy) ** 2).sum() / s)
1167
+ var_x = float((psf * (x - cx) ** 2).sum() / s)
1168
+ sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
1169
+ return 2.3548 * sigma # FWHM≈2.355σ
1170
+
1171
+ STAR_MASK_MAXSIDE = 2048
1172
+ STAR_MASK_MAXOBJS = 2000 # cap number of objects
1173
+ VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
1174
+ THRESHOLD_SIGMA = 2.0
1175
+ KEEP_FLOOR = 0.20
1176
+ GROW_PX = 8
1177
+ MAX_STAR_RADIUS = 16
1178
+ SOFT_SIGMA = 2.0
1179
+ ELLIPSE_SCALE = 1.2
1180
+
1181
+ def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
1182
+ """
1183
+ One-time SEP background build; returns (sky_map, rms_map, err_scalar).
1184
+
1185
+ Guarantees:
1186
+ - Always returns a 3-tuple (sky, rms, err)
1187
+ - sky/rms are float32 and same shape as img_2d
1188
+ - Robust to sep missing, sep errors, NaNs/Infs, and tiny frames
1189
+ """
1190
+ a = np.asarray(img_2d, dtype=np.float32)
1191
+ if a.ndim != 2:
1192
+ # be strict; callers expect 2D
1193
+ raise ValueError(f"_sep_background_precompute expects 2D, got shape={a.shape}")
1194
+
1195
+ H, W = int(a.shape[0]), int(a.shape[1])
1196
+ if H == 0 or W == 0:
1197
+ # should never happen, but don't return empty tuple
1198
+ sky = np.zeros((H, W), dtype=np.float32)
1199
+ rms = np.ones((H, W), dtype=np.float32)
1200
+ return sky, rms, 1.0
1201
+
1202
+ # --- robust fallback builder (works for any input) ---
1203
+ def _fallback():
1204
+ # Use finite-only stats if possible
1205
+ finite = np.isfinite(a)
1206
+ if finite.any():
1207
+ vals = a[finite]
1208
+ med = float(np.median(vals))
1209
+ mad = float(np.median(np.abs(vals - med))) + 1e-6
1210
+ else:
1211
+ med = 0.0
1212
+ mad = 1.0
1213
+ sky = np.full((H, W), med, dtype=np.float32)
1214
+ rms = np.full((H, W), 1.4826 * mad, dtype=np.float32)
1215
+ err = float(np.median(rms))
1216
+ return sky, rms, err
1217
+
1218
+ # If sep isn't available, always fallback
1219
+ if sep is None:
1220
+ return _fallback()
1221
+
1222
+ # SEP is present: sanitize input and clamp tile sizes
1223
+ # sep can choke on NaNs/Infs
1224
+ if not np.isfinite(a).all():
1225
+ # replace non-finite with median of finite values (or 0)
1226
+ finite = np.isfinite(a)
1227
+ fill = float(np.median(a[finite])) if finite.any() else 0.0
1228
+ a = np.where(finite, a, fill).astype(np.float32, copy=False)
1229
+
1230
+ a = np.ascontiguousarray(a, dtype=np.float32)
1231
+
1232
+ # Clamp bw/bh to image size; SEP doesn't like bw/bh > dims
1233
+ bw = int(max(8, min(int(bw), W)))
1234
+ bh = int(max(8, min(int(bh), H)))
1235
+
1236
+ try:
1237
+ b = sep.Background(a, bw=bw, bh=bh, fw=3, fh=3)
1238
+
1239
+ sky = np.asarray(b.back(), dtype=np.float32)
1240
+ rms = np.asarray(b.rms(), dtype=np.float32)
1241
+
1242
+ # Ensure shape sanity (SEP should match, but be paranoid)
1243
+ if sky.shape != a.shape or rms.shape != a.shape:
1244
+ return _fallback()
1245
+
1246
+ # globalrms sometimes isn't available depending on SEP build
1247
+ err = float(getattr(b, "globalrms", np.nan))
1248
+ if not np.isfinite(err) or err <= 0:
1249
+ # robust scalar: median rms
1250
+ err = float(np.median(rms)) if rms.size else 1.0
1251
+
1252
+ return sky, rms, err
1253
+
1254
+ except Exception:
1255
+ # If SEP blows up for any reason, degrade gracefully
1256
+ return _fallback()
1257
+
1258
+
1259
+ def _star_mask_from_precomputed(
1260
+ img_2d: np.ndarray,
1261
+ sky_map: np.ndarray,
1262
+ err_scalar: float,
1263
+ *,
1264
+ thresh_sigma: float,
1265
+ max_objs: int,
1266
+ grow_px: int,
1267
+ ellipse_scale: float,
1268
+ soft_sigma: float,
1269
+ max_radius_px: int,
1270
+ keep_floor: float,
1271
+ max_side: int,
1272
+ status_cb=lambda s: None
1273
+ ) -> np.ndarray:
1274
+ """
1275
+ Build a KEEP weight map using a *downscaled detection / full-res draw* path.
1276
+ **Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
1277
+ """
1278
+ # Optional OpenCV fast path
1279
+ try:
1280
+ import cv2 as _cv2
1281
+ _HAS_CV2 = True
1282
+ except Exception:
1283
+ _HAS_CV2 = False
1284
+ _cv2 = None # type: ignore
1285
+
1286
+ H, W = map(int, img_2d.shape)
1287
+
1288
+ # Residual for detection (contiguous, separate buffer)
1289
+ data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
1290
+
1291
+ # Downscale *detection only* to speed up, never the draw step
1292
+ det = data_sub
1293
+ scale = 1.0
1294
+ if max_side and max(H, W) > int(max_side):
1295
+ scale = float(max(H, W)) / float(max_side)
1296
+ if _HAS_CV2:
1297
+ det = _cv2.resize(
1298
+ det,
1299
+ (max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
1300
+ interpolation=_cv2.INTER_AREA
1301
+ )
1302
+ else:
1303
+ s = int(max(1, round(scale)))
1304
+ det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
1305
+ scale = float(s)
1306
+
1307
+ # Threshold ladder
1308
+ thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
1309
+ thresh_sigma*8, thresh_sigma*16]
1310
+ objs = None; used = float("nan"); raw = 0
1311
+ for t in thresholds:
1312
+ cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
1313
+ n = 0 if cand is None else len(cand)
1314
+ if n == 0: continue
1315
+ if n > max_objs*12: continue
1316
+ objs, raw, used = cand, n, float(t)
1317
+ break
1318
+
1319
+ if objs is None or len(objs) == 0:
1320
+ try:
1321
+ cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
1322
+ except Exception:
1323
+ cand = None
1324
+ if cand is None or len(cand) == 0:
1325
+ status_cb("Star mask: no sources found (mask disabled for this frame).")
1326
+ return np.ones((H, W), dtype=np.float32, order="C")
1327
+ objs, raw, used = cand, len(cand), float(thresholds[-1])
1328
+
1329
+ # Brightest max_objs
1330
+ if "flux" in objs.dtype.names:
1331
+ idx = np.argsort(objs["flux"])[-int(max_objs):]
1332
+ objs = objs[idx]
1333
+ else:
1334
+ objs = objs[:int(max_objs)]
1335
+ kept = len(objs)
1336
+
1337
+ # ---- draw back on full-res into a brand-new buffer ----
1338
+ mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
1339
+ s_back = float(scale)
1340
+ MR = int(max(1, max_radius_px))
1341
+ G = int(max(0, grow_px))
1342
+ ES = float(max(0.1, ellipse_scale))
1343
+
1344
+ drawn = 0
1345
+ if _HAS_CV2:
1346
+ for o in objs:
1347
+ x = int(round(float(o["x"]) * s_back))
1348
+ y = int(round(float(o["y"]) * s_back))
1349
+ if not (0 <= x < W and 0 <= y < H):
1350
+ continue
1351
+ a = float(o["a"]) * s_back
1352
+ b = float(o["b"]) * s_back
1353
+ r = int(math.ceil(ES * max(a, b)))
1354
+ r = min(max(r, 0) + G, MR)
1355
+ if r <= 0:
1356
+ continue
1357
+ _cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
1358
+ drawn += 1
1359
+ else:
1360
+ for o in objs:
1361
+ x = int(round(float(o["x"]) * s_back))
1362
+ y = int(round(float(o["y"]) * s_back))
1363
+ if not (0 <= x < W and 0 <= y < H):
1364
+ continue
1365
+ a = float(o["a"]) * s_back
1366
+ b = float(o["b"]) * s_back
1367
+ r = int(math.ceil(ES * max(a, b)))
1368
+ r = min(max(r, 0) + G, MR)
1369
+ if r <= 0:
1370
+ continue
1371
+ y0 = max(0, y - r); y1 = min(H, y + r + 1)
1372
+ x0 = max(0, x - r); x1 = min(W, x + r + 1)
1373
+ yy, xx = np.ogrid[y0:y1, x0:x1]
1374
+ disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
1375
+ mask_u8[y0:y1, x0:x1][disk] = 1
1376
+ drawn += 1
1377
+
1378
+ # Feather + convert to keep weights
1379
+ m = mask_u8.astype(np.float32, copy=False)
1380
+ if soft_sigma > 0:
1381
+ try:
1382
+ if _HAS_CV2:
1383
+ k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
1384
+ m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
1385
+ borderType=_cv2.BORDER_REFLECT)
1386
+ else:
1387
+ from scipy.ndimage import gaussian_filter
1388
+ m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
1389
+ except Exception:
1390
+ pass
1391
+ np.clip(m, 0.0, 1.0, out=m)
1392
+
1393
+ keep = 1.0 - m
1394
+ kf = float(max(0.0, min(0.99, keep_floor)))
1395
+ keep = kf + (1.0 - kf) * keep
1396
+ np.clip(keep, 0.0, 1.0, out=keep)
1397
+
1398
+ status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
1399
+ return np.ascontiguousarray(keep, dtype=np.float32)
1400
+
1401
+
1402
+ def _ensure_scratch_dir(scratch_dir: str | None) -> str:
1403
+ """Ensure a writable scratch directory exists; default to system temp."""
1404
+ if scratch_dir is None or not isinstance(scratch_dir, str) or not scratch_dir.strip():
1405
+ scratch_dir = tempfile.gettempdir()
1406
+ os.makedirs(scratch_dir, exist_ok=True)
1407
+ return scratch_dir
1408
+
1409
+ def _mm_unique_path(scratch_dir: str, tag: str, ext: str = ".mm") -> str:
1410
+ """Return a unique file path (closed fd) for a memmap file."""
1411
+ fd, path = tempfile.mkstemp(prefix=f"sas_{tag}_", suffix=ext, dir=scratch_dir)
1412
+ try:
1413
+ os.close(fd)
1414
+ except Exception:
1415
+ pass
1416
+ return path
1417
+
1418
+ def _variance_map_from_precomputed(
1419
+ img_2d: np.ndarray,
1420
+ sky_map: np.ndarray,
1421
+ rms_map: np.ndarray,
1422
+ hdr,
1423
+ *,
1424
+ smooth_sigma: float,
1425
+ floor: float,
1426
+ status_cb=lambda s: None
1427
+ ) -> np.ndarray:
1428
+ img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
1429
+ var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
1430
+ obj_dn = np.clip(img - sky_map, 0.0, None)
1431
+
1432
+ gain = None
1433
+ for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
1434
+ if k in hdr:
1435
+ try:
1436
+ g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
1437
+ if gain is not None: break
1438
+ except Exception as e:
1439
+ import logging
1440
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1441
+
1442
+ if gain is not None:
1443
+ a_shot = 1.0 / gain
1444
+ else:
1445
+ sky_med = float(np.median(sky_map))
1446
+ varbg_med= float(np.median(var_bg_dn2))
1447
+ a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
1448
+ a_shot = float(np.clip(a_shot, 0.0, 10.0))
1449
+
1450
+ v = var_bg_dn2 + a_shot * obj_dn
1451
+ if smooth_sigma > 0:
1452
+ try:
1453
+ import cv2 as _cv2
1454
+ k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
1455
+ v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
1456
+ except Exception:
1457
+ try:
1458
+ from scipy.ndimage import gaussian_filter
1459
+ v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
1460
+ except Exception:
1461
+ pass
1462
+
1463
+ np.clip(v, float(floor), None, out=v)
1464
+ try:
1465
+ rms_med = float(np.median(np.sqrt(var_bg_dn2)))
1466
+ status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
1467
+ except Exception:
1468
+ pass
1469
+ return v.astype(np.float32, copy=False)
1470
+
1471
+
1472
+ def _variance_map_from_precomputed_memmap(
1473
+ luma: np.ndarray,
1474
+ sky_map: np.ndarray,
1475
+ rms_map: np.ndarray,
1476
+ hdr,
1477
+ *,
1478
+ smooth_sigma: float = 1.0,
1479
+ floor: float = 1e-8,
1480
+ tile_hw: tuple[int, int] = (512, 512),
1481
+ scratch_dir: str | None = None,
1482
+ tag: str = "varmap",
1483
+ status_cb=lambda s: None,
1484
+ progress_cb=lambda f, m="": None,
1485
+ ) -> str:
1486
+ """
1487
+ Compute a per-pixel variance map exactly like _variance_map_from_precomputed,
1488
+ but stream to a disk-backed memmap file (RAM bounded). Returns the file path.
1489
+ """
1490
+ luma = np.asarray(luma, dtype=np.float32)
1491
+ sky_map = np.asarray(sky_map, dtype=np.float32)
1492
+ rms_map = np.asarray(rms_map, dtype=np.float32)
1493
+
1494
+ H, W = int(luma.shape[0]), int(luma.shape[1])
1495
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
1496
+
1497
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
1498
+ var_path = _mm_unique_path(scratch_dir, tag, ext=".mm")
1499
+
1500
+ # create writeable memmap
1501
+ var_mm = np.memmap(var_path, mode="w+", dtype=np.float32, shape=(H, W))
1502
+
1503
+ tiles = [(y, min(y + th, H), x, min(x + tw, W)) for y in range(0, H, th) for x in range(0, W, tw)]
1504
+ total = len(tiles)
1505
+
1506
+ for ti, (y0, y1, x0, x1) in enumerate(tiles, start=1):
1507
+ # compute tile using the existing exact routine
1508
+ v_tile = _variance_map_from_precomputed(
1509
+ luma[y0:y1, x0:x1],
1510
+ sky_map[y0:y1, x0:x1],
1511
+ rms_map[y0:y1, x0:x1],
1512
+ hdr,
1513
+ smooth_sigma=smooth_sigma,
1514
+ floor=floor,
1515
+ status_cb=lambda _s: None
1516
+ )
1517
+ var_mm[y0:y1, x0:x1] = v_tile
1518
+
1519
+ # free tile buffer promptly to avoid resident growth
1520
+ del v_tile
1521
+
1522
+ # periodic flush + progress
1523
+ if (ti & 7) == 0 or ti == total:
1524
+ try: var_mm.flush()
1525
+ except Exception as e:
1526
+ import logging
1527
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1528
+ try:
1529
+ progress_cb(ti / float(total), f"varmap tiles {ti}/{total}")
1530
+ except Exception:
1531
+ pass
1532
+
1533
+ try: var_mm.flush()
1534
+ except Exception as e:
1535
+ import logging
1536
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
1537
+ # drop the handle (Option B)
1538
+ del var_mm
1539
+
1540
+ status_cb(f"Variance map written (memmap): {var_path} ({H}x{W})")
1541
+ return var_path
1542
+
1543
+
1544
+
1545
+ # -----------------------------
1546
+ # Robust weighting (Huber)
1547
+ # -----------------------------
1548
+
1549
+ def _estimate_scalar_variance_t(r):
1550
+ # r: tensor on device
1551
+ med = torch.median(r)
1552
+ mad = torch.median(torch.abs(r - med)) + 1e-6
1553
+ return (1.4826 * mad) ** 2
1554
+
1555
+ def _estimate_scalar_variance(a):
1556
+ med = np.median(a)
1557
+ mad = np.median(np.abs(a - med)) + 1e-6
1558
+ return float((1.4826 * mad) ** 2)
1559
+
1560
+ def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
1561
+ """
1562
+ Robust per-pixel weights for the MM update.
1563
+ W = [psi(r)/r] * 1/(var + eps) * mask
1564
+ If huber_delta < 0, delta = (-huber_delta) * RMS(residual) (auto).
1565
+ var_map: per-pixel variance (2D); if None, fall back to robust scalar via MAD.
1566
+ mask: 2D {0,1} validity; if None, treat as ones.
1567
+ """
1568
+ r = y - pred
1569
+ eps = EPS
1570
+
1571
+ # resolve Huber delta
1572
+ if huber_delta < 0:
1573
+ if TORCH_OK and isinstance(r, torch.Tensor):
1574
+ med = torch.median(r)
1575
+ mad = torch.median(torch.abs(r - med)) + 1e-6
1576
+ rms = 1.4826 * mad
1577
+ delta = (-huber_delta) * torch.clamp(rms, min=1e-6)
1578
+ else:
1579
+ med = np.median(r)
1580
+ mad = np.median(np.abs(r - med)) + 1e-6
1581
+ rms = 1.4826 * mad
1582
+ delta = (-huber_delta) * max(rms, 1e-6)
1583
+ else:
1584
+ delta = huber_delta
1585
+
1586
+ # psi(r)/r
1587
+ if TORCH_OK and isinstance(r, torch.Tensor):
1588
+ absr = torch.abs(r)
1589
+ if float(delta) > 0:
1590
+ psi_over_r = torch.where(absr <= delta, torch.ones_like(r), delta / (absr + eps))
1591
+ else:
1592
+ psi_over_r = torch.ones_like(r)
1593
+ if var_map is None:
1594
+ v = _estimate_scalar_variance_t(r)
1595
+ else:
1596
+ v = var_map
1597
+ if v.ndim == 2 and r.ndim == 3:
1598
+ v = v[None, ...] # broadcast over channels
1599
+ w = psi_over_r / (v + eps)
1600
+ if mask is not None:
1601
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1602
+ w = w * m
1603
+ return w
1604
+ else:
1605
+ absr = np.abs(r)
1606
+ if float(delta) > 0:
1607
+ psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + eps)).astype(np.float32)
1608
+ else:
1609
+ psi_over_r = np.ones_like(r, dtype=np.float32)
1610
+ if var_map is None:
1611
+ v = _estimate_scalar_variance(r)
1612
+ else:
1613
+ v = var_map
1614
+ if v.ndim == 2 and r.ndim == 3:
1615
+ v = v[None, ...]
1616
+ w = psi_over_r / (v + eps)
1617
+ if mask is not None:
1618
+ m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
1619
+ w = w * m
1620
+ return w
1621
+
1622
+
1623
+ # -----------------------------
1624
+ # Torch / conv
1625
+ # -----------------------------
1626
+
1627
+ def _fftshape_same(H, W, kh, kw):
1628
+ return H + kh - 1, W + kw - 1
1629
+
1630
+ # ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
1631
+ def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
1632
+ tfft = torch.fft
1633
+ psf_fft, psfT_fft = [], []
1634
+ for k, kT in zip(psfs, flip_psf):
1635
+ kh, kw = k.shape
1636
+ padH, padW = _fftshape_same(H, W, kh, kw)
1637
+
1638
+ # shift the small kernels to the origin, then FFT into padded size
1639
+ k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
1640
+ kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
1641
+
1642
+ Kf = tfft.rfftn(k_small, s=(padH, padW))
1643
+ KTf = tfft.rfftn(kT_small, s=(padH, padW))
1644
+
1645
+ psf_fft.append((Kf, padH, padW, kh, kw))
1646
+ psfT_fft.append((KTf, padH, padW, kh, kw))
1647
+ return psf_fft, psfT_fft
1648
+
1649
+
1650
+
1651
+ # ---------- NumPy FFT helpers ----------
1652
+ def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
1653
+ import numpy.fft as fft
1654
+ meta, Kfs, KTfs = [], [], []
1655
+ for k, kT in zip(psfs, flip_psf):
1656
+ kh, kw = k.shape
1657
+ fftH, fftW = _fftshape_same(H, W, kh, kw)
1658
+ Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
1659
+ KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
1660
+ meta.append((kh, kw, fftH, fftW))
1661
+ return Kfs, KTfs, meta
1662
+
1663
+ def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
1664
+ import numpy.fft as fft
1665
+ if a.ndim == 2:
1666
+ A = fft.rfftn(a, s=(fftH, fftW))
1667
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1668
+ sh, sw = kh // 2, kw // 2
1669
+ out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
1670
+ return out
1671
+ else:
1672
+ C, H, W = a.shape
1673
+ acc = []
1674
+ for c in range(C):
1675
+ A = fft.rfftn(a[c], s=(fftH, fftW))
1676
+ y = fft.irfftn(A * Kf, s=(fftH, fftW))
1677
+ sh, sw = kh // 2, kw // 2
1678
+ acc.append(y[sh:sh+H, sw:sw+W])
1679
+ out[...] = np.stack(acc, 0)
1680
+ return out
1681
+
1682
+
1683
+
1684
+ def _torch_device():
1685
+ if TORCH_OK and (torch is not None):
1686
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1687
+ return torch.device("cuda")
1688
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
1689
+ return torch.device("mps")
1690
+ # DirectML: we passed dml_device from outer scope; keep a module-global
1691
+ if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
1692
+ return globals()["dml_device"]
1693
+ return torch.device("cpu")
1694
+
1695
+ def _to_t(x: np.ndarray):
1696
+ if not (TORCH_OK and (torch is not None)):
1697
+ raise RuntimeError("Torch path requested but torch is unavailable")
1698
+ device = _torch_device()
1699
+ t = torch.from_numpy(x)
1700
+ # DirectML wants explicit .to(device)
1701
+ return t.to(device, non_blocking=True) if str(device) != "cpu" else t
1702
+
1703
+ def _contig(x):
1704
+ return np.ascontiguousarray(x, dtype=np.float32)
1705
+
1706
+ def _conv_same_torch(img_t, psf_t):
1707
+ """
1708
+ img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
1709
+ psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
1710
+ Pads with 'reflect' to avoid zero-padding ringing.
1711
+ """
1712
+ kh, kw = psf_t.shape[-2:]
1713
+ pad = (kw // 2, kw - kw // 2 - 1, # left, right
1714
+ kh // 2, kh - kh // 2 - 1) # top, bottom
1715
+
1716
+ if img_t.ndim == 2:
1717
+ x = img_t[None, None]
1718
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1719
+ y = torch.nn.functional.conv2d(x, psf_t, padding=0)
1720
+ return y[0, 0]
1721
+ else:
1722
+ C = img_t.shape[0]
1723
+ x = img_t[None]
1724
+ x = torch.nn.functional.pad(x, pad, mode="reflect")
1725
+ w = psf_t.repeat(C, 1, 1, 1)
1726
+ y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
1727
+ return y[0]
1728
+
1729
+ def _safe_inference_context():
1730
+ """
1731
+ Return a valid, working no-grad context:
1732
+ - prefer torch.inference_mode() if it exists *and* can be entered,
1733
+ - otherwise fall back to torch.no_grad(),
1734
+ - if torch is unavailable, return NO_GRAD.
1735
+ """
1736
+ if not (TORCH_OK and (torch is not None)):
1737
+ return NO_GRAD
1738
+
1739
+ cm = getattr(torch, "inference_mode", None)
1740
+ if cm is None:
1741
+ return torch.no_grad
1742
+
1743
+ # Probe inference_mode once; if it explodes on this build, fall back.
1744
+ try:
1745
+ with cm():
1746
+ pass
1747
+ return cm
1748
+ except Exception:
1749
+ return torch.no_grad
1750
+
1751
+ def _ensure_mask_list(masks, data):
1752
+ # 1s where valid, 0s where invalid (soft edges allowed)
1753
+ if masks is None:
1754
+ return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
1755
+ out = []
1756
+ for a, m in zip(data, masks):
1757
+ base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
1758
+ if m is None:
1759
+ out.append(np.ones_like(base, dtype=np.float32))
1760
+ else:
1761
+ mm = np.asarray(m, dtype=np.float32)
1762
+ if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
1763
+ mm = mm[0]
1764
+ if mm.shape != base.shape:
1765
+ # center crop to match (common intersection already applied)
1766
+ Ht, Wt = base.shape
1767
+ mm = _center_crop(mm, Ht, Wt)
1768
+ # keep as float weights in [0,1] (do not threshold!)
1769
+ out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
1770
+ return out
1771
+
1772
+ def _ensure_var_list(variances, data):
1773
+ # If None, we’ll estimate a robust scalar per frame on-the-fly.
1774
+ if variances is None:
1775
+ return [None]*len(data)
1776
+ out = []
1777
+ for a, v in zip(data, variances):
1778
+ if v is None:
1779
+ out.append(None)
1780
+ else:
1781
+ vv = np.asarray(v, dtype=np.float32)
1782
+ if vv.ndim == 3:
1783
+ vv = vv[0]
1784
+ base = a if a.ndim==2 else a[0]
1785
+ if vv.shape != base.shape:
1786
+ Ht, Wt = base.shape
1787
+ vv = _center_crop(vv, Ht, Wt)
1788
+ # clip tiny/negatives
1789
+ vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
1790
+ out.append(vv)
1791
+ return out
1792
+
1793
+ # ---- SR operators (downsample / upsample-sum) ----
1794
+ def _downsample_avg(img, r: int):
1795
+ """Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
1796
+ if r <= 1:
1797
+ return img
1798
+ a = np.asarray(img, dtype=np.float32)
1799
+ if a.ndim == 2:
1800
+ H, W = a.shape
1801
+ Hs, Ws = (H // r) * r, (W // r) * r
1802
+ a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
1803
+ return a
1804
+ else:
1805
+ C, H, W = a.shape
1806
+ Hs, Ws = (H // r) * r, (W // r) * r
1807
+ a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
1808
+ return a
1809
+
1810
+ def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
1811
+ """Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
1812
+ For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
1813
+ """
1814
+ if r <= 1:
1815
+ return img
1816
+ a = np.asarray(img, dtype=np.float32)
1817
+ if a.ndim == 2:
1818
+ H, W = a.shape
1819
+ out = np.kron(a, np.ones((r, r), dtype=np.float32))
1820
+ else:
1821
+ C, H, W = a.shape
1822
+ out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
1823
+ if target_hw is not None:
1824
+ Ht, Wt = target_hw
1825
+ out = _center_crop(out, Ht, Wt)
1826
+ return out
1827
+
1828
+ def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
1829
+ r = (ksize - 1) // 2
1830
+ y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
1831
+ g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
1832
+ g /= g.sum() + EPS
1833
+ return g
1834
+
1835
+ def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
1836
+ # lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
1837
+ return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
1838
+
1839
+ def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
1840
+ iters: int = 500, lr: float = 0.1) -> np.ndarray:
1841
+ """
1842
+ Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
1843
+ where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
1844
+ """
1845
+ f = np.asarray(f_native, dtype=np.float32)
1846
+ k = int(f.shape[0]); assert f.shape[0] == f.shape[1]
1847
+ kr = int(k * r)
1848
+
1849
+ # build Gaussian pre-blur at native scale (match paper §4.2)
1850
+ g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
1851
+
1852
+ # init h by zero-insertion (nearest upsample of f) then deconvolving g very mildly
1853
+ h0 = np.zeros((kr, kr), dtype=np.float32)
1854
+ h0[::r, ::r] = f
1855
+ h0 = _normalize_psf(h0)
1856
+
1857
+ if TORCH_OK:
1858
+ dev = _torch_device()
1859
+ t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
1860
+ f_t = torch.tensor(f, device=dev, dtype=torch.float32)
1861
+ g_t = torch.tensor(g, device=dev, dtype=torch.float32)
1862
+ opt = torch.optim.Adam([t], lr=lr)
1863
+ for _ in range(max(10, iters)):
1864
+ opt.zero_grad(set_to_none=True)
1865
+ H, W = t.shape
1866
+ Hr, Wr = H//r, W//r
1867
+ th = t[:Hr*r, :Wr*r].reshape(Hr, r, Wr, r).mean(dim=(1,3))
1868
+ # conv native: (Dh) * g
1869
+ conv = torch.nn.functional.conv2d(th[None,None], g_t[None,None], padding=g_t.shape[-1]//2)[0,0]
1870
+ loss = torch.mean((conv - f_t)**2)
1871
+ loss.backward()
1872
+ opt.step()
1873
+ with torch.no_grad():
1874
+ t.clamp_(min=0.0)
1875
+ t /= (t.sum() + 1e-8)
1876
+ h = t.detach().cpu().numpy().astype(np.float32)
1877
+ else:
1878
+ # Tiny gradient-descent fallback on numpy
1879
+ h = h0.copy()
1880
+ eta = float(lr)
1881
+ for _ in range(max(50, iters)):
1882
+ Dh = _downsample_avg(h, r)
1883
+ conv = _conv2_same_np(Dh, g)
1884
+ resid = (conv - f)
1885
+ # backprop through conv and D: grad wrt Dh is resid * g^T conv; adjoint of D is upsample-sum
1886
+ grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
1887
+ grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
1888
+ h = np.clip(h - eta * grad_h, 0.0, None)
1889
+ s = float(h.sum()); h /= (s + 1e-8)
1890
+ eta *= 0.995
1891
+ return _normalize_psf(h)
1892
+
1893
+ def _downsample_avg_t(x, r: int):
1894
+ """
1895
+ Average-pool over non-overlapping r×r blocks.
1896
+ Works for (H,W) or (C,H,W). Crops to multiples of r.
1897
+ """
1898
+ if r <= 1:
1899
+ return x
1900
+ if x.ndim == 2:
1901
+ H, W = x.shape
1902
+ Hr, Wr = (H // r) * r, (W // r) * r
1903
+ if Hr == 0 or Wr == 0:
1904
+ return x # nothing to pool
1905
+ x2 = x[:Hr, :Wr]
1906
+ return x2.view(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
1907
+ else:
1908
+ C, H, W = x.shape
1909
+ Hr, Wr = (H // r) * r, (W // r) * r
1910
+ if Hr == 0 or Wr == 0:
1911
+ return x
1912
+ x2 = x[:, :Hr, :Wr]
1913
+ return x2.view(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
1914
+
1915
+ def _upsample_sum_t(x, r: int):
1916
+ if r <= 1:
1917
+ return x
1918
+ if x.ndim == 2:
1919
+ return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
1920
+ else:
1921
+ return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
1922
+
1923
+ def _sep_bg_rms(frames):
1924
+ """Return a robust background RMS using SEP's background model on the first frame."""
1925
+ if sep is None or not frames:
1926
+ return None
1927
+ try:
1928
+ y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
1929
+ a = np.ascontiguousarray(y0, dtype=np.float32)
1930
+ b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
1931
+ try:
1932
+ rms_val = float(b.globalrms)
1933
+ except Exception:
1934
+ # some SEP builds don’t expose globalrms; fall back to the map’s median
1935
+ rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
1936
+ return rms_val
1937
+ except Exception:
1938
+ return None
1939
+
1940
+ # =========================
1941
+ # Memory/streaming helpers
1942
+ # =========================
1943
+
1944
+ def _approx_bytes(arr_like_shape, dtype=np.float32):
1945
+ """Rough byte estimator for a given shape/dtype."""
1946
+ return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
1947
+
1948
+
1949
+
1950
+ def _read_shape_fast(path) -> tuple[int,int,int]:
1951
+ if _is_xisf(path):
1952
+ a, _ = _load_image_array(path)
1953
+ if a is None:
1954
+ raise ValueError(f"No data in {path}")
1955
+ a = np.asarray(a)
1956
+ else:
1957
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
1958
+ a = hdul[0].data
1959
+ if a is None:
1960
+ raise ValueError(f"No data in {path}")
1961
+
1962
+ # common logic for both XISF and FITS
1963
+ if a.ndim == 2:
1964
+ H, W = a.shape
1965
+ return (1, int(H), int(W))
1966
+ if a.ndim == 3:
1967
+ if a.shape[-1] in (1, 3): # HWC
1968
+ C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
1969
+ return (1 if C == 1 else 3, H, W)
1970
+ if a.shape[0] in (1, 3): # CHW
1971
+ return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
1972
+ s = tuple(map(int, a.shape))
1973
+ H, W = s[-2], s[-1]
1974
+ return (1, H, W)
1975
+
1976
+ def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
1977
+ """FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
1978
+ ext = os.path.splitext(path)[1].lower()
1979
+
1980
+ if ext == ".xisf":
1981
+ a, _ = _load_image_array(path) # helper returns array-like + hdr/metadata
1982
+ if a is None:
1983
+ raise ValueError(f"XISF loader returned None for {path}")
1984
+ a = np.asarray(a)
1985
+ if a.ndim == 2: # HW
1986
+ return np.array(a[y0:y1, x0:x1], copy=True)
1987
+ elif a.ndim == 3:
1988
+ if a.shape[-1] in (1, 3): # HWC
1989
+ out = a[y0:y1, x0:x1, :]
1990
+ if out.shape[-1] == 1:
1991
+ out = out[..., 0]
1992
+ return np.array(out, copy=True)
1993
+ elif a.shape[0] in (1, 3): # CHW
1994
+ out = a[:, y0:y1, x0:x1]
1995
+ if out.shape[0] == 1:
1996
+ out = out[0]
1997
+ return np.array(out, copy=True)
1998
+ else:
1999
+ raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
2000
+ else:
2001
+ raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
2002
+
2003
+ # FITS
2004
+ with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
2005
+ a = None
2006
+ for h in hdul:
2007
+ if getattr(h, "data", None) is not None:
2008
+ a = h.data
2009
+ break
2010
+ if a is None:
2011
+ raise ValueError(f"No image data in {path}")
2012
+
2013
+ a = np.asarray(a)
2014
+
2015
+ if a.ndim == 2: # HW
2016
+ return np.array(a[y0:y1, x0:x1], copy=True)
2017
+
2018
+ if a.ndim == 3:
2019
+ if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
2020
+ out = a[:, y0:y1, x0:x1]
2021
+ if out.shape[0] == 1: out = out[0]
2022
+ return np.array(out, copy=True)
2023
+ if a.shape[-1] in (1, 3): # HWC
2024
+ out = a[y0:y1, x0:x1, :]
2025
+ if out.shape[-1] == 1: out = out[..., 0]
2026
+ return np.array(out, copy=True)
2027
+
2028
+ # Fallback: assume last two axes are spatial (…, H, W)
2029
+ try:
2030
+ out = a[(..., slice(y0, y1), slice(x0, x1))]
2031
+ return np.array(out, copy=True)
2032
+ except Exception:
2033
+ raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
2034
+
2035
+ def _select_io_workers(n_frames: int,
2036
+ user_io_workers: int | None,
2037
+ tile_hw: tuple[int,int] = (256,256),
2038
+ disk_bound: bool = True) -> int:
2039
+ """Heuristic picker for I/O threadpool size."""
2040
+ try:
2041
+ cpu = os.cpu_count() or 4
2042
+ except Exception:
2043
+ cpu = 4
2044
+
2045
+ if user_io_workers is not None:
2046
+ # Respect caller, but clamp to sane bounds
2047
+ return max(1, min(cpu, int(user_io_workers)))
2048
+
2049
+ # Default: don’t oversubscribe CPU; don’t exceed frame count
2050
+ base = min(cpu, 8, int(n_frames))
2051
+
2052
+ # If we’re disk-bound (memmaps/tiling) or tiles are small, cap lower to reduce thrash
2053
+ th, tw = tile_hw
2054
+ if disk_bound or (th * tw <= 256 * 256):
2055
+ base = min(base, 4)
2056
+
2057
+ return max(1, base)
2058
+
2059
+
2060
+ def _seed_median_streaming(
2061
+ paths,
2062
+ Ht,
2063
+ Wt,
2064
+ *,
2065
+ color_mode="luma",
2066
+ tile_hw=(256, 256),
2067
+ status_cb=lambda s: None,
2068
+ progress_cb=lambda f, m="": None,
2069
+ io_workers: int = 4,
2070
+ scratch_dir: str | None = None,
2071
+ tile_loader=None, # NEW
2072
+ ):
2073
+ import time
2074
+ import gc
2075
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2076
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
2077
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2078
+
2079
+ # Prefer channel hint from tile_loader (so we don't reopen original files)
2080
+ if str(color_mode).lower() == "luma":
2081
+ want_c = 1
2082
+ else:
2083
+ want_c = getattr(tile_loader, "want_c", None)
2084
+ if not want_c:
2085
+ want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
2086
+
2087
+ seed = (np.zeros((Ht, Wt), np.float32)
2088
+ if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32))
2089
+
2090
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt))
2091
+ for y in range(0, Ht, th)
2092
+ for x in range(0, Wt, tw)]
2093
+ total = len(tiles)
2094
+ n_frames = len(paths)
2095
+
2096
+ io_workers_eff = _select_io_workers(n_frames, io_workers, tile_hw=tile_hw, disk_bound=True)
2097
+
2098
+ def _tile_msg(ti, tn): return f"median tiles {ti}/{tn}"
2099
+ done = 0
2100
+
2101
+ def _read_slab_for_channel(y0, y1, x0, x1, csel=None):
2102
+ h, w = (y1 - y0), (x1 - x0)
2103
+ # use fp16 slabs to halve peak RAM; we cast to f32 only for the median op
2104
+ slab = np.empty((n_frames, h, w), np.float16)
2105
+
2106
+ def _load_one(i):
2107
+ if tile_loader is not None:
2108
+ t = tile_loader(i, y0, y1, x0, x1, csel=csel)
2109
+ else:
2110
+ t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
2111
+ # luma/channel selection only for file reads
2112
+ if want_c == 1:
2113
+ if t.ndim != 2:
2114
+ t = _to_luma_local(t)
2115
+ else:
2116
+ if t.ndim == 2:
2117
+ pass
2118
+ elif t.ndim == 3 and t.shape[-1] == 3:
2119
+ t = t[..., int(csel)]
2120
+ elif t.ndim == 3 and t.shape[0] == 3:
2121
+ t = t[int(csel)]
2122
+ else:
2123
+ t = _to_luma_local(t)
2124
+
2125
+ # normalize to [0,1] if integer; store as fp16
2126
+ if t.dtype.kind in "ui":
2127
+ t = (t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)).astype(np.float16)
2128
+ else:
2129
+ t = t.astype(np.float16, copy=False)
2130
+ return i, np.ascontiguousarray(t, dtype=np.float16)
2131
+
2132
+ done_local = 0
2133
+ with ThreadPoolExecutor(max_workers=min(io_workers_eff, n_frames)) as ex:
2134
+ futures = [ex.submit(_load_one, i) for i in range(n_frames)]
2135
+ for fut in as_completed(futures):
2136
+ i, t2d = fut.result()
2137
+ if t2d.shape != (h, w):
2138
+ raise RuntimeError(
2139
+ f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
2140
+ f"tile={(y0,y1,x0,x1)}"
2141
+ )
2142
+ slab[i] = t2d
2143
+ done_local += 1
2144
+ if (done_local & 7) == 0 or done_local == n_frames:
2145
+ tile_base = done / total
2146
+ tile_span = 1.0 / total
2147
+ inner = done_local / n_frames
2148
+ progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
2149
+ return slab
2150
+
2151
+ for (y0, y1, x0, x1) in tiles:
2152
+ h, w = (y1 - y0), (x1 - x0)
2153
+ t0 = time.perf_counter()
2154
+
2155
+ if want_c == 1:
2156
+ slab = _read_slab_for_channel(y0, y1, x0, x1)
2157
+ t1 = time.perf_counter()
2158
+ med_np = np.median(slab.astype(np.float32, copy=False), axis=0).astype(np.float32, copy=False)
2159
+ t2 = time.perf_counter()
2160
+ seed[y0:y1, x0:x1] = med_np
2161
+ status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median=CPU={t2-t1:.3f}s")
2162
+ else:
2163
+ for c in range(int(want_c)):
2164
+ slab = _read_slab_for_channel(y0, y1, x0, x1, csel=c)
2165
+ med_np = np.median(slab.astype(np.float32, copy=False), axis=0).astype(np.float32, copy=False)
2166
+ seed[c, y0:y1, x0:x1] = med_np
2167
+
2168
+ # free tile buffers aggressively
2169
+ del slab
2170
+ try: del med_np
2171
+ except Exception as e:
2172
+ import logging
2173
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2174
+
2175
+ done += 1
2176
+ if (done & 1) == 0:
2177
+ gc.collect() # encourage RSS to drop
2178
+ progress_cb(done / total, _tile_msg(done, total))
2179
+ if (done & 3) == 0:
2180
+ _process_gui_events_safely()
2181
+
2182
+ status_cb(f"Median seed (CPU): want_c={want_c}, seed_shape={seed.shape}")
2183
+ return seed
2184
+
2185
+
2186
+
2187
+ def _seed_median_full_from_data(data_list):
2188
+ """
2189
+ data_list: list of np.ndarray each shaped either (H,W) or (C,H,W),
2190
+ already cropped/sanitized to the same size by the caller.
2191
+ Returns: (H,W) or (C,H,W) median image in float32.
2192
+ """
2193
+ if not data_list:
2194
+ raise ValueError("Empty stack for median seed")
2195
+
2196
+ a0 = data_list[0]
2197
+ if a0.ndim == 2:
2198
+ # (N, H, W) -> (H, W)
2199
+ cube = np.stack([np.asarray(a, dtype=np.float32, order="C") for a in data_list], axis=0)
2200
+ med = np.median(cube, axis=0).astype(np.float32, copy=False)
2201
+ return np.ascontiguousarray(med)
2202
+ else:
2203
+ # (N, C, H, W) -> (C, H, W)
2204
+ cube = np.stack([np.asarray(a, dtype=np.float32, order="C") for a in data_list], axis=0)
2205
+ med = np.median(cube, axis=0).astype(np.float32, copy=False)
2206
+ return np.ascontiguousarray(med)
2207
+
2208
+ def _infer_channels_from_tile(path: str, Ht: int, Wt: int) -> int:
2209
+ """
2210
+ Decide output channel count for median seeding in PerChannel mode.
2211
+ Returns 3 if the source is RGB, else 1.
2212
+ """
2213
+ C, _, _ = _read_shape_fast(path) # (C,H,W) with C in {1,3}
2214
+ return 3 if C == 3 else 1
2215
+
2216
+
2217
+ def _build_seed_running_mu_sigma_from_paths(paths, Ht, Wt, color_mode,
2218
+ *, bootstrap_frames=20, clip_sigma=5.0,
2219
+ status_cb=lambda s: None, progress_cb=lambda f,m='': None):
2220
+ K = max(1, min(int(bootstrap_frames), len(paths)))
2221
+ def _load_chw(i):
2222
+ ys, _ = _stack_loader_memmap([paths[i]], Ht, Wt, color_mode)
2223
+ return _as_chw(ys[0]).astype(np.float32, copy=False)
2224
+ x0 = _load_chw(0).copy()
2225
+ mean = x0; m2 = np.zeros_like(mean); count = 1
2226
+ for i in range(1, K):
2227
+ x = _load_chw(i); count += 1
2228
+ d = x - mean; mean += d/count; m2 += d*(x-mean)
2229
+ progress_cb(i/K*0.5, "μ-σ bootstrap")
2230
+ var = m2 / max(1, count-1); sigma = np.sqrt(np.clip(var, 1e-12, None)).astype(np.float32)
2231
+ lo = mean - float(clip_sigma)*sigma; hi = mean + float(clip_sigma)*sigma
2232
+ acc = np.zeros_like(mean); n=0
2233
+ for i in range(len(paths)):
2234
+ x = _load_chw(i); x = np.clip(x, lo, hi, out=x)
2235
+ acc += x; n += 1; progress_cb(0.5 + 0.5*(i+1)/len(paths), "clipped mean")
2236
+ seed = (acc/max(1,n)).astype(np.float32)
2237
+ return seed[0] if (seed.ndim==3 and seed.shape[0]==1) else seed
2238
+
2239
+ def _flush_memmaps(*arrs):
2240
+ for a in arrs:
2241
+ try:
2242
+ mm = a[0] if (isinstance(a, tuple) and isinstance(a[0], np.memmap)) else a
2243
+ if isinstance(mm, np.memmap):
2244
+ mm.flush()
2245
+ except Exception:
2246
+ pass
2247
+
2248
+ def _mm_create(shape, dtype=np.float32, scratch_dir=None, tag="mm"):
2249
+ root = scratch_dir or tempfile.gettempdir()
2250
+ os.makedirs(root, exist_ok=True)
2251
+ fd, path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".dat", dir=root)
2252
+ os.close(fd)
2253
+ mm = np.memmap(path, mode="w+", dtype=dtype, shape=tuple(shape))
2254
+ return mm, path
2255
+
2256
+ def _mm_flush(*arrs):
2257
+ for a in arrs:
2258
+ try:
2259
+ if isinstance(a, np.memmap):
2260
+ a.flush()
2261
+ except Exception:
2262
+ pass
2263
+
2264
+ def _gauss_tile(a, sigma):
2265
+ """
2266
+ Small helper: Gaussian blur a tile (2D float32), reflect borders.
2267
+ Prefers OpenCV; falls back to SciPy; otherwise no-op.
2268
+ """
2269
+ if sigma <= 0:
2270
+ return a
2271
+ k = int(max(1, int(round(3*sigma))) * 2 + 1)
2272
+ try:
2273
+ import cv2
2274
+ return cv2.GaussianBlur(a, (k, k), float(sigma), borderType=cv2.BORDER_REFLECT)
2275
+ except Exception:
2276
+ try:
2277
+ from scipy.ndimage import gaussian_filter
2278
+ return gaussian_filter(a, sigma=float(sigma), mode="reflect")
2279
+ except Exception:
2280
+ return a
2281
+
2282
+ # 1) add a tiny helper near _prepare_frame_stack_memmap
2283
+ def _make_memmap_tile_loader(frame_infos, max_open=32):
2284
+ """
2285
+ Returns tile_loader(i, y0,y1,x0,x1, csel=None) that slices from each frame's memmap.
2286
+ Keeps a tiny LRU cache of opened memmaps (handles only; not image-sized arrays).
2287
+ """
2288
+ opened = OrderedDict()
2289
+
2290
+ def _open_mm(i):
2291
+ fi = frame_infos[i]
2292
+ mm = np.memmap(fi["path"], mode="r", dtype=fi["dtype"], shape=fi["shape"])
2293
+ opened[i] = mm
2294
+ # evict least-recently used beyond max_open
2295
+ while len(opened) > int(max_open):
2296
+ _, old = opened.popitem(last=False)
2297
+ try: del old
2298
+ except Exception as e:
2299
+ import logging
2300
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2301
+ return mm
2302
+
2303
+ def tile_loader(i, y0, y1, x0, x1, csel=None):
2304
+ # reuse or open on demand
2305
+ mm = opened.get(i)
2306
+ if mm is None:
2307
+ mm = _open_mm(i)
2308
+ else:
2309
+ # bump LRU
2310
+ opened.move_to_end(i, last=True)
2311
+
2312
+ a = mm # (H,W) or (C,H,W)
2313
+ if a.ndim == 2:
2314
+ t = a[y0:y1, x0:x1]
2315
+ else:
2316
+ # (C,H,W); pick channel or default to first (luma-equivalent)
2317
+ cc = 0 if csel is None else int(csel)
2318
+ t = a[cc, y0:y1, x0:x1]
2319
+ # return a copy so median slab is independent/contiguous
2320
+ return np.array(t, copy=True)
2321
+
2322
+ # advertise channel count so the seeder doesn't reopen original files
2323
+ shp = frame_infos[0]["shape"]
2324
+ tile_loader.want_c = (shp[0] if (len(shp) == 3) else 1)
2325
+
2326
+ def _close():
2327
+ # drop handles
2328
+ while opened:
2329
+ _, mm = opened.popitem(last=False)
2330
+ try: del mm
2331
+ except Exception as e:
2332
+ import logging
2333
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2334
+ tile_loader.close = _close
2335
+
2336
+ return tile_loader
2337
+
2338
+
2339
+ def _prepare_frame_stack_memmap(
2340
+ paths: list[str],
2341
+ Ht: int,
2342
+ Wt: int,
2343
+ color_mode: str = "luma",
2344
+ *,
2345
+ scratch_dir: str | None = None,
2346
+ dtype: np.dtype | str = np.float32,
2347
+ tile_hw: tuple[int,int] = (512, 512),
2348
+ status_cb=lambda s: None,
2349
+ ):
2350
+ """
2351
+ Create one disk-backed memmap per input frame, already cropped to (Ht,Wt)
2352
+ and normalized to float32 (or requested dtype). Returns:
2353
+ frame_infos: list[dict(path, shape, dtype)]
2354
+ hdrs: list[fits.Header]
2355
+ Each memmap stores (H,W) or (C,H,W) in row-major order.
2356
+ """
2357
+ scratch_dir = _ensure_scratch_dir(scratch_dir)
2358
+
2359
+ # normalize dtype
2360
+ if isinstance(dtype, str):
2361
+ _d = dtype.lower().strip()
2362
+ out_dtype = np.float16 if _d in ("float16","fp16","half") else np.float32
2363
+ else:
2364
+ out_dtype = np.dtype(dtype)
2365
+
2366
+ th, tw = int(tile_hw[0]), int(tile_hw[1])
2367
+ infos, hdrs = [], []
2368
+
2369
+ status_cb(f"Preparing {len(paths)} frame memmaps → {scratch_dir}")
2370
+ for idx, p in enumerate(paths, start=1):
2371
+ try:
2372
+ hdr = _safe_primary_header(p)
2373
+ except Exception:
2374
+ hdr = fits.Header()
2375
+ hdrs.append(hdr)
2376
+
2377
+ mode = str(color_mode).lower().strip()
2378
+ if mode == "luma":
2379
+ shape = (Ht, Wt) # 2D
2380
+ C_out = 1
2381
+ else:
2382
+ # Per-channel path: keep CHW even if mono → (1,H,W) or (3,H,W)
2383
+ C0, _, _ = _read_shape_fast(p)
2384
+ C_out = 3 if C0 == 3 else 1
2385
+ shape = (C_out, Ht, Wt) # 3D
2386
+
2387
+ mm_path = _mm_unique_path(scratch_dir, tag=f"frame_{idx:04d}", ext=".mm")
2388
+ mm = np.memmap(mm_path, mode="w+", dtype=out_dtype, shape=shape)
2389
+ mm_is_3d = (mm.ndim == 3)
2390
+
2391
+ tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt))
2392
+ for y in range(0, Ht, th) for x in range(0, Wt, tw)]
2393
+
2394
+ for (y0, y1, x0, x1) in tiles:
2395
+ # 1) read source tile
2396
+ t = _read_tile_fits_any(p, y0, y1, x0, x1)
2397
+
2398
+ # 2) normalize to float32 in [0,1] if integer input
2399
+ if t.dtype.kind in "ui":
2400
+ t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
2401
+ else:
2402
+ t = t.astype(np.float32, copy=False)
2403
+
2404
+ # 3) layout to match memmap
2405
+ if not mm_is_3d:
2406
+ # target is 2D (Ht,Wt) — luma tile must be 2D
2407
+ if t.ndim == 3:
2408
+ t = _to_luma_local(t)
2409
+ elif t.ndim != 2:
2410
+ t = _to_luma_local(t)
2411
+ if out_dtype != np.float32:
2412
+ t = t.astype(out_dtype, copy=False)
2413
+ mm[y0:y1, x0:x1] = t
2414
+ else:
2415
+ # target is 3D (C,H,W)
2416
+ if C_out == 3:
2417
+ # ensure CHW
2418
+ if t.ndim == 2:
2419
+ # replicate luma across 3 channels
2420
+ t = np.stack([t, t, t], axis=0) # CHW
2421
+ elif t.ndim == 3 and t.shape[-1] == 3: # HWC → CHW
2422
+ t = np.moveaxis(t, -1, 0)
2423
+ elif t.ndim == 3 and t.shape[0] == 3: # already CHW
2424
+ pass
2425
+ else:
2426
+ t = _to_luma_local(t)
2427
+ t = np.stack([t, t, t], axis=0)
2428
+ if out_dtype != np.float32:
2429
+ t = t.astype(out_dtype, copy=False)
2430
+ mm[:, y0:y1, x0:x1] = t
2431
+ else:
2432
+ # C_out == 1: store single channel at mm[0, ...]
2433
+ if t.ndim == 3:
2434
+ t = _to_luma_local(t)
2435
+ # t must be 2D here
2436
+ if out_dtype != np.float32:
2437
+ t = t.astype(out_dtype, copy=False)
2438
+ mm[0, y0:y1, x0:x1] = t
2439
+
2440
+ try: mm.flush()
2441
+ except Exception as e:
2442
+ import logging
2443
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2444
+ del mm
2445
+
2446
+ infos.append({"path": mm_path, "shape": tuple(shape), "dtype": out_dtype})
2447
+
2448
+ if (idx % 8) == 0 or idx == len(paths):
2449
+ status_cb(f"Frame memmaps: {idx}/{len(paths)} ready")
2450
+ gc.collect()
2451
+
2452
+ return infos, hdrs
2453
+
2454
+
2455
+ # -----------------------------
2456
+ # Core
2457
+ # -----------------------------
2458
+ def multiframe_deconv(
2459
+ paths,
2460
+ out_path,
2461
+ iters=20,
2462
+ kappa=2.0,
2463
+ color_mode="luma",
2464
+ seed_mode: str = "robust",
2465
+ huber_delta=0.0,
2466
+ masks=None,
2467
+ variances=None,
2468
+ rho="huber",
2469
+ status_cb=lambda s: None,
2470
+ min_iters: int = 3,
2471
+ use_star_masks: bool = False,
2472
+ use_variance_maps: bool = False,
2473
+ star_mask_cfg: dict | None = None,
2474
+ varmap_cfg: dict | None = None,
2475
+ save_intermediate: bool = False,
2476
+ save_every: int = 1,
2477
+ # >>> SR options
2478
+ super_res_factor: int = 1,
2479
+ sr_sigma: float = 1.1,
2480
+ sr_psf_opt_iters: int = 250,
2481
+ sr_psf_opt_lr: float = 0.1,
2482
+ star_mask_ref_path: str | None = None,
2483
+ scratch_to_disk: bool = True, # spill large scratch arrays to memmap
2484
+ scratch_dir: str | None = None, # where to put memmaps (default: tempdir)
2485
+ memmap_threshold_mb: int = 512, # always memmap if buffer > this
2486
+ force_cpu: bool = False, # disable torch entirely unless caller opts in
2487
+ cache_psf_ffts: str = "disk", # 'disk' | 'ram' | 'none' (see §3)
2488
+ fft_reuse_across_iters: bool = True, # keep PSF FFTs across iters (same math)
2489
+ io_workers: int | None = None, # cap I/O threadpool (seed/tiles)
2490
+ blas_threads: int = 1, # limit BLAS threads to avoid oversub
2491
+ ):
2492
+ # sanitize and clamp
2493
+ max_iters = max(1, int(iters))
2494
+ min_iters = max(1, int(min_iters))
2495
+ if min_iters > max_iters:
2496
+ min_iters = max_iters
2497
+
2498
+ def _emit_pct(pct: float, msg: str | None = None):
2499
+ pct = float(max(0.0, min(1.0, pct)))
2500
+ status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
2501
+
2502
+ status_cb(f"MFDeconv: loading {len(paths)} aligned frames…")
2503
+ _emit_pct(0.02, "loading")
2504
+
2505
+ # Use unified probe to pick a common crop without loading full images
2506
+ Ht, Wt = _common_hw_from_paths(paths)
2507
+ _emit_pct(0.05, "preparing")
2508
+
2509
+ # Stream actual pixels cropped to (Ht,Wt), float32 CHW/2D + headers
2510
+ frame_infos, hdrs = _prepare_frame_stack_memmap(
2511
+ paths, Ht, Wt, color_mode,
2512
+ scratch_dir=scratch_dir,
2513
+ dtype=np.float32, # or pull from a cfg
2514
+ tile_hw=(512, 512),
2515
+ status_cb=status_cb,
2516
+ )
2517
+
2518
+ tile_loader = _make_memmap_tile_loader(frame_infos, max_open=32)
2519
+
2520
+ def _open_frame_numpy(i: int) -> np.ndarray:
2521
+ fi = frame_infos[i]
2522
+ a = np.memmap(fi["path"], mode="r", dtype=fi["dtype"], shape=fi["shape"])
2523
+ # Solver expects float32 math; cast on read (no copy if already f32)
2524
+ return np.asarray(a, dtype=np.float32)
2525
+
2526
+ # For functions that only need luma/HW:
2527
+ def _open_frame_hw(i: int) -> np.ndarray:
2528
+ arr = _open_frame_numpy(i)
2529
+ if arr.ndim == 3:
2530
+ return arr[0] # use first/luma channel consistently
2531
+ return arr
2532
+
2533
+
2534
+ relax = 0.7
2535
+ use_torch = False
2536
+ global torch, TORCH_OK
2537
+
2538
+ # -------- try to import torch from per-user runtime venv --------
2539
+ # -------- try to import torch from per-user runtime venv --------
2540
+ torch = None
2541
+ TORCH_OK = False
2542
+ cuda_ok = mps_ok = dml_ok = False
2543
+ dml_device = None
2544
+ try:
2545
+ from setiastro.saspro.runtime_torch import import_torch
2546
+ torch = import_torch(prefer_cuda=True, status_cb=status_cb)
2547
+ TORCH_OK = True
2548
+
2549
+ try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
2550
+ except Exception: cuda_ok = False
2551
+ try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
2552
+ except Exception: mps_ok = False
2553
+ try:
2554
+ import torch_directml
2555
+ dml_device = torch_directml.device()
2556
+ _ = (torch.ones(1, device=dml_device) + 1).item()
2557
+ dml_ok = True
2558
+ except Exception:
2559
+ dml_ok = False
2560
+
2561
+ if cuda_ok:
2562
+ status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
2563
+ elif mps_ok:
2564
+ status_cb("PyTorch MPS (Apple) available: True")
2565
+ elif dml_ok:
2566
+ status_cb("PyTorch DirectML (Windows) available: True")
2567
+ else:
2568
+ status_cb("PyTorch present, using CPU backend.")
2569
+
2570
+ status_cb(
2571
+ f"PyTorch {getattr(torch, '__version__', '?')} backend: "
2572
+ + ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
2573
+ )
2574
+ except Exception as e:
2575
+ TORCH_OK = False
2576
+ status_cb(f"PyTorch not available → CPU path. ({e})")
2577
+
2578
+ # ----------------------------
2579
+ # Torch usage policy gate (STEP 5)
2580
+ # ----------------------------
2581
+ # 1) Hard off-switch from caller
2582
+ if force_cpu:
2583
+ TORCH_OK = False
2584
+ torch = None
2585
+ status_cb("Torch disabled by policy: force_cpu=True → using NumPy everywhere.")
2586
+
2587
+ # 2) (Optional) clamp BLAS threads globally to avoid oversubscription
2588
+ try:
2589
+ from threadpoolctl import threadpool_limits
2590
+ _blas_ctx = threadpool_limits(limits=int(max(1, blas_threads)))
2591
+ except Exception:
2592
+ _blas_ctx = contextlib.nullcontext()
2593
+
2594
+ use_torch = bool(TORCH_OK)
2595
+
2596
+ # Only configure Torch backends if policy allowed Torch
2597
+ if use_torch:
2598
+ # ----- Precision policy (strict FP32) -----
2599
+ try:
2600
+ torch.backends.cudnn.benchmark = True
2601
+ if hasattr(torch.backends, "cudnn"):
2602
+ torch.backends.cudnn.allow_tf32 = False
2603
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
2604
+ torch.backends.cuda.matmul.allow_tf32 = False
2605
+ if hasattr(torch, "set_float32_matmul_precision"):
2606
+ torch.set_float32_matmul_precision("highest")
2607
+ except Exception:
2608
+ pass
2609
+
2610
+ try:
2611
+ c_tf32 = getattr(torch.backends.cudnn, "allow_tf32", None)
2612
+ m_tf32 = getattr(getattr(torch.backends.cuda, "matmul", object()), "allow_tf32", None)
2613
+ status_cb(
2614
+ f"Precision: cudnn.allow_tf32={c_tf32} | "
2615
+ f"matmul.allow_tf32={m_tf32} | "
2616
+ f"benchmark={getattr(torch.backends.cudnn, 'benchmark', None)}"
2617
+ )
2618
+ except Exception:
2619
+ pass
2620
+
2621
+
2622
+ _process_gui_events_safely()
2623
+
2624
+ # PSFs (auto-size per frame) + flipped copies
2625
+ psf_out_dir = None
2626
+ psfs, masks_auto, vars_auto, var_paths = _build_psf_and_assets(
2627
+ paths,
2628
+ make_masks=bool(use_star_masks),
2629
+ make_varmaps=bool(use_variance_maps),
2630
+ status_cb=status_cb,
2631
+ save_dir=None,
2632
+ star_mask_cfg=star_mask_cfg,
2633
+ varmap_cfg=varmap_cfg,
2634
+ star_mask_ref_path=star_mask_ref_path,
2635
+ Ht=Ht, Wt=Wt, color_mode=color_mode,
2636
+ )
2637
+
2638
+ # >>> SR: lift PSFs to super-res if requested
2639
+ r = int(max(1, super_res_factor))
2640
+ if r > 1:
2641
+ status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
2642
+ _process_gui_events_safely()
2643
+ sr_psfs = []
2644
+ for i, k_native in enumerate(psfs, start=1):
2645
+ h = _solve_super_psf_from_native(k_native, r=r, sigma=float(sr_sigma),
2646
+ iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr))
2647
+ sr_psfs.append(h)
2648
+ status_cb(f" SR-PSF{i}: native {k_native.shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
2649
+ psfs = sr_psfs
2650
+
2651
+ flip_psf = [_flip_kernel(k) for k in psfs]
2652
+ _emit_pct(0.20, "PSF Ready")
2653
+
2654
+
2655
+ # --- SR/native seed ---
2656
+ seed_mode_s = str(seed_mode).lower().strip()
2657
+ if seed_mode_s not in ("robust", "median"):
2658
+ seed_mode_s = "robust"
2659
+
2660
+ if seed_mode_s == "median":
2661
+ status_cb("MFDeconv: Building median seed (streaming, CPU)…")
2662
+ try:
2663
+ seed_native = _seed_median_streaming(
2664
+ paths,
2665
+ Ht,
2666
+ Wt,
2667
+ color_mode=color_mode,
2668
+ tile_hw=(256, 256),
2669
+ status_cb=status_cb,
2670
+ progress_cb=lambda f, m="": _emit_pct(0.10 + 0.10 * f, f"median seed: {m}"),
2671
+ io_workers=io_workers,
2672
+ scratch_dir=scratch_dir,
2673
+ tile_loader=tile_loader, # <<< use the memmap-backed tiles
2674
+ )
2675
+ finally:
2676
+ # drop any open memmap handles held by the loader
2677
+ try: tile_loader.close()
2678
+ except Exception as e:
2679
+ import logging
2680
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2681
+ import gc as _gc; _gc.collect()
2682
+ else:
2683
+ status_cb("MFDeconv: Building robust seed (live μ-σ stacking)…")
2684
+ seed_native = _build_seed_running_mu_sigma_from_paths(
2685
+ paths, Ht, Wt, color_mode,
2686
+ bootstrap_frames=20, clip_sigma=5.0,
2687
+ status_cb=status_cb, progress_cb=lambda f,m='': None
2688
+ )
2689
+ if r > 1:
2690
+ if seed_native.ndim == 2:
2691
+ x = _upsample_sum(seed_native / (r*r), r, target_hw=(Ht*r, Wt*r))
2692
+ else:
2693
+ C, Hn, Wn = seed_native.shape
2694
+ x = np.stack(
2695
+ [_upsample_sum(seed_native[c] / (r*r), r, target_hw=(Hn*r, Wn*r)) for c in range(C)],
2696
+ axis=0
2697
+ )
2698
+ else:
2699
+ x = seed_native
2700
+ # Ensure CHW shape under PerChannel, even for mono (C=1)
2701
+ if str(color_mode).lower() != "luma" and x.ndim == 2:
2702
+ x = x[None, ...] # (1,H,W) to match frame & Torch ops
2703
+
2704
+ # Robust H,W extraction
2705
+ Hs, Ws = (x.shape[-2], x.shape[-1]) if x.ndim >= 2 else (Ht, Wt)
2706
+
2707
+ # masks/vars aligned to native grid (2D each)
2708
+ auto_masks = masks_auto if use_star_masks else None
2709
+ auto_vars = vars_auto if use_variance_maps else None
2710
+ base_template = np.empty((Ht, Wt), dtype=np.float32)
2711
+ data_like = [base_template] * len(paths)
2712
+
2713
+ # replace the bad call with:
2714
+ mask_list = _ensure_mask_list(
2715
+ masks if masks is not None else masks_auto,
2716
+ data_like
2717
+ )
2718
+ # ...
2719
+ if use_variance_maps and var_paths is not None:
2720
+ def _open_var_mm(p):
2721
+ return None if p is None else np.memmap(p, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
2722
+ var_list = [_open_var_mm(p) for p in var_paths]
2723
+ else:
2724
+ var_list = [None] * len(paths)
2725
+
2726
+ iter_dir = None
2727
+ hdr0_seed = None
2728
+ if save_intermediate:
2729
+ iter_dir = _iter_folder(out_path)
2730
+ status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
2731
+ try:
2732
+ hdr0_seed = _safe_primary_header(paths[0])
2733
+ except Exception:
2734
+ hdr0_seed = fits.Header()
2735
+ _save_iter_image(x, hdr0_seed, iter_dir, "seed", color_mode)
2736
+
2737
+ status_cb("MFDeconv: Calculating Backgrounds and MADs…")
2738
+ _process_gui_events_safely()
2739
+ y0 = _open_frame_hw(0)
2740
+ bg_est = _sep_bg_rms([y0]) or (np.median(np.abs(y0 - np.median(y0))) * 1.4826)
2741
+ del y0
2742
+ status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
2743
+ _process_gui_events_safely()
2744
+
2745
+ status_cb("Computing FFTs and Allocating Scratch…")
2746
+ _process_gui_events_safely()
2747
+
2748
+ # -------- precompute and allocate scratch --------
2749
+ pred_super = None # CPU-only temp; avoid UnboundLocal on Torch path
2750
+ tmp_out = None # CPU-only temp; avoid UnboundLocal on Torch path
2751
+ def _arr_only(x):
2752
+ """Accept either ndarray/memmap or (memmap, path) and return the array."""
2753
+ if isinstance(x, tuple) and len(x) == 2 and hasattr(x[0], "dtype"):
2754
+ return x[0]
2755
+ return x
2756
+ per_frame_logging = (r > 1)
2757
+ if use_torch:
2758
+ x_t = _to_t(_contig(x))
2759
+ num = torch.zeros_like(x_t)
2760
+ den = torch.zeros_like(x_t)
2761
+
2762
+ if r > 1:
2763
+ # >>> SR path now uses SPATIAL CONV (cuDNN) to avoid huge FFT buffers
2764
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
2765
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
2766
+ else:
2767
+ # Native spatial (as before)
2768
+ psf_t = [_to_t(_contig(k))[None, None] for k in psfs]
2769
+ psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
2770
+
2771
+ else:
2772
+ # ---------- CPU path (NumPy) ----------
2773
+ x_t = x
2774
+
2775
+ # Determine working size
2776
+ if x_t.ndim == 2:
2777
+ Hs, Ws = x_t.shape
2778
+ else:
2779
+ _, Hs, Ws = x_t.shape
2780
+
2781
+ # Choose PSF FFT caching policy
2782
+ # Expect this to be threaded in from function args; fallback to "ram"
2783
+ cache_psf_ffts = locals().get("cache_psf_ffts", "ram")
2784
+ scratch_dir = locals().get("scratch_dir", None)
2785
+
2786
+ import numpy.fft as _fft
2787
+
2788
+ Kfs = KTfs = meta = None
2789
+
2790
+ if cache_psf_ffts == "ram":
2791
+ # Original behavior: keep FFTs in RAM
2792
+ Kfs, KTfs, meta = _precompute_np_psf_ffts(psfs, flip_psf, Hs, Ws)
2793
+
2794
+ elif cache_psf_ffts == "disk":
2795
+ # New behavior: keep FFTs in disk-backed memmaps to save RAM
2796
+ # Requires _mm_create() helper (Step 2). If not added yet, set cache_psf_ffts="ram".
2797
+ Kfs, KTfs, meta = [], [], []
2798
+ for idx, (k, kT) in enumerate(zip(psfs, flip_psf), start=1):
2799
+ kh, kw = k.shape
2800
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
2801
+
2802
+ # Create complex64 memmaps for rfftn grids (H, W//2+1)
2803
+ Kf_mm, Kf_path = _mm_create((fftH, fftW//2 + 1), np.complex64, scratch_dir, tag=f"Kf_{idx}")
2804
+ KTf_mm, KTf_path = _mm_create((fftH, fftW//2 + 1), np.complex64, scratch_dir, tag=f"KTf_{idx}")
2805
+
2806
+ # Compute once into the memmaps (same math)
2807
+ Kf_mm[...] = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW)).astype(np.complex64, copy=False)
2808
+ KTf_mm[...] = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW)).astype(np.complex64, copy=False)
2809
+ Kf_mm.flush(); KTf_mm.flush()
2810
+
2811
+ Kfs.append(Kf_mm)
2812
+ KTfs.append(KTf_mm)
2813
+ meta.append((kh, kw, fftH, fftW))
2814
+
2815
+ elif cache_psf_ffts == "none":
2816
+ # Don’t precompute; compute per-frame inside the iter loop (same math, less RAM, more CPU).
2817
+ Kfs = KTfs = meta = None
2818
+ else:
2819
+ # Fallback to RAM behavior
2820
+ cache_psf_ffts = "ram"
2821
+ Kfs, KTfs, meta = _precompute_np_psf_ffts(psfs, flip_psf, Hs, Ws)
2822
+
2823
+ # Allocate CPU scratch (keep as-is for Step 3)
2824
+ def _shape_of(a): return a.shape, a.dtype
2825
+
2826
+ # Always keep x_t in RAM for speed (it’s the only array updated iteratively)
2827
+ # But allow opting into memmap if strictly necessary:
2828
+ if scratch_to_disk and (_approx_bytes(x.shape, x.dtype) / 1e6 > memmap_threshold_mb):
2829
+ x_mm, x_path = _mm_create(x.shape, x.dtype, scratch_dir, tag="x")
2830
+ x_mm[...] = x # copy the seed into the memmap
2831
+ x_t = x_mm
2832
+ else:
2833
+ x_t = x
2834
+
2835
+ num = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2836
+ force_mm=scratch_to_disk,
2837
+ threshold_mb=memmap_threshold_mb,
2838
+ scratch_dir=scratch_dir, tag="num"))
2839
+ den = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2840
+ force_mm=scratch_to_disk,
2841
+ threshold_mb=memmap_threshold_mb,
2842
+ scratch_dir=scratch_dir, tag="den"))
2843
+
2844
+ pred_super = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2845
+ force_mm=scratch_to_disk,
2846
+ threshold_mb=memmap_threshold_mb,
2847
+ scratch_dir=scratch_dir, tag="pred"))
2848
+ tmp_out = _arr_only(_maybe_memmap(x_t.shape, x_t.dtype,
2849
+ force_mm=scratch_to_disk,
2850
+ threshold_mb=memmap_threshold_mb,
2851
+ scratch_dir=scratch_dir, tag="tmp"))
2852
+
2853
+
2854
+ _to_check = [('x_t', x_t), ('num', num), ('den', den)]
2855
+ if not use_torch:
2856
+ _to_check += [('pred_super', pred_super), ('tmp_out', tmp_out)]
2857
+ for _name, _arr in _to_check:
2858
+ assert hasattr(_arr, 'shape'), f"{_name} must be array-like with .shape, got {type(_arr)}"
2859
+ _flush_memmaps(num, den)
2860
+
2861
+ # CPU-only scratch; may not exist on Torch path
2862
+ if isinstance(pred_super, np.memmap):
2863
+ _flush_memmaps(pred_super)
2864
+ if isinstance(tmp_out, np.memmap):
2865
+ _flush_memmaps(tmp_out)
2866
+
2867
+
2868
+ status_cb("Starting First Multiplicative Iteration…")
2869
+ _process_gui_events_safely()
2870
+
2871
+ cm = _safe_inference_context() if use_torch else NO_GRAD
2872
+ rho_is_l2 = (str(rho).lower() == "l2")
2873
+ local_delta = 0.0 if rho_is_l2 else huber_delta
2874
+
2875
+
2876
+ auto_delta_cache = None
2877
+ if use_torch and (huber_delta < 0) and (not rho_is_l2):
2878
+ auto_delta_cache = [None] * len(paths)
2879
+ # ---- unified EarlyStopper ----
2880
+ early = EarlyStopper(
2881
+ tol_upd_floor=1e-3,
2882
+ tol_rel_floor=5e-4,
2883
+ early_frac=0.40,
2884
+ ema_alpha=0.5,
2885
+ patience=2,
2886
+ min_iters=min_iters,
2887
+ status_cb=status_cb
2888
+ )
2889
+
2890
+ used_iters = 0
2891
+ early_stopped = False
2892
+
2893
+ with cm():
2894
+ for it in range(1, max_iters + 1):
2895
+ if use_torch:
2896
+ num.zero_(); den.zero_()
2897
+
2898
+ if r > 1:
2899
+ # -------- SR path (SPATIAL conv + stream) --------
2900
+ for fidx, (wk, wkT) in enumerate(zip(psf_t, psfT_t)):
2901
+ yt_np = _open_frame_numpy(fidx) # CHW or HW (CPU)
2902
+ mt_np = mask_list[fidx]
2903
+ vt_np = var_list[fidx]
2904
+
2905
+ yt = torch.as_tensor(yt_np, dtype=x_t.dtype, device=x_t.device)
2906
+ mt = None if mt_np is None else torch.as_tensor(mt_np, dtype=x_t.dtype, device=x_t.device)
2907
+ vt = None if vt_np is None else torch.as_tensor(vt_np, dtype=x_t.dtype, device=x_t.device)
2908
+
2909
+ # SR conv on grid of x_t
2910
+ pred_sr = _conv_same_torch(x_t, wk) # SR grid
2911
+ pred_low = _downsample_avg_t(pred_sr, r) # native grid
2912
+
2913
+ if auto_delta_cache is not None:
2914
+ if (auto_delta_cache[fidx] is None) or (it % 5 == 1):
2915
+ rnat = yt - pred_low
2916
+ med = torch.median(rnat)
2917
+ mad = torch.median(torch.abs(rnat - med)) + 1e-6
2918
+ rms = 1.4826 * mad
2919
+ auto_delta_cache[fidx] = float((-huber_delta) * torch.clamp(rms, min=1e-6).item())
2920
+ wmap_low = _weight_map(yt, pred_low, auto_delta_cache[fidx], var_map=vt, mask=mt)
2921
+ else:
2922
+ wmap_low = _weight_map(yt, pred_low, local_delta, var_map=vt, mask=mt)
2923
+
2924
+ # lift back to SR via sum-replicate
2925
+ up_y = _upsample_sum_t(wmap_low * yt, r)
2926
+ up_pred = _upsample_sum_t(wmap_low * pred_low, r)
2927
+
2928
+ # accumulate via adjoint kernel on SR grid
2929
+ num += _conv_same_torch(up_y, wkT)
2930
+ den += _conv_same_torch(up_pred, wkT)
2931
+
2932
+ # free temps as aggressively as possible
2933
+ del yt, mt, vt, pred_sr, pred_low, wmap_low, up_y, up_pred
2934
+ if cuda_ok:
2935
+ try: torch.cuda.empty_cache()
2936
+ except Exception as e:
2937
+ import logging
2938
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2939
+
2940
+ if per_frame_logging and ((fidx & 7) == 0):
2941
+ status_cb(f"Iter {it}/{max_iters} — frame {fidx+1}/{len(paths)} (SR spatial)")
2942
+
2943
+ del yt_np
2944
+
2945
+ else:
2946
+ # -------- Native path (spatial conv, stream) --------
2947
+ for fidx, (wk, wkT) in enumerate(zip(psf_t, psfT_t)):
2948
+ yt_np = _open_frame_numpy(fidx) # CHW or HW (CPU → to Torch tensor)
2949
+ mt_np = mask_list[fidx]
2950
+ vt_np = var_list[fidx]
2951
+
2952
+ yt = torch.as_tensor(yt_np, dtype=x_t.dtype, device=x_t.device)
2953
+ mt = None if mt_np is None else torch.as_tensor(mt_np, dtype=x_t.dtype, device=x_t.device)
2954
+ vt = None if vt_np is None else torch.as_tensor(vt_np, dtype=x_t.dtype, device=x_t.device)
2955
+
2956
+ pred = _conv_same_torch(x_t, wk)
2957
+ wmap = _weight_map(yt, pred, local_delta, var_map=vt, mask=mt)
2958
+ up_y = wmap * yt
2959
+ up_pred = wmap * pred
2960
+ num += _conv_same_torch(up_y, wkT)
2961
+ den += _conv_same_torch(up_pred, wkT)
2962
+
2963
+ del yt, mt, vt, pred, wmap, up_y, up_pred
2964
+ if cuda_ok:
2965
+ try: torch.cuda.empty_cache()
2966
+ except Exception as e:
2967
+ import logging
2968
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
2969
+
2970
+ ratio = num / (den + EPS)
2971
+ neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
2972
+ ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
2973
+ upd = torch.clamp(ratio, 1.0 / kappa, kappa)
2974
+ x_next = torch.clamp(x_t * upd, min=0.0)
2975
+
2976
+ upd_med = torch.median(torch.abs(upd - 1))
2977
+ rel_change = (torch.median(torch.abs(x_next - x_t)) /
2978
+ (torch.median(torch.abs(x_t)) + 1e-8))
2979
+
2980
+ um = float(upd_med.detach().cpu().item())
2981
+ rc = float(rel_change.detach().cpu().item())
2982
+
2983
+ if early.step(it, max_iters, um, rc):
2984
+ x_t = x_next
2985
+ used_iters = it
2986
+ early_stopped = True
2987
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
2988
+ _process_gui_events_safely()
2989
+ break
2990
+
2991
+
2992
+ x_t = (1.0 - relax) * x_t + relax * x_next
2993
+
2994
+ else:
2995
+ # -------- NumPy path (fixed, no 'data') --------
2996
+ num.fill(0.0); den.fill(0.0)
2997
+
2998
+ if r > 1:
2999
+ # -------- Super-resolution (NumPy) --------
3000
+ if cache_psf_ffts == "none":
3001
+ # No precomputed PSF FFTs → compute per-frame per-iter
3002
+ for fidx, (m2d, v2d) in enumerate(zip(mask_list, var_list)):
3003
+ # Load native frame on demand (CHW or HW)
3004
+ y_nat = _open_frame_numpy(fidx)
3005
+
3006
+ # PSF for this frame
3007
+ k, kT = psfs[fidx], flip_psf[fidx]
3008
+ kh, kw = k.shape
3009
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
3010
+
3011
+ # Per-frame FFTs (same math as precomputed branch)
3012
+ Kf = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW))
3013
+ KTf = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW))
3014
+
3015
+ # Convolve current estimate x_t → SR prediction, then downsample
3016
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3017
+ pred_low = _downsample_avg(pred_super, r)
3018
+
3019
+ # Weight map in native grid
3020
+ wmap_low = _weight_map(y_nat, pred_low, local_delta, var_map=v2d, mask=m2d)
3021
+
3022
+ # Lift back to SR via sum-replicate
3023
+ up_y = _upsample_sum(wmap_low * y_nat, r, target_hw=pred_super.shape[-2:])
3024
+ up_pred = _upsample_sum(wmap_low * pred_low, r, target_hw=pred_super.shape[-2:])
3025
+
3026
+ # Accumulate adjoint contributions
3027
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3028
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3029
+
3030
+ del y_nat, up_y, up_pred, wmap_low, pred_low, Kf, KTf
3031
+
3032
+ else:
3033
+ # Precomputed PSF FFTs (RAM or disk memmap)
3034
+ for (Kf, KTf, (kh, kw, fftH, fftW)), m2d, pvar, fidx in zip(
3035
+ zip(Kfs, KTfs, meta),
3036
+ mask_list,
3037
+ (var_paths or [None] * len(frame_infos)),
3038
+ range(len(frame_infos)),
3039
+ ):
3040
+ y_nat = _open_frame_numpy(fidx) # CHW or HW
3041
+
3042
+ vt_np = None
3043
+ if use_variance_maps and pvar is not None:
3044
+ vt_np = np.memmap(pvar, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
3045
+
3046
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3047
+ pred = pred_super
3048
+
3049
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=vt_np, mask=m2d)
3050
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3051
+
3052
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3053
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3054
+
3055
+ if vt_np is not None:
3056
+ try:
3057
+ del vt_np
3058
+ except Exception as e:
3059
+ import logging
3060
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3061
+ del y_nat, up_y, up_pred, wmap, pred
3062
+
3063
+ else:
3064
+ # -------- Native (NumPy) --------
3065
+ if cache_psf_ffts == "none":
3066
+ # No precomputed PSF FFTs → compute per-frame per-iter
3067
+ for fidx, (m2d, v2d) in enumerate(zip(mask_list, var_list)):
3068
+ y_nat = _open_frame_numpy(fidx)
3069
+
3070
+ k, kT = psfs[fidx], flip_psf[fidx]
3071
+ kh, kw = k.shape
3072
+ fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
3073
+
3074
+ Kf = _fft.rfftn(np.fft.ifftshift(k).astype(np.float32, copy=False), s=(fftH, fftW))
3075
+ KTf = _fft.rfftn(np.fft.ifftshift(kT).astype(np.float32, copy=False), s=(fftH, fftW))
3076
+
3077
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3078
+ pred = pred_super
3079
+
3080
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=v2d, mask=m2d)
3081
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3082
+
3083
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3084
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3085
+
3086
+ del y_nat, up_y, up_pred, wmap, pred, Kf, KTf
3087
+
3088
+ else:
3089
+ # Precomputed PSF FFTs (RAM or disk memmap)
3090
+ for (Kf, KTf, (kh, kw, fftH, fftW)), m2d, pvar, fidx in zip(
3091
+ zip(Kfs, KTfs, meta),
3092
+ mask_list,
3093
+ (var_paths or [None] * len(frame_infos)),
3094
+ range(len(frame_infos)),
3095
+ ):
3096
+ y_nat = _open_frame_numpy(fidx)
3097
+
3098
+ vt_np = None
3099
+ if use_variance_maps and pvar is not None:
3100
+ vt_np = np.memmap(pvar, mode="r", dtype=_VAR_DTYPE, shape=(Ht, Wt))
3101
+
3102
+ _fft_conv_same_np(x_t, Kf, kh, kw, fftH, fftW, pred_super)
3103
+ pred = pred_super
3104
+
3105
+ wmap = _weight_map(y_nat, pred, local_delta, var_map=vt_np, mask=m2d)
3106
+ up_y, up_pred = (wmap * y_nat), (wmap * pred)
3107
+
3108
+ _fft_conv_same_np(up_y, KTf, kh, kw, fftH, fftW, tmp_out); num += tmp_out
3109
+ _fft_conv_same_np(up_pred, KTf, kh, kw, fftH, fftW, tmp_out); den += tmp_out
3110
+
3111
+ if vt_np is not None:
3112
+ try:
3113
+ del vt_np
3114
+ except Exception as e:
3115
+ import logging
3116
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3117
+ del y_nat, up_y, up_pred, wmap, pred
3118
+
3119
+ # --- multiplicative update (NumPy) ---
3120
+ ratio = num / (den + EPS)
3121
+ neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
3122
+ ratio[neutral] = 1.0
3123
+
3124
+ upd = np.clip(ratio, 1.0 / kappa, kappa)
3125
+ x_next = np.clip(x_t * upd, 0.0, None)
3126
+
3127
+ upd_med = np.median(np.abs(upd - 1.0))
3128
+ rel_change = (
3129
+ np.median(np.abs(x_next - x_t)) /
3130
+ (np.median(np.abs(x_t)) + 1e-8)
3131
+ )
3132
+
3133
+ um = float(upd_med)
3134
+ rc = float(rel_change)
3135
+
3136
+ if early.step(it, max_iters, um, rc):
3137
+ x_t = x_next
3138
+ used_iters = it
3139
+ early_stopped = True
3140
+ status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
3141
+ _process_gui_events_safely()
3142
+ break
3143
+
3144
+ x_t = (1.0 - relax) * x_t + relax * x_next
3145
+
3146
+
3147
+ # save intermediate
3148
+ if save_intermediate and (it % int(max(1, save_every)) == 0):
3149
+ try:
3150
+ x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3151
+ _save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
3152
+ except Exception as _e:
3153
+ status_cb(f"Intermediate save failed at iter {it}: {_e}")
3154
+
3155
+ frac = 0.25 + 0.70 * (it / float(max_iters))
3156
+ _emit_pct(frac, f"Iteration {it}/{max_iters}")
3157
+ status_cb(f"Iter {it}/{max_iters}")
3158
+ _process_gui_events_safely()
3159
+ _flush_memmaps(num, den)
3160
+
3161
+ # If present in your CPU path / SR path:
3162
+ if isinstance(pred_super, np.memmap):
3163
+ _flush_memmaps(pred_super)
3164
+ if isinstance(tmp_out, np.memmap):
3165
+ _flush_memmaps(tmp_out)
3166
+
3167
+ if not early_stopped:
3168
+ used_iters = max_iters
3169
+
3170
+ # ----------------------------
3171
+ # Save result (keep FITS-friendly order: (C,H,W))
3172
+ # ----------------------------
3173
+ _emit_pct(0.97, "saving")
3174
+ x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
3175
+
3176
+ if x_final.ndim == 3:
3177
+ if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
3178
+ x_final = np.moveaxis(x_final, -1, 0)
3179
+ if x_final.shape[0] == 1:
3180
+ x_final = x_final[0]
3181
+
3182
+ try:
3183
+ hdr0 = _safe_primary_header(paths[0])
3184
+ except Exception:
3185
+ hdr0 = fits.Header()
3186
+ hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
3187
+ hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
3188
+ hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
3189
+ hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
3190
+ hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
3191
+ hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
3192
+
3193
+ hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
3194
+ if r > 1:
3195
+ hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (pixels at native)')
3196
+ hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
3197
+
3198
+ hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
3199
+ hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
3200
+ hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
3201
+
3202
+ if isinstance(x_final, np.ndarray):
3203
+ if x_final.ndim == 2:
3204
+ hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
3205
+ elif x_final.ndim == 3:
3206
+ C, H, W = x_final.shape
3207
+ hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
3208
+ _flush_memmaps(x_t, num, den)
3209
+ if isinstance(pred_super, np.memmap):
3210
+ _flush_memmaps(pred_super)
3211
+ if isinstance(tmp_out, np.memmap):
3212
+ _flush_memmaps(tmp_out)
3213
+ save_path = _sr_out_path(out_path, super_res_factor)
3214
+ safe_out_path = _nonclobber_path(str(save_path))
3215
+ if safe_out_path != str(save_path):
3216
+ status_cb(f"Output exists — saving as: {safe_out_path}")
3217
+ fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
3218
+
3219
+ status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
3220
+ _emit_pct(1.00, "done")
3221
+ _process_gui_events_safely()
3222
+
3223
+ try:
3224
+ if use_torch:
3225
+ try: del num, den
3226
+ except Exception as e:
3227
+ import logging
3228
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3229
+ try: del psf_t, psfT_t
3230
+ except Exception as e:
3231
+ import logging
3232
+ logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
3233
+ _free_torch_memory()
3234
+ except Exception:
3235
+ pass
3236
+
3237
+ return safe_out_path
3238
+
3239
+
3240
+
3241
+ # -----------------------------
3242
+ # Worker
3243
+ # -----------------------------
3244
+
3245
+ class MultiFrameDeconvWorkercuDNN(QObject):
3246
+ progress = pyqtSignal(str)
3247
+ finished = pyqtSignal(bool, str, str) # success, message, out_path
3248
+
3249
+ def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
3250
+ huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
3251
+ star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
3252
+ save_intermediate: bool = False,
3253
+ seed_mode: str = "robust",
3254
+ # NEW SR params
3255
+ super_res_factor: int = 1,
3256
+ sr_sigma: float = 1.1,
3257
+ sr_psf_opt_iters: int = 250,
3258
+ sr_psf_opt_lr: float = 0.1,
3259
+ star_mask_ref_path: str | None = None):
3260
+ super().__init__(parent)
3261
+ self.aligned_paths = aligned_paths
3262
+ self.output_path = output_path
3263
+ self.iters = iters
3264
+ self.kappa = kappa
3265
+ self.color_mode = color_mode
3266
+ self.huber_delta = huber_delta
3267
+ self.min_iters = min_iters # NEW
3268
+ self.star_mask_cfg = star_mask_cfg or {}
3269
+ self.varmap_cfg = varmap_cfg or {}
3270
+ self.use_star_masks = use_star_masks
3271
+ self.use_variance_maps = use_variance_maps
3272
+ self.rho = rho
3273
+ self.save_intermediate = save_intermediate
3274
+ self.super_res_factor = int(super_res_factor)
3275
+ self.sr_sigma = float(sr_sigma)
3276
+ self.sr_psf_opt_iters = int(sr_psf_opt_iters)
3277
+ self.sr_psf_opt_lr = float(sr_psf_opt_lr)
3278
+ self.star_mask_ref_path = star_mask_ref_path
3279
+ self.seed_mode = seed_mode
3280
+
3281
+
3282
+ def _log(self, s): self.progress.emit(s)
3283
+
3284
+ def run(self):
3285
+ try:
3286
+ out = multiframe_deconv(
3287
+ self.aligned_paths,
3288
+ self.output_path,
3289
+ iters=self.iters,
3290
+ kappa=self.kappa,
3291
+ color_mode=self.color_mode,
3292
+ seed_mode=self.seed_mode,
3293
+ huber_delta=self.huber_delta,
3294
+ use_star_masks=self.use_star_masks,
3295
+ use_variance_maps=self.use_variance_maps,
3296
+ rho=self.rho,
3297
+ min_iters=self.min_iters,
3298
+ status_cb=self._log,
3299
+ star_mask_cfg=self.star_mask_cfg,
3300
+ varmap_cfg=self.varmap_cfg,
3301
+ save_intermediate=self.save_intermediate,
3302
+ # NEW SR forwards
3303
+ super_res_factor=self.super_res_factor,
3304
+ sr_sigma=self.sr_sigma,
3305
+ sr_psf_opt_iters=self.sr_psf_opt_iters,
3306
+ sr_psf_opt_lr=self.sr_psf_opt_lr,
3307
+ star_mask_ref_path=self.star_mask_ref_path,
3308
+ )
3309
+ self.finished.emit(True, "MF deconvolution complete.", out)
3310
+ _process_gui_events_safely()
3311
+ except Exception as e:
3312
+ self.finished.emit(False, f"MF deconvolution failed: {e}", "")