lightstack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lightstack/plot.py ADDED
@@ -0,0 +1,310 @@
1
+ import numpy as np
2
+ import os
3
+ import glob
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ from astropy.io import fits
8
+ from astropy.wcs import WCS
9
+ from astropy.visualization import simple_norm
10
+
11
+ from .utils import find_ext, infer_filter, get_filter
12
+
13
+
14
+ def visualize_fits(fits_path, save_path=None, stretch='log',
15
+ min_percent=25., max_percent=99.98,
16
+ cmap='viridis', xlim=None, ylim=None):
17
+ """
18
+ Visualizes a FITS file with both pixel and RA/Dec axes.
19
+
20
+ Parameters
21
+ ----------
22
+ fits_path : str
23
+ Path to the FITS file.
24
+ save_path : str or None
25
+ Path to save the output image. If None, the figure is shown but not saved.
26
+ stretch : str
27
+ Stretch type for simple_norm (e.g., 'linear', 'log', 'sqrt').
28
+ min_percent, max_percent : float
29
+ Percentile limits for normalization.
30
+ cmap : str
31
+ Colormap name.
32
+ xlim, ylim : tuple or None
33
+ Pixel axis limits.
34
+ """
35
+ # Open FITS file
36
+ with fits.open(fits_path, memmap=False) as hdul:
37
+ ext = find_ext(hdul)
38
+ if ext is None:
39
+ print(f"No data extension found in {fits_path}")
40
+ return
41
+
42
+ data = hdul[ext].data
43
+ header = hdul[ext].header
44
+ wcs = WCS(header)
45
+
46
+ # Mask invalid or non-positive values for normalization
47
+ mask = np.logical_or(np.isnan(data), data <= 0.)
48
+ norm = simple_norm(
49
+ data[~mask],
50
+ stretch=stretch,
51
+ min_percent=min_percent,
52
+ max_percent=max_percent)
53
+
54
+ # Create figure
55
+ plt.rcParams['font.size'] = 10
56
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw={'projection': wcs})
57
+
58
+ im = ax.imshow(
59
+ data,
60
+ cmap=cmap,
61
+ origin='lower',
62
+ norm=norm,
63
+ interpolation='nearest')
64
+
65
+ # World coordinates
66
+ ax.set_xlabel('RA')
67
+ ax.set_ylabel('Dec')
68
+
69
+ # Pixel coordinates
70
+ secax = ax.secondary_xaxis('top')
71
+ secax.set_xlabel('x (pixel)')
72
+ secay = ax.secondary_yaxis('right')
73
+ secay.set_ylabel('y (pixel)')
74
+
75
+ # Set limits if provided
76
+ if xlim is not None:
77
+ ax.set_xlim(xlim)
78
+ if ylim is not None:
79
+ ax.set_ylim(ylim)
80
+
81
+ plt.grid()
82
+
83
+ # Use filename to build a simple title
84
+ base = os.path.basename(fits_path)
85
+ region = base.split('_')[0]
86
+ ax.set_title(region)
87
+
88
+ # Save or show
89
+ if save_path:
90
+ plt.savefig(save_path, dpi=200, bbox_inches='tight')
91
+ print(f"Saved image at {save_path}")
92
+ plt.close(fig)
93
+ else:
94
+ plt.show()
95
+
96
+
97
+ def plot_datacube_filters(
98
+ cube_fits_file,
99
+ ncols=None,
100
+ figsize=(15, 15),
101
+ cmap="viridis",
102
+ norm=None,
103
+ stretch="log",
104
+ min_percent=25.,
105
+ max_percent=99.98,
106
+ save_path=None):
107
+ """
108
+ Plot all filters from a datacube as a grid of images.
109
+
110
+ Parameters
111
+ ----------
112
+ cube_fits_file : str
113
+ Path to the 3D FITS datacube.
114
+
115
+ ncols : int or None
116
+ Number of columns in the grid. If None, a near-square layout is used.
117
+
118
+ figsize : tuple
119
+ Figure size.
120
+
121
+ cmap : str
122
+ Colormap.
123
+
124
+ norm : matplotlib.colors.Normalize or None
125
+ Custom normalization. If None, uses simple_norm.
126
+
127
+ stretch : str
128
+ Stretch for simple_norm (ignored if norm is provided).
129
+
130
+ min_percent, max_percent : float
131
+ Percentile limits for normalization.
132
+
133
+ save_path : str or None
134
+ If provided, saves the figure.
135
+
136
+ show : bool
137
+ If True, displays the figure.
138
+ """
139
+
140
+ # Open FITS
141
+ with fits.open(cube_fits_file) as hdul:
142
+ ext = find_ext(hdul)
143
+ if ext is None:
144
+ raise ValueError(f"No valid data extension in {cube_fits_file}")
145
+
146
+ datacube = hdul[ext].data
147
+ header = hdul[ext].header
148
+
149
+ # Filters
150
+ n_filters = datacube.shape[0]
151
+
152
+ filters = [header.get(f"FILTER{i+1}", f"{i}") for i in range(n_filters)]
153
+
154
+ # Layout for figure
155
+ if ncols is None:
156
+ # Try to make a square-like grid
157
+ ncols = int(np.ceil(np.sqrt(n_filters)))
158
+
159
+ nrows = int(np.ceil(n_filters / ncols))
160
+
161
+ # Figure
162
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
163
+ axes = np.atleast_1d(axes).flatten()
164
+
165
+ for i in range(n_filters):
166
+ ax = axes[i]
167
+ data = datacube[i]
168
+
169
+ mask = np.logical_or(np.isnan(data), data <= 0.)
170
+
171
+ if norm is None:
172
+ if np.all(mask):
173
+ norm_i = None
174
+ else:
175
+ norm_i = simple_norm(
176
+ data[~mask],
177
+ stretch=stretch,
178
+ min_percent=min_percent,
179
+ max_percent=max_percent)
180
+ else:
181
+ norm_i = norm
182
+
183
+ ax.imshow(
184
+ data,
185
+ origin="lower",
186
+ cmap=cmap,
187
+ norm=norm_i,
188
+ interpolation="nearest")
189
+
190
+ ax.set_title(f"{filters[i]}")
191
+ ax.axis("off")
192
+
193
+ for j in range(n_filters, len(axes)):
194
+ axes[j].axis("off")
195
+
196
+ plt.tight_layout()
197
+
198
+ if save_path:
199
+ plt.savefig(save_path, dpi=200, bbox_inches='tight')
200
+ print(f"Saved image at {save_path}")
201
+ plt.close(fig)
202
+ else:
203
+ plt.show()
204
+
205
+ def plot_psf_grid(
206
+ psf_dir=None,
207
+ psf_files=None,
208
+ ncols=None,
209
+ figsize=(12, 12),
210
+ norm=None,
211
+ stretch="log",
212
+ percent=99.0,
213
+ cmap="viridis",
214
+ save_path=None):
215
+ """
216
+ Plot a grid of PSFs from FITS files.
217
+
218
+ Parameters
219
+ ----------
220
+ psf_dir : str, optional
221
+ Directory containing PSF FITS files
222
+
223
+ psf_files : list, optional
224
+ List of PSF FITS file paths (overrides psf_dir)
225
+
226
+ ncols : int or None
227
+ Number of columns in the grid. If None, chosen automatically.
228
+
229
+ figsize : tuple
230
+ Figure size
231
+
232
+ norm : astropy norm, optional
233
+ Custom normalization
234
+
235
+ stretch : str
236
+ Stretch for simple_norm (if norm is None)
237
+
238
+ percent : float
239
+ Percentile for normalization
240
+
241
+ cmap : str
242
+ Colormap
243
+
244
+ save_path : str, optional
245
+ Path to save the figure
246
+ """
247
+
248
+ # Get file list
249
+ if psf_files is None:
250
+ if psf_dir is None:
251
+ raise ValueError("Provide either 'psf_dir' or 'psf_files'")
252
+ psf_files = sorted(glob.glob(os.path.join(psf_dir, "*.fits")))
253
+
254
+ n_psf = len(psf_files)
255
+
256
+ if n_psf == 0:
257
+ raise ValueError("No PSF files found")
258
+
259
+ # Figure
260
+ if ncols is None:
261
+ ncols = int(np.ceil(np.sqrt(n_psf)))
262
+
263
+ nrows = int(np.ceil(n_psf / ncols))
264
+
265
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
266
+ axes = np.atleast_1d(axes).flatten()
267
+
268
+ for ax, psf_file in zip(axes, psf_files):
269
+ with fits.open(psf_file) as hdul:
270
+ ext = find_ext(hdul)
271
+ if ext is None:
272
+ raise ValueError(f"No valid image data in '{psf_file}'")
273
+
274
+ psf_data = hdul[ext].data
275
+
276
+ try:
277
+ filt_name = get_filter(psf_file)
278
+ except Exception:
279
+ filt_name = os.path.basename(psf_file).replace(".fits", "")
280
+
281
+ # Normalization
282
+ if norm is None:
283
+ norm_used = simple_norm(psf_data, stretch=stretch, percent=percent)
284
+ else:
285
+ norm_used = norm
286
+
287
+ im = ax.imshow(psf_data, norm=norm_used, origin="lower", cmap=cmap)
288
+
289
+ ax.set_title(filt_name, fontsize=10)
290
+ ax.set_xlabel("x (pixel)")
291
+ ax.set_ylabel("y (pixel)")
292
+
293
+ ny, nx = psf_data.shape
294
+ ax.set_xticks([0, nx//2, nx-1])
295
+ ax.set_yticks([0, ny//2, ny-1])
296
+
297
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
298
+
299
+ # Remove empty panels
300
+ for j in range(n_psf, len(axes)):
301
+ axes[j].axis("off")
302
+
303
+ plt.tight_layout()
304
+
305
+ # Save
306
+ if save_path is not None:
307
+ plt.savefig(save_path, dpi=200)
308
+
309
+ plt.show()
310
+
lightstack/psf.py ADDED
@@ -0,0 +1,336 @@
1
+ import numpy as np
2
+ import os
3
+
4
+ from astropy.io import fits
5
+ from scipy.fft import fft2, ifft2, fftshift, ifftshift
6
+ from astropy.convolution import convolve_fft
7
+ from photutils.centroids import centroid_com
8
+ from scipy.ndimage import shift, zoom
9
+
10
+
11
+ def centroid_weighted(psf, threshold=1e-4):
12
+ """
13
+ Compute a flux-weighted centroid of a PSF, ignoring low-level noise.
14
+
15
+ Parameters
16
+ ----------
17
+ psf : 2D ndarray
18
+ Input PSF image.
19
+ threshold : float, optional
20
+ Fraction of the maximum PSF value below which pixels are ignored.
21
+ Default is 1e-4.
22
+
23
+ Returns
24
+ -------
25
+ ycen, xcen : float
26
+ Coordinates of the centroid (in pixel units).
27
+ """
28
+ psf_copy = psf.copy()
29
+
30
+ mask = psf_copy < threshold * psf_copy.max()
31
+ psf_copy[mask] = 0
32
+
33
+ return centroid_com(psf_copy)
34
+
35
+
36
+ def make_odd(psf):
37
+ """
38
+ Ensure that a PSF has odd dimensions by padding with zeros if necessary.
39
+
40
+ Parameters
41
+ ----------
42
+ psf : 2D ndarray
43
+ Input PSF image.
44
+
45
+ Returns
46
+ -------
47
+ psf_padded : 2D ndarray
48
+ PSF with odd dimensions.
49
+ """
50
+ ny, nx = psf.shape
51
+
52
+ pad_y = 1 if ny % 2 == 0 else 0
53
+ pad_x = 1 if nx % 2 == 0 else 0
54
+
55
+ pad_before_y = pad_y // 2
56
+ pad_after_y = pad_y - pad_before_y
57
+
58
+ pad_before_x = pad_x // 2
59
+ pad_after_x = pad_x - pad_before_x
60
+
61
+ psf_padded = np.pad(
62
+ psf,
63
+ ((pad_before_y, pad_after_y),
64
+ (pad_before_x, pad_after_x)),
65
+ mode='constant',
66
+ constant_values=0)
67
+
68
+ return psf_padded
69
+
70
+
71
+ def resample_psf(
72
+ input_path,
73
+ output_path,
74
+ zoom_factor=None,
75
+ psf_pixel_scale=None,
76
+ target_pixel_scale=None,
77
+ order=3,
78
+ normalize=True,
79
+ make_odd_shape=True,
80
+ header_comment=True):
81
+ """
82
+ Resample a PSF to match a target pixel scale, by directly providing a zoom factor
83
+ or by specifying the original and target pixel scales. Always downsample to the worst resolution.
84
+ This code has not been tested for upsampling, so it is not recommended!
85
+
86
+ Parameters
87
+ ----------
88
+ input_path : str
89
+ Path to input PSF FITS file.
90
+ output_path : str
91
+ Path to save the resampled PSF.
92
+ zoom_factor : float, optional
93
+ Zoom factor to apply. If None, it will be computed from pixel scales.
94
+ psf_pixel_scale : float, optional
95
+ Original PSF pixel scale (arcsec/pixel).
96
+ target_pixel_scale : float, optional
97
+ Target pixel scale (arcsec/pixel).
98
+ order : int, optional
99
+ Interpolation order for scipy.ndimage.zoom. Default is 3 (cubic).
100
+ normalize : bool, optional
101
+ If True, normalize PSF to unit sum. Default is True.
102
+ make_odd_shape : bool, optional
103
+ If True, pad PSF to have odd dimensions. Default is True.
104
+ header_comment : bool, optional
105
+ If True, add history information to FITS header.
106
+
107
+ Returns
108
+ -------
109
+ psf_resampled : 2D ndarray
110
+ Resampled PSF array.
111
+ """
112
+
113
+ # Load PSF
114
+ with fits.open(input_path) as hdul:
115
+ ext = find_ext(hdul)
116
+ if ext is None:
117
+ raise ValueError(f"No valid image extension found in {input_path}")
118
+
119
+ psf = hdul[ext].data.astype(float)
120
+ header = hdul[ext].header.copy()
121
+
122
+ # Determine zoom factor
123
+ if zoom_factor is None:
124
+ if psf_pixel_scale is None or target_pixel_scale is None:
125
+ raise ValueError("Provide either zoom_factor OR both psf_pixel_scale and target_pixel_scale")
126
+ zoom_factor = psf_pixel_scale / target_pixel_scale
127
+
128
+ # Resample
129
+ psf_resampled = zoom(psf, zoom_factor, order=order)
130
+
131
+ # Ensure odd shape
132
+ if make_odd_shape:
133
+ psf_resampled = make_odd(psf_resampled)
134
+
135
+ # Normalize
136
+ if normalize:
137
+ total = psf_resampled.sum()
138
+ if total != 0:
139
+ psf_resampled /= total
140
+
141
+ # Update header
142
+ if header_comment:
143
+ header["HISTORY"] = "PSF resampled using scipy.ndimage.zoom"
144
+ if target_pixel_scale is not None:
145
+ header["CDELT1"] = (target_pixel_scale / 3600, "deg/pix")
146
+
147
+ # Save
148
+ fits.PrimaryHDU(psf_resampled, header=header).writeto(
149
+ output_path, overwrite=True)
150
+
151
+ return psf_resampled
152
+
153
+
154
+
155
+ def build_kernel(psf_source, psf_ref, shape=(101, 101), eps=1e-3):
156
+ """
157
+ Build a convolution kernel that transforms psf_source into psf_ref
158
+ using Fourier fast transforms.
159
+
160
+ Parameters
161
+ ----------
162
+ psf_source : 2D array
163
+ PSF of the original image.
164
+ psf_ref : 2D array
165
+ PSF of the reference resolution.
166
+ shape : tuple
167
+ Shape (Ny, Nx) of the kernel.
168
+ eps : float
169
+ Regularization parameter.
170
+
171
+ Returns
172
+ -------
173
+ kernel : 2D array
174
+ Convolution kernel.
175
+ """
176
+
177
+ Ny, Nx = shape
178
+
179
+ psf_s = np.zeros((Ny, Nx))
180
+ psf_t = np.zeros((Ny, Nx))
181
+
182
+ ys, xs = psf_source.shape
183
+ yt, xt = psf_ref.shape
184
+
185
+ psf_s[
186
+ Ny//2 - ys//2 : Ny//2 - ys//2 + ys,
187
+ Nx//2 - xs//2 : Nx//2 - xs//2 + xs] = psf_source
188
+
189
+ psf_t[
190
+ Ny//2 - yt//2 : Ny//2 - yt//2 + yt,
191
+ Nx//2 - xt//2 : Nx//2 - xt//2 + xt] = psf_ref
192
+
193
+ psf_s = ifftshift(psf_s)
194
+ psf_t = ifftshift(psf_t)
195
+
196
+ # FFTs
197
+ F_s = fft2(psf_s)
198
+ F_t = fft2(psf_t)
199
+
200
+ # Kernel construction (Wiener-like)
201
+ F_kernel = F_t * np.conj(F_s) / (np.abs(F_s)**2 + eps)
202
+
203
+ kernel = np.real(ifft2(F_kernel))
204
+ kernel = fftshift(kernel)
205
+
206
+ # Normalize
207
+ kernel /= np.sum(kernel)
208
+
209
+ return kernel
210
+
211
+ def save_kernel(kernel, output_path, header=None):
212
+ """
213
+ Save a convolution kernel to a FITS file.
214
+
215
+ Parameters
216
+ ----------
217
+ kernel : 2D array
218
+ Convolution kernel.
219
+ output_path : str
220
+ Output FITS path.
221
+ header : fits.Header or None
222
+ Optional header to attach to the kernel.
223
+ """
224
+ hdu = fits.PrimaryHDU(kernel, header=header)
225
+ hdu.writeto(output_path, overwrite=True)
226
+ print(f"Kernel saved at {output_path}")
227
+
228
+
229
+ def apply_kernel(image, kernel):
230
+ """
231
+ Convolve an image with a given kernel.
232
+
233
+ Parameters
234
+ ----------
235
+ image : 2D array
236
+ kernel : 2D array
237
+
238
+ Returns
239
+ -------
240
+ image_conv : 2D array
241
+ """
242
+ return convolve_fft(image, kernel, allow_huge=True, normalize_kernel=False)
243
+
244
+
245
+ def psf_match_datacube(
246
+ cube_path,
247
+ kernel_dir,
248
+ ref_filter="F444W",
249
+ output_path=None,
250
+ overwrite=True):
251
+ """
252
+ Apply PSF matching to a datacube using precomputed convolution kernels.
253
+
254
+ Each slice is convolved to match the PSF of a reference filter.
255
+
256
+ Parameters
257
+ ----------
258
+ cube_path : str
259
+ Path to input datacube FITS.
260
+
261
+ kernel_dir : str
262
+ Directory containing kernel FITS files.
263
+
264
+ ref_filter : str
265
+ Reference filter (e.g., "F444W").
266
+
267
+ output_path : str or None
268
+ Output FITS file. If None, adds '_psfmatched'.
269
+
270
+ overwrite : bool
271
+ Overwrite output file.
272
+
273
+ Returns
274
+ -------
275
+ output_path : str
276
+ Path to saved PSF-matched datacube.
277
+ """
278
+
279
+ # Output path
280
+ if output_path is None:
281
+ output_path = cube_path.replace(".fits", "_psfmatched.fits")
282
+
283
+ # Load cube
284
+ with fits.open(cube_path) as hdul:
285
+ cube = hdul[0].data
286
+ header = hdul[0].header.copy()
287
+
288
+ nfilters, ny, nx = cube.shape
289
+ cube_conv = np.zeros_like(cube)
290
+
291
+ # Loop over filters
292
+ for i in range(nfilters):
293
+
294
+ filt = header.get(f"FILTER{i+1}")
295
+ if filt is None:
296
+ print(f"No FILTER{i+1} keyword, skipping.")
297
+ cube_conv[i] = cube[i]
298
+ continue
299
+
300
+ filt = filt.strip()
301
+ filt_lower = filt.lower()
302
+
303
+ print(f"Processing {i+1}/{nfilters}: {filt}")
304
+
305
+ # Skip reference filter
306
+ if filt.upper() == ref_filter.upper():
307
+ cube_conv[i] = cube[i]
308
+ continue
309
+
310
+ # Kernel path
311
+ kernel_path = os.path.join(
312
+ kernel_dir,
313
+ f"kernel_{filt_lower}_to_{ref_filter.lower()}.fits")
314
+
315
+ if not os.path.exists(kernel_path):
316
+ print(f"Kernel not found for {filt}, skipping.")
317
+ cube_conv[i] = cube[i]
318
+ continue
319
+
320
+ kernel = fits.getdata(kernel_path)
321
+
322
+ # Convolution
323
+ cube_conv[i] = apply_kernel(cube[i], kernel)
324
+
325
+ # Header update
326
+ header.add_history(
327
+ f"PSF matched to {ref_filter} using FFT convolution kernels")
328
+
329
+ # Save
330
+ fits.PrimaryHDU(cube_conv, header=header).writeto(
331
+ output_path,
332
+ overwrite=overwrite)
333
+
334
+ print(f"PSF-matched datacube saved at {output_path}")
335
+
336
+ return output_path