pyreduce-astro 0.6.0b5__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. pyreduce/__init__.py +67 -0
  2. pyreduce/__main__.py +106 -0
  3. pyreduce/clib/Release/_slitfunc_2d.cp311-win_amd64.exp +0 -0
  4. pyreduce/clib/Release/_slitfunc_2d.cp311-win_amd64.lib +0 -0
  5. pyreduce/clib/Release/_slitfunc_2d.obj +0 -0
  6. pyreduce/clib/Release/_slitfunc_bd.cp311-win_amd64.exp +0 -0
  7. pyreduce/clib/Release/_slitfunc_bd.cp311-win_amd64.lib +0 -0
  8. pyreduce/clib/Release/_slitfunc_bd.obj +0 -0
  9. pyreduce/clib/__init__.py +0 -0
  10. pyreduce/clib/_slitfunc_2d.cp311-win_amd64.pyd +0 -0
  11. pyreduce/clib/_slitfunc_bd.cp311-win_amd64.pyd +0 -0
  12. pyreduce/clib/build_extract.py +75 -0
  13. pyreduce/clib/slit_func_2d_xi_zeta_bd.c +1313 -0
  14. pyreduce/clib/slit_func_2d_xi_zeta_bd.h +55 -0
  15. pyreduce/clib/slit_func_bd.c +362 -0
  16. pyreduce/clib/slit_func_bd.h +17 -0
  17. pyreduce/clipnflip.py +147 -0
  18. pyreduce/combine_frames.py +855 -0
  19. pyreduce/configuration.py +186 -0
  20. pyreduce/continuum_normalization.py +329 -0
  21. pyreduce/cwrappers.py +404 -0
  22. pyreduce/datasets.py +231 -0
  23. pyreduce/echelle.py +413 -0
  24. pyreduce/estimate_background_scatter.py +129 -0
  25. pyreduce/extract.py +1361 -0
  26. pyreduce/extraction_width.py +77 -0
  27. pyreduce/instruments/__init__.py +0 -0
  28. pyreduce/instruments/andes.json +61 -0
  29. pyreduce/instruments/andes.py +102 -0
  30. pyreduce/instruments/common.json +46 -0
  31. pyreduce/instruments/common.py +675 -0
  32. pyreduce/instruments/crires_plus.json +63 -0
  33. pyreduce/instruments/crires_plus.py +103 -0
  34. pyreduce/instruments/filters.py +195 -0
  35. pyreduce/instruments/harpn.json +136 -0
  36. pyreduce/instruments/harpn.py +201 -0
  37. pyreduce/instruments/harps.json +155 -0
  38. pyreduce/instruments/harps.py +310 -0
  39. pyreduce/instruments/instrument_info.py +140 -0
  40. pyreduce/instruments/instrument_schema.json +221 -0
  41. pyreduce/instruments/jwst_miri.json +53 -0
  42. pyreduce/instruments/jwst_miri.py +29 -0
  43. pyreduce/instruments/jwst_niriss.json +52 -0
  44. pyreduce/instruments/jwst_niriss.py +98 -0
  45. pyreduce/instruments/lick_apf.json +53 -0
  46. pyreduce/instruments/lick_apf.py +35 -0
  47. pyreduce/instruments/mcdonald.json +59 -0
  48. pyreduce/instruments/mcdonald.py +123 -0
  49. pyreduce/instruments/metis_ifu.json +63 -0
  50. pyreduce/instruments/metis_ifu.py +45 -0
  51. pyreduce/instruments/metis_lss.json +65 -0
  52. pyreduce/instruments/metis_lss.py +45 -0
  53. pyreduce/instruments/micado.json +53 -0
  54. pyreduce/instruments/micado.py +45 -0
  55. pyreduce/instruments/neid.json +51 -0
  56. pyreduce/instruments/neid.py +154 -0
  57. pyreduce/instruments/nirspec.json +56 -0
  58. pyreduce/instruments/nirspec.py +215 -0
  59. pyreduce/instruments/nte.json +47 -0
  60. pyreduce/instruments/nte.py +42 -0
  61. pyreduce/instruments/uves.json +59 -0
  62. pyreduce/instruments/uves.py +46 -0
  63. pyreduce/instruments/xshooter.json +66 -0
  64. pyreduce/instruments/xshooter.py +39 -0
  65. pyreduce/make_shear.py +606 -0
  66. pyreduce/masks/mask_crires_plus_det1.fits.gz +0 -0
  67. pyreduce/masks/mask_crires_plus_det2.fits.gz +0 -0
  68. pyreduce/masks/mask_crires_plus_det3.fits.gz +0 -0
  69. pyreduce/masks/mask_ctio_chiron.fits.gz +0 -0
  70. pyreduce/masks/mask_elodie.fits.gz +0 -0
  71. pyreduce/masks/mask_feros3.fits.gz +0 -0
  72. pyreduce/masks/mask_flames_giraffe.fits.gz +0 -0
  73. pyreduce/masks/mask_harps_blue.fits.gz +0 -0
  74. pyreduce/masks/mask_harps_red.fits.gz +0 -0
  75. pyreduce/masks/mask_hds_blue.fits.gz +0 -0
  76. pyreduce/masks/mask_hds_red.fits.gz +0 -0
  77. pyreduce/masks/mask_het_hrs_2x5.fits.gz +0 -0
  78. pyreduce/masks/mask_jwst_miri_lrs_slitless.fits.gz +0 -0
  79. pyreduce/masks/mask_jwst_niriss_gr700xd.fits.gz +0 -0
  80. pyreduce/masks/mask_lick_apf_.fits.gz +0 -0
  81. pyreduce/masks/mask_mcdonald.fits.gz +0 -0
  82. pyreduce/masks/mask_nes.fits.gz +0 -0
  83. pyreduce/masks/mask_nirspec_nirspec.fits.gz +0 -0
  84. pyreduce/masks/mask_sarg.fits.gz +0 -0
  85. pyreduce/masks/mask_sarg_2x2a.fits.gz +0 -0
  86. pyreduce/masks/mask_sarg_2x2b.fits.gz +0 -0
  87. pyreduce/masks/mask_subaru_hds_red.fits.gz +0 -0
  88. pyreduce/masks/mask_uves_blue.fits.gz +0 -0
  89. pyreduce/masks/mask_uves_blue_binned_2_2.fits.gz +0 -0
  90. pyreduce/masks/mask_uves_middle.fits.gz +0 -0
  91. pyreduce/masks/mask_uves_middle_2x2_split.fits.gz +0 -0
  92. pyreduce/masks/mask_uves_middle_binned_2_2.fits.gz +0 -0
  93. pyreduce/masks/mask_uves_red.fits.gz +0 -0
  94. pyreduce/masks/mask_uves_red_2x2.fits.gz +0 -0
  95. pyreduce/masks/mask_uves_red_2x2_split.fits.gz +0 -0
  96. pyreduce/masks/mask_uves_red_binned_2_2.fits.gz +0 -0
  97. pyreduce/masks/mask_xshooter_nir.fits.gz +0 -0
  98. pyreduce/rectify.py +138 -0
  99. pyreduce/reduce.py +2205 -0
  100. pyreduce/settings/settings_ANDES.json +89 -0
  101. pyreduce/settings/settings_CRIRES_PLUS.json +89 -0
  102. pyreduce/settings/settings_HARPN.json +73 -0
  103. pyreduce/settings/settings_HARPS.json +69 -0
  104. pyreduce/settings/settings_JWST_MIRI.json +55 -0
  105. pyreduce/settings/settings_JWST_NIRISS.json +55 -0
  106. pyreduce/settings/settings_LICK_APF.json +62 -0
  107. pyreduce/settings/settings_MCDONALD.json +58 -0
  108. pyreduce/settings/settings_METIS_IFU.json +77 -0
  109. pyreduce/settings/settings_METIS_LSS.json +77 -0
  110. pyreduce/settings/settings_MICADO.json +78 -0
  111. pyreduce/settings/settings_NEID.json +73 -0
  112. pyreduce/settings/settings_NIRSPEC.json +58 -0
  113. pyreduce/settings/settings_NTE.json +60 -0
  114. pyreduce/settings/settings_UVES.json +54 -0
  115. pyreduce/settings/settings_XSHOOTER.json +78 -0
  116. pyreduce/settings/settings_pyreduce.json +178 -0
  117. pyreduce/settings/settings_schema.json +827 -0
  118. pyreduce/tools/__init__.py +0 -0
  119. pyreduce/tools/combine.py +117 -0
  120. pyreduce/trace_orders.py +645 -0
  121. pyreduce/util.py +1288 -0
  122. pyreduce/wavecal/MICADO_HK_3arcsec_chip5.npz +0 -0
  123. pyreduce/wavecal/atlas/thar.fits +4946 -13
  124. pyreduce/wavecal/atlas/thar_list.txt +4172 -0
  125. pyreduce/wavecal/atlas/une.fits +0 -0
  126. pyreduce/wavecal/convert.py +38 -0
  127. pyreduce/wavecal/crires_plus_J1228_Open_det1.npz +0 -0
  128. pyreduce/wavecal/crires_plus_J1228_Open_det2.npz +0 -0
  129. pyreduce/wavecal/crires_plus_J1228_Open_det3.npz +0 -0
  130. pyreduce/wavecal/harpn_harpn_2D.npz +0 -0
  131. pyreduce/wavecal/harps_blue_2D.npz +0 -0
  132. pyreduce/wavecal/harps_blue_pol_2D.npz +0 -0
  133. pyreduce/wavecal/harps_red_2D.npz +0 -0
  134. pyreduce/wavecal/harps_red_pol_2D.npz +0 -0
  135. pyreduce/wavecal/mcdonald.npz +0 -0
  136. pyreduce/wavecal/metis_lss_l_2D.npz +0 -0
  137. pyreduce/wavecal/metis_lss_m_2D.npz +0 -0
  138. pyreduce/wavecal/nirspec_K2.npz +0 -0
  139. pyreduce/wavecal/uves_blue_360nm_2D.npz +0 -0
  140. pyreduce/wavecal/uves_blue_390nm_2D.npz +0 -0
  141. pyreduce/wavecal/uves_blue_437nm_2D.npz +0 -0
  142. pyreduce/wavecal/uves_middle_2x2_2D.npz +0 -0
  143. pyreduce/wavecal/uves_middle_565nm_2D.npz +0 -0
  144. pyreduce/wavecal/uves_middle_580nm_2D.npz +0 -0
  145. pyreduce/wavecal/uves_middle_600nm_2D.npz +0 -0
  146. pyreduce/wavecal/uves_middle_665nm_2D.npz +0 -0
  147. pyreduce/wavecal/uves_middle_860nm_2D.npz +0 -0
  148. pyreduce/wavecal/uves_red_580nm_2D.npz +0 -0
  149. pyreduce/wavecal/uves_red_600nm_2D.npz +0 -0
  150. pyreduce/wavecal/uves_red_665nm_2D.npz +0 -0
  151. pyreduce/wavecal/uves_red_760nm_2D.npz +0 -0
  152. pyreduce/wavecal/uves_red_860nm_2D.npz +0 -0
  153. pyreduce/wavecal/xshooter_nir.npz +0 -0
  154. pyreduce/wavelength_calibration.py +1873 -0
  155. pyreduce_astro-0.6.0b5.dist-info/METADATA +113 -0
  156. pyreduce_astro-0.6.0b5.dist-info/RECORD +158 -0
  157. pyreduce_astro-0.6.0b5.dist-info/WHEEL +4 -0
  158. pyreduce_astro-0.6.0b5.dist-info/licenses/LICENSE +674 -0
@@ -0,0 +1,1873 @@
1
+ """
2
+ Wavelength Calibration
3
+ by comparison to a reference spectrum
4
+ Loosely bases on the IDL wavecal function
5
+ """
6
+
7
+ import logging
8
+ from os.path import dirname, join
9
+
10
+ import corner
11
+ import emcee
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from astropy.io import fits
15
+ from numpy.polynomial.polynomial import Polynomial, polyval2d
16
+ from scipy import signal
17
+ from scipy.constants import speed_of_light
18
+ from scipy.interpolate import interp1d
19
+ from scipy.ndimage.filters import gaussian_filter1d
20
+ from scipy.ndimage.morphology import grey_closing
21
+ from scipy.optimize import curve_fit
22
+ from tqdm import tqdm
23
+
24
+ from . import util
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def polyfit(x, y, deg):
30
+ res = Polynomial.fit(x, y, deg, domain=[])
31
+ coef = res.coef[::-1]
32
+ return coef
33
+
34
+
35
+ class AlignmentPlot:
36
+ """
37
+ Makes a plot which can be clicked to align the two spectra, reference and observed
38
+ """
39
+
40
+ def __init__(self, ax, obs, lines, offset=(0, 0), plot_title=None):
41
+ self.im = ax
42
+ self.first = True
43
+ self.nord, self.ncol = obs.shape
44
+ self.RED, self.GREEN, self.BLUE = 0, 1, 2
45
+
46
+ self.obs = obs
47
+ self.lines = lines
48
+ self.plot_title = plot_title
49
+
50
+ self.order_first = 0
51
+ self.spec_first = ""
52
+ self.x_first = 0
53
+ self.offset = list(offset)
54
+
55
+ self.make_ref_image()
56
+
57
+ def make_ref_image(self):
58
+ """create and show the reference plot, with the two spectra"""
59
+ ref_image = np.zeros((self.nord * 2, self.ncol, 3))
60
+ for iord in range(self.nord):
61
+ ref_image[iord * 2, :, self.RED] = 10 * np.ma.filled(self.obs[iord], 0)
62
+ if 0 <= iord + self.offset[0] < self.nord:
63
+ for line in self.lines[self.lines["order"] == iord]:
64
+ first = int(np.clip(line["xfirst"] + self.offset[1], 0, self.ncol))
65
+ last = int(np.clip(line["xlast"] + self.offset[1], 0, self.ncol))
66
+ order = (iord + self.offset[0]) * 2 + 1
67
+ ref_image[order, first:last, self.GREEN] = (
68
+ 10
69
+ * line["height"]
70
+ * signal.windows.gaussian(last - first, line["width"])
71
+ )
72
+ ref_image = np.clip(ref_image, 0, 1)
73
+ ref_image[ref_image < 0.1] = 0
74
+
75
+ self.im.imshow(
76
+ ref_image,
77
+ aspect="auto",
78
+ origin="lower",
79
+ extent=(-0.5, self.ncol - 0.5, -0.5, self.nord - 0.5),
80
+ )
81
+ title = "Alignment, Observed: RED, Reference: GREEN\nGreen should be above red!"
82
+ if self.plot_title is not None:
83
+ title = f"{self.plot_title}\n{title}"
84
+ self.im.figure.suptitle(title)
85
+ self.im.axes.set_xlabel("x [pixel]")
86
+ self.im.axes.set_ylabel("Order")
87
+
88
+ self.im.figure.canvas.draw()
89
+
90
+ def connect(self):
91
+ """connect the click event with the appropiate function"""
92
+ self.cidclick = self.im.figure.canvas.mpl_connect(
93
+ "button_press_event", self.on_click
94
+ )
95
+
96
+ def on_click(self, event):
97
+ """On click offset the reference by the distance between click positions"""
98
+ if event.ydata is None:
99
+ return
100
+ order = int(np.floor(event.ydata))
101
+ spec = "ref" if (event.ydata - order) > 0.5 else "obs" # if True then reference
102
+ x = event.xdata
103
+ print("Order: %i, Spectrum: %s, x: %g" % (order, "ref" if spec else "obs", x))
104
+
105
+ # on every second click
106
+ if self.first:
107
+ self.first = False
108
+ self.order_first = order
109
+ self.spec_first = spec
110
+ self.x_first = x
111
+ else:
112
+ # Clicked different spectra
113
+ if spec != self.spec_first:
114
+ self.first = True
115
+ direction = -1 if spec == "ref" else 1
116
+ offset_orders = int(order - self.order_first) * direction
117
+ offset_x = int(x - self.x_first) * direction
118
+ self.offset[0] -= offset_orders - 1
119
+ self.offset[1] -= offset_x
120
+ self.make_ref_image()
121
+
122
+
123
+ class LineAtlas:
124
+ def __init__(self, element, medium="vac"):
125
+ self.element = element
126
+ self.medium = medium
127
+
128
+ fname = element.lower() + ".fits"
129
+ folder = dirname(__file__)
130
+ self.fname = join(folder, "wavecal/atlas", fname)
131
+ self.wave, self.flux = self.load_fits(self.fname)
132
+
133
+ try:
134
+ # If a specific linelist file is provided
135
+ fname_list = element.lower() + "_list.txt"
136
+ self.fname_list = join(folder, "wavecal/atlas", fname_list)
137
+ linelist = np.genfromtxt(self.fname_list, dtype="f8,U8")
138
+ wpos, element = linelist["f0"], linelist["f1"]
139
+ indices = self.wave.searchsorted(wpos)
140
+ heights = self.flux[indices]
141
+ self.linelist = np.rec.fromarrays(
142
+ [wpos, heights, element], names=["wave", "heights", "element"]
143
+ )
144
+ except (OSError, FileNotFoundError):
145
+ # Otherwise fit the line positions from the spectrum
146
+ logger.warning(
147
+ "No dedicated linelist found for %s, determining peaks based on the reference spectrum instead.",
148
+ element,
149
+ )
150
+ module = WavelengthCalibration(plot=False)
151
+ n, peaks = module._find_peaks(self.flux)
152
+ wpos = np.interp(peaks, np.arange(len(self.wave)), self.wave)
153
+ element = np.full(len(wpos), element)
154
+ indices = self.wave.searchsorted(wpos)
155
+ heights = self.flux[indices]
156
+ self.linelist = np.rec.fromarrays(
157
+ [wpos, heights, element], names=["wave", "heights", "element"]
158
+ )
159
+
160
+ # The data files are in vaccuum, if the instrument is in air, we need to convert
161
+ if medium == "air":
162
+ self.wave = util.vac2air(self.wave)
163
+ self.linelist["wave"] = util.vac2air(self.linelist["wave"])
164
+
165
+ def load_fits(self, fname):
166
+ with fits.open(fname, memmap=False) as hdu:
167
+ if len(hdu) == 1:
168
+ # Its just the spectrum
169
+ # with the wavelength defined via the header keywords
170
+ header = hdu[0].header
171
+ spec = hdu[0].data.ravel()
172
+ wmin = header["CRVAL1"]
173
+ wdel = header["CDELT1"]
174
+ wave = np.arange(spec.size) * wdel + wmin
175
+ else:
176
+ # Its a binary Table, with two columns for the wavelength and the
177
+ # spectrum
178
+ data = hdu[1].data
179
+ wave = data["wave"]
180
+ spec = data["spec"]
181
+
182
+ spec /= np.nanmax(spec)
183
+ spec = np.clip(spec, 0, None)
184
+ return wave, spec
185
+
186
+
187
+ class LineList:
188
+ dtype = np.dtype(
189
+ (
190
+ np.record,
191
+ [
192
+ (("wlc", "WLC"), ">f8"), # Wavelength (before fit)
193
+ (("wll", "WLL"), ">f8"), # Wavelength (after fit)
194
+ (("posc", "POSC"), ">f8"), # Pixel Position (before fit)
195
+ (("posm", "POSM"), ">f8"), # Pixel Position (after fit)
196
+ (("xfirst", "XFIRST"), ">i2"), # first pixel of the line
197
+ (("xlast", "XLAST"), ">i2"), # last pixel of the line
198
+ (
199
+ ("approx", "APPROX"),
200
+ "O",
201
+ ), # Not used. Describes the shape used to approximate the line. "G" for Gaussian
202
+ (("width", "WIDTH"), ">f8"), # width of the line in pixels
203
+ (("height", "HEIGHT"), ">f8"), # relative strength of the line
204
+ (("order", "ORDER"), ">i2"), # echelle order the line is found in
205
+ ("flag", "?"), # flag that tells us if we should use the line or not
206
+ ],
207
+ )
208
+ )
209
+
210
+ def __init__(self, lines=None):
211
+ if lines is None:
212
+ lines = np.array([], dtype=self.dtype)
213
+ self.data = lines
214
+ self.dtype = self.data.dtype
215
+
216
+ def __getitem__(self, key):
217
+ return self.data[key]
218
+
219
+ def __setitem__(self, key, value):
220
+ self.data[key] = value
221
+
222
+ def __len__(self):
223
+ return len(self.data)
224
+
225
+ @classmethod
226
+ def load(cls, filename):
227
+ data = np.load(filename, allow_pickle=True)
228
+ linelist = cls(data["cs_lines"])
229
+ return linelist
230
+
231
+ def save(self, filename):
232
+ np.savez(filename, cs_lines=self.data)
233
+
234
+ def append(self, linelist):
235
+ if isinstance(linelist, LineList):
236
+ linelist = linelist.data
237
+ self.data = np.append(self.data, linelist)
238
+
239
+ def add_line(self, wave, order, pos, width, height, flag):
240
+ lines = self.from_list([wave], [order], [pos], [width], [height], [flag])
241
+ self.data = np.append(self.data, lines)
242
+
243
+ @classmethod
244
+ def from_list(cls, wave, order, pos, width, height, flag):
245
+ lines = [
246
+ (w, w, p, p, p - wi / 2, p + wi / 2, b"G", wi, h, o, f)
247
+ for w, o, p, wi, h, f in zip(
248
+ wave, order, pos, width, height, flag, strict=False
249
+ )
250
+ ]
251
+ lines = np.array(lines, dtype=cls.dtype)
252
+ return cls(lines)
253
+
254
+
255
+ class WavelengthCalibration:
256
+ """
257
+ Wavelength Calibration Module
258
+
259
+ Takes an observed wavelength image and the reference linelist
260
+ and returns the wavelength at each pixel
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ threshold=100,
266
+ degree=(6, 6),
267
+ iterations=3,
268
+ dimensionality="2D",
269
+ nstep=0,
270
+ correlate_cols=0,
271
+ shift_window=0.01,
272
+ manual=False,
273
+ polarim=False,
274
+ lfc_peak_width=3,
275
+ closing=5,
276
+ element=None,
277
+ medium="vac",
278
+ plot=True,
279
+ plot_title=None,
280
+ ):
281
+ #:float: Residual threshold in m/s above which to remove lines
282
+ self.threshold = threshold
283
+ #:tuple(int, int): polynomial degree of the wavelength fit in (pixel, order) direction
284
+ self.degree = degree
285
+ if dimensionality == "1D":
286
+ self.degree = int(degree)
287
+ elif dimensionality == "2D":
288
+ self.degree = (int(degree[0]), int(degree[1]))
289
+ #:int: Number of iterations in the remove residuals, auto id, loop
290
+ self.iterations = iterations
291
+ #:{"1D", "2D"}: Whether to use 1d or 2d fit
292
+ self.dimensionality = dimensionality
293
+ #:bool: Whether to fit for pixel steps (offsets) in the detector
294
+ self.nstep = nstep
295
+ #:int: How many columns to use in the 2D cross correlation alignment. 0 means all pixels (slow).
296
+ self.correlate_cols = correlate_cols
297
+ #:float: Fraction if the number of columns to use in the alignment of individual orders. Set to 0 to disable
298
+ self.shift_window = shift_window
299
+ #:bool: Whether to manually align the reference instead of using cross correlation
300
+ self.manual = manual
301
+ #:bool: Whether to use polarimetric orders instead of the usual ones. I.e. Each pair of two orders represents the same data. Not Supported yet
302
+ self.polarim = polarim
303
+ #:int: Whether to plot the results. Set to 2 to plot during all steps.
304
+ self.plot = plot
305
+ self.plot_title = plot_title
306
+ #:str: Elements used in the wavelength calibration. Used in AutoId to find more lines from the Atlas
307
+ self.element = element
308
+ #:str: Medium of the detector, vac or air
309
+ self.medium = medium
310
+ #:int: Laser Frequency Peak width (for scipy.signal.find_peaks)
311
+ self.lfc_peak_width = lfc_peak_width
312
+ #:int: grey closing range for the input image
313
+ self.closing = 5
314
+ #:int: Number of orders in the observation
315
+ self.nord = None
316
+ #:int: Number of columns in the observation
317
+ self.ncol = None
318
+
319
+ @property
320
+ def step_mode(self):
321
+ return self.nstep > 0
322
+
323
+ @property
324
+ def dimensionality(self):
325
+ """{"1D", "2D"}: Whether to use 1D or 2D polynomials for the wavelength solution"""
326
+ return self._dimensionality
327
+
328
+ @dimensionality.setter
329
+ def dimensionality(self, value):
330
+ accepted_values = ["1D", "2D"]
331
+ if value in accepted_values:
332
+ self._dimensionality = value
333
+ else:
334
+ raise ValueError(
335
+ f"Value for 'dimensionality' not understood. Expected one of {accepted_values} but got {value} instead"
336
+ )
337
+
338
+ def normalize(self, obs, lines):
339
+ """
340
+ Normalize the observation and reference list in each order individually
341
+ Copies the data if the image, but not of the linelist
342
+
343
+ Parameters
344
+ ----------
345
+ obs : array of shape (nord, ncol)
346
+ observed image
347
+ lines : recarray of shape (nlines,)
348
+ reference linelist
349
+
350
+ Returns
351
+ -------
352
+ obs : array of shape (nord, ncol)
353
+ normalized image
354
+ lines : recarray of shape (nlines,)
355
+ normalized reference linelist
356
+ """
357
+ # normalize order by order
358
+ obs = np.ma.copy(obs)
359
+ for i in range(len(obs)):
360
+ if self.closing > 0:
361
+ obs[i] = grey_closing(obs[i], self.closing)
362
+ try:
363
+ obs[i] -= np.ma.median(obs[i][obs[i] > 0])
364
+ except ValueError:
365
+ logger.warning(
366
+ "Could not determine the minimum value in order %i. No positive values found",
367
+ i,
368
+ )
369
+ obs[i] /= np.ma.max(obs[i])
370
+
371
+ # Remove negative outliers
372
+ std = np.std(obs, axis=1)[:, None]
373
+ obs[obs <= -2 * std] = np.ma.masked
374
+ # obs[obs <= 0] = np.ma.masked
375
+
376
+ # Normalize lines in each order
377
+ for order in np.unique(lines["order"]):
378
+ select = lines["order"] == order
379
+ topheight = np.max(lines[select]["height"])
380
+ lines["height"][select] /= topheight
381
+
382
+ return obs, lines
383
+
384
+ def create_image_from_lines(self, lines):
385
+ """
386
+ Create a reference image based on a line list
387
+ Each line will be approximated by a Gaussian
388
+ Space inbetween lines is 0
389
+ The number of orders is from 0 to the maximum order
390
+
391
+ Parameters
392
+ ----------
393
+ lines : recarray of shape (nlines,)
394
+ line data
395
+
396
+ Returns
397
+ -------
398
+ img : array of shape (nord, ncol)
399
+ New reference image
400
+ """
401
+ min_order = int(np.min(lines["order"]))
402
+ max_order = int(np.max(lines["order"]))
403
+ img = np.zeros((max_order - min_order + 1, self.ncol))
404
+ for line in lines:
405
+ if line["order"] < 0:
406
+ continue
407
+ if line["xlast"] < 0 or line["xfirst"] > self.ncol:
408
+ continue
409
+ first = int(max(line["xfirst"], 0))
410
+ last = int(min(line["xlast"], self.ncol))
411
+ img[int(line["order"]) - min_order, first:last] = line[
412
+ "height"
413
+ ] * signal.windows.gaussian(last - first, line["width"])
414
+ return img
415
+
416
+ def align_manual(self, obs, lines):
417
+ """
418
+ Open an AlignmentPlot window for manual selection of the alignment
419
+
420
+ Parameters
421
+ ----------
422
+ obs : array of shape (nord, ncol)
423
+ observed image
424
+ lines : recarray of shape (nlines,)
425
+ reference linelist
426
+
427
+ Returns
428
+ -------
429
+ offset : tuple(int, int)
430
+ offset in order and column to be applied to each line in the linelist
431
+ """
432
+ _, ax = plt.subplots()
433
+ ap = AlignmentPlot(ax, obs, lines, plot_title=self.plot_title)
434
+ ap.connect()
435
+ plt.show()
436
+ offset = ap.offset
437
+ return offset
438
+
439
+ def apply_alignment_offset(self, lines, offset, select=None):
440
+ """
441
+ Apply an offset to the linelist
442
+
443
+ Parameters
444
+ ----------
445
+ lines : recarray of shape (nlines,)
446
+ reference linelist
447
+ offset : tuple(int, int)
448
+ offset in (order, column)
449
+ select : array of shape(nlines,), optional
450
+ Mask that defines which lines the offset applies to
451
+
452
+ Returns
453
+ -------
454
+ lines : recarray of shape (nlines,)
455
+ linelist with offset applied
456
+ """
457
+ if select is None:
458
+ select = slice(None)
459
+ lines["xfirst"][select] += offset[1]
460
+ lines["xlast"][select] += offset[1]
461
+ lines["posm"][select] += offset[1]
462
+ lines["order"][select] += offset[0]
463
+ return lines
464
+
465
+ def align(self, obs, lines):
466
+ """
467
+ Align the observation with the reference spectrum
468
+ Either automatically using cross correlation or manually (visually)
469
+
470
+ Parameters
471
+ ----------
472
+ obs : array[nrow, ncol]
473
+ observed wavelength calibration spectrum (e.g. obs=ThoriumArgon)
474
+ lines : struct_array
475
+ reference line data
476
+ manual : bool, optional
477
+ wether to manually align the spectra (default: False)
478
+ plot : bool, optional
479
+ wether to plot the alignment (default: False)
480
+
481
+ Returns
482
+ -------
483
+ offset: tuple(int, int)
484
+ offset in order and column
485
+ """
486
+ obs = np.ma.filled(obs, 0)
487
+
488
+ if not self.manual:
489
+ # make image from lines
490
+ img = self.create_image_from_lines(lines)
491
+
492
+ # Crop the image to speed up cross correlation
493
+ if self.correlate_cols != 0:
494
+ _slice = slice(
495
+ (self.ncol - self.correlate_cols) // 2,
496
+ (self.ncol + self.correlate_cols) // 2 + 1,
497
+ )
498
+ ccimg = img[:, _slice]
499
+ ccobs = obs[:, _slice]
500
+ else:
501
+ ccimg, ccobs = img, obs
502
+
503
+ # Cross correlate with obs image
504
+ # And determine overall offset
505
+ correlation = signal.correlate2d(ccobs, ccimg, mode="same")
506
+ offset_order, offset_x = np.unravel_index(
507
+ np.argmax(correlation), correlation.shape
508
+ )
509
+
510
+ if self.plot >= 2:
511
+ plt.imshow(correlation, aspect="auto")
512
+ plt.vlines(offset_x, -0.5, correlation.shape[0] - 0.5, color="red")
513
+ plt.hlines(offset_order, -0.5, correlation.shape[1] - 0.5, color="red")
514
+ if self.plot_title is not None:
515
+ plt.title(self.plot_title)
516
+ plt.show()
517
+
518
+ offset_order = offset_order - ccimg.shape[0] / 2 + 1
519
+ offset_x = offset_x - ccimg.shape[1] / 2 + 1
520
+ offset = [int(offset_order), int(offset_x)]
521
+
522
+ # apply offset
523
+ lines = self.apply_alignment_offset(lines, offset)
524
+
525
+ if self.shift_window != 0:
526
+ # Shift individual orders to fit reference
527
+ # Only allow a small shift here (1%) ?
528
+ img = self.create_image_from_lines(lines)
529
+ for i in range(max(offset[0], 0), min(len(obs), len(img))):
530
+ correlation = signal.correlate(obs[i], img[i], mode="same")
531
+ width = int(self.ncol * self.shift_window) // 2
532
+ low, high = self.ncol // 2 - width, self.ncol // 2 + width
533
+ offset_x = np.argmax(correlation[low:high]) + low
534
+ offset_x = int(offset_x - self.ncol / 2 + 1)
535
+
536
+ select = lines["order"] == i
537
+ lines = self.apply_alignment_offset(lines, (0, offset_x), select)
538
+
539
+ if self.plot or self.manual:
540
+ offset = self.align_manual(obs, lines)
541
+ lines = self.apply_alignment_offset(lines, offset)
542
+
543
+ logger.debug(f"Offset order: {offset[0]}, Offset pixel: {offset[1]}")
544
+
545
+ return lines
546
+
547
+ def _fit_single_line(self, obs, center, width, plot=False):
548
+ low = int(center - width * 5)
549
+ low = max(low, 0)
550
+ high = int(center + width * 5)
551
+ high = min(high, len(obs))
552
+
553
+ section = obs[low:high]
554
+ x = np.arange(low, high, 1)
555
+ x = np.ma.masked_array(x, mask=np.ma.getmaskarray(section))
556
+ coef = util.gaussfit2(x, section)
557
+
558
+ if self.plot >= 2 and plot:
559
+ x2 = np.linspace(x.min(), x.max(), len(x) * 100)
560
+ plt.plot(x, section, label="Observation")
561
+ plt.plot(x2, util.gaussval2(x2, *coef), label="Fit")
562
+ title = "Gaussian Fit to spectral line"
563
+ if self.plot_title is not None:
564
+ title = f"{self.plot_title}\n{title}"
565
+ plt.title(title)
566
+ plt.xlabel("x [pixel]")
567
+ plt.ylabel("Intensity [a.u.]")
568
+ plt.legend()
569
+ plt.show()
570
+ return coef
571
+
572
+ def fit_lines(self, obs, lines):
573
+ """
574
+ Determine exact position of each line on the detector based on initial guess
575
+
576
+ This fits a Gaussian to each line, and uses the peak position as a new solution
577
+
578
+ Parameters
579
+ ----------
580
+ obs : array of shape (nord, ncol)
581
+ observed wavelength calibration image
582
+ lines : recarray of shape (nlines,)
583
+ reference line data
584
+
585
+ Returns
586
+ -------
587
+ lines : recarray of shape (nlines,)
588
+ Updated line information (posm is changed)
589
+ """
590
+ # For each line fit a gaussian to the observation
591
+ for i, line in tqdm(
592
+ enumerate(lines), total=len(lines), leave=False, desc="Lines"
593
+ ):
594
+ if line["posm"] < 0 or line["posm"] >= obs.shape[1]:
595
+ # Line outside pixel range
596
+ continue
597
+ if line["order"] < 0 or line["order"] >= len(obs):
598
+ # Line outside order range
599
+ continue
600
+
601
+ try:
602
+ coef = self._fit_single_line(
603
+ obs[int(line["order"])],
604
+ line["posm"],
605
+ line["width"],
606
+ plot=line["flag"],
607
+ )
608
+ lines[i]["posm"] = coef[1]
609
+ except:
610
+ # Gaussian fit failed, dont use line
611
+ lines[i]["flag"] = False
612
+
613
+ return lines
614
+
615
+ def build_2d_solution(self, lines, plot=False):
616
+ """
617
+ Create a 2D polynomial fit to flagged lines
618
+ degree : tuple(int, int), optional
619
+ polynomial degree of the fit in (column, order) dimension (default: (6, 6))
620
+
621
+ Parameters
622
+ ----------
623
+ lines : struc_array
624
+ line data
625
+ plot : bool, optional
626
+ wether to plot the solution (default: False)
627
+
628
+ Returns
629
+ -------
630
+ coef : array[degree_x, degree_y]
631
+ 2d polynomial coefficients
632
+ """
633
+
634
+ if self.step_mode:
635
+ return self.build_step_solution(lines, plot=plot)
636
+
637
+ # Only use flagged data
638
+ mask = lines["flag"] # True: use line, False: dont use line
639
+ m_wave = lines["wll"][mask]
640
+ m_pix = lines["posm"][mask]
641
+ m_ord = lines["order"][mask]
642
+
643
+ if self.dimensionality == "1D":
644
+ nord = self.nord
645
+ coef = np.zeros((nord, self.degree + 1))
646
+ for i in range(nord):
647
+ select = m_ord == i
648
+ if np.count_nonzero(select) < 2:
649
+ # Not enough lines for wavelength solution
650
+ logger.warning(
651
+ "Not enough valid lines found wavelength calibration in order % i",
652
+ i,
653
+ )
654
+ coef[i] = np.nan
655
+ continue
656
+
657
+ deg = max(min(self.degree, np.count_nonzero(select) - 2), 0)
658
+ coef[i, -(deg + 1) :] = np.polyfit(
659
+ m_pix[select], m_wave[select], deg=deg
660
+ )
661
+ elif self.dimensionality == "2D":
662
+ # 2d polynomial fit with: x = column, y = order, z = wavelength
663
+ coef = util.polyfit2d(m_pix, m_ord, m_wave, degree=self.degree, plot=False)
664
+ else:
665
+ raise ValueError(
666
+ f"Parameter 'mode' not understood. Expected '1D' or '2D' but got {self.dimensionality}"
667
+ )
668
+
669
+ if plot or self.plot >= 2: # pragma: no cover
670
+ self.plot_residuals(lines, coef, title="Residuals")
671
+
672
+ return coef
673
+
674
+ def g(self, x, step_coef_pos, step_coef_diff):
675
+ try:
676
+ bins = step_coef_pos
677
+ digits = np.digitize(x, bins) - 1
678
+ except ValueError:
679
+ return np.inf
680
+
681
+ cumsum = np.cumsum(step_coef_diff)
682
+ x = x + cumsum[digits]
683
+ return x
684
+
685
+ def f(self, x, poly_coef, step_coef_pos, step_coef_diff):
686
+ xdash = self.g(x, step_coef_pos, step_coef_diff)
687
+ if np.all(np.isinf(xdash)):
688
+ return np.inf
689
+ y = np.polyval(poly_coef, xdash)
690
+ return y
691
+
692
+ def build_step_solution(self, lines, plot=False):
693
+ """
694
+ Fit the least squares fit to the wavelength points,
695
+ with additional free parameters for detector gaps, e.g. due to stitching.
696
+
697
+ The exact method of the fit depends on the dimensionality.
698
+ Either way we are using the usual polynomial fit for the wavelength, but
699
+ the x points are modified beforehand by shifting them some amount, at specific
700
+ indices. We assume that the stitching effects are distributed evenly and we know how
701
+ many steps we expect (this is set as "nstep").
702
+
703
+ Parameters
704
+ ----------
705
+ lines : np.recarray
706
+ linedata
707
+ plot : bool, optional
708
+ whether to plot results or not, by default False
709
+
710
+ Returns
711
+ -------
712
+ coef
713
+ coefficients of the best fit
714
+ """
715
+ mask = lines["flag"] # True: use line, False: dont use line
716
+ m_wave = lines["wll"][mask]
717
+ m_pix = lines["posm"][mask]
718
+ m_ord = lines["order"][mask]
719
+
720
+ nstep = self.nstep
721
+ ncol = self.ncol
722
+
723
+ if self.dimensionality == "1D":
724
+ coef = {}
725
+ for order in np.unique(m_ord):
726
+ select = m_ord == order
727
+ x = xl = m_pix[select]
728
+ y = m_wave[select]
729
+ step_coef = np.zeros((nstep, 2))
730
+ step_coef[:, 0] = np.linspace(ncol / (nstep + 1), ncol, nstep + 1)[:-1]
731
+
732
+ def func(x, *param):
733
+ return self.f(x, poly_coef, step_coef[:, 0], param) # noqa: B023
734
+
735
+ for _ in range(5):
736
+ poly_coef = np.polyfit(xl, y, self.degree)
737
+ res, _ = curve_fit(func, x, y, p0=step_coef[:, 1], bounds=[-1, 1])
738
+ step_coef[:, 1] = res
739
+ xl = self.g(x, step_coef[:, 0], step_coef[:, 1])
740
+
741
+ coef[order] = [poly_coef, step_coef]
742
+ elif self.dimensionality == "2D":
743
+ unique = np.unique(m_ord)
744
+ nord = len(unique)
745
+ shape = (self.degree[0] + 1, self.degree[1] + 1)
746
+ np.prod(shape)
747
+
748
+ step_coef = np.zeros((nord, nstep, 2))
749
+ step_coef[:, :, 0] = np.linspace(ncol / (nstep + 1), ncol, nstep + 1)[:-1]
750
+
751
+ def func(x, *param):
752
+ x, y = x[: len(x) // 2], x[len(x) // 2 :]
753
+ theta = np.asarray(param).reshape((nord, nstep))
754
+ xl = np.copy(x)
755
+ for j, i in enumerate(unique):
756
+ xl[y == i] = self.g(x[y == i], step_coef[j, :, 0], theta[j])
757
+ z = polyval2d(xl, y, poly_coef)
758
+ return z
759
+
760
+ # TODO: this could use some optimization
761
+ x = np.copy(m_pix)
762
+ x0 = np.concatenate((m_pix, m_ord))
763
+ resid_old = np.inf
764
+ for k in tqdm(range(5)):
765
+ poly_coef = util.polyfit2d(
766
+ x, m_ord, m_wave, degree=self.degree, plot=False
767
+ )
768
+
769
+ res, _ = curve_fit(func, x0, m_wave, p0=step_coef[:, :, 1])
770
+ step_coef[:, :, 1] = res.reshape((nord, nstep))
771
+ for j, i in enumerate(unique):
772
+ x[m_ord == i] = self.g(
773
+ m_pix[m_ord == i], step_coef[j][:, 0], step_coef[j][:, 1]
774
+ )
775
+
776
+ resid = polyval2d(x, m_ord, poly_coef) - m_wave
777
+ resid = np.sum(resid**2)
778
+ improvement = resid_old - resid
779
+ resid_old = resid
780
+ logger.info(
781
+ "Iteration: %i, Residuals: %.5g, Improvement: %.5g",
782
+ k,
783
+ resid,
784
+ improvement,
785
+ )
786
+
787
+ poly_coef = util.polyfit2d(x, m_ord, m_wave, degree=self.degree, plot=False)
788
+ step_coef = {i: step_coef[j] for j, i in enumerate(unique)}
789
+ coef = (poly_coef, step_coef)
790
+ else:
791
+ raise ValueError(
792
+ f"Parameter 'dimensionality' not understood. Expected '1D' or '2D' but got {self.dimensionality}"
793
+ )
794
+
795
+ return coef
796
+
797
+ def evaluate_step_solution(self, pos, order, solution):
798
+ if not np.array_equal(np.shape(pos), np.shape(order)):
799
+ raise ValueError("pos and order must have the same shape")
800
+ if self.dimensionality == "1D":
801
+ result = np.zeros(pos.shape)
802
+ for i in np.unique(order):
803
+ select = order == i
804
+ result[select] = self.f(
805
+ pos[select],
806
+ solution[i][0],
807
+ solution[i][1][:, 0],
808
+ solution[i][1][:, 1],
809
+ )
810
+ elif self.dimensionality == "2D":
811
+ poly_coef, step_coef = solution
812
+ pos = np.copy(pos)
813
+ for i in np.unique(order):
814
+ pos[order == i] = self.g(
815
+ pos[order == i], step_coef[i][:, 0], step_coef[i][:, 1]
816
+ )
817
+ result = polyval2d(pos, order, poly_coef)
818
+ else:
819
+ raise ValueError(
820
+ f"Parameter 'mode' not understood, expected '1D' or '2D' but got {self.dimensionality}"
821
+ )
822
+ return result
823
+
824
+ def evaluate_solution(self, pos, order, solution):
825
+ """
826
+ Evaluate the 1d or 2d wavelength solution at the given pixel positions and orders
827
+
828
+ Parameters
829
+ ----------
830
+ pos : array
831
+ pixel position on the detector (i.e. x axis)
832
+ order : array
833
+ order of each point
834
+ solution : array of shape (nord, ndegree) or (degree_x, degree_y)
835
+ polynomial coefficients. For mode=1D, one set of coefficients per order.
836
+ For mode=2D, the first dimension is for the positions and the second for the orders
837
+ mode : str, optional
838
+ Wether to interpret the solution as 1D or 2D polynomials, by default "1D"
839
+
840
+ Returns
841
+ -------
842
+ result: array
843
+ Evaluated polynomial
844
+
845
+ Raises
846
+ ------
847
+ ValueError
848
+ If pos and order have different shapes, or mode is of the wrong value
849
+ """
850
+ if not np.array_equal(np.shape(pos), np.shape(order)):
851
+ raise ValueError("pos and order must have the same shape")
852
+
853
+ if self.step_mode:
854
+ return self.evaluate_step_solution(pos, order, solution)
855
+
856
+ if self.dimensionality == "1D":
857
+ result = np.zeros(pos.shape)
858
+ for i in np.unique(order):
859
+ select = order == i
860
+ result[select] = np.polyval(solution[int(i)], pos[select])
861
+ elif self.dimensionality == "2D":
862
+ result = np.polynomial.polynomial.polyval2d(pos, order, solution)
863
+ else:
864
+ raise ValueError(
865
+ f"Parameter 'mode' not understood, expected '1D' or '2D' but got {self.dimensionality}"
866
+ )
867
+ return result
868
+
869
+ def make_wave(self, wave_solution, plot=False):
870
+ """Expand polynomial wavelength solution into full image
871
+
872
+ Parameters
873
+ ----------
874
+ wave_solution : array of shape(degree,)
875
+ polynomial coefficients of wavelength solution
876
+ plot : bool, optional
877
+ wether to plot the solution (default: False)
878
+
879
+ Returns
880
+ -------
881
+ wave_img : array of shape (nord, ncol)
882
+ wavelength solution for each point in the spectrum
883
+ """
884
+
885
+ y, x = np.indices((self.nord, self.ncol))
886
+ wave_img = self.evaluate_solution(x, y, wave_solution)
887
+
888
+ return wave_img
889
+
890
+ def auto_id(self, obs, wave_img, lines):
891
+ """Automatically identify peaks that are close to known lines
892
+
893
+ Parameters
894
+ ----------
895
+ obs : array of shape (nord, ncol)
896
+ observed spectrum
897
+ wave_img : array of shape (nord, ncol)
898
+ wavelength solution image
899
+ lines : struc_array
900
+ line data
901
+ threshold : int, optional
902
+ difference threshold between line positions in m/s, until which a line is considered identified (default: 1)
903
+ plot : bool, optional
904
+ wether to plot the new lines
905
+
906
+ Returns
907
+ -------
908
+ lines : struct_array
909
+ line data with new flags
910
+ """
911
+
912
+ new_lines = []
913
+ if self.atlas is not None:
914
+ # For each order, find the corresponding section in the Atlas
915
+ # Look for strong lines in the atlas and the spectrum that match in position
916
+ # Add new lines to the linelist
917
+ width_of_atlas_peaks = 3
918
+ for order in range(obs.shape[0]):
919
+ mask = ~np.ma.getmask(obs[order])
920
+ index_mask = np.arange(len(mask))[mask]
921
+ data_obs = obs[order, mask]
922
+ wave_obs = wave_img[order, mask]
923
+
924
+ threshold_of_peak_closeness = (
925
+ np.diff(wave_obs) / wave_obs[:-1] * speed_of_light
926
+ )
927
+ threshold_of_peak_closeness = np.max(threshold_of_peak_closeness)
928
+
929
+ wmin, wmax = wave_obs[0], wave_obs[-1]
930
+ imin, imax = np.searchsorted(self.atlas.wave, (wmin, wmax))
931
+ wave_atlas = self.atlas.wave[imin:imax]
932
+ data_atlas = self.atlas.flux[imin:imax]
933
+ if len(data_atlas) == 0:
934
+ continue
935
+ data_atlas = data_atlas / data_atlas.max()
936
+
937
+ line = lines[
938
+ (lines["order"] == order)
939
+ & (lines["wll"] > wmin)
940
+ & (lines["wll"] < wmax)
941
+ ]
942
+
943
+ peaks_atlas, peak_info_atlas = signal.find_peaks(
944
+ data_atlas, height=0.01, width=width_of_atlas_peaks
945
+ )
946
+ peaks_obs, peak_info_obs = signal.find_peaks(
947
+ data_obs, height=0.01, width=0
948
+ )
949
+
950
+ for _, p in enumerate(peaks_atlas):
951
+ # Look for an existing line in the vicinityq
952
+ wpeak = wave_atlas[p]
953
+ diff = np.abs(line["wll"] - wpeak) / wpeak * speed_of_light
954
+ if np.any(diff < threshold_of_peak_closeness):
955
+ # Line already in the linelist, ignore
956
+ continue
957
+ else:
958
+ # Look for matching peak in observation
959
+ diff = (
960
+ np.abs(wpeak - wave_obs[peaks_obs]) / wpeak * speed_of_light
961
+ )
962
+ imin = np.argmin(diff)
963
+
964
+ if diff[imin] < threshold_of_peak_closeness:
965
+ # Add line to linelist
966
+ # Location on the detector
967
+ # Include the masked areas!!!
968
+ ipeak = peaks_obs[imin]
969
+ ipeak = index_mask[ipeak]
970
+
971
+ # relative height of the peak
972
+ hpeak = data_obs[peaks_obs[imin]]
973
+ wipeak = peak_info_obs["widths"][imin]
974
+ # wave, order, pos, width, height, flag
975
+ new_lines.append([wpeak, order, ipeak, wipeak, hpeak, True])
976
+
977
+ # Add new lines to the linelist
978
+ if len(new_lines) != 0:
979
+ new_lines = np.array(new_lines).T
980
+ new_lines = LineList.from_list(*new_lines)
981
+ new_lines = self.fit_lines(obs, new_lines)
982
+ lines.append(new_lines)
983
+
984
+ # Option 1:
985
+ # Step 1: Loop over unused lines in lines
986
+ # Step 2: find peaks in neighbourhood
987
+ # Step 3: Toggle flag on if close
988
+ counter = 0
989
+ for i, line in enumerate(lines):
990
+ if line["flag"]:
991
+ # Line is already in use
992
+ continue
993
+ if line["order"] < 0 or line["order"] >= self.nord:
994
+ # Line outside order range
995
+ continue
996
+ iord = int(line["order"])
997
+ if line["wll"] < wave_img[iord][0] or line["wll"] >= wave_img[iord][-1]:
998
+ # Line outside pixel range
999
+ continue
1000
+
1001
+ wl = line["wll"]
1002
+ width = line["width"] * 5
1003
+ wave = wave_img[iord]
1004
+ order_obs = obs[iord]
1005
+ # Find where the line should be
1006
+ try:
1007
+ idx = np.digitize(wl, wave)
1008
+ except ValueError:
1009
+ # Wavelength solution is not monotonic
1010
+ idx = np.where(wave >= wl)[0][0]
1011
+
1012
+ low = int(idx - width)
1013
+ low = max(low, 0)
1014
+ high = int(idx + width)
1015
+ high = min(high, len(order_obs))
1016
+
1017
+ vec = order_obs[low:high]
1018
+ if np.all(np.ma.getmaskarray(vec)):
1019
+ continue
1020
+ # Find the best fitting peak
1021
+ # TODO use gaussian fit?
1022
+ peak_idx, _ = signal.find_peaks(vec, height=np.ma.median(vec), width=3)
1023
+ if len(peak_idx) > 0:
1024
+ peak_pos = np.copy(peak_idx).astype(float)
1025
+ for j in range(len(peak_idx)):
1026
+ try:
1027
+ coef = self._fit_single_line(vec, peak_idx[j], line["width"])
1028
+ peak_pos[j] = coef[1]
1029
+ except:
1030
+ peak_pos[j] = np.nan
1031
+ pass
1032
+
1033
+ pos_wave = np.interp(peak_pos, np.arange(high - low), wave[low:high])
1034
+ residual = np.abs(wl - pos_wave) / wl * speed_of_light
1035
+ idx = np.argmin(residual)
1036
+ if residual[idx] < self.threshold:
1037
+ counter += 1
1038
+ lines["flag"][i] = True
1039
+ lines["posm"][i] = low + peak_pos[idx]
1040
+
1041
+ logger.info("AutoID identified %i new lines", counter + len(new_lines))
1042
+
1043
+ return lines
1044
+
1045
+ def calculate_residual(self, wave_solution, lines):
1046
+ """
1047
+ Calculate all residuals of all given lines
1048
+
1049
+ Residual = (Wavelength Solution - Expected Wavelength) / Expected Wavelength * speed of light
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ wave_solution : array of shape (degree_x, degree_y)
1054
+ polynomial coefficients of the wavelength solution (in numpy format)
1055
+ lines : recarray of shape (nlines,)
1056
+ contains the position of the line on the detector (posm), the order (order), and the expected wavelength (wll)
1057
+
1058
+ Returns
1059
+ -------
1060
+ residual : array of shape (nlines,)
1061
+ Residual of each line in m/s
1062
+ """
1063
+ x = lines["posm"]
1064
+ y = lines["order"]
1065
+ mask = ~lines["flag"]
1066
+
1067
+ solution = self.evaluate_solution(x, y, wave_solution)
1068
+
1069
+ residual = (solution - lines["wll"]) / lines["wll"] * speed_of_light
1070
+ residual = np.ma.masked_array(residual, mask=mask)
1071
+ return residual
1072
+
1073
+ def reject_outlier(self, residual, lines):
1074
+ """
1075
+ Reject the strongest outlier
1076
+
1077
+ Parameters
1078
+ ----------
1079
+ residual : array of shape (nlines,)
1080
+ residuals of all lines
1081
+ lines : recarray of shape (nlines,)
1082
+ line data
1083
+
1084
+ Returns
1085
+ -------
1086
+ lines : struct_array
1087
+ line data with one more flagged line
1088
+ residual : array of shape (nlines,)
1089
+ residuals of each line, with outliers masked (including the new one)
1090
+ """
1091
+
1092
+ # Strongest outlier
1093
+ ibad = np.ma.argmax(np.abs(residual))
1094
+ lines["flag"][ibad] = False
1095
+
1096
+ return lines
1097
+
1098
+ def reject_lines(self, lines, plot=False):
1099
+ """
1100
+ Reject the largest outlier one by one until all residuals are lower than the threshold
1101
+
1102
+ Parameters
1103
+ ----------
1104
+ lines : recarray of shape (nlines,)
1105
+ Line data with pixel position, and expected wavelength
1106
+ threshold : float, optional
1107
+ upper limit for the residual, by default 100
1108
+ degree : tuple, optional
1109
+ polynomial degree of the wavelength solution (pixel, column) (default: (6, 6))
1110
+ plot : bool, optional
1111
+ Wether to plot the results (default: False)
1112
+
1113
+ Returns
1114
+ -------
1115
+ lines : recarray of shape (nlines,)
1116
+ Line data with updated flags
1117
+ """
1118
+
1119
+ wave_solution = self.build_2d_solution(lines)
1120
+ residual = self.calculate_residual(wave_solution, lines)
1121
+ nbad = 0
1122
+ while np.ma.any(np.abs(residual) > self.threshold):
1123
+ lines = self.reject_outlier(residual, lines)
1124
+ wave_solution = self.build_2d_solution(lines)
1125
+ residual = self.calculate_residual(wave_solution, lines)
1126
+ nbad += 1
1127
+ logger.info("Discarding %i lines", nbad)
1128
+
1129
+ if plot or self.plot >= 2: # pragma: no cover
1130
+ mask = lines["flag"]
1131
+ _, axis = plt.subplots()
1132
+ axis.plot(lines["order"][mask], residual[mask], "X", label="Accepted Lines")
1133
+ axis.plot(
1134
+ lines["order"][~mask], residual[~mask], "D", label="Rejected Lines"
1135
+ )
1136
+ axis.set_xlabel("Order")
1137
+ axis.set_ylabel("Residual [m/s]")
1138
+ axis.set_title("Residuals versus order")
1139
+ axis.legend()
1140
+
1141
+ fig, ax = plt.subplots(
1142
+ nrows=self.nord // 2, ncols=2, sharex=True, squeeze=False
1143
+ )
1144
+ plt.subplots_adjust(hspace=0)
1145
+ fig.suptitle("Residuals of each order versus image columns")
1146
+
1147
+ for iord in range(self.nord):
1148
+ order_lines = lines[lines["order"] == iord]
1149
+ solution = self.evaluate_solution(
1150
+ order_lines["posm"], order_lines["order"], wave_solution
1151
+ )
1152
+ # Residual in m/s
1153
+ residual = (
1154
+ (solution - order_lines["wll"])
1155
+ / order_lines["wll"]
1156
+ * speed_of_light
1157
+ )
1158
+ mask = order_lines["flag"]
1159
+ ax[iord // 2, iord % 2].plot(
1160
+ order_lines["posm"][mask],
1161
+ residual[mask],
1162
+ "X",
1163
+ label="Accepted Lines",
1164
+ )
1165
+ ax[iord // 2, iord % 2].plot(
1166
+ order_lines["posm"][~mask],
1167
+ residual[~mask],
1168
+ "D",
1169
+ label="Rejected Lines",
1170
+ )
1171
+ # ax[iord // 2, iord % 2].tick_params(labelleft=False)
1172
+ ax[iord // 2, iord % 2].set_ylim(
1173
+ -self.threshold * 1.5, +self.threshold * 1.5
1174
+ )
1175
+
1176
+ ax[-1, 0].set_xlabel("x [pixel]")
1177
+ ax[-1, 1].set_xlabel("x [pixel]")
1178
+
1179
+ ax[0, 0].legend()
1180
+
1181
+ plt.show()
1182
+ return lines
1183
+
1184
+ def plot_results(self, wave_img, obs):
1185
+ plt.subplot(211)
1186
+ title = "Wavelength solution with Wavelength calibration spectrum\nOrders are in different colours"
1187
+ if self.plot_title is not None:
1188
+ title = f"{self.plot_title}\n{title}"
1189
+ plt.title(title)
1190
+ plt.xlabel("Wavelength")
1191
+ plt.ylabel("Observed spectrum")
1192
+ for i in range(self.nord):
1193
+ plt.plot(wave_img[i], obs[i], label="Order %i" % i)
1194
+
1195
+ plt.subplot(212)
1196
+ plt.title("2D Wavelength solution")
1197
+ plt.imshow(
1198
+ wave_img, aspect="auto", origin="lower", extent=(0, self.ncol, 0, self.nord)
1199
+ )
1200
+ cbar = plt.colorbar()
1201
+ plt.xlabel("Column")
1202
+ plt.ylabel("Order")
1203
+ cbar.set_label("Wavelength [Å]")
1204
+ plt.show()
1205
+
1206
+ def plot_residuals(self, lines, coef, title="Residuals"):
1207
+ orders = np.unique(lines["order"])
1208
+ norders = len(orders)
1209
+ if self.plot_title is not None:
1210
+ title = f"{self.plot_title}\n{title}"
1211
+ plt.suptitle(title)
1212
+ nplots = int(np.ceil(norders / 2))
1213
+ for i, order in enumerate(orders):
1214
+ plt.subplot(nplots, 2, i + 1)
1215
+ order_lines = lines[lines["order"] == order]
1216
+ if len(order_lines) > 0:
1217
+ residual = self.calculate_residual(coef, order_lines)
1218
+ plt.plot(order_lines["posm"], residual, "rX")
1219
+ plt.hlines([0], 0, self.ncol)
1220
+
1221
+ plt.xlim(0, self.ncol)
1222
+ plt.ylim(-self.threshold, self.threshold)
1223
+
1224
+ if (i + 1) not in [norders, norders - 1]:
1225
+ plt.xticks([])
1226
+ else:
1227
+ plt.xlabel("x [Pixel]")
1228
+
1229
+ if (i + 1) % 2 == 0:
1230
+ plt.yticks([])
1231
+ # else:
1232
+ # plt.yticks([-self.threshold, 0, self.threshold])
1233
+
1234
+ plt.subplots_adjust(hspace=0, wspace=0.1)
1235
+
1236
+ # order = 0
1237
+ # order_lines = lines[lines["order"] == order]
1238
+ # if len(order_lines) > 0:
1239
+ # residual = self.calculate_residual(coef, order_lines)
1240
+ # plt.plot(order_lines["posm"], residual, "rX")
1241
+ # plt.hlines([0], 0, self.ncol)
1242
+ # plt.xlim(0, self.ncol)
1243
+ # plt.ylim(-self.threshold, self.threshold)
1244
+ # plt.xlabel("x [Pixel]")
1245
+ # plt.ylabel("Residual [m/s]")
1246
+
1247
+ plt.show()
1248
+
1249
+ def _find_peaks(self, comb):
1250
+ # Find peaks in the comb spectrum
1251
+ # Run find_peak twice
1252
+ # once to find the average distance between peaks
1253
+ # once for real (disregarding close peaks)
1254
+ c = comb - np.ma.min(comb)
1255
+ width = self.lfc_peak_width
1256
+ height = np.ma.median(c)
1257
+ peaks, _ = signal.find_peaks(c, height=height, width=width)
1258
+ distance = np.median(np.diff(peaks)) // 4
1259
+ peaks, _ = signal.find_peaks(c, height=height, distance=distance, width=width)
1260
+
1261
+ # Fit peaks with gaussian to get accurate position
1262
+ new_peaks = peaks.astype(float)
1263
+ width = np.mean(np.diff(peaks)) // 2
1264
+ for j, p in enumerate(peaks):
1265
+ idx = p + np.arange(-width, width + 1, 1)
1266
+ idx = np.clip(idx, 0, len(c) - 1).astype(int)
1267
+ try:
1268
+ coef = util.gaussfit3(np.arange(len(idx)), c[idx])
1269
+ new_peaks[j] = coef[1] + p - width
1270
+ except RuntimeError:
1271
+ new_peaks[j] = p
1272
+
1273
+ n = np.arange(len(peaks))
1274
+
1275
+ # keep peaks within the range
1276
+ mask = (new_peaks > 0) & (new_peaks < len(c))
1277
+ n, new_peaks = n[mask], new_peaks[mask]
1278
+
1279
+ return n, new_peaks
1280
+
1281
+ def calculate_AIC(self, lines, wave_solution):
1282
+ if self.step_mode:
1283
+ if self.dimensionality == "1D":
1284
+ k = 1
1285
+ for _, v in wave_solution.items():
1286
+ k += np.size(v[0])
1287
+ k += np.size(v[1])
1288
+ elif self.dimensionality == "2D":
1289
+ k = 1
1290
+ poly_coef, steps_coef = wave_solution
1291
+ for _, v in steps_coef.items():
1292
+ k += np.size(v)
1293
+ k += np.size(poly_coef)
1294
+ else:
1295
+ k = np.size(wave_solution) + 1
1296
+
1297
+ # We get the residuals in velocity space
1298
+ # but need to remove the speed of light component, to get dimensionless parameters
1299
+ x = lines["posm"]
1300
+ y = lines["order"]
1301
+ ~lines["flag"]
1302
+ solution = self.evaluate_solution(x, y, wave_solution)
1303
+ rss = (solution - lines["wll"]) / lines["wll"]
1304
+
1305
+ # rss = self.calculate_residual(wave_solution, lines)
1306
+ # rss /= speed_of_light
1307
+ n = rss.size
1308
+ rss = np.ma.sum(rss**2)
1309
+
1310
+ # As per Wikipedia https://en.wikipedia.org/wiki/Akaike_information_criterion
1311
+ logl = np.log(rss)
1312
+ aic = 2 * k + n * logl
1313
+ self.logl = logl
1314
+ self.aicc = aic + (2 * k**2 + 2 * k) / (n - k - 1)
1315
+ self.aic = aic
1316
+ return aic
1317
+
1318
+ def execute(self, obs, lines):
1319
+ """
1320
+ Perform the whole wavelength calibration procedure with the current settings
1321
+
1322
+ Parameters
1323
+ ----------
1324
+ obs : array of shape (nord, ncol)
1325
+ observed image
1326
+ lines : recarray of shape (nlines,)
1327
+ reference linelist
1328
+
1329
+ Returns
1330
+ -------
1331
+ wave_img : array of shape (nord, ncol)
1332
+ Wavelength solution for each pixel
1333
+
1334
+ Raises
1335
+ ------
1336
+ NotImplementedError
1337
+ If polarimitry flag is set
1338
+ """
1339
+
1340
+ if self.polarim:
1341
+ raise NotImplementedError("polarized orders not implemented yet")
1342
+
1343
+ self.nord, self.ncol = obs.shape
1344
+ lines = LineList(lines)
1345
+ if self.element is not None:
1346
+ try:
1347
+ self.atlas = LineAtlas(self.element, self.medium)
1348
+ except FileNotFoundError:
1349
+ logger.warning("No Atlas file found for element %s", self.element)
1350
+ self.atlas = None
1351
+ except:
1352
+ self.atlas = None
1353
+ else:
1354
+ self.atlas = None
1355
+
1356
+ obs, lines = self.normalize(obs, lines)
1357
+ # Step 1: align obs and reference
1358
+ lines = self.align(obs, lines)
1359
+
1360
+ # Keep original positions for reference
1361
+ lines["posc"] = np.copy(lines["posm"])
1362
+
1363
+ # Step 2: Locate the lines on the detector, and update the pixel position
1364
+ # lines["flag"] = True
1365
+ lines = self.fit_lines(obs, lines)
1366
+
1367
+ for i in range(self.iterations):
1368
+ logger.info(f"Wavelength calibration iteration: {i}")
1369
+ # Step 3: Create a wavelength solution on known lines
1370
+ wave_solution = self.build_2d_solution(lines)
1371
+ wave_img = self.make_wave(wave_solution)
1372
+ # Step 4: Identify lines that fit into the solution
1373
+ lines = self.auto_id(obs, wave_img, lines)
1374
+ # Step 5: Reject outliers
1375
+ lines = self.reject_lines(lines)
1376
+ # lines = self.reject_lines(lines)
1377
+
1378
+ logger.info(
1379
+ "Number of lines used for wavelength calibration: %i",
1380
+ np.count_nonzero(lines["flag"]),
1381
+ )
1382
+
1383
+ # Step 6: build final 2d solution
1384
+ wave_solution = self.build_2d_solution(lines, plot=self.plot)
1385
+ wave_img = self.make_wave(wave_solution)
1386
+
1387
+ if self.plot:
1388
+ self.plot_results(wave_img, obs)
1389
+
1390
+ aic = self.calculate_AIC(lines, wave_solution)
1391
+ logger.info("AIC of wavelength fit: %f", aic)
1392
+
1393
+ # np.savez("cs_lines.npz", cs_lines=lines.data)
1394
+
1395
+ return wave_img, wave_solution, lines
1396
+
1397
+
1398
+ class WavelengthCalibrationComb(WavelengthCalibration):
1399
+ def execute(self, comb, wave, lines=None):
1400
+ self.nord, self.ncol = comb.shape
1401
+
1402
+ # TODO give everything better names
1403
+ pixel, order, wavelengths = [], [], []
1404
+ n_all, f_all = [], []
1405
+ comb = np.ma.masked_array(comb, mask=comb <= 0)
1406
+
1407
+ for i in range(self.nord):
1408
+ # Find Peak positions in current order
1409
+ n, peaks = self._find_peaks(comb[i])
1410
+
1411
+ # Determine the n-offset of this order, relative to the anchor frequency
1412
+ # Use the existing absolute wavelength calibration as reference
1413
+ y_ord = np.full(len(peaks), i)
1414
+ w_old = interp1d(np.arange(len(wave[i])), wave[i], kind="cubic")(peaks)
1415
+ f_old = speed_of_light / w_old
1416
+
1417
+ # fr: repeating frequency
1418
+ # fd: anchor frequency of this order, needs to be shifted to the absolute reference frame
1419
+ fr = np.median(np.diff(f_old))
1420
+ fd = np.median(f_old % fr)
1421
+ n_raw = (f_old - fd) / fr
1422
+ n = np.round(n_raw)
1423
+
1424
+ if np.any(np.abs(n_raw - n) > 0.3):
1425
+ logger.warning(
1426
+ "Bad peaks detected in the frequency comb in order %i", i
1427
+ )
1428
+
1429
+ fr, fd = polyfit(n, f_old, deg=1)
1430
+
1431
+ n_offset = 0
1432
+ # The first order is used as the baseline for all other orders
1433
+ # The choice is arbitrary and doesn't matter
1434
+ if i == 0:
1435
+ f0 = fd
1436
+ n_offset = 0
1437
+ else:
1438
+ # n0: shift in n, relative to the absolute reference
1439
+ # shift n to the absolute grid, so that all peaks are given by the same f0
1440
+ n_offset = (f0 - fd) / fr
1441
+ n_offset = int(round(n_offset))
1442
+ n -= n_offset
1443
+ fd += n_offset * fr
1444
+
1445
+ n = np.abs(n)
1446
+
1447
+ n_all += [n]
1448
+ f_all += [f_old]
1449
+ pixel += [peaks]
1450
+ order += [y_ord]
1451
+
1452
+ logger.debug(
1453
+ "LFC Order: %i, f0: %.3f, fr: %.5f, n0: %.2f", i, fd, fr, n_offset
1454
+ )
1455
+
1456
+ # Here we postualte that m * lambda = const
1457
+ # where m is the peak number
1458
+ # this is the result of the grating equation
1459
+ # at least const is roughly constant for neighbouring peaks
1460
+ correct = True
1461
+ if correct:
1462
+ w_all = [speed_of_light / f for f in f_all]
1463
+ mw_all = [m * w for m, w in zip(n_all, w_all, strict=False)]
1464
+ y = np.concatenate(mw_all)
1465
+ gap = np.median(y)
1466
+
1467
+ corr = np.zeros(self.nord)
1468
+ for i in range(self.nord):
1469
+ corri = gap / w_all[i] - n_all[i]
1470
+ corri = np.median(corri)
1471
+ corr[i] = np.round(corri)
1472
+ n_all[i] += corr[i]
1473
+
1474
+ logger.debug("LFC order offset correction: %s", corr)
1475
+
1476
+ for i in range(self.nord):
1477
+ coef = polyfit(n_all[i], n_all[i] * w_all[i], deg=5)
1478
+ mw = np.polyval(coef, n_all[i])
1479
+ w_all[i] = mw / n_all[i]
1480
+ f_all[i] = speed_of_light / w_all[i]
1481
+
1482
+ # Merge Data
1483
+ n_all = np.concatenate(n_all)
1484
+ f_all = np.concatenate(f_all)
1485
+ pixel = np.concatenate(pixel)
1486
+ order = np.concatenate(order)
1487
+
1488
+ # Fit f0 and fr to all data
1489
+ # (fr, f0), cov = np.polyfit(n_all, f_all, deg=1, cov=True)
1490
+ fr, f0 = polyfit(n_all, f_all, deg=1)
1491
+
1492
+ logger.debug("Laser Frequency Comb Anchor Frequency: %.3f 10**10 Hz", f0)
1493
+ logger.debug("Laser Frequency Comb Repeating Frequency: %.5f 10**10 Hz", fr)
1494
+
1495
+ # All peaks are then given by f0 + n * fr
1496
+ wavelengths = speed_of_light / (f0 + n_all * fr)
1497
+
1498
+ flag = np.full(len(wavelengths), True)
1499
+ laser_lines = np.rec.fromarrays(
1500
+ (wavelengths, pixel, pixel, order, flag),
1501
+ names=("wll", "posm", "posc", "order", "flag"),
1502
+ )
1503
+
1504
+ # Use now better resolution to find the new solution
1505
+ # A single pass of discarding outliers should be enough
1506
+ coef = self.build_2d_solution(laser_lines)
1507
+ # resid = self.calculate_residual(coef, laser_lines)
1508
+ # laser_lines["flag"] = np.abs(resid) < self.threshold
1509
+ # coef = self.build_2d_solution(laser_lines)
1510
+ new_wave = self.make_wave(coef)
1511
+
1512
+ self.calculate_AIC(laser_lines, coef)
1513
+
1514
+ self.n_lines_good = np.count_nonzero(laser_lines["flag"])
1515
+ logger.info(
1516
+ f"Laser Frequency Comb solution based on {self.n_lines_good} lines."
1517
+ )
1518
+ if self.plot:
1519
+ residual = wave - new_wave
1520
+ residual = residual.ravel()
1521
+
1522
+ area = np.percentile(residual, (32, 50, 68))
1523
+ area = area[0] - 5 * (area[1] - area[0]), area[0] + 5 * (area[2] - area[1])
1524
+ plt.hist(residual, bins=100, range=area)
1525
+ title = "ThAr - LFC"
1526
+ if self.plot_title is not None:
1527
+ title = f"{self.plot_title}\n{title}"
1528
+ plt.title(title)
1529
+ plt.xlabel(r"$\Delta\lambda$ [Å]")
1530
+ plt.ylabel("N")
1531
+ plt.show()
1532
+
1533
+ if self.plot:
1534
+ if lines is not None:
1535
+ self.plot_residuals(
1536
+ lines,
1537
+ coef,
1538
+ title="GasLamp Line Residuals in the Laser Frequency Comb Solution",
1539
+ )
1540
+ self.plot_residuals(
1541
+ laser_lines,
1542
+ coef,
1543
+ title="Laser Frequency Comb Peak Residuals in the LFC Solution",
1544
+ )
1545
+
1546
+ if self.plot:
1547
+ wave_img = wave
1548
+ title = "Difference between GasLamp Solution and Laser Frequency Comb solution\nEach plot shows one order"
1549
+ if self.plot_title is not None:
1550
+ title = f"{self.plot_title}\n{title}"
1551
+ plt.suptitle(title)
1552
+ for i in range(len(new_wave)):
1553
+ plt.subplot(len(new_wave) // 4 + 1, 4, i + 1)
1554
+ plt.plot(wave_img[i] - new_wave[i])
1555
+ plt.show()
1556
+
1557
+ if self.plot:
1558
+ self.plot_results(new_wave, comb)
1559
+
1560
+ return new_wave
1561
+
1562
+
1563
+ class WavelengthCalibrationInitialize(WavelengthCalibration):
1564
+ def __init__(
1565
+ self,
1566
+ degree=2,
1567
+ plot=False,
1568
+ plot_title="Wavecal Initial",
1569
+ wave_delta=20,
1570
+ nwalkers=100,
1571
+ steps=50_000,
1572
+ resid_delta=1000,
1573
+ cutoff=5,
1574
+ smoothing=0,
1575
+ element="thar",
1576
+ medium="vac",
1577
+ ):
1578
+ super().__init__(
1579
+ degree=degree,
1580
+ element=element,
1581
+ medium=medium,
1582
+ plot=plot,
1583
+ plot_title=plot_title,
1584
+ dimensionality="1D",
1585
+ )
1586
+ #:float: wavelength uncertainty on the initial guess in Angstrom
1587
+ self.wave_delta = wave_delta
1588
+ #:int: number of walkers in the MCMC
1589
+ self.nwalkers = nwalkers
1590
+ #:int: number of steps in the MCMC
1591
+ self.steps = steps
1592
+ #:float: residual uncertainty allowed when matching observation with known lines
1593
+ self.resid_delta = resid_delta
1594
+ #:float: gaussian smoothing applied to the wavecal spectrum before the MCMC in pixel scale, disable it by setting it to 0
1595
+ self.smoothing = smoothing
1596
+ #:float: minimum value in the spectrum to be considered a spectral line, if the value is above (or equal 1) it defines the percentile of the spectrum
1597
+ self.cutoff = cutoff
1598
+
1599
+ def get_cutoff(self, spectrum):
1600
+ if self.cutoff == 0:
1601
+ cutoff = None
1602
+ elif self.cutoff < 1:
1603
+ cutoff = self.cutoff
1604
+ else:
1605
+ cutoff = np.nanpercentile(spectrum[spectrum != 0], self.cutoff)
1606
+ return cutoff
1607
+
1608
+ def normalize(self, spectrum):
1609
+ smoothing = self.smoothing
1610
+ spectrum = np.copy(spectrum)
1611
+ spectrum -= np.nanmedian(spectrum)
1612
+ if smoothing != 0:
1613
+ spectrum = gaussian_filter1d(spectrum, smoothing)
1614
+ spectrum[spectrum < 0] = 0
1615
+ spectrum /= np.max(spectrum)
1616
+ return spectrum
1617
+
1618
+ def determine_wavelength_coefficients(
1619
+ self,
1620
+ spectrum,
1621
+ atlas,
1622
+ wave_range,
1623
+ ) -> np.ndarray:
1624
+ """
1625
+ Determines the wavelength polynomial coefficients of a spectrum,
1626
+ based on an line atlas with known spectral lines,
1627
+ and an initial guess for the wavelength range.
1628
+ The calculation uses an MCMC approach to sample the probability space and
1629
+ find the best cross correlation value, between observation and atlas.
1630
+
1631
+ Parameters
1632
+ ----------
1633
+ spectrum : array
1634
+ observed spectrum at each pixel
1635
+ atlas : LineAtlas
1636
+ atlas containing a known spectrum with wavelength and flux
1637
+ wave_range : 2-tuple
1638
+ initial wavelength guess (begin, end)
1639
+ degrees : int, optional
1640
+ number of degrees of the wavelength polynomial,
1641
+ lower numbers yield better results, by default 2
1642
+ w_range : float, optional
1643
+ uncertainty on the initial wavelength guess in Ansgtrom, by default 20
1644
+ nwalkers : int, optional
1645
+ number of walkers for the MCMC, more is better but increases
1646
+ the time, by default 100
1647
+ steps : int, optional
1648
+ number of steps in the MCMC per walker, more is better but increases
1649
+ the time, by default 20_000
1650
+ plot : bool, optional
1651
+ whether to plot the results or not, by default False
1652
+
1653
+ Returns
1654
+ -------
1655
+ coef : array
1656
+ polynomial coefficients in numpy order
1657
+ """
1658
+ spectrum = np.asarray(spectrum)
1659
+
1660
+ assert self.degree >= 2, "The polynomial degree must be at least 2"
1661
+ assert spectrum.ndim == 1, "The spectrum should only have 1 dimension"
1662
+ assert self.wave_delta > 0, "The wavelength uncertainty needs to be positive"
1663
+
1664
+ n_features = spectrum.shape[0]
1665
+ n_output = ndim = self.degree + 1
1666
+
1667
+ # Normalize the spectrum, and copy it just in case
1668
+ spectrum = self.normalize(spectrum)
1669
+ cutoff = self.get_cutoff(spectrum)
1670
+
1671
+ # The pixel scale used for everything else
1672
+ x = np.arange(n_features)
1673
+ # Initial guess for the wavelength solution
1674
+ coef = np.zeros(n_output)
1675
+ coef[-1] = wave_range[0]
1676
+ coef[-2] = (wave_range[-1] - wave_range[0]) / n_features
1677
+
1678
+ # We scale every coefficient to roughly order 1
1679
+ # this is then in units of the maximum offset due to a change in this value
1680
+ # in angstrom
1681
+ w_scale = 1 / np.power(n_features, range(n_output))
1682
+ factors = w_scale[::-1]
1683
+ coef /= factors
1684
+
1685
+ # Here we define the functions we need for the MCMC
1686
+ def polyval_vectorize(p, x, where=None):
1687
+ n_poly, n_coef = p.shape
1688
+ n_points = x.shape[0]
1689
+ y = np.zeros((n_poly, n_points))
1690
+ if where is not None:
1691
+ for i in range(n_coef):
1692
+ y[where] *= x
1693
+ y[where] += p[where, i, None]
1694
+ else:
1695
+ for i in range(n_coef):
1696
+ y *= x
1697
+ y += p[:, i, None]
1698
+ return y
1699
+
1700
+ def log_prior(p):
1701
+ prior = np.zeros(p.shape[0])
1702
+ prior[np.any(~np.isfinite(p), axis=1)] = -np.inf
1703
+ prior[np.any(np.abs(p - coef) > self.wave_delta, axis=1)] = -np.inf
1704
+ return prior
1705
+
1706
+ def log_prior_2(w):
1707
+ # Chech that w is increasing
1708
+ prior = np.zeros(w.shape[0])
1709
+ prior[np.any(w[:, 1:] < w[:, :-1], axis=1)] = -np.inf
1710
+ prior[w[:, 0] < wave_range[0] - self.wave_delta] = -np.inf
1711
+ prior[w[:, -1] > wave_range[1] + self.wave_delta] = -np.inf
1712
+ return prior
1713
+
1714
+ def log_prob(p):
1715
+ # Check that p is within bounds
1716
+ prior = log_prior(p)
1717
+ where = np.isfinite(prior)
1718
+ # Calculate the wavelength scale
1719
+ w = polyval_vectorize(p * factors, x, where=where)
1720
+ # Check that it is monotonically increasing
1721
+ prior += log_prior_2(w)
1722
+ where = np.isfinite(prior)
1723
+
1724
+ y = np.zeros((p.shape[0], x.shape[0]))
1725
+ y[where, :] = np.interp(w[where, :], atlas.wave, atlas.flux)
1726
+ y[where, :] /= np.max(y[where, :], axis=1)[:, None]
1727
+ # This is the cross correlation value squared
1728
+ cross = np.sum(y * spectrum, axis=1) ** 2
1729
+ # chi2 = - np.sum((y - spectrum)**2, axis=1)
1730
+ # chi2 = - np.sum((np.where(y > 0.01, 1, 0) - np.where(spectrum > 0.01, 1, 0))**2, axis=1)
1731
+ # this is the same as above, but a lot faster thanks to the magic of bitwise xor
1732
+ if cutoff is not None:
1733
+ chi2 = (y > cutoff) ^ (spectrum > cutoff)
1734
+ chi2 = -np.count_nonzero(chi2, axis=1) / 20
1735
+ else:
1736
+ chi2 = -np.sum((y - spectrum) ** 2, axis=1) / 20
1737
+ return prior + cross + chi2
1738
+
1739
+ p0 = np.zeros((self.nwalkers, ndim))
1740
+ p0 += coef[None, :]
1741
+ p0 += np.random.uniform(
1742
+ low=-self.wave_delta, high=self.wave_delta, size=(self.nwalkers, ndim)
1743
+ )
1744
+ sampler = emcee.EnsembleSampler(
1745
+ self.nwalkers,
1746
+ ndim,
1747
+ log_prob,
1748
+ vectorize=True,
1749
+ moves=[(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)],
1750
+ )
1751
+ sampler.run_mcmc(p0, self.steps, progress=True)
1752
+
1753
+ tau = sampler.get_autocorr_time(quiet=True)
1754
+ burnin = int(2 * np.max(tau))
1755
+ thin = int(0.5 * np.min(tau))
1756
+ samples = sampler.get_chain(discard=burnin, thin=thin, flat=True)
1757
+
1758
+ low, mid, high = np.percentile(samples, [32, 50, 68], axis=0)
1759
+ coef = mid * factors
1760
+
1761
+ if self.plot:
1762
+ corner.corner(samples, truths=mid)
1763
+ plt.show()
1764
+
1765
+ wave = np.polyval(coef, x)
1766
+ y = np.interp(wave, atlas.wave, atlas.flux)
1767
+ y /= np.max(y)
1768
+ plt.plot(wave, spectrum)
1769
+ plt.plot(wave, y)
1770
+ plt.show()
1771
+
1772
+ return coef
1773
+
1774
+ def create_new_linelist_from_solution(
1775
+ self,
1776
+ spectrum,
1777
+ wavelength,
1778
+ atlas,
1779
+ order,
1780
+ ) -> LineList:
1781
+ """
1782
+ Create a new linelist based on an existing wavelength solution for a spectrum,
1783
+ and a line atlas with known lines. The linelist is the one used by the rest of
1784
+ PyReduce wavelength calibration.
1785
+
1786
+ Observed lines are matched with the lines in the atlas to
1787
+ improve the wavelength solution.
1788
+
1789
+ Parameters
1790
+ ----------
1791
+ spectrum : array
1792
+ Observed spectrum at each pixel
1793
+ wavelength : array
1794
+ Wavelength of spectrum at each pixel
1795
+ atlas : LineAtlas
1796
+ Atlas with wavelength of known lines
1797
+ order : int
1798
+ Order of the spectrum within the detector
1799
+ resid_delta : float, optional
1800
+ Maximum residual allowed between a peak and the closest line in the atlas,
1801
+ to still match them, in m/s, by default 1000.
1802
+
1803
+ Returns
1804
+ -------
1805
+ linelist : LineList
1806
+ new linelist with lines from this order
1807
+ """
1808
+ # The new linelist
1809
+ linelist = LineList()
1810
+ spectrum = np.asarray(spectrum)
1811
+ wavelength = np.asarray(wavelength)
1812
+
1813
+ assert self.resid_delta > 0, "Residuals Delta must be positive"
1814
+ assert spectrum.ndim == 1, "Spectrum must have only 1 dimension"
1815
+ assert wavelength.ndim == 1, "Wavelength must have only 1 dimension"
1816
+ assert spectrum.size == wavelength.size, (
1817
+ "Spectrum and Wavelength must have the same size"
1818
+ )
1819
+
1820
+ n_features = spectrum.shape[0]
1821
+ x = np.arange(n_features)
1822
+
1823
+ # Normalize just in case
1824
+ spectrum = self.normalize(spectrum)
1825
+ cutoff = self.get_cutoff(spectrum)
1826
+
1827
+ # TODO: make this use another function, and pass the hight as a parameter
1828
+ scopy = np.copy(spectrum)
1829
+ if cutoff is not None:
1830
+ scopy[scopy < cutoff] = 0
1831
+ _, peaks = self._find_peaks(scopy)
1832
+
1833
+ peak_wave = np.interp(peaks, x, wavelength)
1834
+ peak_height = np.interp(peaks, x, spectrum)
1835
+
1836
+ # Here we only look at the lines within range
1837
+ atlas_linelist = atlas.linelist[
1838
+ (atlas.linelist["wave"] > wavelength[0])
1839
+ & (atlas.linelist["wave"] < wavelength[-1])
1840
+ ]
1841
+
1842
+ residuals = np.zeros_like(peak_wave)
1843
+ for i, pw in enumerate(peak_wave):
1844
+ resid = np.abs(pw - atlas_linelist["wave"])
1845
+ j = np.argmin(resid)
1846
+ residuals[i] = resid[j] / pw * speed_of_light
1847
+ if residuals[i] < self.resid_delta:
1848
+ linelist.add_line(
1849
+ atlas_linelist["wave"][j],
1850
+ order,
1851
+ peaks[i],
1852
+ 3,
1853
+ peak_height[i],
1854
+ True,
1855
+ )
1856
+
1857
+ return linelist
1858
+
1859
+ def execute(self, spectrum, wave_range) -> LineList:
1860
+ atlas = LineAtlas(self.element, self.medium)
1861
+ linelist = LineList()
1862
+ orders = range(spectrum.shape[0])
1863
+ x = np.arange(spectrum.shape[1])
1864
+ for order in orders:
1865
+ spec = spectrum[order]
1866
+ wrange = wave_range[order]
1867
+ coef = self.determine_wavelength_coefficients(spec, atlas, wrange)
1868
+ wave = np.polyval(coef, x)
1869
+ linelist_loc = self.create_new_linelist_from_solution(
1870
+ spec, wave, atlas, order
1871
+ )
1872
+ linelist.append(linelist_loc)
1873
+ return linelist