ppdmod 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ppdmod/data.py ADDED
@@ -0,0 +1,485 @@
1
+ import copy
2
+ from functools import partial
3
+ from pathlib import Path
4
+ from types import SimpleNamespace
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ import astropy.units as u
8
+ import numpy as np
9
+ from astropy.io import fits
10
+ from numpy.typing import NDArray
11
+ from scipy.stats import circmean, circstd
12
+
13
+ from .options import OPTIONS
14
+ from .utils import (
15
+ get_band,
16
+ get_binning_windows,
17
+ get_indices,
18
+ get_t3_indices,
19
+ )
20
+
21
+
22
+ def mean_and_err(key: str):
23
+ """Gets the mean and std functions."""
24
+ t3_kwargs, axis = {"low": -180, "high": 180, "nan_policy": "omit"}, -1
25
+ if key == "t3":
26
+ mean_func = partial(circmean, axis=axis, **t3_kwargs)
27
+ err_func = lambda x: np.sqrt(
28
+ np.ma.sum(x**2, axis=axis).filled(np.nan)
29
+ + circstd(x, axis=axis, **t3_kwargs) ** 2
30
+ )
31
+ else:
32
+ mean_func = lambda x: np.ma.mean(x, axis=axis).filled(np.nan)
33
+ err_func = lambda x: np.ma.sqrt(
34
+ np.ma.sum(x**2, axis=axis) + np.ma.std(x, axis=axis) ** 2
35
+ ).filled(np.nan)
36
+
37
+ return mean_func, err_func
38
+
39
+
40
+ class ReadoutFits:
41
+ """All functionality to work with (.fits) or flux files.
42
+
43
+ Parameters
44
+ ----------
45
+ fits_file : pathlib.Path
46
+ The path to the (.fits) or flux file.
47
+ """
48
+
49
+ def __init__(self, fits_file: Path) -> None:
50
+ """The class's constructor."""
51
+ self.fits_file = Path(fits_file)
52
+ self.name = self.fits_file.name
53
+ self.band = "unknown"
54
+ self.read_file()
55
+
56
+ def read_file(self) -> None:
57
+ """Reads the data of the (.fits)-files into vectors."""
58
+ with fits.open(self.fits_file) as hdul:
59
+ instrument = None
60
+ if "instrume" in hdul[0].header:
61
+ instrument = hdul[0].header["instrume"].lower()
62
+ sci_index = OPTIONS.data.gravity.index if instrument == "gravity" else None
63
+ self.wl = (hdul["oi_wavelength", sci_index].data["eff_wave"] * u.m).to(u.um)
64
+ self.band = get_band(self.wl)
65
+ self.array = (
66
+ "ats" if "AT" in hdul["oi_array"].data["tel_name"][0] else "uts"
67
+ )
68
+ self.flux = self.read_into_namespace(hdul, "flux", sci_index)
69
+ self.t3 = self.read_into_namespace(hdul, "t3", sci_index)
70
+ self.vis = self.read_into_namespace(hdul, "vis", sci_index)
71
+ self.vis2 = self.read_into_namespace(hdul, "vis2", sci_index)
72
+
73
+ def read_into_namespace(
74
+ self,
75
+ hdul: fits.HDUList,
76
+ key: str,
77
+ sci_index: int | None = None,
78
+ ) -> SimpleNamespace:
79
+ """Reads a (.fits) card into a namespace."""
80
+ try:
81
+ hdu = hdul[f"oi_{key}", sci_index]
82
+ data = hdu.data
83
+ except KeyError:
84
+ return SimpleNamespace(
85
+ val=np.array([]),
86
+ err=np.array([]),
87
+ u=np.array([]).reshape(1, -1),
88
+ v=np.array([]).reshape(1, -1),
89
+ )
90
+
91
+ if key == "flux":
92
+ try:
93
+ return SimpleNamespace(
94
+ val=np.ma.masked_array(data["fluxdata"], mask=data["flag"]),
95
+ err=np.ma.masked_array(data["fluxerr"], mask=data["flag"]),
96
+ )
97
+ except KeyError:
98
+ return SimpleNamespace(
99
+ val=np.array([]), err=np.array([]), flag=np.array([])
100
+ )
101
+
102
+ # TODO: Might err if vis is not included in datasets
103
+ if key in ["vis", "vis2"]:
104
+ if key == "vis":
105
+ val_key, err_key = "visamp", "visamperr"
106
+ else:
107
+ val_key, err_key = "vis2data", "vis2err"
108
+
109
+ ucoord = data["ucoord"].reshape(1, -1).astype(OPTIONS.data.dtype.real)
110
+ vcoord = data["vcoord"].reshape(1, -1).astype(OPTIONS.data.dtype.real)
111
+ return SimpleNamespace(
112
+ val=np.ma.masked_array(data[val_key], mask=data["flag"]),
113
+ err=np.ma.masked_array(data[err_key], mask=data["flag"]),
114
+ u=np.round(ucoord, 2),
115
+ v=np.round(vcoord, 2),
116
+ )
117
+
118
+ u1, u2 = map(lambda x: data[f"u{x}coord"], ["1", "2"])
119
+ v1, v2 = map(lambda x: data[f"v{x}coord"], ["1", "2"])
120
+ u123 = np.array([u1, u2, u1 + u2]).astype(OPTIONS.data.dtype.real)
121
+ v123 = np.array([v1, v2, v1 + v2]).astype(OPTIONS.data.dtype.real)
122
+ return SimpleNamespace(
123
+ val=np.ma.masked_array(data["t3phi"], mask=data["flag"]),
124
+ err=np.ma.masked_array(data["t3phierr"], mask=data["flag"]),
125
+ u123=np.round(u123, 2),
126
+ v123=np.round(v123, 2),
127
+ )
128
+
129
+ def get_data_for_wavelength(
130
+ self,
131
+ wavelength: u.Quantity,
132
+ key: str,
133
+ do_bin: bool = True,
134
+ min_err: Dict[str, float] | None = None,
135
+ ) -> Tuple[NDArray[Any], NDArray[Any]]:
136
+ """Gets the data for the given wavelengths.
137
+
138
+ If there is no data for the given wavelengths,
139
+ a np.nan array is returned of the shape
140
+ (wavelength.size, data.shape[0]).
141
+
142
+ Parameters
143
+ ----------
144
+ wavelength : astropy.units.um
145
+ The wavelengths to be returned.
146
+ key : str
147
+ The key (header) of the data to be returned.
148
+ do_bin : bool, optional
149
+ If the data should be binned or not.
150
+ min_err : dict, optional
151
+ A dictionary containing the minimum error for the data.
152
+ Will be applied to the data before binning.
153
+
154
+ Returns
155
+ -------
156
+ numpy.typing.NDArray
157
+ The values for the given wavelengths.
158
+ numpy.typing.NDArray
159
+ The errors for the given wavelengths.
160
+ """
161
+ if do_bin:
162
+ # TODO: Check if binning is done correctly -> No duplicates, etc.
163
+ # Should there be an error if the bins overlap? -> Maybe a good idea
164
+ # Or at least a warning
165
+ windows = get_binning_windows(wavelength)
166
+ indices = get_indices(wavelength, array=self.wl, windows=windows)
167
+ else:
168
+ indices = [np.where(wl == self.wl)[0] for wl in wavelength]
169
+
170
+ val, err = getattr(self, key).val, getattr(self, key).err
171
+ if all(index.size == 0 for index in indices):
172
+ nan_val = np.full((wavelength.size, val.shape[0]), np.nan)
173
+ nan_err = np.full((wavelength.size, err.shape[0]), np.nan)
174
+ if "flux":
175
+ wl_val, wl_err = nan_val, nan_err
176
+ else:
177
+ wl_val, wl_err = nan_val[:, :1], nan_err[:, :1]
178
+
179
+ return wl_val, wl_err
180
+
181
+ if min_err is not None:
182
+ error_floor = min_err.get(key, 0.0)
183
+ if key == "t3":
184
+ err[err < error_floor] = error_floor
185
+ else:
186
+ ind = np.where(err < val * error_floor)
187
+ err[ind] = val[ind] * error_floor
188
+
189
+ mean_func, err_func = mean_and_err(key)
190
+ if do_bin:
191
+ wl_val = [mean_func(val[:, index]) for index in indices]
192
+ wl_err = [err_func(err[:, index]) for index in indices]
193
+ else:
194
+ wl_val = [
195
+ (
196
+ val[:, index].filled(np.nan).squeeze(-1)
197
+ if index.size != 0
198
+ else np.full((val.shape[0],), np.nan)
199
+ )
200
+ for index in indices
201
+ ]
202
+ wl_err = [
203
+ (
204
+ err[:, index].filled(np.nan).squeeze(-1)
205
+ if index.size != 0
206
+ else np.full((err.shape[0],), np.nan)
207
+ )
208
+ for index in indices
209
+ ]
210
+
211
+ wl_val = np.array(wl_val, dtype=OPTIONS.data.dtype.real)
212
+ wl_err = np.array(wl_err, dtype=OPTIONS.data.dtype.real)
213
+ return wl_val, wl_err
214
+
215
+
216
+ def get_all_wavelengths(readouts: List[ReadoutFits] | None = None) -> NDArray[Any]:
217
+ """Gets all wavelengths from the readouts."""
218
+ readouts = OPTIONS.data.readouts if readouts is None else readouts
219
+ return np.sort(np.unique(np.concatenate(list(map(lambda x: x.wl, readouts)))))
220
+
221
+
222
+ def set_fit_wavelengths(
223
+ wavelengths: u.Quantity[u.um] | None = None,
224
+ ) -> str | NDArray[Any]:
225
+ """Sets the wavelengths to be fitted for as a global option.
226
+
227
+ If called without a wavelength and all set to False, it will clear
228
+ the fit wavelengths.
229
+
230
+ Parameters
231
+ ----------
232
+ wavelengths : numpy.typing.NDArray, optional
233
+ The wavelengths to be fitted.
234
+
235
+ Returns
236
+ -------
237
+ str or numpy.typing.NDArray
238
+ The wavelengths to be fitted as a numpy array or "all" if all are to be
239
+ fitted.
240
+ """
241
+ OPTIONS.fit.wls = None
242
+ if wavelengths is None:
243
+ return
244
+
245
+ wavelengths = u.Quantity(wavelengths, u.um)
246
+ if wavelengths.shape == ():
247
+ wavelengths = wavelengths.reshape((wavelengths.size,))
248
+ OPTIONS.fit.wls = wavelengths.flatten()
249
+ return OPTIONS.fit.wls
250
+
251
+
252
+ def get_counts_data() -> np.ndarray:
253
+ """Gets the number of data points for the flux,
254
+ visibilities and closure phases."""
255
+ counts = []
256
+ for key in OPTIONS.fit.data:
257
+ counts.append(getattr(OPTIONS.data, key).val[0].compressed().size)
258
+
259
+ return np.array(counts)
260
+
261
+
262
+ def clear_data() -> List[str]:
263
+ """Clears data and returns the keys of the cleared data."""
264
+ OPTIONS.fit.wls = None
265
+ OPTIONS.data.readouts = []
266
+
267
+ for key in ["flux", "vis", "vis2", "t3"]:
268
+ data = getattr(OPTIONS.data, key)
269
+ data.val, data.err = [np.array([]) for _ in range(2)]
270
+ if key in ["vis", "vis2"]:
271
+ data.u, data.v = [np.array([]).reshape(1, -1) for _ in range(2)]
272
+ elif key in "t3":
273
+ data.u123, data.v123 = [np.array([]) for _ in range(2)]
274
+
275
+ return ["flux", "vis", "vis2", "t3"]
276
+
277
+
278
+ def read_data(
279
+ data_to_read: List[str], wavelengths: u.um, min_err: Dict[str, float] | None = None
280
+ ) -> None:
281
+ """Reads in the data from the keys."""
282
+ for readout in OPTIONS.data.readouts:
283
+ for key in data_to_read:
284
+ data = getattr(OPTIONS.data, key)
285
+ data_readout = getattr(readout, key)
286
+ val, err = readout.get_data_for_wavelength(
287
+ wavelengths, key, OPTIONS.data.do_bin, min_err
288
+ )
289
+ if data.val.size == 0:
290
+ data.val, data.err = val, err
291
+ else:
292
+ data.val = np.hstack((data.val, val))
293
+ data.err = np.hstack((data.err, err))
294
+
295
+ if key in ["vis", "vis2"]:
296
+ if data.u.size == 0:
297
+ data.u = np.insert(data_readout.u, 0, 0, axis=1)
298
+ data.v = np.insert(data_readout.v, 0, 0, axis=1)
299
+ else:
300
+ data.u = np.hstack((data.u, data_readout.u))
301
+ data.v = np.hstack((data.v, data_readout.v))
302
+
303
+ elif key == "t3":
304
+ if data.u123.size == 0:
305
+ tmp_u123 = np.insert(data_readout.u123, 0, 0, axis=1)
306
+ tmp_v123 = np.insert(data_readout.v123, 0, 0, axis=1)
307
+ data.u123, data.v123 = tmp_u123, tmp_v123
308
+ data.u, data.v, data.i123 = get_t3_indices(data.u123, data.v123)
309
+ else:
310
+ ucoord, vcoord, i123 = get_t3_indices(
311
+ data_readout.u123, data_readout.v123
312
+ )
313
+ data.u123 = np.hstack((data.u123, data_readout.u123))
314
+ data.v123 = np.hstack((data.v123, data_readout.v123))
315
+ data.u = np.hstack((data.u, ucoord))
316
+ data.v = np.hstack((data.v, vcoord))
317
+ data.i123 = np.hstack((data.i123, i123 + data.i123.max() + 1))
318
+
319
+ # NOTE: Make this work with no time input files again
320
+ nt = OPTIONS.data.nt
321
+ for key in data_to_read:
322
+ data = getattr(OPTIONS.data, key)
323
+ data.val = np.tile(
324
+ np.ma.masked_invalid(data.val).filled(np.nan),
325
+ (nt,) + (1,) * len(data.val.shape),
326
+ ).tolist()
327
+ data.err = np.tile(
328
+ np.ma.masked_invalid(data.err).filled(np.nan), (nt,) + (1,) * len(data.err.shape)
329
+ ).tolist()
330
+ if key in ["vis", "vis2", "t3"]:
331
+ data.u = np.tile(data.u, (nt,) + (1,) * len(data.u.shape)).tolist()
332
+ data.v = np.tile(data.v, (nt,) + (1,) * len(data.v.shape)).tolist()
333
+ if key == "t3":
334
+ data.u123 = np.tile(data.u123, (nt,) + (1,) * len(data.u123.shape)).tolist()
335
+ data.v123 = np.tile(data.v123, (nt,) + (1,) * len(data.v123.shape)).tolist()
336
+ data.i123 = np.tile(data.i123, (nt,) + (1,) * len(data.i123.shape)).tolist()
337
+
338
+ # NOTE: This only works if all the time datasets have the same amount of (u, v)-coordinates
339
+ # and also if there is the same number of them.
340
+ for index, readout in enumerate(OPTIONS.data.readouts_t):
341
+ for key in data_to_read:
342
+ data = getattr(OPTIONS.data, key)
343
+ data_readout = getattr(readout, key)
344
+ val, err = readout.get_data_for_wavelength(
345
+ wavelengths, key, OPTIONS.data.do_bin, min_err
346
+ )
347
+ data.val[index] = np.hstack((data.val[index], val))
348
+ data.err[index] = np.hstack((data.err[index], err))
349
+
350
+ if key in ["vis", "vis2"]:
351
+ data.u[index] = np.hstack((data.u[index], data_readout.u))
352
+ data.v[index] = np.hstack((data.v[index], data_readout.v))
353
+ elif key == "t3":
354
+ ucoord, vcoord, i123 = get_t3_indices(
355
+ data_readout.u123, data_readout.v123
356
+ )
357
+ data.u[index] = np.hstack((data.u[index], ucoord))
358
+ data.v[index] = np.hstack((data.v[index], vcoord))
359
+ data.u123[index] = np.hstack((data.u123[index], data_readout.u123))
360
+ data.v123[index] = np.hstack((data.v123[index], data_readout.v123))
361
+ data.i123[index] = np.hstack(
362
+ (data.i123[index], i123 + np.max(data.i123[index]) + 1)
363
+ )
364
+
365
+ for key in data_to_read:
366
+ data = getattr(OPTIONS.data, key)
367
+ data.val = np.ma.masked_invalid(data.val)
368
+ data.err = np.ma.masked_invalid(data.err)
369
+
370
+ if key in ["vis", "vis2", "t3"]:
371
+ data.u, data.v = np.array(data.u), np.array(data.v)
372
+ if key == "t3":
373
+ data.u123, data.v123 = np.array(data.u123), np.array(data.v123)
374
+ data.i123 = np.array(data.i123)
375
+
376
+
377
+ def get_weights(kind="general") -> NDArray[Any]:
378
+ """Gets the weights either for the indiviudal band or the general ones for the observables."""
379
+ if kind == "general":
380
+ return np.array(
381
+ [getattr(OPTIONS.fit.weights, key).general for key in OPTIONS.fit.data]
382
+ )
383
+
384
+ return np.array(
385
+ [
386
+ [
387
+ getattr(getattr(OPTIONS.fit.weights, key), band)
388
+ for band in OPTIONS.fit.bands
389
+ ]
390
+ for key in OPTIONS.fit.data
391
+ ]
392
+ )
393
+
394
+
395
+ def set_weights(weights, weights_bands) -> None:
396
+ """Sets the weights from the input."""
397
+ if weights is not None:
398
+ if weights == "ndata":
399
+ ndata = get_counts_data()
400
+ weights = dict(zip(OPTIONS.fit.data, (ndata / ndata.max()) ** -1))
401
+
402
+ for key, weight in weights.items():
403
+ getattr(OPTIONS.fit.weights, key).general = weight
404
+
405
+ if weights_bands is not None:
406
+ for key, band_values in weights_bands.items():
407
+ for band, value in band_values.items():
408
+ setattr(getattr(OPTIONS.fit.weights, key), band, value)
409
+
410
+
411
+ def get_data(
412
+ fits_files: Path | List[Path] | None = [],
413
+ fits_files_t: Path | List[Path] | None = [],
414
+ wavelengths: str | u.Quantity[u.um] | None = None,
415
+ fit_data: List[str] = ["flux", "vis", "t3"],
416
+ weights: Dict[str, float] | str | None = None,
417
+ weights_bands: Dict[str, float] | None = None,
418
+ min_err: Dict[str, float] | None = None,
419
+ **kwargs,
420
+ ) -> SimpleNamespace:
421
+ """Sets the data as a global variable from the input files.
422
+
423
+ If called without parameters or recalled, it will clear the data.
424
+
425
+ Parameters
426
+ ----------
427
+ fits_files : list of pathlib.Path
428
+ Paths to (.fits)-files.
429
+ fits_files_t : list of pathlib.Path
430
+ Paths to time-dependent (.fits)-files.
431
+ wavelengts : str or numpy.ndarray
432
+ The wavelengths to be fitted as a numpy array or "all" if all are to be
433
+ fitted.
434
+ fit_data : list of str, optional
435
+ The data to be fitted.
436
+ weights : list of float, optional
437
+ The fitting weights of the interferometric datasets.
438
+ weights_band : list of float, optional
439
+ The fitting weights of the interferometric datasets individual bands.
440
+ set_std_err : list of str, optional
441
+ The data to be set the standard error from the variance of the datasets from.
442
+ min_err : float, optional
443
+ The error floor for the datasets. The keys of the dictionary need to correspond to the
444
+ entries for the fit data argument.
445
+ """
446
+ data_to_read = clear_data()
447
+ if fits_files is None:
448
+ return OPTIONS.data
449
+
450
+ if not isinstance(fits_files, (list, tuple, np.ndarray)):
451
+ fits_files = [fits_files]
452
+
453
+ if not isinstance(fits_files_t, (list, tuple, np.ndarray)):
454
+ fits_files_t = [fits_files_t]
455
+
456
+ OPTIONS.fit.data = fit_data
457
+ hduls = [fits.open(fits_file) for fits_file in fits_files]
458
+ hduls_t = [fits.open(fits_file) for fits_file in fits_files_t]
459
+ OPTIONS.data.hduls = [copy.deepcopy(hdul) for hdul in hduls]
460
+ OPTIONS.data.hduls_t = [copy.deepcopy(hdul) for hdul in hduls_t]
461
+ [hdul.close() for hdul in [*hduls, *hduls_t]]
462
+ OPTIONS.data.nt = len(hduls_t) if len(hduls_t) != 0 else 1
463
+
464
+ OPTIONS.data.readouts = list(map(ReadoutFits, fits_files))
465
+ OPTIONS.data.readouts_t = list(map(ReadoutFits, fits_files_t))
466
+
467
+ OPTIONS.data.bands = list(map(lambda x: x.band, OPTIONS.data.readouts))
468
+ if wavelengths == "all":
469
+ wavelengths = get_all_wavelengths(
470
+ [*OPTIONS.data.readouts, *OPTIONS.data.readouts_t]
471
+ )
472
+ OPTIONS.data.do_bin = False
473
+
474
+ if wavelengths is None:
475
+ raise ValueError("No wavelengths given and/or not 'all' specified!")
476
+
477
+ wavelengths = set_fit_wavelengths(wavelengths)
478
+ read_data(data_to_read, wavelengths, min_err)
479
+
480
+ for key in OPTIONS.fit.data:
481
+ for band in OPTIONS.fit.bands:
482
+ setattr(getattr(OPTIONS.fit.weights, key), band, 1)
483
+
484
+ set_weights(weights, weights_bands)
485
+ return OPTIONS.data