setiastrosuitepro 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of setiastrosuitepro might be problematic. Click here for more details.

Files changed (174) hide show
  1. setiastro/__init__.py +2 -0
  2. setiastro/saspro/__init__.py +20 -0
  3. setiastro/saspro/__main__.py +784 -0
  4. setiastro/saspro/_generated/__init__.py +7 -0
  5. setiastro/saspro/_generated/build_info.py +2 -0
  6. setiastro/saspro/abe.py +1295 -0
  7. setiastro/saspro/abe_preset.py +196 -0
  8. setiastro/saspro/aberration_ai.py +694 -0
  9. setiastro/saspro/aberration_ai_preset.py +224 -0
  10. setiastro/saspro/accel_installer.py +218 -0
  11. setiastro/saspro/accel_workers.py +30 -0
  12. setiastro/saspro/add_stars.py +621 -0
  13. setiastro/saspro/astrobin_exporter.py +1007 -0
  14. setiastro/saspro/astrospike.py +153 -0
  15. setiastro/saspro/astrospike_python.py +1839 -0
  16. setiastro/saspro/autostretch.py +196 -0
  17. setiastro/saspro/backgroundneutral.py +560 -0
  18. setiastro/saspro/batch_convert.py +325 -0
  19. setiastro/saspro/batch_renamer.py +519 -0
  20. setiastro/saspro/blemish_blaster.py +488 -0
  21. setiastro/saspro/blink_comparator_pro.py +2923 -0
  22. setiastro/saspro/bundles.py +61 -0
  23. setiastro/saspro/bundles_dock.py +114 -0
  24. setiastro/saspro/cheat_sheet.py +168 -0
  25. setiastro/saspro/clahe.py +342 -0
  26. setiastro/saspro/comet_stacking.py +1377 -0
  27. setiastro/saspro/config.py +38 -0
  28. setiastro/saspro/config_bootstrap.py +40 -0
  29. setiastro/saspro/config_manager.py +316 -0
  30. setiastro/saspro/continuum_subtract.py +1617 -0
  31. setiastro/saspro/convo.py +1397 -0
  32. setiastro/saspro/convo_preset.py +414 -0
  33. setiastro/saspro/copyastro.py +187 -0
  34. setiastro/saspro/cosmicclarity.py +1564 -0
  35. setiastro/saspro/cosmicclarity_preset.py +407 -0
  36. setiastro/saspro/crop_dialog_pro.py +948 -0
  37. setiastro/saspro/crop_preset.py +189 -0
  38. setiastro/saspro/curve_editor_pro.py +2544 -0
  39. setiastro/saspro/curves_preset.py +375 -0
  40. setiastro/saspro/debayer.py +670 -0
  41. setiastro/saspro/debug_utils.py +29 -0
  42. setiastro/saspro/dnd_mime.py +35 -0
  43. setiastro/saspro/doc_manager.py +2634 -0
  44. setiastro/saspro/exoplanet_detector.py +2166 -0
  45. setiastro/saspro/file_utils.py +284 -0
  46. setiastro/saspro/fitsmodifier.py +744 -0
  47. setiastro/saspro/free_torch_memory.py +48 -0
  48. setiastro/saspro/frequency_separation.py +1343 -0
  49. setiastro/saspro/function_bundle.py +1594 -0
  50. setiastro/saspro/ghs_dialog_pro.py +660 -0
  51. setiastro/saspro/ghs_preset.py +284 -0
  52. setiastro/saspro/graxpert.py +634 -0
  53. setiastro/saspro/graxpert_preset.py +287 -0
  54. setiastro/saspro/gui/__init__.py +0 -0
  55. setiastro/saspro/gui/main_window.py +8494 -0
  56. setiastro/saspro/gui/mixins/__init__.py +33 -0
  57. setiastro/saspro/gui/mixins/dock_mixin.py +263 -0
  58. setiastro/saspro/gui/mixins/file_mixin.py +445 -0
  59. setiastro/saspro/gui/mixins/geometry_mixin.py +403 -0
  60. setiastro/saspro/gui/mixins/header_mixin.py +441 -0
  61. setiastro/saspro/gui/mixins/mask_mixin.py +421 -0
  62. setiastro/saspro/gui/mixins/menu_mixin.py +361 -0
  63. setiastro/saspro/gui/mixins/theme_mixin.py +367 -0
  64. setiastro/saspro/gui/mixins/toolbar_mixin.py +1324 -0
  65. setiastro/saspro/gui/mixins/update_mixin.py +309 -0
  66. setiastro/saspro/gui/mixins/view_mixin.py +435 -0
  67. setiastro/saspro/halobgon.py +462 -0
  68. setiastro/saspro/header_viewer.py +445 -0
  69. setiastro/saspro/headless_utils.py +88 -0
  70. setiastro/saspro/histogram.py +753 -0
  71. setiastro/saspro/history_explorer.py +939 -0
  72. setiastro/saspro/image_combine.py +414 -0
  73. setiastro/saspro/image_peeker_pro.py +1596 -0
  74. setiastro/saspro/imageops/__init__.py +37 -0
  75. setiastro/saspro/imageops/mdi_snap.py +292 -0
  76. setiastro/saspro/imageops/scnr.py +36 -0
  77. setiastro/saspro/imageops/starbasedwhitebalance.py +210 -0
  78. setiastro/saspro/imageops/stretch.py +244 -0
  79. setiastro/saspro/isophote.py +1179 -0
  80. setiastro/saspro/layers.py +208 -0
  81. setiastro/saspro/layers_dock.py +714 -0
  82. setiastro/saspro/lazy_imports.py +193 -0
  83. setiastro/saspro/legacy/__init__.py +2 -0
  84. setiastro/saspro/legacy/image_manager.py +2226 -0
  85. setiastro/saspro/legacy/numba_utils.py +3659 -0
  86. setiastro/saspro/legacy/xisf.py +1071 -0
  87. setiastro/saspro/linear_fit.py +534 -0
  88. setiastro/saspro/live_stacking.py +1830 -0
  89. setiastro/saspro/log_bus.py +5 -0
  90. setiastro/saspro/logging_config.py +460 -0
  91. setiastro/saspro/luminancerecombine.py +309 -0
  92. setiastro/saspro/main_helpers.py +201 -0
  93. setiastro/saspro/mask_creation.py +928 -0
  94. setiastro/saspro/masks_core.py +56 -0
  95. setiastro/saspro/mdi_widgets.py +353 -0
  96. setiastro/saspro/memory_utils.py +666 -0
  97. setiastro/saspro/metadata_patcher.py +75 -0
  98. setiastro/saspro/mfdeconv.py +3826 -0
  99. setiastro/saspro/mfdeconv_earlystop.py +71 -0
  100. setiastro/saspro/mfdeconvcudnn.py +3263 -0
  101. setiastro/saspro/mfdeconvsport.py +2382 -0
  102. setiastro/saspro/minorbodycatalog.py +567 -0
  103. setiastro/saspro/morphology.py +382 -0
  104. setiastro/saspro/multiscale_decomp.py +1290 -0
  105. setiastro/saspro/nbtorgb_stars.py +531 -0
  106. setiastro/saspro/numba_utils.py +3044 -0
  107. setiastro/saspro/numba_warmup.py +141 -0
  108. setiastro/saspro/ops/__init__.py +9 -0
  109. setiastro/saspro/ops/command_help_dialog.py +623 -0
  110. setiastro/saspro/ops/command_runner.py +217 -0
  111. setiastro/saspro/ops/commands.py +1594 -0
  112. setiastro/saspro/ops/script_editor.py +1102 -0
  113. setiastro/saspro/ops/scripts.py +1413 -0
  114. setiastro/saspro/ops/settings.py +560 -0
  115. setiastro/saspro/parallel_utils.py +554 -0
  116. setiastro/saspro/pedestal.py +121 -0
  117. setiastro/saspro/perfect_palette_picker.py +1053 -0
  118. setiastro/saspro/pipeline.py +110 -0
  119. setiastro/saspro/pixelmath.py +1600 -0
  120. setiastro/saspro/plate_solver.py +2435 -0
  121. setiastro/saspro/project_io.py +797 -0
  122. setiastro/saspro/psf_utils.py +136 -0
  123. setiastro/saspro/psf_viewer.py +549 -0
  124. setiastro/saspro/pyi_rthook_astroquery.py +95 -0
  125. setiastro/saspro/remove_green.py +314 -0
  126. setiastro/saspro/remove_stars.py +1625 -0
  127. setiastro/saspro/remove_stars_preset.py +404 -0
  128. setiastro/saspro/resources.py +472 -0
  129. setiastro/saspro/rgb_combination.py +207 -0
  130. setiastro/saspro/rgb_extract.py +19 -0
  131. setiastro/saspro/rgbalign.py +723 -0
  132. setiastro/saspro/runtime_imports.py +7 -0
  133. setiastro/saspro/runtime_torch.py +754 -0
  134. setiastro/saspro/save_options.py +72 -0
  135. setiastro/saspro/selective_color.py +1552 -0
  136. setiastro/saspro/sfcc.py +1425 -0
  137. setiastro/saspro/shortcuts.py +2807 -0
  138. setiastro/saspro/signature_insert.py +1099 -0
  139. setiastro/saspro/stacking_suite.py +17712 -0
  140. setiastro/saspro/star_alignment.py +7420 -0
  141. setiastro/saspro/star_alignment_preset.py +329 -0
  142. setiastro/saspro/star_metrics.py +49 -0
  143. setiastro/saspro/star_spikes.py +681 -0
  144. setiastro/saspro/star_stretch.py +470 -0
  145. setiastro/saspro/stat_stretch.py +502 -0
  146. setiastro/saspro/status_log_dock.py +78 -0
  147. setiastro/saspro/subwindow.py +3267 -0
  148. setiastro/saspro/supernovaasteroidhunter.py +1712 -0
  149. setiastro/saspro/swap_manager.py +99 -0
  150. setiastro/saspro/torch_backend.py +89 -0
  151. setiastro/saspro/torch_rejection.py +434 -0
  152. setiastro/saspro/view_bundle.py +1555 -0
  153. setiastro/saspro/wavescale_hdr.py +624 -0
  154. setiastro/saspro/wavescale_hdr_preset.py +100 -0
  155. setiastro/saspro/wavescalede.py +657 -0
  156. setiastro/saspro/wavescalede_preset.py +228 -0
  157. setiastro/saspro/wcs_update.py +374 -0
  158. setiastro/saspro/whitebalance.py +456 -0
  159. setiastro/saspro/widgets/__init__.py +48 -0
  160. setiastro/saspro/widgets/common_utilities.py +305 -0
  161. setiastro/saspro/widgets/graphics_views.py +122 -0
  162. setiastro/saspro/widgets/image_utils.py +518 -0
  163. setiastro/saspro/widgets/preview_dialogs.py +280 -0
  164. setiastro/saspro/widgets/spinboxes.py +275 -0
  165. setiastro/saspro/widgets/themed_buttons.py +13 -0
  166. setiastro/saspro/widgets/wavelet_utils.py +299 -0
  167. setiastro/saspro/window_shelf.py +185 -0
  168. setiastro/saspro/xisf.py +1123 -0
  169. setiastrosuitepro-1.6.0.dist-info/METADATA +266 -0
  170. setiastrosuitepro-1.6.0.dist-info/RECORD +174 -0
  171. setiastrosuitepro-1.6.0.dist-info/WHEEL +4 -0
  172. setiastrosuitepro-1.6.0.dist-info/entry_points.txt +6 -0
  173. setiastrosuitepro-1.6.0.dist-info/licenses/LICENSE +674 -0
  174. setiastrosuitepro-1.6.0.dist-info/licenses/license.txt +2580 -0
@@ -0,0 +1,3659 @@
1
+ #legacy.numba_utils.py
2
+ import numpy as np
3
+ from numba import njit, prange
4
+ from numba.typed import List
5
+ import cv2
6
+ import math
7
+
8
+ @njit(parallel=True, fastmath=True)
9
+ def blend_add_numba(A, B, alpha):
10
+ H, W, C = A.shape
11
+ out = np.empty_like(A)
12
+ for y in prange(H):
13
+ for x in range(W):
14
+ for c in range(C):
15
+ v = A[y,x,c] + B[y,x,c] * alpha
16
+ # clamp 0..1
17
+ if v < 0.0: v = 0.0
18
+ elif v > 1.0: v = 1.0
19
+ out[y,x,c] = v
20
+ return out
21
+
22
+ @njit(parallel=True, fastmath=True)
23
+ def blend_subtract_numba(A, B, alpha):
24
+ H, W, C = A.shape
25
+ out = np.empty_like(A)
26
+ for y in prange(H):
27
+ for x in range(W):
28
+ for c in range(C):
29
+ v = A[y,x,c] - B[y,x,c] * alpha
30
+ if v < 0.0: v = 0.0
31
+ elif v > 1.0: v = 1.0
32
+ out[y,x,c] = v
33
+ return out
34
+
35
+ @njit(parallel=True, fastmath=True)
36
+ def blend_multiply_numba(A, B, alpha):
37
+ H, W, C = A.shape
38
+ out = np.empty_like(A)
39
+ for y in prange(H):
40
+ for x in range(W):
41
+ for c in range(C):
42
+ v = (A[y,x,c] * (1-alpha)) + (A[y,x,c] * B[y,x,c] * alpha)
43
+ if v < 0.0: v = 0.0
44
+ elif v > 1.0: v = 1.0
45
+ out[y,x,c] = v
46
+ return out
47
+
48
+ @njit(parallel=True, fastmath=True)
49
+ def blend_divide_numba(A, B, alpha):
50
+ H, W, C = A.shape
51
+ out = np.empty_like(A)
52
+ eps = 1e-6
53
+ for y in prange(H):
54
+ for x in range(W):
55
+ for c in range(C):
56
+ # avoid division by zero
57
+ b = A[y,x,c] / (B[y,x,c] + eps)
58
+ # clamp f(A,B)
59
+ if b < 0.0: b = 0.0
60
+ elif b > 1.0: b = 1.0
61
+ # mix with original
62
+ v = A[y,x,c] * (1.0 - alpha) + b * alpha
63
+ # clamp final
64
+ if v < 0.0: v = 0.0
65
+ elif v > 1.0: v = 1.0
66
+ out[y,x,c] = v
67
+ return out
68
+
69
+ @njit(parallel=True, fastmath=True)
70
+ def blend_screen_numba(A, B, alpha):
71
+ H, W, C = A.shape
72
+ out = np.empty_like(A)
73
+ for y in prange(H):
74
+ for x in range(W):
75
+ for c in range(C):
76
+ # Screen: 1 - (1-A)*(1-B)
77
+ b = 1.0 - (1.0 - A[y,x,c]) * (1.0 - B[y,x,c])
78
+ if b < 0.0: b = 0.0
79
+ elif b > 1.0: b = 1.0
80
+ v = A[y,x,c] * (1.0 - alpha) + b * alpha
81
+ if v < 0.0: v = 0.0
82
+ elif v > 1.0: v = 1.0
83
+ out[y,x,c] = v
84
+ return out
85
+
86
+ @njit(parallel=True, fastmath=True)
87
+ def blend_overlay_numba(A, B, alpha):
88
+ H, W, C = A.shape
89
+ out = np.empty_like(A)
90
+ for y in prange(H):
91
+ for x in range(W):
92
+ for c in range(C):
93
+ a = A[y,x,c]
94
+ b_in = B[y,x,c]
95
+ # Overlay: if a < .5: 2*a*b, else: 1 - 2*(1-a)*(1-b)
96
+ if a <= 0.5:
97
+ b = 2.0 * a * b_in
98
+ else:
99
+ b = 1.0 - 2.0 * (1.0 - a) * (1.0 - b_in)
100
+ if b < 0.0: b = 0.0
101
+ elif b > 1.0: b = 1.0
102
+ v = a * (1.0 - alpha) + b * alpha
103
+ if v < 0.0: v = 0.0
104
+ elif v > 1.0: v = 1.0
105
+ out[y,x,c] = v
106
+ return out
107
+
108
+ @njit(parallel=True, fastmath=True)
109
+ def blend_difference_numba(A, B, alpha):
110
+ H, W, C = A.shape
111
+ out = np.empty_like(A)
112
+ for y in prange(H):
113
+ for x in range(W):
114
+ for c in range(C):
115
+ # Difference: |A - B|
116
+ b = A[y,x,c] - B[y,x,c]
117
+ if b < 0.0: b = -b
118
+ # clamp f(A,B) is redundant since abs() already ≥0; we cap above 1
119
+ if b > 1.0: b = 1.0
120
+ v = A[y,x,c] * (1.0 - alpha) + b * alpha
121
+ if v < 0.0: v = 0.0
122
+ elif v > 1.0: v = 1.0
123
+ out[y,x,c] = v
124
+ return out
125
+
126
+ @njit(parallel=True, fastmath=True)
127
+ def rescale_image_numba(image, factor):
128
+ """
129
+ Custom rescale function using bilinear interpolation optimized with numba.
130
+ Supports both mono (2D) and color (3D) images.
131
+ """
132
+ if image.ndim == 2:
133
+ height, width = image.shape
134
+ new_width = int(width * factor)
135
+ new_height = int(height * factor)
136
+ output = np.zeros((new_height, new_width), dtype=np.float32)
137
+ for y in prange(new_height):
138
+ for x in prange(new_width):
139
+ src_x = x / factor
140
+ src_y = y / factor
141
+ x0, y0 = int(src_x), int(src_y)
142
+ x1 = x0 + 1 if x0 + 1 < width else width - 1
143
+ y1 = y0 + 1 if y0 + 1 < height else height - 1
144
+ dx = src_x - x0
145
+ dy = src_y - y0
146
+ output[y, x] = (image[y0, x0] * (1 - dx) * (1 - dy) +
147
+ image[y0, x1] * dx * (1 - dy) +
148
+ image[y1, x0] * (1 - dx) * dy +
149
+ image[y1, x1] * dx * dy)
150
+ return output
151
+ else:
152
+ height, width, channels = image.shape
153
+ new_width = int(width * factor)
154
+ new_height = int(height * factor)
155
+ output = np.zeros((new_height, new_width, channels), dtype=np.float32)
156
+ for y in prange(new_height):
157
+ for x in prange(new_width):
158
+ src_x = x / factor
159
+ src_y = y / factor
160
+ x0, y0 = int(src_x), int(src_y)
161
+ x1 = x0 + 1 if x0 + 1 < width else width - 1
162
+ y1 = y0 + 1 if y0 + 1 < height else height - 1
163
+ dx = src_x - x0
164
+ dy = src_y - y0
165
+ for c in range(channels):
166
+ output[y, x, c] = (image[y0, x0, c] * (1 - dx) * (1 - dy) +
167
+ image[y0, x1, c] * dx * (1 - dy) +
168
+ image[y1, x0, c] * (1 - dx) * dy +
169
+ image[y1, x1, c] * dx * dy)
170
+ return output
171
+
172
+ @njit(parallel=True, fastmath=True)
173
+ def bin2x2_numba(image):
174
+ """
175
+ Downsample the image by 2×2 via simple averaging (“integer binning”).
176
+ Works on 2D (H×W) or 3D (H×W×C) arrays. If dimensions aren’t even,
177
+ the last row/column is dropped.
178
+ """
179
+ h, w = image.shape[:2]
180
+ h2 = h // 2
181
+ w2 = w // 2
182
+
183
+ # allocate output
184
+ if image.ndim == 2:
185
+ out = np.empty((h2, w2), dtype=np.float32)
186
+ for i in prange(h2):
187
+ for j in prange(w2):
188
+ # average 2×2 block
189
+ s = image[2*i , 2*j ] \
190
+ + image[2*i+1, 2*j ] \
191
+ + image[2*i , 2*j+1] \
192
+ + image[2*i+1, 2*j+1]
193
+ out[i, j] = s * 0.25
194
+ else:
195
+ c = image.shape[2]
196
+ out = np.empty((h2, w2, c), dtype=np.float32)
197
+ for i in prange(h2):
198
+ for j in prange(w2):
199
+ for k in range(c):
200
+ s = image[2*i , 2*j , k] \
201
+ + image[2*i+1, 2*j , k] \
202
+ + image[2*i , 2*j+1, k] \
203
+ + image[2*i+1, 2*j+1, k]
204
+ out[i, j, k] = s * 0.25
205
+
206
+ return out
207
+
208
+ @njit(parallel=True, fastmath=True)
209
+ def flip_horizontal_numba(image):
210
+ """
211
+ Flips an image horizontally using Numba JIT.
212
+ Works with both mono (2D) and color (3D) images.
213
+ """
214
+ if image.ndim == 2:
215
+ height, width = image.shape
216
+ output = np.empty((height, width), dtype=image.dtype)
217
+ for y in prange(height):
218
+ for x in prange(width):
219
+ output[y, x] = image[y, width - x - 1]
220
+ return output
221
+ else:
222
+ height, width, channels = image.shape
223
+ output = np.empty((height, width, channels), dtype=image.dtype)
224
+ for y in prange(height):
225
+ for x in prange(width):
226
+ for c in range(channels):
227
+ output[y, x, c] = image[y, width - x - 1, c]
228
+ return output
229
+
230
+
231
+ @njit(parallel=True, fastmath=True)
232
+ def flip_vertical_numba(image):
233
+ """
234
+ Flips an image vertically using Numba JIT.
235
+ Works with both mono (2D) and color (3D) images.
236
+ """
237
+ if image.ndim == 2:
238
+ height, width = image.shape
239
+ output = np.empty((height, width), dtype=image.dtype)
240
+ for y in prange(height):
241
+ for x in prange(width):
242
+ output[y, x] = image[height - y - 1, x]
243
+ return output
244
+ else:
245
+ height, width, channels = image.shape
246
+ output = np.empty((height, width, channels), dtype=image.dtype)
247
+ for y in prange(height):
248
+ for x in prange(width):
249
+ for c in range(channels):
250
+ output[y, x, c] = image[height - y - 1, x, c]
251
+ return output
252
+
253
+
254
+ @njit(parallel=True, fastmath=True)
255
+ def rotate_90_clockwise_numba(image):
256
+ """
257
+ Rotates the image 90 degrees clockwise.
258
+ Works with both mono (2D) and color (3D) images.
259
+ """
260
+ if image.ndim == 2:
261
+ height, width = image.shape
262
+ output = np.empty((width, height), dtype=image.dtype)
263
+ for y in prange(height):
264
+ for x in prange(width):
265
+ output[x, height - 1 - y] = image[y, x]
266
+ return output
267
+ else:
268
+ height, width, channels = image.shape
269
+ output = np.empty((width, height, channels), dtype=image.dtype)
270
+ for y in prange(height):
271
+ for x in prange(width):
272
+ for c in range(channels):
273
+ output[x, height - 1 - y, c] = image[y, x, c]
274
+ return output
275
+
276
+
277
+ @njit(parallel=True, fastmath=True)
278
+ def rotate_90_counterclockwise_numba(image):
279
+ """
280
+ Rotates the image 90 degrees counterclockwise.
281
+ Works with both mono (2D) and color (3D) images.
282
+ """
283
+ if image.ndim == 2:
284
+ height, width = image.shape
285
+ output = np.empty((width, height), dtype=image.dtype)
286
+ for y in prange(height):
287
+ for x in prange(width):
288
+ output[width - 1 - x, y] = image[y, x]
289
+ return output
290
+ else:
291
+ height, width, channels = image.shape
292
+ output = np.empty((width, height, channels), dtype=image.dtype)
293
+ for y in prange(height):
294
+ for x in prange(width):
295
+ for c in range(channels):
296
+ output[width - 1 - x, y, c] = image[y, x, c]
297
+ return output
298
+
299
+
300
+ @njit(parallel=True, fastmath=True)
301
+ def invert_image_numba(image):
302
+ """
303
+ Inverts an image (1 - pixel value) using Numba JIT.
304
+ Works with both mono (2D) and color (3D) images.
305
+ """
306
+ if image.ndim == 2:
307
+ height, width = image.shape
308
+ output = np.empty((height, width), dtype=image.dtype)
309
+ for y in prange(height):
310
+ for x in prange(width):
311
+ output[y, x] = 1.0 - image[y, x]
312
+ return output
313
+ else:
314
+ height, width, channels = image.shape
315
+ output = np.empty((height, width, channels), dtype=image.dtype)
316
+ for y in prange(height):
317
+ for x in prange(width):
318
+ for c in range(channels):
319
+ output[y, x, c] = 1.0 - image[y, x, c]
320
+ return output
321
+
322
+ @njit(parallel=True, fastmath=True)
323
+ def rotate_180_numba(image):
324
+ """
325
+ Rotates the image 180 degrees.
326
+ Works with both mono (2D) and color (3D) images.
327
+ """
328
+ if image.ndim == 2:
329
+ height, width = image.shape
330
+ output = np.empty((height, width), dtype=image.dtype)
331
+ for y in prange(height):
332
+ for x in prange(width):
333
+ output[y, x] = image[height - 1 - y, width - 1 - x]
334
+ return output
335
+ else:
336
+ height, width, channels = image.shape
337
+ output = np.empty((height, width, channels), dtype=image.dtype)
338
+ for y in prange(height):
339
+ for x in prange(width):
340
+ for c in range(channels):
341
+ output[y, x, c] = image[height - 1 - y, width - 1 - x, c]
342
+ return output
343
+
344
+ def normalize_flat_cfa_inplace(flat2d: np.ndarray, pattern: str, *, combine_greens: bool = True) -> np.ndarray:
345
+ """
346
+ Normalize a Bayer/mosaic flat so each CFA plane has median 1.0.
347
+ Operates in-place on flat2d and returns it.
348
+
349
+ pattern: 'RGGB','BGGR','GRBG','GBRG'
350
+ combine_greens: if True, use one median for both greens (reduces checkerboard risk)
351
+ """
352
+ pat = (pattern or "RGGB").strip().upper()
353
+ if pat not in ("RGGB", "BGGR", "GRBG", "GBRG"):
354
+ pat = "RGGB"
355
+
356
+ # map (row_parity, col_parity) -> plane key
357
+ # row0: even rows, row1: odd rows; col0: even cols, col1: odd cols
358
+ if pat == "RGGB":
359
+ m = {(0,0):"R", (0,1):"G1", (1,0):"G2", (1,1):"B"}
360
+ elif pat == "BGGR":
361
+ m = {(0,0):"B", (0,1):"G1", (1,0):"G2", (1,1):"R"}
362
+ elif pat == "GRBG":
363
+ m = {(0,0):"G1", (0,1):"R", (1,0):"B", (1,1):"G2"}
364
+ else: # "GBRG"
365
+ m = {(0,0):"G1", (0,1):"B", (1,0):"R", (1,1):"G2"}
366
+
367
+ # build slice views
368
+ planes = {
369
+ m[(0,0)]: flat2d[0::2, 0::2],
370
+ m[(0,1)]: flat2d[0::2, 1::2],
371
+ m[(1,0)]: flat2d[1::2, 0::2],
372
+ m[(1,1)]: flat2d[1::2, 1::2],
373
+ }
374
+
375
+ def safe_median(a: np.ndarray) -> float:
376
+ v = a[np.isfinite(a) & (a > 0)]
377
+ if v.size == 0:
378
+ return 1.0
379
+ d = float(np.median(v))
380
+ return d if np.isfinite(d) and d > 0 else 1.0
381
+
382
+ # greens
383
+ if combine_greens and ("G1" in planes) and ("G2" in planes):
384
+ g = np.concatenate([
385
+ planes["G1"][np.isfinite(planes["G1"]) & (planes["G1"] > 0)].ravel(),
386
+ planes["G2"][np.isfinite(planes["G2"]) & (planes["G2"] > 0)].ravel(),
387
+ ])
388
+ denom_g = float(np.median(g)) if g.size else 1.0
389
+ if not np.isfinite(denom_g) or denom_g <= 0:
390
+ denom_g = 1.0
391
+ planes["G1"][:] = planes["G1"] / denom_g
392
+ planes["G2"][:] = planes["G2"] / denom_g
393
+ else:
394
+ for k in ("G1","G2"):
395
+ if k in planes:
396
+ d = safe_median(planes[k])
397
+ planes[k][:] = planes[k] / d
398
+
399
+ # R / B
400
+ for k in ("R","B"):
401
+ if k in planes:
402
+ d = safe_median(planes[k])
403
+ planes[k][:] = planes[k] / d
404
+
405
+ # final safety
406
+ np.nan_to_num(flat2d, copy=False, nan=1.0, posinf=1.0, neginf=1.0)
407
+ flat2d[flat2d == 0] = 1.0
408
+ return flat2d
409
+
410
+ @njit(parallel=True, fastmath=True)
411
+ def apply_flat_division_numba_2d(image, master_flat, master_bias=None):
412
+ """
413
+ Mono version: image.shape == (H,W)
414
+ """
415
+ if master_bias is not None:
416
+ master_flat = master_flat - master_bias
417
+ image = image - master_bias
418
+
419
+ median_flat = np.mean(master_flat)
420
+ height, width = image.shape
421
+
422
+ for y in prange(height):
423
+ for x in range(width):
424
+ image[y, x] /= (master_flat[y, x] / median_flat)
425
+
426
+ return image
427
+
428
+
429
+ @njit(parallel=True, fastmath=True)
430
+ def apply_flat_division_numba_3d(image, master_flat, master_bias=None):
431
+ """
432
+ Color version: image.shape == (H,W,C)
433
+ """
434
+ if master_bias is not None:
435
+ master_flat = master_flat - master_bias
436
+ image = image - master_bias
437
+
438
+ median_flat = np.mean(master_flat)
439
+ height, width, channels = image.shape
440
+
441
+ for y in prange(height):
442
+ for x in range(width):
443
+ for c in range(channels):
444
+ image[y, x, c] /= (master_flat[y, x, c] / median_flat)
445
+
446
+ return image
447
+
448
+ @njit(parallel=True, fastmath=True)
449
+ def _flat_div_2d(img, flat):
450
+ h, w = img.shape
451
+ for y in prange(h):
452
+ for x in range(w):
453
+ f = flat[y, x]
454
+ if (not np.isfinite(f)) or f <= 0.0:
455
+ f = 1.0
456
+ img[y, x] = img[y, x] / f
457
+ return img
458
+
459
+ @njit(parallel=True, fastmath=True)
460
+ def _flat_div_hwc(img, flat):
461
+ h, w, c = img.shape
462
+ flat_is_2d = (flat.ndim == 2)
463
+ for y in prange(h):
464
+ for x in range(w):
465
+ if flat_is_2d:
466
+ f0 = flat[y, x]
467
+ if (not np.isfinite(f0)) or f0 <= 0.0:
468
+ f0 = 1.0
469
+ for k in range(c):
470
+ img[y, x, k] = img[y, x, k] / f0
471
+ else:
472
+ for k in range(c):
473
+ f = flat[y, x, k]
474
+ if (not np.isfinite(f)) or f <= 0.0:
475
+ f = 1.0
476
+ img[y, x, k] = img[y, x, k] / f
477
+ return img
478
+
479
+ @njit(parallel=True, fastmath=True)
480
+ def _flat_div_chw(img, flat):
481
+ c, h, w = img.shape
482
+ flat_is_2d = (flat.ndim == 2)
483
+ for y in prange(h):
484
+ for x in range(w):
485
+ if flat_is_2d:
486
+ f0 = flat[y, x]
487
+ if (not np.isfinite(f0)) or f0 <= 0.0:
488
+ f0 = 1.0
489
+ for k in range(c):
490
+ img[k, y, x] = img[k, y, x] / f0
491
+ else:
492
+ for k in range(c):
493
+ f = flat[k, y, x]
494
+ if (not np.isfinite(f)) or f <= 0.0:
495
+ f = 1.0
496
+ img[k, y, x] = img[k, y, x] / f
497
+ return img
498
+
499
+ def apply_flat_division_numba(image, master_flat, master_bias=None):
500
+ """
501
+ Supports:
502
+ - 2D mono/bayer: (H,W)
503
+ - Color HWC: (H,W,3)
504
+ - Color CHW: (3,H,W)
505
+
506
+ NOTE: master_bias arg kept for API compatibility; do bias/dark subtraction outside.
507
+ """
508
+ if image.ndim == 2:
509
+ return _flat_div_2d(image, master_flat)
510
+
511
+ if image.ndim == 3:
512
+ # CHW common in your pipeline
513
+ if image.shape[0] == 3 and image.shape[-1] != 3:
514
+ return _flat_div_chw(image, master_flat)
515
+ # HWC
516
+ if image.shape[-1] == 3:
517
+ return _flat_div_hwc(image, master_flat)
518
+
519
+ # fallback: treat as HWC
520
+ return _flat_div_hwc(image, master_flat)
521
+
522
+ raise ValueError(f"apply_flat_division_numba: expected 2D or 3D, got shape {image.shape}")
523
+
524
+ def _bayerpat_to_id(pat: str) -> int:
525
+ pat = (pat or "RGGB").strip().upper()
526
+ if pat == "RGGB": return 0
527
+ if pat == "BGGR": return 1
528
+ if pat == "GRBG": return 2
529
+ if pat == "GBRG": return 3
530
+ return 0
531
+
532
+ def _bayer_plane_medians(flat2d: np.ndarray, pat: str) -> np.ndarray:
533
+ pat = (pat or "RGGB").strip().upper()
534
+ if pat == "RGGB":
535
+ r = np.median(flat2d[0::2, 0::2])
536
+ g1 = np.median(flat2d[0::2, 1::2])
537
+ g2 = np.median(flat2d[1::2, 0::2])
538
+ b = np.median(flat2d[1::2, 1::2])
539
+ elif pat == "BGGR":
540
+ b = np.median(flat2d[0::2, 0::2])
541
+ g1 = np.median(flat2d[0::2, 1::2])
542
+ g2 = np.median(flat2d[1::2, 0::2])
543
+ r = np.median(flat2d[1::2, 1::2])
544
+ elif pat == "GRBG":
545
+ g1 = np.median(flat2d[0::2, 0::2])
546
+ r = np.median(flat2d[0::2, 1::2])
547
+ b = np.median(flat2d[1::2, 0::2])
548
+ g2 = np.median(flat2d[1::2, 1::2])
549
+ else: # GBRG
550
+ g1 = np.median(flat2d[0::2, 0::2])
551
+ b = np.median(flat2d[0::2, 1::2])
552
+ r = np.median(flat2d[1::2, 0::2])
553
+ g2 = np.median(flat2d[1::2, 1::2])
554
+
555
+ med4 = np.array([r, g1, g2, b], dtype=np.float32)
556
+ med4[~np.isfinite(med4)] = 1.0
557
+ med4[med4 <= 0] = 1.0
558
+ return med4
559
+
560
+ @njit(parallel=True, fastmath=True)
561
+ def apply_flat_division_numba_bayer_2d(image, master_flat, med4, pat_id):
562
+ """
563
+ Bayer-aware mono division. image/master_flat are (H,W).
564
+ med4 is [R,G1,G2,B] for that master_flat, pat_id in {0..3}.
565
+ """
566
+ H, W = image.shape
567
+ for y in prange(H):
568
+ y1 = y & 1
569
+ for x in range(W):
570
+ x1 = x & 1
571
+
572
+ # map parity->plane index
573
+ if pat_id == 0: # RGGB: (0,0)R (0,1)G1 (1,0)G2 (1,1)B
574
+ pi = 0 if (y1==0 and x1==0) else 1 if (y1==0 and x1==1) else 2 if (y1==1 and x1==0) else 3
575
+ elif pat_id == 1: # BGGR
576
+ pi = 3 if (y1==1 and x1==1) else 1 if (y1==0 and x1==1) else 2 if (y1==1 and x1==0) else 0
577
+ elif pat_id == 2: # GRBG
578
+ pi = 1 if (y1==0 and x1==0) else 0 if (y1==0 and x1==1) else 3 if (y1==1 and x1==0) else 2
579
+ else: # GBRG
580
+ pi = 1 if (y1==0 and x1==0) else 3 if (y1==0 and x1==1) else 0 if (y1==1 and x1==0) else 2
581
+
582
+ denom = master_flat[y, x] / med4[pi]
583
+ if denom == 0.0 or not np.isfinite(denom):
584
+ denom = 1.0
585
+ image[y, x] /= denom
586
+ return image
587
+
588
+ def apply_flat_division_bayer(image2d: np.ndarray, flat2d: np.ndarray, bayerpat: str):
589
+ med4 = _bayer_plane_medians(flat2d, bayerpat)
590
+ pid = _bayerpat_to_id(bayerpat)
591
+ return apply_flat_division_numba_bayer_2d(image2d, flat2d, med4, pid)
592
+
593
+ @njit(parallel=True)
594
+ def subtract_dark_3d(frames, dark_frame):
595
+ """
596
+ For mono stack:
597
+ frames.shape == (F,H,W)
598
+ dark_frame.shape == (H,W)
599
+ Returns the same shape (F,H,W).
600
+ """
601
+ num_frames, height, width = frames.shape
602
+ result = np.empty_like(frames, dtype=np.float32)
603
+
604
+ for i in prange(num_frames):
605
+ # Subtract the dark frame from each 2D slice
606
+ result[i] = frames[i] - dark_frame
607
+
608
+ return result
609
+
610
+
611
+ @njit(parallel=True)
612
+ def subtract_dark_4d(frames, dark_frame):
613
+ """
614
+ For color stack:
615
+ frames.shape == (F,H,W,C)
616
+ dark_frame.shape == (H,W,C)
617
+ Returns the same shape (F,H,W,C).
618
+ """
619
+ num_frames, height, width, channels = frames.shape
620
+ result = np.empty_like(frames, dtype=np.float32)
621
+
622
+ for i in prange(num_frames):
623
+ for y in range(height):
624
+ for x in range(width):
625
+ for c in range(channels):
626
+ result[i, y, x, c] = frames[i, y, x, c] - dark_frame[y, x, c]
627
+
628
+ return result
629
+
630
+ def subtract_dark(frames, dark_frame):
631
+ """
632
+ Dispatcher function that calls the correct Numba function
633
+ depending on whether 'frames' is 3D or 4D.
634
+ """
635
+ if frames.ndim == 3:
636
+ # frames: (F,H,W), dark_frame: (H,W)
637
+ return subtract_dark_3d(frames, dark_frame)
638
+ elif frames.ndim == 4:
639
+ # frames: (F,H,W,C), dark_frame: (H,W,C)
640
+ return subtract_dark_4d(frames, dark_frame)
641
+ else:
642
+ raise ValueError(f"subtract_dark: frames must be 3D or 4D, got {frames.shape}")
643
+
644
+
645
+ import numpy as np
646
+ from numba import njit, prange
647
+
648
+ # -------------------------------
649
+ # Windsorized Sigma Clipping (Weighted, Iterative)
650
+ # -------------------------------
651
+
652
+ @njit(parallel=True, fastmath=True)
653
+ def windsorized_sigma_clip_weighted_3d_iter(stack, weights, lower=2.5, upper=2.5, iterations=2):
654
+ """
655
+ Iterative Weighted Windsorized Sigma Clipping for a 3D mono stack.
656
+ stack.shape == (F,H,W)
657
+ weights.shape can be (F,) or (F,H,W).
658
+ Returns a tuple:
659
+ (clipped, rejection_mask)
660
+ where:
661
+ clipped is a 2D image (H,W),
662
+ rejection_mask is a boolean array of shape (F,H,W) with True indicating rejection.
663
+ """
664
+ num_frames, height, width = stack.shape
665
+ clipped = np.zeros((height, width), dtype=np.float32)
666
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
667
+
668
+ # Check weights shape
669
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
670
+ pass
671
+ elif weights.ndim == 3 and weights.shape == stack.shape:
672
+ pass
673
+ else:
674
+ raise ValueError("windsorized_sigma_clip_weighted_3d_iter: mismatch in shapes for 3D stack & weights")
675
+
676
+ for i in prange(height):
677
+ for j in range(width):
678
+ pixel_values = stack[:, i, j] # shape=(F,)
679
+ if weights.ndim == 1:
680
+ pixel_weights = weights[:] # shape (F,)
681
+ else:
682
+ pixel_weights = weights[:, i, j]
683
+ # Start with nonzero pixels as valid
684
+ valid_mask = pixel_values != 0
685
+ for _ in range(iterations):
686
+ if np.sum(valid_mask) == 0:
687
+ break
688
+ valid_vals = pixel_values[valid_mask]
689
+ median_val = np.median(valid_vals)
690
+ std_dev = np.std(valid_vals)
691
+ lower_bound = median_val - lower * std_dev
692
+ upper_bound = median_val + upper * std_dev
693
+ valid_mask = valid_mask & (pixel_values >= lower_bound) & (pixel_values <= upper_bound)
694
+ # Record rejections: a pixel is rejected if not valid.
695
+ for f in range(num_frames):
696
+ rej_mask[f, i, j] = not valid_mask[f]
697
+ valid_vals = pixel_values[valid_mask]
698
+ valid_w = pixel_weights[valid_mask]
699
+ wsum = np.sum(valid_w)
700
+ if wsum > 0:
701
+ clipped[i, j] = np.sum(valid_vals * valid_w) / wsum
702
+ else:
703
+ nonzero = pixel_values[pixel_values != 0]
704
+ if nonzero.size > 0:
705
+ clipped[i, j] = np.median(nonzero)
706
+ else:
707
+ clipped[i, j] = 0.0
708
+ return clipped, rej_mask
709
+
710
+
711
+ @njit(parallel=True, fastmath=True)
712
+ def windsorized_sigma_clip_weighted_4d_iter(stack, weights, lower=2.5, upper=2.5, iterations=2):
713
+ """
714
+ Iterative Weighted Windsorized Sigma Clipping for a 4D color stack.
715
+ stack.shape == (F,H,W,C)
716
+ weights.shape can be (F,) or (F,H,W,C).
717
+ Returns a tuple:
718
+ (clipped, rejection_mask)
719
+ where:
720
+ clipped is a 3D image (H,W,C),
721
+ rejection_mask is a boolean array of shape (F,H,W,C).
722
+ """
723
+ num_frames, height, width, channels = stack.shape
724
+ clipped = np.zeros((height, width, channels), dtype=np.float32)
725
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
726
+
727
+ # Check weights shape
728
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
729
+ pass
730
+ elif weights.ndim == 4 and weights.shape == stack.shape:
731
+ pass
732
+ else:
733
+ raise ValueError("windsorized_sigma_clip_weighted_4d_iter: mismatch in shapes for 4D stack & weights")
734
+
735
+ for i in prange(height):
736
+ for j in range(width):
737
+ for c in range(channels):
738
+ pixel_values = stack[:, i, j, c] # shape=(F,)
739
+ if weights.ndim == 1:
740
+ pixel_weights = weights[:]
741
+ else:
742
+ pixel_weights = weights[:, i, j, c]
743
+ valid_mask = pixel_values != 0
744
+ for _ in range(iterations):
745
+ if np.sum(valid_mask) == 0:
746
+ break
747
+ valid_vals = pixel_values[valid_mask]
748
+ median_val = np.median(valid_vals)
749
+ std_dev = np.std(valid_vals)
750
+ lower_bound = median_val - lower * std_dev
751
+ upper_bound = median_val + upper * std_dev
752
+ valid_mask = valid_mask & (pixel_values >= lower_bound) & (pixel_values <= upper_bound)
753
+ for f in range(num_frames):
754
+ rej_mask[f, i, j, c] = not valid_mask[f]
755
+ valid_vals = pixel_values[valid_mask]
756
+ valid_w = pixel_weights[valid_mask]
757
+ wsum = np.sum(valid_w)
758
+ if wsum > 0:
759
+ clipped[i, j, c] = np.sum(valid_vals * valid_w) / wsum
760
+ else:
761
+ nonzero = pixel_values[pixel_values != 0]
762
+ if nonzero.size > 0:
763
+ clipped[i, j, c] = np.median(nonzero)
764
+ else:
765
+ clipped[i, j, c] = 0.0
766
+ return clipped, rej_mask
767
+
768
+
769
+ def windsorized_sigma_clip_weighted(stack, weights, lower=2.5, upper=2.5, iterations=2):
770
+ """
771
+ Dispatcher that calls the appropriate iterative Numba function.
772
+ Now returns (clipped, rejection_mask).
773
+ """
774
+ if stack.ndim == 3:
775
+ return windsorized_sigma_clip_weighted_3d_iter(stack, weights, lower, upper, iterations)
776
+ elif stack.ndim == 4:
777
+ return windsorized_sigma_clip_weighted_4d_iter(stack, weights, lower, upper, iterations)
778
+ else:
779
+ raise ValueError(f"windsorized_sigma_clip_weighted: stack must be 3D or 4D, got {stack.shape}")
780
+
781
+
782
+ # -------------------------------
783
+ # Kappa-Sigma Clipping (Weighted)
784
+ # -------------------------------
785
+
786
+ @njit(parallel=True, fastmath=True)
787
+ def kappa_sigma_clip_weighted_3d(stack, weights, kappa=2.5, iterations=3):
788
+ """
789
+ Kappa-Sigma Clipping for a 3D mono stack.
790
+ stack.shape == (F,H,W)
791
+ Returns a tuple: (clipped, rejection_mask)
792
+ where rejection_mask is of shape (F,H,W) indicating per-frame rejections.
793
+ """
794
+ num_frames, height, width = stack.shape
795
+ clipped = np.empty((height, width), dtype=np.float32)
796
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
797
+
798
+ for i in prange(height):
799
+ for j in range(width):
800
+ pixel_values = stack[:, i, j].copy()
801
+ if weights.ndim == 1:
802
+ pixel_weights = weights[:]
803
+ else:
804
+ pixel_weights = weights[:, i, j].copy()
805
+ # Initialize tracking of indices
806
+ current_idx = np.empty(num_frames, dtype=np.int64)
807
+ for f in range(num_frames):
808
+ current_idx[f] = f
809
+ current_vals = pixel_values
810
+ current_w = pixel_weights
811
+ current_indices = current_idx
812
+ med = 0.0
813
+ for _ in range(iterations):
814
+ if current_vals.size == 0:
815
+ break
816
+ med = np.median(current_vals)
817
+ std = np.std(current_vals)
818
+ lower_bound = med - kappa * std
819
+ upper_bound = med + kappa * std
820
+ valid = (current_vals != 0) & (current_vals >= lower_bound) & (current_vals <= upper_bound)
821
+ current_vals = current_vals[valid]
822
+ current_w = current_w[valid]
823
+ current_indices = current_indices[valid]
824
+ # Mark rejected: frames not in current_indices are rejected.
825
+ for f in range(num_frames):
826
+ # Check if f is in current_indices
827
+ found = False
828
+ for k in range(current_indices.size):
829
+ if current_indices[k] == f:
830
+ found = True
831
+ break
832
+ if not found:
833
+ rej_mask[f, i, j] = True
834
+ else:
835
+ rej_mask[f, i, j] = False
836
+ if current_w.size > 0 and current_w.sum() > 0:
837
+ clipped[i, j] = np.sum(current_vals * current_w) / current_w.sum()
838
+ else:
839
+ clipped[i, j] = med
840
+ return clipped, rej_mask
841
+
842
+
843
+ @njit(parallel=True, fastmath=True)
844
+ def kappa_sigma_clip_weighted_4d(stack, weights, kappa=2.5, iterations=3):
845
+ """
846
+ Kappa-Sigma Clipping for a 4D color stack.
847
+ stack.shape == (F,H,W,C)
848
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W,C).
849
+ """
850
+ num_frames, height, width, channels = stack.shape
851
+ clipped = np.empty((height, width, channels), dtype=np.float32)
852
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
853
+
854
+ for i in prange(height):
855
+ for j in range(width):
856
+ for c in range(channels):
857
+ pixel_values = stack[:, i, j, c].copy()
858
+ if weights.ndim == 1:
859
+ pixel_weights = weights[:]
860
+ else:
861
+ pixel_weights = weights[:, i, j, c].copy()
862
+ current_idx = np.empty(num_frames, dtype=np.int64)
863
+ for f in range(num_frames):
864
+ current_idx[f] = f
865
+ current_vals = pixel_values
866
+ current_w = pixel_weights
867
+ current_indices = current_idx
868
+ med = 0.0
869
+ for _ in range(iterations):
870
+ if current_vals.size == 0:
871
+ break
872
+ med = np.median(current_vals)
873
+ std = np.std(current_vals)
874
+ lower_bound = med - kappa * std
875
+ upper_bound = med + kappa * std
876
+ valid = (current_vals != 0) & (current_vals >= lower_bound) & (current_vals <= upper_bound)
877
+ current_vals = current_vals[valid]
878
+ current_w = current_w[valid]
879
+ current_indices = current_indices[valid]
880
+ for f in range(num_frames):
881
+ found = False
882
+ for k in range(current_indices.size):
883
+ if current_indices[k] == f:
884
+ found = True
885
+ break
886
+ if not found:
887
+ rej_mask[f, i, j, c] = True
888
+ else:
889
+ rej_mask[f, i, j, c] = False
890
+ if current_w.size > 0 and current_w.sum() > 0:
891
+ clipped[i, j, c] = np.sum(current_vals * current_w) / current_w.sum()
892
+ else:
893
+ clipped[i, j, c] = med
894
+ return clipped, rej_mask
895
+
896
+
897
+ def kappa_sigma_clip_weighted(stack, weights, kappa=2.5, iterations=3):
898
+ """
899
+ Dispatcher that returns (clipped, rejection_mask) for kappa-sigma clipping.
900
+ """
901
+ if stack.ndim == 3:
902
+ return kappa_sigma_clip_weighted_3d(stack, weights, kappa, iterations)
903
+ elif stack.ndim == 4:
904
+ return kappa_sigma_clip_weighted_4d(stack, weights, kappa, iterations)
905
+ else:
906
+ raise ValueError(f"kappa_sigma_clip_weighted: stack must be 3D or 4D, got {stack.shape}")
907
+
908
+
909
+ # -------------------------------
910
+ # Trimmed Mean (Weighted)
911
+ # -------------------------------
912
+
913
+ @njit(parallel=True, fastmath=True)
914
+ def trimmed_mean_weighted_3d(stack, weights, trim_fraction=0.1):
915
+ """
916
+ Trimmed Mean for a 3D mono stack.
917
+ stack.shape == (F,H,W)
918
+ Returns (clipped, rejection_mask) where rejection_mask (F,H,W) flags frames that were trimmed.
919
+ """
920
+ num_frames, height, width = stack.shape
921
+ clipped = np.empty((height, width), dtype=np.float32)
922
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
923
+
924
+ for i in prange(height):
925
+ for j in range(width):
926
+ pix_all = stack[:, i, j]
927
+ if weights.ndim == 1:
928
+ w_all = weights[:]
929
+ else:
930
+ w_all = weights[:, i, j]
931
+ # Exclude zeros and record original indices.
932
+ valid = pix_all != 0
933
+ pix = pix_all[valid]
934
+ w = w_all[valid]
935
+ orig_idx = np.empty(pix_all.shape[0], dtype=np.int64)
936
+ count = 0
937
+ for f in range(num_frames):
938
+ if valid[f]:
939
+ orig_idx[count] = f
940
+ count += 1
941
+ n = pix.size
942
+ if n == 0:
943
+ clipped[i, j] = 0.0
944
+ # Mark all as rejected.
945
+ for f in range(num_frames):
946
+ if not valid[f]:
947
+ rej_mask[f, i, j] = True
948
+ continue
949
+ trim = int(trim_fraction * n)
950
+ order = np.argsort(pix)
951
+ # Determine which indices (in the valid list) are kept.
952
+ if n > 2 * trim:
953
+ keep_order = order[trim:n - trim]
954
+ else:
955
+ keep_order = order
956
+ # Build a mask for the valid pixels (length n) that are kept.
957
+ keep_mask = np.zeros(n, dtype=np.bool_)
958
+ for k in range(keep_order.size):
959
+ keep_mask[keep_order[k]] = True
960
+ # Map back to original frame indices.
961
+ for idx in range(n):
962
+ frame = orig_idx[idx]
963
+ if not keep_mask[idx]:
964
+ rej_mask[frame, i, j] = True
965
+ else:
966
+ rej_mask[frame, i, j] = False
967
+ # Compute weighted average of kept values.
968
+ sorted_pix = pix[order]
969
+ sorted_w = w[order]
970
+ if n > 2 * trim:
971
+ trimmed_values = sorted_pix[trim:n - trim]
972
+ trimmed_weights = sorted_w[trim:n - trim]
973
+ else:
974
+ trimmed_values = sorted_pix
975
+ trimmed_weights = sorted_w
976
+ wsum = trimmed_weights.sum()
977
+ if wsum > 0:
978
+ clipped[i, j] = np.sum(trimmed_values * trimmed_weights) / wsum
979
+ else:
980
+ clipped[i, j] = np.median(trimmed_values)
981
+ return clipped, rej_mask
982
+
983
+
984
+ @njit(parallel=True, fastmath=True)
985
+ def trimmed_mean_weighted_4d(stack, weights, trim_fraction=0.1):
986
+ """
987
+ Trimmed Mean for a 4D color stack.
988
+ stack.shape == (F,H,W,C)
989
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W,C).
990
+ """
991
+ num_frames, height, width, channels = stack.shape
992
+ clipped = np.empty((height, width, channels), dtype=np.float32)
993
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
994
+
995
+ for i in prange(height):
996
+ for j in range(width):
997
+ for c in range(channels):
998
+ pix_all = stack[:, i, j, c]
999
+ if weights.ndim == 1:
1000
+ w_all = weights[:]
1001
+ else:
1002
+ w_all = weights[:, i, j, c]
1003
+ valid = pix_all != 0
1004
+ pix = pix_all[valid]
1005
+ w = w_all[valid]
1006
+ orig_idx = np.empty(pix_all.shape[0], dtype=np.int64)
1007
+ count = 0
1008
+ for f in range(num_frames):
1009
+ if valid[f]:
1010
+ orig_idx[count] = f
1011
+ count += 1
1012
+ n = pix.size
1013
+ if n == 0:
1014
+ clipped[i, j, c] = 0.0
1015
+ for f in range(num_frames):
1016
+ if not valid[f]:
1017
+ rej_mask[f, i, j, c] = True
1018
+ continue
1019
+ trim = int(trim_fraction * n)
1020
+ order = np.argsort(pix)
1021
+ if n > 2 * trim:
1022
+ keep_order = order[trim:n - trim]
1023
+ else:
1024
+ keep_order = order
1025
+ keep_mask = np.zeros(n, dtype=np.bool_)
1026
+ for k in range(keep_order.size):
1027
+ keep_mask[keep_order[k]] = True
1028
+ for idx in range(n):
1029
+ frame = orig_idx[idx]
1030
+ if not keep_mask[idx]:
1031
+ rej_mask[frame, i, j, c] = True
1032
+ else:
1033
+ rej_mask[frame, i, j, c] = False
1034
+ sorted_pix = pix[order]
1035
+ sorted_w = w[order]
1036
+ if n > 2 * trim:
1037
+ trimmed_values = sorted_pix[trim:n - trim]
1038
+ trimmed_weights = sorted_w[trim:n - trim]
1039
+ else:
1040
+ trimmed_values = sorted_pix
1041
+ trimmed_weights = sorted_w
1042
+ wsum = trimmed_weights.sum()
1043
+ if wsum > 0:
1044
+ clipped[i, j, c] = np.sum(trimmed_values * trimmed_weights) / wsum
1045
+ else:
1046
+ clipped[i, j, c] = np.median(trimmed_values)
1047
+ return clipped, rej_mask
1048
+
1049
+
1050
+ def trimmed_mean_weighted(stack, weights, trim_fraction=0.1):
1051
+ """
1052
+ Dispatcher that returns (clipped, rejection_mask) for trimmed mean.
1053
+ """
1054
+ if stack.ndim == 3:
1055
+ return trimmed_mean_weighted_3d(stack, weights, trim_fraction)
1056
+ elif stack.ndim == 4:
1057
+ return trimmed_mean_weighted_4d(stack, weights, trim_fraction)
1058
+ else:
1059
+ raise ValueError(f"trimmed_mean_weighted: stack must be 3D or 4D, got {stack.shape}")
1060
+
1061
+
1062
+ # -------------------------------
1063
+ # Extreme Studentized Deviate (ESD) Clipping (Weighted)
1064
+ # -------------------------------
1065
+
1066
+ @njit(parallel=True, fastmath=True)
1067
+ def esd_clip_weighted_3d(stack, weights, threshold=3.0):
1068
+ """
1069
+ ESD Clipping for a 3D mono stack.
1070
+ stack.shape == (F,H,W)
1071
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W).
1072
+ """
1073
+ num_frames, height, width = stack.shape
1074
+ clipped = np.empty((height, width), dtype=np.float32)
1075
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
1076
+
1077
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1078
+ pass
1079
+ elif weights.ndim == 3 and weights.shape == stack.shape:
1080
+ pass
1081
+ else:
1082
+ raise ValueError("esd_clip_weighted_3d: mismatch in shapes for 3D stack & weights")
1083
+
1084
+ for i in prange(height):
1085
+ for j in range(width):
1086
+ pix = stack[:, i, j]
1087
+ if weights.ndim == 1:
1088
+ w = weights[:]
1089
+ else:
1090
+ w = weights[:, i, j]
1091
+ valid = pix != 0
1092
+ values = pix[valid]
1093
+ wvals = w[valid]
1094
+ if values.size == 0:
1095
+ clipped[i, j] = 0.0
1096
+ for f in range(num_frames):
1097
+ if not valid[f]:
1098
+ rej_mask[f, i, j] = True
1099
+ continue
1100
+ mean_val = np.mean(values)
1101
+ std_val = np.std(values)
1102
+ if std_val == 0:
1103
+ clipped[i, j] = mean_val
1104
+ for f in range(num_frames):
1105
+ rej_mask[f, i, j] = False
1106
+ continue
1107
+ z_scores = np.abs((values - mean_val) / std_val)
1108
+ valid2 = z_scores < threshold
1109
+ # Mark rejected: for the valid entries, use valid2.
1110
+ idx = 0
1111
+ for f in range(num_frames):
1112
+ if valid[f]:
1113
+ if not valid2[idx]:
1114
+ rej_mask[f, i, j] = True
1115
+ else:
1116
+ rej_mask[f, i, j] = False
1117
+ idx += 1
1118
+ else:
1119
+ rej_mask[f, i, j] = True
1120
+ values = values[valid2]
1121
+ wvals = wvals[valid2]
1122
+ wsum = wvals.sum()
1123
+ if wsum > 0:
1124
+ clipped[i, j] = np.sum(values * wvals) / wsum
1125
+ else:
1126
+ clipped[i, j] = mean_val
1127
+ return clipped, rej_mask
1128
+
1129
+
1130
+ @njit(parallel=True, fastmath=True)
1131
+ def esd_clip_weighted_4d(stack, weights, threshold=3.0):
1132
+ """
1133
+ ESD Clipping for a 4D color stack.
1134
+ stack.shape == (F,H,W,C)
1135
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W,C).
1136
+ """
1137
+ num_frames, height, width, channels = stack.shape
1138
+ clipped = np.empty((height, width, channels), dtype=np.float32)
1139
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
1140
+
1141
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1142
+ pass
1143
+ elif weights.ndim == 4 and weights.shape == stack.shape:
1144
+ pass
1145
+ else:
1146
+ raise ValueError("esd_clip_weighted_4d: mismatch in shapes for 4D stack & weights")
1147
+
1148
+ for i in prange(height):
1149
+ for j in range(width):
1150
+ for c in range(channels):
1151
+ pix = stack[:, i, j, c]
1152
+ if weights.ndim == 1:
1153
+ w = weights[:]
1154
+ else:
1155
+ w = weights[:, i, j, c]
1156
+ valid = pix != 0
1157
+ values = pix[valid]
1158
+ wvals = w[valid]
1159
+ if values.size == 0:
1160
+ clipped[i, j, c] = 0.0
1161
+ for f in range(num_frames):
1162
+ if not valid[f]:
1163
+ rej_mask[f, i, j, c] = True
1164
+ continue
1165
+ mean_val = np.mean(values)
1166
+ std_val = np.std(values)
1167
+ if std_val == 0:
1168
+ clipped[i, j, c] = mean_val
1169
+ for f in range(num_frames):
1170
+ rej_mask[f, i, j, c] = False
1171
+ continue
1172
+ z_scores = np.abs((values - mean_val) / std_val)
1173
+ valid2 = z_scores < threshold
1174
+ idx = 0
1175
+ for f in range(num_frames):
1176
+ if valid[f]:
1177
+ if not valid2[idx]:
1178
+ rej_mask[f, i, j, c] = True
1179
+ else:
1180
+ rej_mask[f, i, j, c] = False
1181
+ idx += 1
1182
+ else:
1183
+ rej_mask[f, i, j, c] = True
1184
+ values = values[valid2]
1185
+ wvals = wvals[valid2]
1186
+ wsum = wvals.sum()
1187
+ if wsum > 0:
1188
+ clipped[i, j, c] = np.sum(values * wvals) / wsum
1189
+ else:
1190
+ clipped[i, j, c] = mean_val
1191
+ return clipped, rej_mask
1192
+
1193
+
1194
+ def esd_clip_weighted(stack, weights, threshold=3.0):
1195
+ """
1196
+ Dispatcher that returns (clipped, rejection_mask) for ESD clipping.
1197
+ """
1198
+ if stack.ndim == 3:
1199
+ return esd_clip_weighted_3d(stack, weights, threshold)
1200
+ elif stack.ndim == 4:
1201
+ return esd_clip_weighted_4d(stack, weights, threshold)
1202
+ else:
1203
+ raise ValueError(f"esd_clip_weighted: stack must be 3D or 4D, got {stack.shape}")
1204
+
1205
+
1206
+ # -------------------------------
1207
+ # Biweight Location (Weighted)
1208
+ # -------------------------------
1209
+
1210
+ @njit(parallel=True, fastmath=True)
1211
+ def biweight_location_weighted_3d(stack, weights, tuning_constant=6.0):
1212
+ """
1213
+ Biweight Location for a 3D mono stack.
1214
+ stack.shape == (F,H,W)
1215
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W).
1216
+ """
1217
+ num_frames, height, width = stack.shape
1218
+ clipped = np.empty((height, width), dtype=np.float32)
1219
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
1220
+
1221
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1222
+ pass
1223
+ elif weights.ndim == 3 and weights.shape == stack.shape:
1224
+ pass
1225
+ else:
1226
+ raise ValueError("biweight_location_weighted_3d: mismatch in shapes for 3D stack & weights")
1227
+
1228
+ for i in prange(height):
1229
+ for j in range(width):
1230
+ x = stack[:, i, j]
1231
+ if weights.ndim == 1:
1232
+ w = weights[:]
1233
+ else:
1234
+ w = weights[:, i, j]
1235
+ valid = x != 0
1236
+ x_valid = x[valid]
1237
+ w_valid = w[valid]
1238
+ # Record rejections for zeros:
1239
+ for f in range(num_frames):
1240
+ if not valid[f]:
1241
+ rej_mask[f, i, j] = True
1242
+ else:
1243
+ rej_mask[f, i, j] = False # initialize as accepted; may update below
1244
+ n = x_valid.size
1245
+ if n == 0:
1246
+ clipped[i, j] = 0.0
1247
+ continue
1248
+ M = np.median(x_valid)
1249
+ mad = np.median(np.abs(x_valid - M))
1250
+ if mad == 0:
1251
+ clipped[i, j] = M
1252
+ continue
1253
+ u = (x_valid - M) / (tuning_constant * mad)
1254
+ mask = np.abs(u) < 1
1255
+ # Mark frames that were excluded by the biweight rejection:
1256
+ idx = 0
1257
+ for f in range(num_frames):
1258
+ if valid[f]:
1259
+ if not mask[idx]:
1260
+ rej_mask[f, i, j] = True
1261
+ idx += 1
1262
+ x_masked = x_valid[mask]
1263
+ w_masked = w_valid[mask]
1264
+ numerator = ((x_masked - M) * (1 - u[mask]**2)**2 * w_masked).sum()
1265
+ denominator = ((1 - u[mask]**2)**2 * w_masked).sum()
1266
+ if denominator != 0:
1267
+ biweight = M + numerator / denominator
1268
+ else:
1269
+ biweight = M
1270
+ clipped[i, j] = biweight
1271
+ return clipped, rej_mask
1272
+
1273
+
1274
+ @njit(parallel=True, fastmath=True)
1275
+ def biweight_location_weighted_4d(stack, weights, tuning_constant=6.0):
1276
+ """
1277
+ Biweight Location for a 4D color stack.
1278
+ stack.shape == (F,H,W,C)
1279
+ Returns (clipped, rejection_mask) where rejection_mask has shape (F,H,W,C).
1280
+ """
1281
+ num_frames, height, width, channels = stack.shape
1282
+ clipped = np.empty((height, width, channels), dtype=np.float32)
1283
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
1284
+
1285
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1286
+ pass
1287
+ elif weights.ndim == 4 and weights.shape == stack.shape:
1288
+ pass
1289
+ else:
1290
+ raise ValueError("biweight_location_weighted_4d: mismatch in shapes for 4D stack & weights")
1291
+
1292
+ for i in prange(height):
1293
+ for j in range(width):
1294
+ for c in range(channels):
1295
+ x = stack[:, i, j, c]
1296
+ if weights.ndim == 1:
1297
+ w = weights[:]
1298
+ else:
1299
+ w = weights[:, i, j, c]
1300
+ valid = x != 0
1301
+ x_valid = x[valid]
1302
+ w_valid = w[valid]
1303
+ for f in range(num_frames):
1304
+ if not valid[f]:
1305
+ rej_mask[f, i, j, c] = True
1306
+ else:
1307
+ rej_mask[f, i, j, c] = False
1308
+ n = x_valid.size
1309
+ if n == 0:
1310
+ clipped[i, j, c] = 0.0
1311
+ continue
1312
+ M = np.median(x_valid)
1313
+ mad = np.median(np.abs(x_valid - M))
1314
+ if mad == 0:
1315
+ clipped[i, j, c] = M
1316
+ continue
1317
+ u = (x_valid - M) / (tuning_constant * mad)
1318
+ mask = np.abs(u) < 1
1319
+ idx = 0
1320
+ for f in range(num_frames):
1321
+ if valid[f]:
1322
+ if not mask[idx]:
1323
+ rej_mask[f, i, j, c] = True
1324
+ idx += 1
1325
+ x_masked = x_valid[mask]
1326
+ w_masked = w_valid[mask]
1327
+ numerator = ((x_masked - M) * (1 - u[mask]**2)**2 * w_masked).sum()
1328
+ denominator = ((1 - u[mask]**2)**2 * w_masked).sum()
1329
+ if denominator != 0:
1330
+ biweight = M + numerator / denominator
1331
+ else:
1332
+ biweight = M
1333
+ clipped[i, j, c] = biweight
1334
+ return clipped, rej_mask
1335
+
1336
+
1337
+ def biweight_location_weighted(stack, weights, tuning_constant=6.0):
1338
+ """
1339
+ Dispatcher that returns (clipped, rejection_mask) for biweight location.
1340
+ """
1341
+ if stack.ndim == 3:
1342
+ return biweight_location_weighted_3d(stack, weights, tuning_constant)
1343
+ elif stack.ndim == 4:
1344
+ return biweight_location_weighted_4d(stack, weights, tuning_constant)
1345
+ else:
1346
+ raise ValueError(f"biweight_location_weighted: stack must be 3D or 4D, got {stack.shape}")
1347
+
1348
+
1349
+ # -------------------------------
1350
+ # Modified Z-Score Clipping (Weighted)
1351
+ # -------------------------------
1352
+
1353
+ @njit(parallel=True, fastmath=True)
1354
+ def modified_zscore_clip_weighted_3d(stack, weights, threshold=3.5):
1355
+ """
1356
+ Modified Z-Score Clipping for a 3D mono stack.
1357
+ stack.shape == (F,H,W)
1358
+ Returns (clipped, rejection_mask) with rejection_mask shape (F,H,W).
1359
+ """
1360
+ num_frames, height, width = stack.shape
1361
+ clipped = np.empty((height, width), dtype=np.float32)
1362
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
1363
+
1364
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1365
+ pass
1366
+ elif weights.ndim == 3 and weights.shape == stack.shape:
1367
+ pass
1368
+ else:
1369
+ raise ValueError("modified_zscore_clip_weighted_3d: mismatch in shapes for 3D stack & weights")
1370
+
1371
+ for i in prange(height):
1372
+ for j in range(width):
1373
+ x = stack[:, i, j]
1374
+ if weights.ndim == 1:
1375
+ w = weights[:]
1376
+ else:
1377
+ w = weights[:, i, j]
1378
+ valid = x != 0
1379
+ x_valid = x[valid]
1380
+ w_valid = w[valid]
1381
+ if x_valid.size == 0:
1382
+ clipped[i, j] = 0.0
1383
+ for f in range(num_frames):
1384
+ if not valid[f]:
1385
+ rej_mask[f, i, j] = True
1386
+ continue
1387
+ median_val = np.median(x_valid)
1388
+ mad = np.median(np.abs(x_valid - median_val))
1389
+ if mad == 0:
1390
+ clipped[i, j] = median_val
1391
+ for f in range(num_frames):
1392
+ rej_mask[f, i, j] = False
1393
+ continue
1394
+ modified_z = 0.6745 * (x_valid - median_val) / mad
1395
+ valid2 = np.abs(modified_z) < threshold
1396
+ idx = 0
1397
+ for f in range(num_frames):
1398
+ if valid[f]:
1399
+ if not valid2[idx]:
1400
+ rej_mask[f, i, j] = True
1401
+ else:
1402
+ rej_mask[f, i, j] = False
1403
+ idx += 1
1404
+ else:
1405
+ rej_mask[f, i, j] = True
1406
+ x_final = x_valid[valid2]
1407
+ w_final = w_valid[valid2]
1408
+ wsum = w_final.sum()
1409
+ if wsum > 0:
1410
+ clipped[i, j] = np.sum(x_final * w_final) / wsum
1411
+ else:
1412
+ clipped[i, j] = median_val
1413
+ return clipped, rej_mask
1414
+
1415
+
1416
+ @njit(parallel=True, fastmath=True)
1417
+ def modified_zscore_clip_weighted_4d(stack, weights, threshold=3.5):
1418
+ """
1419
+ Modified Z-Score Clipping for a 4D color stack.
1420
+ stack.shape == (F,H,W,C)
1421
+ Returns (clipped, rejection_mask) with rejection_mask shape (F,H,W,C).
1422
+ """
1423
+ num_frames, height, width, channels = stack.shape
1424
+ clipped = np.empty((height, width, channels), dtype=np.float32)
1425
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
1426
+
1427
+ if weights.ndim == 1 and weights.shape[0] == num_frames:
1428
+ pass
1429
+ elif weights.ndim == 4 and weights.shape == stack.shape:
1430
+ pass
1431
+ else:
1432
+ raise ValueError("modified_zscore_clip_weighted_4d: mismatch in shapes for 4D stack & weights")
1433
+
1434
+ for i in prange(height):
1435
+ for j in range(width):
1436
+ for c in range(channels):
1437
+ x = stack[:, i, j, c]
1438
+ if weights.ndim == 1:
1439
+ w = weights[:]
1440
+ else:
1441
+ w = weights[:, i, j, c]
1442
+ valid = x != 0
1443
+ x_valid = x[valid]
1444
+ w_valid = w[valid]
1445
+ if x_valid.size == 0:
1446
+ clipped[i, j, c] = 0.0
1447
+ for f in range(num_frames):
1448
+ if not valid[f]:
1449
+ rej_mask[f, i, j, c] = True
1450
+ continue
1451
+ median_val = np.median(x_valid)
1452
+ mad = np.median(np.abs(x_valid - median_val))
1453
+ if mad == 0:
1454
+ clipped[i, j, c] = median_val
1455
+ for f in range(num_frames):
1456
+ rej_mask[f, i, j, c] = False
1457
+ continue
1458
+ modified_z = 0.6745 * (x_valid - median_val) / mad
1459
+ valid2 = np.abs(modified_z) < threshold
1460
+ idx = 0
1461
+ for f in range(num_frames):
1462
+ if valid[f]:
1463
+ if not valid2[idx]:
1464
+ rej_mask[f, i, j, c] = True
1465
+ else:
1466
+ rej_mask[f, i, j, c] = False
1467
+ idx += 1
1468
+ else:
1469
+ rej_mask[f, i, j, c] = True
1470
+ x_final = x_valid[valid2]
1471
+ w_final = w_valid[valid2]
1472
+ wsum = w_final.sum()
1473
+ if wsum > 0:
1474
+ clipped[i, j, c] = np.sum(x_final * w_final) / wsum
1475
+ else:
1476
+ clipped[i, j, c] = median_val
1477
+ return clipped, rej_mask
1478
+
1479
+
1480
+ def modified_zscore_clip_weighted(stack, weights, threshold=3.5):
1481
+ """
1482
+ Dispatcher that returns (clipped, rejection_mask) for modified z-score clipping.
1483
+ """
1484
+ if stack.ndim == 3:
1485
+ return modified_zscore_clip_weighted_3d(stack, weights, threshold)
1486
+ elif stack.ndim == 4:
1487
+ return modified_zscore_clip_weighted_4d(stack, weights, threshold)
1488
+ else:
1489
+ raise ValueError(f"modified_zscore_clip_weighted: stack must be 3D or 4D, got {stack.shape}")
1490
+
1491
+
1492
+ # -------------------------------
1493
+ # Windsorized Sigma Clipping (Non-weighted)
1494
+ # -------------------------------
1495
+
1496
+ @njit(parallel=True, fastmath=True)
1497
+ def windsorized_sigma_clip_3d(stack, lower=2.5, upper=2.5):
1498
+ """
1499
+ Windsorized Sigma Clipping for a 3D mono stack (non-weighted).
1500
+ stack.shape == (F,H,W)
1501
+ Returns (clipped, rejection_mask) where rejection_mask is (F,H,W).
1502
+ """
1503
+ num_frames, height, width = stack.shape
1504
+ clipped = np.zeros((height, width), dtype=np.float32)
1505
+ rej_mask = np.zeros((num_frames, height, width), dtype=np.bool_)
1506
+
1507
+ for i in prange(height):
1508
+ for j in range(width):
1509
+ pixel_values = stack[:, i, j]
1510
+ median_val = np.median(pixel_values)
1511
+ std_dev = np.std(pixel_values)
1512
+ lower_bound = median_val - lower * std_dev
1513
+ upper_bound = median_val + upper * std_dev
1514
+ valid = (pixel_values >= lower_bound) & (pixel_values <= upper_bound)
1515
+ for f in range(num_frames):
1516
+ rej_mask[f, i, j] = not valid[f]
1517
+ valid_vals = pixel_values[valid]
1518
+ if valid_vals.size > 0:
1519
+ clipped[i, j] = np.mean(valid_vals)
1520
+ else:
1521
+ clipped[i, j] = median_val
1522
+ return clipped, rej_mask
1523
+
1524
+
1525
+ @njit(parallel=True, fastmath=True)
1526
+ def windsorized_sigma_clip_4d(stack, lower=2.5, upper=2.5):
1527
+ """
1528
+ Windsorized Sigma Clipping for a 4D color stack (non-weighted).
1529
+ stack.shape == (F,H,W,C)
1530
+ Returns (clipped, rejection_mask) where rejection_mask is (F,H,W,C).
1531
+ """
1532
+ num_frames, height, width, channels = stack.shape
1533
+ clipped = np.zeros((height, width, channels), dtype=np.float32)
1534
+ rej_mask = np.zeros((num_frames, height, width, channels), dtype=np.bool_)
1535
+
1536
+ for i in prange(height):
1537
+ for j in range(width):
1538
+ for c in range(channels):
1539
+ pixel_values = stack[:, i, j, c]
1540
+ median_val = np.median(pixel_values)
1541
+ std_dev = np.std(pixel_values)
1542
+ lower_bound = median_val - lower * std_dev
1543
+ upper_bound = median_val + upper * std_dev
1544
+ valid = (pixel_values >= lower_bound) & (pixel_values <= upper_bound)
1545
+ for f in range(num_frames):
1546
+ rej_mask[f, i, j, c] = not valid[f]
1547
+ valid_vals = pixel_values[valid]
1548
+ if valid_vals.size > 0:
1549
+ clipped[i, j, c] = np.mean(valid_vals)
1550
+ else:
1551
+ clipped[i, j, c] = median_val
1552
+ return clipped, rej_mask
1553
+
1554
+
1555
+ def windsorized_sigma_clip(stack, lower=2.5, upper=2.5):
1556
+ """
1557
+ Dispatcher function that calls either the 3D or 4D specialized Numba function,
1558
+ depending on 'stack.ndim'.
1559
+ """
1560
+ if stack.ndim == 3:
1561
+ return windsorized_sigma_clip_3d(stack, lower, upper)
1562
+ elif stack.ndim == 4:
1563
+ return windsorized_sigma_clip_4d(stack, lower, upper)
1564
+ else:
1565
+ raise ValueError(f"windsorized_sigma_clip: stack must be 3D or 4D, got {stack.shape}")
1566
+
1567
+ def max_value_stack(stack, weights=None):
1568
+ """
1569
+ Stacking by taking the maximum value along the frame axis.
1570
+ Returns (clipped, rejection_mask) for compatibility:
1571
+ - clipped: H×W (or H×W×C)
1572
+ - rejection_mask: same shape as stack, all False
1573
+ """
1574
+ clipped = np.max(stack, axis=0)
1575
+ rej_mask = np.zeros(stack.shape, dtype=bool)
1576
+ return clipped, rej_mask
1577
+
1578
+ @njit(parallel=True)
1579
+ def subtract_dark_with_pedestal_3d(frames, dark_frame, pedestal):
1580
+ """
1581
+ For mono stack:
1582
+ frames.shape == (F,H,W)
1583
+ dark_frame.shape == (H,W)
1584
+ Adds 'pedestal' after subtracting dark_frame from each frame.
1585
+ Returns the same shape (F,H,W).
1586
+ """
1587
+ num_frames, height, width = frames.shape
1588
+ result = np.empty_like(frames, dtype=np.float32)
1589
+
1590
+ # Validate dark_frame shape
1591
+ if dark_frame.ndim != 2 or dark_frame.shape != (height, width):
1592
+ raise ValueError(
1593
+ "subtract_dark_with_pedestal_3d: for 3D frames, dark_frame must be 2D (H,W)"
1594
+ )
1595
+
1596
+ for i in prange(num_frames):
1597
+ for y in range(height):
1598
+ for x in range(width):
1599
+ result[i, y, x] = frames[i, y, x] - dark_frame[y, x] + pedestal
1600
+
1601
+ return result
1602
+
1603
+ @njit(parallel=True)
1604
+ def subtract_dark_with_pedestal_4d(frames, dark_frame, pedestal):
1605
+ """
1606
+ For color stack:
1607
+ frames.shape == (F,H,W,C)
1608
+ dark_frame.shape == (H,W,C)
1609
+ Adds 'pedestal' after subtracting dark_frame from each frame.
1610
+ Returns the same shape (F,H,W,C).
1611
+ """
1612
+ num_frames, height, width, channels = frames.shape
1613
+ result = np.empty_like(frames, dtype=np.float32)
1614
+
1615
+ # Validate dark_frame shape
1616
+ if dark_frame.ndim != 3 or dark_frame.shape != (height, width, channels):
1617
+ raise ValueError(
1618
+ "subtract_dark_with_pedestal_4d: for 4D frames, dark_frame must be 3D (H,W,C)"
1619
+ )
1620
+
1621
+ for i in prange(num_frames):
1622
+ for y in range(height):
1623
+ for x in range(width):
1624
+ for c in range(channels):
1625
+ result[i, y, x, c] = frames[i, y, x, c] - dark_frame[y, x, c] + pedestal
1626
+
1627
+ return result
1628
+
1629
+ def subtract_dark_with_pedestal(frames, dark_frame, pedestal):
1630
+ """
1631
+ Dispatcher function that calls either the 3D or 4D specialized Numba function
1632
+ depending on 'frames.ndim'.
1633
+ """
1634
+ if frames.ndim == 3:
1635
+ return subtract_dark_with_pedestal_3d(frames, dark_frame, pedestal)
1636
+ elif frames.ndim == 4:
1637
+ return subtract_dark_with_pedestal_4d(frames, dark_frame, pedestal)
1638
+ else:
1639
+ raise ValueError(
1640
+ f"subtract_dark_with_pedestal: frames must be 3D or 4D, got {frames.shape}"
1641
+ )
1642
+
1643
+
1644
+ @njit(parallel=True, fastmath=True, cache=True)
1645
+ def _parallel_measure_frames_stack(stack): # stack: float32[N,H,W] or float32[N,H,W,C]
1646
+ n = stack.shape[0]
1647
+ means = np.empty(n, np.float32)
1648
+ for i in prange(n):
1649
+ # Option A: mean then cast
1650
+ # m = np.mean(stack[i])
1651
+ # means[i] = np.float32(m)
1652
+
1653
+ # Option B (often a hair faster): sum / size then cast
1654
+ s = np.sum(stack[i]) # no kwargs
1655
+ means[i] = np.float32(s / stack[i].size)
1656
+ return means
1657
+
1658
+ def parallel_measure_frames(images_py):
1659
+ a = [np.ascontiguousarray(x, dtype=np.float32) for x in images_py]
1660
+ a = [x[:, :, None] if x.ndim == 2 else x for x in a]
1661
+ stack = np.ascontiguousarray(np.stack(a, axis=0)) # (N,H,W,C)
1662
+ return _parallel_measure_frames_stack(stack)
1663
+
1664
+ @njit(fastmath=True)
1665
+ def fast_mad(image):
1666
+ """ Computes the Median Absolute Deviation (MAD) as a robust noise estimator. """
1667
+ flat_image = image.ravel() # ✅ Flatten the 2D array into 1D
1668
+ median_val = np.median(flat_image) # Compute median
1669
+ mad = np.median(np.abs(flat_image - median_val)) # Compute MAD
1670
+ return mad * 1.4826 # ✅ Scale MAD to match standard deviation (for Gaussian noise)
1671
+
1672
+
1673
+
1674
+ @njit(fastmath=True)
1675
+ def compute_snr(image):
1676
+ """ Computes the Signal-to-Noise Ratio (SNR) using fast Numba std. """
1677
+ mean_signal = np.mean(image)
1678
+ noise = compute_noise(image)
1679
+ return mean_signal / noise if noise > 0 else 0
1680
+
1681
+
1682
+
1683
+
1684
+ @njit(fastmath=True)
1685
+ def compute_noise(image):
1686
+ """ Estimates noise using Median Absolute Deviation (MAD). """
1687
+ return fast_mad(image)
1688
+
1689
+ def _downsample_for_stars(img: np.ndarray, factor: int = 4) -> np.ndarray:
1690
+ """
1691
+ Very cheap spatial downsample for star counting.
1692
+ Works on mono or RGB. Returns float32 2D.
1693
+ """
1694
+ if img.ndim == 3 and img.shape[-1] == 3:
1695
+ # luma first
1696
+ r, g, b = img[..., 0], img[..., 1], img[..., 2]
1697
+ img = 0.2126*r + 0.7152*g + 0.0722*b
1698
+ img = np.asarray(img, dtype=np.float32, order="C")
1699
+ if factor <= 1:
1700
+ return img
1701
+ # stride (fast & cache friendly), not interpolation
1702
+ return img[::factor, ::factor]
1703
+
1704
+
1705
+ def fast_star_count_lite(img: np.ndarray,
1706
+ sample_stride: int = 8,
1707
+ localmax_k: int = 3,
1708
+ thr_sigma: float = 4.0,
1709
+ max_ecc_samples: int = 200) -> tuple[int, float]:
1710
+ """
1711
+ Super-fast star counter:
1712
+ • sample a tiny subset to estimate background mean/std
1713
+ • local-maxima on small image
1714
+ • optional rough eccentricity on a small random subset
1715
+ Returns (count, avg_ecc).
1716
+ """
1717
+ # img is 2D float32, already downsampled
1718
+ H, W = img.shape
1719
+ # 1) quick background stats on a sparse grid
1720
+ samp = img[::sample_stride, ::sample_stride]
1721
+ mu = float(np.mean(samp))
1722
+ sigma = float(np.std(samp))
1723
+ thr = mu + thr_sigma * max(sigma, 1e-6)
1724
+
1725
+ # 2) find local maxima above threshold
1726
+ # small structuring element; k must be odd
1727
+ k = localmax_k if (localmax_k % 2 == 1) else (localmax_k + 1)
1728
+ se = np.ones((k, k), np.uint8)
1729
+ # dilate the image (on float -> do it via cv2.dilate after scaling)
1730
+ # scale to 16-bit to keep numeric fidelity (cheap)
1731
+ scaled = (img * (65535.0 / max(np.max(img), 1e-6))).astype(np.uint16)
1732
+ dil = cv2.dilate(scaled, se)
1733
+ # peaks are pixels that equal the local max and exceed thr
1734
+ peaks = (scaled == dil) & (img > thr)
1735
+ count = int(np.count_nonzero(peaks))
1736
+
1737
+ # 3) (optional) rough eccentricity on a tiny subset
1738
+ if count == 0:
1739
+ return 0, 0.0
1740
+ if max_ecc_samples <= 0:
1741
+ return count, 0.0
1742
+
1743
+ ys, xs = np.where(peaks)
1744
+ if xs.size > max_ecc_samples:
1745
+ idx = np.random.choice(xs.size, max_ecc_samples, replace=False)
1746
+ xs, ys = xs[idx], ys[idx]
1747
+
1748
+ ecc_vals = []
1749
+ # small window around each peak
1750
+ r = 2 # 5×5 window
1751
+ for x, y in zip(xs, ys):
1752
+ x0, x1 = max(0, x - r), min(W, x + r + 1)
1753
+ y0, y1 = max(0, y - r), min(H, y + r + 1)
1754
+ patch = img[y0:y1, x0:x1]
1755
+ if patch.size < 9:
1756
+ continue
1757
+ # second moments for ellipse approximation
1758
+ yy, xx = np.mgrid[y0:y1, x0:x1]
1759
+ yy = yy.astype(np.float32) - y
1760
+ xx = xx.astype(np.float32) - x
1761
+ w = patch - patch.min()
1762
+ s = float(w.sum())
1763
+ if s <= 0:
1764
+ continue
1765
+ mxx = float((w * (xx*xx)).sum() / s)
1766
+ myy = float((w * (yy*yy)).sum() / s)
1767
+ # approximate major/minor from variances
1768
+ a = math.sqrt(max(mxx, myy))
1769
+ b = math.sqrt(min(mxx, myy))
1770
+ if a > 1e-6:
1771
+ e = math.sqrt(max(0.0, 1.0 - (b*b)/(a*a)))
1772
+ ecc_vals.append(e)
1773
+ avg_ecc = float(np.mean(ecc_vals)) if ecc_vals else 0.0
1774
+ return count, avg_ecc
1775
+
1776
+
1777
+
1778
+ def compute_star_count_fast_preview(preview_2d: np.ndarray) -> tuple[int, float]:
1779
+ """
1780
+ Wrapper used in measurement: downsample aggressively and run the lite counter.
1781
+ """
1782
+ tiny = _downsample_for_stars(preview_2d, factor=4) # try 4–8 depending on your sensor
1783
+ return fast_star_count_lite(tiny, sample_stride=8, localmax_k=3, thr_sigma=4.0, max_ecc_samples=120)
1784
+
1785
+
1786
+
1787
+ def compute_star_count(image):
1788
+ """Fast star detection with robust pre-stretch for linear data."""
1789
+ return fast_star_count(image)
1790
+
1791
+ def fast_star_count(
1792
+ image,
1793
+ blur_size=None, # adaptive if None
1794
+ threshold_factor=0.8,
1795
+ min_area=None, # adaptive if None
1796
+ max_area=None, # adaptive if None
1797
+ *,
1798
+ gamma=0.45, # <1 brightens faint signal; 0.35–0.55 is a good range
1799
+ p_lo=0.1, # robust low percentile for stretch
1800
+ p_hi=99.8 # robust high percentile for stretch
1801
+ ):
1802
+ """
1803
+ Estimate star count + avg eccentricity from a 2D float/uint8 image.
1804
+ Now does robust percentile stretch + gamma in float BEFORE 8-bit/Otsu.
1805
+ """
1806
+ # 1) Ensure 2D grayscale (stay float32)
1807
+ if image.ndim == 3:
1808
+ # RGB -> luma
1809
+ r, g, b = image[..., 0], image[..., 1], image[..., 2]
1810
+ img = (0.2126 * r + 0.7152 * g + 0.0722 * b).astype(np.float32, copy=False)
1811
+ else:
1812
+ img = np.asarray(image, dtype=np.float32, order="C")
1813
+
1814
+ H, W = img.shape[:2]
1815
+ short_side = max(1, min(H, W))
1816
+
1817
+ # Adaptive params
1818
+ if blur_size is None:
1819
+ k = max(3, int(round(short_side / 80)))
1820
+ blur_size = k if (k % 2 == 1) else (k + 1)
1821
+ if min_area is None:
1822
+ min_area = 1
1823
+ if max_area is None:
1824
+ max_area = max(100, int(0.01 * H * W))
1825
+
1826
+ # 2) Robust percentile stretch in float (no 8-bit yet)
1827
+ # This lifts the sky background and pulls faint stars up before thresholding.
1828
+ lo = float(np.percentile(img, p_lo))
1829
+ hi = float(np.percentile(img, p_hi))
1830
+ if not (hi > lo):
1831
+ lo, hi = float(img.min()), float(img.max())
1832
+ if not (hi > lo):
1833
+ return 0, 0.0
1834
+
1835
+ norm = (img - lo) / max(1e-8, (hi - lo))
1836
+ norm = np.clip(norm, 0.0, 1.0)
1837
+
1838
+ # 3) Gamma (<1 brightens low end)
1839
+ if gamma and gamma > 0:
1840
+ norm = np.power(norm, gamma, dtype=np.float32)
1841
+
1842
+ # 4) Convert to 8-bit ONLY after stretch/gamma (preserves faint structure)
1843
+ image_8u = (norm * 255.0).astype(np.uint8)
1844
+
1845
+ # 5) Blur + subtract (unsharp-ish)
1846
+ blurred = cv2.GaussianBlur(image_8u, (blur_size, blur_size), 0)
1847
+ sub = cv2.absdiff(image_8u, blurred)
1848
+
1849
+ # 6) Otsu + threshold_factor
1850
+ otsu, _ = cv2.threshold(sub, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
1851
+ thr = max(2, int(otsu * threshold_factor))
1852
+ _, mask = cv2.threshold(sub, thr, 255, cv2.THRESH_BINARY)
1853
+
1854
+ # 7) Morph open *only* on larger frames (tiny frames lose stars otherwise)
1855
+ if short_side >= 600:
1856
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((2, 2), np.uint8))
1857
+
1858
+ # 8) Contours → area filter → eccentricity
1859
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
1860
+
1861
+ star_count = 0
1862
+ ecc_values = []
1863
+ for c in contours:
1864
+ area = cv2.contourArea(c)
1865
+ if area < min_area or area > max_area:
1866
+ continue
1867
+ if len(c) < 5:
1868
+ continue
1869
+ (_, _), (a, b), _ = cv2.fitEllipse(c)
1870
+ if b > a: a, b = b, a
1871
+ if a > 0:
1872
+ e = math.sqrt(max(0.0, 1.0 - (b * b) / (a * a)))
1873
+ else:
1874
+ e = 0.0
1875
+ ecc_values.append(e)
1876
+ star_count += 1
1877
+
1878
+ # 9) Gentle fallback if too few detections: lower threshold & smaller blur
1879
+ if star_count < 5:
1880
+ k2 = max(3, (blur_size // 2) | 1)
1881
+ blurred2 = cv2.GaussianBlur(image_8u, (k2, k2), 0)
1882
+ sub2 = cv2.absdiff(image_8u, blurred2)
1883
+ otsu2, _ = cv2.threshold(sub2, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
1884
+ thr2 = max(2, int(otsu2 * 0.6)) # more permissive
1885
+ _, mask2 = cv2.threshold(sub2, thr2, 255, cv2.THRESH_BINARY)
1886
+ contours2, _ = cv2.findContours(mask2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
1887
+ star_count = 0
1888
+ ecc_values = []
1889
+ for c in contours2:
1890
+ area = cv2.contourArea(c)
1891
+ if area < 1 or area > max_area:
1892
+ continue
1893
+ if len(c) < 5:
1894
+ continue
1895
+ (_, _), (a, b), _ = cv2.fitEllipse(c)
1896
+ if b > a: a, b = b, a
1897
+ e = math.sqrt(max(0.0, 1.0 - (b * b) / (a * a))) if a > 0 else 0.0
1898
+ ecc_values.append(e)
1899
+ star_count += 1
1900
+
1901
+ avg_ecc = float(np.mean(ecc_values)) if star_count > 0 else 0.0
1902
+ return star_count, avg_ecc
1903
+
1904
+ @njit(parallel=True, fastmath=True)
1905
+ def normalize_images_3d(stack, ref_median):
1906
+ """
1907
+ Normalizes each frame in a 3D mono stack (F,H,W)
1908
+ so that its median equals ref_median.
1909
+
1910
+ Returns a 3D result (F,H,W).
1911
+ """
1912
+ num_frames, height, width = stack.shape
1913
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1914
+
1915
+ for i in prange(num_frames):
1916
+ # shape of one frame: (H,W)
1917
+ img = stack[i]
1918
+ img_median = np.median(img)
1919
+
1920
+ # Prevent division by zero
1921
+ scale_factor = ref_median / max(img_median, 1e-6)
1922
+ # Scale the entire 2D frame
1923
+ normalized_stack[i] = img * scale_factor
1924
+
1925
+ return normalized_stack
1926
+
1927
+ @njit(parallel=True, fastmath=True)
1928
+ def normalize_images_4d(stack, ref_median):
1929
+ """
1930
+ Normalizes each frame in a 4D color stack (F,H,W,C)
1931
+ so that its median equals ref_median.
1932
+
1933
+ Returns a 4D result (F,H,W,C).
1934
+ """
1935
+ num_frames, height, width, channels = stack.shape
1936
+ normalized_stack = np.zeros_like(stack, dtype=np.float32)
1937
+
1938
+ for i in prange(num_frames):
1939
+ # shape of one frame: (H,W,C)
1940
+ img = stack[i] # (H,W,C)
1941
+ # Flatten to 1D to compute median across all channels/pixels
1942
+ img_median = np.median(img.ravel())
1943
+
1944
+ # Prevent division by zero
1945
+ scale_factor = ref_median / max(img_median, 1e-6)
1946
+
1947
+ # Scale the entire 3D frame
1948
+ for y in range(height):
1949
+ for x in range(width):
1950
+ for c in range(channels):
1951
+ normalized_stack[i, y, x, c] = img[y, x, c] * scale_factor
1952
+
1953
+ return normalized_stack
1954
+
1955
+ def normalize_images(stack, ref_median):
1956
+ """
1957
+ Dispatcher that calls either the 3D or 4D specialized Numba function
1958
+ depending on 'stack.ndim'.
1959
+
1960
+ - If stack.ndim == 3, we assume shape (F,H,W).
1961
+ - If stack.ndim == 4, we assume shape (F,H,W,C).
1962
+ """
1963
+ if stack.ndim == 3:
1964
+ return normalize_images_3d(stack, ref_median)
1965
+ elif stack.ndim == 4:
1966
+ return normalize_images_4d(stack, ref_median)
1967
+ else:
1968
+ raise ValueError(f"normalize_images: stack must be 3D or 4D, got shape {stack.shape}")
1969
+
1970
+ @njit(parallel=True, fastmath=True)
1971
+ def _bilinear_interpolate_numba(out):
1972
+ H, W, C = out.shape
1973
+ for c in range(C):
1974
+ for y in prange(H):
1975
+ for x in range(W):
1976
+ if out[y, x, c] == 0:
1977
+ sumv = 0.0
1978
+ cnt = 0
1979
+ # 3x3 neighborhood average of non-zero samples (simple & fast)
1980
+ for dy in (-1, 0, 1):
1981
+ yy = y + dy
1982
+ if yy < 0 or yy >= H:
1983
+ continue
1984
+ for dx in (-1, 0, 1):
1985
+ xx = x + dx
1986
+ if xx < 0 or xx >= W:
1987
+ continue
1988
+ v = out[yy, xx, c]
1989
+ if v != 0:
1990
+ sumv += v
1991
+ cnt += 1
1992
+ if cnt > 0:
1993
+ out[y, x, c] = sumv / cnt
1994
+ return out
1995
+
1996
+
1997
+ @njit(parallel=True, fastmath=True)
1998
+ def _edge_aware_interpolate_numba(out):
1999
+ """
2000
+ For each pixel in out (shape: (H,W,3)) where out[y,x,c] == 0,
2001
+ use a simple edge-aware approach:
2002
+ 1) Compute horizontal gradient = abs( left - right )
2003
+ 2) Compute vertical gradient = abs( top - bottom )
2004
+ 3) Choose the direction with the smaller gradient => average neighbors
2005
+ 4) If neighbors are missing or zero, fallback to a small 3x3 average
2006
+
2007
+ This is simpler than AHD but usually better than naive bilinear
2008
+ for high-contrast features like star cores.
2009
+ """
2010
+ H, W, C = out.shape
2011
+
2012
+ for c in range(C):
2013
+ for y in prange(H):
2014
+ for x in range(W):
2015
+ if out[y, x, c] == 0:
2016
+ # Gather immediate neighbors
2017
+ left = 0.0
2018
+ right = 0.0
2019
+ top = 0.0
2020
+ bottom = 0.0
2021
+ have_left = False
2022
+ have_right = False
2023
+ have_top = False
2024
+ have_bottom = False
2025
+
2026
+ # Left
2027
+ if x - 1 >= 0:
2028
+ val = out[y, x - 1, c]
2029
+ if val != 0:
2030
+ left = val
2031
+ have_left = True
2032
+
2033
+ # Right
2034
+ if x + 1 < W:
2035
+ val = out[y, x + 1, c]
2036
+ if val != 0:
2037
+ right = val
2038
+ have_right = True
2039
+
2040
+ # Top
2041
+ if y - 1 >= 0:
2042
+ val = out[y - 1, x, c]
2043
+ if val != 0:
2044
+ top = val
2045
+ have_top = True
2046
+
2047
+ # Bottom
2048
+ if y + 1 < H:
2049
+ val = out[y + 1, x, c]
2050
+ if val != 0:
2051
+ bottom = val
2052
+ have_bottom = True
2053
+
2054
+ # Compute gradients
2055
+ # If we don't have valid neighbors for that direction,
2056
+ # set the gradient to a large number => won't be chosen
2057
+ gh = 1e6
2058
+ gv = 1e6
2059
+
2060
+ if have_left and have_right:
2061
+ gh = abs(left - right)
2062
+ if have_top and have_bottom:
2063
+ gv = abs(top - bottom)
2064
+
2065
+ # Decide which direction to interpolate
2066
+ if gh < gv and have_left and have_right:
2067
+ # Horizontal interpolation
2068
+ out[y, x, c] = 0.5 * (left + right)
2069
+ elif gv <= gh and have_top and have_bottom:
2070
+ # Vertical interpolation
2071
+ out[y, x, c] = 0.5 * (top + bottom)
2072
+ else:
2073
+ # Fallback: average 3×3 region
2074
+ sumv = 0.0
2075
+ count = 0
2076
+ for dy in range(-1, 2):
2077
+ for dx in range(-1, 2):
2078
+ yy = y + dy
2079
+ xx = x + dx
2080
+ if 0 <= yy < H and 0 <= xx < W:
2081
+ val = out[yy, xx, c]
2082
+ if val != 0:
2083
+ sumv += val
2084
+ count += 1
2085
+ if count > 0:
2086
+ out[y, x, c] = sumv / count
2087
+
2088
+ return out
2089
+ # === Separate Full-Resolution Demosaicing Kernels ===
2090
+ # These njit functions assume the raw image is arranged in a Bayer pattern
2091
+ # and that we want a full (H,W,3) output.
2092
+
2093
+ @njit(parallel=True, fastmath=True)
2094
+ def debayer_RGGB_fullres_fast(image, interpolate=True):
2095
+ H, W = image.shape
2096
+ out = np.zeros((H, W, 3), dtype=image.dtype)
2097
+ for y in prange(H):
2098
+ for x in range(W):
2099
+ if (y & 1) == 0:
2100
+ if (x & 1) == 0: out[y, x, 0] = image[y, x] # R
2101
+ else: out[y, x, 1] = image[y, x] # G
2102
+ else:
2103
+ if (x & 1) == 0: out[y, x, 1] = image[y, x] # G
2104
+ else: out[y, x, 2] = image[y, x] # B
2105
+ if interpolate:
2106
+ _edge_aware_interpolate_numba(out)
2107
+ return out
2108
+
2109
+ @njit(parallel=True, fastmath=True)
2110
+ def debayer_BGGR_fullres_fast(image, interpolate=True):
2111
+ H, W = image.shape
2112
+ out = np.zeros((H, W, 3), dtype=image.dtype)
2113
+ for y in prange(H):
2114
+ for x in range(W):
2115
+ if (y & 1) == 0:
2116
+ if (x & 1) == 0: out[y, x, 2] = image[y, x] # B
2117
+ else: out[y, x, 1] = image[y, x] # G
2118
+ else:
2119
+ if (x & 1) == 0: out[y, x, 1] = image[y, x] # G
2120
+ else: out[y, x, 0] = image[y, x] # R
2121
+ if interpolate:
2122
+ _edge_aware_interpolate_numba(out)
2123
+ return out
2124
+
2125
+ @njit(parallel=True, fastmath=True)
2126
+ def debayer_GRBG_fullres_fast(image, interpolate=True):
2127
+ H, W = image.shape
2128
+ out = np.zeros((H, W, 3), dtype=image.dtype)
2129
+ for y in prange(H):
2130
+ for x in range(W):
2131
+ if (y & 1) == 0:
2132
+ if (x & 1) == 0: out[y, x, 1] = image[y, x] # G
2133
+ else: out[y, x, 0] = image[y, x] # R
2134
+ else:
2135
+ if (x & 1) == 0: out[y, x, 2] = image[y, x] # B
2136
+ else: out[y, x, 1] = image[y, x] # G
2137
+ if interpolate:
2138
+ _edge_aware_interpolate_numba(out)
2139
+ return out
2140
+
2141
+ @njit(parallel=True, fastmath=True)
2142
+ def debayer_GBRG_fullres_fast(image, interpolate=True):
2143
+ H, W = image.shape
2144
+ out = np.zeros((H, W, 3), dtype=image.dtype)
2145
+ for y in prange(H):
2146
+ for x in range(W):
2147
+ if (y & 1) == 0:
2148
+ if (x & 1) == 0: out[y, x, 1] = image[y, x] # G
2149
+ else: out[y, x, 2] = image[y, x] # B
2150
+ else:
2151
+ if (x & 1) == 0: out[y, x, 0] = image[y, x] # R
2152
+ else: out[y, x, 1] = image[y, x] # G
2153
+ if interpolate:
2154
+ _edge_aware_interpolate_numba(out)
2155
+ return out
2156
+
2157
+ # === Python-Level Dispatch Function ===
2158
+ # Since Numba cannot easily compare strings in nopython mode,
2159
+ # we do the if/elif check here in Python and then call the appropriate njit function.
2160
+
2161
+ def debayer_fits_fast(image_data, bayer_pattern, cfa_drizzle=False, method="edge"):
2162
+ bp = (bayer_pattern or "").upper()
2163
+ interpolate = not cfa_drizzle
2164
+
2165
+ # 1) lay down the known samples per CFA
2166
+ if bp == 'RGGB':
2167
+ out = debayer_RGGB_fullres_fast(image_data, interpolate=False)
2168
+ elif bp == 'BGGR':
2169
+ out = debayer_BGGR_fullres_fast(image_data, interpolate=False)
2170
+ elif bp == 'GRBG':
2171
+ out = debayer_GRBG_fullres_fast(image_data, interpolate=False)
2172
+ elif bp == 'GBRG':
2173
+ out = debayer_GBRG_fullres_fast(image_data, interpolate=False)
2174
+ else:
2175
+ raise ValueError(f"Unsupported Bayer pattern: {bayer_pattern}")
2176
+
2177
+ # 2) perform interpolation unless doing CFA-drizzle
2178
+ if interpolate:
2179
+ m = (method or "edge").lower()
2180
+ if m == "edge":
2181
+ _edge_aware_interpolate_numba(out)
2182
+ elif m == "bilinear":
2183
+ _bilinear_interpolate_numba(out)
2184
+ else:
2185
+ # fallback to edge-aware if unknown
2186
+ _edge_aware_interpolate_numba(out)
2187
+
2188
+ return out
2189
+
2190
+ def debayer_raw_fast(raw_image_data, bayer_pattern="RGGB", cfa_drizzle=False, method="edge"):
2191
+ return debayer_fits_fast(raw_image_data, bayer_pattern, cfa_drizzle=cfa_drizzle, method=method)
2192
+
2193
+ @njit(parallel=True, fastmath=True)
2194
+ def applyPixelMath_numba(image_array, amount):
2195
+ factor = 3 ** amount
2196
+ denom_factor = 3 ** amount - 1
2197
+ height, width, channels = image_array.shape
2198
+ output = np.empty_like(image_array, dtype=np.float32)
2199
+
2200
+ for y in prange(height):
2201
+ for x in prange(width):
2202
+ for c in prange(channels):
2203
+ val = (factor * image_array[y, x, c]) / (denom_factor * image_array[y, x, c] + 1)
2204
+ output[y, x, c] = min(max(val, 0.0), 1.0) # Equivalent to np.clip()
2205
+
2206
+ return output
2207
+
2208
+ @njit(parallel=True, fastmath=True)
2209
+ def adjust_saturation_numba(image_array, saturation_factor):
2210
+ height, width, channels = image_array.shape
2211
+ output = np.empty_like(image_array, dtype=np.float32)
2212
+
2213
+ for y in prange(int(height)): # Ensure y is an integer
2214
+ for x in prange(int(width)): # Ensure x is an integer
2215
+ r, g, b = image_array[int(y), int(x)] # Force integer indexing
2216
+
2217
+ # Convert RGB to HSV manually
2218
+ max_val = max(r, g, b)
2219
+ min_val = min(r, g, b)
2220
+ delta = max_val - min_val
2221
+
2222
+ # Compute Hue (H)
2223
+ if delta == 0:
2224
+ h = 0
2225
+ elif max_val == r:
2226
+ h = (60 * ((g - b) / delta) + 360) % 360
2227
+ elif max_val == g:
2228
+ h = (60 * ((b - r) / delta) + 120) % 360
2229
+ else:
2230
+ h = (60 * ((r - g) / delta) + 240) % 360
2231
+
2232
+ # Compute Saturation (S)
2233
+ s = (delta / max_val) if max_val != 0 else 0
2234
+ s *= saturation_factor # Apply saturation adjustment
2235
+ s = min(max(s, 0.0), 1.0) # Clip saturation
2236
+
2237
+ # Convert back to RGB
2238
+ if s == 0:
2239
+ r, g, b = max_val, max_val, max_val
2240
+ else:
2241
+ c = s * max_val
2242
+ x_val = c * (1 - abs((h / 60) % 2 - 1))
2243
+ m = max_val - c
2244
+
2245
+ if 0 <= h < 60:
2246
+ r, g, b = c, x_val, 0
2247
+ elif 60 <= h < 120:
2248
+ r, g, b = x_val, c, 0
2249
+ elif 120 <= h < 180:
2250
+ r, g, b = 0, c, x_val
2251
+ elif 180 <= h < 240:
2252
+ r, g, b = 0, x_val, c
2253
+ elif 240 <= h < 300:
2254
+ r, g, b = x_val, 0, c
2255
+ else:
2256
+ r, g, b = c, 0, x_val
2257
+
2258
+ r, g, b = r + m, g + m, b + m # Add m to shift brightness
2259
+
2260
+ # ✅ Fix: Explicitly cast indices to integers
2261
+ output[int(y), int(x), 0] = r
2262
+ output[int(y), int(x), 1] = g
2263
+ output[int(y), int(x), 2] = b
2264
+
2265
+ return output
2266
+
2267
+
2268
+
2269
+
2270
+ @njit(parallel=True, fastmath=True)
2271
+ def applySCNR_numba(image_array):
2272
+ height, width, _ = image_array.shape
2273
+ output = np.empty_like(image_array, dtype=np.float32)
2274
+
2275
+ for y in prange(int(height)):
2276
+ for x in prange(int(width)):
2277
+ r, g, b = image_array[y, x]
2278
+ g = min(g, (r + b) / 2) # Reduce green to the average of red & blue
2279
+
2280
+ # ✅ Fix: Assign channels individually instead of a tuple
2281
+ output[int(y), int(x), 0] = r
2282
+ output[int(y), int(x), 1] = g
2283
+ output[int(y), int(x), 2] = b
2284
+
2285
+
2286
+ return output
2287
+
2288
+ # D65 reference
2289
+ _Xn, _Yn, _Zn = 0.95047, 1.00000, 1.08883
2290
+
2291
+ # Matrix for RGB -> XYZ (sRGB => D65)
2292
+ _M_rgb2xyz = np.array([
2293
+ [0.4124564, 0.3575761, 0.1804375],
2294
+ [0.2126729, 0.7151522, 0.0721750],
2295
+ [0.0193339, 0.1191920, 0.9503041]
2296
+ ], dtype=np.float32)
2297
+
2298
+ # Matrix for XYZ -> RGB (sRGB => D65)
2299
+ _M_xyz2rgb = np.array([
2300
+ [ 3.2404542, -1.5371385, -0.4985314],
2301
+ [-0.9692660, 1.8760108, 0.0415560],
2302
+ [ 0.0556434, -0.2040259, 1.0572252]
2303
+ ], dtype=np.float32)
2304
+
2305
+
2306
+
2307
+ @njit(parallel=True, fastmath=True)
2308
+ def apply_lut_gray(image_in, lut):
2309
+ """
2310
+ Numba-accelerated application of 'lut' to a single-channel image_in in [0..1].
2311
+ 'lut' is a 1D array of shape (size,) also in [0..1].
2312
+ """
2313
+ out = np.empty_like(image_in)
2314
+ height, width = image_in.shape
2315
+ size_lut = len(lut) - 1
2316
+
2317
+ for y in prange(height):
2318
+ for x in range(width):
2319
+ v = image_in[y, x]
2320
+ idx = int(v * size_lut + 0.5)
2321
+ if idx < 0: idx = 0
2322
+ elif idx > size_lut: idx = size_lut
2323
+ out[y, x] = lut[idx]
2324
+
2325
+ return out
2326
+
2327
+ @njit(parallel=True, fastmath=True)
2328
+ def apply_lut_color(image_in, lut):
2329
+ """
2330
+ Numba-accelerated application of 'lut' to a 3-channel image_in in [0..1].
2331
+ 'lut' is a 1D array of shape (size,) also in [0..1].
2332
+ """
2333
+ out = np.empty_like(image_in)
2334
+ height, width, channels = image_in.shape
2335
+ size_lut = len(lut) - 1
2336
+
2337
+ for y in prange(height):
2338
+ for x in range(width):
2339
+ for c in range(channels):
2340
+ v = image_in[y, x, c]
2341
+ idx = int(v * size_lut + 0.5)
2342
+ if idx < 0: idx = 0
2343
+ elif idx > size_lut: idx = size_lut
2344
+ out[y, x, c] = lut[idx]
2345
+
2346
+ return out
2347
+
2348
+ @njit(parallel=True, fastmath=True)
2349
+ def apply_lut_mono_inplace(array2d, lut):
2350
+ """
2351
+ In-place LUT application on a single-channel 2D array in [0..1].
2352
+ 'lut' has shape (size,) also in [0..1].
2353
+ """
2354
+ H, W = array2d.shape
2355
+ size_lut = len(lut) - 1
2356
+ for y in prange(H):
2357
+ for x in prange(W):
2358
+ v = array2d[y, x]
2359
+ idx = int(v * size_lut + 0.5)
2360
+ if idx < 0:
2361
+ idx = 0
2362
+ elif idx > size_lut:
2363
+ idx = size_lut
2364
+ array2d[y, x] = lut[idx]
2365
+
2366
+ @njit(parallel=True, fastmath=True)
2367
+ def apply_lut_color_inplace(array3d, lut):
2368
+ """
2369
+ In-place LUT application on a 3-channel array in [0..1].
2370
+ 'lut' has shape (size,) also in [0..1].
2371
+ """
2372
+ H, W, C = array3d.shape
2373
+ size_lut = len(lut) - 1
2374
+ for y in prange(H):
2375
+ for x in prange(W):
2376
+ for c in range(C):
2377
+ v = array3d[y, x, c]
2378
+ idx = int(v * size_lut + 0.5)
2379
+ if idx < 0:
2380
+ idx = 0
2381
+ elif idx > size_lut:
2382
+ idx = size_lut
2383
+ array3d[y, x, c] = lut[idx]
2384
+
2385
+ @njit(parallel=True, fastmath=True)
2386
+ def rgb_to_xyz_numba(rgb):
2387
+ """
2388
+ Convert an image from sRGB to XYZ (D65).
2389
+ rgb: float32 array in [0..1], shape (H,W,3)
2390
+ returns xyz in [0..maybe >1], shape (H,W,3)
2391
+ """
2392
+ H, W, _ = rgb.shape
2393
+ out = np.empty((H, W, 3), dtype=np.float32)
2394
+ for y in prange(H):
2395
+ for x in prange(W):
2396
+ r = rgb[y, x, 0]
2397
+ g = rgb[y, x, 1]
2398
+ b = rgb[y, x, 2]
2399
+ # Multiply by M_rgb2xyz
2400
+ X = _M_rgb2xyz[0,0]*r + _M_rgb2xyz[0,1]*g + _M_rgb2xyz[0,2]*b
2401
+ Y = _M_rgb2xyz[1,0]*r + _M_rgb2xyz[1,1]*g + _M_rgb2xyz[1,2]*b
2402
+ Z = _M_rgb2xyz[2,0]*r + _M_rgb2xyz[2,1]*g + _M_rgb2xyz[2,2]*b
2403
+ out[y, x, 0] = X
2404
+ out[y, x, 1] = Y
2405
+ out[y, x, 2] = Z
2406
+ return out
2407
+
2408
+ @njit(parallel=True, fastmath=True)
2409
+ def xyz_to_rgb_numba(xyz):
2410
+ """
2411
+ Convert an image from XYZ (D65) to sRGB.
2412
+ xyz: float32 array, shape (H,W,3)
2413
+ returns rgb in [0..1], shape (H,W,3)
2414
+ """
2415
+ H, W, _ = xyz.shape
2416
+ out = np.empty((H, W, 3), dtype=np.float32)
2417
+ for y in prange(H):
2418
+ for x in prange(W):
2419
+ X = xyz[y, x, 0]
2420
+ Y = xyz[y, x, 1]
2421
+ Z = xyz[y, x, 2]
2422
+ # Multiply by M_xyz2rgb
2423
+ r = _M_xyz2rgb[0,0]*X + _M_xyz2rgb[0,1]*Y + _M_xyz2rgb[0,2]*Z
2424
+ g = _M_xyz2rgb[1,0]*X + _M_xyz2rgb[1,1]*Y + _M_xyz2rgb[1,2]*Z
2425
+ b = _M_xyz2rgb[2,0]*X + _M_xyz2rgb[2,1]*Y + _M_xyz2rgb[2,2]*Z
2426
+ # Clip to [0..1]
2427
+ if r < 0: r = 0
2428
+ elif r > 1: r = 1
2429
+ if g < 0: g = 0
2430
+ elif g > 1: g = 1
2431
+ if b < 0: b = 0
2432
+ elif b > 1: b = 1
2433
+ out[y, x, 0] = r
2434
+ out[y, x, 1] = g
2435
+ out[y, x, 2] = b
2436
+ return out
2437
+
2438
+ @njit
2439
+ def f_lab_numba(t):
2440
+ delta = 6/29
2441
+ out = np.empty_like(t, dtype=np.float32)
2442
+ for i in range(t.size):
2443
+ val = t.flat[i]
2444
+ if val > delta**3:
2445
+ out.flat[i] = val**(1/3)
2446
+ else:
2447
+ out.flat[i] = val/(3*delta*delta) + (4/29)
2448
+ return out
2449
+
2450
+ @njit(parallel=True, fastmath=True)
2451
+ def xyz_to_lab_numba(xyz):
2452
+ """
2453
+ xyz => shape(H,W,3), in D65.
2454
+ returns lab in shape(H,W,3): L in [0..100], a,b in ~[-128..127].
2455
+ """
2456
+ H, W, _ = xyz.shape
2457
+ out = np.empty((H,W,3), dtype=np.float32)
2458
+ for y in prange(H):
2459
+ for x in prange(W):
2460
+ X = xyz[y, x, 0] / _Xn
2461
+ Y = xyz[y, x, 1] / _Yn
2462
+ Z = xyz[y, x, 2] / _Zn
2463
+ fx = (X)**(1/3) if X > (6/29)**3 else X/(3*(6/29)**2) + 4/29
2464
+ fy = (Y)**(1/3) if Y > (6/29)**3 else Y/(3*(6/29)**2) + 4/29
2465
+ fz = (Z)**(1/3) if Z > (6/29)**3 else Z/(3*(6/29)**2) + 4/29
2466
+ L = 116*fy - 16
2467
+ a = 500*(fx - fy)
2468
+ b = 200*(fy - fz)
2469
+ out[y, x, 0] = L
2470
+ out[y, x, 1] = a
2471
+ out[y, x, 2] = b
2472
+ return out
2473
+
2474
+ @njit(parallel=True, fastmath=True)
2475
+ def lab_to_xyz_numba(lab):
2476
+ """
2477
+ lab => shape(H,W,3): L in [0..100], a,b in ~[-128..127].
2478
+ returns xyz shape(H,W,3).
2479
+ """
2480
+ H, W, _ = lab.shape
2481
+ out = np.empty((H,W,3), dtype=np.float32)
2482
+ delta = 6/29
2483
+ for y in prange(H):
2484
+ for x in prange(W):
2485
+ L = lab[y, x, 0]
2486
+ a = lab[y, x, 1]
2487
+ b = lab[y, x, 2]
2488
+ fy = (L+16)/116
2489
+ fx = fy + a/500
2490
+ fz = fy - b/200
2491
+
2492
+ if fx > delta:
2493
+ xr = fx**3
2494
+ else:
2495
+ xr = 3*delta*delta*(fx - 4/29)
2496
+ if fy > delta:
2497
+ yr = fy**3
2498
+ else:
2499
+ yr = 3*delta*delta*(fy - 4/29)
2500
+ if fz > delta:
2501
+ zr = fz**3
2502
+ else:
2503
+ zr = 3*delta*delta*(fz - 4/29)
2504
+
2505
+ X = _Xn * xr
2506
+ Y = _Yn * yr
2507
+ Z = _Zn * zr
2508
+ out[y, x, 0] = X
2509
+ out[y, x, 1] = Y
2510
+ out[y, x, 2] = Z
2511
+ return out
2512
+
2513
+ @njit(parallel=True, fastmath=True)
2514
+ def rgb_to_hsv_numba(rgb):
2515
+ H, W, _ = rgb.shape
2516
+ out = np.empty((H,W,3), dtype=np.float32)
2517
+ for y in prange(H):
2518
+ for x in prange(W):
2519
+ r = rgb[y,x,0]
2520
+ g = rgb[y,x,1]
2521
+ b = rgb[y,x,2]
2522
+ cmax = max(r,g,b)
2523
+ cmin = min(r,g,b)
2524
+ delta = cmax - cmin
2525
+ # Hue
2526
+ h = 0.0
2527
+ if delta != 0.0:
2528
+ if cmax == r:
2529
+ h = 60*(((g-b)/delta) % 6)
2530
+ elif cmax == g:
2531
+ h = 60*(((b-r)/delta) + 2)
2532
+ else:
2533
+ h = 60*(((r-g)/delta) + 4)
2534
+ # Saturation
2535
+ s = 0.0
2536
+ if cmax > 0.0:
2537
+ s = delta / cmax
2538
+ v = cmax
2539
+ out[y,x,0] = h
2540
+ out[y,x,1] = s
2541
+ out[y,x,2] = v
2542
+ return out
2543
+
2544
+ @njit(parallel=True, fastmath=True)
2545
+ def hsv_to_rgb_numba(hsv):
2546
+ H, W, _ = hsv.shape
2547
+ out = np.empty((H,W,3), dtype=np.float32)
2548
+ for y in prange(H):
2549
+ for x in prange(W):
2550
+ h = hsv[y,x,0]
2551
+ s = hsv[y,x,1]
2552
+ v = hsv[y,x,2]
2553
+ c = v*s
2554
+ hh = (h/60.0) % 6
2555
+ x_ = c*(1 - abs(hh % 2 - 1))
2556
+ m = v - c
2557
+ r = 0.0
2558
+ g = 0.0
2559
+ b = 0.0
2560
+ if 0 <= hh < 1:
2561
+ r,g,b = c,x_,0
2562
+ elif 1 <= hh < 2:
2563
+ r,g,b = x_,c,0
2564
+ elif 2 <= hh < 3:
2565
+ r,g,b = 0,c,x_
2566
+ elif 3 <= hh < 4:
2567
+ r,g,b = 0,x_,c
2568
+ elif 4 <= hh < 5:
2569
+ r,g,b = x_,0,c
2570
+ else:
2571
+ r,g,b = c,0,x_
2572
+ out[y,x,0] = (r + m)
2573
+ out[y,x,1] = (g + m)
2574
+ out[y,x,2] = (b + m)
2575
+ return out
2576
+
2577
+ @njit(parallel=True, fastmath=True)
2578
+ def _cosmetic_correction_core(src, dst, H, W, C,
2579
+ hot_sigma, cold_sigma,
2580
+ star_mean_ratio, # e.g. 0.18..0.30
2581
+ star_max_ratio, # e.g. 0.45..0.65
2582
+ sat_threshold, # absolute cutoff in src units
2583
+ cold_cluster_max # max # of neighbors below low before we skip
2584
+ ):
2585
+ """
2586
+ Read from src, write to dst. Center is EXCLUDED from stats.
2587
+ Star guard: if ring mean or ring max are a decent fraction of center, skip (likely a PSF).
2588
+ Cold guard: if many neighbors are also low, skip (structure/shadow, not a dead pixel).
2589
+ """
2590
+ local_vals = np.empty(8, dtype=np.float32)
2591
+
2592
+ for y in prange(1, H-1):
2593
+ for x in range(1, W-1):
2594
+ for c in range(C if src.ndim == 3 else 1):
2595
+ # gather 8-neighbor ring (no center)
2596
+ k = 0
2597
+ ring_sum = 0.0
2598
+ ring_max = -1e30
2599
+ for dy in (-1, 0, 1):
2600
+ for dx in (-1, 0, 1):
2601
+ if dy == 0 and dx == 0:
2602
+ continue
2603
+ if src.ndim == 3:
2604
+ v = src[y+dy, x+dx, c]
2605
+ else:
2606
+ v = src[y+dy, x+dx]
2607
+ local_vals[k] = v
2608
+ ring_sum += v
2609
+ if v > ring_max:
2610
+ ring_max = v
2611
+ k += 1
2612
+
2613
+ # median and MAD from ring only
2614
+ M = np.median(local_vals)
2615
+ abs_devs = np.empty(8, dtype=np.float32)
2616
+ for i in range(8):
2617
+ abs_devs[i] = abs(local_vals[i] - M)
2618
+ MAD = np.median(abs_devs)
2619
+ sigma = 1.4826 * MAD + 1e-8 # epsilon guard
2620
+
2621
+ # center
2622
+ T = src[y, x, c] if src.ndim == 3 else src[y, x]
2623
+
2624
+ # saturation guard
2625
+ if T >= sat_threshold:
2626
+ if src.ndim == 3: dst[y, x, c] = T
2627
+ else: dst[y, x] = T
2628
+ continue
2629
+
2630
+ high = M + hot_sigma * sigma
2631
+ low = M - cold_sigma * sigma
2632
+
2633
+ replace = False
2634
+
2635
+ if T > high:
2636
+ # Star guard for HOT: neighbors should not form a footprint
2637
+ ring_mean = ring_sum / 8.0
2638
+ if (ring_mean / (T + 1e-8) < star_mean_ratio) and (ring_max / (T + 1e-8) < star_max_ratio):
2639
+ replace = True
2640
+ elif T < low:
2641
+ # Cold pixel: only if it's isolated (few neighbors also low)
2642
+ count_below = 0
2643
+ for i in range(8):
2644
+ if local_vals[i] < low:
2645
+ count_below += 1
2646
+ if count_below <= cold_cluster_max:
2647
+ replace = True
2648
+
2649
+ if replace:
2650
+ if src.ndim == 3: dst[y, x, c] = M
2651
+ else: dst[y, x] = M
2652
+ else:
2653
+ if src.ndim == 3: dst[y, x, c] = T
2654
+ else: dst[y, x] = T
2655
+
2656
+
2657
+ def bulk_cosmetic_correction_numba(image,
2658
+ hot_sigma=5.0,
2659
+ cold_sigma=5.0,
2660
+ star_mean_ratio=0.22,
2661
+ star_max_ratio=0.55,
2662
+ sat_quantile=0.9995):
2663
+ """
2664
+ Star-safe cosmetic correction for 2D (mono) or 3D (RGB) arrays.
2665
+ Reads from the original, writes to a new array (two-pass).
2666
+ - star_mean_ratio: how large neighbor mean must be vs center to *skip* (PSF)
2667
+ - star_max_ratio : how large neighbor max must be vs center to *skip* (PSF)
2668
+ - sat_quantile : top quantile to protect from edits (bright cores)
2669
+ """
2670
+ img = image.astype(np.float32, copy=False)
2671
+ was_gray = (img.ndim == 2)
2672
+ if was_gray:
2673
+ src = img[:, :, None]
2674
+ else:
2675
+ src = img
2676
+
2677
+ H, W, C = src.shape
2678
+ dst = src.copy()
2679
+
2680
+ # per-channel saturation guards
2681
+ sat_thresholds = np.empty(C, dtype=np.float32)
2682
+ for ci in range(C):
2683
+ plane = src[:, :, ci]
2684
+ # Compute in Python (Numba doesn't support np.quantile well)
2685
+ sat_thresholds[ci] = float(np.quantile(plane, sat_quantile))
2686
+
2687
+ # run per-channel to use per-channel saturation
2688
+ for ci in range(C):
2689
+ _cosmetic_correction_core(src[:, :, ci], dst[:, :, ci],
2690
+ H, W, 1,
2691
+ float(hot_sigma), float(cold_sigma),
2692
+ float(star_mean_ratio), float(star_max_ratio),
2693
+ float(sat_thresholds[ci]),
2694
+ 1) # cold_cluster_max: allow 1 neighbor to be low
2695
+
2696
+ if was_gray:
2697
+ return dst[:, :, 0]
2698
+ return dst
2699
+
2700
+
2701
+ def bulk_cosmetic_correction_bayer(image,
2702
+ hot_sigma=5.5,
2703
+ cold_sigma=5.0,
2704
+ star_mean_ratio=0.22,
2705
+ star_max_ratio=0.55,
2706
+ sat_quantile=0.9995,
2707
+ pattern="RGGB"):
2708
+ """
2709
+ Bayer-safe cosmetic correction. Work on same-color sub-planes (2-px stride),
2710
+ then write results back. Defaults assume normalized or 16/32f data.
2711
+ """
2712
+ H, W = image.shape
2713
+ corrected = image.astype(np.float32).copy()
2714
+
2715
+ if pattern.upper() not in ("RGGB", "BGGR", "GRBG", "GBRG"):
2716
+ pattern = "RGGB"
2717
+
2718
+ # index maps for each CFA pattern (row0,col0 offsets)
2719
+ if pattern.upper() == "RGGB":
2720
+ r0, c0 = 0, 0
2721
+ g1r, g1c = 0, 1
2722
+ g2r, g2c = 1, 0
2723
+ b0, b0c = 1, 1
2724
+ elif pattern.upper() == "BGGR":
2725
+ r0, c0 = 1, 1
2726
+ g1r, g1c = 1, 0
2727
+ g2r, g2c = 0, 1
2728
+ b0, b0c = 0, 0
2729
+ elif pattern.upper() == "GRBG":
2730
+ r0, c0 = 0, 1
2731
+ g1r, g1c = 0, 0
2732
+ g2r, g2c = 1, 1
2733
+ b0, b0c = 1, 0
2734
+ else: # GBRG
2735
+ r0, c0 = 1, 0
2736
+ g1r, g1c = 0, 0
2737
+ g2r, g2c = 1, 1
2738
+ b0, b0c = 0, 1
2739
+
2740
+ # helper to process a same-color plane view
2741
+ def _process_plane(view):
2742
+ return bulk_cosmetic_correction_numba(
2743
+ view,
2744
+ hot_sigma=hot_sigma,
2745
+ cold_sigma=cold_sigma,
2746
+ star_mean_ratio=star_mean_ratio,
2747
+ star_max_ratio=star_max_ratio,
2748
+ sat_quantile=sat_quantile
2749
+ )
2750
+
2751
+ # Red
2752
+ red = corrected[r0:H:2, c0:W:2]
2753
+ corrected[r0:H:2, c0:W:2] = _process_plane(red)
2754
+
2755
+ # Blue
2756
+ blue = corrected[b0:H:2, b0c:W:2]
2757
+ corrected[b0:H:2, b0c:W:2] = _process_plane(blue)
2758
+
2759
+ # Greens
2760
+ g1 = corrected[g1r:H:2, g1c:W:2]
2761
+ corrected[g1r:H:2, g1c:W:2] = _process_plane(g1)
2762
+
2763
+ g2 = corrected[g2r:H:2, g2c:W:2]
2764
+ corrected[g2r:H:2, g2c:W:2] = _process_plane(g2)
2765
+
2766
+ return corrected
2767
+
2768
+ def evaluate_polynomial(H: int, W: int, coeffs: np.ndarray, degree: int) -> np.ndarray:
2769
+ """
2770
+ Evaluates the polynomial function over the entire image domain.
2771
+ """
2772
+ xx, yy = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy")
2773
+ A_full = build_poly_terms(xx.ravel(), yy.ravel(), degree)
2774
+ return (A_full @ coeffs).reshape(H, W)
2775
+
2776
+
2777
+
2778
+ @njit(parallel=True, fastmath=True)
2779
+ def numba_mono_final_formula(rescaled, median_rescaled, target_median):
2780
+ """
2781
+ Applies the final formula *after* we already have the rescaled values.
2782
+
2783
+ rescaled[y,x] = (original[y,x] - black_point) / (1 - black_point)
2784
+ median_rescaled = median(rescaled)
2785
+
2786
+ out_val = ((median_rescaled - 1) * target_median * r) /
2787
+ ( median_rescaled*(target_median + r -1) - target_median*r )
2788
+ """
2789
+ H, W = rescaled.shape
2790
+ out = np.empty_like(rescaled)
2791
+
2792
+ for y in prange(H):
2793
+ for x in range(W):
2794
+ r = rescaled[y, x]
2795
+ numer = (median_rescaled - 1.0) * target_median * r
2796
+ denom = median_rescaled * (target_median + r - 1.0) - target_median * r
2797
+ if np.abs(denom) < 1e-12:
2798
+ denom = 1e-12
2799
+ out[y, x] = numer / denom
2800
+
2801
+ return out
2802
+
2803
+ @njit(parallel=True, fastmath=True)
2804
+ def numba_color_final_formula_linked(rescaled, median_rescaled, target_median):
2805
+ """
2806
+ Linked color transform: we use one median_rescaled for all channels.
2807
+ rescaled: (H,W,3), already = (image - black_point)/(1 - black_point)
2808
+ median_rescaled = median of *all* pixels in rescaled
2809
+ """
2810
+ H, W, C = rescaled.shape
2811
+ out = np.empty_like(rescaled)
2812
+
2813
+ for y in prange(H):
2814
+ for x in range(W):
2815
+ for c in range(C):
2816
+ r = rescaled[y, x, c]
2817
+ numer = (median_rescaled - 1.0) * target_median * r
2818
+ denom = median_rescaled * (target_median + r - 1.0) - target_median * r
2819
+ if np.abs(denom) < 1e-12:
2820
+ denom = 1e-12
2821
+ out[y, x, c] = numer / denom
2822
+
2823
+ return out
2824
+
2825
+ @njit(parallel=True, fastmath=True)
2826
+ def numba_color_final_formula_unlinked(rescaled, medians_rescaled, target_median):
2827
+ """
2828
+ Unlinked color transform: a separate median_rescaled per channel.
2829
+ rescaled: (H,W,3), where each channel is already (val - black_point[c]) / (1 - black_point[c])
2830
+ medians_rescaled: shape (3,) with median of each channel in the rescaled array.
2831
+ """
2832
+ H, W, C = rescaled.shape
2833
+ out = np.empty_like(rescaled)
2834
+
2835
+ for y in prange(H):
2836
+ for x in range(W):
2837
+ for c in range(C):
2838
+ r = rescaled[y, x, c]
2839
+ med = medians_rescaled[c]
2840
+ numer = (med - 1.0) * target_median * r
2841
+ denom = med * (target_median + r - 1.0) - target_median * r
2842
+ if np.abs(denom) < 1e-12:
2843
+ denom = 1e-12
2844
+ out[y, x, c] = numer / denom
2845
+
2846
+ return out
2847
+
2848
+
2849
+ def build_poly_terms(x_array: np.ndarray, y_array: np.ndarray, degree: int) -> np.ndarray:
2850
+ """
2851
+ Precomputes polynomial basis terms efficiently using NumPy, supporting up to degree 6.
2852
+ """
2853
+ ones = np.ones_like(x_array, dtype=np.float32)
2854
+
2855
+ if degree == 1:
2856
+ return np.column_stack((ones, x_array, y_array))
2857
+
2858
+ elif degree == 2:
2859
+ return np.column_stack((ones, x_array, y_array,
2860
+ x_array**2, x_array * y_array, y_array**2))
2861
+
2862
+ elif degree == 3:
2863
+ return np.column_stack((ones, x_array, y_array,
2864
+ x_array**2, x_array * y_array, y_array**2,
2865
+ x_array**3, x_array**2 * y_array, x_array * y_array**2, y_array**3))
2866
+
2867
+ elif degree == 4:
2868
+ return np.column_stack((ones, x_array, y_array,
2869
+ x_array**2, x_array * y_array, y_array**2,
2870
+ x_array**3, x_array**2 * y_array, x_array * y_array**2, y_array**3,
2871
+ x_array**4, x_array**3 * y_array, x_array**2 * y_array**2, x_array * y_array**3, y_array**4))
2872
+
2873
+ elif degree == 5:
2874
+ return np.column_stack((ones, x_array, y_array,
2875
+ x_array**2, x_array * y_array, y_array**2,
2876
+ x_array**3, x_array**2 * y_array, x_array * y_array**2, y_array**3,
2877
+ x_array**4, x_array**3 * y_array, x_array**2 * y_array**2, x_array * y_array**3, y_array**4,
2878
+ x_array**5, x_array**4 * y_array, x_array**3 * y_array**2, x_array**2 * y_array**3, x_array * y_array**4, y_array**5))
2879
+
2880
+ elif degree == 6:
2881
+ return np.column_stack((ones, x_array, y_array,
2882
+ x_array**2, x_array * y_array, y_array**2,
2883
+ x_array**3, x_array**2 * y_array, x_array * y_array**2, y_array**3,
2884
+ x_array**4, x_array**3 * y_array, x_array**2 * y_array**2, x_array * y_array**3, y_array**4,
2885
+ x_array**5, x_array**4 * y_array, x_array**3 * y_array**2, x_array**2 * y_array**3, x_array * y_array**4, y_array**5,
2886
+ x_array**6, x_array**5 * y_array, x_array**4 * y_array**2, x_array**3 * y_array**3, x_array**2 * y_array**4, x_array * y_array**5, y_array**6))
2887
+
2888
+ else:
2889
+ raise ValueError(f"Unsupported polynomial degree={degree}. Max supported is 6.")
2890
+
2891
+
2892
+
2893
+
2894
+ def generate_sample_points(image: np.ndarray, num_points: int = 100) -> np.ndarray:
2895
+ """
2896
+ Generates sample points uniformly across the image.
2897
+
2898
+ - Places points in a uniform grid (no randomization).
2899
+ - Avoids border pixels.
2900
+ - Skips any points with value 0.000 or above 0.85.
2901
+
2902
+ Returns:
2903
+ np.ndarray: Array of shape (N, 2) containing (x, y) coordinates of sample points.
2904
+ """
2905
+ H, W = image.shape[:2]
2906
+ points = []
2907
+
2908
+ # Create a uniform grid (avoiding the border)
2909
+ grid_size = int(np.sqrt(num_points)) # Roughly equal spacing
2910
+ x_vals = np.linspace(10, W - 10, grid_size, dtype=int) # Avoids border
2911
+ y_vals = np.linspace(10, H - 10, grid_size, dtype=int)
2912
+
2913
+ for y in y_vals:
2914
+ for x in x_vals:
2915
+ # Skip values that are too dark (0.000) or too bright (> 0.85)
2916
+ if np.any(image[int(y), int(x)] == 0.000) or np.any(image[int(y), int(x)] > 0.85):
2917
+ continue # Skip this pixel
2918
+
2919
+ points.append((int(x), int(y)))
2920
+
2921
+ if len(points) >= num_points:
2922
+ return np.array(points, dtype=np.int32) # Return only valid points
2923
+
2924
+ return np.array(points, dtype=np.int32) # Return all collected points
2925
+
2926
+ @njit(parallel=True, fastmath=True)
2927
+ def numba_unstretch(image: np.ndarray, stretch_original_medians: np.ndarray, stretch_original_mins: np.ndarray) -> np.ndarray:
2928
+ """
2929
+ Numba-optimized function to undo the unlinked stretch.
2930
+ Restores each channel separately.
2931
+ """
2932
+ H, W, C = image.shape
2933
+ out = np.empty_like(image, dtype=np.float32)
2934
+
2935
+ for c in prange(C): # Parallelize per channel
2936
+ cmed_stretched = np.median(image[..., c])
2937
+ orig_med = stretch_original_medians[c]
2938
+ orig_min = stretch_original_mins[c]
2939
+
2940
+ if cmed_stretched != 0 and orig_med != 0:
2941
+ for y in prange(H):
2942
+ for x in range(W):
2943
+ r = image[y, x, c]
2944
+ numerator = (cmed_stretched - 1) * orig_med * r
2945
+ denominator = cmed_stretched * (orig_med + r - 1) - orig_med * r
2946
+ if denominator == 0:
2947
+ denominator = 1e-6 # Avoid division by zero
2948
+ out[y, x, c] = numerator / denominator
2949
+
2950
+ # Restore the original black point
2951
+ out[..., c] += orig_min
2952
+
2953
+ return np.clip(out, 0, 1) # Clip to valid range
2954
+
2955
+
2956
+ @njit(fastmath=True)
2957
+ def drizzle_deposit_numba_naive(
2958
+ img_data, # shape (H, W), mono
2959
+ transform, # shape (2, 3), e.g. [[a,b,tx],[c,d,ty]]
2960
+ drizzle_buffer, # shape (outH, outW)
2961
+ coverage_buffer,# shape (outH, outW)
2962
+ drizzle_factor: float,
2963
+ frame_weight: float
2964
+ ):
2965
+ """
2966
+ Naive deposit: each input pixel is mapped to exactly one output pixel,
2967
+ ignoring drop_shrink. 2D single-channel version (mono).
2968
+ """
2969
+ h, w = img_data.shape
2970
+ out_h, out_w = drizzle_buffer.shape
2971
+
2972
+ # Build a 3×3 matrix M
2973
+ # transform is 2×3, so we expand to 3×3 for the standard [x, y, 1] approach
2974
+ M = np.zeros((3, 3), dtype=np.float32)
2975
+ M[0, 0] = transform[0, 0] # a
2976
+ M[0, 1] = transform[0, 1] # b
2977
+ M[0, 2] = transform[0, 2] # tx
2978
+ M[1, 0] = transform[1, 0] # c
2979
+ M[1, 1] = transform[1, 1] # d
2980
+ M[1, 2] = transform[1, 2] # ty
2981
+ M[2, 2] = 1.0
2982
+
2983
+ # We'll reuse a small input vector for each pixel
2984
+ in_coords = np.zeros(3, dtype=np.float32)
2985
+ in_coords[2] = 1.0
2986
+
2987
+ for y in range(h):
2988
+ for x in range(w):
2989
+ val = img_data[y, x]
2990
+ if val == 0:
2991
+ continue
2992
+
2993
+ # Fill the input vector
2994
+ in_coords[0] = x
2995
+ in_coords[1] = y
2996
+
2997
+ # Multiply
2998
+ out_coords = M @ in_coords
2999
+ X = out_coords[0]
3000
+ Y = out_coords[1]
3001
+
3002
+ # Multiply by drizzle_factor
3003
+ Xo = int(X * drizzle_factor)
3004
+ Yo = int(Y * drizzle_factor)
3005
+
3006
+ if 0 <= Xo < out_w and 0 <= Yo < out_h:
3007
+ drizzle_buffer[Yo, Xo] += val * frame_weight
3008
+ coverage_buffer[Yo, Xo] += frame_weight
3009
+
3010
+ return drizzle_buffer, coverage_buffer
3011
+
3012
+
3013
+ @njit(fastmath=True)
3014
+ def drizzle_deposit_numba_footprint(
3015
+ img_data, # shape (H, W), mono
3016
+ transform, # shape (2, 3)
3017
+ drizzle_buffer, # shape (outH, outW)
3018
+ coverage_buffer,# shape (outH, outW)
3019
+ drizzle_factor: float,
3020
+ drop_shrink: float,
3021
+ frame_weight: float
3022
+ ):
3023
+ """
3024
+ Distributes each input pixel over a bounding box of width=drop_shrink
3025
+ in the drizzle (out) plane. (Mono 2D version)
3026
+ """
3027
+ h, w = img_data.shape
3028
+ out_h, out_w = drizzle_buffer.shape
3029
+
3030
+ # Build a 3×3 matrix M
3031
+ M = np.zeros((3, 3), dtype=np.float32)
3032
+ M[0, 0] = transform[0, 0] # a
3033
+ M[0, 1] = transform[0, 1] # b
3034
+ M[0, 2] = transform[0, 2] # tx
3035
+ M[1, 0] = transform[1, 0] # c
3036
+ M[1, 1] = transform[1, 1] # d
3037
+ M[1, 2] = transform[1, 2] # ty
3038
+ M[2, 2] = 1.0
3039
+
3040
+ in_coords = np.zeros(3, dtype=np.float32)
3041
+ in_coords[2] = 1.0
3042
+
3043
+ footprint_radius = drop_shrink * 0.5
3044
+
3045
+ for y in range(h):
3046
+ for x in range(w):
3047
+ val = img_data[y, x]
3048
+ if val == 0:
3049
+ continue
3050
+
3051
+ # Transform to output coords
3052
+ in_coords[0] = x
3053
+ in_coords[1] = y
3054
+ out_coords = M @ in_coords
3055
+ X = out_coords[0]
3056
+ Y = out_coords[1]
3057
+
3058
+ # Upsample
3059
+ Xo = X * drizzle_factor
3060
+ Yo = Y * drizzle_factor
3061
+
3062
+ # bounding box
3063
+ min_x = int(np.floor(Xo - footprint_radius))
3064
+ max_x = int(np.floor(Xo + footprint_radius))
3065
+ min_y = int(np.floor(Yo - footprint_radius))
3066
+ max_y = int(np.floor(Yo + footprint_radius))
3067
+
3068
+ # clip
3069
+ if max_x < 0 or min_x >= out_w or max_y < 0 or min_y >= out_h:
3070
+ continue
3071
+ if min_x < 0:
3072
+ min_x = 0
3073
+ if max_x >= out_w:
3074
+ max_x = out_w - 1
3075
+ if min_y < 0:
3076
+ min_y = 0
3077
+ if max_y >= out_h:
3078
+ max_y = out_h - 1
3079
+
3080
+ width_foot = (max_x - min_x + 1)
3081
+ height_foot = (max_y - min_y + 1)
3082
+ area_pixels = width_foot * height_foot
3083
+ if area_pixels <= 0:
3084
+ continue
3085
+
3086
+ deposit_val = (val * frame_weight) / area_pixels
3087
+ coverage_fraction = frame_weight / area_pixels
3088
+
3089
+ for oy in range(min_y, max_y+1):
3090
+ for ox in range(min_x, max_x+1):
3091
+ drizzle_buffer[oy, ox] += deposit_val
3092
+ coverage_buffer[oy, ox] += coverage_fraction
3093
+
3094
+ return drizzle_buffer, coverage_buffer
3095
+
3096
+ @njit(fastmath=True)
3097
+ def _drizzle_kernel_weights(kernel_code: int, Xo: float, Yo: float,
3098
+ min_x: int, max_x: int, min_y: int, max_y: int,
3099
+ sigma_out: float,
3100
+ weights_out): # preallocated 2D view (max_y-min_y+1, max_x-min_x+1)
3101
+ """
3102
+ Fill `weights_out` with unnormalized kernel weights centered at (Xo,Yo).
3103
+ Returns (sum_w, count_used).
3104
+ """
3105
+ H = max_y - min_y + 1
3106
+ W = max_x - min_x + 1
3107
+ r2_limit = sigma_out * sigma_out # for circle, sigma_out := radius
3108
+
3109
+ sum_w = 0.0
3110
+ cnt = 0
3111
+ for j in range(H):
3112
+ oy = min_y + j
3113
+ cy = (oy + 0.5) - Yo # pixel-center distance
3114
+ for i in range(W):
3115
+ ox = min_x + i
3116
+ cx = (ox + 0.5) - Xo
3117
+ w = 0.0
3118
+
3119
+ if kernel_code == 0:
3120
+ # square = uniform weight in the bounding box
3121
+ w = 1.0
3122
+ elif kernel_code == 1:
3123
+ # circle = uniform weight if inside radius
3124
+ if (cx*cx + cy*cy) <= r2_limit:
3125
+ w = 1.0
3126
+ else: # gaussian
3127
+ # gaussian centered at (Xo,Yo) with sigma_out
3128
+ z = (cx*cx + cy*cy) / (2.0 * sigma_out * sigma_out)
3129
+ # drop tiny far-away contributions to keep perf ok
3130
+ if z <= 9.0: # ~3σ
3131
+ w = math.exp(-z)
3132
+
3133
+ weights_out[j, i] = w
3134
+ sum_w += w
3135
+ if w > 0.0:
3136
+ cnt += 1
3137
+
3138
+ return sum_w, cnt
3139
+
3140
+
3141
+ @njit(fastmath=True)
3142
+ def drizzle_deposit_numba_kernel_mono(
3143
+ img_data, transform, drizzle_buffer, coverage_buffer,
3144
+ drizzle_factor: float, drop_shrink: float, frame_weight: float,
3145
+ kernel_code: int, gaussian_sigma_or_radius: float
3146
+ ):
3147
+ H, W = img_data.shape
3148
+ outH, outW = drizzle_buffer.shape
3149
+
3150
+ # build 3x3
3151
+ M = np.zeros((3, 3), dtype=np.float32)
3152
+ M[0,0], M[0,1], M[0,2] = transform[0,0], transform[0,1], transform[0,2]
3153
+ M[1,0], M[1,1], M[1,2] = transform[1,0], transform[1,1], transform[1,2]
3154
+ M[2,2] = 1.0
3155
+
3156
+ v = np.zeros(3, dtype=np.float32); v[2] = 1.0
3157
+
3158
+ # interpret width parameter:
3159
+ # - square/circle: radius = drop_shrink * 0.5 (pixfrac-like)
3160
+ # - gaussian: sigma_out = max(gaussian_sigma_or_radius, drop_shrink * 0.5)
3161
+ radius = drop_shrink * 0.5
3162
+ sigma_out = gaussian_sigma_or_radius if kernel_code == 2 else radius
3163
+ if sigma_out < 1e-6:
3164
+ sigma_out = 1e-6
3165
+
3166
+ # temp weights tile (safely sized later per pixel)
3167
+ for y in range(H):
3168
+ for x in range(W):
3169
+ val = img_data[y, x]
3170
+ if val == 0.0:
3171
+ continue
3172
+
3173
+ v[0] = x; v[1] = y
3174
+ out_coords = M @ v
3175
+ Xo = out_coords[0] * drizzle_factor
3176
+ Yo = out_coords[1] * drizzle_factor
3177
+
3178
+ # choose bounds
3179
+ if kernel_code == 2:
3180
+ r = int(math.ceil(3.0 * sigma_out))
3181
+ else:
3182
+ r = int(math.ceil(radius))
3183
+
3184
+ if r <= 0:
3185
+ # degenerate → nearest pixel
3186
+ ox = int(Xo); oy = int(Yo)
3187
+ if 0 <= ox < outW and 0 <= oy < outH:
3188
+ drizzle_buffer[oy, ox] += val * frame_weight
3189
+ coverage_buffer[oy, ox] += frame_weight
3190
+ continue
3191
+
3192
+ min_x = int(math.floor(Xo - r))
3193
+ max_x = int(math.floor(Xo + r))
3194
+ min_y = int(math.floor(Yo - r))
3195
+ max_y = int(math.floor(Yo + r))
3196
+ if max_x < 0 or min_x >= outW or max_y < 0 or min_y >= outH:
3197
+ continue
3198
+ if min_x < 0: min_x = 0
3199
+ if min_y < 0: min_y = 0
3200
+ if max_x >= outW: max_x = outW - 1
3201
+ if max_y >= outH: max_y = outH - 1
3202
+
3203
+ Ht = max_y - min_y + 1
3204
+ Wt = max_x - min_x + 1
3205
+ if Ht <= 0 or Wt <= 0:
3206
+ continue
3207
+
3208
+ # allocate small tile (Numba-friendly: fixed-size via stack array)
3209
+ weights = np.zeros((Ht, Wt), dtype=np.float32)
3210
+ sum_w, cnt = _drizzle_kernel_weights(kernel_code, Xo, Yo,
3211
+ min_x, max_x, min_y, max_y,
3212
+ sigma_out, weights)
3213
+ if cnt == 0 or sum_w <= 1e-12:
3214
+ # fallback to nearest
3215
+ ox = int(Xo); oy = int(Yo)
3216
+ if 0 <= ox < outW and 0 <= oy < outH:
3217
+ drizzle_buffer[oy, ox] += val * frame_weight
3218
+ coverage_buffer[oy, ox] += frame_weight
3219
+ continue
3220
+
3221
+ scale = (val * frame_weight) / sum_w
3222
+ cov_scale = frame_weight / sum_w
3223
+ for j in range(Ht):
3224
+ oy = min_y + j
3225
+ for i in range(Wt):
3226
+ w = weights[j, i]
3227
+ if w > 0.0:
3228
+ ox = min_x + i
3229
+ drizzle_buffer[oy, ox] += w * scale
3230
+ coverage_buffer[oy, ox] += w * cov_scale
3231
+
3232
+ return drizzle_buffer, coverage_buffer
3233
+
3234
+
3235
+ @njit(fastmath=True)
3236
+ def drizzle_deposit_color_kernel(
3237
+ img_data, transform, drizzle_buffer, coverage_buffer,
3238
+ drizzle_factor: float, drop_shrink: float, frame_weight: float,
3239
+ kernel_code: int, gaussian_sigma_or_radius: float
3240
+ ):
3241
+ H, W, C = img_data.shape
3242
+ outH, outW, _ = drizzle_buffer.shape
3243
+
3244
+ M = np.zeros((3, 3), dtype=np.float32)
3245
+ M[0,0], M[0,1], M[0,2] = transform[0,0], transform[0,1], transform[0,2]
3246
+ M[1,0], M[1,1], M[1,2] = transform[1,0], transform[1,1], transform[1,2]
3247
+ M[2,2] = 1.0
3248
+
3249
+ v = np.zeros(3, dtype=np.float32); v[2] = 1.0
3250
+
3251
+ radius = drop_shrink * 0.5
3252
+ sigma_out = gaussian_sigma_or_radius if kernel_code == 2 else radius
3253
+ if sigma_out < 1e-6:
3254
+ sigma_out = 1e-6
3255
+
3256
+ for y in range(H):
3257
+ for x in range(W):
3258
+ # (minor optimization) skip all-zero triplets
3259
+ nz = False
3260
+ for cc in range(C):
3261
+ if img_data[y, x, cc] != 0.0:
3262
+ nz = True; break
3263
+ if not nz:
3264
+ continue
3265
+
3266
+ v[0] = x; v[1] = y
3267
+ out_coords = M @ v
3268
+ Xo = out_coords[0] * drizzle_factor
3269
+ Yo = out_coords[1] * drizzle_factor
3270
+
3271
+ if kernel_code == 2:
3272
+ r = int(math.ceil(3.0 * sigma_out))
3273
+ else:
3274
+ r = int(math.ceil(radius))
3275
+
3276
+ if r <= 0:
3277
+ ox = int(Xo); oy = int(Yo)
3278
+ if 0 <= ox < outW and 0 <= oy < outH:
3279
+ for c in range(C):
3280
+ val = img_data[y, x, c]
3281
+ if val != 0.0:
3282
+ drizzle_buffer[oy, ox, c] += val * frame_weight
3283
+ coverage_buffer[oy, ox, c] += frame_weight
3284
+ continue
3285
+
3286
+ min_x = int(math.floor(Xo - r))
3287
+ max_x = int(math.floor(Xo + r))
3288
+ min_y = int(math.floor(Yo - r))
3289
+ max_y = int(math.floor(Yo + r))
3290
+ if max_x < 0 or min_x >= outW or max_y < 0 or min_y >= outH:
3291
+ continue
3292
+ if min_x < 0: min_x = 0
3293
+ if min_y < 0: min_y = 0
3294
+ if max_x >= outW: max_x = outW - 1
3295
+ if max_y >= outH: max_y = outH - 1
3296
+
3297
+ Ht = max_y - min_y + 1
3298
+ Wt = max_x - min_x + 1
3299
+ if Ht <= 0 or Wt <= 0:
3300
+ continue
3301
+
3302
+ weights = np.zeros((Ht, Wt), dtype=np.float32)
3303
+ sum_w, cnt = _drizzle_kernel_weights(kernel_code, Xo, Yo,
3304
+ min_x, max_x, min_y, max_y,
3305
+ sigma_out, weights)
3306
+ if cnt == 0 or sum_w <= 1e-12:
3307
+ ox = int(Xo); oy = int(Yo)
3308
+ if 0 <= ox < outW and 0 <= oy < outH:
3309
+ for c in range(C):
3310
+ val = img_data[y, x, c]
3311
+ if val != 0.0:
3312
+ drizzle_buffer[oy, ox, c] += val * frame_weight
3313
+ coverage_buffer[oy, ox, c] += frame_weight
3314
+ continue
3315
+
3316
+ inv_sum = 1.0 / sum_w
3317
+ for c in range(C):
3318
+ val = img_data[y, x, c]
3319
+ if val == 0.0:
3320
+ continue
3321
+ scale = (val * frame_weight) * inv_sum
3322
+ cov_scale = frame_weight * inv_sum
3323
+ for j in range(Ht):
3324
+ oy = min_y + j
3325
+ for i in range(Wt):
3326
+ w = weights[j, i]
3327
+ if w > 0.0:
3328
+ ox = min_x + i
3329
+ drizzle_buffer[oy, ox, c] += w * scale
3330
+ coverage_buffer[oy, ox, c] += w * cov_scale
3331
+
3332
+ return drizzle_buffer, coverage_buffer
3333
+
3334
+ @njit(parallel=True)
3335
+ def finalize_drizzle_2d(drizzle_buffer, coverage_buffer, final_out):
3336
+ """
3337
+ parallel-friendly final step: final_out = drizzle_buffer / coverage_buffer,
3338
+ with coverage < 1e-8 => 0
3339
+ """
3340
+ out_h, out_w = drizzle_buffer.shape
3341
+ for y in prange(out_h):
3342
+ for x in range(out_w):
3343
+ cov = coverage_buffer[y, x]
3344
+ if cov < 1e-8:
3345
+ final_out[y, x] = 0.0
3346
+ else:
3347
+ final_out[y, x] = drizzle_buffer[y, x] / cov
3348
+ return final_out
3349
+
3350
+ @njit(fastmath=True)
3351
+ def drizzle_deposit_color_naive(
3352
+ img_data, # shape (H,W,C)
3353
+ transform, # shape (2,3)
3354
+ drizzle_buffer, # shape (outH,outW,C)
3355
+ coverage_buffer, # shape (outH,outW,C)
3356
+ drizzle_factor: float,
3357
+ drop_shrink: float, # unused here
3358
+ frame_weight: float
3359
+ ):
3360
+ """
3361
+ Naive color deposit:
3362
+ Each input pixel is mapped to exactly one output pixel (ignores drop_shrink).
3363
+ """
3364
+ H, W, channels = img_data.shape
3365
+ outH, outW, outC = drizzle_buffer.shape
3366
+
3367
+ # Build 3×3 matrix M
3368
+ M = np.zeros((3, 3), dtype=np.float32)
3369
+ M[0, 0] = transform[0, 0]
3370
+ M[0, 1] = transform[0, 1]
3371
+ M[0, 2] = transform[0, 2]
3372
+ M[1, 0] = transform[1, 0]
3373
+ M[1, 1] = transform[1, 1]
3374
+ M[1, 2] = transform[1, 2]
3375
+ M[2, 2] = 1.0
3376
+
3377
+ in_coords = np.zeros(3, dtype=np.float32)
3378
+ in_coords[2] = 1.0
3379
+
3380
+ for y in range(H):
3381
+ for x in range(W):
3382
+ # 1) Transform
3383
+ in_coords[0] = x
3384
+ in_coords[1] = y
3385
+ out_coords = M @ in_coords
3386
+ X = out_coords[0]
3387
+ Y = out_coords[1]
3388
+
3389
+ # 2) Upsample
3390
+ Xo = int(X * drizzle_factor)
3391
+ Yo = int(Y * drizzle_factor)
3392
+
3393
+ # 3) Check bounds
3394
+ if 0 <= Xo < outW and 0 <= Yo < outH:
3395
+ # 4) For each channel
3396
+ for cidx in range(channels):
3397
+ val = img_data[y, x, cidx]
3398
+ if val != 0:
3399
+ drizzle_buffer[Yo, Xo, cidx] += val * frame_weight
3400
+ coverage_buffer[Yo, Xo, cidx] += frame_weight
3401
+
3402
+ return drizzle_buffer, coverage_buffer
3403
+ @njit(fastmath=True)
3404
+ def drizzle_deposit_color_footprint(
3405
+ img_data, # shape (H,W,C)
3406
+ transform, # shape (2,3)
3407
+ drizzle_buffer, # shape (outH,outW,C)
3408
+ coverage_buffer, # shape (outH,outW,C)
3409
+ drizzle_factor: float,
3410
+ drop_shrink: float,
3411
+ frame_weight: float
3412
+ ):
3413
+ """
3414
+ Color version with a bounding-box footprint of width=drop_shrink
3415
+ for distributing flux in the output plane.
3416
+ """
3417
+ H, W, channels = img_data.shape
3418
+ outH, outW, outC = drizzle_buffer.shape
3419
+
3420
+ # Build 3×3 matrix
3421
+ M = np.zeros((3, 3), dtype=np.float32)
3422
+ M[0, 0] = transform[0, 0]
3423
+ M[0, 1] = transform[0, 1]
3424
+ M[0, 2] = transform[0, 2]
3425
+ M[1, 0] = transform[1, 0]
3426
+ M[1, 1] = transform[1, 1]
3427
+ M[1, 2] = transform[1, 2]
3428
+ M[2, 2] = 1.0
3429
+
3430
+ in_coords = np.zeros(3, dtype=np.float32)
3431
+ in_coords[2] = 1.0
3432
+
3433
+ footprint_radius = drop_shrink * 0.5
3434
+
3435
+ for y in range(H):
3436
+ for x in range(W):
3437
+ # Transform once per pixel
3438
+ in_coords[0] = x
3439
+ in_coords[1] = y
3440
+ out_coords = M @ in_coords
3441
+ X = out_coords[0]
3442
+ Y = out_coords[1]
3443
+
3444
+ # Upsample
3445
+ Xo = X * drizzle_factor
3446
+ Yo = Y * drizzle_factor
3447
+
3448
+ # bounding box
3449
+ min_x = int(np.floor(Xo - footprint_radius))
3450
+ max_x = int(np.floor(Xo + footprint_radius))
3451
+ min_y = int(np.floor(Yo - footprint_radius))
3452
+ max_y = int(np.floor(Yo + footprint_radius))
3453
+
3454
+ if max_x < 0 or min_x >= outW or max_y < 0 or min_y >= outH:
3455
+ continue
3456
+ if min_x < 0:
3457
+ min_x = 0
3458
+ if max_x >= outW:
3459
+ max_x = outW - 1
3460
+ if min_y < 0:
3461
+ min_y = 0
3462
+ if max_y >= outH:
3463
+ max_y = outH - 1
3464
+
3465
+ width_foot = (max_x - min_x + 1)
3466
+ height_foot = (max_y - min_y + 1)
3467
+ area_pixels = width_foot * height_foot
3468
+ if area_pixels <= 0:
3469
+ continue
3470
+
3471
+ for cidx in range(channels):
3472
+ val = img_data[y, x, cidx]
3473
+ if val == 0:
3474
+ continue
3475
+
3476
+ deposit_val = (val * frame_weight) / area_pixels
3477
+ coverage_fraction = frame_weight / area_pixels
3478
+
3479
+ for oy in range(min_y, max_y + 1):
3480
+ for ox in range(min_x, max_x + 1):
3481
+ drizzle_buffer[oy, ox, cidx] += deposit_val
3482
+ coverage_buffer[oy, ox, cidx] += coverage_fraction
3483
+
3484
+ return drizzle_buffer, coverage_buffer
3485
+
3486
+
3487
+ @njit
3488
+ def finalize_drizzle_3d(drizzle_buffer, coverage_buffer, final_out):
3489
+ """
3490
+ final_out[y,x,c] = drizzle_buffer[y,x,c] / coverage_buffer[y,x,c]
3491
+ if coverage < 1e-8 => 0
3492
+ """
3493
+ outH, outW, channels = drizzle_buffer.shape
3494
+ for y in range(outH):
3495
+ for x in range(outW):
3496
+ for cidx in range(channels):
3497
+ cov = coverage_buffer[y, x, cidx]
3498
+ if cov < 1e-8:
3499
+ final_out[y, x, cidx] = 0.0
3500
+ else:
3501
+ final_out[y, x, cidx] = drizzle_buffer[y, x, cidx] / cov
3502
+ return final_out
3503
+
3504
+
3505
+
3506
+ @njit
3507
+ def piecewise_linear(val, xvals, yvals):
3508
+ """
3509
+ Performs piecewise linear interpolation:
3510
+ Given a scalar 'val', and arrays xvals, yvals (each of length N),
3511
+ finds i s.t. xvals[i] <= val < xvals[i+1],
3512
+ then returns the linear interpolation between yvals[i], yvals[i+1].
3513
+ If val < xvals[0], returns yvals[0].
3514
+ If val > xvals[-1], returns yvals[-1].
3515
+ """
3516
+ if val <= xvals[0]:
3517
+ return yvals[0]
3518
+ for i in range(len(xvals)-1):
3519
+ if val < xvals[i+1]:
3520
+ # Perform a linear interpolation in interval [xvals[i], xvals[i+1]]
3521
+ dx = xvals[i+1] - xvals[i]
3522
+ dy = yvals[i+1] - yvals[i]
3523
+ ratio = (val - xvals[i]) / dx
3524
+ return yvals[i] + ratio * dy
3525
+ return yvals[-1]
3526
+
3527
+ @njit(parallel=True, fastmath=True)
3528
+ def apply_curves_numba(image, xvals, yvals):
3529
+ """
3530
+ Numba-accelerated routine to apply piecewise linear interpolation
3531
+ to each pixel in 'image'.
3532
+ - image can be (H,W) or (H,W,3).
3533
+ - xvals, yvals are the curve arrays in ascending order.
3534
+ Returns the adjusted image as float32.
3535
+ """
3536
+ if image.ndim == 2:
3537
+ H, W = image.shape
3538
+ out = np.empty((H, W), dtype=np.float32)
3539
+ for y in prange(H):
3540
+ for x in range(W):
3541
+ val = image[y, x]
3542
+ out[y, x] = piecewise_linear(val, xvals, yvals)
3543
+ return out
3544
+ elif image.ndim == 3:
3545
+ H, W, C = image.shape
3546
+ out = np.empty((H, W, C), dtype=np.float32)
3547
+ for y in prange(H):
3548
+ for x in range(W):
3549
+ for c in range(C):
3550
+ val = image[y, x, c]
3551
+ out[y, x, c] = piecewise_linear(val, xvals, yvals)
3552
+ return out
3553
+ else:
3554
+ # Unexpected shape
3555
+ return image # Fallback
3556
+
3557
+ def fast_star_detect(image,
3558
+ blur_size=9,
3559
+ threshold_factor=0.7,
3560
+ min_area=1,
3561
+ max_area=5000):
3562
+ """
3563
+ Finds star positions via contour detection + ellipse fitting.
3564
+ Returns Nx2 array of (x, y) star coordinates in the same coordinate system as 'image'.
3565
+ """
3566
+
3567
+ # 1) Convert to grayscale if needed
3568
+ if image.ndim == 3:
3569
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
3570
+
3571
+ # 2) Normalize to 8-bit [0..255]
3572
+ img_min, img_max = image.min(), image.max()
3573
+ if img_max <= img_min:
3574
+ return np.empty((0,2), dtype=np.float32) # All pixels same => no stars
3575
+ image_8u = (255.0 * (image - img_min) / (img_max - img_min)).astype(np.uint8)
3576
+
3577
+ # 3) Blur => subtract => highlight stars
3578
+ blurred = cv2.GaussianBlur(image_8u, (blur_size, blur_size), 0)
3579
+ subtracted = cv2.absdiff(image_8u, blurred)
3580
+
3581
+ # 4) Otsu's threshold => scaled by threshold_factor
3582
+ otsu_thresh, _ = cv2.threshold(subtracted, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
3583
+ final_thresh_val = max(2, int(otsu_thresh * threshold_factor))
3584
+
3585
+ _, thresh = cv2.threshold(subtracted, final_thresh_val, 255, cv2.THRESH_BINARY)
3586
+
3587
+ # 5) (Optional) morphological opening to remove single-pixel noise
3588
+ kernel = np.ones((2, 2), np.uint8)
3589
+ thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
3590
+
3591
+ # 6) Find contours
3592
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
3593
+
3594
+ # 7) Filter by area, fit ellipse => use ellipse center as star position
3595
+ star_positions = []
3596
+ for c in contours:
3597
+ area = cv2.contourArea(c)
3598
+ if area < min_area or area > max_area:
3599
+ continue
3600
+ if len(c) < 5:
3601
+ # Need >=5 points to fit an ellipse
3602
+ continue
3603
+
3604
+ ellipse = cv2.fitEllipse(c)
3605
+ (cx, cy), (major_axis, minor_axis), angle = ellipse
3606
+ # You could check eccentricity, etc. if you want to filter out weird shapes
3607
+ star_positions.append((cx, cy))
3608
+
3609
+ if len(star_positions) == 0:
3610
+ return np.empty((0,2), dtype=np.float32)
3611
+ else:
3612
+ return np.array(star_positions, dtype=np.float32)
3613
+
3614
+
3615
+ @njit(fastmath=True)
3616
+ def _drizzle_kernel_weights(kernel_code: int, Xo: float, Yo: float,
3617
+ min_x: int, max_x: int, min_y: int, max_y: int,
3618
+ sigma_out: float,
3619
+ weights_out): # preallocated 2D view (max_y-min_y+1, max_x-min_x+1)
3620
+ """
3621
+ Fill `weights_out` with unnormalized kernel weights centered at (Xo,Yo).
3622
+ Returns (sum_w, count_used).
3623
+ """
3624
+ H = max_y - min_y + 1
3625
+ W = max_x - min_x + 1
3626
+ r2_limit = sigma_out * sigma_out # for circle, sigma_out := radius
3627
+
3628
+ sum_w = 0.0
3629
+ cnt = 0
3630
+ for j in range(H):
3631
+ oy = min_y + j
3632
+ cy = (oy + 0.5) - Yo # pixel-center distance
3633
+ for i in range(W):
3634
+ ox = min_x + i
3635
+ cx = (ox + 0.5) - Xo
3636
+ w = 0.0
3637
+
3638
+ if kernel_code == 0:
3639
+ # square = uniform weight in the bounding box
3640
+ w = 1.0
3641
+ elif kernel_code == 1:
3642
+ # circle = uniform weight if inside radius
3643
+ if (cx*cx + cy*cy) <= r2_limit:
3644
+ w = 1.0
3645
+ else: # gaussian
3646
+ # gaussian centered at (Xo,Yo) with sigma_out
3647
+ z = (cx*cx + cy*cy) / (2.0 * sigma_out * sigma_out)
3648
+ # drop tiny far-away contributions to keep perf ok
3649
+ if z <= 9.0: # ~3σ
3650
+ w = math.exp(-z)
3651
+
3652
+ weights_out[j, i] = w
3653
+ sum_w += w
3654
+ if w > 0.0:
3655
+ cnt += 1
3656
+
3657
+ return sum_w, cnt
3658
+
3659
+