setiastrosuitepro 1.6.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- setiastro/__init__.py +2 -0
- setiastro/data/SASP_data.fits +0 -0
- setiastro/data/catalogs/List_of_Galaxies_with_Distances_Gly.csv +488 -0
- setiastro/data/catalogs/astrobin_filters.csv +890 -0
- setiastro/data/catalogs/astrobin_filters_page1_local.csv +51 -0
- setiastro/data/catalogs/cali2.csv +63 -0
- setiastro/data/catalogs/cali2color.csv +65 -0
- setiastro/data/catalogs/celestial_catalog - original.csv +16471 -0
- setiastro/data/catalogs/celestial_catalog.csv +24031 -0
- setiastro/data/catalogs/detected_stars.csv +24784 -0
- setiastro/data/catalogs/fits_header_data.csv +46 -0
- setiastro/data/catalogs/test.csv +8 -0
- setiastro/data/catalogs/updated_celestial_catalog.csv +16471 -0
- setiastro/images/Astro_Spikes.png +0 -0
- setiastro/images/Background_startup.jpg +0 -0
- setiastro/images/HRDiagram.png +0 -0
- setiastro/images/LExtract.png +0 -0
- setiastro/images/LInsert.png +0 -0
- setiastro/images/Oxygenation-atm-2.svg.png +0 -0
- setiastro/images/RGB080604.png +0 -0
- setiastro/images/abeicon.png +0 -0
- setiastro/images/aberration.png +0 -0
- setiastro/images/andromedatry.png +0 -0
- setiastro/images/andromedatry_satellited.png +0 -0
- setiastro/images/annotated.png +0 -0
- setiastro/images/aperture.png +0 -0
- setiastro/images/astrosuite.ico +0 -0
- setiastro/images/astrosuite.png +0 -0
- setiastro/images/astrosuitepro.icns +0 -0
- setiastro/images/astrosuitepro.ico +0 -0
- setiastro/images/astrosuitepro.png +0 -0
- setiastro/images/background.png +0 -0
- setiastro/images/background2.png +0 -0
- setiastro/images/benchmark.png +0 -0
- setiastro/images/big_moon_stabilizer_timeline.png +0 -0
- setiastro/images/big_moon_stabilizer_timeline_clean.png +0 -0
- setiastro/images/blaster.png +0 -0
- setiastro/images/blink.png +0 -0
- setiastro/images/clahe.png +0 -0
- setiastro/images/collage.png +0 -0
- setiastro/images/colorwheel.png +0 -0
- setiastro/images/contsub.png +0 -0
- setiastro/images/convo.png +0 -0
- setiastro/images/copyslot.png +0 -0
- setiastro/images/cosmic.png +0 -0
- setiastro/images/cosmicsat.png +0 -0
- setiastro/images/crop1.png +0 -0
- setiastro/images/cropicon.png +0 -0
- setiastro/images/curves.png +0 -0
- setiastro/images/cvs.png +0 -0
- setiastro/images/debayer.png +0 -0
- setiastro/images/denoise_cnn_custom.png +0 -0
- setiastro/images/denoise_cnn_graph.png +0 -0
- setiastro/images/disk.png +0 -0
- setiastro/images/dse.png +0 -0
- setiastro/images/exoicon.png +0 -0
- setiastro/images/eye.png +0 -0
- setiastro/images/fliphorizontal.png +0 -0
- setiastro/images/flipvertical.png +0 -0
- setiastro/images/font.png +0 -0
- setiastro/images/freqsep.png +0 -0
- setiastro/images/functionbundle.png +0 -0
- setiastro/images/graxpert.png +0 -0
- setiastro/images/green.png +0 -0
- setiastro/images/gridicon.png +0 -0
- setiastro/images/halo.png +0 -0
- setiastro/images/hdr.png +0 -0
- setiastro/images/histogram.png +0 -0
- setiastro/images/hubble.png +0 -0
- setiastro/images/imagecombine.png +0 -0
- setiastro/images/invert.png +0 -0
- setiastro/images/isophote.png +0 -0
- setiastro/images/isophote_demo_figure.png +0 -0
- setiastro/images/isophote_demo_image.png +0 -0
- setiastro/images/isophote_demo_model.png +0 -0
- setiastro/images/isophote_demo_residual.png +0 -0
- setiastro/images/jwstpupil.png +0 -0
- setiastro/images/linearfit.png +0 -0
- setiastro/images/livestacking.png +0 -0
- setiastro/images/mask.png +0 -0
- setiastro/images/maskapply.png +0 -0
- setiastro/images/maskcreate.png +0 -0
- setiastro/images/maskremove.png +0 -0
- setiastro/images/morpho.png +0 -0
- setiastro/images/mosaic.png +0 -0
- setiastro/images/multiscale_decomp.png +0 -0
- setiastro/images/nbtorgb.png +0 -0
- setiastro/images/neutral.png +0 -0
- setiastro/images/nuke.png +0 -0
- setiastro/images/openfile.png +0 -0
- setiastro/images/pedestal.png +0 -0
- setiastro/images/pen.png +0 -0
- setiastro/images/pixelmath.png +0 -0
- setiastro/images/platesolve.png +0 -0
- setiastro/images/ppp.png +0 -0
- setiastro/images/pro.png +0 -0
- setiastro/images/project.png +0 -0
- setiastro/images/psf.png +0 -0
- setiastro/images/redo.png +0 -0
- setiastro/images/redoicon.png +0 -0
- setiastro/images/rescale.png +0 -0
- setiastro/images/rgbalign.png +0 -0
- setiastro/images/rgbcombo.png +0 -0
- setiastro/images/rgbextract.png +0 -0
- setiastro/images/rotate180.png +0 -0
- setiastro/images/rotatearbitrary.png +0 -0
- setiastro/images/rotateclockwise.png +0 -0
- setiastro/images/rotatecounterclockwise.png +0 -0
- setiastro/images/satellite.png +0 -0
- setiastro/images/script.png +0 -0
- setiastro/images/selectivecolor.png +0 -0
- setiastro/images/simbad.png +0 -0
- setiastro/images/slot0.png +0 -0
- setiastro/images/slot1.png +0 -0
- setiastro/images/slot2.png +0 -0
- setiastro/images/slot3.png +0 -0
- setiastro/images/slot4.png +0 -0
- setiastro/images/slot5.png +0 -0
- setiastro/images/slot6.png +0 -0
- setiastro/images/slot7.png +0 -0
- setiastro/images/slot8.png +0 -0
- setiastro/images/slot9.png +0 -0
- setiastro/images/spcc.png +0 -0
- setiastro/images/spin_precession_vs_lunar_distance.png +0 -0
- setiastro/images/spinner.gif +0 -0
- setiastro/images/stacking.png +0 -0
- setiastro/images/staradd.png +0 -0
- setiastro/images/staralign.png +0 -0
- setiastro/images/starnet.png +0 -0
- setiastro/images/starregistration.png +0 -0
- setiastro/images/starspike.png +0 -0
- setiastro/images/starstretch.png +0 -0
- setiastro/images/statstretch.png +0 -0
- setiastro/images/supernova.png +0 -0
- setiastro/images/uhs.png +0 -0
- setiastro/images/undoicon.png +0 -0
- setiastro/images/upscale.png +0 -0
- setiastro/images/viewbundle.png +0 -0
- setiastro/images/whitebalance.png +0 -0
- setiastro/images/wimi_icon_256x256.png +0 -0
- setiastro/images/wimilogo.png +0 -0
- setiastro/images/wims.png +0 -0
- setiastro/images/wrench_icon.png +0 -0
- setiastro/images/xisfliberator.png +0 -0
- setiastro/qml/ResourceMonitor.qml +126 -0
- setiastro/saspro/__init__.py +20 -0
- setiastro/saspro/__main__.py +958 -0
- setiastro/saspro/_generated/__init__.py +7 -0
- setiastro/saspro/_generated/build_info.py +3 -0
- setiastro/saspro/abe.py +1346 -0
- setiastro/saspro/abe_preset.py +196 -0
- setiastro/saspro/aberration_ai.py +698 -0
- setiastro/saspro/aberration_ai_preset.py +224 -0
- setiastro/saspro/accel_installer.py +218 -0
- setiastro/saspro/accel_workers.py +30 -0
- setiastro/saspro/add_stars.py +624 -0
- setiastro/saspro/astrobin_exporter.py +1010 -0
- setiastro/saspro/astrospike.py +153 -0
- setiastro/saspro/astrospike_python.py +1841 -0
- setiastro/saspro/autostretch.py +198 -0
- setiastro/saspro/backgroundneutral.py +611 -0
- setiastro/saspro/batch_convert.py +328 -0
- setiastro/saspro/batch_renamer.py +522 -0
- setiastro/saspro/blemish_blaster.py +491 -0
- setiastro/saspro/blink_comparator_pro.py +3149 -0
- setiastro/saspro/bundles.py +61 -0
- setiastro/saspro/bundles_dock.py +114 -0
- setiastro/saspro/cheat_sheet.py +213 -0
- setiastro/saspro/clahe.py +368 -0
- setiastro/saspro/comet_stacking.py +1442 -0
- setiastro/saspro/common_tr.py +107 -0
- setiastro/saspro/config.py +38 -0
- setiastro/saspro/config_bootstrap.py +40 -0
- setiastro/saspro/config_manager.py +316 -0
- setiastro/saspro/continuum_subtract.py +1617 -0
- setiastro/saspro/convo.py +1400 -0
- setiastro/saspro/convo_preset.py +414 -0
- setiastro/saspro/copyastro.py +190 -0
- setiastro/saspro/cosmicclarity.py +1589 -0
- setiastro/saspro/cosmicclarity_preset.py +407 -0
- setiastro/saspro/crop_dialog_pro.py +983 -0
- setiastro/saspro/crop_preset.py +189 -0
- setiastro/saspro/curve_editor_pro.py +2562 -0
- setiastro/saspro/curves_preset.py +375 -0
- setiastro/saspro/debayer.py +673 -0
- setiastro/saspro/debug_utils.py +29 -0
- setiastro/saspro/dnd_mime.py +35 -0
- setiastro/saspro/doc_manager.py +2664 -0
- setiastro/saspro/exoplanet_detector.py +2166 -0
- setiastro/saspro/file_utils.py +284 -0
- setiastro/saspro/fitsmodifier.py +748 -0
- setiastro/saspro/fix_bom.py +32 -0
- setiastro/saspro/free_torch_memory.py +48 -0
- setiastro/saspro/frequency_separation.py +1349 -0
- setiastro/saspro/function_bundle.py +1596 -0
- setiastro/saspro/generate_translations.py +3092 -0
- setiastro/saspro/ghs_dialog_pro.py +663 -0
- setiastro/saspro/ghs_preset.py +284 -0
- setiastro/saspro/graxpert.py +637 -0
- setiastro/saspro/graxpert_preset.py +287 -0
- setiastro/saspro/gui/__init__.py +0 -0
- setiastro/saspro/gui/main_window.py +8792 -0
- setiastro/saspro/gui/mixins/__init__.py +33 -0
- setiastro/saspro/gui/mixins/dock_mixin.py +375 -0
- setiastro/saspro/gui/mixins/file_mixin.py +450 -0
- setiastro/saspro/gui/mixins/geometry_mixin.py +503 -0
- setiastro/saspro/gui/mixins/header_mixin.py +441 -0
- setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
- setiastro/saspro/gui/mixins/menu_mixin.py +390 -0
- setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
- setiastro/saspro/gui/mixins/toolbar_mixin.py +1619 -0
- setiastro/saspro/gui/mixins/update_mixin.py +323 -0
- setiastro/saspro/gui/mixins/view_mixin.py +435 -0
- setiastro/saspro/gui/statistics_dialog.py +47 -0
- setiastro/saspro/halobgon.py +488 -0
- setiastro/saspro/header_viewer.py +448 -0
- setiastro/saspro/headless_utils.py +88 -0
- setiastro/saspro/histogram.py +756 -0
- setiastro/saspro/history_explorer.py +941 -0
- setiastro/saspro/i18n.py +168 -0
- setiastro/saspro/image_combine.py +417 -0
- setiastro/saspro/image_peeker_pro.py +1604 -0
- setiastro/saspro/imageops/__init__.py +37 -0
- setiastro/saspro/imageops/mdi_snap.py +292 -0
- setiastro/saspro/imageops/scnr.py +36 -0
- setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
- setiastro/saspro/imageops/stretch.py +236 -0
- setiastro/saspro/isophote.py +1182 -0
- setiastro/saspro/layers.py +208 -0
- setiastro/saspro/layers_dock.py +714 -0
- setiastro/saspro/lazy_imports.py +193 -0
- setiastro/saspro/legacy/__init__.py +2 -0
- setiastro/saspro/legacy/image_manager.py +2360 -0
- setiastro/saspro/legacy/numba_utils.py +3676 -0
- setiastro/saspro/legacy/xisf.py +1213 -0
- setiastro/saspro/linear_fit.py +537 -0
- setiastro/saspro/live_stacking.py +1854 -0
- setiastro/saspro/log_bus.py +5 -0
- setiastro/saspro/logging_config.py +460 -0
- setiastro/saspro/luminancerecombine.py +510 -0
- setiastro/saspro/main_helpers.py +201 -0
- setiastro/saspro/mask_creation.py +1086 -0
- setiastro/saspro/masks_core.py +56 -0
- setiastro/saspro/mdi_widgets.py +353 -0
- setiastro/saspro/memory_utils.py +666 -0
- setiastro/saspro/metadata_patcher.py +75 -0
- setiastro/saspro/mfdeconv.py +3909 -0
- setiastro/saspro/mfdeconv_earlystop.py +71 -0
- setiastro/saspro/mfdeconvcudnn.py +3312 -0
- setiastro/saspro/mfdeconvsport.py +2459 -0
- setiastro/saspro/minorbodycatalog.py +567 -0
- setiastro/saspro/morphology.py +407 -0
- setiastro/saspro/multiscale_decomp.py +1747 -0
- setiastro/saspro/nbtorgb_stars.py +541 -0
- setiastro/saspro/numba_utils.py +3145 -0
- setiastro/saspro/numba_warmup.py +141 -0
- setiastro/saspro/ops/__init__.py +9 -0
- setiastro/saspro/ops/command_help_dialog.py +623 -0
- setiastro/saspro/ops/command_runner.py +217 -0
- setiastro/saspro/ops/commands.py +1594 -0
- setiastro/saspro/ops/script_editor.py +1105 -0
- setiastro/saspro/ops/scripts.py +1476 -0
- setiastro/saspro/ops/settings.py +637 -0
- setiastro/saspro/parallel_utils.py +554 -0
- setiastro/saspro/pedestal.py +121 -0
- setiastro/saspro/perfect_palette_picker.py +1105 -0
- setiastro/saspro/pipeline.py +110 -0
- setiastro/saspro/pixelmath.py +1604 -0
- setiastro/saspro/plate_solver.py +2445 -0
- setiastro/saspro/project_io.py +797 -0
- setiastro/saspro/psf_utils.py +136 -0
- setiastro/saspro/psf_viewer.py +549 -0
- setiastro/saspro/pyi_rthook_astroquery.py +95 -0
- setiastro/saspro/remove_green.py +331 -0
- setiastro/saspro/remove_stars.py +1599 -0
- setiastro/saspro/remove_stars_preset.py +446 -0
- setiastro/saspro/resources.py +503 -0
- setiastro/saspro/rgb_combination.py +208 -0
- setiastro/saspro/rgb_extract.py +19 -0
- setiastro/saspro/rgbalign.py +723 -0
- setiastro/saspro/runtime_imports.py +7 -0
- setiastro/saspro/runtime_torch.py +754 -0
- setiastro/saspro/save_options.py +73 -0
- setiastro/saspro/selective_color.py +1611 -0
- setiastro/saspro/sfcc.py +1472 -0
- setiastro/saspro/shortcuts.py +3116 -0
- setiastro/saspro/signature_insert.py +1102 -0
- setiastro/saspro/stacking_suite.py +19066 -0
- setiastro/saspro/star_alignment.py +7380 -0
- setiastro/saspro/star_alignment_preset.py +329 -0
- setiastro/saspro/star_metrics.py +49 -0
- setiastro/saspro/star_spikes.py +765 -0
- setiastro/saspro/star_stretch.py +507 -0
- setiastro/saspro/stat_stretch.py +538 -0
- setiastro/saspro/status_log_dock.py +78 -0
- setiastro/saspro/subwindow.py +3407 -0
- setiastro/saspro/supernovaasteroidhunter.py +1719 -0
- setiastro/saspro/swap_manager.py +134 -0
- setiastro/saspro/torch_backend.py +89 -0
- setiastro/saspro/torch_rejection.py +434 -0
- setiastro/saspro/translations/all_source_strings.json +4726 -0
- setiastro/saspro/translations/ar_translations.py +4096 -0
- setiastro/saspro/translations/de_translations.py +3728 -0
- setiastro/saspro/translations/es_translations.py +4169 -0
- setiastro/saspro/translations/fr_translations.py +4090 -0
- setiastro/saspro/translations/hi_translations.py +3803 -0
- setiastro/saspro/translations/integrate_translations.py +271 -0
- setiastro/saspro/translations/it_translations.py +4728 -0
- setiastro/saspro/translations/ja_translations.py +3834 -0
- setiastro/saspro/translations/pt_translations.py +3847 -0
- setiastro/saspro/translations/ru_translations.py +3082 -0
- setiastro/saspro/translations/saspro_ar.qm +0 -0
- setiastro/saspro/translations/saspro_ar.ts +16019 -0
- setiastro/saspro/translations/saspro_de.qm +0 -0
- setiastro/saspro/translations/saspro_de.ts +14548 -0
- setiastro/saspro/translations/saspro_es.qm +0 -0
- setiastro/saspro/translations/saspro_es.ts +16202 -0
- setiastro/saspro/translations/saspro_fr.qm +0 -0
- setiastro/saspro/translations/saspro_fr.ts +15870 -0
- setiastro/saspro/translations/saspro_hi.qm +0 -0
- setiastro/saspro/translations/saspro_hi.ts +14855 -0
- setiastro/saspro/translations/saspro_it.qm +0 -0
- setiastro/saspro/translations/saspro_it.ts +19046 -0
- setiastro/saspro/translations/saspro_ja.qm +0 -0
- setiastro/saspro/translations/saspro_ja.ts +14980 -0
- setiastro/saspro/translations/saspro_pt.qm +0 -0
- setiastro/saspro/translations/saspro_pt.ts +15024 -0
- setiastro/saspro/translations/saspro_ru.qm +0 -0
- setiastro/saspro/translations/saspro_ru.ts +11835 -0
- setiastro/saspro/translations/saspro_sw.qm +0 -0
- setiastro/saspro/translations/saspro_sw.ts +15237 -0
- setiastro/saspro/translations/saspro_uk.qm +0 -0
- setiastro/saspro/translations/saspro_uk.ts +15248 -0
- setiastro/saspro/translations/saspro_zh.qm +0 -0
- setiastro/saspro/translations/saspro_zh.ts +15289 -0
- setiastro/saspro/translations/sw_translations.py +3897 -0
- setiastro/saspro/translations/uk_translations.py +3929 -0
- setiastro/saspro/translations/zh_translations.py +3910 -0
- setiastro/saspro/versioning.py +77 -0
- setiastro/saspro/view_bundle.py +1558 -0
- setiastro/saspro/wavescale_hdr.py +645 -0
- setiastro/saspro/wavescale_hdr_preset.py +101 -0
- setiastro/saspro/wavescalede.py +680 -0
- setiastro/saspro/wavescalede_preset.py +230 -0
- setiastro/saspro/wcs_update.py +374 -0
- setiastro/saspro/whitebalance.py +513 -0
- setiastro/saspro/widgets/__init__.py +48 -0
- setiastro/saspro/widgets/common_utilities.py +306 -0
- setiastro/saspro/widgets/graphics_views.py +122 -0
- setiastro/saspro/widgets/image_utils.py +518 -0
- setiastro/saspro/widgets/minigame/game.js +991 -0
- setiastro/saspro/widgets/minigame/index.html +53 -0
- setiastro/saspro/widgets/minigame/style.css +241 -0
- setiastro/saspro/widgets/preview_dialogs.py +280 -0
- setiastro/saspro/widgets/resource_monitor.py +263 -0
- setiastro/saspro/widgets/spinboxes.py +290 -0
- setiastro/saspro/widgets/themed_buttons.py +13 -0
- setiastro/saspro/widgets/wavelet_utils.py +331 -0
- setiastro/saspro/wimi.py +7996 -0
- setiastro/saspro/wims.py +578 -0
- setiastro/saspro/window_shelf.py +185 -0
- setiastro/saspro/xisf.py +1213 -0
- setiastrosuitepro-1.6.5.post3.dist-info/METADATA +278 -0
- setiastrosuitepro-1.6.5.post3.dist-info/RECORD +368 -0
- setiastrosuitepro-1.6.5.post3.dist-info/WHEEL +4 -0
- setiastrosuitepro-1.6.5.post3.dist-info/entry_points.txt +6 -0
- setiastrosuitepro-1.6.5.post3.dist-info/licenses/LICENSE +674 -0
- setiastrosuitepro-1.6.5.post3.dist-info/licenses/license.txt +2580 -0
|
@@ -0,0 +1,3909 @@
|
|
|
1
|
+
# pro/mfdeconv.py non-sport normal version
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import os, sys
|
|
4
|
+
import math
|
|
5
|
+
import re
|
|
6
|
+
import numpy as np
|
|
7
|
+
import time
|
|
8
|
+
from astropy.io import fits
|
|
9
|
+
from PyQt6.QtCore import QObject, pyqtSignal
|
|
10
|
+
from setiastro.saspro.psf_utils import compute_psf_kernel_for_image
|
|
11
|
+
from PyQt6.QtWidgets import QApplication
|
|
12
|
+
from PyQt6.QtCore import QThread
|
|
13
|
+
import contextlib
|
|
14
|
+
from threadpoolctl import threadpool_limits
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
|
|
16
|
+
_USE_PROCESS_POOL_FOR_ASSETS = not getattr(sys, "frozen", False)
|
|
17
|
+
import gc
|
|
18
|
+
try:
|
|
19
|
+
import sep
|
|
20
|
+
except Exception:
|
|
21
|
+
sep = None
|
|
22
|
+
from setiastro.saspro.free_torch_memory import _free_torch_memory
|
|
23
|
+
from setiastro.saspro.mfdeconv_earlystop import EarlyStopper
|
|
24
|
+
torch = None # filled by runtime loader if available
|
|
25
|
+
TORCH_OK = False
|
|
26
|
+
NO_GRAD = contextlib.nullcontext # fallback
|
|
27
|
+
_XISF_READERS = []
|
|
28
|
+
try:
|
|
29
|
+
# e.g. your legacy module
|
|
30
|
+
from setiastro.saspro.legacy import xisf as _legacy_xisf
|
|
31
|
+
if hasattr(_legacy_xisf, "read"):
|
|
32
|
+
_XISF_READERS.append(lambda p: _legacy_xisf.read(p))
|
|
33
|
+
elif hasattr(_legacy_xisf, "open"):
|
|
34
|
+
_XISF_READERS.append(lambda p: _legacy_xisf.open(p)[0])
|
|
35
|
+
except Exception:
|
|
36
|
+
pass
|
|
37
|
+
try:
|
|
38
|
+
# sometimes projects expose a generic load_image
|
|
39
|
+
from setiastro.saspro.legacy.image_manager import load_image as _generic_load_image # adjust if needed
|
|
40
|
+
_XISF_READERS.append(lambda p: _generic_load_image(p)[0])
|
|
41
|
+
except Exception:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
from pathlib import Path
|
|
45
|
+
|
|
46
|
+
# at top of file with the other imports
|
|
47
|
+
|
|
48
|
+
from queue import SimpleQueue
|
|
49
|
+
from setiastro.saspro.memory_utils import LRUDict
|
|
50
|
+
|
|
51
|
+
# ── XISF decode cache → memmap on disk ─────────────────────────────────
|
|
52
|
+
import tempfile
|
|
53
|
+
import threading
|
|
54
|
+
import uuid
|
|
55
|
+
import atexit
|
|
56
|
+
_XISF_CACHE = LRUDict(50)
|
|
57
|
+
_XISF_LOCK = threading.Lock()
|
|
58
|
+
_XISF_TMPFILES = []
|
|
59
|
+
|
|
60
|
+
from collections import OrderedDict
|
|
61
|
+
|
|
62
|
+
# ── CHW LRU (float32) built on top of FITS memmap & XISF memmap ────────────────
|
|
63
|
+
class _FrameCHWLRU:
|
|
64
|
+
def __init__(self, capacity=8):
|
|
65
|
+
self.cap = int(max(1, capacity))
|
|
66
|
+
self.od = OrderedDict()
|
|
67
|
+
|
|
68
|
+
def clear(self):
|
|
69
|
+
self.od.clear()
|
|
70
|
+
|
|
71
|
+
def get(self, path, Ht, Wt, color_mode):
|
|
72
|
+
key = (path, Ht, Wt, str(color_mode).lower())
|
|
73
|
+
hit = self.od.get(key)
|
|
74
|
+
if hit is not None:
|
|
75
|
+
self.od.move_to_end(key)
|
|
76
|
+
return hit
|
|
77
|
+
|
|
78
|
+
# Load backing array cheaply (memmap for FITS, cached memmap for XISF)
|
|
79
|
+
ext = os.path.splitext(path)[1].lower()
|
|
80
|
+
if ext == ".xisf":
|
|
81
|
+
a = _xisf_cached_array(path) # float32, HW/HWC/CHW
|
|
82
|
+
else:
|
|
83
|
+
# FITS path: use astropy memmap (no data copy)
|
|
84
|
+
with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
|
|
85
|
+
arr = None
|
|
86
|
+
for h in hdul:
|
|
87
|
+
if getattr(h, "data", None) is not None:
|
|
88
|
+
arr = h.data
|
|
89
|
+
break
|
|
90
|
+
if arr is None:
|
|
91
|
+
raise ValueError(f"No image data in {path}")
|
|
92
|
+
a = np.asarray(arr)
|
|
93
|
+
# dtype normalize once; keep float32
|
|
94
|
+
if a.dtype.kind in "ui":
|
|
95
|
+
a = a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0)
|
|
96
|
+
else:
|
|
97
|
+
a = a.astype(np.float32, copy=False)
|
|
98
|
+
|
|
99
|
+
# Center-crop to (Ht, Wt) and convert to CHW
|
|
100
|
+
a = np.asarray(a) # float32
|
|
101
|
+
a = _center_crop(a, Ht, Wt)
|
|
102
|
+
|
|
103
|
+
# Respect color_mode: “luma” → 1×H×W, “PerChannel” → 3×H×W if RGB present
|
|
104
|
+
cm = str(color_mode).lower()
|
|
105
|
+
if cm == "luma":
|
|
106
|
+
a_chw = _as_chw(_to_luma_local(a)).astype(np.float32, copy=False)
|
|
107
|
+
else:
|
|
108
|
+
a_chw = _as_chw(a).astype(np.float32, copy=False)
|
|
109
|
+
if a_chw.shape[0] == 1 and cm != "luma":
|
|
110
|
+
# still OK (mono data)
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
# LRU insert
|
|
114
|
+
self.od[key] = a_chw
|
|
115
|
+
if len(self.od) > self.cap:
|
|
116
|
+
self.od.popitem(last=False)
|
|
117
|
+
return a_chw
|
|
118
|
+
|
|
119
|
+
_FRAME_LRU = _FrameCHWLRU(capacity=8) # tune if you like
|
|
120
|
+
|
|
121
|
+
def _clear_all_caches():
|
|
122
|
+
try: _clear_xisf_cache()
|
|
123
|
+
except Exception as e:
|
|
124
|
+
import logging
|
|
125
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
126
|
+
try: _FRAME_LRU.clear()
|
|
127
|
+
except Exception as e:
|
|
128
|
+
import logging
|
|
129
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _normalize_to_float32(a: np.ndarray) -> np.ndarray:
|
|
133
|
+
if a.dtype.kind in "ui":
|
|
134
|
+
return (a.astype(np.float32) / (float(np.iinfo(a.dtype).max) or 1.0))
|
|
135
|
+
if a.dtype == np.float32:
|
|
136
|
+
return a
|
|
137
|
+
return a.astype(np.float32, copy=False)
|
|
138
|
+
|
|
139
|
+
def _xisf_cached_array(path: str) -> np.memmap:
|
|
140
|
+
"""
|
|
141
|
+
Decode an XISF image exactly once and back it by a read-only float32 memmap.
|
|
142
|
+
Returns a memmap that can be sliced cheaply for tiles.
|
|
143
|
+
"""
|
|
144
|
+
with _XISF_LOCK:
|
|
145
|
+
hit = _XISF_CACHE.get(path)
|
|
146
|
+
if hit is not None:
|
|
147
|
+
fn, shape = hit
|
|
148
|
+
return np.memmap(fn, dtype=np.float32, mode="r", shape=shape)
|
|
149
|
+
|
|
150
|
+
# Decode once
|
|
151
|
+
arr, _ = _load_image_array(path) # your existing loader
|
|
152
|
+
if arr is None:
|
|
153
|
+
raise ValueError(f"XISF loader returned None for {path}")
|
|
154
|
+
arr = np.asarray(arr)
|
|
155
|
+
arrf = _normalize_to_float32(arr)
|
|
156
|
+
|
|
157
|
+
# Create a temp file-backed memmap
|
|
158
|
+
tmpdir = tempfile.gettempdir()
|
|
159
|
+
fn = os.path.join(tmpdir, f"xisf_cache_{uuid.uuid4().hex}.mmap")
|
|
160
|
+
mm = np.memmap(fn, dtype=np.float32, mode="w+", shape=arrf.shape)
|
|
161
|
+
mm[...] = arrf[...]
|
|
162
|
+
mm.flush()
|
|
163
|
+
del mm # close writer handle; re-open below as read-only
|
|
164
|
+
|
|
165
|
+
_XISF_CACHE[path] = (fn, arrf.shape)
|
|
166
|
+
_XISF_TMPFILES.append(fn)
|
|
167
|
+
return np.memmap(fn, dtype=np.float32, mode="r", shape=arrf.shape)
|
|
168
|
+
|
|
169
|
+
def _clear_xisf_cache():
|
|
170
|
+
with _XISF_LOCK:
|
|
171
|
+
for fn in _XISF_TMPFILES:
|
|
172
|
+
try: os.remove(fn)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
import logging
|
|
175
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
176
|
+
_XISF_CACHE.clear()
|
|
177
|
+
_XISF_TMPFILES.clear()
|
|
178
|
+
|
|
179
|
+
atexit.register(_clear_xisf_cache)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _is_xisf(path: str) -> bool:
|
|
183
|
+
return os.path.splitext(path)[1].lower() == ".xisf"
|
|
184
|
+
|
|
185
|
+
def _read_xisf_numpy(path: str) -> np.ndarray:
|
|
186
|
+
if not _XISF_READERS:
|
|
187
|
+
raise RuntimeError(
|
|
188
|
+
"No XISF readers registered. Ensure one of "
|
|
189
|
+
"legacy.xisf.read/open or *.image_io.load_image is importable."
|
|
190
|
+
)
|
|
191
|
+
last_err = None
|
|
192
|
+
for fn in _XISF_READERS:
|
|
193
|
+
try:
|
|
194
|
+
arr = fn(path)
|
|
195
|
+
if isinstance(arr, tuple):
|
|
196
|
+
arr = arr[0]
|
|
197
|
+
return np.asarray(arr)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
last_err = e
|
|
200
|
+
raise RuntimeError(f"All XISF readers failed for {path}: {last_err}")
|
|
201
|
+
|
|
202
|
+
def _fits_open_data(path: str):
|
|
203
|
+
# ignore_missing_simple=True lets us open headers missing SIMPLE
|
|
204
|
+
with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
|
|
205
|
+
hdu = hdul[0]
|
|
206
|
+
if hdu.data is None:
|
|
207
|
+
# find first image HDU if primary is header-only
|
|
208
|
+
for h in hdul[1:]:
|
|
209
|
+
if getattr(h, "data", None) is not None:
|
|
210
|
+
hdu = h
|
|
211
|
+
break
|
|
212
|
+
data = np.asanyarray(hdu.data)
|
|
213
|
+
hdr = hdu.header
|
|
214
|
+
return data, hdr
|
|
215
|
+
|
|
216
|
+
def _load_image_array(path: str) -> tuple[np.ndarray, "fits.Header | None"]:
|
|
217
|
+
"""
|
|
218
|
+
Return (numpy array, fits.Header or None). Color-last if 3D.
|
|
219
|
+
dtype left as-is; callers cast to float32. Array is C-contig & writeable.
|
|
220
|
+
"""
|
|
221
|
+
if _is_xisf(path):
|
|
222
|
+
arr = _read_xisf_numpy(path)
|
|
223
|
+
hdr = None
|
|
224
|
+
else:
|
|
225
|
+
arr, hdr = _fits_open_data(path)
|
|
226
|
+
|
|
227
|
+
a = np.asarray(arr)
|
|
228
|
+
# Move color axis to last if 3D with a leading channel axis
|
|
229
|
+
if a.ndim == 3 and a.shape[0] in (1, 3) and a.shape[-1] not in (1, 3):
|
|
230
|
+
a = np.moveaxis(a, 0, -1)
|
|
231
|
+
# Ensure contiguous, writeable float32 decisions happen later; here we just ensure writeable
|
|
232
|
+
if (not a.flags.c_contiguous) or (not a.flags.writeable):
|
|
233
|
+
a = np.array(a, copy=True)
|
|
234
|
+
return a, hdr
|
|
235
|
+
|
|
236
|
+
def _probe_hw(path: str) -> tuple[int, int, int | None]:
|
|
237
|
+
"""
|
|
238
|
+
Returns (H, W, C_or_None) without changing data. Moves color to last if needed.
|
|
239
|
+
"""
|
|
240
|
+
a, _ = _load_image_array(path)
|
|
241
|
+
if a.ndim == 2:
|
|
242
|
+
return a.shape[0], a.shape[1], None
|
|
243
|
+
if a.ndim == 3:
|
|
244
|
+
h, w, c = a.shape
|
|
245
|
+
# treat mono-3D as (H,W,1)
|
|
246
|
+
if c not in (1, 3) and a.shape[0] in (1, 3):
|
|
247
|
+
a = np.moveaxis(a, 0, -1)
|
|
248
|
+
h, w, c = a.shape
|
|
249
|
+
return h, w, c if c in (1, 3) else None
|
|
250
|
+
raise ValueError(f"Unsupported ndim={a.ndim} for {path}")
|
|
251
|
+
|
|
252
|
+
def _common_hw_from_paths(paths: list[str]) -> tuple[int, int]:
|
|
253
|
+
Hs, Ws = [], []
|
|
254
|
+
for p in paths:
|
|
255
|
+
h, w, _ = _probe_hw(p)
|
|
256
|
+
h = int(h); w = int(w)
|
|
257
|
+
if h > 0 and w > 0:
|
|
258
|
+
Hs.append(h); Ws.append(w)
|
|
259
|
+
|
|
260
|
+
if not Hs:
|
|
261
|
+
raise ValueError("Could not determine any valid frame sizes.")
|
|
262
|
+
Ht = min(Hs); Wt = min(Ws)
|
|
263
|
+
if Ht < 8 or Wt < 8:
|
|
264
|
+
raise ValueError(f"Intersection too small: {Ht}x{Wt}")
|
|
265
|
+
return Ht, Wt
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _to_chw_float32(img: np.ndarray, color_mode: str) -> np.ndarray:
|
|
269
|
+
"""
|
|
270
|
+
Convert to CHW float32:
|
|
271
|
+
- mono → (1,H,W)
|
|
272
|
+
- RGB → (3,H,W) if 'PerChannel'; (1,H,W) if 'luma'
|
|
273
|
+
"""
|
|
274
|
+
x = np.asarray(img)
|
|
275
|
+
if x.ndim == 2:
|
|
276
|
+
y = x.astype(np.float32, copy=False)[None, ...] # (1,H,W)
|
|
277
|
+
return y
|
|
278
|
+
if x.ndim == 3:
|
|
279
|
+
# color-last (H,W,C) expected
|
|
280
|
+
if x.shape[-1] == 1:
|
|
281
|
+
return x[..., 0].astype(np.float32, copy=False)[None, ...]
|
|
282
|
+
if x.shape[-1] == 3:
|
|
283
|
+
if str(color_mode).lower() in ("perchannel", "per_channel", "perchannelrgb"):
|
|
284
|
+
r, g, b = x[..., 0], x[..., 1], x[..., 2]
|
|
285
|
+
return np.stack([r.astype(np.float32, copy=False),
|
|
286
|
+
g.astype(np.float32, copy=False),
|
|
287
|
+
b.astype(np.float32, copy=False)], axis=0)
|
|
288
|
+
# luma
|
|
289
|
+
r, g, b = x[..., 0].astype(np.float32, copy=False), x[..., 1].astype(np.float32, copy=False), x[..., 2].astype(np.float32, copy=False)
|
|
290
|
+
L = 0.2126*r + 0.7152*g + 0.0722*b
|
|
291
|
+
return L[None, ...]
|
|
292
|
+
# rare mono-3D
|
|
293
|
+
if x.shape[0] in (1, 3) and x.shape[-1] not in (1, 3):
|
|
294
|
+
x = np.moveaxis(x, 0, -1)
|
|
295
|
+
return _to_chw_float32(x, color_mode)
|
|
296
|
+
raise ValueError(f"Unsupported image shape {x.shape}")
|
|
297
|
+
|
|
298
|
+
def _center_crop_hw(img: np.ndarray, Ht: int, Wt: int) -> np.ndarray:
|
|
299
|
+
h, w = img.shape[:2]
|
|
300
|
+
y0 = max(0, (h - Ht)//2); x0 = max(0, (w - Wt)//2)
|
|
301
|
+
return img[y0:y0+Ht, x0:x0+Wt, ...].copy() if (Ht < h or Wt < w) else img
|
|
302
|
+
|
|
303
|
+
def _stack_loader_memmap(paths: list[str], Ht: int, Wt: int, color_mode: str):
|
|
304
|
+
"""
|
|
305
|
+
Drop-in replacement of the old FITS-only helper.
|
|
306
|
+
Returns (ys, hdrs):
|
|
307
|
+
ys : list of CHW float32 arrays cropped to (Ht,Wt)
|
|
308
|
+
hdrs : list of fits.Header or None (XISF)
|
|
309
|
+
"""
|
|
310
|
+
ys, hdrs = [], []
|
|
311
|
+
for p in paths:
|
|
312
|
+
arr, hdr = _load_image_array(p)
|
|
313
|
+
arr = _center_crop_hw(arr, Ht, Wt)
|
|
314
|
+
# normalize integer data to [0,1] like the rest of your code
|
|
315
|
+
if arr.dtype.kind in "ui":
|
|
316
|
+
mx = np.float32(np.iinfo(arr.dtype).max)
|
|
317
|
+
arr = arr.astype(np.float32, copy=False) / (mx if mx > 0 else 1.0)
|
|
318
|
+
elif arr.dtype.kind == "f":
|
|
319
|
+
arr = arr.astype(np.float32, copy=False)
|
|
320
|
+
else:
|
|
321
|
+
arr = arr.astype(np.float32, copy=False)
|
|
322
|
+
|
|
323
|
+
y = _to_chw_float32(arr, color_mode)
|
|
324
|
+
if (not y.flags.c_contiguous) or (not y.flags.writeable):
|
|
325
|
+
y = np.ascontiguousarray(y.astype(np.float32, copy=True))
|
|
326
|
+
ys.append(y)
|
|
327
|
+
hdrs.append(hdr if isinstance(hdr, fits.Header) else None)
|
|
328
|
+
return ys, hdrs
|
|
329
|
+
|
|
330
|
+
def _safe_primary_header(path: str) -> fits.Header:
|
|
331
|
+
if _is_xisf(path):
|
|
332
|
+
# best-effort synthetic header
|
|
333
|
+
h = fits.Header()
|
|
334
|
+
h["SIMPLE"] = (True, "created by MFDeconv")
|
|
335
|
+
h["BITPIX"] = -32
|
|
336
|
+
h["NAXIS"] = 2
|
|
337
|
+
return h
|
|
338
|
+
try:
|
|
339
|
+
return fits.getheader(path, ext=0, ignore_missing_simple=True)
|
|
340
|
+
except Exception:
|
|
341
|
+
return fits.Header()
|
|
342
|
+
|
|
343
|
+
# --- CUDA busy/unavailable detector (runtime fallback helper) ---
|
|
344
|
+
def _is_cuda_busy_error(e: Exception) -> bool:
|
|
345
|
+
"""
|
|
346
|
+
Return True for 'device busy/unavailable' style CUDA errors that can pop up
|
|
347
|
+
mid-run on shared Linux systems. We *exclude* OOM here (handled elsewhere).
|
|
348
|
+
"""
|
|
349
|
+
em = str(e).lower()
|
|
350
|
+
if "out of memory" in em:
|
|
351
|
+
return False # OOM handled by your batch size backoff
|
|
352
|
+
return (
|
|
353
|
+
("cuda" in em and ("busy" in em or "unavailable" in em))
|
|
354
|
+
or ("device-side" in em and "assert" in em)
|
|
355
|
+
or ("driver shutting down" in em)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _compute_frame_assets(i, arr, hdr, *, make_masks, make_varmaps,
|
|
360
|
+
star_mask_cfg, varmap_cfg, status_sink=lambda s: None):
|
|
361
|
+
"""
|
|
362
|
+
Worker function: compute PSF and optional star mask / varmap for one frame.
|
|
363
|
+
Returns (index, psf, mask_or_None, var_or_None, log_lines)
|
|
364
|
+
"""
|
|
365
|
+
logs = []
|
|
366
|
+
def log(s): logs.append(s)
|
|
367
|
+
|
|
368
|
+
# --- PSF sizing by FWHM ---
|
|
369
|
+
f_hdr = _estimate_fwhm_from_header(hdr)
|
|
370
|
+
f_img = _estimate_fwhm_from_image(arr)
|
|
371
|
+
f_whm = f_hdr if (np.isfinite(f_hdr)) else f_img
|
|
372
|
+
if not np.isfinite(f_whm) or f_whm <= 0:
|
|
373
|
+
f_whm = 2.5
|
|
374
|
+
k_auto = _auto_ksize_from_fwhm(f_whm)
|
|
375
|
+
|
|
376
|
+
# --- Star-derived PSF with retries (dynamic det_sigma ladder) ---
|
|
377
|
+
psf = None
|
|
378
|
+
|
|
379
|
+
# Your existing ksize ladder
|
|
380
|
+
k_ladder = [k_auto, max(k_auto - 4, 11), 21, 17, 15, 13, 11]
|
|
381
|
+
|
|
382
|
+
# New: start high to avoid detecting 10k stars; step down only if needed
|
|
383
|
+
sigma_ladder = [50.0, 25.0, 12.0, 6.0]
|
|
384
|
+
|
|
385
|
+
tried = set()
|
|
386
|
+
for det_sigma in sigma_ladder:
|
|
387
|
+
for k_try in k_ladder:
|
|
388
|
+
if (det_sigma, k_try) in tried:
|
|
389
|
+
continue
|
|
390
|
+
tried.add((det_sigma, k_try))
|
|
391
|
+
try:
|
|
392
|
+
out = compute_psf_kernel_for_image(arr, ksize=k_try, det_sigma=det_sigma, max_stars=80)
|
|
393
|
+
psf_try = out[0] if (isinstance(out, tuple) and len(out) >= 1) else out
|
|
394
|
+
if psf_try is not None:
|
|
395
|
+
psf = psf_try
|
|
396
|
+
break
|
|
397
|
+
except Exception:
|
|
398
|
+
psf = None
|
|
399
|
+
if psf is not None:
|
|
400
|
+
break
|
|
401
|
+
|
|
402
|
+
if psf is None:
|
|
403
|
+
psf = _gaussian_psf(f_whm, ksize=k_auto)
|
|
404
|
+
|
|
405
|
+
psf = _soften_psf(_normalize_psf(psf.astype(np.float32, copy=False)), sigma_px=0.25)
|
|
406
|
+
|
|
407
|
+
mask = None
|
|
408
|
+
var = None
|
|
409
|
+
|
|
410
|
+
if make_masks or make_varmaps:
|
|
411
|
+
# one background per frame (reused by both)
|
|
412
|
+
luma = _to_luma_local(arr)
|
|
413
|
+
vmc = (varmap_cfg or {})
|
|
414
|
+
sky_map, rms_map, err_scalar = _sep_background_precompute(
|
|
415
|
+
luma, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
if make_masks:
|
|
419
|
+
smc = star_mask_cfg or {}
|
|
420
|
+
mask = _star_mask_from_precomputed(
|
|
421
|
+
luma, sky_map, err_scalar,
|
|
422
|
+
thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
|
|
423
|
+
max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
|
|
424
|
+
grow_px = smc.get("grow_px", GROW_PX),
|
|
425
|
+
ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
|
|
426
|
+
soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
|
|
427
|
+
max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
|
|
428
|
+
keep_floor = smc.get("keep_floor", KEEP_FLOOR),
|
|
429
|
+
max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
|
|
430
|
+
status_cb = log,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
if make_varmaps:
|
|
434
|
+
vmc = varmap_cfg or {}
|
|
435
|
+
var = _variance_map_from_precomputed(
|
|
436
|
+
luma, sky_map, rms_map, hdr,
|
|
437
|
+
smooth_sigma = vmc.get("smooth_sigma", 1.0),
|
|
438
|
+
floor = vmc.get("floor", 1e-8),
|
|
439
|
+
status_cb = log,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# small per-frame summary
|
|
443
|
+
fwhm_est = _psf_fwhm_px(psf)
|
|
444
|
+
logs.insert(0, f"MFDeconv: PSF{i}: ksize={psf.shape[0]} | FWHM≈{fwhm_est:.2f}px")
|
|
445
|
+
|
|
446
|
+
return i, psf, mask, var, logs
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _compute_one_worker(args):
|
|
450
|
+
"""
|
|
451
|
+
Top-level picklable worker for ProcessPoolExecutor.
|
|
452
|
+
args: (i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg)
|
|
453
|
+
Returns (i, psf, mask, var, logs)
|
|
454
|
+
"""
|
|
455
|
+
(i, path, make_masks_in_worker, make_varmaps, star_mask_cfg, varmap_cfg) = args
|
|
456
|
+
# avoid BLAS/OMP storm inside each process
|
|
457
|
+
with threadpool_limits(limits=1):
|
|
458
|
+
arr, hdr = _load_image_array(path) # FITS or XISF
|
|
459
|
+
arr = np.asarray(arr, dtype=np.float32, order="C")
|
|
460
|
+
if arr.ndim == 3 and arr.shape[-1] == 1:
|
|
461
|
+
arr = np.squeeze(arr, axis=-1)
|
|
462
|
+
if not isinstance(hdr, fits.Header): # synthesize FITS-like header for XISF
|
|
463
|
+
hdr = _safe_primary_header(path)
|
|
464
|
+
return _compute_frame_assets(
|
|
465
|
+
i, arr, hdr,
|
|
466
|
+
make_masks=bool(make_masks_in_worker),
|
|
467
|
+
make_varmaps=bool(make_varmaps),
|
|
468
|
+
star_mask_cfg=star_mask_cfg,
|
|
469
|
+
varmap_cfg=varmap_cfg,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _build_psf_and_assets(
|
|
474
|
+
paths, # list[str]
|
|
475
|
+
make_masks=False,
|
|
476
|
+
make_varmaps=False,
|
|
477
|
+
status_cb=lambda s: None,
|
|
478
|
+
save_dir: str | None = None,
|
|
479
|
+
star_mask_cfg: dict | None = None,
|
|
480
|
+
varmap_cfg: dict | None = None,
|
|
481
|
+
max_workers: int | None = None,
|
|
482
|
+
star_mask_ref_path: str | None = None, # build one mask from this frame if provided
|
|
483
|
+
# NEW (passed from multiframe_deconv so we don’t re-probe/convert):
|
|
484
|
+
Ht: int | None = None,
|
|
485
|
+
Wt: int | None = None,
|
|
486
|
+
color_mode: str = "luma",
|
|
487
|
+
):
|
|
488
|
+
"""
|
|
489
|
+
Parallel PSF + (optional) star mask + variance map per frame.
|
|
490
|
+
|
|
491
|
+
Changes from the original:
|
|
492
|
+
• Reuses the decoded frame cache (_FRAME_LRU) for FITS/XISF so we never re-decode.
|
|
493
|
+
• Automatically switches to threads for XISF (so memmaps are shared across workers).
|
|
494
|
+
• Builds a single reference star mask (if requested) from the cached frame and
|
|
495
|
+
center-pads/crops it for all frames (no extra I/O).
|
|
496
|
+
• Preserves return order and streams worker logs back to the UI.
|
|
497
|
+
"""
|
|
498
|
+
if save_dir:
|
|
499
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
500
|
+
|
|
501
|
+
n = len(paths)
|
|
502
|
+
|
|
503
|
+
# Resolve target intersection size if caller didn't pass it
|
|
504
|
+
if Ht is None or Wt is None:
|
|
505
|
+
Ht, Wt = _common_hw_from_paths(paths)
|
|
506
|
+
|
|
507
|
+
# Sensible default worker count (cap at 8)
|
|
508
|
+
if max_workers is None:
|
|
509
|
+
try:
|
|
510
|
+
hw = os.cpu_count() or 4
|
|
511
|
+
except Exception:
|
|
512
|
+
hw = 4
|
|
513
|
+
max_workers = max(1, min(8, hw))
|
|
514
|
+
|
|
515
|
+
# Decide executor: for any XISF, prefer threads so the memmap/cache is shared
|
|
516
|
+
any_xisf = any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths)
|
|
517
|
+
use_proc_pool = (not any_xisf) and _USE_PROCESS_POOL_FOR_ASSETS
|
|
518
|
+
Executor = ProcessPoolExecutor if use_proc_pool else ThreadPoolExecutor
|
|
519
|
+
pool_kind = "process" if use_proc_pool else "thread"
|
|
520
|
+
status_cb(f"MFDeconv: measuring PSFs/masks/varmaps with {max_workers} {pool_kind}s…")
|
|
521
|
+
|
|
522
|
+
# ---- helper: pad-or-crop a 2D array to (Ht,Wt), centered ----
|
|
523
|
+
def _center_pad_or_crop_2d(a2d: np.ndarray, Ht: int, Wt: int, fill: float = 1.0) -> np.ndarray:
|
|
524
|
+
a2d = np.asarray(a2d, dtype=np.float32)
|
|
525
|
+
H, W = int(a2d.shape[0]), int(a2d.shape[1])
|
|
526
|
+
# crop first if bigger
|
|
527
|
+
y0 = max(0, (H - Ht) // 2); x0 = max(0, (W - Wt) // 2)
|
|
528
|
+
y1 = min(H, y0 + Ht); x1 = min(W, x0 + Wt)
|
|
529
|
+
cropped = a2d[y0:y1, x0:x1]
|
|
530
|
+
ch, cw = cropped.shape
|
|
531
|
+
if ch == Ht and cw == Wt:
|
|
532
|
+
return np.ascontiguousarray(cropped, dtype=np.float32)
|
|
533
|
+
# pad if smaller
|
|
534
|
+
out = np.full((Ht, Wt), float(fill), dtype=np.float32)
|
|
535
|
+
oy = (Ht - ch) // 2; ox = (Wt - cw) // 2
|
|
536
|
+
out[oy:oy+ch, ox:ox+cw] = cropped
|
|
537
|
+
return out
|
|
538
|
+
|
|
539
|
+
# ---- optional: build one mask from the reference frame and reuse ----
|
|
540
|
+
base_ref_mask = None
|
|
541
|
+
if make_masks and star_mask_ref_path:
|
|
542
|
+
try:
|
|
543
|
+
status_cb(f"Star mask: using reference frame for all masks → {os.path.basename(star_mask_ref_path)}")
|
|
544
|
+
# Pull from the shared frame cache as luma on (Ht,Wt)
|
|
545
|
+
ref_chw = _FRAME_LRU.get(star_mask_ref_path, Ht, Wt, "luma") # (1,H,W) or (H,W)
|
|
546
|
+
L = ref_chw[0] if (ref_chw.ndim == 3) else ref_chw # 2D float32
|
|
547
|
+
|
|
548
|
+
vmc = (varmap_cfg or {})
|
|
549
|
+
sky_map, rms_map, err_scalar = _sep_background_precompute(
|
|
550
|
+
L, bw=int(vmc.get("bw", 64)), bh=int(vmc.get("bh", 64))
|
|
551
|
+
)
|
|
552
|
+
smc = (star_mask_cfg or {})
|
|
553
|
+
base_ref_mask = _star_mask_from_precomputed(
|
|
554
|
+
L, sky_map, err_scalar,
|
|
555
|
+
thresh_sigma = smc.get("thresh_sigma", THRESHOLD_SIGMA),
|
|
556
|
+
max_objs = smc.get("max_objs", STAR_MASK_MAXOBJS),
|
|
557
|
+
grow_px = smc.get("grow_px", GROW_PX),
|
|
558
|
+
ellipse_scale= smc.get("ellipse_scale", ELLIPSE_SCALE),
|
|
559
|
+
soft_sigma = smc.get("soft_sigma", SOFT_SIGMA),
|
|
560
|
+
max_radius_px= smc.get("max_radius_px", MAX_STAR_RADIUS),
|
|
561
|
+
keep_floor = smc.get("keep_floor", KEEP_FLOOR),
|
|
562
|
+
max_side = smc.get("max_side", STAR_MASK_MAXSIDE),
|
|
563
|
+
status_cb = status_cb,
|
|
564
|
+
)
|
|
565
|
+
except Exception as e:
|
|
566
|
+
status_cb(f"⚠️ Star mask (reference) failed: {e}. Falling back to per-frame masks.")
|
|
567
|
+
base_ref_mask = None
|
|
568
|
+
|
|
569
|
+
# for GUI safety, queue logs from workers and flush in the main thread
|
|
570
|
+
log_queue: SimpleQueue = SimpleQueue()
|
|
571
|
+
|
|
572
|
+
def enqueue_logs(lines):
|
|
573
|
+
for s in lines:
|
|
574
|
+
log_queue.put(s)
|
|
575
|
+
|
|
576
|
+
psfs = [None] * n
|
|
577
|
+
masks = ([None] * n) if make_masks else None
|
|
578
|
+
vars_ = ([None] * n) if make_varmaps else None
|
|
579
|
+
make_masks_in_worker = bool(make_masks and (base_ref_mask is None))
|
|
580
|
+
|
|
581
|
+
# --- thread worker: get frame from cache and compute assets ---
|
|
582
|
+
def _compute_one(i: int, path: str):
|
|
583
|
+
# avoid heavy BLAS oversubscription inside each worker
|
|
584
|
+
with threadpool_limits(limits=1):
|
|
585
|
+
# Pull frame from cache honoring color_mode & target (Ht,Wt)
|
|
586
|
+
img_chw = _FRAME_LRU.get(path, Ht, Wt, color_mode) # (C,H,W) float32
|
|
587
|
+
# For PSF/mask/varmap we operate on a 2D plane (luma/mono)
|
|
588
|
+
arr2d = img_chw[0] if (img_chw.ndim == 3) else img_chw # (H,W) float32
|
|
589
|
+
|
|
590
|
+
# Header: synthesize a safe FITS-like header (works for XISF too)
|
|
591
|
+
try:
|
|
592
|
+
hdr = _safe_primary_header(path)
|
|
593
|
+
except Exception:
|
|
594
|
+
hdr = fits.Header()
|
|
595
|
+
|
|
596
|
+
return _compute_frame_assets(
|
|
597
|
+
i, arr2d, hdr,
|
|
598
|
+
make_masks=bool(make_masks_in_worker),
|
|
599
|
+
make_varmaps=bool(make_varmaps),
|
|
600
|
+
star_mask_cfg=star_mask_cfg,
|
|
601
|
+
varmap_cfg=varmap_cfg,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# --- submit jobs ---
|
|
605
|
+
with Executor(max_workers=max_workers) as ex:
|
|
606
|
+
futs = []
|
|
607
|
+
for i, p in enumerate(paths, start=1):
|
|
608
|
+
status_cb(f"MFDeconv: measuring PSF {i}/{n} …")
|
|
609
|
+
if use_proc_pool:
|
|
610
|
+
# Process-safe path: worker re-loads inside the subprocess
|
|
611
|
+
futs.append(ex.submit(
|
|
612
|
+
_compute_one_worker,
|
|
613
|
+
(i, p, bool(make_masks_in_worker), bool(make_varmaps), star_mask_cfg, varmap_cfg)
|
|
614
|
+
))
|
|
615
|
+
else:
|
|
616
|
+
# Thread path: hits the shared cache (fast path for XISF/FITS)
|
|
617
|
+
futs.append(ex.submit(_compute_one, i, p))
|
|
618
|
+
|
|
619
|
+
done_cnt = 0
|
|
620
|
+
for fut in as_completed(futs):
|
|
621
|
+
i, psf, m, v, logs = fut.result()
|
|
622
|
+
idx = i - 1
|
|
623
|
+
psfs[idx] = psf
|
|
624
|
+
if masks is not None:
|
|
625
|
+
masks[idx] = m
|
|
626
|
+
if vars_ is not None:
|
|
627
|
+
vars_[idx] = v
|
|
628
|
+
enqueue_logs(logs)
|
|
629
|
+
|
|
630
|
+
done_cnt += 1
|
|
631
|
+
if (done_cnt % 4) == 0 or done_cnt == n:
|
|
632
|
+
while not log_queue.empty():
|
|
633
|
+
try:
|
|
634
|
+
status_cb(log_queue.get_nowait())
|
|
635
|
+
except Exception:
|
|
636
|
+
break
|
|
637
|
+
|
|
638
|
+
# If we built a single reference mask, apply it to every frame (center pad/crop)
|
|
639
|
+
if base_ref_mask is not None and masks is not None:
|
|
640
|
+
for idx in range(n):
|
|
641
|
+
masks[idx] = _center_pad_or_crop_2d(base_ref_mask, int(Ht), int(Wt), fill=1.0)
|
|
642
|
+
|
|
643
|
+
# final flush of any remaining logs
|
|
644
|
+
while not log_queue.empty():
|
|
645
|
+
try:
|
|
646
|
+
status_cb(log_queue.get_nowait())
|
|
647
|
+
except Exception:
|
|
648
|
+
break
|
|
649
|
+
|
|
650
|
+
# save PSFs if requested
|
|
651
|
+
if save_dir:
|
|
652
|
+
for i, k in enumerate(psfs, start=1):
|
|
653
|
+
if k is not None:
|
|
654
|
+
fits.PrimaryHDU(k.astype(np.float32, copy=False)).writeto(
|
|
655
|
+
os.path.join(save_dir, f"psf_{i:03d}.fit"), overwrite=True
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
return psfs, masks, vars_
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
_ALLOWED = re.compile(r"[^A-Za-z0-9_-]+")
|
|
662
|
+
|
|
663
|
+
# known FITS-style multi-extensions (rightmost-first match)
|
|
664
|
+
_KNOWN_EXTS = [
|
|
665
|
+
".fits.fz", ".fit.fz", ".fits.gz", ".fit.gz",
|
|
666
|
+
".fz", ".gz",
|
|
667
|
+
".fits", ".fit"
|
|
668
|
+
]
|
|
669
|
+
|
|
670
|
+
def _sanitize_token(s: str) -> str:
|
|
671
|
+
s = _ALLOWED.sub("_", s)
|
|
672
|
+
s = re.sub(r"_+", "_", s).strip("_")
|
|
673
|
+
return s
|
|
674
|
+
|
|
675
|
+
def _split_known_exts(p: Path) -> tuple[str, str]:
|
|
676
|
+
"""
|
|
677
|
+
Return (name_body, full_ext) where full_ext is a REAL extension block
|
|
678
|
+
(e.g. '.fits.fz'). Any junk like '.0s (1310x880)_MFDeconv' stays in body.
|
|
679
|
+
"""
|
|
680
|
+
name = p.name
|
|
681
|
+
for ext in _KNOWN_EXTS:
|
|
682
|
+
if name.lower().endswith(ext):
|
|
683
|
+
body = name[:-len(ext)]
|
|
684
|
+
return body, ext
|
|
685
|
+
# fallback: single suffix
|
|
686
|
+
return p.stem, "".join(p.suffixes)
|
|
687
|
+
|
|
688
|
+
_SIZE_RE = re.compile(r"\(?\s*(\d{2,5})x(\d{2,5})\s*\)?", re.IGNORECASE)
|
|
689
|
+
_EXP_RE = re.compile(r"(?<![A-Za-z0-9])(\d+(?:\.\d+)?)\s*s\b", re.IGNORECASE)
|
|
690
|
+
_RX_RE = re.compile(r"(?<![A-Za-z0-9])(\d+)x\b", re.IGNORECASE)
|
|
691
|
+
|
|
692
|
+
def _extract_size(body: str) -> str | None:
|
|
693
|
+
m = _SIZE_RE.search(body)
|
|
694
|
+
return f"{m.group(1)}x{m.group(2)}" if m else None
|
|
695
|
+
|
|
696
|
+
def _extract_exposure_secs(body: str) -> str | None:
|
|
697
|
+
m = _EXP_RE.search(body)
|
|
698
|
+
if not m:
|
|
699
|
+
return None
|
|
700
|
+
secs = int(round(float(m.group(1))))
|
|
701
|
+
return f"{secs}s"
|
|
702
|
+
|
|
703
|
+
def _strip_metadata_from_base(body: str) -> str:
|
|
704
|
+
s = body
|
|
705
|
+
|
|
706
|
+
# normalize common separators first
|
|
707
|
+
s = s.replace(" - ", "_")
|
|
708
|
+
|
|
709
|
+
# remove known trailing marker '_MFDeconv'
|
|
710
|
+
s = re.sub(r"(?i)[\s_]+MFDeconv$", "", s)
|
|
711
|
+
|
|
712
|
+
# remove parenthetical copy counters e.g. '(1)'
|
|
713
|
+
s = re.sub(r"\(\s*\d+\s*\)$", "", s)
|
|
714
|
+
|
|
715
|
+
# remove size (with or without parens) anywhere
|
|
716
|
+
s = _SIZE_RE.sub("", s)
|
|
717
|
+
|
|
718
|
+
# remove exposures like '0s', '0.5s', ' 45 s' (even if preceded by a dot)
|
|
719
|
+
s = _EXP_RE.sub("", s)
|
|
720
|
+
|
|
721
|
+
# remove any _#x tokens
|
|
722
|
+
s = _RX_RE.sub("", s)
|
|
723
|
+
|
|
724
|
+
# collapse whitespace/underscores and sanitize
|
|
725
|
+
s = re.sub(r"[\s]+", "_", s)
|
|
726
|
+
s = _sanitize_token(s)
|
|
727
|
+
return s or "output"
|
|
728
|
+
|
|
729
|
+
def _canonical_out_name_prefix(base: str, r: int, size: str | None,
|
|
730
|
+
exposure_secs: str | None, tag: str = "MFDeconv") -> str:
|
|
731
|
+
parts = [_sanitize_token(tag), _sanitize_token(base)]
|
|
732
|
+
if size:
|
|
733
|
+
parts.append(_sanitize_token(size))
|
|
734
|
+
if exposure_secs:
|
|
735
|
+
parts.append(_sanitize_token(exposure_secs))
|
|
736
|
+
if int(max(1, r)) > 1:
|
|
737
|
+
parts.append(f"{int(r)}x")
|
|
738
|
+
return "_".join(parts)
|
|
739
|
+
|
|
740
|
+
def _sr_out_path(out_path: str, r: int) -> Path:
|
|
741
|
+
"""
|
|
742
|
+
Build: MFDeconv_<base>[_<HxW>][_<secs>s][_2x], preserving REAL extensions.
|
|
743
|
+
"""
|
|
744
|
+
p = Path(out_path)
|
|
745
|
+
body, real_ext = _split_known_exts(p)
|
|
746
|
+
|
|
747
|
+
# harvest metadata from the whole body (not Path.stem)
|
|
748
|
+
size = _extract_size(body)
|
|
749
|
+
ex_sec = _extract_exposure_secs(body)
|
|
750
|
+
|
|
751
|
+
# clean base
|
|
752
|
+
base = _strip_metadata_from_base(body)
|
|
753
|
+
|
|
754
|
+
new_stem = _canonical_out_name_prefix(base, r=int(max(1, r)), size=size, exposure_secs=ex_sec, tag="MFDeconv")
|
|
755
|
+
return p.with_name(f"{new_stem}{real_ext}")
|
|
756
|
+
|
|
757
|
+
def _nonclobber_path(path: str) -> str:
|
|
758
|
+
"""
|
|
759
|
+
Version collisions as '_v2', '_v3', ... (no spaces/parentheses).
|
|
760
|
+
"""
|
|
761
|
+
p = Path(path)
|
|
762
|
+
if not p.exists():
|
|
763
|
+
return str(p)
|
|
764
|
+
|
|
765
|
+
# keep the true extension(s)
|
|
766
|
+
body, real_ext = _split_known_exts(p)
|
|
767
|
+
|
|
768
|
+
# if already has _vN, bump it
|
|
769
|
+
m = re.search(r"(.*)_v(\d+)$", body)
|
|
770
|
+
if m:
|
|
771
|
+
base = m.group(1); n = int(m.group(2)) + 1
|
|
772
|
+
else:
|
|
773
|
+
base = body; n = 2
|
|
774
|
+
|
|
775
|
+
while True:
|
|
776
|
+
candidate = p.with_name(f"{base}_v{n}{real_ext}")
|
|
777
|
+
if not candidate.exists():
|
|
778
|
+
return str(candidate)
|
|
779
|
+
n += 1
|
|
780
|
+
|
|
781
|
+
def _iter_folder(basefile: str) -> str:
|
|
782
|
+
d, fname = os.path.split(basefile)
|
|
783
|
+
root, ext = os.path.splitext(fname)
|
|
784
|
+
tgt = os.path.join(d, f"{root}.iters")
|
|
785
|
+
if not os.path.exists(tgt):
|
|
786
|
+
try:
|
|
787
|
+
os.makedirs(tgt, exist_ok=True)
|
|
788
|
+
except Exception:
|
|
789
|
+
# last resort: suffix (n)
|
|
790
|
+
n = 1
|
|
791
|
+
while True:
|
|
792
|
+
cand = os.path.join(d, f"{root}.iters ({n})")
|
|
793
|
+
try:
|
|
794
|
+
os.makedirs(cand, exist_ok=True)
|
|
795
|
+
return cand
|
|
796
|
+
except Exception:
|
|
797
|
+
n += 1
|
|
798
|
+
return tgt
|
|
799
|
+
|
|
800
|
+
def _save_iter_image(arr, hdr_base, folder, tag, color_mode):
|
|
801
|
+
"""
|
|
802
|
+
arr: numpy array (H,W) or (C,H,W) float32
|
|
803
|
+
tag: 'seed' or 'iter_###'
|
|
804
|
+
"""
|
|
805
|
+
if arr.ndim == 3 and arr.shape[0] not in (1, 3) and arr.shape[-1] in (1, 3):
|
|
806
|
+
arr = np.moveaxis(arr, -1, 0)
|
|
807
|
+
if arr.ndim == 3 and arr.shape[0] == 1:
|
|
808
|
+
arr = arr[0]
|
|
809
|
+
|
|
810
|
+
hdr = fits.Header(hdr_base) if isinstance(hdr_base, fits.Header) else fits.Header()
|
|
811
|
+
hdr['MF_PART'] = (str(tag), 'MFDeconv intermediate (seed/iter)')
|
|
812
|
+
hdr['MF_COLOR'] = (str(color_mode), 'Color mode used')
|
|
813
|
+
path = os.path.join(folder, f"{tag}.fit")
|
|
814
|
+
# overwrite allowed inside the dedicated folder
|
|
815
|
+
fits.PrimaryHDU(data=arr.astype(np.float32, copy=False), header=hdr).writeto(path, overwrite=True)
|
|
816
|
+
return path
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def _process_gui_events_safely():
|
|
820
|
+
app = QApplication.instance()
|
|
821
|
+
if app and QThread.currentThread() is app.thread():
|
|
822
|
+
app.processEvents()
|
|
823
|
+
|
|
824
|
+
EPS = 1e-6
|
|
825
|
+
|
|
826
|
+
# -----------------------------
|
|
827
|
+
# Helpers: image prep / shapes
|
|
828
|
+
# -----------------------------
|
|
829
|
+
|
|
830
|
+
# new: lightweight loader that yields one frame at a time
|
|
831
|
+
def _iter_fits(paths):
|
|
832
|
+
for p in paths:
|
|
833
|
+
with fits.open(p, memmap=False) as hdul: # ⬅ False
|
|
834
|
+
arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy
|
|
835
|
+
if arr.ndim == 3 and arr.shape[-1] == 1:
|
|
836
|
+
arr = np.squeeze(arr, axis=-1)
|
|
837
|
+
hdr = hdul[0].header.copy()
|
|
838
|
+
yield arr, hdr
|
|
839
|
+
|
|
840
|
+
def _to_luma_local(a: np.ndarray) -> np.ndarray:
|
|
841
|
+
a = np.asarray(a, dtype=np.float32)
|
|
842
|
+
if a.ndim == 2:
|
|
843
|
+
return a
|
|
844
|
+
# (H,W,3) or (3,H,W)
|
|
845
|
+
if a.ndim == 3 and a.shape[-1] == 3:
|
|
846
|
+
try:
|
|
847
|
+
import cv2
|
|
848
|
+
return cv2.cvtColor(a, cv2.COLOR_RGB2GRAY).astype(np.float32, copy=False)
|
|
849
|
+
except Exception:
|
|
850
|
+
pass
|
|
851
|
+
r, g, b = a[..., 0], a[..., 1], a[..., 2]
|
|
852
|
+
return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
|
|
853
|
+
if a.ndim == 3 and a.shape[0] == 3:
|
|
854
|
+
r, g, b = a[0], a[1], a[2]
|
|
855
|
+
return (0.2126*r + 0.7152*g + 0.0722*b).astype(np.float32, copy=False)
|
|
856
|
+
return a.mean(axis=-1).astype(np.float32, copy=False)
|
|
857
|
+
|
|
858
|
+
def _stack_loader(paths):
|
|
859
|
+
ys, hdrs = [], []
|
|
860
|
+
for p in paths:
|
|
861
|
+
with fits.open(p, memmap=False) as hdul: # ⬅ False
|
|
862
|
+
arr = np.array(hdul[0].data, dtype=np.float32, copy=True) # ⬅ copy inside with
|
|
863
|
+
hdr = hdul[0].header.copy()
|
|
864
|
+
if arr.ndim == 3 and arr.shape[-1] == 1:
|
|
865
|
+
arr = np.squeeze(arr, axis=-1)
|
|
866
|
+
ys.append(arr)
|
|
867
|
+
hdrs.append(hdr)
|
|
868
|
+
return ys, hdrs
|
|
869
|
+
|
|
870
|
+
def _normalize_layout_single(a, color_mode):
|
|
871
|
+
"""
|
|
872
|
+
Coerce to:
|
|
873
|
+
- 'luma' -> (H, W)
|
|
874
|
+
- 'perchannel' -> (C, H, W); mono stays (1,H,W), RGB → (3,H,W)
|
|
875
|
+
Accepts (H,W), (H,W,3), or (3,H,W).
|
|
876
|
+
"""
|
|
877
|
+
a = np.asarray(a, dtype=np.float32)
|
|
878
|
+
|
|
879
|
+
if color_mode == "luma":
|
|
880
|
+
return _to_luma_local(a) # returns (H,W)
|
|
881
|
+
|
|
882
|
+
# perchannel
|
|
883
|
+
if a.ndim == 2:
|
|
884
|
+
return a[None, ...] # (1,H,W) ← keep mono as 1 channel
|
|
885
|
+
if a.ndim == 3 and a.shape[-1] == 3:
|
|
886
|
+
return np.moveaxis(a, -1, 0) # (3,H,W)
|
|
887
|
+
if a.ndim == 3 and a.shape[0] in (1, 3):
|
|
888
|
+
return a # already (1,H,W) or (3,H,W)
|
|
889
|
+
# fallback: average any weird shape into luma 1×H×W
|
|
890
|
+
l = _to_luma_local(a)
|
|
891
|
+
return l[None, ...]
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def _normalize_layout_batch(arrs, color_mode):
|
|
895
|
+
return [_normalize_layout_single(a, color_mode) for a in arrs]
|
|
896
|
+
|
|
897
|
+
def _common_hw(data_list):
|
|
898
|
+
"""Return minimal (H,W) across items; items are (H,W) or (C,H,W)."""
|
|
899
|
+
Hs, Ws = [], []
|
|
900
|
+
for a in data_list:
|
|
901
|
+
if a.ndim == 2:
|
|
902
|
+
H, W = a.shape
|
|
903
|
+
else:
|
|
904
|
+
_, H, W = a.shape
|
|
905
|
+
Hs.append(H); Ws.append(W)
|
|
906
|
+
return int(min(Hs)), int(min(Ws))
|
|
907
|
+
|
|
908
|
+
def _center_crop(arr, Ht, Wt):
|
|
909
|
+
"""Center-crop arr (H,W) or (C,H,W) to (Ht,Wt)."""
|
|
910
|
+
if arr.ndim == 2:
|
|
911
|
+
H, W = arr.shape
|
|
912
|
+
if H == Ht and W == Wt:
|
|
913
|
+
return arr
|
|
914
|
+
y0 = max(0, (H - Ht) // 2)
|
|
915
|
+
x0 = max(0, (W - Wt) // 2)
|
|
916
|
+
return arr[y0:y0+Ht, x0:x0+Wt]
|
|
917
|
+
else:
|
|
918
|
+
C, H, W = arr.shape
|
|
919
|
+
if H == Ht and W == Wt:
|
|
920
|
+
return arr
|
|
921
|
+
y0 = max(0, (H - Ht) // 2)
|
|
922
|
+
x0 = max(0, (W - Wt) // 2)
|
|
923
|
+
return arr[:, y0:y0+Ht, x0:x0+Wt]
|
|
924
|
+
|
|
925
|
+
def _sanitize_numeric(a):
|
|
926
|
+
"""Replace NaN/Inf, clip negatives, make contiguous float32."""
|
|
927
|
+
a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
|
|
928
|
+
a = np.clip(a, 0.0, None).astype(np.float32, copy=False)
|
|
929
|
+
return np.ascontiguousarray(a)
|
|
930
|
+
|
|
931
|
+
# -----------------------------
|
|
932
|
+
# PSF utilities
|
|
933
|
+
# -----------------------------
|
|
934
|
+
|
|
935
|
+
def _gaussian_psf(fwhm_px: float, ksize: int) -> np.ndarray:
|
|
936
|
+
sigma = max(fwhm_px, 1.0) / 2.3548
|
|
937
|
+
r = (ksize - 1) / 2
|
|
938
|
+
y, x = np.mgrid[-r:r+1, -r:r+1]
|
|
939
|
+
g = np.exp(-(x*x + y*y) / (2*sigma*sigma))
|
|
940
|
+
g /= (np.sum(g) + EPS)
|
|
941
|
+
return g.astype(np.float32, copy=False)
|
|
942
|
+
|
|
943
|
+
def _estimate_fwhm_from_header(hdr) -> float:
|
|
944
|
+
for key in ("FWHM", "FWHM_PIX", "PSF_FWHM"):
|
|
945
|
+
if key in hdr:
|
|
946
|
+
try:
|
|
947
|
+
val = float(hdr[key])
|
|
948
|
+
if np.isfinite(val) and val > 0:
|
|
949
|
+
return val
|
|
950
|
+
except Exception:
|
|
951
|
+
pass
|
|
952
|
+
return float("nan")
|
|
953
|
+
|
|
954
|
+
def _estimate_fwhm_from_image(arr) -> float:
|
|
955
|
+
"""Fast FWHM estimate from SEP 'a','b' parameters (≈ sigma in px)."""
|
|
956
|
+
if sep is None:
|
|
957
|
+
return float("nan")
|
|
958
|
+
try:
|
|
959
|
+
img = _contig(_to_luma_local(arr)) # ← ensure C-contig float32
|
|
960
|
+
bkg = sep.Background(img)
|
|
961
|
+
data = _contig(img - bkg.back()) # ← ensure data is C-contig
|
|
962
|
+
try:
|
|
963
|
+
err = bkg.globalrms
|
|
964
|
+
except Exception:
|
|
965
|
+
err = float(np.median(bkg.rms()))
|
|
966
|
+
sources = sep.extract(data, 6.0, err=err)
|
|
967
|
+
if sources is None or len(sources) == 0:
|
|
968
|
+
return float("nan")
|
|
969
|
+
a = np.asarray(sources["a"], dtype=np.float32)
|
|
970
|
+
b = np.asarray(sources["b"], dtype=np.float32)
|
|
971
|
+
ab = (a + b) * 0.5
|
|
972
|
+
sigma = float(np.median(ab[np.isfinite(ab) & (ab > 0)]))
|
|
973
|
+
if not np.isfinite(sigma) or sigma <= 0:
|
|
974
|
+
return float("nan")
|
|
975
|
+
return 2.3548 * sigma
|
|
976
|
+
except Exception:
|
|
977
|
+
return float("nan")
|
|
978
|
+
|
|
979
|
+
def _auto_ksize_from_fwhm(fwhm_px: float, kmin: int = 11, kmax: int = 51) -> int:
|
|
980
|
+
"""
|
|
981
|
+
Choose odd kernel size to cover about ±4σ.
|
|
982
|
+
"""
|
|
983
|
+
sigma = max(fwhm_px, 1.0) / 2.3548
|
|
984
|
+
r = int(math.ceil(4.0 * sigma))
|
|
985
|
+
k = 2 * r + 1
|
|
986
|
+
k = max(kmin, min(k, kmax))
|
|
987
|
+
if (k % 2) == 0:
|
|
988
|
+
k += 1
|
|
989
|
+
return k
|
|
990
|
+
|
|
991
|
+
def _flip_kernel(psf):
|
|
992
|
+
# PyTorch dislikes negative strides; make it contiguous.
|
|
993
|
+
return np.flip(np.flip(psf, -1), -2).copy()
|
|
994
|
+
|
|
995
|
+
def _conv_same_np(img, psf):
|
|
996
|
+
"""
|
|
997
|
+
NumPy FFT-based SAME convolution for (H,W) or (C,H,W).
|
|
998
|
+
IMPORTANT: ifftshift the PSF so its peak is at [0,0] before FFT.
|
|
999
|
+
"""
|
|
1000
|
+
import numpy as _np
|
|
1001
|
+
import numpy.fft as _fft
|
|
1002
|
+
|
|
1003
|
+
kh, kw = psf.shape
|
|
1004
|
+
|
|
1005
|
+
def fftconv2(a, k):
|
|
1006
|
+
# a is (1,H,W); k is (kh,kw)
|
|
1007
|
+
H, W = a.shape[-2:]
|
|
1008
|
+
fftH, fftW = _fftshape_same(H, W, kh, kw)
|
|
1009
|
+
A = _fft.rfftn(a, s=(fftH, fftW), axes=(-2, -1))
|
|
1010
|
+
K = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW), axes=(-2, -1))
|
|
1011
|
+
y = _fft.irfftn(A * K, s=(fftH, fftW), axes=(-2, -1))
|
|
1012
|
+
sh, sw = (kh - 1)//2, (kw - 1)//2
|
|
1013
|
+
return y[..., sh:sh+H, sw:sw+W]
|
|
1014
|
+
|
|
1015
|
+
if img.ndim == 2:
|
|
1016
|
+
return fftconv2(img[None], psf)[0]
|
|
1017
|
+
else:
|
|
1018
|
+
# per-channel
|
|
1019
|
+
return _np.stack([fftconv2(img[c:c+1], psf)[0] for c in range(img.shape[0])], axis=0)
|
|
1020
|
+
|
|
1021
|
+
def _normalize_psf(psf):
|
|
1022
|
+
psf = np.maximum(psf, 0.0).astype(np.float32, copy=False)
|
|
1023
|
+
s = float(psf.sum())
|
|
1024
|
+
if not np.isfinite(s) or s <= 1e-6:
|
|
1025
|
+
return psf
|
|
1026
|
+
return (psf / s).astype(np.float32, copy=False)
|
|
1027
|
+
|
|
1028
|
+
def _soften_psf(psf, sigma_px=0.25):
|
|
1029
|
+
if sigma_px <= 0:
|
|
1030
|
+
return psf
|
|
1031
|
+
r = int(max(1, round(3 * sigma_px)))
|
|
1032
|
+
y, x = np.mgrid[-r:r+1, -r:r+1]
|
|
1033
|
+
g = np.exp(-(x*x + y*y) / (2 * sigma_px * sigma_px)).astype(np.float32)
|
|
1034
|
+
g /= g.sum() + 1e-6
|
|
1035
|
+
return _conv_same_np(psf[None], g)[0]
|
|
1036
|
+
|
|
1037
|
+
def _psf_fwhm_px(psf: np.ndarray) -> float:
|
|
1038
|
+
"""Approximate FWHM (pixels) from second moments of a normalized kernel."""
|
|
1039
|
+
psf = np.maximum(psf, 0).astype(np.float32, copy=False)
|
|
1040
|
+
s = float(psf.sum())
|
|
1041
|
+
if s <= EPS:
|
|
1042
|
+
return float("nan")
|
|
1043
|
+
k = psf.shape[0]
|
|
1044
|
+
y, x = np.mgrid[:k, :k].astype(np.float32)
|
|
1045
|
+
cy = float((psf * y).sum() / s)
|
|
1046
|
+
cx = float((psf * x).sum() / s)
|
|
1047
|
+
var_y = float((psf * (y - cy) ** 2).sum() / s)
|
|
1048
|
+
var_x = float((psf * (x - cx) ** 2).sum() / s)
|
|
1049
|
+
sigma = math.sqrt(max(0.0, 0.5 * (var_x + var_y)))
|
|
1050
|
+
return 2.3548 * sigma # FWHM≈2.355σ
|
|
1051
|
+
|
|
1052
|
+
STAR_MASK_MAXSIDE = 2048
|
|
1053
|
+
STAR_MASK_MAXOBJS = 2000 # cap number of objects
|
|
1054
|
+
VARMAP_SAMPLE_STRIDE = 8 # (kept for compat; currently unused internally)
|
|
1055
|
+
THRESHOLD_SIGMA = 2.0
|
|
1056
|
+
KEEP_FLOOR = 0.20
|
|
1057
|
+
GROW_PX = 8
|
|
1058
|
+
MAX_STAR_RADIUS = 16
|
|
1059
|
+
SOFT_SIGMA = 2.0
|
|
1060
|
+
ELLIPSE_SCALE = 1.2
|
|
1061
|
+
|
|
1062
|
+
def _sep_background_precompute(img_2d: np.ndarray, bw: int = 64, bh: int = 64):
|
|
1063
|
+
"""
|
|
1064
|
+
One-time SEP background build; returns (sky_map, rms_map, err_scalar).
|
|
1065
|
+
|
|
1066
|
+
Guarantees:
|
|
1067
|
+
- Always returns a 3-tuple (sky, rms, err)
|
|
1068
|
+
- sky/rms are float32 and same shape as img_2d
|
|
1069
|
+
- Robust to sep missing, sep errors, NaNs/Infs, and tiny frames
|
|
1070
|
+
"""
|
|
1071
|
+
a = np.asarray(img_2d, dtype=np.float32)
|
|
1072
|
+
if a.ndim != 2:
|
|
1073
|
+
# be strict; callers expect 2D
|
|
1074
|
+
raise ValueError(f"_sep_background_precompute expects 2D, got shape={a.shape}")
|
|
1075
|
+
|
|
1076
|
+
H, W = int(a.shape[0]), int(a.shape[1])
|
|
1077
|
+
if H == 0 or W == 0:
|
|
1078
|
+
# should never happen, but don't return empty tuple
|
|
1079
|
+
sky = np.zeros((H, W), dtype=np.float32)
|
|
1080
|
+
rms = np.ones((H, W), dtype=np.float32)
|
|
1081
|
+
return sky, rms, 1.0
|
|
1082
|
+
|
|
1083
|
+
# --- robust fallback builder (works for any input) ---
|
|
1084
|
+
def _fallback():
|
|
1085
|
+
# Use finite-only stats if possible
|
|
1086
|
+
finite = np.isfinite(a)
|
|
1087
|
+
if finite.any():
|
|
1088
|
+
vals = a[finite]
|
|
1089
|
+
med = float(np.median(vals))
|
|
1090
|
+
mad = float(np.median(np.abs(vals - med))) + 1e-6
|
|
1091
|
+
else:
|
|
1092
|
+
med = 0.0
|
|
1093
|
+
mad = 1.0
|
|
1094
|
+
sky = np.full((H, W), med, dtype=np.float32)
|
|
1095
|
+
rms = np.full((H, W), 1.4826 * mad, dtype=np.float32)
|
|
1096
|
+
err = float(np.median(rms))
|
|
1097
|
+
return sky, rms, err
|
|
1098
|
+
|
|
1099
|
+
# If sep isn't available, always fallback
|
|
1100
|
+
if sep is None:
|
|
1101
|
+
return _fallback()
|
|
1102
|
+
|
|
1103
|
+
# SEP is present: sanitize input and clamp tile sizes
|
|
1104
|
+
# sep can choke on NaNs/Infs
|
|
1105
|
+
if not np.isfinite(a).all():
|
|
1106
|
+
# replace non-finite with median of finite values (or 0)
|
|
1107
|
+
finite = np.isfinite(a)
|
|
1108
|
+
fill = float(np.median(a[finite])) if finite.any() else 0.0
|
|
1109
|
+
a = np.where(finite, a, fill).astype(np.float32, copy=False)
|
|
1110
|
+
|
|
1111
|
+
a = np.ascontiguousarray(a, dtype=np.float32)
|
|
1112
|
+
|
|
1113
|
+
# Clamp bw/bh to image size; SEP doesn't like bw/bh > dims
|
|
1114
|
+
bw = int(max(8, min(int(bw), W)))
|
|
1115
|
+
bh = int(max(8, min(int(bh), H)))
|
|
1116
|
+
|
|
1117
|
+
try:
|
|
1118
|
+
b = sep.Background(a, bw=bw, bh=bh, fw=3, fh=3)
|
|
1119
|
+
|
|
1120
|
+
sky = np.asarray(b.back(), dtype=np.float32)
|
|
1121
|
+
rms = np.asarray(b.rms(), dtype=np.float32)
|
|
1122
|
+
|
|
1123
|
+
# Ensure shape sanity (SEP should match, but be paranoid)
|
|
1124
|
+
if sky.shape != a.shape or rms.shape != a.shape:
|
|
1125
|
+
return _fallback()
|
|
1126
|
+
|
|
1127
|
+
# globalrms sometimes isn't available depending on SEP build
|
|
1128
|
+
err = float(getattr(b, "globalrms", np.nan))
|
|
1129
|
+
if not np.isfinite(err) or err <= 0:
|
|
1130
|
+
# robust scalar: median rms
|
|
1131
|
+
err = float(np.median(rms)) if rms.size else 1.0
|
|
1132
|
+
|
|
1133
|
+
return sky, rms, err
|
|
1134
|
+
|
|
1135
|
+
except Exception:
|
|
1136
|
+
# If SEP blows up for any reason, degrade gracefully
|
|
1137
|
+
return _fallback()
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _auto_star_mask_sep(
|
|
1142
|
+
img_2d: np.ndarray,
|
|
1143
|
+
thresh_sigma: float = THRESHOLD_SIGMA,
|
|
1144
|
+
grow_px: int = GROW_PX,
|
|
1145
|
+
max_objs: int = STAR_MASK_MAXOBJS,
|
|
1146
|
+
max_side: int = STAR_MASK_MAXSIDE,
|
|
1147
|
+
ellipse_scale: float = ELLIPSE_SCALE,
|
|
1148
|
+
soft_sigma: float = SOFT_SIGMA,
|
|
1149
|
+
max_semiaxis_px: float | None = None, # kept for API compat; unused
|
|
1150
|
+
max_area_px2: float | None = None, # kept for API compat; unused
|
|
1151
|
+
max_radius_px: int = MAX_STAR_RADIUS,
|
|
1152
|
+
keep_floor: float = KEEP_FLOOR,
|
|
1153
|
+
status_cb=lambda s: None
|
|
1154
|
+
) -> np.ndarray:
|
|
1155
|
+
"""
|
|
1156
|
+
Build a KEEP weight map (float32 in [0,1]) using SEP detections only.
|
|
1157
|
+
**Never writes to img_2d** and draws only into a fresh mask buffer.
|
|
1158
|
+
"""
|
|
1159
|
+
if sep is None:
|
|
1160
|
+
return np.ones_like(img_2d, dtype=np.float32, order="C")
|
|
1161
|
+
|
|
1162
|
+
# Optional OpenCV path for fast draw/blur
|
|
1163
|
+
try:
|
|
1164
|
+
import cv2 as _cv2
|
|
1165
|
+
_HAS_CV2 = True
|
|
1166
|
+
except Exception:
|
|
1167
|
+
_HAS_CV2 = False
|
|
1168
|
+
_cv2 = None # type: ignore
|
|
1169
|
+
|
|
1170
|
+
h, w = map(int, img_2d.shape)
|
|
1171
|
+
|
|
1172
|
+
# Background / residual on our own contiguous buffer
|
|
1173
|
+
data = np.ascontiguousarray(img_2d.astype(np.float32))
|
|
1174
|
+
bkg = sep.Background(data)
|
|
1175
|
+
data_sub = np.ascontiguousarray(data - bkg.back(), dtype=np.float32)
|
|
1176
|
+
try:
|
|
1177
|
+
err_scalar = float(bkg.globalrms)
|
|
1178
|
+
except Exception:
|
|
1179
|
+
err_scalar = float(np.median(np.asarray(bkg.rms(), dtype=np.float32)))
|
|
1180
|
+
|
|
1181
|
+
# Downscale for detection only
|
|
1182
|
+
det = data_sub
|
|
1183
|
+
scale = 1.0
|
|
1184
|
+
if max_side and max(h, w) > int(max_side):
|
|
1185
|
+
scale = float(max(h, w)) / float(max_side)
|
|
1186
|
+
if _HAS_CV2:
|
|
1187
|
+
det = _cv2.resize(
|
|
1188
|
+
det,
|
|
1189
|
+
(max(1, int(round(w / scale))), max(1, int(round(h / scale)))),
|
|
1190
|
+
interpolation=_cv2.INTER_AREA
|
|
1191
|
+
)
|
|
1192
|
+
else:
|
|
1193
|
+
s = int(max(1, round(scale)))
|
|
1194
|
+
det = det[:(h // s) * s, :(w // s) * s].reshape(h // s, s, w // s, s).mean(axis=(1, 3))
|
|
1195
|
+
scale = float(s)
|
|
1196
|
+
|
|
1197
|
+
# When averaging down by 'scale', per-pixel noise scales ~ 1/scale
|
|
1198
|
+
err_det = float(err_scalar) / float(max(1.0, scale))
|
|
1199
|
+
|
|
1200
|
+
thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
|
|
1201
|
+
thresh_sigma*8, thresh_sigma*16]
|
|
1202
|
+
objs = None; used = float("nan"); raw = 0
|
|
1203
|
+
for t in thresholds:
|
|
1204
|
+
try:
|
|
1205
|
+
cand = sep.extract(det, thresh=float(t), err=err_det)
|
|
1206
|
+
except Exception:
|
|
1207
|
+
cand = None
|
|
1208
|
+
n = 0 if cand is None else len(cand)
|
|
1209
|
+
if n == 0: continue
|
|
1210
|
+
if n > max_objs*12: continue
|
|
1211
|
+
objs, raw, used = cand, n, float(t)
|
|
1212
|
+
break
|
|
1213
|
+
|
|
1214
|
+
if objs is None or len(objs) == 0:
|
|
1215
|
+
try:
|
|
1216
|
+
cand = sep.extract(det, thresh=float(thresholds[-1]), err=err_det, minarea=9)
|
|
1217
|
+
except Exception:
|
|
1218
|
+
cand = None
|
|
1219
|
+
if cand is None or len(cand) == 0:
|
|
1220
|
+
status_cb("Star mask: no sources found (mask disabled for this frame).")
|
|
1221
|
+
return np.ones((h, w), dtype=np.float32, order="C")
|
|
1222
|
+
objs, raw, used = cand, len(cand), float(thresholds[-1])
|
|
1223
|
+
|
|
1224
|
+
# Keep brightest max_objs
|
|
1225
|
+
if "flux" in objs.dtype.names:
|
|
1226
|
+
idx = np.argsort(objs["flux"])[-int(max_objs):]
|
|
1227
|
+
objs = objs[idx]
|
|
1228
|
+
else:
|
|
1229
|
+
objs = objs[:int(max_objs)]
|
|
1230
|
+
kept_after_cap = int(len(objs))
|
|
1231
|
+
|
|
1232
|
+
# Draw on full-res fresh buffer
|
|
1233
|
+
mask_u8 = np.zeros((h, w), dtype=np.uint8, order="C")
|
|
1234
|
+
s_back = float(scale)
|
|
1235
|
+
MR = int(max(1, max_radius_px))
|
|
1236
|
+
G = int(max(0, grow_px))
|
|
1237
|
+
ES = float(max(0.1, ellipse_scale))
|
|
1238
|
+
|
|
1239
|
+
drawn = 0
|
|
1240
|
+
if _HAS_CV2:
|
|
1241
|
+
for o in objs:
|
|
1242
|
+
x = int(round(float(o["x"]) * s_back))
|
|
1243
|
+
y = int(round(float(o["y"]) * s_back))
|
|
1244
|
+
if not (0 <= x < w and 0 <= y < h):
|
|
1245
|
+
continue
|
|
1246
|
+
a = float(o["a"]) * s_back
|
|
1247
|
+
b = float(o["b"]) * s_back
|
|
1248
|
+
r = int(math.ceil(ES * max(a, b)))
|
|
1249
|
+
r = min(max(r, 0) + G, MR)
|
|
1250
|
+
if r <= 0:
|
|
1251
|
+
continue
|
|
1252
|
+
_cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
|
|
1253
|
+
drawn += 1
|
|
1254
|
+
else:
|
|
1255
|
+
yy, xx = np.ogrid[:h, :w]
|
|
1256
|
+
for o in objs:
|
|
1257
|
+
x = int(round(float(o["x"]) * s_back))
|
|
1258
|
+
y = int(round(float(o["y"]) * s_back))
|
|
1259
|
+
if not (0 <= x < w and 0 <= y < h):
|
|
1260
|
+
continue
|
|
1261
|
+
a = float(o["a"]) * s_back
|
|
1262
|
+
b = float(o["b"]) * s_back
|
|
1263
|
+
r = int(math.ceil(ES * max(a, b)))
|
|
1264
|
+
r = min(max(r, 0) + G, MR)
|
|
1265
|
+
if r <= 0:
|
|
1266
|
+
continue
|
|
1267
|
+
y0 = max(0, y - r); y1 = min(h, y + r + 1)
|
|
1268
|
+
x0 = max(0, x - r); x1 = min(w, x + r + 1)
|
|
1269
|
+
ys = yy[y0:y1] - y
|
|
1270
|
+
xs = xx[x0:x1] - x
|
|
1271
|
+
disk = (ys * ys + xs * xs) <= (r * r)
|
|
1272
|
+
mask_u8[y0:y1, x0:x1][disk] = 1
|
|
1273
|
+
drawn += 1
|
|
1274
|
+
|
|
1275
|
+
masked_px_hard = int(mask_u8.sum())
|
|
1276
|
+
|
|
1277
|
+
# Feather and convert to KEEP weights in [0,1]
|
|
1278
|
+
m = mask_u8.astype(np.float32, copy=False)
|
|
1279
|
+
if soft_sigma and soft_sigma > 0.0:
|
|
1280
|
+
try:
|
|
1281
|
+
if _HAS_CV2:
|
|
1282
|
+
k = int(max(1, math.ceil(soft_sigma * 3)) * 2 + 1)
|
|
1283
|
+
m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
|
|
1284
|
+
borderType=_cv2.BORDER_REFLECT)
|
|
1285
|
+
else:
|
|
1286
|
+
from scipy.ndimage import gaussian_filter
|
|
1287
|
+
m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
|
|
1288
|
+
except Exception:
|
|
1289
|
+
pass
|
|
1290
|
+
np.clip(m, 0.0, 1.0, out=m)
|
|
1291
|
+
|
|
1292
|
+
keep = 1.0 - m
|
|
1293
|
+
kf = float(max(0.0, min(0.99, keep_floor)))
|
|
1294
|
+
keep = kf + (1.0 - kf) * keep
|
|
1295
|
+
np.clip(keep, 0.0, 1.0, out=keep)
|
|
1296
|
+
|
|
1297
|
+
status_cb(
|
|
1298
|
+
f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept_after_cap} | "
|
|
1299
|
+
f"drawn={drawn} | masked_px={masked_px_hard} | grow_px={G} | soft_sigma={soft_sigma} | keep_floor={keep_floor}"
|
|
1300
|
+
)
|
|
1301
|
+
return np.ascontiguousarray(keep, dtype=np.float32)
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
|
|
1305
|
+
def _auto_variance_map(
|
|
1306
|
+
img_2d: np.ndarray,
|
|
1307
|
+
hdr,
|
|
1308
|
+
status_cb=lambda s: None,
|
|
1309
|
+
sample_stride: int = VARMAP_SAMPLE_STRIDE, # kept for signature compat; not used
|
|
1310
|
+
bw: int = 64, # SEP background box width (pixels)
|
|
1311
|
+
bh: int = 64, # SEP background box height (pixels)
|
|
1312
|
+
smooth_sigma: float = 1.0, # Gaussian sigma (px) to smooth the variance map
|
|
1313
|
+
floor: float = 1e-8, # hard floor to prevent blow-up in 1/var
|
|
1314
|
+
) -> np.ndarray:
|
|
1315
|
+
"""
|
|
1316
|
+
Build a per-pixel variance map in DN^2:
|
|
1317
|
+
|
|
1318
|
+
var_DN ≈ (object_only_DN)/gain + var_bg_DN^2
|
|
1319
|
+
|
|
1320
|
+
where:
|
|
1321
|
+
- object_only_DN = max(img - sky_DN, 0)
|
|
1322
|
+
- var_bg_DN^2 comes from SEP's local background rms (Poisson(sky)+readnoise)
|
|
1323
|
+
- if GAIN is missing, estimate 1/gain ≈ median(var_bg)/median(sky)
|
|
1324
|
+
|
|
1325
|
+
Returns float32 array, clipped below by `floor`, optionally smoothed with a
|
|
1326
|
+
small Gaussian to stabilize weights. Emits a summary status line.
|
|
1327
|
+
"""
|
|
1328
|
+
img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
|
|
1329
|
+
|
|
1330
|
+
# --- Parse header for camera params (optional) ---
|
|
1331
|
+
gain = None
|
|
1332
|
+
for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
|
|
1333
|
+
if k in hdr:
|
|
1334
|
+
try:
|
|
1335
|
+
g = float(hdr[k])
|
|
1336
|
+
if np.isfinite(g) and g > 0:
|
|
1337
|
+
gain = g
|
|
1338
|
+
break
|
|
1339
|
+
except Exception:
|
|
1340
|
+
pass
|
|
1341
|
+
|
|
1342
|
+
readnoise = None
|
|
1343
|
+
for k in ("RDNOISE", "READNOISE", "RN"):
|
|
1344
|
+
if k in hdr:
|
|
1345
|
+
try:
|
|
1346
|
+
rn = float(hdr[k])
|
|
1347
|
+
if np.isfinite(rn) and rn >= 0:
|
|
1348
|
+
readnoise = rn
|
|
1349
|
+
break
|
|
1350
|
+
except Exception:
|
|
1351
|
+
pass
|
|
1352
|
+
|
|
1353
|
+
# --- Local background (full-res) ---
|
|
1354
|
+
if sep is not None:
|
|
1355
|
+
try:
|
|
1356
|
+
b = sep.Background(img, bw=int(bw), bh=int(bh), fw=3, fh=3)
|
|
1357
|
+
sky_dn_map = np.asarray(b.back(), dtype=np.float32)
|
|
1358
|
+
try:
|
|
1359
|
+
rms_dn_map = np.asarray(b.rms(), dtype=np.float32)
|
|
1360
|
+
except Exception:
|
|
1361
|
+
rms_dn_map = np.full_like(img, float(np.median(b.rms())), dtype=np.float32)
|
|
1362
|
+
except Exception:
|
|
1363
|
+
sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
|
|
1364
|
+
med = float(np.median(img))
|
|
1365
|
+
mad = float(np.median(np.abs(img - med))) + 1e-6
|
|
1366
|
+
rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
|
|
1367
|
+
else:
|
|
1368
|
+
sky_dn_map = np.full_like(img, float(np.median(img)), dtype=np.float32)
|
|
1369
|
+
med = float(np.median(img))
|
|
1370
|
+
mad = float(np.median(np.abs(img - med))) + 1e-6
|
|
1371
|
+
rms_dn_map = np.full_like(img, float(1.4826 * mad), dtype=np.float32)
|
|
1372
|
+
|
|
1373
|
+
# Background variance in DN^2
|
|
1374
|
+
var_bg_dn2 = np.maximum(rms_dn_map, 1e-6) ** 2
|
|
1375
|
+
|
|
1376
|
+
# Object-only DN
|
|
1377
|
+
obj_dn = np.clip(img - sky_dn_map, 0.0, None)
|
|
1378
|
+
|
|
1379
|
+
# Shot-noise coefficient
|
|
1380
|
+
if gain is not None and np.isfinite(gain) and gain > 0:
|
|
1381
|
+
a_shot = 1.0 / gain
|
|
1382
|
+
else:
|
|
1383
|
+
sky_med = float(np.median(sky_dn_map))
|
|
1384
|
+
varbg_med = float(np.median(var_bg_dn2))
|
|
1385
|
+
if sky_med > 1e-6:
|
|
1386
|
+
a_shot = np.clip(varbg_med / sky_med, 0.0, 10.0) # ~ 1/gain estimate
|
|
1387
|
+
else:
|
|
1388
|
+
a_shot = 0.0
|
|
1389
|
+
|
|
1390
|
+
# Total variance: background + shot noise from object-only flux
|
|
1391
|
+
v = var_bg_dn2 + a_shot * obj_dn
|
|
1392
|
+
v_raw = v.copy()
|
|
1393
|
+
|
|
1394
|
+
# Optional mild smoothing
|
|
1395
|
+
if smooth_sigma and smooth_sigma > 0:
|
|
1396
|
+
try:
|
|
1397
|
+
import cv2 as _cv2
|
|
1398
|
+
k = int(max(1, int(round(3 * float(smooth_sigma)))) * 2 + 1)
|
|
1399
|
+
v = _cv2.GaussianBlur(v, (k, k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
|
|
1400
|
+
except Exception:
|
|
1401
|
+
try:
|
|
1402
|
+
from scipy.ndimage import gaussian_filter
|
|
1403
|
+
v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
|
|
1404
|
+
except Exception:
|
|
1405
|
+
r = int(max(1, round(3 * float(smooth_sigma))))
|
|
1406
|
+
yy, xx = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
|
|
1407
|
+
gk = np.exp(-(xx*xx + yy*yy) / (2.0 * float(smooth_sigma) * float(smooth_sigma))).astype(np.float32)
|
|
1408
|
+
gk /= (gk.sum() + EPS)
|
|
1409
|
+
v = _conv_same_np(v, gk)
|
|
1410
|
+
|
|
1411
|
+
# Clip to avoid zero/negative variances
|
|
1412
|
+
v = np.clip(v, float(floor), None).astype(np.float32, copy=False)
|
|
1413
|
+
|
|
1414
|
+
# Emit telemetry
|
|
1415
|
+
try:
|
|
1416
|
+
sky_med = float(np.median(sky_dn_map))
|
|
1417
|
+
rms_med = float(np.median(np.sqrt(var_bg_dn2)))
|
|
1418
|
+
floor_pct = float((v <= floor).mean() * 100.0)
|
|
1419
|
+
status_cb(
|
|
1420
|
+
"Variance map: "
|
|
1421
|
+
f"sky_med={sky_med:.3g} DN | rms_med={rms_med:.3g} DN | "
|
|
1422
|
+
f"gain={(gain if gain is not None else 'NA')} | rn={(readnoise if readnoise is not None else 'NA')} | "
|
|
1423
|
+
f"smooth_sigma={smooth_sigma} | floor={floor} ({floor_pct:.2f}% at floor)"
|
|
1424
|
+
)
|
|
1425
|
+
except Exception:
|
|
1426
|
+
pass
|
|
1427
|
+
|
|
1428
|
+
return v
|
|
1429
|
+
|
|
1430
|
+
|
|
1431
|
+
def _star_mask_from_precomputed(
|
|
1432
|
+
img_2d: np.ndarray,
|
|
1433
|
+
sky_map: np.ndarray,
|
|
1434
|
+
err_scalar: float,
|
|
1435
|
+
*,
|
|
1436
|
+
thresh_sigma: float,
|
|
1437
|
+
max_objs: int,
|
|
1438
|
+
grow_px: int,
|
|
1439
|
+
ellipse_scale: float,
|
|
1440
|
+
soft_sigma: float,
|
|
1441
|
+
max_radius_px: int,
|
|
1442
|
+
keep_floor: float,
|
|
1443
|
+
max_side: int,
|
|
1444
|
+
status_cb=lambda s: None
|
|
1445
|
+
) -> np.ndarray:
|
|
1446
|
+
"""
|
|
1447
|
+
Build a KEEP weight map using a *downscaled detection / full-res draw* path.
|
|
1448
|
+
**Never writes to img_2d**; all drawing happens in a fresh `mask_u8`.
|
|
1449
|
+
"""
|
|
1450
|
+
# Optional OpenCV fast path
|
|
1451
|
+
try:
|
|
1452
|
+
import cv2 as _cv2
|
|
1453
|
+
_HAS_CV2 = True
|
|
1454
|
+
except Exception:
|
|
1455
|
+
_HAS_CV2 = False
|
|
1456
|
+
_cv2 = None # type: ignore
|
|
1457
|
+
|
|
1458
|
+
H, W = map(int, img_2d.shape)
|
|
1459
|
+
|
|
1460
|
+
# Residual for detection (contiguous, separate buffer)
|
|
1461
|
+
data_sub = np.ascontiguousarray((img_2d - sky_map).astype(np.float32))
|
|
1462
|
+
|
|
1463
|
+
# Downscale *detection only* to speed up, never the draw step
|
|
1464
|
+
det = data_sub
|
|
1465
|
+
scale = 1.0
|
|
1466
|
+
if max_side and max(H, W) > int(max_side):
|
|
1467
|
+
scale = float(max(H, W)) / float(max_side)
|
|
1468
|
+
if _HAS_CV2:
|
|
1469
|
+
det = _cv2.resize(
|
|
1470
|
+
det,
|
|
1471
|
+
(max(1, int(round(W / scale))), max(1, int(round(H / scale)))),
|
|
1472
|
+
interpolation=_cv2.INTER_AREA
|
|
1473
|
+
)
|
|
1474
|
+
else:
|
|
1475
|
+
s = int(max(1, round(scale)))
|
|
1476
|
+
det = det[:(H // s) * s, :(W // s) * s].reshape(H // s, s, W // s, s).mean(axis=(1, 3))
|
|
1477
|
+
scale = float(s)
|
|
1478
|
+
|
|
1479
|
+
# Threshold ladder
|
|
1480
|
+
thresholds = [thresh_sigma, thresh_sigma*2, thresh_sigma*4,
|
|
1481
|
+
thresh_sigma*8, thresh_sigma*16]
|
|
1482
|
+
objs = None; used = float("nan"); raw = 0
|
|
1483
|
+
for t in thresholds:
|
|
1484
|
+
cand = sep.extract(det, thresh=float(t), err=float(err_scalar))
|
|
1485
|
+
n = 0 if cand is None else len(cand)
|
|
1486
|
+
if n == 0: continue
|
|
1487
|
+
if n > max_objs*12: continue
|
|
1488
|
+
objs, raw, used = cand, n, float(t)
|
|
1489
|
+
break
|
|
1490
|
+
|
|
1491
|
+
if objs is None or len(objs) == 0:
|
|
1492
|
+
try:
|
|
1493
|
+
cand = sep.extract(det, thresh=thresholds[-1], err=float(err_scalar), minarea=9)
|
|
1494
|
+
except Exception:
|
|
1495
|
+
cand = None
|
|
1496
|
+
if cand is None or len(cand) == 0:
|
|
1497
|
+
status_cb("Star mask: no sources found (mask disabled for this frame).")
|
|
1498
|
+
return np.ones((H, W), dtype=np.float32, order="C")
|
|
1499
|
+
objs, raw, used = cand, len(cand), float(thresholds[-1])
|
|
1500
|
+
|
|
1501
|
+
# Brightest max_objs
|
|
1502
|
+
if "flux" in objs.dtype.names:
|
|
1503
|
+
idx = np.argsort(objs["flux"])[-int(max_objs):]
|
|
1504
|
+
objs = objs[idx]
|
|
1505
|
+
else:
|
|
1506
|
+
objs = objs[:int(max_objs)]
|
|
1507
|
+
kept = len(objs)
|
|
1508
|
+
|
|
1509
|
+
# ---- draw back on full-res into a brand-new buffer ----
|
|
1510
|
+
mask_u8 = np.zeros((H, W), dtype=np.uint8, order="C")
|
|
1511
|
+
s_back = float(scale)
|
|
1512
|
+
MR = int(max(1, max_radius_px))
|
|
1513
|
+
G = int(max(0, grow_px))
|
|
1514
|
+
ES = float(max(0.1, ellipse_scale))
|
|
1515
|
+
|
|
1516
|
+
drawn = 0
|
|
1517
|
+
if _HAS_CV2:
|
|
1518
|
+
for o in objs:
|
|
1519
|
+
x = int(round(float(o["x"]) * s_back))
|
|
1520
|
+
y = int(round(float(o["y"]) * s_back))
|
|
1521
|
+
if not (0 <= x < W and 0 <= y < H):
|
|
1522
|
+
continue
|
|
1523
|
+
a = float(o["a"]) * s_back
|
|
1524
|
+
b = float(o["b"]) * s_back
|
|
1525
|
+
r = int(math.ceil(ES * max(a, b)))
|
|
1526
|
+
r = min(max(r, 0) + G, MR)
|
|
1527
|
+
if r <= 0:
|
|
1528
|
+
continue
|
|
1529
|
+
_cv2.circle(mask_u8, (x, y), r, 1, thickness=-1, lineType=_cv2.LINE_8)
|
|
1530
|
+
drawn += 1
|
|
1531
|
+
else:
|
|
1532
|
+
for o in objs:
|
|
1533
|
+
x = int(round(float(o["x"]) * s_back))
|
|
1534
|
+
y = int(round(float(o["y"]) * s_back))
|
|
1535
|
+
if not (0 <= x < W and 0 <= y < H):
|
|
1536
|
+
continue
|
|
1537
|
+
a = float(o["a"]) * s_back
|
|
1538
|
+
b = float(o["b"]) * s_back
|
|
1539
|
+
r = int(math.ceil(ES * max(a, b)))
|
|
1540
|
+
r = min(max(r, 0) + G, MR)
|
|
1541
|
+
if r <= 0:
|
|
1542
|
+
continue
|
|
1543
|
+
y0 = max(0, y - r); y1 = min(H, y + r + 1)
|
|
1544
|
+
x0 = max(0, x - r); x1 = min(W, x + r + 1)
|
|
1545
|
+
yy, xx = np.ogrid[y0:y1, x0:x1]
|
|
1546
|
+
disk = (yy - y)*(yy - y) + (xx - x)*(xx - x) <= r*r
|
|
1547
|
+
mask_u8[y0:y1, x0:x1][disk] = 1
|
|
1548
|
+
drawn += 1
|
|
1549
|
+
|
|
1550
|
+
# Feather + convert to keep weights
|
|
1551
|
+
m = mask_u8.astype(np.float32, copy=False)
|
|
1552
|
+
if soft_sigma > 0:
|
|
1553
|
+
try:
|
|
1554
|
+
if _HAS_CV2:
|
|
1555
|
+
k = int(max(1, int(round(3*soft_sigma)))*2 + 1)
|
|
1556
|
+
m = _cv2.GaussianBlur(m, (k, k), float(soft_sigma),
|
|
1557
|
+
borderType=_cv2.BORDER_REFLECT)
|
|
1558
|
+
else:
|
|
1559
|
+
from scipy.ndimage import gaussian_filter
|
|
1560
|
+
m = gaussian_filter(m, sigma=float(soft_sigma), mode="reflect")
|
|
1561
|
+
except Exception:
|
|
1562
|
+
pass
|
|
1563
|
+
np.clip(m, 0.0, 1.0, out=m)
|
|
1564
|
+
|
|
1565
|
+
keep = 1.0 - m
|
|
1566
|
+
kf = float(max(0.0, min(0.99, keep_floor)))
|
|
1567
|
+
keep = kf + (1.0 - kf) * keep
|
|
1568
|
+
np.clip(keep, 0.0, 1.0, out=keep)
|
|
1569
|
+
|
|
1570
|
+
status_cb(f"Star mask: thresh={used:.3g} | detected={raw} | kept={kept} | drawn={drawn} | keep_floor={keep_floor}")
|
|
1571
|
+
return np.ascontiguousarray(keep, dtype=np.float32)
|
|
1572
|
+
|
|
1573
|
+
|
|
1574
|
+
def _variance_map_from_precomputed(
|
|
1575
|
+
img_2d: np.ndarray,
|
|
1576
|
+
sky_map: np.ndarray,
|
|
1577
|
+
rms_map: np.ndarray,
|
|
1578
|
+
hdr,
|
|
1579
|
+
*,
|
|
1580
|
+
smooth_sigma: float,
|
|
1581
|
+
floor: float,
|
|
1582
|
+
status_cb=lambda s: None
|
|
1583
|
+
) -> np.ndarray:
|
|
1584
|
+
img = np.clip(np.asarray(img_2d, dtype=np.float32), 0.0, None)
|
|
1585
|
+
var_bg_dn2 = np.maximum(rms_map, 1e-6) ** 2
|
|
1586
|
+
obj_dn = np.clip(img - sky_map, 0.0, None)
|
|
1587
|
+
|
|
1588
|
+
gain = None
|
|
1589
|
+
for k in ("EGAIN", "GAIN", "GAIN1", "GAIN2"):
|
|
1590
|
+
if k in hdr:
|
|
1591
|
+
try:
|
|
1592
|
+
g = float(hdr[k]); gain = g if (np.isfinite(g) and g > 0) else None
|
|
1593
|
+
if gain is not None: break
|
|
1594
|
+
except Exception as e:
|
|
1595
|
+
import logging
|
|
1596
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
1597
|
+
|
|
1598
|
+
if gain is not None:
|
|
1599
|
+
a_shot = 1.0 / gain
|
|
1600
|
+
else:
|
|
1601
|
+
sky_med = float(np.median(sky_map))
|
|
1602
|
+
varbg_med= float(np.median(var_bg_dn2))
|
|
1603
|
+
a_shot = (varbg_med / sky_med) if sky_med > 1e-6 else 0.0
|
|
1604
|
+
a_shot = float(np.clip(a_shot, 0.0, 10.0))
|
|
1605
|
+
|
|
1606
|
+
v = var_bg_dn2 + a_shot * obj_dn
|
|
1607
|
+
if smooth_sigma > 0:
|
|
1608
|
+
try:
|
|
1609
|
+
import cv2 as _cv2
|
|
1610
|
+
k = int(max(1, int(round(3*smooth_sigma)))*2 + 1)
|
|
1611
|
+
v = _cv2.GaussianBlur(v, (k,k), float(smooth_sigma), borderType=_cv2.BORDER_REFLECT)
|
|
1612
|
+
except Exception:
|
|
1613
|
+
try:
|
|
1614
|
+
from scipy.ndimage import gaussian_filter
|
|
1615
|
+
v = gaussian_filter(v, sigma=float(smooth_sigma), mode="reflect")
|
|
1616
|
+
except Exception:
|
|
1617
|
+
pass
|
|
1618
|
+
|
|
1619
|
+
np.clip(v, float(floor), None, out=v)
|
|
1620
|
+
try:
|
|
1621
|
+
rms_med = float(np.median(np.sqrt(var_bg_dn2)))
|
|
1622
|
+
status_cb(f"Variance map: sky_med={float(np.median(sky_map)):.3g} DN | rms_med={rms_med:.3g} DN | smooth_sigma={smooth_sigma} | floor={floor}")
|
|
1623
|
+
except Exception:
|
|
1624
|
+
pass
|
|
1625
|
+
return v.astype(np.float32, copy=False)
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
|
|
1629
|
+
# -----------------------------
|
|
1630
|
+
# Robust weighting (Huber)
|
|
1631
|
+
# -----------------------------
|
|
1632
|
+
|
|
1633
|
+
EPS = 1e-6
|
|
1634
|
+
|
|
1635
|
+
def _estimate_scalar_variance(a):
|
|
1636
|
+
med = np.median(a)
|
|
1637
|
+
mad = np.median(np.abs(a - med)) + 1e-6
|
|
1638
|
+
return float((1.4826 * mad) ** 2)
|
|
1639
|
+
|
|
1640
|
+
def _weight_map(y, pred, huber_delta, var_map=None, mask=None):
|
|
1641
|
+
"""
|
|
1642
|
+
W = [psi(r)/r] * 1/(var + eps) * mask, psi=Huber
|
|
1643
|
+
If huber_delta<0, delta = (-huber_delta) * RMS(residual) via MAD.
|
|
1644
|
+
y,pred: (H,W) or (C,H,W). var_map/mask are 2D; broadcast if needed.
|
|
1645
|
+
"""
|
|
1646
|
+
r = y - pred
|
|
1647
|
+
# auto delta?
|
|
1648
|
+
if huber_delta < 0:
|
|
1649
|
+
med = np.median(r)
|
|
1650
|
+
mad = np.median(np.abs(r - med)) + 1e-6
|
|
1651
|
+
rms = 1.4826 * mad
|
|
1652
|
+
delta = (-huber_delta) * max(rms, 1e-6)
|
|
1653
|
+
else:
|
|
1654
|
+
delta = huber_delta
|
|
1655
|
+
|
|
1656
|
+
absr = np.abs(r)
|
|
1657
|
+
if float(delta) > 0:
|
|
1658
|
+
psi_over_r = np.where(absr <= delta, 1.0, delta / (absr + EPS)).astype(np.float32)
|
|
1659
|
+
else:
|
|
1660
|
+
psi_over_r = np.ones_like(r, dtype=np.float32)
|
|
1661
|
+
|
|
1662
|
+
if var_map is None:
|
|
1663
|
+
v = _estimate_scalar_variance(r)
|
|
1664
|
+
else:
|
|
1665
|
+
v = var_map
|
|
1666
|
+
if v.ndim == 2 and r.ndim == 3:
|
|
1667
|
+
v = v[None, ...]
|
|
1668
|
+
w = psi_over_r / (v + EPS)
|
|
1669
|
+
if mask is not None:
|
|
1670
|
+
m = mask if mask.ndim == w.ndim else (mask[None, ...] if w.ndim == 3 else mask)
|
|
1671
|
+
w = w * m
|
|
1672
|
+
return w
|
|
1673
|
+
|
|
1674
|
+
|
|
1675
|
+
# -----------------------------
|
|
1676
|
+
# Torch / conv
|
|
1677
|
+
# -----------------------------
|
|
1678
|
+
|
|
1679
|
+
def _fftshape_same(H, W, kh, kw):
|
|
1680
|
+
return H + kh - 1, W + kw - 1
|
|
1681
|
+
|
|
1682
|
+
# ---------- Torch FFT helpers (FIXED: carry padH/padW) ----------
|
|
1683
|
+
def _precompute_torch_psf_ffts(psfs, flip_psf, H, W, device, dtype):
|
|
1684
|
+
"""
|
|
1685
|
+
Pack (Kf,padH,padW,kh,kw) so the conv can crop correctly to SAME.
|
|
1686
|
+
Kernel is ifftshifted before padding.
|
|
1687
|
+
"""
|
|
1688
|
+
tfft = torch.fft
|
|
1689
|
+
psf_fft, psfT_fft = [], []
|
|
1690
|
+
for k, kT in zip(psfs, flip_psf):
|
|
1691
|
+
kh, kw = k.shape
|
|
1692
|
+
padH, padW = _fftshape_same(H, W, kh, kw)
|
|
1693
|
+
k_small = torch.as_tensor(np.fft.ifftshift(k), device=device, dtype=dtype)
|
|
1694
|
+
kT_small = torch.as_tensor(np.fft.ifftshift(kT), device=device, dtype=dtype)
|
|
1695
|
+
Kf = tfft.rfftn(k_small, s=(padH, padW))
|
|
1696
|
+
KTf = tfft.rfftn(kT_small, s=(padH, padW))
|
|
1697
|
+
psf_fft.append((Kf, padH, padW, kh, kw))
|
|
1698
|
+
psfT_fft.append((KTf, padH, padW, kh, kw))
|
|
1699
|
+
return psf_fft, psfT_fft
|
|
1700
|
+
|
|
1701
|
+
|
|
1702
|
+
def _fft_conv_same_torch(x, Kf_pack, out_spatial):
|
|
1703
|
+
tfft = torch.fft
|
|
1704
|
+
Kf, padH, padW, kh, kw = Kf_pack
|
|
1705
|
+
H, W = x.shape[-2], x.shape[-1]
|
|
1706
|
+
if x.ndim == 2:
|
|
1707
|
+
X = tfft.rfftn(x, s=(padH, padW))
|
|
1708
|
+
y = tfft.irfftn(X * Kf, s=(padH, padW))
|
|
1709
|
+
sh, sw = (kh - 1)//2, (kw - 1)//2
|
|
1710
|
+
out_spatial.copy_(y[sh:sh+H, sw:sw+W])
|
|
1711
|
+
return out_spatial
|
|
1712
|
+
else:
|
|
1713
|
+
X = tfft.rfftn(x, s=(padH, padW), dim=(-2,-1))
|
|
1714
|
+
y = tfft.irfftn(X * Kf, s=(padH, padW), dim=(-2,-1))
|
|
1715
|
+
sh, sw = (kh - 1)//2, (kw - 1)//2
|
|
1716
|
+
out_spatial.copy_(y[..., sh:sh+H, sw:sw+W])
|
|
1717
|
+
return out_spatial
|
|
1718
|
+
|
|
1719
|
+
# ---------- NumPy FFT helpers ----------
|
|
1720
|
+
def _precompute_np_psf_ffts(psfs, flip_psf, H, W):
|
|
1721
|
+
import numpy.fft as fft
|
|
1722
|
+
meta, Kfs, KTfs = [], [], []
|
|
1723
|
+
for k, kT in zip(psfs, flip_psf):
|
|
1724
|
+
kh, kw = k.shape
|
|
1725
|
+
fftH, fftW = _fftshape_same(H, W, kh, kw)
|
|
1726
|
+
Kfs.append( fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)) )
|
|
1727
|
+
KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)) )
|
|
1728
|
+
meta.append((kh, kw, fftH, fftW))
|
|
1729
|
+
return Kfs, KTfs, meta
|
|
1730
|
+
|
|
1731
|
+
def _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out):
|
|
1732
|
+
import numpy.fft as fft
|
|
1733
|
+
if a.ndim == 2:
|
|
1734
|
+
A = fft.rfftn(a, s=(fftH, fftW))
|
|
1735
|
+
y = fft.irfftn(A * Kf, s=(fftH, fftW))
|
|
1736
|
+
sh, sw = (kh - 1)//2, (kw - 1)//2
|
|
1737
|
+
out[...] = y[sh:sh+a.shape[0], sw:sw+a.shape[1]]
|
|
1738
|
+
return out
|
|
1739
|
+
else:
|
|
1740
|
+
C, H, W = a.shape
|
|
1741
|
+
acc = []
|
|
1742
|
+
for c in range(C):
|
|
1743
|
+
A = fft.rfftn(a[c], s=(fftH, fftW))
|
|
1744
|
+
y = fft.irfftn(A * Kf, s=(fftH, fftW))
|
|
1745
|
+
sh, sw = (kh - 1)//2, (kw - 1)//2
|
|
1746
|
+
acc.append(y[sh:sh+H, sw:sw+W])
|
|
1747
|
+
out[...] = np.stack(acc, 0)
|
|
1748
|
+
return out
|
|
1749
|
+
|
|
1750
|
+
|
|
1751
|
+
|
|
1752
|
+
def _torch_device():
|
|
1753
|
+
if TORCH_OK and (torch is not None):
|
|
1754
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
1755
|
+
return torch.device("cuda")
|
|
1756
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
1757
|
+
return torch.device("mps")
|
|
1758
|
+
# DirectML: we passed dml_device from outer scope; keep a module-global
|
|
1759
|
+
if globals().get("dml_ok", False) and globals().get("dml_device", None) is not None:
|
|
1760
|
+
return globals()["dml_device"]
|
|
1761
|
+
return torch.device("cpu")
|
|
1762
|
+
|
|
1763
|
+
def _to_t(x: np.ndarray):
|
|
1764
|
+
if not (TORCH_OK and (torch is not None)):
|
|
1765
|
+
raise RuntimeError("Torch path requested but torch is unavailable")
|
|
1766
|
+
device = _torch_device()
|
|
1767
|
+
t = torch.from_numpy(x)
|
|
1768
|
+
# DirectML wants explicit .to(device)
|
|
1769
|
+
return t.to(device, non_blocking=True) if str(device) != "cpu" else t
|
|
1770
|
+
|
|
1771
|
+
def _contig(x):
|
|
1772
|
+
return np.ascontiguousarray(x, dtype=np.float32)
|
|
1773
|
+
|
|
1774
|
+
def _conv_same_torch(img_t, psf_t):
|
|
1775
|
+
"""
|
|
1776
|
+
img_t: torch tensor on DEVICE, (H,W) or (C,H,W)
|
|
1777
|
+
psf_t: torch tensor on DEVICE, (1,1,kh,kw) (single kernel)
|
|
1778
|
+
Pads with 'reflect' to avoid zero-padding ringing.
|
|
1779
|
+
"""
|
|
1780
|
+
kh, kw = psf_t.shape[-2:]
|
|
1781
|
+
pad = (kw // 2, kw - kw // 2 - 1, # left, right
|
|
1782
|
+
kh // 2, kh - kh // 2 - 1) # top, bottom
|
|
1783
|
+
|
|
1784
|
+
if img_t.ndim == 2:
|
|
1785
|
+
x = img_t[None, None]
|
|
1786
|
+
x = torch.nn.functional.pad(x, pad, mode="reflect")
|
|
1787
|
+
y = torch.nn.functional.conv2d(x, psf_t, padding=0)
|
|
1788
|
+
return y[0, 0]
|
|
1789
|
+
else:
|
|
1790
|
+
C = img_t.shape[0]
|
|
1791
|
+
x = img_t[None]
|
|
1792
|
+
x = torch.nn.functional.pad(x, pad, mode="reflect")
|
|
1793
|
+
w = psf_t.repeat(C, 1, 1, 1)
|
|
1794
|
+
y = torch.nn.functional.conv2d(x, w, padding=0, groups=C)
|
|
1795
|
+
return y[0]
|
|
1796
|
+
|
|
1797
|
+
def _safe_inference_context():
|
|
1798
|
+
"""
|
|
1799
|
+
Return a valid, working no-grad context:
|
|
1800
|
+
- prefer torch.inference_mode() if it exists *and* can be entered,
|
|
1801
|
+
- otherwise fall back to torch.no_grad(),
|
|
1802
|
+
- if torch is unavailable, return NO_GRAD.
|
|
1803
|
+
"""
|
|
1804
|
+
if not (TORCH_OK and (torch is not None)):
|
|
1805
|
+
return NO_GRAD
|
|
1806
|
+
|
|
1807
|
+
cm = getattr(torch, "inference_mode", None)
|
|
1808
|
+
if cm is None:
|
|
1809
|
+
return torch.no_grad
|
|
1810
|
+
|
|
1811
|
+
# Probe inference_mode once; if it explodes on this build, fall back.
|
|
1812
|
+
try:
|
|
1813
|
+
with cm():
|
|
1814
|
+
pass
|
|
1815
|
+
return cm
|
|
1816
|
+
except Exception:
|
|
1817
|
+
return torch.no_grad
|
|
1818
|
+
|
|
1819
|
+
def _ensure_mask_list(masks, data):
|
|
1820
|
+
# 1s where valid, 0s where invalid (soft edges allowed)
|
|
1821
|
+
if masks is None:
|
|
1822
|
+
return [np.ones_like(a if a.ndim==2 else a[0], dtype=np.float32) for a in data]
|
|
1823
|
+
out = []
|
|
1824
|
+
for a, m in zip(data, masks):
|
|
1825
|
+
base = a if a.ndim==2 else a[0] # mask is 2D; shared across channels
|
|
1826
|
+
if m is None:
|
|
1827
|
+
out.append(np.ones_like(base, dtype=np.float32))
|
|
1828
|
+
else:
|
|
1829
|
+
mm = np.asarray(m, dtype=np.float32)
|
|
1830
|
+
if mm.ndim == 3: # tolerate (1,H,W) or (C,H,W)
|
|
1831
|
+
mm = mm[0]
|
|
1832
|
+
if mm.shape != base.shape:
|
|
1833
|
+
# center crop to match (common intersection already applied)
|
|
1834
|
+
Ht, Wt = base.shape
|
|
1835
|
+
mm = _center_crop(mm, Ht, Wt)
|
|
1836
|
+
# keep as float weights in [0,1] (do not threshold!)
|
|
1837
|
+
out.append(np.clip(mm.astype(np.float32, copy=False), 0.0, 1.0))
|
|
1838
|
+
return out
|
|
1839
|
+
|
|
1840
|
+
def _ensure_var_list(variances, data):
|
|
1841
|
+
# If None, we’ll estimate a robust scalar per frame on-the-fly.
|
|
1842
|
+
if variances is None:
|
|
1843
|
+
return [None]*len(data)
|
|
1844
|
+
out = []
|
|
1845
|
+
for a, v in zip(data, variances):
|
|
1846
|
+
if v is None:
|
|
1847
|
+
out.append(None)
|
|
1848
|
+
else:
|
|
1849
|
+
vv = np.asarray(v, dtype=np.float32)
|
|
1850
|
+
if vv.ndim == 3:
|
|
1851
|
+
vv = vv[0]
|
|
1852
|
+
base = a if a.ndim==2 else a[0]
|
|
1853
|
+
if vv.shape != base.shape:
|
|
1854
|
+
Ht, Wt = base.shape
|
|
1855
|
+
vv = _center_crop(vv, Ht, Wt)
|
|
1856
|
+
# clip tiny/negatives
|
|
1857
|
+
vv = np.nan_to_num(vv, nan=1e-8, posinf=1e8, neginf=1e8, copy=False)
|
|
1858
|
+
vv = np.clip(vv, 1e-8, None).astype(np.float32, copy=False)
|
|
1859
|
+
out.append(vv)
|
|
1860
|
+
return out
|
|
1861
|
+
|
|
1862
|
+
# ---- SR operators (downsample / upsample-sum) ----
|
|
1863
|
+
def _downsample_avg(img, r: int):
|
|
1864
|
+
"""Average-pool over non-overlapping r×r blocks. Works for (H,W) or (C,H,W)."""
|
|
1865
|
+
if r <= 1:
|
|
1866
|
+
return img
|
|
1867
|
+
a = np.asarray(img, dtype=np.float32)
|
|
1868
|
+
if a.ndim == 2:
|
|
1869
|
+
H, W = a.shape
|
|
1870
|
+
Hs, Ws = (H // r) * r, (W // r) * r
|
|
1871
|
+
a = a[:Hs, :Ws].reshape(Hs//r, r, Ws//r, r).mean(axis=(1,3))
|
|
1872
|
+
return a
|
|
1873
|
+
else:
|
|
1874
|
+
C, H, W = a.shape
|
|
1875
|
+
Hs, Ws = (H // r) * r, (W // r) * r
|
|
1876
|
+
a = a[:, :Hs, :Ws].reshape(C, Hs//r, r, Ws//r, r).mean(axis=(2,4))
|
|
1877
|
+
return a
|
|
1878
|
+
|
|
1879
|
+
def _upsample_sum(img, r: int, target_hw: tuple[int,int] | None = None):
|
|
1880
|
+
"""Adjoint of average-pooling: replicate-sum each pixel into an r×r block.
|
|
1881
|
+
For (H,W) or (C,H,W). If target_hw given, center-crop/pad to that size.
|
|
1882
|
+
"""
|
|
1883
|
+
if r <= 1:
|
|
1884
|
+
return img
|
|
1885
|
+
a = np.asarray(img, dtype=np.float32)
|
|
1886
|
+
if a.ndim == 2:
|
|
1887
|
+
H, W = a.shape
|
|
1888
|
+
out = np.kron(a, np.ones((r, r), dtype=np.float32))
|
|
1889
|
+
else:
|
|
1890
|
+
C, H, W = a.shape
|
|
1891
|
+
out = np.stack([np.kron(a[c], np.ones((r, r), dtype=np.float32)) for c in range(C)], axis=0)
|
|
1892
|
+
if target_hw is not None:
|
|
1893
|
+
Ht, Wt = target_hw
|
|
1894
|
+
out = _center_crop(out, Ht, Wt)
|
|
1895
|
+
return out
|
|
1896
|
+
|
|
1897
|
+
def _gaussian2d(ksize: int, sigma: float) -> np.ndarray:
|
|
1898
|
+
r = (ksize - 1) // 2
|
|
1899
|
+
y, x = np.mgrid[-r:r+1, -r:r+1].astype(np.float32)
|
|
1900
|
+
g = np.exp(-(x*x + y*y)/(2.0*sigma*sigma)).astype(np.float32)
|
|
1901
|
+
g /= g.sum() + EPS
|
|
1902
|
+
return g
|
|
1903
|
+
|
|
1904
|
+
def _conv2_same_np(a: np.ndarray, k: np.ndarray) -> np.ndarray:
|
|
1905
|
+
# lightweight wrap for 2D conv on (H,W) or (C,H,W) with same-size output
|
|
1906
|
+
return _conv_same_np(a if a.ndim==3 else a[None], k)[0] if a.ndim==2 else _conv_same_np(a, k)
|
|
1907
|
+
|
|
1908
|
+
def _solve_super_psf_from_native(f_native: np.ndarray, r: int, sigma: float = 1.1,
|
|
1909
|
+
iters: int = 500, lr: float = 0.1) -> np.ndarray:
|
|
1910
|
+
"""
|
|
1911
|
+
Solve: h* = argmin_h || f_native - (D(h) * g_sigma) ||_2^2,
|
|
1912
|
+
where h is (r*k)×(r*k) if f_native is k×k. Returns normalized h (sum=1).
|
|
1913
|
+
"""
|
|
1914
|
+
f = np.asarray(f_native, dtype=np.float32)
|
|
1915
|
+
|
|
1916
|
+
# NEW: sanitize to 2D odd square before anything else
|
|
1917
|
+
if f.ndim != 2:
|
|
1918
|
+
f = np.squeeze(f)
|
|
1919
|
+
if f.ndim != 2:
|
|
1920
|
+
raise ValueError(f"PSF must be 2D, got shape {f.shape}")
|
|
1921
|
+
|
|
1922
|
+
H, W = int(f.shape[0]), int(f.shape[1])
|
|
1923
|
+
k_sq = min(H, W)
|
|
1924
|
+
# center-crop to square if needed
|
|
1925
|
+
if H != W:
|
|
1926
|
+
y0 = (H - k_sq) // 2
|
|
1927
|
+
x0 = (W - k_sq) // 2
|
|
1928
|
+
f = f[y0:y0 + k_sq, x0:x0 + k_sq]
|
|
1929
|
+
H = W = k_sq
|
|
1930
|
+
|
|
1931
|
+
# enforce odd size (required by SAME padding math)
|
|
1932
|
+
if (H % 2) == 0:
|
|
1933
|
+
# drop one pixel border to make it odd (centered)
|
|
1934
|
+
f = f[1:, 1:]
|
|
1935
|
+
H = W = f.shape[0]
|
|
1936
|
+
|
|
1937
|
+
k = int(H) # k is now odd and square
|
|
1938
|
+
kr = int(k * r)
|
|
1939
|
+
|
|
1940
|
+
|
|
1941
|
+
g = _gaussian2d(k, max(sigma, 1e-3)).astype(np.float32)
|
|
1942
|
+
|
|
1943
|
+
h0 = np.zeros((kr, kr), dtype=np.float32)
|
|
1944
|
+
h0[::r, ::r] = f
|
|
1945
|
+
h0 = _normalize_psf(h0)
|
|
1946
|
+
|
|
1947
|
+
if TORCH_OK:
|
|
1948
|
+
import torch.nn.functional as F
|
|
1949
|
+
dev = _torch_device()
|
|
1950
|
+
|
|
1951
|
+
# (1) Make sure Gaussian kernel is odd-sized for SAME conv padding
|
|
1952
|
+
g_pad = g
|
|
1953
|
+
if (g.shape[-1] % 2) == 0:
|
|
1954
|
+
# ensure odd + renormalize
|
|
1955
|
+
gg = _pad_kernel_to(g, g.shape[-1] + 1)
|
|
1956
|
+
g_pad = gg.astype(np.float32, copy=False)
|
|
1957
|
+
|
|
1958
|
+
t = torch.tensor(h0, device=dev, dtype=torch.float32, requires_grad=True)
|
|
1959
|
+
f_t = torch.tensor(f, device=dev, dtype=torch.float32)
|
|
1960
|
+
g_t = torch.tensor(g_pad, device=dev, dtype=torch.float32)
|
|
1961
|
+
opt = torch.optim.Adam([t], lr=lr)
|
|
1962
|
+
|
|
1963
|
+
# Helpful assertion avoids silent shape traps
|
|
1964
|
+
H, W = t.shape
|
|
1965
|
+
assert (H % r) == 0 and (W % r) == 0, f"h shape {t.shape} not divisible by r={r}"
|
|
1966
|
+
Hr, Wr = H // r, W // r
|
|
1967
|
+
|
|
1968
|
+
try:
|
|
1969
|
+
for _ in range(max(10, iters)):
|
|
1970
|
+
opt.zero_grad(set_to_none=True)
|
|
1971
|
+
|
|
1972
|
+
# (2) Downsample with avg_pool2d instead of reshape/mean
|
|
1973
|
+
blk = t.narrow(0, 0, Hr * r).narrow(1, 0, Wr * r).contiguous()
|
|
1974
|
+
th = F.avg_pool2d(blk[None, None], kernel_size=r, stride=r)[0, 0] # (k,k)
|
|
1975
|
+
|
|
1976
|
+
# (3) Native-space blur with guaranteed-odd g_t
|
|
1977
|
+
pad = g_t.shape[-1] // 2
|
|
1978
|
+
conv = F.conv2d(th[None, None], g_t[None, None], padding=pad)[0, 0]
|
|
1979
|
+
|
|
1980
|
+
loss = torch.mean((conv - f_t) ** 2)
|
|
1981
|
+
loss.backward()
|
|
1982
|
+
opt.step()
|
|
1983
|
+
with torch.no_grad():
|
|
1984
|
+
t.clamp_(min=0.0)
|
|
1985
|
+
t /= (t.sum() + 1e-8)
|
|
1986
|
+
h = t.detach().cpu().numpy().astype(np.float32)
|
|
1987
|
+
except Exception:
|
|
1988
|
+
# (4) Conservative safety net: if a backend balks (commonly at r=2),
|
|
1989
|
+
# fall back to the NumPy solver *just for this kernel*.
|
|
1990
|
+
h = None
|
|
1991
|
+
|
|
1992
|
+
if not TORCH_OK or h is None:
|
|
1993
|
+
# NumPy fallback (unchanged)
|
|
1994
|
+
h = h0.copy()
|
|
1995
|
+
eta = float(lr)
|
|
1996
|
+
for _ in range(max(50, iters)):
|
|
1997
|
+
Dh = _downsample_avg(h, r)
|
|
1998
|
+
conv = _conv2_same_np(Dh, g)
|
|
1999
|
+
resid = (conv - f)
|
|
2000
|
+
grad_Dh = _conv2_same_np(resid, np.flip(np.flip(g, 0), 1))
|
|
2001
|
+
grad_h = _upsample_sum(grad_Dh, r, target_hw=h.shape)
|
|
2002
|
+
h = np.clip(h - eta * grad_h, 0.0, None)
|
|
2003
|
+
s = float(h.sum()); h /= (s + 1e-8)
|
|
2004
|
+
eta *= 0.995
|
|
2005
|
+
|
|
2006
|
+
return _normalize_psf(h)
|
|
2007
|
+
|
|
2008
|
+
|
|
2009
|
+
def _downsample_avg_t(x, r: int):
|
|
2010
|
+
if r <= 1:
|
|
2011
|
+
return x
|
|
2012
|
+
if x.ndim == 2:
|
|
2013
|
+
H, W = x.shape
|
|
2014
|
+
Hr, Wr = (H // r) * r, (W // r) * r
|
|
2015
|
+
if Hr == 0 or Wr == 0:
|
|
2016
|
+
return x
|
|
2017
|
+
x2 = x[:Hr, :Wr]
|
|
2018
|
+
# ❌ .view → ✅ .reshape
|
|
2019
|
+
return x2.reshape(Hr // r, r, Wr // r, r).mean(dim=(1, 3))
|
|
2020
|
+
else:
|
|
2021
|
+
C, H, W = x.shape
|
|
2022
|
+
Hr, Wr = (H // r) * r, (W // r) * r
|
|
2023
|
+
if Hr == 0 or Wr == 0:
|
|
2024
|
+
return x
|
|
2025
|
+
x2 = x[:, :Hr, :Wr]
|
|
2026
|
+
# ❌ .view → ✅ .reshape
|
|
2027
|
+
return x2.reshape(C, Hr // r, r, Wr // r, r).mean(dim=(2, 4))
|
|
2028
|
+
|
|
2029
|
+
def _upsample_sum_t(x, r: int):
|
|
2030
|
+
if r <= 1:
|
|
2031
|
+
return x
|
|
2032
|
+
if x.ndim == 2:
|
|
2033
|
+
return x.repeat_interleave(r, dim=0).repeat_interleave(r, dim=1)
|
|
2034
|
+
else:
|
|
2035
|
+
return x.repeat_interleave(r, dim=-2).repeat_interleave(r, dim=-1)
|
|
2036
|
+
|
|
2037
|
+
def _sep_bg_rms(frames):
|
|
2038
|
+
"""Return a robust background RMS using SEP's background model on the first frame."""
|
|
2039
|
+
if sep is None or not frames:
|
|
2040
|
+
return None
|
|
2041
|
+
try:
|
|
2042
|
+
y0 = frames[0] if frames[0].ndim == 2 else frames[0][0] # use luma/first channel
|
|
2043
|
+
a = np.ascontiguousarray(y0, dtype=np.float32)
|
|
2044
|
+
b = sep.Background(a, bw=64, bh=64, fw=3, fh=3)
|
|
2045
|
+
try:
|
|
2046
|
+
rms_val = float(b.globalrms)
|
|
2047
|
+
except Exception:
|
|
2048
|
+
# some SEP builds don’t expose globalrms; fall back to the map’s median
|
|
2049
|
+
rms_val = float(np.median(np.asarray(b.rms(), dtype=np.float32)))
|
|
2050
|
+
return rms_val
|
|
2051
|
+
except Exception:
|
|
2052
|
+
return None
|
|
2053
|
+
|
|
2054
|
+
# =========================
|
|
2055
|
+
# Memory/streaming helpers
|
|
2056
|
+
# =========================
|
|
2057
|
+
|
|
2058
|
+
def _approx_bytes(arr_like_shape, dtype=np.float32):
|
|
2059
|
+
"""Rough byte estimator for a given shape/dtype."""
|
|
2060
|
+
return int(np.prod(arr_like_shape)) * np.dtype(dtype).itemsize
|
|
2061
|
+
|
|
2062
|
+
def _mem_model(
|
|
2063
|
+
grid_hw: tuple[int,int],
|
|
2064
|
+
r: int,
|
|
2065
|
+
ksize: int,
|
|
2066
|
+
channels: int,
|
|
2067
|
+
mem_target_mb: int,
|
|
2068
|
+
prefer_tiles: bool = False,
|
|
2069
|
+
min_tile: int = 256,
|
|
2070
|
+
max_tile: int = 2048,
|
|
2071
|
+
) -> dict:
|
|
2072
|
+
"""
|
|
2073
|
+
Pick a batch size (#frames) and optional tile size (HxW) given a memory budget.
|
|
2074
|
+
Very conservative — aims to bound peak working-set on CPU/GPU.
|
|
2075
|
+
"""
|
|
2076
|
+
Hs, Ws = grid_hw
|
|
2077
|
+
halo = (ksize // 2) * max(1, r) # SR grid halo if r>1
|
|
2078
|
+
C = max(1, channels)
|
|
2079
|
+
|
|
2080
|
+
# working-set per *full-frame* conv scratch (num/den/tmp/etc.)
|
|
2081
|
+
per_frame_fft_like = 3 * _approx_bytes((C, Hs, Ws)) # tmp/pred + in/out buffers
|
|
2082
|
+
global_accum = 2 * _approx_bytes((C, Hs, Ws)) # num + den
|
|
2083
|
+
|
|
2084
|
+
budget = int(mem_target_mb * 1024 * 1024)
|
|
2085
|
+
|
|
2086
|
+
# Try to stay in full-frame mode first unless prefer_tiles
|
|
2087
|
+
B_full = max(1, (budget - global_accum) // max(per_frame_fft_like, 1))
|
|
2088
|
+
use_tiles = prefer_tiles or (B_full < 1)
|
|
2089
|
+
|
|
2090
|
+
if not use_tiles:
|
|
2091
|
+
return dict(batch_frames=int(B_full), tiles=None, halo=int(halo), ksize=int(ksize))
|
|
2092
|
+
|
|
2093
|
+
# Tile mode: pick a square tile side t that fits
|
|
2094
|
+
# scratch per tile ~ 3*C*(t+2h)^2 + accum(core) ~ small
|
|
2095
|
+
# try descending from max_tile
|
|
2096
|
+
t = int(min(max_tile, max(min_tile, 1 << int(np.floor(np.log2(min(Hs, Ws)))))))
|
|
2097
|
+
while t >= min_tile:
|
|
2098
|
+
th = t + 2 * halo
|
|
2099
|
+
per_tile = 3 * _approx_bytes((C, th, th))
|
|
2100
|
+
B_tile = max(1, (budget - global_accum) // max(per_tile, 1))
|
|
2101
|
+
if B_tile >= 1:
|
|
2102
|
+
return dict(batch_frames=int(B_tile), tiles=(t, t), halo=int(halo), ksize=int(ksize))
|
|
2103
|
+
t //= 2
|
|
2104
|
+
|
|
2105
|
+
# Worst case: 1 frame, minimal tile
|
|
2106
|
+
return dict(batch_frames=1, tiles=(min_tile, min_tile), halo=int(halo), ksize=int(ksize))
|
|
2107
|
+
|
|
2108
|
+
def _build_seed_running_mu_sigma_from_paths(
|
|
2109
|
+
paths, Ht, Wt, color_mode,
|
|
2110
|
+
*, bootstrap_frames=24, clip_sigma=3.5, # clip_sigma used for streaming updates
|
|
2111
|
+
status_cb=lambda s: None, progress_cb=None
|
|
2112
|
+
):
|
|
2113
|
+
"""
|
|
2114
|
+
Seed:
|
|
2115
|
+
1) Load first B frames -> mean0
|
|
2116
|
+
2) MAD around mean0 -> ±4·MAD mask -> masked-mean seed (one mini-iteration)
|
|
2117
|
+
3) Stream remaining frames with σ-clipped Welford updates (unchanged behavior)
|
|
2118
|
+
Returns float32 image in (H,W) or (C,H,W) matching color_mode.
|
|
2119
|
+
"""
|
|
2120
|
+
def p(frac, msg):
|
|
2121
|
+
if progress_cb:
|
|
2122
|
+
progress_cb(float(max(0.0, min(1.0, frac))), msg)
|
|
2123
|
+
|
|
2124
|
+
n_total = len(paths)
|
|
2125
|
+
B = int(max(1, min(int(bootstrap_frames), n_total)))
|
|
2126
|
+
status_cb(f"MFDeconv: Seed bootstrap {B} frame(s) with ±4·MAD clip on the average…")
|
|
2127
|
+
p(0.00, f"bootstrap load 0/{B}")
|
|
2128
|
+
|
|
2129
|
+
# ---------- load first B frames ----------
|
|
2130
|
+
boot = []
|
|
2131
|
+
for i, pth in enumerate(paths[:B], start=1):
|
|
2132
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2133
|
+
boot.append(ys[0].astype(np.float32, copy=False))
|
|
2134
|
+
if (i == B) or (i % 4 == 0):
|
|
2135
|
+
p(0.25 * (i / float(B)), f"bootstrap load {i}/{B}")
|
|
2136
|
+
|
|
2137
|
+
stack = np.stack(boot, axis=0) # (B,H,W) or (B,C,H,W)
|
|
2138
|
+
del boot
|
|
2139
|
+
|
|
2140
|
+
# ---------- mean0 ----------
|
|
2141
|
+
mean0 = np.mean(stack, axis=0, dtype=np.float32)
|
|
2142
|
+
p(0.28, "bootstrap mean computed")
|
|
2143
|
+
|
|
2144
|
+
# ---------- ±4·MAD clip around mean0, then masked mean (one pass) ----------
|
|
2145
|
+
# MAD per-pixel: median(|x - mean0|)
|
|
2146
|
+
abs_dev = np.abs(stack - mean0[None, ...])
|
|
2147
|
+
mad = np.median(abs_dev, axis=0).astype(np.float32, copy=False)
|
|
2148
|
+
|
|
2149
|
+
thr = 4.0 * mad + EPS
|
|
2150
|
+
mask = (abs_dev <= thr)
|
|
2151
|
+
|
|
2152
|
+
# masked mean with fallback to mean0 where all rejected
|
|
2153
|
+
m = mask.astype(np.float32, copy=False)
|
|
2154
|
+
sum_acc = np.sum(stack * m, axis=0, dtype=np.float32)
|
|
2155
|
+
cnt_acc = np.sum(m, axis=0, dtype=np.float32)
|
|
2156
|
+
seed = mean0.copy()
|
|
2157
|
+
np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
|
|
2158
|
+
p(0.36, "±4·MAD masked mean computed")
|
|
2159
|
+
|
|
2160
|
+
# ---------- initialize Welford state from seed ----------
|
|
2161
|
+
# Start μ=seed, set an initial variance envelope from the bootstrap dispersion
|
|
2162
|
+
dif = stack - seed[None, ...]
|
|
2163
|
+
M2 = np.sum(dif * dif, axis=0, dtype=np.float32)
|
|
2164
|
+
cnt = np.full_like(seed, float(B), dtype=np.float32)
|
|
2165
|
+
mu = seed.astype(np.float32, copy=False)
|
|
2166
|
+
del stack, abs_dev, mad, m, sum_acc, cnt_acc, dif
|
|
2167
|
+
|
|
2168
|
+
p(0.40, "seed initialized; streaming refinements…")
|
|
2169
|
+
|
|
2170
|
+
# ---------- stream remaining frames with σ-clipped Welford updates ----------
|
|
2171
|
+
remain = n_total - B
|
|
2172
|
+
if remain > 0:
|
|
2173
|
+
status_cb(f"MFDeconv: Seed μ–σ clipping {remain} remaining frame(s) (k={clip_sigma:.2f})…")
|
|
2174
|
+
|
|
2175
|
+
k = float(clip_sigma)
|
|
2176
|
+
for j, pth in enumerate(paths[B:], start=1):
|
|
2177
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2178
|
+
x = ys[0].astype(np.float32, copy=False)
|
|
2179
|
+
|
|
2180
|
+
var = M2 / np.maximum(cnt - 1.0, 1.0)
|
|
2181
|
+
sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
|
|
2182
|
+
|
|
2183
|
+
accept = (np.abs(x - mu) <= (k * sigma))
|
|
2184
|
+
acc = accept.astype(np.float32, copy=False)
|
|
2185
|
+
|
|
2186
|
+
n_new = cnt + acc
|
|
2187
|
+
delta = x - mu
|
|
2188
|
+
mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
|
|
2189
|
+
M2 = M2 + acc * delta * (x - mu_n)
|
|
2190
|
+
|
|
2191
|
+
mu, cnt = mu_n, n_new
|
|
2192
|
+
|
|
2193
|
+
if (j == remain) or (j % 8 == 0):
|
|
2194
|
+
p(0.40 + 0.60 * (j / float(remain)), f"μ–σ refine {j}/{remain}")
|
|
2195
|
+
|
|
2196
|
+
p(1.0, "seed ready")
|
|
2197
|
+
return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
|
|
2198
|
+
|
|
2199
|
+
|
|
2200
|
+
def _chunk(seq, n):
|
|
2201
|
+
"""Yield chunks of size n from seq."""
|
|
2202
|
+
for i in range(0, len(seq), n):
|
|
2203
|
+
yield seq[i:i+n]
|
|
2204
|
+
|
|
2205
|
+
def _read_shape_fast(path) -> tuple[int,int,int]:
|
|
2206
|
+
if _is_xisf(path):
|
|
2207
|
+
a, _ = _load_image_array(path)
|
|
2208
|
+
if a is None:
|
|
2209
|
+
raise ValueError(f"No data in {path}")
|
|
2210
|
+
a = np.asarray(a)
|
|
2211
|
+
else:
|
|
2212
|
+
with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
|
|
2213
|
+
a = hdul[0].data
|
|
2214
|
+
if a is None:
|
|
2215
|
+
raise ValueError(f"No data in {path}")
|
|
2216
|
+
|
|
2217
|
+
# common logic for both XISF and FITS
|
|
2218
|
+
if a.ndim == 2:
|
|
2219
|
+
H, W = a.shape
|
|
2220
|
+
return (1, int(H), int(W))
|
|
2221
|
+
if a.ndim == 3:
|
|
2222
|
+
if a.shape[-1] in (1, 3): # HWC
|
|
2223
|
+
C = int(a.shape[-1]); H = int(a.shape[0]); W = int(a.shape[1])
|
|
2224
|
+
return (1 if C == 1 else 3, H, W)
|
|
2225
|
+
if a.shape[0] in (1, 3): # CHW
|
|
2226
|
+
return (int(a.shape[0]), int(a.shape[1]), int(a.shape[2]))
|
|
2227
|
+
s = tuple(map(int, a.shape))
|
|
2228
|
+
H, W = s[-2], s[-1]
|
|
2229
|
+
return (1, H, W)
|
|
2230
|
+
|
|
2231
|
+
|
|
2232
|
+
|
|
2233
|
+
|
|
2234
|
+
def _tiles_of(hw: tuple[int,int], tile_hw: tuple[int,int], halo: int):
|
|
2235
|
+
"""
|
|
2236
|
+
Yield tiles as dicts: {y0,y1,x0,x1,yc0,yc1,xc0,xc1}
|
|
2237
|
+
(outer region includes halo; core (yc0:yc1, xc0:xc1) excludes halo).
|
|
2238
|
+
"""
|
|
2239
|
+
H, W = hw
|
|
2240
|
+
th, tw = tile_hw
|
|
2241
|
+
th = max(1, int(th)); tw = max(1, int(tw))
|
|
2242
|
+
for y in range(0, H, th):
|
|
2243
|
+
for x in range(0, W, tw):
|
|
2244
|
+
yc0 = y; yc1 = min(y + th, H)
|
|
2245
|
+
xc0 = x; xc1 = min(x + tw, W)
|
|
2246
|
+
y0 = max(0, yc0 - halo); y1 = min(H, yc1 + halo)
|
|
2247
|
+
x0 = max(0, xc0 - halo); x1 = min(W, xc1 + halo)
|
|
2248
|
+
yield dict(y0=y0, y1=y1, x0=x0, x1=x1, yc0=yc0, yc1=yc1, xc0=xc0, xc1=xc1)
|
|
2249
|
+
|
|
2250
|
+
def _extract_with_halo(a, tile):
|
|
2251
|
+
"""
|
|
2252
|
+
Slice 'a' ((H,W) or (C,H,W)) to [y0:y1, x0:x1] with channel kept.
|
|
2253
|
+
"""
|
|
2254
|
+
y0,y1,x0,x1 = tile["y0"], tile["y1"], tile["x0"], tile["x1"]
|
|
2255
|
+
if a.ndim == 2:
|
|
2256
|
+
return a[y0:y1, x0:x1]
|
|
2257
|
+
else:
|
|
2258
|
+
return a[:, y0:y1, x0:x1]
|
|
2259
|
+
|
|
2260
|
+
def _add_core(accum, tile_val, tile):
|
|
2261
|
+
"""
|
|
2262
|
+
Add tile_val core into accum at (yc0:yc1, xc0:xc1).
|
|
2263
|
+
Shapes match (2D) or (C,H,W).
|
|
2264
|
+
"""
|
|
2265
|
+
yc0,yc1,xc0,xc1 = tile["yc0"], tile["yc1"], tile["xc0"], tile["xc1"]
|
|
2266
|
+
if accum.ndim == 2:
|
|
2267
|
+
h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
|
|
2268
|
+
w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
|
|
2269
|
+
accum[yc0:yc1, xc0:xc1] += tile_val[h0:h1, w0:w1]
|
|
2270
|
+
else:
|
|
2271
|
+
h0 = yc0 - tile["y0"]; h1 = h0 + (yc1 - yc0)
|
|
2272
|
+
w0 = xc0 - tile["x0"]; w1 = w0 + (xc1 - xc0)
|
|
2273
|
+
accum[:, yc0:yc1, xc0:xc1] += tile_val[:, h0:h1, w0:w1]
|
|
2274
|
+
|
|
2275
|
+
def _prepare_np_fft_packs_batch(psfs, flip_psf, Hs, Ws):
|
|
2276
|
+
"""Precompute rFFT packs on current grid for NumPy path; returns lists aligned to batch psfs."""
|
|
2277
|
+
Kfs, KTfs, meta = [], [], []
|
|
2278
|
+
import numpy.fft as fft
|
|
2279
|
+
for k, kT in zip(psfs, flip_psf):
|
|
2280
|
+
kh, kw = k.shape
|
|
2281
|
+
fftH, fftW = _fftshape_same(Hs, Ws, kh, kw)
|
|
2282
|
+
Kfs.append(fft.rfftn(np.fft.ifftshift(k), s=(fftH, fftW)))
|
|
2283
|
+
KTfs.append(fft.rfftn(np.fft.ifftshift(kT), s=(fftH, fftW)))
|
|
2284
|
+
meta.append((kh, kw, fftH, fftW))
|
|
2285
|
+
return Kfs, KTfs, meta
|
|
2286
|
+
|
|
2287
|
+
def _prepare_torch_fft_packs_batch(psfs, flip_psf, Hs, Ws, device, dtype):
|
|
2288
|
+
"""Torch FFT packs per PSF on current grid; mirrors your existing packer."""
|
|
2289
|
+
return _precompute_torch_psf_ffts(psfs, flip_psf, Hs, Ws, device, dtype)
|
|
2290
|
+
|
|
2291
|
+
def _as_chw(np_img: np.ndarray) -> np.ndarray:
|
|
2292
|
+
x = np.asarray(np_img, dtype=np.float32, order="C")
|
|
2293
|
+
if x.size == 0:
|
|
2294
|
+
raise RuntimeError(f"Empty image array after load; raw shape={np_img.shape}")
|
|
2295
|
+
if x.ndim == 2:
|
|
2296
|
+
return x[None, ...] # 1,H,W
|
|
2297
|
+
if x.ndim == 3 and x.shape[0] in (1, 3):
|
|
2298
|
+
if x.shape[0] == 0:
|
|
2299
|
+
raise RuntimeError(f"Zero channels in CHW array; shape={x.shape}")
|
|
2300
|
+
return x
|
|
2301
|
+
if x.ndim == 3 and x.shape[-1] in (1, 3):
|
|
2302
|
+
if x.shape[-1] == 0:
|
|
2303
|
+
raise RuntimeError(f"Zero channels in HWC array; shape={x.shape}")
|
|
2304
|
+
return np.moveaxis(x, -1, 0)
|
|
2305
|
+
# last resort: treat first dim as channels, but reject zero
|
|
2306
|
+
if x.shape[0] == 0:
|
|
2307
|
+
raise RuntimeError(f"Zero channels in array; shape={x.shape}")
|
|
2308
|
+
return x
|
|
2309
|
+
|
|
2310
|
+
|
|
2311
|
+
|
|
2312
|
+
def _conv_same_np_spatial(a: np.ndarray, k: np.ndarray, out: np.ndarray | None = None):
|
|
2313
|
+
try:
|
|
2314
|
+
import cv2
|
|
2315
|
+
except Exception:
|
|
2316
|
+
return None # no opencv -> caller falls back to FFT
|
|
2317
|
+
|
|
2318
|
+
# cv2 wants HxW single-channel float32
|
|
2319
|
+
kf = np.ascontiguousarray(k.astype(np.float32))
|
|
2320
|
+
kf = np.flip(np.flip(kf, 0), 1) # OpenCV uses correlation; flip to emulate conv
|
|
2321
|
+
|
|
2322
|
+
if a.ndim == 2:
|
|
2323
|
+
y = cv2.filter2D(a, -1, kf, borderType=cv2.BORDER_REFLECT)
|
|
2324
|
+
if out is None: return y
|
|
2325
|
+
out[...] = y; return out
|
|
2326
|
+
else:
|
|
2327
|
+
C, H, W = a.shape
|
|
2328
|
+
if out is None:
|
|
2329
|
+
out = np.empty_like(a)
|
|
2330
|
+
for c in range(C):
|
|
2331
|
+
out[c] = cv2.filter2D(a[c], -1, kf, borderType=cv2.BORDER_REFLECT)
|
|
2332
|
+
return out
|
|
2333
|
+
|
|
2334
|
+
def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
|
|
2335
|
+
F = torch.nn.functional
|
|
2336
|
+
x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
|
|
2337
|
+
w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
|
|
2338
|
+
|
|
2339
|
+
kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
|
|
2340
|
+
pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
|
|
2341
|
+
|
|
2342
|
+
# unified path (CUDA/CPU/MPS): one grouped conv with G=B*C
|
|
2343
|
+
G = int(B * C)
|
|
2344
|
+
x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
|
|
2345
|
+
x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
|
|
2346
|
+
w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
|
|
2347
|
+
y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
|
|
2348
|
+
return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
|
|
2349
|
+
|
|
2350
|
+
|
|
2351
|
+
# put near other small helpers
|
|
2352
|
+
def _robust_med_mad_t(x, max_elems_per_sample: int = 2_000_000):
|
|
2353
|
+
"""
|
|
2354
|
+
x: (B, C, H, W) tensor on device.
|
|
2355
|
+
Returns (median[B,1,1,1], mad[B,1,1,1]) computed on a strided subsample
|
|
2356
|
+
to avoid 'quantile() input tensor is too large'.
|
|
2357
|
+
"""
|
|
2358
|
+
import math
|
|
2359
|
+
import torch
|
|
2360
|
+
B = x.shape[0]
|
|
2361
|
+
flat = x.reshape(B, -1)
|
|
2362
|
+
N = flat.shape[1]
|
|
2363
|
+
if N > max_elems_per_sample:
|
|
2364
|
+
stride = int(math.ceil(N / float(max_elems_per_sample)))
|
|
2365
|
+
flat = flat[:, ::stride] # strided subsample
|
|
2366
|
+
med = torch.quantile(flat, 0.5, dim=1, keepdim=True)
|
|
2367
|
+
mad = torch.quantile((flat - med).abs(), 0.5, dim=1, keepdim=True) + 1e-6
|
|
2368
|
+
return med.view(B,1,1,1), mad.view(B,1,1,1)
|
|
2369
|
+
|
|
2370
|
+
def _torch_should_use_spatial(psf_ksize: int) -> bool:
|
|
2371
|
+
# Prefer spatial on non-CUDA backends and for modest kernels.
|
|
2372
|
+
try:
|
|
2373
|
+
dev = _torch_device()
|
|
2374
|
+
if dev.type in ("mps", "privateuseone"): # privateuseone = DirectML
|
|
2375
|
+
return True
|
|
2376
|
+
if dev.type == "cuda":
|
|
2377
|
+
return psf_ksize <= 51 # typical PSF sizes; spatial is fast & stable
|
|
2378
|
+
except Exception:
|
|
2379
|
+
pass
|
|
2380
|
+
# Allow override via env
|
|
2381
|
+
import os as _os
|
|
2382
|
+
if _os.environ.get("MF_SPATIAL", "") == "1":
|
|
2383
|
+
return True
|
|
2384
|
+
return False
|
|
2385
|
+
|
|
2386
|
+
def _read_tile_fits(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
|
|
2387
|
+
"""Return a (H,W) or (H,W,3|1) tile via FITS memmap, without loading whole image."""
|
|
2388
|
+
with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
|
|
2389
|
+
hdu = hdul[0]
|
|
2390
|
+
a = hdu.data
|
|
2391
|
+
if a is None:
|
|
2392
|
+
# find first image HDU if primary is header-only
|
|
2393
|
+
for h in hdul[1:]:
|
|
2394
|
+
if getattr(h, "data", None) is not None:
|
|
2395
|
+
a = h.data; break
|
|
2396
|
+
a = np.asarray(a) # still lazy until sliced
|
|
2397
|
+
# squeeze trailing singleton if present to keep your conventions
|
|
2398
|
+
if a.ndim == 3 and a.shape[-1] == 1:
|
|
2399
|
+
a = np.squeeze(a, axis=-1)
|
|
2400
|
+
tile = a[y0:y1, x0:x1, ...]
|
|
2401
|
+
# copy so we own the buffer (we will cast/normalize)
|
|
2402
|
+
return np.array(tile, copy=True)
|
|
2403
|
+
|
|
2404
|
+
def _read_tile_fits_any(path: str, y0: int, y1: int, x0: int, x1: int) -> np.ndarray:
|
|
2405
|
+
"""FITS/XISF-aware tile read: returns spatial tile; supports 2D, HWC, and CHW."""
|
|
2406
|
+
ext = os.path.splitext(path)[1].lower()
|
|
2407
|
+
|
|
2408
|
+
if ext == ".xisf":
|
|
2409
|
+
a = _xisf_cached_array(path) # float32 memmap; cheap slicing
|
|
2410
|
+
# a is HW, HWC, or CHW (whatever _load_image_array returned)
|
|
2411
|
+
if a.ndim == 2:
|
|
2412
|
+
return np.array(a[y0:y1, x0:x1], copy=True)
|
|
2413
|
+
elif a.ndim == 3:
|
|
2414
|
+
if a.shape[-1] in (1, 3): # HWC
|
|
2415
|
+
out = a[y0:y1, x0:x1, :]
|
|
2416
|
+
if out.shape[-1] == 1: out = out[..., 0]
|
|
2417
|
+
return np.array(out, copy=True)
|
|
2418
|
+
elif a.shape[0] in (1, 3): # CHW
|
|
2419
|
+
out = a[:, y0:y1, x0:x1]
|
|
2420
|
+
if out.shape[0] == 1: out = out[0]
|
|
2421
|
+
return np.array(out, copy=True)
|
|
2422
|
+
else:
|
|
2423
|
+
raise ValueError(f"Unsupported XISF 3D shape {a.shape} in {path}")
|
|
2424
|
+
else:
|
|
2425
|
+
raise ValueError(f"Unsupported XISF ndim {a.ndim} in {path}")
|
|
2426
|
+
|
|
2427
|
+
# FITS
|
|
2428
|
+
with fits.open(path, memmap=True, ignore_missing_simple=True) as hdul:
|
|
2429
|
+
a = None
|
|
2430
|
+
for h in hdul:
|
|
2431
|
+
if getattr(h, "data", None) is not None:
|
|
2432
|
+
a = h.data
|
|
2433
|
+
break
|
|
2434
|
+
if a is None:
|
|
2435
|
+
raise ValueError(f"No image data in {path}")
|
|
2436
|
+
|
|
2437
|
+
a = np.asarray(a)
|
|
2438
|
+
|
|
2439
|
+
if a.ndim == 2: # HW
|
|
2440
|
+
return np.array(a[y0:y1, x0:x1], copy=True)
|
|
2441
|
+
|
|
2442
|
+
if a.ndim == 3:
|
|
2443
|
+
if a.shape[0] in (1, 3): # CHW (planes, rows, cols)
|
|
2444
|
+
out = a[:, y0:y1, x0:x1]
|
|
2445
|
+
if out.shape[0] == 1:
|
|
2446
|
+
out = out[0]
|
|
2447
|
+
return np.array(out, copy=True)
|
|
2448
|
+
if a.shape[-1] in (1, 3): # HWC
|
|
2449
|
+
out = a[y0:y1, x0:x1, :]
|
|
2450
|
+
if out.shape[-1] == 1:
|
|
2451
|
+
out = out[..., 0]
|
|
2452
|
+
return np.array(out, copy=True)
|
|
2453
|
+
|
|
2454
|
+
# Fallback: assume last two axes are spatial (…, H, W)
|
|
2455
|
+
try:
|
|
2456
|
+
out = a[(..., slice(y0, y1), slice(x0, x1))]
|
|
2457
|
+
return np.array(out, copy=True)
|
|
2458
|
+
except Exception:
|
|
2459
|
+
raise ValueError(f"Unsupported FITS data shape {a.shape} in {path}")
|
|
2460
|
+
|
|
2461
|
+
def _infer_channels_from_tile(p: str, Ht: int, Wt: int) -> int:
|
|
2462
|
+
"""Look at a 1×1 tile to infer channel count; supports HW, HWC, CHW."""
|
|
2463
|
+
y1 = min(1, Ht); x1 = min(1, Wt)
|
|
2464
|
+
t = _read_tile_fits_any(p, 0, y1, 0, x1)
|
|
2465
|
+
|
|
2466
|
+
if t.ndim == 2:
|
|
2467
|
+
return 1
|
|
2468
|
+
|
|
2469
|
+
if t.ndim == 3:
|
|
2470
|
+
# Prefer the axis that actually carries the color planes
|
|
2471
|
+
ch_first = t.shape[0] in (1, 3)
|
|
2472
|
+
ch_last = t.shape[-1] in (1, 3)
|
|
2473
|
+
|
|
2474
|
+
if ch_first and not ch_last:
|
|
2475
|
+
return int(t.shape[0])
|
|
2476
|
+
if ch_last and not ch_first:
|
|
2477
|
+
return int(t.shape[-1])
|
|
2478
|
+
|
|
2479
|
+
# Ambiguous tiny tile (e.g. CHW 3×1×1 or HWC 1×1×3):
|
|
2480
|
+
if t.shape[0] == 3 or t.shape[-1] == 3:
|
|
2481
|
+
return 3
|
|
2482
|
+
return 1
|
|
2483
|
+
|
|
2484
|
+
return 1
|
|
2485
|
+
|
|
2486
|
+
|
|
2487
|
+
|
|
2488
|
+
def _seed_median_streaming(
|
|
2489
|
+
paths,
|
|
2490
|
+
Ht,
|
|
2491
|
+
Wt,
|
|
2492
|
+
*,
|
|
2493
|
+
color_mode="luma",
|
|
2494
|
+
tile_hw=(256, 256),
|
|
2495
|
+
status_cb=lambda s: None,
|
|
2496
|
+
progress_cb=lambda f, m="": None,
|
|
2497
|
+
use_torch: bool | None = None, # auto by default
|
|
2498
|
+
):
|
|
2499
|
+
"""
|
|
2500
|
+
Exact per-pixel median via tiling; RAM-bounded.
|
|
2501
|
+
Now shows per-tile progress and uses Torch on GPU if available.
|
|
2502
|
+
Parallelizes per-tile slab reads to hide I/O and luma work.
|
|
2503
|
+
"""
|
|
2504
|
+
|
|
2505
|
+
th, tw = int(tile_hw[0]), int(tile_hw[1])
|
|
2506
|
+
# old: want_c = 1 if str(color_mode).lower() == "luma" else (3 if _read_shape_fast(paths[0])[0] == 3 else 1)
|
|
2507
|
+
if str(color_mode).lower() == "luma":
|
|
2508
|
+
want_c = 1
|
|
2509
|
+
else:
|
|
2510
|
+
want_c = _infer_channels_from_tile(paths[0], Ht, Wt)
|
|
2511
|
+
seed = np.zeros((Ht, Wt), np.float32) if want_c == 1 else np.zeros((want_c, Ht, Wt), np.float32)
|
|
2512
|
+
tiles = [(y, min(y + th, Ht), x, min(x + tw, Wt)) for y in range(0, Ht, th) for x in range(0, Wt, tw)]
|
|
2513
|
+
total = len(tiles)
|
|
2514
|
+
n_frames = len(paths)
|
|
2515
|
+
|
|
2516
|
+
# Choose a sensible number of I/O workers (bounded by frames)
|
|
2517
|
+
try:
|
|
2518
|
+
_cpu = (os.cpu_count() or 4)
|
|
2519
|
+
except Exception:
|
|
2520
|
+
_cpu = 4
|
|
2521
|
+
io_workers = max(1, min(8, _cpu, n_frames))
|
|
2522
|
+
|
|
2523
|
+
# Torch autodetect (once)
|
|
2524
|
+
TORCH_OK = False
|
|
2525
|
+
device = None
|
|
2526
|
+
if use_torch is not False:
|
|
2527
|
+
try:
|
|
2528
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
2529
|
+
_t = import_torch(prefer_cuda=True, status_cb=status_cb)
|
|
2530
|
+
dev = None
|
|
2531
|
+
if hasattr(_t, "cuda") and _t.cuda.is_available():
|
|
2532
|
+
dev = _t.device("cuda")
|
|
2533
|
+
elif hasattr(_t.backends, "mps") and _t.backends.mps.is_available():
|
|
2534
|
+
dev = _t.device("mps")
|
|
2535
|
+
else:
|
|
2536
|
+
dev = None # CPU tensors slower than NumPy for median; only use if forced
|
|
2537
|
+
if dev is not None:
|
|
2538
|
+
TORCH_OK = True
|
|
2539
|
+
device = dev
|
|
2540
|
+
status_cb(f"Median seed: using Torch device {device}")
|
|
2541
|
+
except Exception as e:
|
|
2542
|
+
status_cb(f"Median seed: Torch unavailable → NumPy fallback ({e})")
|
|
2543
|
+
TORCH_OK = False
|
|
2544
|
+
device = None
|
|
2545
|
+
|
|
2546
|
+
def _tile_msg(ti, tn):
|
|
2547
|
+
return f"median tiles {ti}/{tn}"
|
|
2548
|
+
|
|
2549
|
+
done = 0
|
|
2550
|
+
|
|
2551
|
+
for (y0, y1, x0, x1) in tiles:
|
|
2552
|
+
h, w = (y1 - y0), (x1 - x0)
|
|
2553
|
+
|
|
2554
|
+
# per-tile slab reader with incremental progress (parallel)
|
|
2555
|
+
def _read_slab_for_channel(csel=None):
|
|
2556
|
+
"""
|
|
2557
|
+
Returns slab of shape (N, h, w) float32 in [0,1]-ish (normalized if input was integer).
|
|
2558
|
+
If csel is None and luma is requested, computes luma.
|
|
2559
|
+
"""
|
|
2560
|
+
# parallel worker returns (i, tile2d)
|
|
2561
|
+
def _load_one(i):
|
|
2562
|
+
t = _read_tile_fits_any(paths[i], y0, y1, x0, x1)
|
|
2563
|
+
|
|
2564
|
+
# normalize dtype
|
|
2565
|
+
if t.dtype.kind in "ui":
|
|
2566
|
+
t = t.astype(np.float32) / (float(np.iinfo(t.dtype).max) or 1.0)
|
|
2567
|
+
else:
|
|
2568
|
+
t = t.astype(np.float32, copy=False)
|
|
2569
|
+
|
|
2570
|
+
# luma / channel selection
|
|
2571
|
+
if want_c == 1:
|
|
2572
|
+
if t.ndim == 3:
|
|
2573
|
+
t = _to_luma_local(t)
|
|
2574
|
+
elif t.ndim != 2:
|
|
2575
|
+
t = _to_luma_local(t)
|
|
2576
|
+
else:
|
|
2577
|
+
if t.ndim == 2:
|
|
2578
|
+
pass
|
|
2579
|
+
elif t.ndim == 3 and t.shape[-1] == 3: # HWC
|
|
2580
|
+
t = t[..., csel]
|
|
2581
|
+
elif t.ndim == 3 and t.shape[0] == 3: # CHW
|
|
2582
|
+
t = t[csel]
|
|
2583
|
+
else:
|
|
2584
|
+
t = _to_luma_local(t)
|
|
2585
|
+
return i, np.ascontiguousarray(t, dtype=np.float32)
|
|
2586
|
+
|
|
2587
|
+
slab = np.empty((n_frames, h, w), np.float32)
|
|
2588
|
+
done_local = 0
|
|
2589
|
+
# cap workers by n_frames so we don't spawn useless threads
|
|
2590
|
+
with ThreadPoolExecutor(max_workers=min(io_workers, n_frames)) as ex:
|
|
2591
|
+
futures = [ex.submit(_load_one, i) for i in range(n_frames)]
|
|
2592
|
+
for fut in as_completed(futures):
|
|
2593
|
+
i, t2d = fut.result()
|
|
2594
|
+
# quick sanity (avoid silent mis-shapes)
|
|
2595
|
+
if t2d.shape != (h, w):
|
|
2596
|
+
raise RuntimeError(
|
|
2597
|
+
f"Tile read mismatch at frame {i}: got {t2d.shape}, expected {(h, w)} "
|
|
2598
|
+
f"tile={(y0,y1,x0,x1)}"
|
|
2599
|
+
)
|
|
2600
|
+
slab[i] = t2d
|
|
2601
|
+
done_local += 1
|
|
2602
|
+
if (done_local & 7) == 0 or done_local == n_frames:
|
|
2603
|
+
tile_base = done / total
|
|
2604
|
+
tile_span = 1.0 / total
|
|
2605
|
+
inner = done_local / n_frames
|
|
2606
|
+
progress_cb(tile_base + 0.8 * tile_span * inner, _tile_msg(done + 1, total))
|
|
2607
|
+
return slab
|
|
2608
|
+
|
|
2609
|
+
try:
|
|
2610
|
+
if want_c == 1:
|
|
2611
|
+
t0 = time.perf_counter()
|
|
2612
|
+
slab = _read_slab_for_channel()
|
|
2613
|
+
t1 = time.perf_counter()
|
|
2614
|
+
if TORCH_OK:
|
|
2615
|
+
import torch as _t
|
|
2616
|
+
slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32) # one H2D
|
|
2617
|
+
med_t = slab_t.median(dim=0).values
|
|
2618
|
+
med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
|
|
2619
|
+
# no per-tile empty_cache() — avoid forced syncs
|
|
2620
|
+
else:
|
|
2621
|
+
med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
|
|
2622
|
+
t2 = time.perf_counter()
|
|
2623
|
+
seed[y0:y1, x0:x1] = med_np
|
|
2624
|
+
# lightweight telemetry to confirm bottleneck
|
|
2625
|
+
status_cb(f"seed tile {y0}:{y1},{x0}:{x1} I/O={t1-t0:.3f}s median={'GPU' if TORCH_OK else 'CPU'}={t2-t1:.3f}s")
|
|
2626
|
+
else:
|
|
2627
|
+
for c in range(want_c):
|
|
2628
|
+
slab = _read_slab_for_channel(csel=c)
|
|
2629
|
+
if TORCH_OK:
|
|
2630
|
+
import torch as _t
|
|
2631
|
+
slab_t = _t.as_tensor(slab, device=device, dtype=_t.float32)
|
|
2632
|
+
med_t = slab_t.median(dim=0).values
|
|
2633
|
+
med_np = med_t.detach().cpu().numpy().astype(np.float32, copy=False)
|
|
2634
|
+
else:
|
|
2635
|
+
med_np = np.median(slab, axis=0).astype(np.float32, copy=False)
|
|
2636
|
+
seed[c, y0:y1, x0:x1] = med_np
|
|
2637
|
+
|
|
2638
|
+
except RuntimeError as e:
|
|
2639
|
+
# Per-tile GPU OOM or device issues → fallback to NumPy for this tile
|
|
2640
|
+
msg = str(e).lower()
|
|
2641
|
+
if TORCH_OK and ("out of memory" in msg or "resource" in msg or "alloc" in msg):
|
|
2642
|
+
status_cb(f"Median seed: GPU OOM on tile ({h}x{w}); falling back to NumPy for this tile.")
|
|
2643
|
+
if want_c == 1:
|
|
2644
|
+
slab = _read_slab_for_channel()
|
|
2645
|
+
seed[y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
|
|
2646
|
+
else:
|
|
2647
|
+
for c in range(want_c):
|
|
2648
|
+
slab = _read_slab_for_channel(csel=c)
|
|
2649
|
+
seed[c, y0:y1, x0:x1] = np.median(slab, axis=0).astype(np.float32, copy=False)
|
|
2650
|
+
else:
|
|
2651
|
+
raise
|
|
2652
|
+
|
|
2653
|
+
done += 1
|
|
2654
|
+
# report tile completion (remaining 20% of tile span reserved for median compute)
|
|
2655
|
+
progress_cb(done / total, _tile_msg(done, total))
|
|
2656
|
+
|
|
2657
|
+
if (done & 3) == 0:
|
|
2658
|
+
_process_gui_events_safely()
|
|
2659
|
+
status_cb(f"Median seed: want_c={want_c}, seed_shape={seed.shape}")
|
|
2660
|
+
return seed
|
|
2661
|
+
|
|
2662
|
+
def _seed_bootstrap_streaming(paths, Ht, Wt, color_mode,
|
|
2663
|
+
bootstrap_frames: int = 20,
|
|
2664
|
+
clip_sigma: float = 5.0,
|
|
2665
|
+
status_cb=lambda s: None,
|
|
2666
|
+
progress_cb=None):
|
|
2667
|
+
"""
|
|
2668
|
+
Seed = average first B frames, estimate global MAD threshold, masked-mean on B,
|
|
2669
|
+
then stream the remaining frames with σ-clipped Welford μ–σ updates.
|
|
2670
|
+
|
|
2671
|
+
Returns float32, CHW for per-channel mode, or CHW(1,...) for luma (your caller later squeezes if needed).
|
|
2672
|
+
"""
|
|
2673
|
+
def p(frac, msg):
|
|
2674
|
+
if progress_cb:
|
|
2675
|
+
progress_cb(float(max(0.0, min(1.0, frac))), msg)
|
|
2676
|
+
|
|
2677
|
+
n = len(paths)
|
|
2678
|
+
B = int(max(1, min(int(bootstrap_frames), n)))
|
|
2679
|
+
status_cb(f"Seed: bootstrap={B}, clip_sigma={clip_sigma}")
|
|
2680
|
+
|
|
2681
|
+
# ---------- pass 1: running mean over the first B frames (Welford μ only) ----------
|
|
2682
|
+
cnt = None
|
|
2683
|
+
mu = None
|
|
2684
|
+
for i, pth in enumerate(paths[:B], 1):
|
|
2685
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2686
|
+
x = ys[0].astype(np.float32, copy=False)
|
|
2687
|
+
if mu is None:
|
|
2688
|
+
mu = x.copy()
|
|
2689
|
+
cnt = np.ones_like(x, dtype=np.float32)
|
|
2690
|
+
else:
|
|
2691
|
+
delta = x - mu
|
|
2692
|
+
cnt += 1.0
|
|
2693
|
+
mu += delta / cnt
|
|
2694
|
+
if (i == B) or (i % 4) == 0:
|
|
2695
|
+
p(0.20 * (i / float(B)), f"bootstrap mean {i}/{B}")
|
|
2696
|
+
|
|
2697
|
+
# ---------- pass 2: estimate a *global* MAD around μ using strided samples ----------
|
|
2698
|
+
stride = 4 if max(Ht, Wt) > 1500 else 2
|
|
2699
|
+
# CHW or HW → build a slice that keeps C and strides H,W
|
|
2700
|
+
samp_slices = (slice(None) if mu.ndim == 3 else slice(None),
|
|
2701
|
+
slice(None, None, stride),
|
|
2702
|
+
slice(None, None, stride)) if mu.ndim == 3 else \
|
|
2703
|
+
(slice(None, None, stride), slice(None, None, stride))
|
|
2704
|
+
|
|
2705
|
+
mad_samples = []
|
|
2706
|
+
for i, pth in enumerate(paths[:B], 1):
|
|
2707
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2708
|
+
x = ys[0].astype(np.float32, copy=False)
|
|
2709
|
+
d = np.abs(x - mu)
|
|
2710
|
+
mad_samples.append(d[samp_slices].ravel())
|
|
2711
|
+
if (i == B) or (i % 8) == 0:
|
|
2712
|
+
p(0.35 + 0.10 * (i / float(B)), f"bootstrap MAD {i}/{B}")
|
|
2713
|
+
|
|
2714
|
+
# robust MAD estimate (scalar) → 4·MAD clip band
|
|
2715
|
+
mad_est = float(np.median(np.concatenate(mad_samples).astype(np.float32)))
|
|
2716
|
+
thr = 4.0 * max(mad_est, 1e-6)
|
|
2717
|
+
|
|
2718
|
+
# ---------- pass 3: masked mean over first B frames using the global threshold ----------
|
|
2719
|
+
sum_acc = np.zeros_like(mu, dtype=np.float32)
|
|
2720
|
+
cnt_acc = np.zeros_like(mu, dtype=np.float32)
|
|
2721
|
+
for i, pth in enumerate(paths[:B], 1):
|
|
2722
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2723
|
+
x = ys[0].astype(np.float32, copy=False)
|
|
2724
|
+
m = (np.abs(x - mu) <= thr).astype(np.float32, copy=False)
|
|
2725
|
+
sum_acc += x * m
|
|
2726
|
+
cnt_acc += m
|
|
2727
|
+
if (i == B) or (i % 4) == 0:
|
|
2728
|
+
p(0.48 + 0.12 * (i / float(B)), f"masked mean {i}/{B}")
|
|
2729
|
+
|
|
2730
|
+
seed = mu.copy()
|
|
2731
|
+
np.divide(sum_acc, np.maximum(cnt_acc, 1.0), out=seed, where=(cnt_acc > 0.5))
|
|
2732
|
+
|
|
2733
|
+
# ---------- pass 4: μ–σ streaming on the remaining frames (σ-clipped Welford) ----------
|
|
2734
|
+
M2 = np.zeros_like(seed, dtype=np.float32) # sum of squared diffs
|
|
2735
|
+
cnt = np.full_like(seed, float(B), dtype=np.float32)
|
|
2736
|
+
mu = seed.astype(np.float32, copy=False)
|
|
2737
|
+
|
|
2738
|
+
remain = n - B
|
|
2739
|
+
k = float(clip_sigma)
|
|
2740
|
+
for j, pth in enumerate(paths[B:], 1):
|
|
2741
|
+
ys, _ = _stack_loader_memmap([pth], Ht, Wt, color_mode)
|
|
2742
|
+
x = ys[0].astype(np.float32, copy=False)
|
|
2743
|
+
|
|
2744
|
+
var = M2 / np.maximum(cnt - 1.0, 1.0)
|
|
2745
|
+
sigma = np.sqrt(np.maximum(var, 1e-12, dtype=np.float32))
|
|
2746
|
+
|
|
2747
|
+
acc = (np.abs(x - mu) <= (k * sigma)).astype(np.float32, copy=False) # {0,1} mask
|
|
2748
|
+
n_new = cnt + acc
|
|
2749
|
+
delta = x - mu
|
|
2750
|
+
mu_n = mu + (acc * delta) / np.maximum(n_new, 1.0)
|
|
2751
|
+
M2 = M2 + acc * delta * (x - mu_n)
|
|
2752
|
+
|
|
2753
|
+
mu, cnt = mu_n, n_new
|
|
2754
|
+
|
|
2755
|
+
if (j == remain) or (j % 8) == 0:
|
|
2756
|
+
p(0.60 + 0.40 * (j / float(max(1, remain))), f"μ–σ refine {j}/{remain}")
|
|
2757
|
+
|
|
2758
|
+
return np.clip(mu, 0.0, None).astype(np.float32, copy=False)
|
|
2759
|
+
|
|
2760
|
+
|
|
2761
|
+
def _coerce_sr_factor(srf, *, default_on_bad=2):
|
|
2762
|
+
"""
|
|
2763
|
+
Parse super-res factor robustly:
|
|
2764
|
+
- accepts 2, '2', '2x', ' 2 X ', 2.0
|
|
2765
|
+
- clamps to integers >= 1
|
|
2766
|
+
- if invalid/missing → returns default_on_bad (we want 2 by your request)
|
|
2767
|
+
"""
|
|
2768
|
+
if srf is None:
|
|
2769
|
+
return int(default_on_bad)
|
|
2770
|
+
if isinstance(srf, (float, int)):
|
|
2771
|
+
r = int(round(float(srf)))
|
|
2772
|
+
return int(r if r >= 1 else default_on_bad)
|
|
2773
|
+
s = str(srf).strip().lower()
|
|
2774
|
+
# common GUIs pass e.g. "2x", "3×", etc.
|
|
2775
|
+
s = s.replace("×", "x")
|
|
2776
|
+
if s.endswith("x"):
|
|
2777
|
+
s = s[:-1]
|
|
2778
|
+
try:
|
|
2779
|
+
r = int(round(float(s)))
|
|
2780
|
+
return int(r if r >= 1 else default_on_bad)
|
|
2781
|
+
except Exception:
|
|
2782
|
+
return int(default_on_bad)
|
|
2783
|
+
|
|
2784
|
+
def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
|
|
2785
|
+
"""Pad/center an odd-sized kernel to K×K (K odd)."""
|
|
2786
|
+
k = np.asarray(k, dtype=np.float32)
|
|
2787
|
+
kh, kw = int(k.shape[0]), int(k.shape[1])
|
|
2788
|
+
assert (kh % 2 == 1) and (kw % 2 == 1)
|
|
2789
|
+
if kh == K and kw == K:
|
|
2790
|
+
return k
|
|
2791
|
+
out = np.zeros((K, K), dtype=np.float32)
|
|
2792
|
+
y0 = (K - kh)//2; x0 = (K - kw)//2
|
|
2793
|
+
out[y0:y0+kh, x0:x0+kw] = k
|
|
2794
|
+
s = float(out.sum())
|
|
2795
|
+
return out if s <= 0 else (out / s).astype(np.float32, copy=False)
|
|
2796
|
+
|
|
2797
|
+
# -----------------------------
|
|
2798
|
+
# Core
|
|
2799
|
+
# -----------------------------
|
|
2800
|
+
def multiframe_deconv(
|
|
2801
|
+
paths,
|
|
2802
|
+
out_path,
|
|
2803
|
+
iters=20,
|
|
2804
|
+
kappa=2.0,
|
|
2805
|
+
color_mode="luma",
|
|
2806
|
+
seed_mode: str = "robust",
|
|
2807
|
+
huber_delta=0.0,
|
|
2808
|
+
masks=None,
|
|
2809
|
+
variances=None,
|
|
2810
|
+
rho="huber",
|
|
2811
|
+
status_cb=lambda s: None,
|
|
2812
|
+
min_iters: int = 3,
|
|
2813
|
+
use_star_masks: bool = False,
|
|
2814
|
+
use_variance_maps: bool = False,
|
|
2815
|
+
star_mask_cfg: dict | None = None,
|
|
2816
|
+
varmap_cfg: dict | None = None,
|
|
2817
|
+
save_intermediate: bool = False,
|
|
2818
|
+
save_every: int = 1,
|
|
2819
|
+
# SR options
|
|
2820
|
+
super_res_factor: int = 1,
|
|
2821
|
+
sr_sigma: float = 1.1,
|
|
2822
|
+
sr_psf_opt_iters: int = 250,
|
|
2823
|
+
sr_psf_opt_lr: float = 0.1,
|
|
2824
|
+
# NEW
|
|
2825
|
+
batch_frames: int | None = None,
|
|
2826
|
+
# GPU tuning (optional knobs)
|
|
2827
|
+
mixed_precision: bool | None = None, # default: auto (True on CUDA/MPS)
|
|
2828
|
+
fft_kernel_threshold: int = 1024, # switch to FFT if K >= this (or lower if SR)
|
|
2829
|
+
prefetch_batches: bool = True, # CPU→GPU double-buffer prefetch
|
|
2830
|
+
use_channels_last: bool | None = None, # default: auto (True on CUDA/MPS)
|
|
2831
|
+
force_cpu: bool = False,
|
|
2832
|
+
star_mask_ref_path: str | None = None,
|
|
2833
|
+
low_mem: bool = False,
|
|
2834
|
+
):
|
|
2835
|
+
"""
|
|
2836
|
+
Streaming multi-frame deconvolution with optional SR (r>1).
|
|
2837
|
+
Optimized GPU path: AMP for convs, channels-last, pinned-memory prefetch, optional FFT for large kernels.
|
|
2838
|
+
"""
|
|
2839
|
+
mixed_precision = False
|
|
2840
|
+
DEBUG_FLAT_WEIGHTS = False
|
|
2841
|
+
# ---------- local helpers (kept self-contained) ----------
|
|
2842
|
+
def _emit_pct(pct: float, msg: str | None = None):
|
|
2843
|
+
pct = float(max(0.0, min(1.0, pct)))
|
|
2844
|
+
status_cb(f"__PROGRESS__ {pct:.4f}" + (f" {msg}" if msg else ""))
|
|
2845
|
+
|
|
2846
|
+
def _pad_kernel_to(k: np.ndarray, K: int) -> np.ndarray:
|
|
2847
|
+
"""Pad/center an odd-sized kernel to K×K (K odd)."""
|
|
2848
|
+
k = np.asarray(k, dtype=np.float32)
|
|
2849
|
+
kh, kw = int(k.shape[0]), int(k.shape[1])
|
|
2850
|
+
assert (kh % 2 == 1) and (kw % 2 == 1)
|
|
2851
|
+
if kh == K and kw == K:
|
|
2852
|
+
return k
|
|
2853
|
+
out = np.zeros((K, K), dtype=np.float32)
|
|
2854
|
+
y0 = (K - kh)//2; x0 = (K - kw)//2
|
|
2855
|
+
out[y0:y0+kh, x0:x0+kw] = k
|
|
2856
|
+
s = float(out.sum())
|
|
2857
|
+
return out if s <= 0 else (out / s).astype(np.float32, copy=False)
|
|
2858
|
+
|
|
2859
|
+
max_iters = max(1, int(iters))
|
|
2860
|
+
min_iters = max(1, int(min_iters))
|
|
2861
|
+
if min_iters > max_iters:
|
|
2862
|
+
min_iters = max_iters
|
|
2863
|
+
|
|
2864
|
+
n_frames = len(paths)
|
|
2865
|
+
status_cb(f"MFDeconv: scanning {n_frames} aligned frames (memmap)…")
|
|
2866
|
+
_emit_pct(0.02, "scanning")
|
|
2867
|
+
|
|
2868
|
+
# choose common intersection size without loading full pixels
|
|
2869
|
+
Ht, Wt = _common_hw_from_paths(paths)
|
|
2870
|
+
_emit_pct(0.05, "preparing")
|
|
2871
|
+
|
|
2872
|
+
# --- LOW-MEM PATCH (begin) ---
|
|
2873
|
+
if low_mem:
|
|
2874
|
+
# Cap decoded-frame LRU to keep peak RAM sane on 16 GB laptops
|
|
2875
|
+
try:
|
|
2876
|
+
_FRAME_LRU.cap = max(1, min(getattr(_FRAME_LRU, "cap", 8), 2))
|
|
2877
|
+
except Exception:
|
|
2878
|
+
pass
|
|
2879
|
+
|
|
2880
|
+
# Disable CPU→GPU prefetch to avoid double-buffering allocations
|
|
2881
|
+
prefetch_batches = False
|
|
2882
|
+
|
|
2883
|
+
# Relax SEP background grid & star detection canvas when requested
|
|
2884
|
+
if use_variance_maps:
|
|
2885
|
+
varmap_cfg = {**(varmap_cfg or {})}
|
|
2886
|
+
# fewer, larger tiles → fewer big temporaries
|
|
2887
|
+
varmap_cfg.setdefault("bw", 96)
|
|
2888
|
+
varmap_cfg.setdefault("bh", 96)
|
|
2889
|
+
|
|
2890
|
+
if use_star_masks:
|
|
2891
|
+
star_mask_cfg = {**(star_mask_cfg or {})}
|
|
2892
|
+
# shrink detection canvas to limit temp buffers inside SEP/mask draw
|
|
2893
|
+
star_mask_cfg["max_side"] = int(min(1024, int(star_mask_cfg.get("max_side", 2048))))
|
|
2894
|
+
# --- LOW-MEM PATCH (end) ---
|
|
2895
|
+
|
|
2896
|
+
|
|
2897
|
+
if any(os.path.splitext(p)[1].lower() == ".xisf" for p in paths):
|
|
2898
|
+
status_cb("MFDeconv: priming XISF cache (one-time decode per frame)…")
|
|
2899
|
+
for i, p in enumerate(paths, 1):
|
|
2900
|
+
try:
|
|
2901
|
+
_ = _xisf_cached_array(p) # decode once, store memmap
|
|
2902
|
+
except Exception as e:
|
|
2903
|
+
status_cb(f"XISF cache failed for {p}: {e}")
|
|
2904
|
+
if (i & 7) == 0 or i == len(paths):
|
|
2905
|
+
_process_gui_events_safely()
|
|
2906
|
+
|
|
2907
|
+
# per-frame loader & sequence view (closures capture Ht/Wt/color_mode/paths)
|
|
2908
|
+
def _load_frame_chw(i: int):
|
|
2909
|
+
return _FRAME_LRU.get(paths[i], Ht, Wt, color_mode)
|
|
2910
|
+
|
|
2911
|
+
class _FrameSeq:
|
|
2912
|
+
def __len__(self): return len(paths)
|
|
2913
|
+
def __getitem__(self, i): return _load_frame_chw(i)
|
|
2914
|
+
data = _FrameSeq()
|
|
2915
|
+
|
|
2916
|
+
# ---- torch detection (optional) ----
|
|
2917
|
+
global torch, TORCH_OK
|
|
2918
|
+
torch = None
|
|
2919
|
+
TORCH_OK = False
|
|
2920
|
+
cuda_ok = mps_ok = dml_ok = False
|
|
2921
|
+
dml_device = None
|
|
2922
|
+
|
|
2923
|
+
try:
|
|
2924
|
+
from setiastro.saspro.runtime_torch import import_torch
|
|
2925
|
+
torch = import_torch(prefer_cuda=True, status_cb=status_cb)
|
|
2926
|
+
TORCH_OK = True
|
|
2927
|
+
try: cuda_ok = hasattr(torch, "cuda") and torch.cuda.is_available()
|
|
2928
|
+
except Exception as e:
|
|
2929
|
+
import logging
|
|
2930
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
2931
|
+
try: mps_ok = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
2932
|
+
except Exception as e:
|
|
2933
|
+
import logging
|
|
2934
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
2935
|
+
try:
|
|
2936
|
+
import torch_directml
|
|
2937
|
+
dml_device = torch_directml.device()
|
|
2938
|
+
_ = (torch.ones(1, device=dml_device) + 1).item()
|
|
2939
|
+
dml_ok = True
|
|
2940
|
+
# NEW: expose to _torch_device()
|
|
2941
|
+
globals()["dml_ok"] = True
|
|
2942
|
+
globals()["dml_device"] = dml_device
|
|
2943
|
+
except Exception:
|
|
2944
|
+
dml_ok = False
|
|
2945
|
+
globals()["dml_ok"] = False
|
|
2946
|
+
globals()["dml_device"] = None
|
|
2947
|
+
|
|
2948
|
+
if cuda_ok:
|
|
2949
|
+
status_cb(f"PyTorch CUDA available: True | device={torch.cuda.get_device_name(0)}")
|
|
2950
|
+
elif mps_ok:
|
|
2951
|
+
status_cb("PyTorch MPS (Apple) available: True")
|
|
2952
|
+
elif dml_ok:
|
|
2953
|
+
status_cb("PyTorch DirectML (Windows) available: True")
|
|
2954
|
+
else:
|
|
2955
|
+
status_cb("PyTorch present, using CPU backend.")
|
|
2956
|
+
status_cb(
|
|
2957
|
+
f"PyTorch {getattr(torch,'__version__','?')} backend: "
|
|
2958
|
+
+ ("CUDA" if cuda_ok else "MPS" if mps_ok else "DirectML" if dml_ok else "CPU")
|
|
2959
|
+
)
|
|
2960
|
+
|
|
2961
|
+
try:
|
|
2962
|
+
# keep cuDNN autotune on
|
|
2963
|
+
if getattr(torch.backends, "cudnn", None) is not None:
|
|
2964
|
+
torch.backends.cudnn.benchmark = True
|
|
2965
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
2966
|
+
except Exception:
|
|
2967
|
+
pass
|
|
2968
|
+
|
|
2969
|
+
try:
|
|
2970
|
+
# disable TF32 matmul shortcuts if present (CUDA-only; safe no-op elsewhere)
|
|
2971
|
+
if getattr(getattr(torch.backends, "cuda", None), "matmul", None) is not None:
|
|
2972
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
2973
|
+
except Exception:
|
|
2974
|
+
pass
|
|
2975
|
+
|
|
2976
|
+
try:
|
|
2977
|
+
# prefer highest FP32 precision on PT 2.x
|
|
2978
|
+
if hasattr(torch, "set_float32_matmul_precision"):
|
|
2979
|
+
torch.set_float32_matmul_precision("highest")
|
|
2980
|
+
except Exception:
|
|
2981
|
+
pass
|
|
2982
|
+
except Exception as e:
|
|
2983
|
+
TORCH_OK = False
|
|
2984
|
+
status_cb(f"PyTorch not available → CPU path. ({e})")
|
|
2985
|
+
|
|
2986
|
+
if force_cpu:
|
|
2987
|
+
status_cb("⚠️ CPU-only debug mode: disabling PyTorch path.")
|
|
2988
|
+
TORCH_OK = False
|
|
2989
|
+
|
|
2990
|
+
_process_gui_events_safely()
|
|
2991
|
+
|
|
2992
|
+
# ---- PSFs + optional assets (computed in parallel, streaming I/O) ----
|
|
2993
|
+
psfs, masks_auto, vars_auto = _build_psf_and_assets(
|
|
2994
|
+
paths,
|
|
2995
|
+
make_masks=bool(use_star_masks),
|
|
2996
|
+
make_varmaps=bool(use_variance_maps),
|
|
2997
|
+
status_cb=status_cb,
|
|
2998
|
+
save_dir=None,
|
|
2999
|
+
star_mask_cfg=star_mask_cfg,
|
|
3000
|
+
varmap_cfg=varmap_cfg,
|
|
3001
|
+
star_mask_ref_path=star_mask_ref_path,
|
|
3002
|
+
# NEW:
|
|
3003
|
+
Ht=Ht, Wt=Wt, color_mode=color_mode,
|
|
3004
|
+
)
|
|
3005
|
+
|
|
3006
|
+
try:
|
|
3007
|
+
import gc as _gc
|
|
3008
|
+
_gc.collect()
|
|
3009
|
+
except Exception:
|
|
3010
|
+
pass
|
|
3011
|
+
|
|
3012
|
+
psfs_native = psfs # keep a reference to the original native PSFs for fallback
|
|
3013
|
+
|
|
3014
|
+
# ---- SR lift of PSFs if needed ----
|
|
3015
|
+
r_req = _coerce_sr_factor(super_res_factor, default_on_bad=2)
|
|
3016
|
+
status_cb(f"MFDeconv: SR factor requested={super_res_factor!r} → using r={r_req}")
|
|
3017
|
+
r = int(r_req)
|
|
3018
|
+
|
|
3019
|
+
if r > 1:
|
|
3020
|
+
status_cb(f"MFDeconv: Super-resolution r={r} with σ={sr_sigma} — solving SR PSFs…")
|
|
3021
|
+
_process_gui_events_safely()
|
|
3022
|
+
|
|
3023
|
+
def _naive_sr_from_native(f_nat: np.ndarray, r_: int) -> np.ndarray:
|
|
3024
|
+
f2 = np.asarray(f_nat, np.float32)
|
|
3025
|
+
H, W = f2.shape[:2]
|
|
3026
|
+
k_sq = min(H, W)
|
|
3027
|
+
if H != W:
|
|
3028
|
+
y0 = (H - k_sq) // 2; x0 = (W - k_sq) // 2
|
|
3029
|
+
f2 = f2[y0:y0+k_sq, x0:x0+k_sq]
|
|
3030
|
+
if (f2.shape[0] % 2) == 0:
|
|
3031
|
+
f2 = f2[1:, 1:] # make odd
|
|
3032
|
+
f2 = _normalize_psf(f2)
|
|
3033
|
+
h0 = np.zeros((f2.shape[0]*r_, f2.shape[1]*r_), np.float32)
|
|
3034
|
+
h0[::r_, ::r_] = f2
|
|
3035
|
+
return _normalize_psf(h0)
|
|
3036
|
+
|
|
3037
|
+
sr_psfs = []
|
|
3038
|
+
for i, k_native in enumerate(psfs, start=1):
|
|
3039
|
+
try:
|
|
3040
|
+
status_cb(f" SR-PSF{i}: native shape={np.asarray(k_native).shape}")
|
|
3041
|
+
h = _solve_super_psf_from_native(
|
|
3042
|
+
k_native, r=r, sigma=float(sr_sigma),
|
|
3043
|
+
iters=int(sr_psf_opt_iters), lr=float(sr_psf_opt_lr)
|
|
3044
|
+
)
|
|
3045
|
+
except Exception as e:
|
|
3046
|
+
status_cb(f" SR-PSF{i} failed: {e!r} → using naïve upsample")
|
|
3047
|
+
h = _naive_sr_from_native(k_native, r)
|
|
3048
|
+
|
|
3049
|
+
# guarantee odd size for downstream SAME-padding math
|
|
3050
|
+
if (h.shape[0] % 2) == 0:
|
|
3051
|
+
h = h[:-1, :-1]
|
|
3052
|
+
sr_psfs.append(h.astype(np.float32, copy=False))
|
|
3053
|
+
status_cb(f" SR-PSF{i}: native {np.asarray(k_native).shape[0]} → {h.shape[0]} (sum={h.sum():.6f})")
|
|
3054
|
+
psfs = sr_psfs
|
|
3055
|
+
|
|
3056
|
+
|
|
3057
|
+
|
|
3058
|
+
# ---- Seed (streaming) with robust bootstrap already in file helpers ----
|
|
3059
|
+
_emit_pct(0.25, "Calculating Seed Image...")
|
|
3060
|
+
def _seed_progress(frac, msg):
|
|
3061
|
+
_emit_pct(0.25 + 0.15 * float(frac), f"seed: {msg}")
|
|
3062
|
+
|
|
3063
|
+
seed_mode_s = str(seed_mode).lower().strip()
|
|
3064
|
+
if seed_mode_s not in ("robust", "median"):
|
|
3065
|
+
seed_mode_s = "robust"
|
|
3066
|
+
if seed_mode_s == "median":
|
|
3067
|
+
status_cb("MFDeconv: Building median seed (tiled, streaming)…")
|
|
3068
|
+
seed_native = _seed_median_streaming(
|
|
3069
|
+
paths, Ht, Wt,
|
|
3070
|
+
color_mode=color_mode,
|
|
3071
|
+
tile_hw=(256, 256),
|
|
3072
|
+
status_cb=status_cb,
|
|
3073
|
+
progress_cb=_seed_progress,
|
|
3074
|
+
use_torch=TORCH_OK, # ← auto: GPU if available, else NumPy
|
|
3075
|
+
)
|
|
3076
|
+
else:
|
|
3077
|
+
seed_native = _seed_bootstrap_streaming(
|
|
3078
|
+
paths, Ht, Wt, color_mode,
|
|
3079
|
+
bootstrap_frames=20, clip_sigma=5,
|
|
3080
|
+
status_cb=status_cb, progress_cb=_seed_progress
|
|
3081
|
+
)
|
|
3082
|
+
|
|
3083
|
+
# lift seed if SR
|
|
3084
|
+
if r > 1:
|
|
3085
|
+
target_hw = (Ht * r, Wt * r)
|
|
3086
|
+
if seed_native.ndim == 2:
|
|
3087
|
+
x = _upsample_sum(seed_native / (r*r), r, target_hw=target_hw)
|
|
3088
|
+
else:
|
|
3089
|
+
C, Hn, Wn = seed_native.shape
|
|
3090
|
+
x = np.stack(
|
|
3091
|
+
[_upsample_sum(seed_native[c] / (r*r), r, target_hw=target_hw) for c in range(C)],
|
|
3092
|
+
axis=0
|
|
3093
|
+
)
|
|
3094
|
+
else:
|
|
3095
|
+
x = seed_native
|
|
3096
|
+
|
|
3097
|
+
# FINAL SHAPE CHECKS (auto-correct if a GUI sent something odd)
|
|
3098
|
+
if x.ndim == 2: x = x[None, ...]
|
|
3099
|
+
Hs, Ws = x.shape[-2], x.shape[-1]
|
|
3100
|
+
if r > 1:
|
|
3101
|
+
expected_H, expected_W = Ht * r, Wt * r
|
|
3102
|
+
if (Hs, Ws) != (expected_H, expected_W):
|
|
3103
|
+
status_cb(f"SR seed grid mismatch: got {(Hs, Ws)}, expected {(expected_H, expected_W)} → correcting")
|
|
3104
|
+
# Rebuild from the native mean to ensure exact SR size
|
|
3105
|
+
x = _upsample_sum(x if x.ndim==2 else x[0], r, target_hw=(expected_H, expected_W))
|
|
3106
|
+
if x.ndim == 2: x = x[None, ...]
|
|
3107
|
+
try: del seed_native
|
|
3108
|
+
except Exception as e:
|
|
3109
|
+
import logging
|
|
3110
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
3111
|
+
|
|
3112
|
+
try:
|
|
3113
|
+
import gc as _gc
|
|
3114
|
+
_gc.collect()
|
|
3115
|
+
except Exception:
|
|
3116
|
+
pass
|
|
3117
|
+
|
|
3118
|
+
flip_psf = [_flip_kernel(k) for k in psfs]
|
|
3119
|
+
_emit_pct(0.20, "PSF Ready")
|
|
3120
|
+
|
|
3121
|
+
# --- Harmonize seed channels with the actual frames ---
|
|
3122
|
+
# Probe the first frame's channel count (CHW from the cache)
|
|
3123
|
+
try:
|
|
3124
|
+
y0_probe = _load_frame_chw(0) # CHW float32
|
|
3125
|
+
C_ref = 1 if y0_probe.ndim == 2 else int(y0_probe.shape[0])
|
|
3126
|
+
except Exception:
|
|
3127
|
+
# Fallback: infer from seed if probing fails
|
|
3128
|
+
C_ref = 1 if x.ndim == 2 else int(x.shape[0])
|
|
3129
|
+
|
|
3130
|
+
# If median seed came back mono but frames are RGB, broadcast it
|
|
3131
|
+
if x.ndim == 2 and C_ref == 3:
|
|
3132
|
+
x = np.stack([x] * 3, axis=0)
|
|
3133
|
+
elif x.ndim == 3 and x.shape[0] == 1 and C_ref == 3:
|
|
3134
|
+
x = np.repeat(x, 3, axis=0)
|
|
3135
|
+
# If seed is RGB but frames are mono (rare), collapse to luma
|
|
3136
|
+
elif x.ndim == 3 and x.shape[0] == 3 and C_ref == 1:
|
|
3137
|
+
# ITU-R BT.709 luma (safe in float)
|
|
3138
|
+
x = (0.2126 * x[0] + 0.7152 * x[1] + 0.0722 * x[2]).astype(np.float32)
|
|
3139
|
+
|
|
3140
|
+
# Ensure CHW shape for the rest of the pipeline
|
|
3141
|
+
if x.ndim == 2:
|
|
3142
|
+
x = x[None, ...]
|
|
3143
|
+
elif x.ndim == 3 and x.shape[0] not in (1, 3) and x.shape[-1] in (1, 3):
|
|
3144
|
+
x = np.moveaxis(x, -1, 0)
|
|
3145
|
+
|
|
3146
|
+
# Set the expected channel count from the FRAMES, not from the seed
|
|
3147
|
+
C_EXPECTED = int(C_ref)
|
|
3148
|
+
_, Hs, Ws = x.shape
|
|
3149
|
+
|
|
3150
|
+
|
|
3151
|
+
|
|
3152
|
+
# ---- choose default batch size ----
|
|
3153
|
+
if batch_frames is None:
|
|
3154
|
+
px = Hs * Ws
|
|
3155
|
+
if px >= 16_000_000: auto_B = 2
|
|
3156
|
+
elif px >= 8_000_000: auto_B = 4
|
|
3157
|
+
else: auto_B = 8
|
|
3158
|
+
else:
|
|
3159
|
+
auto_B = int(max(1, batch_frames))
|
|
3160
|
+
|
|
3161
|
+
# --- LOW-MEM PATCH: clamp batch size hard ---
|
|
3162
|
+
if low_mem:
|
|
3163
|
+
auto_B = max(1, min(auto_B, 2))
|
|
3164
|
+
|
|
3165
|
+
# ---- background/MAD telemetry (first frame) ----
|
|
3166
|
+
status_cb("MFDeconv: Calculating Backgrounds and MADs…")
|
|
3167
|
+
_process_gui_events_safely()
|
|
3168
|
+
try:
|
|
3169
|
+
y0 = data[0]; y0l = y0 if y0.ndim == 2 else y0[0]
|
|
3170
|
+
med = float(np.median(y0l)); mad = float(np.median(np.abs(y0l - med))) + 1e-6
|
|
3171
|
+
bg_est = 1.4826 * mad
|
|
3172
|
+
except Exception:
|
|
3173
|
+
bg_est = 0.0
|
|
3174
|
+
status_cb(f"MFDeconv: color_mode={color_mode}, huber_delta={huber_delta} (bg RMS~{bg_est:.3g})")
|
|
3175
|
+
|
|
3176
|
+
# ---- mask/variance accessors ----
|
|
3177
|
+
def _mask_for(i, like_img):
|
|
3178
|
+
src = (masks if masks is not None else masks_auto)
|
|
3179
|
+
if src is None: # no masks at all
|
|
3180
|
+
return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
|
|
3181
|
+
m = src[i]
|
|
3182
|
+
if m is None:
|
|
3183
|
+
return np.ones((like_img.shape[-2], like_img.shape[-1]), dtype=np.float32)
|
|
3184
|
+
m = np.asarray(m, dtype=np.float32)
|
|
3185
|
+
if m.ndim == 3: m = m[0]
|
|
3186
|
+
return _center_crop(m, like_img.shape[-2], like_img.shape[-1]).astype(np.float32, copy=False)
|
|
3187
|
+
|
|
3188
|
+
def _var_for(i, like_img):
|
|
3189
|
+
src = (variances if variances is not None else vars_auto)
|
|
3190
|
+
if src is None: return None
|
|
3191
|
+
v = src[i]
|
|
3192
|
+
if v is None: return None
|
|
3193
|
+
v = np.asarray(v, dtype=np.float32)
|
|
3194
|
+
if v.ndim == 3: v = v[0]
|
|
3195
|
+
v = _center_crop(v, like_img.shape[-2], like_img.shape[-1])
|
|
3196
|
+
return np.clip(v, 1e-8, None).astype(np.float32, copy=False)
|
|
3197
|
+
|
|
3198
|
+
# ---- NumPy path conv helper (keep as-is) ----
|
|
3199
|
+
def _conv_np_same(a, k, out=None):
|
|
3200
|
+
y = _conv_same_np_spatial(a, k, out)
|
|
3201
|
+
if y is not None:
|
|
3202
|
+
return y
|
|
3203
|
+
# No OpenCV → always use the ifftshifted FFT path
|
|
3204
|
+
import numpy as _np
|
|
3205
|
+
import numpy.fft as _fft
|
|
3206
|
+
H, W = a.shape[-2:]
|
|
3207
|
+
kh, kw = k.shape
|
|
3208
|
+
fftH, fftW = _fftshape_same(H, W, kh, kw)
|
|
3209
|
+
Kf = _fft.rfftn(_np.fft.ifftshift(k), s=(fftH, fftW))
|
|
3210
|
+
if out is None:
|
|
3211
|
+
out = _np.empty_like(a, dtype=_np.float32)
|
|
3212
|
+
return _fft_conv_same_np(a, Kf, kh, kw, fftH, fftW, out)
|
|
3213
|
+
|
|
3214
|
+
# ---- allocate scratch & prepare PSF tensors if torch ----
|
|
3215
|
+
relax = 0.7
|
|
3216
|
+
use_torch = bool(TORCH_OK)
|
|
3217
|
+
cm = _safe_inference_context() if use_torch else NO_GRAD
|
|
3218
|
+
rho_is_l2 = (str(rho).lower() == "l2")
|
|
3219
|
+
local_delta = 0.0 if rho_is_l2 else huber_delta
|
|
3220
|
+
|
|
3221
|
+
if use_torch:
|
|
3222
|
+
F = torch.nn.functional
|
|
3223
|
+
device = _torch_device()
|
|
3224
|
+
|
|
3225
|
+
# Force FP32 tensors
|
|
3226
|
+
x_t = _to_t(_contig(x)).to(torch.float32)
|
|
3227
|
+
num = torch.zeros_like(x_t, dtype=torch.float32)
|
|
3228
|
+
den = torch.zeros_like(x_t, dtype=torch.float32)
|
|
3229
|
+
|
|
3230
|
+
# channels-last preference kept (does not change dtype)
|
|
3231
|
+
if use_channels_last is None:
|
|
3232
|
+
use_channels_last = bool(cuda_ok) # <- never enable on MPS
|
|
3233
|
+
if mps_ok:
|
|
3234
|
+
use_channels_last = False # <- force NCHW on MPS
|
|
3235
|
+
|
|
3236
|
+
# PSF tensors strictly FP32
|
|
3237
|
+
psf_t = [_to_t(_contig(k))[None, None] for k in psfs] # (1,1,kh,kw)
|
|
3238
|
+
psfT_t = [_to_t(_contig(kT))[None, None] for kT in flip_psf]
|
|
3239
|
+
use_spatial = _torch_should_use_spatial(psf_t[0].shape[-1])
|
|
3240
|
+
|
|
3241
|
+
# No mixed precision, no autocast
|
|
3242
|
+
use_amp = False
|
|
3243
|
+
amp_cm = None
|
|
3244
|
+
amp_kwargs = {}
|
|
3245
|
+
|
|
3246
|
+
# FFT gate (worth it for large kernels / SR); still FP32
|
|
3247
|
+
# Decide FFT vs spatial FIRST, based on native kernel sizes (no global padding!)
|
|
3248
|
+
K_native_max = max(int(k.shape[0]) for k in psfs)
|
|
3249
|
+
use_fft = False
|
|
3250
|
+
if device.type == "cuda" and ((K_native_max >= int(fft_kernel_threshold)) or (r > 1 and K_native_max >= max(21, int(fft_kernel_threshold) - 4))):
|
|
3251
|
+
use_fft = True
|
|
3252
|
+
# Force NCHW for FFT branch to keep channel slices contiguous/sane
|
|
3253
|
+
use_channels_last = False
|
|
3254
|
+
# Precompute FFT packs…
|
|
3255
|
+
psf_fft, psfT_fft = _precompute_torch_psf_ffts(
|
|
3256
|
+
psfs, flip_psf, Hs, Ws, device=x_t.device, dtype=torch.float32
|
|
3257
|
+
)
|
|
3258
|
+
else:
|
|
3259
|
+
psf_fft = psfT_fft = None
|
|
3260
|
+
Kmax = K_native_max
|
|
3261
|
+
if (Kmax % 2) == 0:
|
|
3262
|
+
Kmax += 1
|
|
3263
|
+
if any(int(k.shape[0]) != Kmax for k in psfs):
|
|
3264
|
+
status_cb(f"MFDeconv: normalizing PSF sizes → {Kmax}×{Kmax}")
|
|
3265
|
+
psfs = [_pad_kernel_to(k, Kmax) for k in psfs]
|
|
3266
|
+
flip_psf = [_flip_kernel(k) for k in psfs]
|
|
3267
|
+
|
|
3268
|
+
# (Re)build spatial kernels strictly as contiguous FP32 tensors
|
|
3269
|
+
psf_t = [_to_t(_contig(k))[None, None].to(torch.float32).contiguous() for k in psfs]
|
|
3270
|
+
psfT_t = [_to_t(_contig(kT))[None, None].to(torch.float32).contiguous() for kT in flip_psf]
|
|
3271
|
+
else:
|
|
3272
|
+
x_t = _contig(x).astype(np.float32, copy=False)
|
|
3273
|
+
num = np.zeros_like(x_t, dtype=np.float32)
|
|
3274
|
+
den = np.zeros_like(x_t, dtype=np.float32)
|
|
3275
|
+
use_amp = False
|
|
3276
|
+
use_fft = False
|
|
3277
|
+
|
|
3278
|
+
# ---- batched torch helper (grouped depthwise per-sample) ----
|
|
3279
|
+
if use_torch:
|
|
3280
|
+
# inside `if use_torch:` block in multiframe_deconv — replace the whole inner helper
|
|
3281
|
+
def _grouped_conv_same_torch_per_sample(x_bc_hw, w_b1kk, B, C):
|
|
3282
|
+
"""
|
|
3283
|
+
x_bc_hw : (B,C,H,W), torch.float32 on device
|
|
3284
|
+
w_b1kk : (B,1,kh,kw), torch.float32 on device
|
|
3285
|
+
Returns (B,C,H,W) contiguous (NCHW).
|
|
3286
|
+
"""
|
|
3287
|
+
F = torch.nn.functional
|
|
3288
|
+
|
|
3289
|
+
# Force standard NCHW contiguous tensors
|
|
3290
|
+
x_bc_hw = x_bc_hw.to(memory_format=torch.contiguous_format).contiguous()
|
|
3291
|
+
w_b1kk = w_b1kk.to(memory_format=torch.contiguous_format).contiguous()
|
|
3292
|
+
|
|
3293
|
+
kh, kw = int(w_b1kk.shape[-2]), int(w_b1kk.shape[-1])
|
|
3294
|
+
pad = (kw // 2, kw - kw // 2 - 1, kh // 2, kh - kh // 2 - 1)
|
|
3295
|
+
|
|
3296
|
+
if x_bc_hw.device.type == "mps":
|
|
3297
|
+
# Safe, slower path: convolve each channel separately, no groups
|
|
3298
|
+
ys = []
|
|
3299
|
+
for j in range(B): # per sample
|
|
3300
|
+
xj = x_bc_hw[j:j+1] # (1,C,H,W)
|
|
3301
|
+
# reflect pad once per sample
|
|
3302
|
+
xj = F.pad(xj, pad, mode="reflect")
|
|
3303
|
+
cj_out = []
|
|
3304
|
+
# one shared kernel per sample j: (1,1,kh,kw)
|
|
3305
|
+
kj = w_b1kk[j:j+1] # keep shape (1,1,kh,kw)
|
|
3306
|
+
for c in range(C):
|
|
3307
|
+
# slice that channel as its own (1,1,H,W) tensor
|
|
3308
|
+
xjc = xj[:, c:c+1, ...]
|
|
3309
|
+
yjc = F.conv2d(xjc, kj, padding=0, groups=1) # no groups
|
|
3310
|
+
cj_out.append(yjc)
|
|
3311
|
+
ys.append(torch.cat(cj_out, dim=1)) # (1,C,H,W)
|
|
3312
|
+
return torch.stack([y[0] for y in ys], 0).contiguous()
|
|
3313
|
+
|
|
3314
|
+
|
|
3315
|
+
# ---- FAST PATH (CUDA/CPU): single grouped conv with G=B*C ----
|
|
3316
|
+
G = int(B * C)
|
|
3317
|
+
x_1ghw = x_bc_hw.reshape(1, G, x_bc_hw.shape[-2], x_bc_hw.shape[-1])
|
|
3318
|
+
x_1ghw = F.pad(x_1ghw, pad, mode="reflect")
|
|
3319
|
+
w_g1kk = w_b1kk.repeat_interleave(C, dim=0) # (G,1,kh,kw)
|
|
3320
|
+
y_1ghw = F.conv2d(x_1ghw, w_g1kk, padding=0, groups=G)
|
|
3321
|
+
return y_1ghw.reshape(B, C, y_1ghw.shape[-2], y_1ghw.shape[-1]).contiguous()
|
|
3322
|
+
|
|
3323
|
+
|
|
3324
|
+
def _downsample_avg_bt_t(x, r_):
|
|
3325
|
+
if r_ <= 1:
|
|
3326
|
+
return x
|
|
3327
|
+
B, C, H, W = x.shape
|
|
3328
|
+
Hr, Wr = (H // r_) * r_, (W // r_) * r_
|
|
3329
|
+
if Hr == 0 or Wr == 0:
|
|
3330
|
+
return x
|
|
3331
|
+
return x[:, :, :Hr, :Wr].reshape(B, C, Hr // r_, r_, Wr // r_, r_).mean(dim=(3, 5))
|
|
3332
|
+
|
|
3333
|
+
def _upsample_sum_bt_t(x, r_):
|
|
3334
|
+
if r_ <= 1:
|
|
3335
|
+
return x
|
|
3336
|
+
return x.repeat_interleave(r_, dim=-2).repeat_interleave(r_, dim=-1)
|
|
3337
|
+
|
|
3338
|
+
def _make_pinned_batch(idx, C_expected, to_device_dtype):
|
|
3339
|
+
y_list, m_list, v_list = [], [], []
|
|
3340
|
+
for fi in idx:
|
|
3341
|
+
y_chw = _load_frame_chw(fi) # CHW float32 from cache
|
|
3342
|
+
if y_chw.ndim != 3:
|
|
3343
|
+
raise RuntimeError(f"Frame {fi}: expected CHW, got {tuple(y_chw.shape)}")
|
|
3344
|
+
C_here = int(y_chw.shape[0])
|
|
3345
|
+
if C_expected is not None and C_here != C_expected:
|
|
3346
|
+
raise RuntimeError(f"Mixed channel counts: expected C={C_expected}, got C={C_here} (frame {fi})")
|
|
3347
|
+
|
|
3348
|
+
m2d = _mask_for(fi, y_chw)
|
|
3349
|
+
v2d = _var_for(fi, y_chw)
|
|
3350
|
+
y_list.append(y_chw); m_list.append(m2d); v_list.append(v2d)
|
|
3351
|
+
|
|
3352
|
+
# CPU (NCHW) tensors
|
|
3353
|
+
y_cpu = torch.from_numpy(np.stack(y_list, 0)).to(torch.float32).contiguous()
|
|
3354
|
+
m_cpu = torch.from_numpy(np.stack(m_list, 0)).to(torch.float32).contiguous()
|
|
3355
|
+
have_v = all(v is not None for v in v_list)
|
|
3356
|
+
vb_cpu = None if not have_v else torch.from_numpy(np.stack(v_list, 0)).to(torch.float32).contiguous()
|
|
3357
|
+
|
|
3358
|
+
# optional pin for faster H2D on CUDA
|
|
3359
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
3360
|
+
y_cpu = y_cpu.pin_memory(); m_cpu = m_cpu.pin_memory()
|
|
3361
|
+
if vb_cpu is not None: vb_cpu = vb_cpu.pin_memory()
|
|
3362
|
+
|
|
3363
|
+
# move to device in FP32, keep NCHW contiguous format
|
|
3364
|
+
yb = y_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
|
|
3365
|
+
mb = m_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
|
|
3366
|
+
vb = None if vb_cpu is None else vb_cpu.to(x_t.device, dtype=torch.float32, non_blocking=True).contiguous()
|
|
3367
|
+
|
|
3368
|
+
return yb, mb, vb
|
|
3369
|
+
|
|
3370
|
+
|
|
3371
|
+
|
|
3372
|
+
# ---- intermediates folder ----
|
|
3373
|
+
iter_dir = None
|
|
3374
|
+
hdr0_seed = None
|
|
3375
|
+
if save_intermediate:
|
|
3376
|
+
iter_dir = _iter_folder(out_path)
|
|
3377
|
+
status_cb(f"MFDeconv: Intermediate outputs → {iter_dir}")
|
|
3378
|
+
try:
|
|
3379
|
+
hdr0 = _safe_primary_header(paths[0])
|
|
3380
|
+
except Exception:
|
|
3381
|
+
hdr0 = fits.Header()
|
|
3382
|
+
_save_iter_image(x, hdr0, iter_dir, "seed", color_mode)
|
|
3383
|
+
|
|
3384
|
+
# ---- iterative loop ----
|
|
3385
|
+
|
|
3386
|
+
|
|
3387
|
+
|
|
3388
|
+
auto_delta_cache = None
|
|
3389
|
+
if use_torch and (huber_delta < 0) and (not rho_is_l2):
|
|
3390
|
+
auto_delta_cache = [None] * n_frames
|
|
3391
|
+
early = EarlyStopper(
|
|
3392
|
+
tol_upd_floor=2e-4,
|
|
3393
|
+
tol_rel_floor=5e-4,
|
|
3394
|
+
early_frac=0.40,
|
|
3395
|
+
ema_alpha=0.5,
|
|
3396
|
+
patience=2,
|
|
3397
|
+
min_iters=min_iters,
|
|
3398
|
+
status_cb=status_cb
|
|
3399
|
+
)
|
|
3400
|
+
|
|
3401
|
+
used_iters = 0
|
|
3402
|
+
early_stopped = False
|
|
3403
|
+
with cm():
|
|
3404
|
+
for it in range(1, max_iters + 1):
|
|
3405
|
+
# reset accumulators
|
|
3406
|
+
if use_torch:
|
|
3407
|
+
num.zero_(); den.zero_()
|
|
3408
|
+
else:
|
|
3409
|
+
num.fill(0.0); den.fill(0.0)
|
|
3410
|
+
|
|
3411
|
+
if use_torch:
|
|
3412
|
+
# ---- batched GPU path (hardened with busy→CPU fallback) ----
|
|
3413
|
+
retry_cpu = False # if True, we’ll redo this iteration’s accumulators on CPU
|
|
3414
|
+
|
|
3415
|
+
try:
|
|
3416
|
+
frame_idx = list(range(n_frames))
|
|
3417
|
+
B_cur = int(max(1, auto_B))
|
|
3418
|
+
ci = 0
|
|
3419
|
+
Cn = None
|
|
3420
|
+
|
|
3421
|
+
# AMP context helper
|
|
3422
|
+
def _maybe_amp():
|
|
3423
|
+
if use_amp and (amp_cm is not None):
|
|
3424
|
+
return amp_cm(**amp_kwargs)
|
|
3425
|
+
from contextlib import nullcontext
|
|
3426
|
+
return nullcontext()
|
|
3427
|
+
|
|
3428
|
+
# prefetch first batch (unchanged)
|
|
3429
|
+
if prefetch_batches:
|
|
3430
|
+
idx0 = frame_idx[ci:ci+B_cur]
|
|
3431
|
+
if idx0:
|
|
3432
|
+
Cn = C_EXPECTED
|
|
3433
|
+
yb_next, mb_next, vb_next = _make_pinned_batch(idx0, Cn, x_t.dtype)
|
|
3434
|
+
else:
|
|
3435
|
+
yb_next = mb_next = vb_next = None
|
|
3436
|
+
|
|
3437
|
+
while ci < n_frames:
|
|
3438
|
+
idx = frame_idx[ci:ci + B_cur]
|
|
3439
|
+
B = len(idx)
|
|
3440
|
+
try:
|
|
3441
|
+
# ---- existing gather/conv/weights/backproject code (UNCHANGED) ----
|
|
3442
|
+
if prefetch_batches and (yb_next is not None):
|
|
3443
|
+
yb, mb, vb = yb_next, mb_next, vb_next
|
|
3444
|
+
ci2 = ci + B_cur
|
|
3445
|
+
if ci2 < n_frames:
|
|
3446
|
+
idx2 = frame_idx[ci2:ci2+B_cur]
|
|
3447
|
+
Cn = Cn or (_as_chw(_sanitize_numeric(data[idx[0]])).shape[0])
|
|
3448
|
+
yb_next, mb_next, vb_next = _make_pinned_batch(idx2, Cn, torch.float32)
|
|
3449
|
+
else:
|
|
3450
|
+
yb_next = mb_next = vb_next = None
|
|
3451
|
+
else:
|
|
3452
|
+
Cn = C_EXPECTED
|
|
3453
|
+
yb, mb, vb = _make_pinned_batch(idx, Cn, torch.float32)
|
|
3454
|
+
if use_channels_last:
|
|
3455
|
+
yb = yb.contiguous(memory_format=torch.channels_last)
|
|
3456
|
+
|
|
3457
|
+
wk = torch.cat([psf_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
|
|
3458
|
+
wkT = torch.cat([psfT_t[fi] for fi in idx], dim=0).to(memory_format=torch.contiguous_format).contiguous()
|
|
3459
|
+
|
|
3460
|
+
# --- predict on SR grid ---
|
|
3461
|
+
x_bc_hw = x_t.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
|
|
3462
|
+
if use_channels_last:
|
|
3463
|
+
x_bc_hw = x_bc_hw.contiguous(memory_format=torch.channels_last)
|
|
3464
|
+
|
|
3465
|
+
if use_fft:
|
|
3466
|
+
pred_list = []
|
|
3467
|
+
chan_tmp = torch.empty_like(x_t[0], dtype=torch.float32) # (H,W) single-channel
|
|
3468
|
+
for j, fi in enumerate(idx):
|
|
3469
|
+
C_here = x_bc_hw.shape[1]
|
|
3470
|
+
pred_j = torch.empty_like(x_t, dtype=torch.float32) # (C,H,W)
|
|
3471
|
+
Kf_pack = psf_fft[fi]
|
|
3472
|
+
for c in range(C_here):
|
|
3473
|
+
_fft_conv_same_torch(x_bc_hw[j, c], Kf_pack, out_spatial=chan_tmp) # write into chan_tmp
|
|
3474
|
+
pred_j[c].copy_(chan_tmp) # copy once
|
|
3475
|
+
pred_list.append(pred_j)
|
|
3476
|
+
pred_super = torch.stack(pred_list, 0).contiguous()
|
|
3477
|
+
else:
|
|
3478
|
+
pred_super = _grouped_conv_same_torch_per_sample(
|
|
3479
|
+
x_bc_hw, wk, B, Cn
|
|
3480
|
+
).contiguous()
|
|
3481
|
+
|
|
3482
|
+
pred_low = _downsample_avg_bt_t(pred_super, r) if r > 1 else pred_super
|
|
3483
|
+
|
|
3484
|
+
# --- robust weights (UNCHANGED) ---
|
|
3485
|
+
rnat = yb - pred_low
|
|
3486
|
+
|
|
3487
|
+
# Build/estimate variance map on the native grid (per batch)
|
|
3488
|
+
if vb is None:
|
|
3489
|
+
# robust per-batch variance estimate
|
|
3490
|
+
med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
|
|
3491
|
+
vmap = (1.4826 * mad) ** 2
|
|
3492
|
+
# repeat across channels
|
|
3493
|
+
vmap = vmap.repeat(1, Cn, rnat.shape[-2], rnat.shape[-1]).contiguous()
|
|
3494
|
+
else:
|
|
3495
|
+
vmap = vb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
|
|
3496
|
+
|
|
3497
|
+
if str(rho).lower() == "l2":
|
|
3498
|
+
# L2 ⇒ psi/r = 1 (no robust clipping). Weighted LS with mask/variance.
|
|
3499
|
+
psi_over_r = torch.ones_like(rnat, dtype=torch.float32, device=x_t.device)
|
|
3500
|
+
else:
|
|
3501
|
+
# Huber ⇒ auto-delta or fixed delta
|
|
3502
|
+
if huber_delta < 0:
|
|
3503
|
+
# ensure we have auto deltas
|
|
3504
|
+
if (auto_delta_cache is None) or any(auto_delta_cache[fi] is None for fi in idx) or (it % 5 == 1):
|
|
3505
|
+
Btmp, C_here, H0, W0 = rnat.shape
|
|
3506
|
+
med, mad = _robust_med_mad_t(rnat, max_elems_per_sample=2_000_000)
|
|
3507
|
+
rms = 1.4826 * mad
|
|
3508
|
+
# store per-frame deltas
|
|
3509
|
+
if auto_delta_cache is None:
|
|
3510
|
+
# should not happen because we gated creation earlier, but be safe
|
|
3511
|
+
auto_delta_cache = [None] * n_frames
|
|
3512
|
+
for j, fi in enumerate(idx):
|
|
3513
|
+
auto_delta_cache[fi] = float((-huber_delta) * torch.clamp(rms[j, 0, 0, 0], min=1e-6).item())
|
|
3514
|
+
deltas = torch.tensor([auto_delta_cache[fi] for fi in idx],
|
|
3515
|
+
device=x_t.device, dtype=torch.float32).view(B, 1, 1, 1)
|
|
3516
|
+
else:
|
|
3517
|
+
deltas = torch.tensor(float(huber_delta), device=x_t.device,
|
|
3518
|
+
dtype=torch.float32).view(1, 1, 1, 1)
|
|
3519
|
+
|
|
3520
|
+
absr = rnat.abs()
|
|
3521
|
+
psi_over_r = torch.where(absr <= deltas, torch.ones_like(absr, dtype=torch.float32),
|
|
3522
|
+
deltas / (absr + EPS))
|
|
3523
|
+
|
|
3524
|
+
# compose weights with mask and variance
|
|
3525
|
+
if DEBUG_FLAT_WEIGHTS:
|
|
3526
|
+
wmap_low = mb.unsqueeze(1) # debug: mask-only weighting
|
|
3527
|
+
else:
|
|
3528
|
+
m1 = mb.unsqueeze(1).repeat(1, Cn, 1, 1).contiguous()
|
|
3529
|
+
wmap_low = (psi_over_r / (vmap + EPS)) * m1
|
|
3530
|
+
wmap_low = torch.nan_to_num(wmap_low, nan=0.0, posinf=0.0, neginf=0.0)
|
|
3531
|
+
|
|
3532
|
+
# --- adjoint + backproject (UNCHANGED) ---
|
|
3533
|
+
if r > 1:
|
|
3534
|
+
up_y = _upsample_sum_bt_t(wmap_low * yb, r)
|
|
3535
|
+
up_pred = _upsample_sum_bt_t(wmap_low * pred_low, r)
|
|
3536
|
+
else:
|
|
3537
|
+
up_y, up_pred = wmap_low * yb, wmap_low * pred_low
|
|
3538
|
+
|
|
3539
|
+
if use_fft:
|
|
3540
|
+
back_num_list, back_den_list = [], []
|
|
3541
|
+
for j, fi in enumerate(idx):
|
|
3542
|
+
C_here = up_y.shape[1]
|
|
3543
|
+
bn_j = torch.empty_like(x_t, dtype=torch.float32)
|
|
3544
|
+
bd_j = torch.empty_like(x_t, dtype=torch.float32)
|
|
3545
|
+
KTf_pack = psfT_fft[fi]
|
|
3546
|
+
for c in range(C_here):
|
|
3547
|
+
_fft_conv_same_torch(up_y[j, c], KTf_pack, out_spatial=bn_j[c])
|
|
3548
|
+
_fft_conv_same_torch(up_pred[j, c], KTf_pack, out_spatial=bd_j[c])
|
|
3549
|
+
back_num_list.append(bn_j)
|
|
3550
|
+
back_den_list.append(bd_j)
|
|
3551
|
+
back_num = torch.stack(back_num_list, 0).sum(dim=0)
|
|
3552
|
+
back_den = torch.stack(back_den_list, 0).sum(dim=0)
|
|
3553
|
+
else:
|
|
3554
|
+
back_num = _grouped_conv_same_torch_per_sample(up_y, wkT, B, Cn).sum(dim=0)
|
|
3555
|
+
back_den = _grouped_conv_same_torch_per_sample(up_pred, wkT, B, Cn).sum(dim=0)
|
|
3556
|
+
|
|
3557
|
+
num += back_num
|
|
3558
|
+
den += back_den
|
|
3559
|
+
ci += B
|
|
3560
|
+
|
|
3561
|
+
if os.environ.get("MFDECONV_DEBUG_SYNC", "0") == "1":
|
|
3562
|
+
try:
|
|
3563
|
+
torch.cuda.synchronize()
|
|
3564
|
+
except Exception:
|
|
3565
|
+
pass
|
|
3566
|
+
|
|
3567
|
+
except RuntimeError as e:
|
|
3568
|
+
emsg = str(e).lower()
|
|
3569
|
+
# existing OOM backoff stays the same
|
|
3570
|
+
if ("out of memory" in emsg or "resource" in emsg or "alloc" in emsg) and B_cur > 1:
|
|
3571
|
+
B_cur = max(1, B_cur // 2)
|
|
3572
|
+
status_cb(f"GPU OOM: reducing batch_frames → {B_cur} and retrying this chunk.")
|
|
3573
|
+
if prefetch_batches:
|
|
3574
|
+
yb_next = mb_next = vb_next = None
|
|
3575
|
+
continue
|
|
3576
|
+
# NEW: busy/unavailable → trigger CPU retry for this iteration
|
|
3577
|
+
if _is_cuda_busy_error(e):
|
|
3578
|
+
status_cb("CUDA became unavailable mid-run → retrying this iteration on CPU and switching to CPU for the rest.")
|
|
3579
|
+
retry_cpu = True
|
|
3580
|
+
break # break the inner while; we'll recompute on CPU
|
|
3581
|
+
raise
|
|
3582
|
+
|
|
3583
|
+
_process_gui_events_safely()
|
|
3584
|
+
|
|
3585
|
+
except RuntimeError as e:
|
|
3586
|
+
# Catch outer-level CUDA busy/unavailable too (e.g., from syncs)
|
|
3587
|
+
if _is_cuda_busy_error(e):
|
|
3588
|
+
status_cb("CUDA error indicates device busy/unavailable → retrying this iteration on CPU and switching to CPU for the rest.")
|
|
3589
|
+
retry_cpu = True
|
|
3590
|
+
else:
|
|
3591
|
+
raise
|
|
3592
|
+
|
|
3593
|
+
if retry_cpu:
|
|
3594
|
+
# flip to CPU for this and subsequent iterations
|
|
3595
|
+
globals()["cuda_usable"] = False
|
|
3596
|
+
use_torch = False
|
|
3597
|
+
try:
|
|
3598
|
+
torch.cuda.empty_cache()
|
|
3599
|
+
except Exception:
|
|
3600
|
+
pass
|
|
3601
|
+
|
|
3602
|
+
try:
|
|
3603
|
+
import gc as _gc
|
|
3604
|
+
_gc.collect()
|
|
3605
|
+
except Exception:
|
|
3606
|
+
pass
|
|
3607
|
+
|
|
3608
|
+
# Rebuild accumulators on CPU for this iteration:
|
|
3609
|
+
# 1) convert x_t (seed/current estimate) to NumPy
|
|
3610
|
+
x_t = x_t.detach().cpu().numpy()
|
|
3611
|
+
# 2) reset accumulators as NumPy arrays
|
|
3612
|
+
num = np.zeros_like(x_t, dtype=np.float32)
|
|
3613
|
+
den = np.zeros_like(x_t, dtype=np.float32)
|
|
3614
|
+
|
|
3615
|
+
# 3) RUN THE EXISTING NUMPY ACCUMULATION (same as your else: block below)
|
|
3616
|
+
# 3) recompute this iteration’s accumulators on CPU
|
|
3617
|
+
for fi in range(n_frames):
|
|
3618
|
+
# load one frame as CHW (float32, sanitized)
|
|
3619
|
+
|
|
3620
|
+
y_chw = _load_frame_chw(fi) # CHW float32 from cache
|
|
3621
|
+
|
|
3622
|
+
# forward predict (super grid) with this frame's PSF
|
|
3623
|
+
pred_super = _conv_np_same(x_t, psfs[fi])
|
|
3624
|
+
|
|
3625
|
+
# downsample to native if SR was on
|
|
3626
|
+
pred_low = _downsample_avg(pred_super, r) if r > 1 else pred_super
|
|
3627
|
+
|
|
3628
|
+
# per-frame mask/variance
|
|
3629
|
+
m2d = _mask_for(fi, y_chw) # 2D, [0..1]
|
|
3630
|
+
v2d = _var_for(fi, y_chw) # 2D or None
|
|
3631
|
+
|
|
3632
|
+
# robust weights per pixel/channel
|
|
3633
|
+
w = _weight_map(
|
|
3634
|
+
y=y_chw, pred=pred_low,
|
|
3635
|
+
huber_delta=local_delta, # 0.0 for L2, else huber_delta
|
|
3636
|
+
var_map=v2d, mask=m2d
|
|
3637
|
+
).astype(np.float32, copy=False)
|
|
3638
|
+
|
|
3639
|
+
# adjoint (backproject) on super grid
|
|
3640
|
+
if r > 1:
|
|
3641
|
+
up_y = _upsample_sum(w * y_chw, r, target_hw=x_t.shape[-2:])
|
|
3642
|
+
up_pred = _upsample_sum(w * pred_low, r, target_hw=x_t.shape[-2:])
|
|
3643
|
+
else:
|
|
3644
|
+
up_y, up_pred = (w * y_chw), (w * pred_low)
|
|
3645
|
+
|
|
3646
|
+
num += _conv_np_same(up_y, flip_psf[fi])
|
|
3647
|
+
den += _conv_np_same(up_pred, flip_psf[fi])
|
|
3648
|
+
|
|
3649
|
+
# ensure strictly positive denominator
|
|
3650
|
+
den = np.clip(den, 1e-8, None).astype(np.float32, copy=False)
|
|
3651
|
+
|
|
3652
|
+
# switch everything to NumPy for the remainder of the run
|
|
3653
|
+
psf_fft = psfT_fft = None # (Torch packs no longer used)
|
|
3654
|
+
|
|
3655
|
+
# ---- multiplicative RL/MM step with clamping ----
|
|
3656
|
+
if use_torch:
|
|
3657
|
+
ratio = num / (den + EPS)
|
|
3658
|
+
neutral = (den.abs() < 1e-12) & (num.abs() < 1e-12)
|
|
3659
|
+
ratio = torch.where(neutral, torch.ones_like(ratio), ratio)
|
|
3660
|
+
upd = torch.clamp(ratio, 1.0 / kappa, kappa)
|
|
3661
|
+
x_next = torch.clamp(x_t * upd, min=0.0)
|
|
3662
|
+
# Robust scalars
|
|
3663
|
+
upd_med = torch.median(torch.abs(upd - 1))
|
|
3664
|
+
rel_change = (torch.median(torch.abs(x_next - x_t)) /
|
|
3665
|
+
(torch.median(torch.abs(x_t)) + 1e-8))
|
|
3666
|
+
|
|
3667
|
+
um = float(upd_med.detach().item())
|
|
3668
|
+
rc = float(rel_change.detach().item())
|
|
3669
|
+
|
|
3670
|
+
if early.step(it, max_iters, um, rc):
|
|
3671
|
+
x_t = x_next
|
|
3672
|
+
used_iters = it
|
|
3673
|
+
early_stopped = True
|
|
3674
|
+
status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
|
|
3675
|
+
_process_gui_events_safely()
|
|
3676
|
+
break
|
|
3677
|
+
|
|
3678
|
+
x_t = (1.0 - relax) * x_t + relax * x_next
|
|
3679
|
+
else:
|
|
3680
|
+
ratio = num / (den + EPS)
|
|
3681
|
+
neutral = (np.abs(den) < 1e-12) & (np.abs(num) < 1e-12)
|
|
3682
|
+
if np.any(neutral): ratio[neutral] = 1.0
|
|
3683
|
+
upd = np.clip(ratio, 1.0 / kappa, kappa)
|
|
3684
|
+
x_next = np.clip(x_t * upd, 0.0, None)
|
|
3685
|
+
|
|
3686
|
+
um = float(np.median(np.abs(upd - 1.0)))
|
|
3687
|
+
rc = float(np.median(np.abs(x_next - x_t)) / (np.median(np.abs(x_t)) + 1e-8))
|
|
3688
|
+
|
|
3689
|
+
if early.step(it, max_iters, um, rc):
|
|
3690
|
+
x_t = x_next
|
|
3691
|
+
used_iters = it
|
|
3692
|
+
early_stopped = True
|
|
3693
|
+
status_cb(f"MFDeconv: Iteration {it}/{max_iters} (early stop)")
|
|
3694
|
+
_process_gui_events_safely()
|
|
3695
|
+
break
|
|
3696
|
+
|
|
3697
|
+
x_t = (1.0 - relax) * x_t + relax * x_next
|
|
3698
|
+
|
|
3699
|
+
# --- LOW-MEM CLEANUP (per-iteration) ---
|
|
3700
|
+
if low_mem:
|
|
3701
|
+
# Torch temporaries we created in the iteration (best-effort deletes)
|
|
3702
|
+
to_kill = [
|
|
3703
|
+
"pred_super", "pred_low", "wmap_low", "yb", "mb", "vb", "wk", "wkT",
|
|
3704
|
+
"back_num", "back_den", "pred_list", "back_num_list", "back_den_list",
|
|
3705
|
+
"x_bc_hw", "up_y", "up_pred", "psi_over_r", "vmap", "rnat", "deltas", "chan_tmp"
|
|
3706
|
+
]
|
|
3707
|
+
loc = locals()
|
|
3708
|
+
for _name in to_kill:
|
|
3709
|
+
if _name in loc:
|
|
3710
|
+
try:
|
|
3711
|
+
del loc[_name]
|
|
3712
|
+
except Exception:
|
|
3713
|
+
pass
|
|
3714
|
+
|
|
3715
|
+
# Proactively release CUDA cache every other iter
|
|
3716
|
+
if use_torch:
|
|
3717
|
+
try:
|
|
3718
|
+
dev = _torch_device()
|
|
3719
|
+
if dev.type == "cuda" and (it % 2) == 0:
|
|
3720
|
+
import torch as _t
|
|
3721
|
+
_t.cuda.empty_cache()
|
|
3722
|
+
except Exception:
|
|
3723
|
+
pass
|
|
3724
|
+
|
|
3725
|
+
# Encourage Python to return big NumPy buffers to the OS sooner
|
|
3726
|
+
import gc as _gc
|
|
3727
|
+
_gc.collect()
|
|
3728
|
+
|
|
3729
|
+
|
|
3730
|
+
# ---- save intermediates ----
|
|
3731
|
+
if save_intermediate and (it % int(max(1, save_every)) == 0):
|
|
3732
|
+
try:
|
|
3733
|
+
x_np = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
|
|
3734
|
+
_save_iter_image(x_np, hdr0_seed, iter_dir, f"iter_{it:03d}", color_mode)
|
|
3735
|
+
except Exception as _e:
|
|
3736
|
+
status_cb(f"Intermediate save failed at iter {it}: {_e}")
|
|
3737
|
+
|
|
3738
|
+
frac = 0.25 + 0.70 * (it / float(max_iters))
|
|
3739
|
+
_emit_pct(frac, f"Iteration {it}/{max_iters}")
|
|
3740
|
+
status_cb(f"Iter {it}/{max_iters}")
|
|
3741
|
+
_process_gui_events_safely()
|
|
3742
|
+
|
|
3743
|
+
if not early_stopped:
|
|
3744
|
+
used_iters = max_iters
|
|
3745
|
+
|
|
3746
|
+
# ---- save result ----
|
|
3747
|
+
_emit_pct(0.97, "saving")
|
|
3748
|
+
x_final = x_t.detach().cpu().numpy().astype(np.float32) if use_torch else x_t.astype(np.float32)
|
|
3749
|
+
if x_final.ndim == 3:
|
|
3750
|
+
if x_final.shape[0] not in (1, 3) and x_final.shape[-1] in (1, 3):
|
|
3751
|
+
x_final = np.moveaxis(x_final, -1, 0)
|
|
3752
|
+
if x_final.shape[0] == 1:
|
|
3753
|
+
x_final = x_final[0]
|
|
3754
|
+
|
|
3755
|
+
try:
|
|
3756
|
+
hdr0 = _safe_primary_header(paths[0])
|
|
3757
|
+
except Exception:
|
|
3758
|
+
hdr0 = fits.Header()
|
|
3759
|
+
|
|
3760
|
+
hdr0['MFDECONV'] = (True, 'Seti Astro multi-frame deconvolution')
|
|
3761
|
+
hdr0['MF_COLOR'] = (str(color_mode), 'Color mode used')
|
|
3762
|
+
hdr0['MF_RHO'] = (str(rho), 'Loss: huber|l2')
|
|
3763
|
+
hdr0['MF_HDEL'] = (float(huber_delta), 'Huber delta (>0 abs, <0 autoxRMS)')
|
|
3764
|
+
hdr0['MF_MASK'] = (bool(use_star_masks), 'Used auto star masks')
|
|
3765
|
+
hdr0['MF_VAR'] = (bool(use_variance_maps), 'Used auto variance maps')
|
|
3766
|
+
r = int(max(1, super_res_factor))
|
|
3767
|
+
hdr0['MF_SR'] = (int(r), 'Super-resolution factor (1 := native)')
|
|
3768
|
+
if r > 1:
|
|
3769
|
+
hdr0['MF_SRSIG'] = (float(sr_sigma), 'Gaussian sigma for SR PSF fit (native px)')
|
|
3770
|
+
hdr0['MF_SRIT'] = (int(sr_psf_opt_iters), 'SR-PSF solver iters')
|
|
3771
|
+
hdr0['MF_ITMAX'] = (int(max_iters), 'Requested max iterations')
|
|
3772
|
+
hdr0['MF_ITERS'] = (int(used_iters), 'Actual iterations run')
|
|
3773
|
+
hdr0['MF_ESTOP'] = (bool(early_stopped), 'Early stop triggered')
|
|
3774
|
+
|
|
3775
|
+
if isinstance(x_final, np.ndarray):
|
|
3776
|
+
if x_final.ndim == 2:
|
|
3777
|
+
hdr0['MF_SHAPE'] = (f"{x_final.shape[0]}x{x_final.shape[1]}", 'Saved as 2D image (HxW)')
|
|
3778
|
+
elif x_final.ndim == 3:
|
|
3779
|
+
C, H, W = x_final.shape
|
|
3780
|
+
hdr0['MF_SHAPE'] = (f"{C}x{H}x{W}", 'Saved as 3D cube (CxHxW)')
|
|
3781
|
+
|
|
3782
|
+
save_path = _sr_out_path(out_path, super_res_factor)
|
|
3783
|
+
safe_out_path = _nonclobber_path(str(save_path))
|
|
3784
|
+
if safe_out_path != str(save_path):
|
|
3785
|
+
status_cb(f"Output exists — saving as: {safe_out_path}")
|
|
3786
|
+
fits.PrimaryHDU(data=x_final, header=hdr0).writeto(safe_out_path, overwrite=False)
|
|
3787
|
+
|
|
3788
|
+
status_cb(f"✅ MFDeconv saved: {safe_out_path} (iters used: {used_iters}{', early stop' if early_stopped else ''})")
|
|
3789
|
+
_emit_pct(1.00, "done")
|
|
3790
|
+
_process_gui_events_safely()
|
|
3791
|
+
|
|
3792
|
+
try:
|
|
3793
|
+
if use_torch:
|
|
3794
|
+
try: del num, den
|
|
3795
|
+
except Exception as e:
|
|
3796
|
+
import logging
|
|
3797
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
3798
|
+
try: del psf_t, psfT_t
|
|
3799
|
+
except Exception as e:
|
|
3800
|
+
import logging
|
|
3801
|
+
logging.debug(f"Exception suppressed: {type(e).__name__}: {e}")
|
|
3802
|
+
_free_torch_memory()
|
|
3803
|
+
except Exception:
|
|
3804
|
+
pass
|
|
3805
|
+
|
|
3806
|
+
try:
|
|
3807
|
+
_clear_all_caches()
|
|
3808
|
+
except Exception:
|
|
3809
|
+
pass
|
|
3810
|
+
|
|
3811
|
+
return safe_out_path
|
|
3812
|
+
|
|
3813
|
+
# -----------------------------
|
|
3814
|
+
# Worker
|
|
3815
|
+
# -----------------------------
|
|
3816
|
+
|
|
3817
|
+
class MultiFrameDeconvWorker(QObject):
|
|
3818
|
+
progress = pyqtSignal(str)
|
|
3819
|
+
finished = pyqtSignal(bool, str, str) # success, message, out_path
|
|
3820
|
+
|
|
3821
|
+
def __init__(self, parent, aligned_paths, output_path, iters, kappa, color_mode,
|
|
3822
|
+
huber_delta, min_iters, use_star_masks=False, use_variance_maps=False, rho="huber",
|
|
3823
|
+
star_mask_cfg: dict | None = None, varmap_cfg: dict | None = None,
|
|
3824
|
+
save_intermediate: bool = False,
|
|
3825
|
+
seed_mode: str = "robust",
|
|
3826
|
+
# NEW SR params
|
|
3827
|
+
super_res_factor: int = 1,
|
|
3828
|
+
sr_sigma: float = 1.1,
|
|
3829
|
+
sr_psf_opt_iters: int = 250,
|
|
3830
|
+
sr_psf_opt_lr: float = 0.1,
|
|
3831
|
+
star_mask_ref_path: str | None = None):
|
|
3832
|
+
|
|
3833
|
+
super().__init__(parent)
|
|
3834
|
+
self.aligned_paths = aligned_paths
|
|
3835
|
+
self.output_path = output_path
|
|
3836
|
+
self.iters = iters
|
|
3837
|
+
self.kappa = kappa
|
|
3838
|
+
self.color_mode = color_mode
|
|
3839
|
+
self.huber_delta = huber_delta
|
|
3840
|
+
self.min_iters = min_iters # NEW
|
|
3841
|
+
self.star_mask_cfg = star_mask_cfg or {}
|
|
3842
|
+
self.varmap_cfg = varmap_cfg or {}
|
|
3843
|
+
self.use_star_masks = use_star_masks
|
|
3844
|
+
self.use_variance_maps = use_variance_maps
|
|
3845
|
+
self.rho = rho
|
|
3846
|
+
self.save_intermediate = save_intermediate
|
|
3847
|
+
self.super_res_factor = int(super_res_factor)
|
|
3848
|
+
self.sr_sigma = float(sr_sigma)
|
|
3849
|
+
self.sr_psf_opt_iters = int(sr_psf_opt_iters)
|
|
3850
|
+
self.sr_psf_opt_lr = float(sr_psf_opt_lr)
|
|
3851
|
+
self.star_mask_ref_path = star_mask_ref_path
|
|
3852
|
+
self.seed_mode = seed_mode
|
|
3853
|
+
|
|
3854
|
+
|
|
3855
|
+
def _log(self, s): self.progress.emit(s)
|
|
3856
|
+
|
|
3857
|
+
def run(self):
|
|
3858
|
+
try:
|
|
3859
|
+
out = multiframe_deconv(
|
|
3860
|
+
self.aligned_paths,
|
|
3861
|
+
self.output_path,
|
|
3862
|
+
iters=self.iters,
|
|
3863
|
+
kappa=self.kappa,
|
|
3864
|
+
color_mode=self.color_mode,
|
|
3865
|
+
seed_mode=self.seed_mode,
|
|
3866
|
+
huber_delta=self.huber_delta,
|
|
3867
|
+
use_star_masks=self.use_star_masks,
|
|
3868
|
+
use_variance_maps=self.use_variance_maps,
|
|
3869
|
+
rho=self.rho,
|
|
3870
|
+
min_iters=self.min_iters,
|
|
3871
|
+
status_cb=self._log,
|
|
3872
|
+
star_mask_cfg=self.star_mask_cfg,
|
|
3873
|
+
varmap_cfg=self.varmap_cfg,
|
|
3874
|
+
save_intermediate=self.save_intermediate,
|
|
3875
|
+
super_res_factor=self.super_res_factor,
|
|
3876
|
+
sr_sigma=self.sr_sigma,
|
|
3877
|
+
sr_psf_opt_iters=self.sr_psf_opt_iters,
|
|
3878
|
+
sr_psf_opt_lr=self.sr_psf_opt_lr,
|
|
3879
|
+
star_mask_ref_path=self.star_mask_ref_path,
|
|
3880
|
+
)
|
|
3881
|
+
self.finished.emit(True, "MF deconvolution complete.", out)
|
|
3882
|
+
_process_gui_events_safely()
|
|
3883
|
+
except Exception as e:
|
|
3884
|
+
self.finished.emit(False, f"MF deconvolution failed: {e}", "")
|
|
3885
|
+
finally:
|
|
3886
|
+
# Hard cleanup: drop references + free GPU memory
|
|
3887
|
+
try:
|
|
3888
|
+
# Drop big Python refs that might keep tensors alive indirectly
|
|
3889
|
+
self.aligned_paths = []
|
|
3890
|
+
self.star_mask_cfg = {}
|
|
3891
|
+
self.varmap_cfg = {}
|
|
3892
|
+
except Exception:
|
|
3893
|
+
pass
|
|
3894
|
+
try:
|
|
3895
|
+
_free_torch_memory() # your helper: del tensors, gc.collect(), etc.
|
|
3896
|
+
except Exception:
|
|
3897
|
+
pass
|
|
3898
|
+
try:
|
|
3899
|
+
import torch as _t
|
|
3900
|
+
if hasattr(_t, "cuda") and _t.cuda.is_available():
|
|
3901
|
+
_t.cuda.synchronize()
|
|
3902
|
+
_t.cuda.empty_cache()
|
|
3903
|
+
if hasattr(_t, "mps") and getattr(_t.backends, "mps", None) and _t.backends.mps.is_available():
|
|
3904
|
+
# PyTorch 2.x has this
|
|
3905
|
+
if hasattr(_t.mps, "empty_cache"):
|
|
3906
|
+
_t.mps.empty_cache()
|
|
3907
|
+
# DirectML usually frees on GC; nothing special to call.
|
|
3908
|
+
except Exception:
|
|
3909
|
+
pass
|