setiastrosuitepro 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (174) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/saspro/__init__.py +20 -0
  3. setiastro/saspro/__main__.py +784 -0
  4. setiastro/saspro/_generated/__init__.py +7 -0
  5. setiastro/saspro/_generated/build_info.py +2 -0
  6. setiastro/saspro/abe.py +1295 -0
  7. setiastro/saspro/abe_preset.py +196 -0
  8. setiastro/saspro/aberration_ai.py +694 -0
  9. setiastro/saspro/aberration_ai_preset.py +224 -0
  10. setiastro/saspro/accel_installer.py +218 -0
  11. setiastro/saspro/accel_workers.py +30 -0
  12. setiastro/saspro/add_stars.py +621 -0
  13. setiastro/saspro/astrobin_exporter.py +1007 -0
  14. setiastro/saspro/astrospike.py +153 -0
  15. setiastro/saspro/astrospike_python.py +1839 -0
  16. setiastro/saspro/autostretch.py +196 -0
  17. setiastro/saspro/backgroundneutral.py +560 -0
  18. setiastro/saspro/batch_convert.py +325 -0
  19. setiastro/saspro/batch_renamer.py +519 -0
  20. setiastro/saspro/blemish_blaster.py +488 -0
  21. setiastro/saspro/blink_comparator_pro.py +2923 -0
  22. setiastro/saspro/bundles.py +61 -0
  23. setiastro/saspro/bundles_dock.py +114 -0
  24. setiastro/saspro/cheat_sheet.py +168 -0
  25. setiastro/saspro/clahe.py +342 -0
  26. setiastro/saspro/comet_stacking.py +1377 -0
  27. setiastro/saspro/config.py +38 -0
  28. setiastro/saspro/config_bootstrap.py +40 -0
  29. setiastro/saspro/config_manager.py +316 -0
  30. setiastro/saspro/continuum_subtract.py +1617 -0
  31. setiastro/saspro/convo.py +1397 -0
  32. setiastro/saspro/convo_preset.py +414 -0
  33. setiastro/saspro/copyastro.py +187 -0
  34. setiastro/saspro/cosmicclarity.py +1564 -0
  35. setiastro/saspro/cosmicclarity_preset.py +407 -0
  36. setiastro/saspro/crop_dialog_pro.py +948 -0
  37. setiastro/saspro/crop_preset.py +189 -0
  38. setiastro/saspro/curve_editor_pro.py +2544 -0
  39. setiastro/saspro/curves_preset.py +375 -0
  40. setiastro/saspro/debayer.py +670 -0
  41. setiastro/saspro/debug_utils.py +29 -0
  42. setiastro/saspro/dnd_mime.py +35 -0
  43. setiastro/saspro/doc_manager.py +2634 -0
  44. setiastro/saspro/exoplanet_detector.py +2166 -0
  45. setiastro/saspro/file_utils.py +284 -0
  46. setiastro/saspro/fitsmodifier.py +744 -0
  47. setiastro/saspro/free_torch_memory.py +48 -0
  48. setiastro/saspro/frequency_separation.py +1343 -0
  49. setiastro/saspro/function_bundle.py +1594 -0
  50. setiastro/saspro/ghs_dialog_pro.py +660 -0
  51. setiastro/saspro/ghs_preset.py +284 -0
  52. setiastro/saspro/graxpert.py +634 -0
  53. setiastro/saspro/graxpert_preset.py +287 -0
  54. setiastro/saspro/gui/__init__.py +0 -0
  55. setiastro/saspro/gui/main_window.py +8494 -0
  56. setiastro/saspro/gui/mixins/__init__.py +33 -0
  57. setiastro/saspro/gui/mixins/dock_mixin.py +263 -0
  58. setiastro/saspro/gui/mixins/file_mixin.py +445 -0
  59. setiastro/saspro/gui/mixins/geometry_mixin.py +403 -0
  60. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  61. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  62. setiastro/saspro/gui/mixins/menu_mixin.py +361 -0
  63. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  64. setiastro/saspro/gui/mixins/toolbar_mixin.py +1324 -0
  65. setiastro/saspro/gui/mixins/update_mixin.py +309 -0
  66. setiastro/saspro/gui/mixins/view_mixin.py +435 -0
  67. setiastro/saspro/halobgon.py +462 -0
  68. setiastro/saspro/header_viewer.py +445 -0
  69. setiastro/saspro/headless_utils.py +88 -0
  70. setiastro/saspro/histogram.py +753 -0
  71. setiastro/saspro/history_explorer.py +939 -0
  72. setiastro/saspro/image_combine.py +414 -0
  73. setiastro/saspro/image_peeker_pro.py +1596 -0
  74. setiastro/saspro/imageops/__init__.py +37 -0
  75. setiastro/saspro/imageops/mdi_snap.py +292 -0
  76. setiastro/saspro/imageops/scnr.py +36 -0
  77. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  78. setiastro/saspro/imageops/stretch.py +244 -0
  79. setiastro/saspro/isophote.py +1179 -0
  80. setiastro/saspro/layers.py +208 -0
  81. setiastro/saspro/layers_dock.py +714 -0
  82. setiastro/saspro/lazy_imports.py +193 -0
  83. setiastro/saspro/legacy/__init__.py +2 -0
  84. setiastro/saspro/legacy/image_manager.py +2226 -0
  85. setiastro/saspro/legacy/numba_utils.py +3659 -0
  86. setiastro/saspro/legacy/xisf.py +1071 -0
  87. setiastro/saspro/linear_fit.py +534 -0
  88. setiastro/saspro/live_stacking.py +1830 -0
  89. setiastro/saspro/log_bus.py +5 -0
  90. setiastro/saspro/logging_config.py +460 -0
  91. setiastro/saspro/luminancerecombine.py +309 -0
  92. setiastro/saspro/main_helpers.py +201 -0
  93. setiastro/saspro/mask_creation.py +928 -0
  94. setiastro/saspro/masks_core.py +56 -0
  95. setiastro/saspro/mdi_widgets.py +353 -0
  96. setiastro/saspro/memory_utils.py +666 -0
  97. setiastro/saspro/metadata_patcher.py +75 -0
  98. setiastro/saspro/mfdeconv.py +3826 -0
  99. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  100. setiastro/saspro/mfdeconvcudnn.py +3263 -0
  101. setiastro/saspro/mfdeconvsport.py +2382 -0
  102. setiastro/saspro/minorbodycatalog.py +567 -0
  103. setiastro/saspro/morphology.py +382 -0
  104. setiastro/saspro/multiscale_decomp.py +1290 -0
  105. setiastro/saspro/nbtorgb_stars.py +531 -0
  106. setiastro/saspro/numba_utils.py +3044 -0
  107. setiastro/saspro/numba_warmup.py +141 -0
  108. setiastro/saspro/ops/__init__.py +9 -0
  109. setiastro/saspro/ops/command_help_dialog.py +623 -0
  110. setiastro/saspro/ops/command_runner.py +217 -0
  111. setiastro/saspro/ops/commands.py +1594 -0
  112. setiastro/saspro/ops/script_editor.py +1102 -0
  113. setiastro/saspro/ops/scripts.py +1413 -0
  114. setiastro/saspro/ops/settings.py +560 -0
  115. setiastro/saspro/parallel_utils.py +554 -0
  116. setiastro/saspro/pedestal.py +121 -0
  117. setiastro/saspro/perfect_palette_picker.py +1053 -0
  118. setiastro/saspro/pipeline.py +110 -0
  119. setiastro/saspro/pixelmath.py +1600 -0
  120. setiastro/saspro/plate_solver.py +2435 -0
  121. setiastro/saspro/project_io.py +797 -0
  122. setiastro/saspro/psf_utils.py +136 -0
  123. setiastro/saspro/psf_viewer.py +549 -0
  124. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  125. setiastro/saspro/remove_green.py +314 -0
  126. setiastro/saspro/remove_stars.py +1625 -0
  127. setiastro/saspro/remove_stars_preset.py +404 -0
  128. setiastro/saspro/resources.py +472 -0
  129. setiastro/saspro/rgb_combination.py +207 -0
  130. setiastro/saspro/rgb_extract.py +19 -0
  131. setiastro/saspro/rgbalign.py +723 -0
  132. setiastro/saspro/runtime_imports.py +7 -0
  133. setiastro/saspro/runtime_torch.py +754 -0
  134. setiastro/saspro/save_options.py +72 -0
  135. setiastro/saspro/selective_color.py +1552 -0
  136. setiastro/saspro/sfcc.py +1425 -0
  137. setiastro/saspro/shortcuts.py +2807 -0
  138. setiastro/saspro/signature_insert.py +1099 -0
  139. setiastro/saspro/stacking_suite.py +17712 -0
  140. setiastro/saspro/star_alignment.py +7420 -0
  141. setiastro/saspro/star_alignment_preset.py +329 -0
  142. setiastro/saspro/star_metrics.py +49 -0
  143. setiastro/saspro/star_spikes.py +681 -0
  144. setiastro/saspro/star_stretch.py +470 -0
  145. setiastro/saspro/stat_stretch.py +502 -0
  146. setiastro/saspro/status_log_dock.py +78 -0
  147. setiastro/saspro/subwindow.py +3267 -0
  148. setiastro/saspro/supernovaasteroidhunter.py +1712 -0
  149. setiastro/saspro/swap_manager.py +99 -0
  150. setiastro/saspro/torch_backend.py +89 -0
  151. setiastro/saspro/torch_rejection.py +434 -0
  152. setiastro/saspro/view_bundle.py +1555 -0
  153. setiastro/saspro/wavescale_hdr.py +624 -0
  154. setiastro/saspro/wavescale_hdr_preset.py +100 -0
  155. setiastro/saspro/wavescalede.py +657 -0
  156. setiastro/saspro/wavescalede_preset.py +228 -0
  157. setiastro/saspro/wcs_update.py +374 -0
  158. setiastro/saspro/whitebalance.py +456 -0
  159. setiastro/saspro/widgets/__init__.py +48 -0
  160. setiastro/saspro/widgets/common_utilities.py +305 -0
  161. setiastro/saspro/widgets/graphics_views.py +122 -0
  162. setiastro/saspro/widgets/image_utils.py +518 -0
  163. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  164. setiastro/saspro/widgets/spinboxes.py +275 -0
  165. setiastro/saspro/widgets/themed_buttons.py +13 -0
  166. setiastro/saspro/widgets/wavelet_utils.py +299 -0
  167. setiastro/saspro/window_shelf.py +185 -0
  168. setiastro/saspro/xisf.py +1123 -0
  169. setiastrosuitepro-1.6.0.dist-info/METADATA +266 -0
  170. setiastrosuitepro-1.6.0.dist-info/RECORD +174 -0
  171. setiastrosuitepro-1.6.0.dist-info/WHEEL +4 -0
  172. setiastrosuitepro-1.6.0.dist-info/entry_points.txt +6 -0
  173. setiastrosuitepro-1.6.0.dist-info/licenses/LICENSE +674 -0
  174. setiastrosuitepro-1.6.0.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,99 @@
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import uuid
5
+ import pickle
6
+ import atexit
7
+ import threading
8
+ import numpy as np
9
+
10
+ class SwapManager:
11
+ _instance = None
12
+ _lock = threading.Lock()
13
+
14
+ def __new__(cls, *args, **kwargs):
15
+ with cls._lock:
16
+ if cls._instance is None:
17
+ cls._instance = super(SwapManager, cls).__new__(cls)
18
+ cls._instance._initialized = False
19
+ return cls._instance
20
+
21
+ def __init__(self):
22
+ if self._initialized:
23
+ return
24
+ self._initialized = True
25
+
26
+ # Create a unique temp directory for this session
27
+ self.temp_dir = os.path.join(tempfile.gettempdir(), "SetiAstroSuitePro_Swap", str(uuid.uuid4()))
28
+ os.makedirs(self.temp_dir, exist_ok=True)
29
+
30
+ # Register cleanup on exit
31
+ atexit.register(self.cleanup_all)
32
+
33
+ def get_swap_path(self, swap_id: str) -> str:
34
+ return os.path.join(self.temp_dir, f"{swap_id}.swap")
35
+
36
+ def save_state(self, image: np.ndarray) -> str:
37
+ """
38
+ Save the image array to a swap file.
39
+ Returns the unique swap_id.
40
+ """
41
+ swap_id = uuid.uuid4().hex
42
+ path = self.get_swap_path(swap_id)
43
+
44
+ # We only save the image data to disk. Metadata is kept in RAM by the caller.
45
+ # Using pickle for simplicity and robustness with numpy arrays.
46
+ # For pure numpy arrays, np.save might be slightly faster, but pickle is more flexible if we change what we store.
47
+ # Let's stick to pickle for now as per plan.
48
+ try:
49
+ with open(path, "wb") as f:
50
+ pickle.dump(image, f, protocol=pickle.HIGHEST_PROTOCOL)
51
+ except Exception as e:
52
+ print(f"[SwapManager] Failed to save state {swap_id}: {e}")
53
+ return None
54
+
55
+ return swap_id
56
+
57
+ def load_state(self, swap_id: str) -> np.ndarray | None:
58
+ """
59
+ Load the image array from the swap file.
60
+ """
61
+ path = self.get_swap_path(swap_id)
62
+ if not os.path.exists(path):
63
+ print(f"[SwapManager] Swap file not found: {path}")
64
+ return None
65
+
66
+ try:
67
+ with open(path, "rb") as f:
68
+ return pickle.load(f)
69
+ except Exception as e:
70
+ print(f"[SwapManager] Failed to load state {swap_id}: {e}")
71
+ return None
72
+
73
+ def delete_state(self, swap_id: str):
74
+ """
75
+ Delete a specific swap file.
76
+ """
77
+ path = self.get_swap_path(swap_id)
78
+ try:
79
+ if os.path.exists(path):
80
+ os.remove(path)
81
+ except Exception as e:
82
+ print(f"[SwapManager] Failed to delete state {swap_id}: {e}")
83
+
84
+ def cleanup_all(self):
85
+ """
86
+ Delete the entire temporary directory for this session.
87
+ """
88
+ try:
89
+ if os.path.exists(self.temp_dir):
90
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
91
+ # print(f"[SwapManager] Cleaned up {self.temp_dir}")
92
+ except Exception as e:
93
+ print(f"[SwapManager] Cleanup failed: {e}")
94
+
95
+ # Global instance
96
+ _swap_mgr = SwapManager()
97
+
98
+ def get_swap_manager():
99
+ return _swap_mgr
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+ import contextlib
3
+ import os
4
+
5
+ # Resolve a single "torch-like" object for the whole app
6
+ # Try your preferred/backed build order first.
7
+ _TORCH = None
8
+ _err = None
9
+
10
+ # If you vendor or rename your build, try that FIRST (example):
11
+ # try:
12
+ # import mybundled.torch as torch
13
+ # _TORCH = torch
14
+ # except Exception as e:
15
+ # _err = e
16
+
17
+ if _TORCH is None:
18
+ try:
19
+ import torch # system/packaged torch
20
+ _TORCH = torch
21
+ except Exception as e:
22
+ _err = e
23
+
24
+ # Optional: DirectML fallback on Windows (comment out if not needed)
25
+ if _TORCH is None:
26
+ try:
27
+ import torch_directml as torch # pip install torch-directml
28
+ _TORCH = torch
29
+ except Exception:
30
+ pass
31
+
32
+ def has_torch() -> bool:
33
+ return _TORCH is not None
34
+
35
+ def torch_module():
36
+ """Return the torch module or None."""
37
+ return _TORCH
38
+
39
+ def pick_device():
40
+ """Pick best available device. Returns None if no torch."""
41
+ if _TORCH is None:
42
+ return None
43
+ try:
44
+ if hasattr(_TORCH, "cuda") and _TORCH.cuda.is_available():
45
+ return _TORCH.device("cuda")
46
+ except Exception:
47
+ pass
48
+ try:
49
+ mps = getattr(getattr(_TORCH, "backends", None), "mps", None)
50
+ if mps and getattr(mps, "is_available", lambda: False)():
51
+ return _TORCH.device("mps")
52
+ except Exception:
53
+ pass
54
+ return _TORCH.device("cpu")
55
+
56
+ def no_grad_decorator():
57
+ """
58
+ Returns a decorator:
59
+ • If torch exists, returns torch.no_grad()
60
+ • Else, identity decorator
61
+ """
62
+ if _TORCH and hasattr(_TORCH, "no_grad"):
63
+ return _TORCH.no_grad()
64
+ def _identity(fn): return fn
65
+ return _identity
66
+
67
+ def inference_ctx():
68
+ """
69
+ Returns a context manager for inference if available (torch.inference_mode),
70
+ else a no-op context manager.
71
+ """
72
+ if _TORCH and hasattr(_TORCH, "inference_mode"):
73
+ return _TORCH.inference_mode()
74
+ return contextlib.nullcontext()
75
+
76
+ def free_torch_memory():
77
+ """Best-effort GPU memory cleanup."""
78
+ if _TORCH is None:
79
+ return
80
+ try:
81
+ if hasattr(_TORCH, "cuda") and hasattr(_TORCH.cuda, "empty_cache"):
82
+ _TORCH.cuda.empty_cache()
83
+ except Exception:
84
+ pass
85
+ try:
86
+ if hasattr(_TORCH, "mps") and hasattr(_TORCH.mps, "empty_cache"):
87
+ _TORCH.mps.empty_cache()
88
+ except Exception:
89
+ pass
@@ -0,0 +1,434 @@
1
+ # pro/torch_rejection.py
2
+ from __future__ import annotations
3
+ import contextlib
4
+ import numpy as np
5
+
6
+ # Always route through our runtime shim so ALL GPU users share the same backend.
7
+ # Nothing heavy happens at import; we only resolve Torch when needed.
8
+ from .runtime_torch import import_torch, add_runtime_to_sys_path
9
+
10
+ # Algorithms supported by the GPU path here (names match your UI/CPU counterparts)
11
+ _SUPPORTED = {
12
+ "Comet Median",
13
+ "Simple Median (No Rejection)",
14
+ "Comet High-Clip Percentile",
15
+ "Comet Lower-Trim (30%)",
16
+ "Comet Percentile (40th)",
17
+ "Simple Average (No Rejection)",
18
+ "Weighted Windsorized Sigma Clipping",
19
+ "Windsorized Sigma Clipping", # <<< NEW (unweighted)
20
+ "Kappa-Sigma Clipping",
21
+ "Trimmed Mean",
22
+ "Extreme Studentized Deviate (ESD)",
23
+ "Biweight Estimator",
24
+ "Modified Z-Score Clipping",
25
+ "Max Value",
26
+ }
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Lazy Torch resolution (so PyInstaller bootstrap and non-GPU users don’t break)
30
+ # ---------------------------------------------------------------------------
31
+ _TORCH = None
32
+ _DEVICE = None
33
+
34
+ def _get_torch(prefer_cuda: bool = True):
35
+ """
36
+ Resolve and cache the torch module via the SAS runtime shim.
37
+ This may install/repair torch into the per-user runtime if needed.
38
+ """
39
+ global _TORCH, _DEVICE
40
+ if _TORCH is not None:
41
+ return _TORCH
42
+
43
+ # In frozen builds, help the process see the runtime site-packages first.
44
+ try:
45
+ add_runtime_to_sys_path(lambda *_: None)
46
+ except Exception:
47
+ pass
48
+
49
+ # Import (and if necessary, install) torch using the unified runtime.
50
+ torch = import_torch(prefer_cuda=prefer_cuda, status_cb=lambda *_: None)
51
+ _TORCH = torch
52
+ _force_fp32_policy(torch)
53
+
54
+ # Choose the best device once; cheap calls, but cached anyway
55
+ try:
56
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
57
+ _DEVICE = torch.device("cuda")
58
+ elif getattr(getattr(torch, "backends", None), "mps", None) and torch.backends.mps.is_available():
59
+ _DEVICE = torch.device("mps")
60
+ else:
61
+ # Try DirectML for AMD/Intel GPUs on Windows
62
+ try:
63
+ import torch_directml
64
+ dml_device = torch_directml.device()
65
+ # Quick sanity check
66
+ _ = (torch.ones(1, device=dml_device) + 1).item()
67
+ _DEVICE = dml_device
68
+ except Exception:
69
+ _DEVICE = torch.device("cpu")
70
+ except Exception:
71
+ _DEVICE = torch.device("cpu")
72
+
73
+ return _TORCH
74
+
75
+ def _device():
76
+ if _DEVICE is not None:
77
+ return _DEVICE
78
+ # Default to CPU if torch is not yet resolved; the first GPU call resolves it.
79
+ return None
80
+
81
+ def torch_available() -> bool:
82
+ """Return True iff we can import/resolve torch via the runtime shim."""
83
+ try:
84
+ _get_torch(prefer_cuda=True)
85
+ return True
86
+ except Exception:
87
+ return False
88
+
89
+ def gpu_algo_supported(algo_name: str) -> bool:
90
+ return algo_name in _SUPPORTED
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Helpers (nan-safe reducers) – assume torch is available *inside* callers
94
+ # ---------------------------------------------------------------------------
95
+ def _nanmedian(torch, x, dim: int):
96
+ try:
97
+ return torch.nanmedian(x, dim=dim).values
98
+ except Exception:
99
+ m = torch.isfinite(x)
100
+ x2 = x.clone()
101
+ x2[~m] = float("inf")
102
+ idx = x2.argsort(dim=dim)
103
+ cnt = m.sum(dim=dim).clamp_min(1)
104
+ mid = (cnt - 1) // 2
105
+ gather_idx = idx.gather(dim, mid.unsqueeze(dim))
106
+ return x.gather(dim, gather_idx).squeeze(dim)
107
+
108
+ def _nanstd(torch, x, dim: int):
109
+ try:
110
+ return torch.nanstd(x, dim=dim, unbiased=False)
111
+ except Exception:
112
+ m = torch.isfinite(x)
113
+ cnt = m.sum(dim=dim).clamp_min(1)
114
+ s1 = torch.where(m, x, torch.zeros_like(x)).sum(dim=dim)
115
+ s2 = torch.where(m, x * x, torch.zeros_like(x)).sum(dim=dim)
116
+ mean = s1 / cnt
117
+ var = (s2 / cnt) - mean * mean
118
+ return var.clamp_min(0).sqrt()
119
+
120
+ def _nanquantile(torch, x, q: float, dim: int):
121
+ try:
122
+ return torch.nanquantile(x, q, dim=dim)
123
+ except Exception:
124
+ m = torch.isfinite(x)
125
+ x2 = x.clone()
126
+ x2[~m] = float("inf")
127
+ idx = x2.argsort(dim=dim)
128
+ n = m.sum(dim=dim).clamp_min(1)
129
+ kth = (q * (n - 1)).round().to(torch.long)
130
+ kth = kth.clamp(min=0)
131
+ gather_idx = idx.gather(dim, kth.unsqueeze(dim))
132
+ return x.gather(dim, gather_idx).squeeze(dim)
133
+
134
+ def _no_amp_ctx(torch, dev):
135
+ """
136
+ Return a context that disables autocast on this thread for the current device.
137
+ Works across torch 1.13–2.x and CUDA/CPU/MPS. No-ops if unsupported.
138
+ """
139
+ import contextlib
140
+ # PyTorch 2.x unified API
141
+ try:
142
+ ac = getattr(torch, "autocast", None)
143
+ if ac is not None:
144
+ dt = "cuda" if getattr(dev, "type", "") == "cuda" else \
145
+ "mps" if getattr(dev, "type", "") == "mps" else "cpu"
146
+ return ac(device_type=dt, enabled=False)
147
+ except Exception:
148
+ pass
149
+ # Older CUDA AMP API
150
+ try:
151
+ amp = getattr(getattr(torch, "cuda", None), "amp", None)
152
+ if amp and hasattr(amp, "autocast"):
153
+ return amp.autocast(enabled=False)
154
+ except Exception:
155
+ pass
156
+ return contextlib.nullcontext()
157
+
158
+
159
+ # --- add near the top (after imports) ---
160
+ def _safe_inference_ctx(torch):
161
+ """
162
+ Return a context manager for inference that won't explode on older or
163
+ backend-variant Torch builds (DirectML/MPS/CPU-only).
164
+ """
165
+ try:
166
+ # Prefer inference_mode if both the API and C++ backend support it
167
+ if getattr(torch, "inference_mode", None) is not None:
168
+ _C = getattr(torch, "_C", None)
169
+ if _C is not None and hasattr(_C, "_InferenceMode"):
170
+ return torch.inference_mode()
171
+ except Exception:
172
+ pass
173
+ # Fallbacks
174
+ if getattr(torch, "no_grad", None) is not None:
175
+ return torch.no_grad()
176
+ import contextlib
177
+ return contextlib.nullcontext()
178
+
179
+ def _force_fp32_policy(torch):
180
+ try:
181
+ # default dtype for new tensors (does not upcast existing)
182
+ torch.set_default_dtype(torch.float32)
183
+ except Exception:
184
+ pass
185
+ # disable “helpful” lower-precision math
186
+ try:
187
+ if hasattr(torch.backends, "cudnn"):
188
+ torch.backends.cudnn.allow_tf32 = False
189
+ if hasattr(torch.backends, "cuda") and hasattr(torch.backends.cuda, "matmul"):
190
+ torch.backends.cuda.matmul.allow_tf32 = False
191
+ except Exception:
192
+ pass
193
+ try:
194
+ # prefer strict fp32 matmul kernels where supported
195
+ if hasattr(torch, "set_float32_matmul_precision"):
196
+ torch.set_float32_matmul_precision("highest")
197
+ except Exception:
198
+ pass
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Public GPU reducer – lazy-loads Torch, never decorates at import time
202
+ # ---------------------------------------------------------------------------
203
+ def torch_reduce_tile(
204
+ ts_np: np.ndarray, # (F, th, tw, C) or (F, th, tw) -> treated as C=1
205
+ weights_np: np.ndarray, # (F,) or (F, th, tw, C)
206
+ *,
207
+ algo_name: str,
208
+ kappa: float = 2.5,
209
+ iterations: int = 3,
210
+ sigma_low: float = 2.5, # for winsorized
211
+ sigma_high: float = 2.5, # for winsorized
212
+ trim_fraction: float = 0.1, # for trimmed mean
213
+ esd_threshold: float = 3.0, # for ESD
214
+ biweight_constant: float = 6.0, # for biweight
215
+ modz_threshold: float = 3.5, # for modified z
216
+ comet_hclip_k: float = 1.30, # for comet high-clip percentile
217
+ comet_hclip_p: float = 25.0, # for comet high-clip percentile
218
+ ) -> tuple[np.ndarray, np.ndarray]:
219
+ """
220
+ Returns: (tile_result, tile_rej_map)
221
+ tile_result: (th, tw, C) float32
222
+ tile_rej_map: (F, th, tw, C) bool (collapse C on caller if needed)
223
+ """
224
+ # Resolve torch on demand, using the SAME backend as the rest of the app.
225
+ torch = _get_torch(prefer_cuda=True)
226
+ dev = _device() or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
227
+
228
+ # Normalize shape to 4D float32
229
+ ts_np = np.asarray(ts_np, dtype=np.float32)
230
+ if ts_np.ndim == 3:
231
+ ts_np = ts_np[..., None]
232
+ F, H, W, C = ts_np.shape
233
+
234
+ if H == 0 or W == 0 or C < 1:
235
+ raise ValueError(
236
+ f"torch_reduce_tile received degenerate tile shape={ts_np.shape}. "
237
+ "This usually means a bad edge tile or corrupted frame; "
238
+ "try disabling GPU rejection or reducing chunk size."
239
+ )
240
+
241
+ # Sanity check: C must be at least 1
242
+ if C < 1:
243
+ raise ValueError(f"torch_reduce_tile received input with C={C} channels (shape={ts_np.shape}). Expected C >= 1.")
244
+
245
+ # Host → device
246
+ ts = torch.from_numpy(ts_np).to(dev, dtype=torch.float32, non_blocking=True)
247
+
248
+ # Weights broadcast to 4D
249
+ weights_np = np.asarray(weights_np, dtype=np.float32)
250
+ if weights_np.ndim == 1:
251
+ w = torch.from_numpy(weights_np).to(dev, dtype=torch.float32, non_blocking=True).view(F,1,1,1)
252
+ else:
253
+ w = torch.from_numpy(weights_np).to(dev, dtype=torch.float32, non_blocking=True)
254
+
255
+ algo = algo_name
256
+ valid = torch.isfinite(ts)
257
+
258
+ # Use inference_mode if present; else nullcontext.
259
+ with _safe_inference_ctx(torch), _no_amp_ctx(torch, dev):
260
+ # ---------------- simple, no-rejection reducers ----------------
261
+ if algo in ("Comet Median", "Simple Median (No Rejection)"):
262
+ out = ts.median(dim=0).values
263
+ rej = torch.zeros((F, H, W, C), dtype=torch.bool, device=dev)
264
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
265
+
266
+ if algo == "Comet Percentile (40th)":
267
+ out = _nanquantile(torch, ts, 0.40, dim=0)
268
+ rej = torch.zeros((F, H, W, C), dtype=torch.bool, device=dev)
269
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
270
+
271
+ if algo_name == "Windsorized Sigma Clipping":
272
+ # Unweighted: mask by k*sigma around median, then plain mean of survivors
273
+ low = float(sigma_low)
274
+ high = float(sigma_high)
275
+ valid = torch.isfinite(ts)
276
+
277
+ keep = valid.clone()
278
+ for _ in range(int(iterations)):
279
+ x = ts.masked_fill(~keep, float("nan"))
280
+ med = _nanmedian(torch, x, dim=0) # (H,W,C)
281
+ std = _nanstd(torch, x, dim=0) # (H,W,C)
282
+ lo = med - low * std
283
+ hi = med + high * std
284
+ keep = valid & (ts >= lo.unsqueeze(0)) & (ts <= hi.unsqueeze(0))
285
+
286
+ # Numerator/denominator over frames -> (H,W,C)
287
+ kept = torch.where(keep, ts, torch.zeros_like(ts))
288
+ num = kept.sum(dim=0) # (H,W,C)
289
+ cnt = keep.sum(dim=0).to(ts.dtype) # (H,W,C)
290
+
291
+ # Fallback to nanmedian where nothing survived
292
+ x = ts.masked_fill(~valid, float("nan"))
293
+ fallback = _nanmedian(torch, x, dim=0) # (H,W,C)
294
+
295
+ out = torch.where(cnt <= 0, fallback, num / cnt.clamp_min(1))
296
+ rej = ~keep
297
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
298
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
299
+
300
+
301
+
302
+ if algo == "Comet Lower-Trim (30%)":
303
+ n = torch.isfinite(ts).sum(dim=0).clamp_min(1)
304
+ k_keep = torch.floor(n * (1.0 - 0.30)).to(torch.long).clamp(min=1)
305
+ vals, idx = ts.sort(dim=0, stable=True)
306
+ arangeF = torch.arange(F, device=dev).view(F, 1, 1, 1).expand_as(vals)
307
+ keep = arangeF < k_keep.unsqueeze(0).expand_as(vals)
308
+ den = keep.sum(dim=0).clamp_min(1).to(vals.dtype)
309
+ out = (vals * keep).sum(dim=0) / den
310
+ keep_orig = torch.zeros_like(keep)
311
+ keep_orig.scatter_(0, idx, keep)
312
+ rej = ~keep_orig
313
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
314
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
315
+
316
+ if algo == "Comet High-Clip Percentile":
317
+ med = _nanmedian(torch, ts, dim=0)
318
+ mad = _nanmedian(torch, (ts - med.unsqueeze(0)).abs(), dim=0) + 1e-6
319
+ hi = med + (float(comet_hclip_k) * 1.4826 * mad)
320
+ clipped = torch.minimum(ts, hi.unsqueeze(0))
321
+ out = _nanquantile(torch, clipped, float(comet_hclip_p) / 100.0, dim=0)
322
+ rej = torch.zeros((F, H, W, C), dtype=torch.bool, device=dev)
323
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
324
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
325
+
326
+ if algo == "Simple Average (No Rejection)":
327
+ num = (ts * w).sum(dim=0)
328
+ den = w.sum(dim=0).clamp_min(1e-20)
329
+ out = (num / den)
330
+ rej = torch.zeros((F, H, W, C), dtype=torch.bool, device=dev)
331
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
332
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
333
+
334
+ if algo == "Max Value":
335
+ out = ts.max(dim=0).values
336
+ rej = torch.zeros((F, H, W, C), dtype=torch.bool, device=dev)
337
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
338
+
339
+ # ---------------- rejection-based reducers ----------------
340
+ if algo == "Kappa-Sigma Clipping":
341
+ keep = valid.clone()
342
+ for _ in range(int(iterations)):
343
+ x = ts.masked_fill(~keep, float("nan"))
344
+ med = _nanmedian(torch, x, dim=0)
345
+ std = _nanstd(torch, x, dim=0)
346
+ lo = med - float(kappa) * std
347
+ hi = med + float(kappa) * std
348
+ keep = valid & (ts >= lo.unsqueeze(0)) & (ts <= hi.unsqueeze(0))
349
+ w_eff = torch.where(keep, w, torch.zeros_like(w))
350
+ den = w_eff.sum(dim=0).clamp_min(1e-20)
351
+ out = (ts.mul(w_eff)).sum(dim=0).div(den)
352
+ rej = ~keep
353
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
354
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
355
+
356
+ if algo == "Weighted Windsorized Sigma Clipping":
357
+ low = float(sigma_low); high = float(sigma_high)
358
+ keep = valid.clone()
359
+ for _ in range(int(iterations)):
360
+ x = ts.masked_fill(~keep, float("nan"))
361
+ med = _nanmedian(torch, x, dim=0)
362
+ std = _nanstd(torch, x, dim=0)
363
+ lo = med - low * std
364
+ hi = med + high * std
365
+ keep = valid & (ts >= lo.unsqueeze(0)) & (ts <= hi.unsqueeze(0))
366
+ w_eff = torch.where(keep, w, torch.zeros_like(w))
367
+ den = w_eff.sum(dim=0)
368
+ num = (ts * w_eff).sum(dim=0)
369
+ out = torch.empty((H, W, C), dtype=ts.dtype, device=dev)
370
+ mask_no = den <= 0
371
+ if mask_no.any():
372
+ x = ts.masked_fill(~valid, float("nan"))
373
+ out_fallback = _nanmedian(torch, x, dim=0)
374
+ out[mask_no] = out_fallback[mask_no]
375
+ if (~mask_no).any():
376
+ out[~mask_no] = (num[~mask_no] / den[~mask_no])
377
+ rej = ~keep
378
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
379
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
380
+
381
+ if algo == "Trimmed Mean":
382
+ x = ts.masked_fill(~valid, float("nan"))
383
+ qlo = _nanquantile(torch, x, trim_fraction, dim=0)
384
+ qhi = _nanquantile(torch, x, 1.0 - trim_fraction, dim=0)
385
+ keep = valid & (ts >= qlo.unsqueeze(0)) & (ts <= qhi.unsqueeze(0))
386
+ w_eff = torch.where(keep, w, torch.zeros_like(w))
387
+ den = w_eff.sum(dim=0).clamp_min(1e-20)
388
+ out = (ts.mul(w_eff)).sum(dim=0).div(den)
389
+ rej = ~keep
390
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
391
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
392
+
393
+ if algo == "Extreme Studentized Deviate (ESD)":
394
+ x = ts.masked_fill(~valid, float("nan"))
395
+ mean = torch.where(torch.isfinite(x), x, torch.zeros_like(x)).nanmean(dim=0)
396
+ std = _nanstd(torch, x, dim=0).clamp_min(1e-12)
397
+ z = (ts - mean.unsqueeze(0)).abs() / std.unsqueeze(0)
398
+ keep = valid & (z < float(esd_threshold))
399
+ w_eff = torch.where(keep, w, torch.zeros_like(w))
400
+ den = w_eff.sum(dim=0).clamp_min(1e-20)
401
+ out = (ts.mul(w_eff)).sum(dim=0).div(den)
402
+ rej = ~keep
403
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
404
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
405
+
406
+ if algo == "Biweight Estimator":
407
+ x = ts
408
+ m = _nanmedian(torch, x.masked_fill(~valid, float("nan")), dim=0)
409
+ mad = _nanmedian(torch, (x - m.unsqueeze(0)).abs().masked_fill(~valid, float("nan")), dim=0) + 1e-12
410
+ u = (x - m.unsqueeze(0)) / (float(biweight_constant) * mad.unsqueeze(0))
411
+ mask = valid & (u.abs() < 1.0)
412
+ w_eff = torch.where(mask, w, torch.zeros_like(w))
413
+ one_minus_u2 = (1 - u**2).clamp_min(0)
414
+ num = ((x - m.unsqueeze(0)) * (one_minus_u2**2) * w_eff).sum(dim=0)
415
+ den = ((one_minus_u2**2) * w_eff).sum(dim=0)
416
+ out = torch.where(den > 0, m + num / den, m)
417
+ rej = ~mask
418
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
419
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
420
+
421
+ if algo == "Modified Z-Score Clipping":
422
+ x = ts
423
+ med = _nanmedian(torch, x.masked_fill(~valid, float("nan")), dim=0)
424
+ mad = _nanmedian(torch, (x - med.unsqueeze(0)).abs().masked_fill(~valid, float("nan")), dim=0) + 1e-12
425
+ mz = 0.6745 * (x - med.unsqueeze(0)) / mad.unsqueeze(0)
426
+ keep = valid & (mz.abs() < float(modz_threshold))
427
+ w_eff = torch.where(keep, w, torch.zeros_like(w))
428
+ den = w_eff.sum(dim=0).clamp_min(1e-20)
429
+ out = (ts.mul(w_eff)).sum(dim=0).div(den)
430
+ rej = ~keep
431
+ assert out.dtype == torch.float32, f"reducer produced {out.dtype}, expected float32"
432
+ return out.to(dtype=torch.float32).contiguous().cpu().numpy(), rej.cpu().numpy()
433
+
434
+ raise NotImplementedError(f"GPU path not implemented for: {algo_name}")