teareduce 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
teareduce/cleanest.py ADDED
@@ -0,0 +1,694 @@
1
+ #
2
+ # Copyright 2025 Universidad Complutense de Madrid
3
+ #
4
+ # This file is part of teareduce
5
+ #
6
+ # SPDX-License-Identifier: GPL-3.0+
7
+ # License-Filename: LICENSE.txt
8
+ #
9
+
10
+ """Interactive Cosmic Ray cleaning tool."""
11
+
12
+ import argparse
13
+ import tkinter as tk
14
+ from tkinter import filedialog
15
+ from tkinter import simpledialog
16
+
17
+ from astropy.io import fits
18
+ from ccdproc import cosmicray_lacosmic
19
+ import matplotlib.pyplot as plt
20
+ from matplotlib.backend_bases import key_press_handler
21
+ from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
22
+ import numpy as np
23
+ import os
24
+ from scipy import ndimage
25
+
26
+ from .imshow import imshow
27
+ from .sliceregion import SliceRegion2D
28
+ from .zscale import zscale
29
+
30
+ import matplotlib
31
+ matplotlib.use("TkAgg")
32
+
33
+
34
+ class ReviewCosmicRay():
35
+ """Class to review suspected cosmic ray pixels."""
36
+
37
+ def __init__(self, root, data, mask_fixed, mask_crfound):
38
+ """Initialize the review window.
39
+
40
+ Parameters
41
+ ----------
42
+ root : tk.Tk
43
+ The main Tkinter window.
44
+ data : 2D numpy array
45
+ The original image data.
46
+ mask_fixed : 2D numpy array
47
+ Mask of previously corrected pixels.
48
+ mask_crfound : 2D numpy array
49
+ Mask of new pixels identified as cosmic rays.
50
+ """
51
+ self.root = root
52
+ self.data = data
53
+ self.data_original = data.copy()
54
+ self.mask_fixed = mask_fixed
55
+ self.mask_crfound = mask_crfound
56
+ self.first_plot = True
57
+ self.degree = 1 # Degree of polynomial for interpolation
58
+ self.npoints = 2 # Number of points at each side of the CR pixel for interpolation
59
+ # Label connected components in the mask; note that by default,
60
+ # structure is a cross [0,1,0;1,1,1;0,1,0], but we want to consider
61
+ # diagonal connections too, so we define a 3x3 square.
62
+ structure = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]
63
+ self.cr_labels, self.num_features = ndimage.label(self.mask_crfound, structure=structure)
64
+ # Make a copy of the original labels to allow pixel re-marking
65
+ self.cr_labels_original = self.cr_labels.copy()
66
+ print(f"Number of cosmic ray pixels detected: {np.sum(self.mask_crfound)}")
67
+ print(f"Number of cosmic rays detected: {self.num_features}")
68
+ if self.num_features == 0:
69
+ print('No CR hits found!')
70
+ else:
71
+ self.cr_index = 1
72
+ self.create_widgets()
73
+
74
+ def create_widgets(self):
75
+ self.review_window = tk.Toplevel(self.root)
76
+ self.review_window.title("Review Cosmic Rays")
77
+ self.review_window.geometry("800x700+100+100")
78
+
79
+ self.button_frame1 = tk.Frame(self.review_window)
80
+ self.button_frame1.pack(pady=5)
81
+ self.remove_crosses_button = tk.Button(self.button_frame1, text="remove all X's", command=self.remove_crosses)
82
+ self.remove_crosses_button.pack(side=tk.LEFT, padx=5)
83
+ self.restore_cr_button = tk.Button(self.button_frame1, text="[r]estore CR", command=self.restore_cr)
84
+ self.restore_cr_button.pack(side=tk.LEFT, padx=5)
85
+ self.restore_cr_button.config(state=tk.DISABLED)
86
+ self.next_button = tk.Button(self.button_frame1, text="[c]ontinue", command=self.continue_cr)
87
+ self.next_button.pack(side=tk.LEFT, padx=5)
88
+
89
+ self.button_frame2 = tk.Frame(self.review_window)
90
+ self.button_frame2.pack(pady=5)
91
+ self.ndeg_label = tk.Button(self.button_frame2, text=f"deg={self.degree}, n={self.npoints}",
92
+ command=self.set_ndeg)
93
+ self.ndeg_label.pack(side=tk.LEFT, padx=5)
94
+ self.interp_x_button = tk.Button(self.button_frame2, text="[x] interp.", command=self.interp_x)
95
+ self.interp_x_button.pack(side=tk.LEFT, padx=5)
96
+ self.interp_y_button = tk.Button(self.button_frame2, text="[y] interp.", command=self.interp_y)
97
+ self.interp_y_button.pack(side=tk.LEFT, padx=5)
98
+ self.interp_s_button = tk.Button(self.button_frame2, text="[s] interp.", command=self.interp_s)
99
+ self.interp_s_button.pack(side=tk.LEFT, padx=5)
100
+
101
+ self.button_frame3 = tk.Frame(self.review_window)
102
+ self.button_frame3.pack(pady=5)
103
+ vmin, vmax = zscale(self.data)
104
+ self.vmin_button = tk.Button(self.button_frame3, text=f"vmin: {vmin:.2f}", command=self.set_vmin)
105
+ self.vmin_button.pack(side=tk.LEFT, padx=5)
106
+ self.vmax_button = tk.Button(self.button_frame3, text=f"vmax: {vmax:.2f}", command=self.set_vmax)
107
+ self.vmax_button.pack(side=tk.LEFT, padx=5)
108
+ self.set_minmax_button = tk.Button(self.button_frame3, text="minmax [,]", command=self.set_minmax)
109
+ self.set_minmax_button.pack(side=tk.LEFT, padx=5)
110
+ self.set_zscale_button = tk.Button(self.button_frame3, text="zscale [/]", command=self.set_zscale)
111
+ self.set_zscale_button.pack(side=tk.LEFT, padx=5)
112
+ self.exit_button = tk.Button(self.button_frame3, text="[e]xit review", command=self.exit_review)
113
+ self.exit_button.pack(side=tk.LEFT, padx=5)
114
+
115
+ self.fig, self.ax = plt.subplots(figsize=(8, 5))
116
+ self.canvas = FigureCanvasTkAgg(self.fig, master=self.review_window)
117
+ # The next two instructions prevent a segmentation fault when pressing "q"
118
+ self.canvas.mpl_disconnect(self.canvas.mpl_connect("key_press_event", key_press_handler))
119
+ self.canvas.mpl_connect("key_press_event", self.on_key)
120
+ self.canvas.mpl_connect("button_press_event", self.on_click)
121
+ self.canvas_widget = self.canvas.get_tk_widget()
122
+ self.canvas_widget.pack(fill=tk.BOTH, expand=True)
123
+
124
+ # Matplotlib toolbar
125
+ self.toolbar_frame = tk.Frame(self.review_window)
126
+ self.toolbar_frame.pack(fill=tk.X, expand=False, pady=5)
127
+ self.toolbar = NavigationToolbar2Tk(self.canvas, self.toolbar_frame)
128
+ self.toolbar.update()
129
+
130
+ self.update_display()
131
+
132
+ self.root.wait_window(self.review_window)
133
+
134
+ def update_display(self):
135
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
136
+ ycr_list_original, xcr_list_original = np.where(self.cr_labels_original == self.cr_index)
137
+ if self.first_plot:
138
+ print(f"Cosmic ray {self.cr_index}: "
139
+ f"Number of pixels = {len(xcr_list)}, "
140
+ f"Centroid = ({np.mean(xcr_list):.2f}, {np.mean(ycr_list):.2f})")
141
+ # Use original positions to define the region to display in order
142
+ # to avoid image shifts when some pixels are unmarked or new ones are marked
143
+ i0 = int(np.mean(ycr_list_original) + 0.5)
144
+ j0 = int(np.mean(xcr_list_original) + 0.5)
145
+ jmin = j0 - 15 if j0 - 15 >= 0 else 0
146
+ jmax = j0 + 15 if j0 + 15 < self.data.shape[1] else self.data.shape[1] - 1
147
+ imin = i0 - 15 if i0 - 15 >= 0 else 0
148
+ imax = i0 + 15 if i0 + 15 < self.data.shape[0] else self.data.shape[0] - 1
149
+ self.region = SliceRegion2D(f'[{jmin+1}:{jmax+1}, {imin+1}:{imax+1}]', mode='fits').python
150
+ self.ax.clear()
151
+ vmin = self.get_vmin()
152
+ vmax = self.get_vmax()
153
+ xlabel = 'X pixel (from 1 to NAXIS1)'
154
+ ylabel = 'Y pixel (from 1 to NAXIS2)'
155
+ self.image_review, _, _ = imshow(self.fig, self.ax, self.data[self.region], colorbar=False,
156
+ xlabel=xlabel, ylabel=ylabel,
157
+ vmin=vmin, vmax=vmax)
158
+ self.image_review.set_extent([jmin + 0.5, jmax + 1.5, imin + 0.5, imax + 1.5])
159
+ xlim = self.ax.get_xlim()
160
+ ylim = self.ax.get_ylim()
161
+ for xcr, ycr in zip(xcr_list, ycr_list):
162
+ xcr += 1 # from index to pixel
163
+ ycr += 1 # from index to pixel
164
+ self.ax.plot([xcr - 0.5, xcr + 0.5], [ycr + 0.5, ycr - 0.5], 'r-')
165
+ self.ax.plot([xcr - 0.5, xcr + 0.5], [ycr - 0.5, ycr + 0.5], 'r-')
166
+ self.ax.set_xlim(xlim)
167
+ self.ax.set_ylim(ylim)
168
+ self.ax.set_title(f"Cosmic ray #{self.cr_index}/{self.num_features}")
169
+ if self.first_plot:
170
+ self.first_plot = False
171
+ self.fig.tight_layout()
172
+ self.canvas.draw()
173
+
174
+ def set_vmin(self):
175
+ old_vmin = self.get_vmin()
176
+ new_vmin = simpledialog.askfloat("Set vmin", "Enter new vmin:", initialvalue=old_vmin)
177
+ if new_vmin is None:
178
+ return
179
+ self.vmin_button.config(text=f"vmin: {new_vmin:.2f}")
180
+ self.image_review.set_clim(vmin=new_vmin)
181
+ self.canvas.draw()
182
+
183
+ def set_vmax(self):
184
+ old_vmax = self.get_vmax()
185
+ new_vmax = simpledialog.askfloat("Set vmax", "Enter new vmax:", initialvalue=old_vmax)
186
+ if new_vmax is None:
187
+ return
188
+ self.vmax_button.config(text=f"vmax: {new_vmax:.2f}")
189
+ self.image_review.set_clim(vmax=new_vmax)
190
+ self.canvas.draw()
191
+
192
+ def get_vmin(self):
193
+ return float(self.vmin_button.cget("text").split(":")[1])
194
+
195
+ def get_vmax(self):
196
+ return float(self.vmax_button.cget("text").split(":")[1])
197
+
198
+ def set_minmax(self):
199
+ vmin_new = np.min(self.data[self.region])
200
+ vmax_new = np.max(self.data[self.region])
201
+ self.vmin_button.config(text=f"vmin: {vmin_new:.2f}")
202
+ self.vmax_button.config(text=f"vmax: {vmax_new:.2f}")
203
+ self.image_review.set_clim(vmin=vmin_new)
204
+ self.image_review.set_clim(vmax=vmax_new)
205
+ self.canvas.draw()
206
+
207
+ def set_zscale(self):
208
+ vmin_new, vmax_new = zscale(self.data[self.region])
209
+ self.vmin_button.config(text=f"vmin: {vmin_new:.2f}")
210
+ self.vmax_button.config(text=f"vmax: {vmax_new:.2f}")
211
+ self.image_review.set_clim(vmin=vmin_new)
212
+ self.image_review.set_clim(vmax=vmax_new)
213
+ self.canvas.draw()
214
+
215
+ def set_ndeg(self):
216
+ new_degree = simpledialog.askinteger("Set degree", "Enter new degree (min=0):",
217
+ initialvalue=self.degree, minvalue=0)
218
+ if new_degree is None:
219
+ return
220
+ new_npoints = simpledialog.askinteger("Set n", f"Enter new n (min={2*new_degree}):",
221
+ initialvalue=self.npoints, minvalue=2*new_degree)
222
+ if new_npoints is None:
223
+ return
224
+ self.degree = new_degree
225
+ self.npoints = new_npoints
226
+ self.ndeg_label.config(text=f"deg={self.degree}, n={self.npoints}")
227
+
228
+ def interp_x(self):
229
+ print(f"X-interpolation of cosmic ray {self.cr_index}")
230
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
231
+ ycr_min = np.min(ycr_list)
232
+ ycr_max = np.max(ycr_list)
233
+ xfit_all = []
234
+ yfit_all = []
235
+ for ycr in range(ycr_min, ycr_max + 1):
236
+ xmarked = xcr_list[np.where(ycr_list == ycr)]
237
+ if len(xmarked) > 0:
238
+ jmin = np.min(xmarked)
239
+ jmax = np.max(xmarked)
240
+ # mark intermediate pixels too
241
+ for ix in range(jmin, jmax + 1):
242
+ self.cr_labels[ycr, ix] = self.cr_index
243
+ xmarked = xcr_list[np.where(ycr_list == ycr)]
244
+ xfit = []
245
+ zfit = []
246
+ for i in range(jmin - self.npoints, jmin):
247
+ if 0 <= i < self.data.shape[1]:
248
+ xfit.append(i)
249
+ xfit_all.append(i)
250
+ yfit_all.append(ycr)
251
+ zfit.append(self.data[ycr, i])
252
+ for i in range(jmax + 1, jmax + 1 + self.npoints):
253
+ if 0 <= i < self.data.shape[1]:
254
+ xfit.append(i)
255
+ xfit_all.append(i)
256
+ yfit_all.append(ycr)
257
+ zfit.append(self.data[ycr, i])
258
+ if len(xfit) > self.degree:
259
+ p = np.polyfit(xfit, zfit, self.degree)
260
+ for i in range(jmin, jmax + 1):
261
+ if 0 <= i < self.data.shape[1]:
262
+ self.data[ycr, i] = np.polyval(p, i)
263
+ self.mask_fixed[ycr, i] = True
264
+ else:
265
+ print(f"Not enough points to fit at y={ycr+1}")
266
+ self.update_display()
267
+ return
268
+ self.restore_cr_button.config(state=tk.NORMAL)
269
+ self.remove_crosses_button.config(state=tk.DISABLED)
270
+ self.interp_x_button.config(state=tk.DISABLED)
271
+ self.interp_y_button.config(state=tk.DISABLED)
272
+ self.interp_s_button.config(state=tk.DISABLED)
273
+ self.update_display()
274
+ if len(xfit_all) > 0:
275
+ self.ax.plot(np.array(xfit_all) + 1, np.array(yfit_all) + 1, 'mo', markersize=4) # +1: from index to pixel
276
+ self.canvas.draw()
277
+
278
+ def interp_y(self):
279
+ print(f"Y-interpolation of cosmic ray {self.cr_index}")
280
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
281
+ xcr_min = np.min(xcr_list)
282
+ xcr_max = np.max(xcr_list)
283
+ xfit_all = []
284
+ yfit_all = []
285
+ for xcr in range(xcr_min, xcr_max + 1):
286
+ ymarked = ycr_list[np.where(xcr_list == xcr)]
287
+ if len(ymarked) > 0:
288
+ imin = np.min(ymarked)
289
+ imax = np.max(ymarked)
290
+ # mark intermediate pixels too
291
+ for iy in range(imin, imax + 1):
292
+ self.cr_labels[iy, xcr] = self.cr_index
293
+ ymarked = ycr_list[np.where(xcr_list == xcr)]
294
+ yfit = []
295
+ zfit = []
296
+ for i in range(imin - self.npoints, imin):
297
+ if 0 <= i < self.data.shape[0]:
298
+ yfit.append(i)
299
+ yfit_all.append(i)
300
+ xfit_all.append(xcr)
301
+ zfit.append(self.data[i, xcr])
302
+ for i in range(imax + 1, imax + 1 + self.npoints):
303
+ if 0 <= i < self.data.shape[0]:
304
+ yfit.append(i)
305
+ yfit_all.append(i)
306
+ xfit_all.append(xcr)
307
+ zfit.append(self.data[i, xcr])
308
+ if len(yfit) > self.degree:
309
+ p = np.polyfit(yfit, zfit, self.degree)
310
+ for i in range(imin, imax + 1):
311
+ if 0 <= i < self.data.shape[1]:
312
+ self.data[i, xcr] = np.polyval(p, i)
313
+ self.mask_fixed[i, xcr] = True
314
+ else:
315
+ print(f"Not enough points to fit at x={xcr+1}")
316
+ self.update_display()
317
+ return
318
+ self.restore_cr_button.config(state=tk.NORMAL)
319
+ self.remove_crosses_button.config(state=tk.DISABLED)
320
+ self.interp_x_button.config(state=tk.DISABLED)
321
+ self.interp_y_button.config(state=tk.DISABLED)
322
+ self.interp_s_button.config(state=tk.DISABLED)
323
+ self.update_display()
324
+ if len(xfit_all) > 0:
325
+ self.ax.plot(np.array(xfit_all) + 1, np.array(yfit_all) + 1, 'mo', markersize=4) # +1: from index to pixel
326
+ self.canvas.draw()
327
+
328
+ def interp_s(self):
329
+ print(f"S-interpolation of cosmic ray {self.cr_index}")
330
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
331
+ ycr_min = np.min(ycr_list)
332
+ ycr_max = np.max(ycr_list)
333
+ xfit_all = []
334
+ yfit_all = []
335
+ zfit_all = []
336
+ # First do horizontal lines
337
+ for ycr in range(ycr_min, ycr_max + 1):
338
+ xmarked = xcr_list[np.where(ycr_list == ycr)]
339
+ if len(xmarked) > 0:
340
+ jmin = np.min(xmarked)
341
+ jmax = np.max(xmarked)
342
+ # mark intermediate pixels too
343
+ for ix in range(jmin, jmax + 1):
344
+ self.cr_labels[ycr, ix] = self.cr_index
345
+ xmarked = xcr_list[np.where(ycr_list == ycr)]
346
+ for i in range(jmin - self.npoints, jmin):
347
+ if 0 <= i < self.data.shape[1]:
348
+ xfit_all.append(i)
349
+ yfit_all.append(ycr)
350
+ zfit_all.append(self.data[ycr, i])
351
+ for i in range(jmax + 1, jmax + 1 + self.npoints):
352
+ if 0 <= i < self.data.shape[1]:
353
+ xfit_all.append(i)
354
+ yfit_all.append(ycr)
355
+ zfit_all.append(self.data[ycr, i])
356
+ xcr_min = np.min(xcr_list)
357
+ # Now do vertical lines
358
+ xcr_max = np.max(xcr_list)
359
+ for xcr in range(xcr_min, xcr_max + 1):
360
+ ymarked = ycr_list[np.where(xcr_list == xcr)]
361
+ if len(ymarked) > 0:
362
+ imin = np.min(ymarked)
363
+ imax = np.max(ymarked)
364
+ # mark intermediate pixels too
365
+ for iy in range(imin, imax + 1):
366
+ self.cr_labels[iy, xcr] = self.cr_index
367
+ ymarked = ycr_list[np.where(xcr_list == xcr)]
368
+ for i in range(imin - self.npoints, imin):
369
+ if 0 <= i < self.data.shape[0]:
370
+ yfit_all.append(i)
371
+ xfit_all.append(xcr)
372
+ zfit_all.append(self.data[i, xcr])
373
+ for i in range(imax + 1, imax + 1 + self.npoints):
374
+ if 0 <= i < self.data.shape[0]:
375
+ yfit_all.append(i)
376
+ xfit_all.append(xcr)
377
+ zfit_all.append(self.data[i, xcr])
378
+ if len(xfit_all) > 3:
379
+ # Construct the design matrix for a 2D polynomial fit to a plane,
380
+ # where each row corresponds to a point (x, y, z) and the model
381
+ # is z = C[0]*x + C[1]*y + C[2]
382
+ A = np.c_[xfit_all, yfit_all, np.ones(len(xfit_all))]
383
+ # Least squares polynomial fit
384
+ C, _, _, _ = np.linalg.lstsq(A, zfit_all, rcond=None)
385
+ # recompute all CR pixels to take into account "holes" between marked pixels
386
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
387
+ for iy, ix in zip(ycr_list, xcr_list):
388
+ self.data[iy, ix] = C[0] * ix + C[1] * iy + C[2]
389
+ self.mask_fixed[iy, ix] = True
390
+ else:
391
+ print("Not enough points to fit a plane")
392
+ self.update_display()
393
+ return
394
+ self.restore_cr_button.config(state=tk.NORMAL)
395
+ self.remove_crosses_button.config(state=tk.DISABLED)
396
+ self.interp_x_button.config(state=tk.DISABLED)
397
+ self.interp_y_button.config(state=tk.DISABLED)
398
+ self.interp_s_button.config(state=tk.DISABLED)
399
+ self.update_display()
400
+ if len(xfit_all) > 0:
401
+ self.ax.plot(np.array(xfit_all) + 1, np.array(yfit_all) + 1, 'mo', markersize=4) # +1: from index to pixel
402
+ self.canvas.draw()
403
+
404
+ def remove_crosses(self):
405
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
406
+ for iy, ix in zip(ycr_list, xcr_list):
407
+ self.cr_labels[iy, ix] = 0
408
+ print(f"Removed all pixels of cosmic ray {self.cr_index}")
409
+ self.remove_crosses_button.config(state=tk.DISABLED)
410
+ self.interp_x_button.config(state=tk.DISABLED)
411
+ self.interp_y_button.config(state=tk.DISABLED)
412
+ self.interp_s_button.config(state=tk.DISABLED)
413
+ self.update_display()
414
+
415
+ def restore_cr(self):
416
+ ycr_list, xcr_list = np.where(self.cr_labels == self.cr_index)
417
+ for iy, ix in zip(ycr_list, xcr_list):
418
+ self.data[iy, ix] = self.data_original[iy, ix]
419
+ self.interp_x_button.config(state=tk.NORMAL)
420
+ self.interp_y_button.config(state=tk.NORMAL)
421
+ self.interp_s_button.config(state=tk.NORMAL)
422
+ print(f"Restored all pixels of cosmic ray {self.cr_index}")
423
+ self.remove_crosses_button.config(state=tk.NORMAL)
424
+ self.restore_cr_button.config(state=tk.DISABLED)
425
+ self.update_display()
426
+
427
+ def continue_cr(self):
428
+ self.cr_index += 1
429
+ if self.cr_index > self.num_features:
430
+ self.cr_index = 1
431
+ self.first_plot = True
432
+ self.restore_cr_button.config(state=tk.DISABLED)
433
+ self.interp_x_button.config(state=tk.NORMAL)
434
+ self.interp_y_button.config(state=tk.NORMAL)
435
+ self.interp_s_button.config(state=tk.NORMAL)
436
+ self.update_display()
437
+
438
+ def exit_review(self):
439
+ self.review_window.destroy()
440
+
441
+ def on_key(self, event):
442
+ if event.key == 'q':
443
+ pass # Ignore the "q" key to prevent closing the window
444
+ elif event.key == 'r':
445
+ if self.restore_cr_button.cget("state") != "disabled":
446
+ self.restore_cr()
447
+ elif event.key == 'x':
448
+ if self.interp_x_button.cget("state") != "disabled":
449
+ self.interp_x()
450
+ elif event.key == 'y':
451
+ if self.interp_y_button.cget("state") != "disabled":
452
+ self.interp_y()
453
+ elif event.key == 's':
454
+ if self.interp_s_button.cget("state") != "disabled":
455
+ self.interp_s()
456
+ elif event.key == 'right' or event.key == 'c':
457
+ self.continue_cr()
458
+ elif event.key == ',':
459
+ self.set_minmax()
460
+ elif event.key == '/':
461
+ self.set_zscale()
462
+ elif event.key == 'e':
463
+ self.exit_review()
464
+ else:
465
+ print(f"Key pressed: {event.key}")
466
+
467
+ def on_click(self, event):
468
+ if event.inaxes:
469
+ x, y = event.xdata, event.ydata
470
+ print(f"Clicked at image coordinates: ({x:.2f}, {y:.2f})")
471
+ ix = int(x+0.5) - 1 # from pixel to index
472
+ iy = int(y+0.5) - 1 # from pixel to index
473
+ if int(self.cr_labels[iy, ix]) == self.cr_index:
474
+ self.cr_labels[iy, ix] = 0
475
+ print(f"Pixel ({ix+1}, {iy+1}) unmarked as cosmic ray.")
476
+ else:
477
+ self.cr_labels[iy, ix] = self.cr_index
478
+ print(f"Pixel ({ix+1}, {iy+1}) marked as cosmic ray.")
479
+ xcr_list, ycr_list = np.where(self.cr_labels == self.cr_index)
480
+ if len(xcr_list) == 0:
481
+ self.interp_x_button.config(state=tk.DISABLED)
482
+ self.interp_y_button.config(state=tk.DISABLED)
483
+ self.interp_s_button.config(state=tk.DISABLED)
484
+ self.remove_crosses_button.config(state=tk.DISABLED)
485
+ else:
486
+ self.interp_x_button.config(state=tk.NORMAL)
487
+ self.interp_y_button.config(state=tk.NORMAL)
488
+ self.interp_s_button.config(state=tk.NORMAL)
489
+ self.remove_crosses_button.config(state=tk.NORMAL)
490
+ # Update the display to reflect the change
491
+ self.update_display()
492
+
493
+
494
+ class CosmicRayCleanerApp():
495
+ """Main application class for cosmic ray cleaning."""
496
+
497
+ def __init__(self, root, input_fits, extension=0, output_fits=None):
498
+ """
499
+ Initialize the application.
500
+
501
+ Parameters
502
+ ----------
503
+ root : tk.Tk
504
+ The main Tkinter window.
505
+ input_fits : str
506
+ Path to the FITS file to be cleaned.
507
+ extension : int, optional
508
+ FITS extension to use (default is 0).
509
+ output_fits : str, optional
510
+ Path to save the cleaned FITS file (default is None, which prompts
511
+ for a save location).
512
+ """
513
+ self.root = root
514
+ self.root.title("Cosmic Ray Cleaner")
515
+ self.root.geometry("800x700+50+0")
516
+ self.input_fits = input_fits
517
+ self.extension = extension
518
+ self.output_fits = output_fits
519
+ self.load_fits_file()
520
+ self.create_widgets()
521
+
522
+ def load_fits_file(self):
523
+ try:
524
+ with fits.open(self.input_fits, mode='readonly') as hdul:
525
+ self.data = hdul[self.extension].data
526
+ if 'CRMASK' in hdul:
527
+ self.mask_fixed = hdul['CRMASK'].data.astype(bool)
528
+ else:
529
+ self.mask_fixed = np.zeros(self.data.shape, dtype=bool)
530
+ except Exception as e:
531
+ print(f"Error loading FITS file: {e}")
532
+
533
+ def save_fits_file(self):
534
+ if self.output_fits is None:
535
+ base, ext = os.path.splitext(self.input_fits)
536
+ suggested_name = f"{base}_cleaned"
537
+ else:
538
+ suggested_name, _ = os.path.splitext(self.output_fits)
539
+ self.output_fits = filedialog.asksaveasfilename(
540
+ initialdir=os.getcwd(),
541
+ title="Save cleaned FITS file",
542
+ defaultextension=".fits",
543
+ filetypes=[("FITS files", "*.fits"), ("All files", "*.*")],
544
+ initialfile=suggested_name
545
+ )
546
+ try:
547
+ with fits.open(self.input_fits, mode='readonly') as hdul:
548
+ hdul[self.extension].data = self.data
549
+ if 'CRMASK' in hdul:
550
+ hdul['CRMASK'].data = self.mask_fixed.astype(np.uint8)
551
+ else:
552
+ crmask_hdu = fits.ImageHDU(self.mask_fixed.astype(np.uint8), name='CRMASK')
553
+ hdul.append(crmask_hdu)
554
+ hdul.writeto(self.output_fits, overwrite=True)
555
+ print(f"Cleaned data saved to {self.output_fits}")
556
+ except Exception as e:
557
+ print(f"Error saving FITS file: {e}")
558
+
559
+ def create_widgets(self):
560
+ # Row 1
561
+ self.button_frame1 = tk.Frame(self.root)
562
+ self.button_frame1.grid(row=0, column=0, pady=5)
563
+ self.run_lacosmic_button = tk.Button(self.button_frame1, text="Run L.A.Cosmic", command=self.run_lacosmic)
564
+ self.run_lacosmic_button.pack(side=tk.LEFT, padx=5)
565
+ self.save_button = tk.Button(self.button_frame1, text="Save cleaned FITS", command=self.save_fits_file)
566
+ self.save_button.pack(side=tk.LEFT, padx=5)
567
+
568
+ # Row 2
569
+ self.button_frame2 = tk.Frame(self.root)
570
+ self.button_frame2.grid(row=1, column=0, pady=5)
571
+ vmin, vmax = zscale(self.data)
572
+ self.vmin_button = tk.Button(self.button_frame2, text=f"vmin: {vmin:.2f}", command=self.set_vmin)
573
+ self.vmin_button.pack(side=tk.LEFT, padx=5)
574
+ self.vmax_button = tk.Button(self.button_frame2, text=f"vmax: {vmax:.2f}", command=self.set_vmax)
575
+ self.vmax_button.pack(side=tk.LEFT, padx=5)
576
+ self.stop_button = tk.Button(self.button_frame2, text="Stop program", command=self.stop_app)
577
+ self.stop_button.pack(side=tk.LEFT, padx=5)
578
+
579
+ # Main frame for figure and toolbar
580
+ self.main_frame = tk.Frame(self.root)
581
+ self.main_frame.grid(row=2, column=0, sticky="nsew")
582
+ self.root.grid_rowconfigure(2, weight=1)
583
+ self.root.grid_columnconfigure(0, weight=1)
584
+ self.main_frame.grid_rowconfigure(0, weight=1)
585
+ self.main_frame.grid_columnconfigure(0, weight=1)
586
+
587
+ # Create figure and axis
588
+ self.fig, self.ax = plt.subplots(figsize=(8, 6))
589
+ xlabel = 'X pixel (from 1 to NAXIS1)'
590
+ ylabel = 'Y pixel (from 1 to NAXIS2)'
591
+ extent = [0.5, self.data.shape[1] + 0.5, 0.5, self.data.shape[0] + 0.5]
592
+ self.image, _, _ = imshow(self.fig, self.ax, self.data, vmin=vmin, vmax=vmax,
593
+ xlabel=xlabel, ylabel=ylabel, extent=extent)
594
+ # Note: tight_layout should be called before defining the canvas
595
+ self.fig.tight_layout()
596
+
597
+ # Create canvas and toolbar
598
+ self.canvas = FigureCanvasTkAgg(self.fig, master=self.main_frame)
599
+ # The next two instructions prevent a segmentation fault when pressing "q"
600
+ self.canvas.mpl_disconnect(self.canvas.mpl_connect("key_press_event", key_press_handler))
601
+ self.canvas.mpl_connect("key_press_event", self.on_key)
602
+ self.canvas.mpl_connect("button_press_event", self.on_click)
603
+ canvas_widget = self.canvas.get_tk_widget()
604
+ canvas_widget.grid(row=0, column=0, sticky="nsew")
605
+
606
+ # Matplotlib toolbar
607
+ self.toolbar_frame = tk.Frame(self.main_frame)
608
+ self.toolbar_frame.grid(row=1, column=0, sticky="ew")
609
+ self.toolbar = NavigationToolbar2Tk(self.canvas, self.toolbar_frame)
610
+ self.toolbar.update()
611
+
612
+ def set_vmin(self):
613
+ old_vmin = self.get_vmin()
614
+ new_vmin = simpledialog.askfloat("Set vmin", "Enter new vmin:", initialvalue=old_vmin)
615
+ if new_vmin is None:
616
+ return
617
+ self.vmin_button.config(text=f"vmin: {new_vmin:.2f}")
618
+ self.image.set_clim(vmin=new_vmin)
619
+ self.canvas.draw()
620
+
621
+ def set_vmax(self):
622
+ old_vmax = self.get_vmax()
623
+ new_vmax = simpledialog.askfloat("Set vmax", "Enter new vmax:", initialvalue=old_vmax)
624
+ if new_vmax is None:
625
+ return
626
+ self.vmax_button.config(text=f"vmax: {new_vmax:.2f}")
627
+ self.image.set_clim(vmax=new_vmax)
628
+ self.canvas.draw()
629
+
630
+ def get_vmin(self):
631
+ return float(self.vmin_button.cget("text").split(":")[1])
632
+
633
+ def get_vmax(self):
634
+ return float(self.vmax_button.cget("text").split(":")[1])
635
+
636
+ def run_lacosmic(self):
637
+ self.run_lacosmic_button.config(state=tk.DISABLED)
638
+ self.stop_button.config(state=tk.DISABLED)
639
+ # Parameters for L.A.Cosmic can be adjusted as needed
640
+ _, mask_crfound = cosmicray_lacosmic(self.data, sigclip=4.5, sigfrac=0.3, objlim=5.0, verbose=True)
641
+ ReviewCosmicRay(
642
+ root=self.root,
643
+ data=self.data,
644
+ mask_fixed=self.mask_fixed,
645
+ mask_crfound=mask_crfound
646
+ )
647
+ print("L.A.Cosmic cleaning applied.")
648
+ self.run_lacosmic_button.config(state=tk.NORMAL)
649
+ self.stop_button.config(state=tk.NORMAL)
650
+
651
+ def stop_app(self):
652
+ self.root.quit()
653
+ self.root.destroy()
654
+
655
+ def on_key(self, event):
656
+ if event.key == 'q':
657
+ pass # Ignore the "q" key to prevent closing the window
658
+ else:
659
+ print(f"Key pressed: {event.key}")
660
+
661
+ def on_click(self, event):
662
+ if event.inaxes:
663
+ x, y = event.xdata, event.ydata
664
+ print(f"Clicked at image coordinates: ({x:.2f}, {y:.2f})")
665
+
666
+
667
+ def main():
668
+ parser = argparse.ArgumentParser(description="Interactive cosmic ray cleaner for FITS images.")
669
+ parser.add_argument("input_fits", help="Path to the FITS file to be cleaned.")
670
+ parser.add_argument("--extension", type=int, default=0,
671
+ help="FITS extension to use (default: 0).")
672
+ parser.add_argument("--output_fits", type=str, default=None,
673
+ help="Path to save the cleaned FITS file")
674
+ args = parser.parse_args()
675
+
676
+ if not os.path.isfile(args.input_fits):
677
+ print(f"Error: File '{args.input_fits}' does not exist.")
678
+ return
679
+ if args.output_fits is not None and os.path.isfile(args.output_fits):
680
+ print(f"Error: Output file '{args.output_fits}' already exists.")
681
+ return
682
+
683
+ # Initialize Tkinter root
684
+ root = tk.Tk()
685
+
686
+ # Create and run the application
687
+ CosmicRayCleanerApp(root, args.input_fits, args.extension, args.output_fits)
688
+
689
+ # Execute
690
+ root.mainloop()
691
+
692
+
693
+ if __name__ == "__main__":
694
+ main()