visualastro 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
visualastro/io.py ADDED
@@ -0,0 +1,251 @@
1
+ from astropy.io import fits
2
+ import matplotlib as mpl
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from .numerical_utils import check_is_array
7
+ from .visual_classes import FitsFile
8
+
9
+
10
+ # Fits File I/O Operations
11
+ # ––––––––––––––––––––––––
12
+ def load_fits(filepath, header=True, error=True,
13
+ print_info=True, transpose=False, dtype=None):
14
+ '''
15
+ Load a FITS file and return its data and optional header.
16
+ Parameters
17
+ ––––––––––
18
+ filepath : str
19
+ Path to the FITS file to load.
20
+ header : bool, default=True
21
+ If True, return the FITS header along with the data
22
+ as a FitsFile object.
23
+ If False, only the data is returned.
24
+ error : bool, default=True
25
+ If True, return the 'ERR' extention of the fits file.
26
+ print_info : bool, default=True
27
+ If True, print HDU information using 'hdul.info()'.
28
+ transpose : bool, default=False
29
+ If True, transpose the data array before returning.
30
+ dtype : np.dtype, default=None
31
+ Data type to convert the FITS data to. If None,
32
+ determines the dtype from the data.
33
+ Returns
34
+ –––––––
35
+ FitsFile
36
+ If header is True, returns an object containing:
37
+ - data: 'np.ndarray' of the FITS data
38
+ - header: 'astropy.io.fits.Header' if 'header=True', else 'None'
39
+ data : np.ndarray
40
+ If header is False, returns just the data component.
41
+ '''
42
+ # print fits file info
43
+ with fits.open(filepath) as hdul:
44
+ if print_info:
45
+ hdul.info()
46
+ # extract data and optionally the header from the file
47
+ # if header is not requested, return None
48
+ result = fits.getdata(filepath, header=header)
49
+ data, fits_header = result if isinstance(result, tuple) else (result, None)
50
+
51
+ dt = get_dtype(data, dtype)
52
+ data = data.astype(dt, copy=False)
53
+ if transpose:
54
+ data = data.T
55
+
56
+ errors = get_errors(hdul, dt, transpose)
57
+
58
+ if header or error:
59
+ return FitsFile(data, fits_header, errors)
60
+ else:
61
+ return data
62
+
63
+
64
+ def get_dtype(data, dtype=None, default_dtype=np.float64):
65
+ '''
66
+ Returns the dtype from the provided data. Promotes
67
+ integers to floats if needed.
68
+ Parameters
69
+ ––––––––––
70
+ data : array-like
71
+ Input array whose dtype will be checked.
72
+ dtype : data-type, optional, default=None
73
+ If provided, this dtype is returned directly.
74
+ If None, returns `data.dtype` if floating or
75
+ `np.float64` if integer or unsigned.
76
+ default_dtype : data-type, optional, default=np.float64
77
+ Float type to use if `data` is integer or unsigned.
78
+ Returns
79
+ –––––––
80
+ dtype : np.dtype
81
+ NumPy dtype object: user dtype if given, otherwise the array's
82
+ float dtype or `default_dtype` if array is integer/unsigned.
83
+ '''
84
+ # return user dtype if passed in
85
+ if dtype is not None:
86
+ return np.dtype(dtype)
87
+
88
+ data = check_is_array(data)
89
+ # by default use data dtype if floating
90
+ # if unsigned or int use default_dtype
91
+ if np.issubdtype(data.dtype, np.floating):
92
+ return np.dtype(data.dtype)
93
+ else:
94
+ return np.dtype(default_dtype)
95
+
96
+
97
+ def get_errors(hdul, dtype=None, transpose=False):
98
+ '''
99
+ Return the error array from an HDUList, falling back to square root
100
+ of variance if needed.
101
+ Parameters
102
+ ––––––––––
103
+ hdul : astropy.io.fits.HDUList
104
+ The HDUList object containing FITS extensions to search for errors or variance.
105
+ dtype : data-type, optional, default=np.float64
106
+ The desired NumPy dtype of the returned error array.
107
+ Returns
108
+ –––––––
109
+ errors : np.ndarray or None
110
+ The error array if found, or None if no suitable extension is present.
111
+ '''
112
+ errors = None
113
+ for hdu in hdul[1:]:
114
+ extname = hdu.header.get('EXTNAME', '').upper()
115
+ if extname in {'ERR', 'ERROR', 'UNCERT'}:
116
+ dt = get_dtype(hdu.data, dtype)
117
+ errors = hdu.data.astype(dt, copy=False)
118
+ break
119
+ # fallback to variance if no explicit errors
120
+ if errors is None:
121
+ for hdu in hdul[1:]:
122
+ extname = hdu.header.get('EXTNAME', '').upper()
123
+ if extname in {'VAR', 'VARIANCE', 'VAR_POISSON', 'VAR_RNOISE'}:
124
+ dt = get_dtype(hdu.data, dtype)
125
+ errors = np.sqrt(hdu.data.astype(dtype, copy=False))
126
+ break
127
+ if transpose and errors is not None:
128
+ errors = errors.T
129
+
130
+ return errors
131
+
132
+
133
+ def write_cube_2_fits(cube, filename, overwrite=False):
134
+ '''
135
+ Write a 3D data cube to a series of FITS files.
136
+ Parameters
137
+ ––––––––––
138
+ cube : ndarray (N_frames, N, M)
139
+ Data cube containing N_frames images of shape (N, M).
140
+ filename : str
141
+ Base filename (without extension). Each
142
+ output file will be saved as "{filename}_i.fits".
143
+ overwrite : bool, optional, default=False
144
+ If True, existing files with the same name
145
+ will be overwritten.
146
+ Notes
147
+ –––––
148
+ Prints a message indicating the number of
149
+ frames and the base filename.
150
+ '''
151
+ N_frames, N, M = cube.shape
152
+ print(f"Writing {N_frames} fits files to {filename}_i.fits")
153
+ for i in tqdm(range(N_frames)):
154
+ output_name = filename + f"_{i}.fits"
155
+ fits.writeto(output_name, cube[i], overwrite=overwrite)
156
+
157
+
158
+ # Figure I/O Operations
159
+ # –––––––––––––––––––––
160
+ def get_kwargs(kwargs, *names, default=None):
161
+ '''
162
+ Return the first matching kwarg value from a list of possible names.
163
+ Parameters
164
+ ––––––––––
165
+ kwargs : dict
166
+ Dictionary of keyword arguments, typically taken from ``**kwargs``.
167
+ *names : str
168
+ One or more possible keyword names to search for. The first name found
169
+ in ``kwargs`` with a non-None value is returned.
170
+ default : any, optional, default=None
171
+ Value to return if none of the provided names are found in ``kwargs``.
172
+ Default is None.
173
+ Returns
174
+ –––––––
175
+ value : any
176
+ The value of the first matching keyword argument, or `default` if
177
+ none are found.
178
+ '''
179
+ for name in names:
180
+ if name in kwargs and kwargs[name] is not None:
181
+ return kwargs[name]
182
+ return default
183
+
184
+
185
+ def save_figure_2_disk(dpi=600, pdf_compression=6, transparent=False, bbox_inches='tight', **kwargs):
186
+ '''
187
+ Saves current figure to disk as a
188
+ eps, pdf, png, or svg, and prompts
189
+ user for a filename and format.
190
+ Parameters
191
+ ––––––––––
192
+ dpi : float or int, optional, default=600
193
+ Resolution in dots per inch.
194
+ pdf_compression : int, optional, default=False
195
+ 'Pdf.compression' value for matplotlib.rcParams.
196
+ Accepts integers from 0-9, with 0 meaning no
197
+ compression.
198
+ transparent : bool, optional, default=False
199
+ If True, the Axes patches will all be transparent;
200
+ the Figure patch will also be transparent unless
201
+ facecolor and/or edgecolor are specified via kwargs.
202
+ bbox_inches : str or Bbox, default='tight'
203
+ Bounding box in inches: only the given portion of the
204
+ figure is saved. If 'tight', try to figure out the
205
+ tight bbox of the figure.
206
+
207
+ **kwargs : dict, optional
208
+ Additional parameters.
209
+
210
+ Supported keyword arguments include:
211
+
212
+ - `facecolorcolor` : str, default='auto'
213
+ The facecolor of the figure. If 'auto',
214
+ use the current figure facecolor.
215
+ - `edgecolorcolor` : str, default='auto'
216
+ The edgecolor of the figure. If 'auto',
217
+ use the current figure edgecolor.
218
+ '''
219
+ # –––– KWARGS ––––
220
+ facecolor = get_kwargs(kwargs, 'facecolor', 'fc', default='auto')
221
+ edgecolor = get_kwargs(kwargs, 'edgecolor', 'ec', default='auto')
222
+ allowed_formats = {'eps', 'pdf', 'png', 'svg'}
223
+ # prompt user for filename, and extract extension
224
+ filename = input("Input filename for image (ex: myimage.pdf): ").strip()
225
+ basename, *extension = filename.rsplit(".", 1)
226
+ # if extension exists, and is allowed, extract extension from list
227
+ if extension and extension[0].lower() in allowed_formats:
228
+ extension = extension[0]
229
+ # else prompt user to input a valid extension
230
+ else:
231
+ extension = ""
232
+ while extension not in allowed_formats:
233
+ extension = (
234
+ input(f"Please choose a format from ({', '.join(allowed_formats)}): ")
235
+ .strip()
236
+ .lower()
237
+ )
238
+ # construct complete filename
239
+ filename = f"{basename}.{extension}"
240
+
241
+ with plt.rc_context(rc={'pdf.compression': int(pdf_compression)} if extension == 'pdf' else {}):
242
+ # save figure
243
+ plt.savefig(
244
+ fname=filename,
245
+ format=extension,
246
+ transparent=transparent,
247
+ bbox_inches=bbox_inches,
248
+ facecolor=facecolor,
249
+ edgecolor=edgecolor,
250
+ dpi=dpi
251
+ )
@@ -0,0 +1,285 @@
1
+ import warnings
2
+ from astropy.io.fits import Header
3
+ from astropy import units as u
4
+ from astropy.units import Quantity, spectral, Unit, UnitConversionError
5
+ import numpy as np
6
+ from scipy.interpolate import interp1d, CubicSpline
7
+ from spectral_cube import SpectralCube
8
+ from .visual_classes import DataCube, ExtractedSpectrum, FitsFile
9
+
10
+
11
+ # Type Checking Arrays and Objects
12
+ # ––––––––––––––––––––––––––––––––
13
+ def check_is_array(data, keep_units=False):
14
+ '''
15
+ Ensure array input is np.ndarray.
16
+ Parameters
17
+ ––––––––––
18
+ data : np.ndarray, DataCube, FitsFile, or Quantity
19
+ Array or DataCube object.
20
+ keep_inits : bool, optional, default=False
21
+ If True, keep astropy units attached if present.
22
+ Returns
23
+ –––––––
24
+ data : np.ndarray
25
+ Array or 'data' component of DataCube.
26
+ '''
27
+ if isinstance(data, DataCube):
28
+ data = data.value
29
+ elif isinstance(data, FitsFile):
30
+ data = data.data
31
+ if isinstance(data, Quantity):
32
+ if keep_units:
33
+ return data
34
+ else:
35
+ data = data.value
36
+
37
+ return np.asarray(data)
38
+
39
+
40
+ def check_units_consistency(datas):
41
+ '''
42
+ Check that all input objects have the same units and warn if they differ.
43
+ Additionally ensure that the input is iterable by wrapping in a list.
44
+ Parameters
45
+ ----------
46
+ datas : object or list/tuple of objects
47
+ Objects to check. Can be Quantity, SpectralCube, DataCube, etc.
48
+ Returns
49
+ -------
50
+ datas : list
51
+ The input objects as a list.
52
+ '''
53
+ datas = datas if isinstance(datas, (list, tuple)) else [datas]
54
+
55
+ first_unit = get_units(datas[0])
56
+ for i, obj in enumerate(datas[1:], start=1):
57
+ unit = get_units(obj)
58
+ if unit != first_unit:
59
+ warnings.warn(
60
+ f"\nInput at index {i} has unit `{unit}`, which differs from unit `{first_unit}`."
61
+ f"at index 0."
62
+ )
63
+
64
+ return datas
65
+
66
+
67
+ def get_data(obj):
68
+ '''
69
+ Extract the underlying data attribute from a DataCube or FitsFile object.
70
+ Parameters
71
+ ––––––––––
72
+ obj : DataCube or FitsFile or np.ndarray
73
+ The object from which to extract the data. If a raw array is provided,
74
+ it is returned unchanged.
75
+ Returns
76
+ –––––––
77
+ np.ndarray, or data extension
78
+ The data attribute contained in the object, or the input array itself
79
+ if it is not a DataCube or FitsFile.
80
+ '''
81
+ if isinstance(obj, DataCube):
82
+ obj = obj.data
83
+ elif isinstance(obj, FitsFile):
84
+ obj = obj.data
85
+
86
+ return obj
87
+
88
+
89
+ def get_units(obj):
90
+ '''
91
+ Extract the unit from an object, if it exists.
92
+ Parameters
93
+ ––––––––––
94
+ obj : Quantity, SpectralCube, FITS-like object, or any
95
+ The input object from which to extract a unit. This can be:
96
+ - an astropy.units.Quantity
97
+ - a SpectralCube
98
+ - a DataCube or FitsFile
99
+ - a FITS-like object with a header containing a 'BUNIT' keyword
100
+ - any other object (returns None if no unit is found)
101
+ Returns
102
+ –––––––
103
+ astropy.units.Unit or None
104
+ The unit associated with the input object, if it exists.
105
+ Returns None if the object has no unit or if the unit cannot be parsed.
106
+ '''
107
+ # check if object is DataCube or FitsFile
108
+ data = get_data(obj)
109
+ # check if unit extension exists
110
+ if isinstance(data, (DataCube, FitsFile, Quantity, SpectralCube)):
111
+ return data.unit
112
+ if isinstance(obj, ExtractedSpectrum):
113
+ try:
114
+ return obj.spectrum1d.unit
115
+ except:
116
+ try:
117
+ return obj.flux.unit
118
+ except:
119
+ return None
120
+
121
+ # try to extract unit from header
122
+ # use either header extension or obj if obj is a header
123
+ header = getattr(obj, 'header', obj if isinstance(obj, Header) else None)
124
+ if isinstance(header, Header) and 'BUNIT' in header:
125
+ try:
126
+ return Unit(header['BUNIT'])
127
+ except Exception:
128
+ return None
129
+
130
+ return None
131
+
132
+
133
+ def return_array_values(array):
134
+ '''
135
+ Extract the numerical values from an 'astropy.units.Quantity'
136
+ or return the array as-is.
137
+ Parameters
138
+ ––––––––––
139
+ array : astropy.units.Quantity or array-like
140
+ The input array. If it is a Quantity, the numerical values are extracted.
141
+ Otherwise, the input is returned unchanged.
142
+ Returns
143
+ –––––––
144
+ np.ndarray or array-like
145
+ The numerical values of the array, without units if input was a Quantity,
146
+ or the original array if it was not a Quantity.
147
+ '''
148
+ array = array.value if isinstance(array, Quantity) else array
149
+
150
+ return array
151
+
152
+
153
+ # Science Operation Functions
154
+ # –––––––––––––––––––––––––––
155
+ def convert_units(quantity, unit):
156
+ '''
157
+ Convert an Astropy Quantity to a specified unit, with a fallback if conversion fails.
158
+ Parameters
159
+ ––––––––––
160
+ quantity : astropy.units.Quantity
161
+ The input quantity to convert.
162
+ unit : str, astropy.units.Unit, or None
163
+ The unit to convert to. If None, no conversion is performed.
164
+ Returns
165
+ –––––––
166
+ astropy.units.Quantity
167
+ The quantity converted to the requested unit if possible; otherwise,
168
+ the original quantity with its existing unit.
169
+ Notes
170
+ –––––
171
+ - Uses 'spectral()' equivalencies to allow conversions between
172
+ wavelength, frequency, and velocity units.
173
+ - If conversion fails, prints a warning and returns the original quantity.
174
+ '''
175
+ if unit is None:
176
+ return quantity
177
+ try:
178
+ # convert string unit to Unit if necessary
179
+ target_unit = Unit(unit) if isinstance(unit, str) else unit
180
+ return quantity.to(target_unit, equivalencies=spectral())
181
+ except UnitConversionError:
182
+ print(
183
+ f'Could not convert to unit: {unit}.'
184
+ f'Defaulting to unit: {quantity.unit}.'
185
+ )
186
+ return quantity
187
+
188
+
189
+ def shift_by_radial_vel(spectral_axis, radial_vel):
190
+ '''
191
+ Shift spectral axis to rest frame using a radial velocity.
192
+ Parameters
193
+ ––––––––––
194
+ spectral_axis : astropy.units.Quantity
195
+ The spectral axis to shift. Can be in frequency or wavelength units.
196
+ radial_vel : float, astropy.units.Quantity or None
197
+ Radial velocity in km/s (astropy units are optional). Positive values
198
+ correspond to a redshift (moving away). If None, no shift is applied.
199
+ Returns
200
+ –––––––
201
+ shifted_axis : astropy.units.Quantity
202
+ The spectral axis shifted to the rest frame according to the given
203
+ radial velocity. If the input is in frequency units, the classical
204
+ Doppler formula for frequency is applied; otherwise, the classical
205
+ formula for wavelength is applied.
206
+ '''
207
+ # speed of light in km/s in vacuum
208
+ c = 299792.458 # [km/s]
209
+ if radial_vel is not None:
210
+ if isinstance(radial_vel, Quantity):
211
+ radial_vel = radial_vel.to(u.km/u.s).value # type: ignore
212
+ # if spectral axis in units of frequency
213
+ if spectral_axis.unit.is_equivalent(u.Unit('Hz')):
214
+ spectral_axis /= (1 - radial_vel / c)
215
+ # if spectral axis in units of wavelength
216
+ else:
217
+ spectral_axis /= (1 + radial_vel / c)
218
+
219
+ return spectral_axis
220
+
221
+
222
+ # Numerical Operation Functions
223
+ # –––––––––––––––––––––––––––––
224
+ def interpolate_arrays(xp, yp, x_range, N_samples, method='linear'):
225
+ '''
226
+ Interpolate a 1D array over a specified range.
227
+ Parameters
228
+ ––––––––––
229
+ xp : array-like
230
+ The x-coordinates of the data points.
231
+ yp : array-like
232
+ The y-coordinates of the data points.
233
+ x_range : tuple of float
234
+ The (min, max) range over which to interpolate.
235
+ N_samples : int
236
+ Number of points in the interpolated output.
237
+ method : str, default='linear'
238
+ Interpolation method. Options:
239
+ - 'linear' : linear interpolation
240
+ - 'cubic' : cubic interpolation using 'interp1d'
241
+ - 'cubic_spline' : cubic spline interpolation using 'CubicSpline'
242
+ Returns
243
+ –––––––
244
+ x_interp : np.ndarray
245
+ The evenly spaced x-coordinates over the specified range.
246
+ y_interp : np.ndarray
247
+ The interpolated y-values corresponding to 'x_interp'.
248
+ '''
249
+ # generate new interpolation samples
250
+ x_interp = np.linspace(x_range[0], x_range[1], N_samples)
251
+ # get interpolation method
252
+ if method == 'cubic_spline':
253
+ f_interp = CubicSpline(xp, yp)
254
+ else:
255
+ # fallback to linear if method is unknown
256
+ kind = method if method in ['linear', 'cubic'] else 'linear'
257
+ f_interp = interp1d(xp, yp, kind=kind)
258
+ # interpolate over new samples
259
+ y_interp = f_interp(x_interp)
260
+
261
+ return x_interp, y_interp
262
+
263
+
264
+ def mask_within_range(x, xlim=None):
265
+ '''
266
+ Return a boolean mask for values of x within the given limits.
267
+ Parameters
268
+ ––––––––––
269
+ x : array-like
270
+ Data array (e.g., wavelength or flux values)
271
+ xlim : tuple or list, optional
272
+ (xmin, xmax) range. If None, uses the min/max of x.
273
+ Returns
274
+ –––––––
275
+ mask : ndarray of bool
276
+ True where x is within the limits.
277
+ '''
278
+ x = return_array_values(x)
279
+ xlim = return_array_values(xlim)
280
+
281
+ xmin = xlim[0] if xlim is not None else np.nanmin(x)
282
+ xmax = xlim[1] if xlim is not None else np.nanmax(x)
283
+ mask = (x > xmin) & (x < xmax)
284
+
285
+ return mask
@@ -0,0 +1,40 @@
1
+ import numpy as np
2
+ from scipy.ndimage import center_of_mass
3
+ from tqdm import tqdm
4
+ from .numerical_utils import check_is_array
5
+ from .visual_plots import va
6
+
7
+ def compute_flux(cube, target_pixel_loc, star_radius, sky_radii=None, window_half_width=100, plot=False):
8
+ cube = check_is_array(cube)
9
+ # initial target location guess
10
+ x_pixel, y_pixel = target_pixel_loc
11
+ # array to store target and sky flux
12
+ star_flux = np.zeros((cube.shape[0]))
13
+ sky_flux = np.zeros_like(star_flux)
14
+ # loop through each image in cube
15
+ for i in tqdm(range(len(cube))):
16
+ # extract subimage centered around target
17
+ xmin, xmax = x_pixel - window_half_width, x_pixel + window_half_width
18
+ ymin, ymax = y_pixel - window_half_width, y_pixel + window_half_width
19
+ sub_image = cube[i][xmin:xmax, ymin:ymax]
20
+ # compute approximate center pixels of target
21
+ cenx, ceny = center_of_mass(sub_image)
22
+ # compute distance between target center each pixel in subimage
23
+ x, y = np.indices(sub_image.shape)
24
+ distance_from_center = np.sqrt( (x - cenx)**2 + (y - ceny)**2)
25
+ # mask out pixels outside star aperture
26
+ star_aperture = distance_from_center < star_radius
27
+ star_flux[i] = np.nansum(sub_image[star_aperture])
28
+ if sky_radii is not None:
29
+ sky_inner_r, sky_outer_r = sky_radii
30
+ sky_aperture = (distance_from_center < sky_outer_r) & (distance_from_center > sky_inner_r)
31
+ sky_flux[i] = np.nanmedian(sub_image[sky_aperture])
32
+ star_flux[i] -= sky_flux[i]
33
+ if plot:
34
+ circles = [[cenx, ceny, star_radius]]
35
+ if sky_radii is not None:
36
+ for r in sky_radii:
37
+ circles.append([cenx, ceny, r])
38
+ va.imshow(sub_image, circles=circles)
39
+
40
+ return star_flux, sky_flux