arvi 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/plots.py CHANGED
@@ -6,13 +6,101 @@ import matplotlib.collections
6
6
  import numpy as np
7
7
  import matplotlib
8
8
  import matplotlib.pyplot as plt
9
- from matplotlib.collections import LineCollection
10
9
  import mplcursors
11
10
 
12
11
  from astropy.timeseries import LombScargle
13
12
 
14
13
  from .setup_logger import logger
15
14
  from . import config
15
+ from .stats import wmean
16
+
17
+
18
+ class BlittedCursor:
19
+ """ A cross-hair cursor using blitting for faster redraw. """
20
+ def __init__(self, axes, vertical=True, horizontal=True, show_text=None,
21
+ transforms_x=None, transforms_y=None):
22
+ if isinstance(axes, matplotlib.axes.Axes):
23
+ axes = [axes]
24
+ self.axes = axes
25
+ self.background = None
26
+ self.vertical = vertical
27
+ self.horizontal = horizontal
28
+
29
+ self.transforms_x = [lambda x:x for _ in axes] if transforms_x is None else transforms_x
30
+ self.transforms_y = [lambda x:x for _ in axes] if transforms_y is None else transforms_y
31
+
32
+ if horizontal:
33
+ self.horizontal_line = [ax.axhline(color='k', lw=0.8, ls='--') for ax in axes]
34
+ if vertical:
35
+ self.vertical_line = [ax.axvline(color='k', lw=0.8, ls='--') for ax in axes]
36
+
37
+ self.show_text = show_text
38
+ if show_text is not None: # text location in axes coordinates
39
+ self.text = [ax.text(0.72, 0.9, '', transform=ax.transAxes) for ax in axes]
40
+
41
+ self._creating_background = False
42
+ for ax in axes:
43
+ ax.figure.canvas.mpl_connect('draw_event', self.on_draw)
44
+
45
+ def on_draw(self, event):
46
+ self.create_new_background()
47
+
48
+ def set_cross_hair_visible(self, visible):
49
+ if self.horizontal:
50
+ need_redraw = [line.get_visible() != visible for line in self.horizontal_line]
51
+ else:
52
+ need_redraw = [line.get_visible() != visible for line in self.vertical_line]
53
+ if self.horizontal:
54
+ [line.set_visible(visible) for line in self.horizontal_line]
55
+ if self.vertical:
56
+ [line.set_visible(visible) for line in self.vertical_line]
57
+ if self.show_text:
58
+ self.text.set_visible(visible)
59
+ return need_redraw
60
+
61
+ def create_new_background(self):
62
+ if self._creating_background:
63
+ # discard calls triggered from within this function
64
+ return
65
+ self._creating_background = True
66
+ self.set_cross_hair_visible(False)
67
+ for ax in self.axes:
68
+ ax.figure.canvas.draw()
69
+ self.backgrounds = [ax.figure.canvas.copy_from_bbox(ax.bbox) for ax in self.axes]
70
+ self.set_cross_hair_visible(True)
71
+ self._creating_background = False
72
+
73
+ def on_mouse_move(self, event):
74
+ if self.background is None:
75
+ self.create_new_background()
76
+ if not event.inaxes:
77
+ need_redraw = self.set_cross_hair_visible(False)
78
+ if any(need_redraw):
79
+ for ax, bkgd in zip(self.axes, self.backgrounds):
80
+ ax.figure.canvas.restore_region(bkgd)
81
+ ax.figure.canvas.blit(ax.bbox)
82
+ else:
83
+ self.set_cross_hair_visible(True)
84
+ # update the line positions
85
+ x, y = event.xdata, event.ydata
86
+ X = [trans(x) for trans in self.transforms_x]
87
+ Y = [trans(y) for trans in self.transforms_y]
88
+ if self.horizontal:
89
+ [line.set_ydata([y]) for line, y in zip(self.horizontal_line, Y)]
90
+ if self.vertical:
91
+ [line.set_xdata([x]) for line, x in zip(self.vertical_line, X)]
92
+ if self.show_text:
93
+ self.text.set_text(f'x={x:1.2f}, y={y:1.2f}')
94
+
95
+ for ax, bkgd in zip(self.axes, self.backgrounds):
96
+ ax.figure.canvas.restore_region(bkgd)
97
+ if self.horizontal:
98
+ [ax.draw_artist(line) for line in self.horizontal_line]
99
+ if self.vertical:
100
+ [ax.draw_artist(line) for line in self.vertical_line]
101
+ if self.show_text:
102
+ ax.draw_artist(self.text)
103
+ ax.figure.canvas.blit(ax.bbox)
16
104
 
17
105
 
18
106
  def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
@@ -72,11 +160,12 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
72
160
 
73
161
  strict = kwargs.pop('strict', False)
74
162
  instruments = self._check_instrument(instrument, strict=strict)
163
+ marker = kwargs.pop('marker', 'o')
75
164
 
76
165
  if bw:
77
- markers = cycle(('o', 'P', 's', '^', '*'))
166
+ markers = cycle((marker, 'P', 's', '^', '*'))
78
167
  else:
79
- markers = cycle(('o',) * len(instruments))
168
+ markers = cycle((marker,) * len(instruments))
80
169
 
81
170
  try:
82
171
  zorders = cycle(-np.argsort([getattr(self, i).error for i in instruments])[::-1])
@@ -313,12 +402,14 @@ def plot_quantity(self, quantity, ax=None, show_masked=False, instrument=None,
313
402
  ax.legend()
314
403
  ax.minorticks_on()
315
404
 
316
- if quantity == 'fwhm':
317
- ax.set_ylabel(f'FWHM [{self.units}]')
318
- elif quantity == 'bispan':
319
- ax.set_ylabel(f'BIS [{self.units}]')
320
- elif quantity == 'rhk':
321
- ax.set_ylabel(r"$\log$ R'$_{HK}$")
405
+ ylabel = {
406
+ 'fwhm': f'FWHM [{self.units}]',
407
+ 'bispan': f'BIS [{self.units}]',
408
+ 'rhk': r"$\log$ R'$_{HK}$",
409
+ 'berv': f'BERV [km/s]',
410
+ quantity: quantity,
411
+ }
412
+ ax.set_ylabel(ylabel[quantity])
322
413
 
323
414
  if remove_50000:
324
415
  ax.set_xlabel('BJD - 2450000 [days]')
@@ -334,9 +425,11 @@ def plot_quantity(self, quantity, ax=None, show_masked=False, instrument=None,
334
425
  plot_fwhm = partialmethod(plot_quantity, quantity='fwhm')
335
426
  plot_bis = partialmethod(plot_quantity, quantity='bispan')
336
427
  plot_rhk = partialmethod(plot_quantity, quantity='rhk')
428
+ plot_berv = partialmethod(plot_quantity, quantity='berv')
337
429
 
338
430
 
339
- def gls(self, ax=None, label=None, fap=True, picker=True, instrument=None, **kwargs):
431
+ def gls(self, ax=None, label=None, fap=True, instrument=None, adjust_means=config.adjust_means_gls,
432
+ picker=True, **kwargs):
340
433
  """
341
434
  Calculate and plot the Generalised Lomb-Scargle periodogram of the radial
342
435
  velocities.
@@ -348,42 +441,82 @@ def gls(self, ax=None, label=None, fap=True, picker=True, instrument=None, **kwa
348
441
  label (str):
349
442
  The label to use for the plot.
350
443
  fap (bool):
351
- Whether to show the false alarm probability.
444
+ Whether to show the false alarm probability. Default is True.
352
445
  instrument (str or list):
353
- Which instruments' data to include in the periodogram.
446
+ Which instruments' data to include in the periodogram. Default is
447
+ all instruments.
448
+ adjust_means (bool):
449
+ Whether to adjust (subtract) the weighted means of each instrument.
450
+ Default is `config.adjust_means_gls`.
354
451
  """
355
452
  if self.N == 0:
356
453
  if self.verbose:
357
454
  logger.error('no data to compute gls')
358
455
  return
359
456
 
360
- if ax is None:
361
- fig, ax = plt.subplots(1, 1, constrained_layout=True)
362
- else:
363
- fig = ax.figure
457
+ if not self._did_adjust_means and not adjust_means:
458
+ logger.warning('gls() called before adjusting instrument means, '
459
+ 'consider using `adjust_means` argument')
364
460
 
365
461
  if instrument is not None:
366
462
  strict = kwargs.pop('strict', False)
367
- instrument = self._check_instrument(instrument, strict=strict)
368
- if instrument is not None:
369
- instrument_mask = np.isin(self.instrument_array, instrument)
370
- t = self.time[instrument_mask & self.mask]
371
- y = self.vrad[instrument_mask & self.mask]
372
- e = self.svrad[instrument_mask & self.mask]
463
+ instrument = self._check_instrument(instrument, strict=strict, log=True)
464
+ if instrument is None:
465
+ return
466
+
467
+ instrument_mask = np.isin(self.instrument_array, instrument)
468
+ t = self.time[instrument_mask & self.mask].copy()
469
+ y = self.vrad[instrument_mask & self.mask].copy()
470
+ e = self.svrad[instrument_mask & self.mask].copy()
471
+ if self.verbose:
472
+ logger.info(f'calculating periodogram for instrument {instrument}')
473
+
474
+ if adjust_means:
373
475
  if self.verbose:
374
- logger.info(f'calculating periodogram for instrument {instrument}')
476
+ logger.info('adjusting instrument means before gls')
477
+ means = np.empty_like(y)
478
+ for i in instrument:
479
+ mask = self.instrument_array[instrument_mask & self.mask] == i
480
+ if len(y[mask]) > 0:
481
+ means += wmean(y[mask], e[mask]) * mask
482
+ y = y - means
483
+
375
484
  else:
376
- t = self.time[self.mask]
377
- y = self.vrad[self.mask]
378
- e = self.svrad[self.mask]
485
+ t = self.time[self.mask].copy()
486
+ y = self.vrad[self.mask].copy()
487
+ e = self.svrad[self.mask].copy()
488
+
489
+ if adjust_means:
490
+ if self.verbose:
491
+ logger.info('adjusting instrument means before gls')
492
+ means = np.empty_like(y)
493
+ for i in self.instruments:
494
+ mask = self.instrument_array[self.mask] == i
495
+ if len(y[mask]) > 0:
496
+ means += wmean(y[mask], e[mask]) * mask
497
+ y = y - means
379
498
 
380
499
  self._gls = gls = LombScargle(t, y, e)
500
+
381
501
  maximum_frequency = kwargs.pop('maximum_frequency', 1.0)
382
502
  minimum_frequency = kwargs.pop('minimum_frequency', None)
503
+ samples_per_peak = kwargs.pop('samples_per_peak', 10)
504
+
383
505
  freq, power = gls.autopower(maximum_frequency=maximum_frequency,
384
506
  minimum_frequency=minimum_frequency,
385
- samples_per_peak=10)
386
- ax.semilogx(1/freq, power, picker=picker, label=label, **kwargs)
507
+ samples_per_peak=samples_per_peak)
508
+
509
+ if ax is None:
510
+ fig, ax = plt.subplots(1, 1, constrained_layout=True)
511
+ else:
512
+ fig = ax.figure
513
+
514
+ if kwargs.pop('fill_between', False):
515
+ kwargs.pop('lw', None)
516
+ ax.fill_between(1/freq, 0, power, label=label, lw=0, **kwargs)
517
+ ax.set_xscale('log')
518
+ else:
519
+ ax.semilogx(1/freq, power, picker=picker, label=label, **kwargs)
387
520
 
388
521
  if fap:
389
522
  ax.axhline(gls.false_alarm_level(0.01),
@@ -471,6 +604,73 @@ gls_bis = partialmethod(gls_quantity, quantity='bispan')
471
604
  gls_rhk = partialmethod(gls_quantity, quantity='rhk')
472
605
 
473
606
 
607
+
608
+ def window_function(self, ax1=None, ax2=None, instrument=None, crosshair=False, **kwargs):
609
+ """
610
+ Calculate and plot the window function of the observed times.
611
+
612
+ Args:
613
+ ax1 (matplotlib.axes.Axes):
614
+ An axes to plot the window function vs period. If None, a new figure
615
+ will be created.
616
+ ax2 (matplotlib.axes.Axes):
617
+ An axes to plot the periodogram vs frequency. If None, a new figure
618
+ will be created.
619
+ instrument (str or list):
620
+ Which instruments' data to include in the window function.
621
+ crosshair (bool):
622
+ If True, a crosshair will be drawn on the plot.
623
+ """
624
+ if self.N == 0:
625
+ if self.verbose:
626
+ logger.error('no data to compute window function')
627
+ return
628
+
629
+ if ax1 is None:
630
+ fig, (ax1, ax2) = plt.subplots(2, 1, constrained_layout=True)
631
+ else:
632
+ fig = ax1.figure
633
+
634
+ if instrument is not None:
635
+ strict = kwargs.pop('strict', False)
636
+ instrument = self._check_instrument(instrument, strict=strict)
637
+ if instrument is not None:
638
+ instrument_mask = np.isin(self.instrument_array, instrument)
639
+ t = self.time[instrument_mask & self.mask]
640
+ ye = self.svrad[instrument_mask & self.mask]
641
+ if self.verbose:
642
+ logger.info(f'calculating window function for instrument {instrument}')
643
+ else:
644
+ t = self.time[self.mask]
645
+ ye = self.svrad[self.mask]
646
+
647
+ wf = LombScargle(t, np.ones_like(t), ye / np.std(ye),
648
+ fit_mean=False, center_data=False)
649
+
650
+ freq, power = wf.autopower(maximum_frequency=1.1,
651
+ samples_per_peak=20, method='cython')
652
+ ax1.semilogx(1/freq, power, **kwargs)
653
+ ax1.set(xlabel='Period [days]', ylabel='Window function')
654
+
655
+ ax2.plot(freq, power, **kwargs)
656
+ ax2.set(xlabel='Frequency [1/day]', ylabel='Window function')
657
+
658
+ for x in (365.25, 1.0, 1 - 1.0/365.25):
659
+ ax1.axvline(x, color='k', alpha=0.2, zorder=-1)
660
+ ax2.axvline(1/x, color='k', alpha=0.2, zorder=-1)
661
+
662
+ if crosshair:
663
+ blitted_cursor = BlittedCursor((ax1, ax2), horizontal=False,
664
+ transforms_x=(lambda x:x, lambda x:1/x))
665
+ fig.canvas.mpl_connect('motion_notify_event', blitted_cursor.on_mouse_move)
666
+ return fig, (ax1, ax2), blitted_cursor
667
+ # from matplotlib.widgets import MultiCursor
668
+ # cursor = MultiCursor(fig.canvas, (ax1, ax2), color='r',
669
+ # lw=0.5, horizOn=False, vertOn=True)
670
+ # return fig, (ax1, ax2), (cursor)
671
+ else:
672
+ return fig, (ax1, ax2)
673
+
474
674
  def histogram_svrad(self, ax=None, instrument=None, label=None):
475
675
  """ Plot an histogram of the radial velocity uncertainties.
476
676
 
arvi/programs.py CHANGED
@@ -20,12 +20,14 @@ def get_star(star, instrument=None):
20
20
 
21
21
 
22
22
  class LazyRV:
23
- def __init__(self, stars: list, instrument: str = None):
23
+ def __init__(self, stars: list, instrument: str = None,
24
+ _parallel_limit=10):
24
25
  self.stars = stars
25
26
  if isinstance(self.stars, str):
26
27
  self.stars = [self.stars]
27
28
  self.instrument = instrument
28
29
  self._saved = None
30
+ self._parallel_limit = _parallel_limit
29
31
 
30
32
  @property
31
33
  def N(self):
@@ -35,7 +37,7 @@ class LazyRV:
35
37
  return f"RV({self.N} stars)"
36
38
 
37
39
  def _get(self):
38
- if self.N > 10:
40
+ if self.N > self._parallel_limit:
39
41
  # logger.info('Querying DACE...')
40
42
  _get_star = partial(get_star, instrument=self.instrument)
41
43
  with multiprocessing.Pool() as pool:
arvi/simbad_wrapper.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- from dataclasses import dataclass, field
3
2
  import requests
4
3
 
5
4
  from astropy.coordinates import SkyCoord
@@ -126,7 +125,11 @@ class simbad:
126
125
  except IndexError:
127
126
  raise ValueError(f'simbad query for {star} failed')
128
127
 
129
- self.gaia_id = int([i for i in self.ids if 'Gaia DR3' in i][0].split('Gaia DR3')[-1])
128
+ try:
129
+ self.gaia_id = int([i for i in self.ids if 'Gaia DR3' in i][0]
130
+ .split('Gaia DR3')[-1])
131
+ except IndexError:
132
+ self.gaia_id = None
130
133
 
131
134
  for col, val in zip(cols, values):
132
135
  if col == 'oid':
arvi/spectra.py ADDED
@@ -0,0 +1,208 @@
1
+ import os
2
+ from glob import glob
3
+ import pickle
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ from .setup_logger import logger
8
+
9
+ from tqdm import tqdm
10
+ import astropy.units as u, astropy.constants as const
11
+ from astropy.io import fits
12
+
13
+ def doppler_shift(wave: np.ndarray, flux: np.ndarray, velocity: float):
14
+ """ Doppler shift a spectrum by a given velocity
15
+
16
+ Args:
17
+ wave (np.ndarray): wavelength array
18
+ flux (np.ndarray): flux array
19
+ velocity (float): velocity in km/s
20
+ """
21
+ c = const.c.to(u.km/u.second).value
22
+ doppler_factor = np.sqrt((1 + velocity/c) / (1 - velocity/c))
23
+ new_wavelength = wave * doppler_factor
24
+ new_flux = np.interp(new_wavelength, wave, flux)
25
+ return new_wavelength, new_flux
26
+
27
+ def fit_gaussian_to_line(wave, flux, center_wavelength, around=0.15 * u.angstrom,
28
+ careful_continuum=False, plot=True, ax=None):
29
+ from scipy.optimize import curve_fit
30
+ if center_wavelength < wave.min() or center_wavelength > wave.max():
31
+ raise ValueError('`center_wavelength` is outside the wavelength range')
32
+ if center_wavelength < wave[np.nonzero(flux)].min() or center_wavelength > wave[np.nonzero(flux)].max():
33
+ raise ValueError('`center_wavelength` is outside the wavelength range where flux is not zero')
34
+
35
+ try:
36
+ wave <<= u.angstrom
37
+ except u.UnitConversionError as e:
38
+ raise ValueError(f'could not convert `wave` to Angstroms: {e}') from None
39
+
40
+ try:
41
+ center_wavelength <<= u.angstrom
42
+ except u.UnitConversionError as e:
43
+ raise ValueError(f'could not convert `center_wavelength` to Angstroms: {e}') from None
44
+
45
+ try:
46
+ around <<= u.angstrom
47
+ except u.UnitConversionError as e:
48
+ raise ValueError(f'could not convert `around` to Angstroms: {e}') from None
49
+
50
+
51
+ def gaussian(x, amp, cen, wid, off):
52
+ return amp * np.exp(-(x-cen)**2 / (2*wid**2)) + off
53
+
54
+ wave_around = (wave > center_wavelength - around) & (wave < center_wavelength + around)
55
+ w, f = wave[wave_around].value, flux[wave_around]
56
+
57
+ if careful_continuum:
58
+ wave_around_continuum = (wave > center_wavelength - 10*around) & (wave < center_wavelength + 10*around)
59
+ wc, fc = wave[wave_around_continuum].value, flux[wave_around_continuum]
60
+ lim = np.percentile(fc, 80)
61
+ wc = wc[fc > lim]
62
+ fc = fc[fc > lim]
63
+ w, f = np.r_[wc, w], np.r_[fc, f]
64
+ f = f[np.argsort(w)]
65
+ w = np.sort(w)
66
+
67
+ lower, upper = np.array([
68
+ [-np.inf, 0],
69
+ [-np.inf, np.inf],
70
+ [0.0, 0.11],
71
+ [0.9*f.max(), 1.1*f.max()]
72
+ ]).T
73
+
74
+ try:
75
+ popt, pcov = curve_fit(gaussian, w, f, p0=[-f.ptp(), center_wavelength.value, 0.1, f.max()],
76
+ bounds=(lower, upper))
77
+ except RuntimeError as e:
78
+ logger.warning(f'fit_gaussian_to_line: {e}')
79
+ return None, np.nan, np.nan
80
+
81
+ EW = A = (np.sqrt(2) * np.abs(popt[0]) * np.abs(popt[2]) * np.sqrt(np.pi)) / popt[3]
82
+ perr = np.sqrt(np.diag(pcov))
83
+ EW_err = (np.sqrt(2) * np.abs(perr[0]) * np.abs(perr[2]) * np.sqrt(np.pi)) / perr[3]
84
+
85
+ if plot:
86
+ if ax is None:
87
+ fig, ax = plt.subplots(figsize=(8, 4))
88
+ if careful_continuum:
89
+ ax.plot(w, f, 'ko', ms=4, zorder=1)
90
+ wave_around_plot = wave_around_continuum
91
+ else:
92
+ ax.plot(wave[wave_around], flux[wave_around], 'ko', ms=4, zorder=1)
93
+ ax.plot(wave[wave_around], flux[wave_around] - gaussian(w, *popt), 'o', ms=2)
94
+ wave_around_plot = (wave > center_wavelength - 2*around) & (wave < center_wavelength + 2*around)
95
+ # ax.plot(wave[wave_around_plot], flux[wave_around_plot], 'o', ms=2)
96
+ w = wave[wave_around_plot].value
97
+ ax.plot(w, gaussian(w, *popt), 'r-')
98
+ ax.fill_between([popt[1]-A, popt[1]+A], popt[3]+popt[0], popt[3],
99
+ color='C2', alpha=0.1, lw=0)
100
+
101
+ return popt, EW*1e3, EW_err*1e3
102
+
103
+ def detrend(w, f):
104
+ if w.shape[0] > w.shape[1]:
105
+ w = np.copy(w).T
106
+ f = np.copy(f).T
107
+
108
+ f_detrended = np.zeros_like(f)
109
+ for i, (ww, ff) in enumerate(zip(w, f)):
110
+ m = np.nonzero(ff)
111
+ fit = np.polyval(np.polyfit(ww[m] - np.median(ww[m]), ff[m], 1), ww - np.median(ww[m]))
112
+ f_detrended[i] = ff - fit + np.median(ff[m])
113
+ return w, f_detrended
114
+
115
+ def build_master(self, limit=None, plot=True):
116
+ files = sorted(glob(f'{self.star}_downloads/*S1D_A.fits'))
117
+ if self.verbose:
118
+ logger.info(f'Found {len(files)} S1D files')
119
+
120
+ files = files[:limit]
121
+
122
+ if len(files) == 0:
123
+ if self.verbose:
124
+ logger.warning('Should probably run `download_s1d` first')
125
+ return
126
+
127
+ if plot:
128
+ fig, axs = plt.subplots(2, 1, figsize=(9, 6), sharex=True, constrained_layout=True)
129
+ for ax in axs:
130
+ ax.set(xlabel=r'wavelength air [$\AA$]', ylabel='flux')
131
+ axs[0].set_title(self.star, loc='right', fontsize=10)
132
+
133
+ w0 = fits.getdata(files[0])['wavelength_air']
134
+ master_flux = np.zeros_like(w0)
135
+ for file in files:
136
+ rv = fits.getval(file, '*CCF RV')[0]
137
+ flux = fits.getdata(file)['flux']
138
+ _, new_flux = doppler_shift(w0, flux, rv)
139
+ master_flux += new_flux
140
+ if plot:
141
+ axs[0].plot(w0, new_flux, alpha=0.5)
142
+
143
+ master_flux /= len(files)
144
+ if plot:
145
+ axs[1].plot(w0, master_flux, 'k', label='master')
146
+ axs[1].legend()
147
+ axs[0].legend([], [], title=f'{len(files)} S1D spectra')
148
+
149
+ return w0, master_flux
150
+
151
+
152
+ def determine_stellar_parameters(self, linelist: str, plot=True, **kwargs):
153
+ try:
154
+ from juliacall import Main as jl
155
+ jl.seval("using Korg")
156
+ Korg = jl.Korg
157
+ except ModuleNotFoundError:
158
+ msg = 'this function requires juliacall and Korg.jl, please `pip install juliacall`'
159
+ logger.error(msg)
160
+ return
161
+
162
+ w, f = build_master(self, plot=plot)
163
+
164
+ linelist = np.genfromtxt(linelist, dtype=None, encoding=None, names=True)
165
+ lines = [
166
+ Korg.Line(line['wl'], line['loggf'], Korg.Species(line['elem'].replace('Fe', 'Fe ')), line['EP'])
167
+ for line in linelist
168
+ ]
169
+
170
+ if self.verbose:
171
+ logger.info(f'Found {len(lines)} lines in linelist')
172
+ logger.info('Measuring EWs...')
173
+
174
+ EW = []
175
+ pbar = tqdm(linelist)
176
+ for line in pbar:
177
+ pbar.set_description(f'{line["elem"]} {line["wl"]}')
178
+ _, ew, _ = fit_gaussian_to_line(w, f, line['wl'], plot=plot,
179
+ careful_continuum=kwargs.pop('careful_continuum', False))
180
+ EW.append(ew)
181
+
182
+ lines = list(np.array(lines)[~np.isnan(EW)])
183
+ EW = np.array(EW)[~np.isnan(EW)]
184
+
185
+ if self.verbose:
186
+ logger.info('Determining stellar parameters (can take a few minutes)...')
187
+
188
+ callback = lambda p, r, A: print('current parameters:', p)
189
+ result = Korg.Fit.ews_to_stellar_parameters(lines, EW, callback=callback)
190
+ par, stat_err, sys_err = result
191
+
192
+ if self.verbose:
193
+ logger.info(f'Best fit stellar parameters:')
194
+ logger.info(f' Teff: {par[0]:.0f} ± {sys_err[0]:.0f} K')
195
+ logger.info(f' logg: {par[1]:.2f} ± {sys_err[1]:.2f} dex')
196
+ logger.info(f' m/H : {par[3]:.2f} ± {sys_err[3]:.2f} dex')
197
+
198
+ r = {
199
+ 'teff': (par[0], sys_err[0]),
200
+ 'logg': (par[1], sys_err[1]),
201
+ 'vmic': (par[2], sys_err[2]),
202
+ 'moh': (par[3], sys_err[3]),
203
+ }
204
+
205
+ with open(f'{self.star}_stellar_parameters.pkl', 'wb') as f:
206
+ pickle.dump(r, f)
207
+
208
+ return r
arvi/stellar.py ADDED
@@ -0,0 +1,89 @@
1
+ import numpy as np
2
+
3
+ class prot_age_result:
4
+ prot_n84: float
5
+ prot_n84_err: float
6
+ prot_m08: float
7
+ prot_m08_err: float
8
+ age_m08: float
9
+ age_m08_err: float
10
+ def __init__(self):
11
+ pass
12
+ def __repr__(self):
13
+ s = f'{self.prot_n84=:.2f} ± {self.prot_n84_err:.2f}, '
14
+ s += f'{self.prot_m08=:.2f} ± {self.prot_m08_err:.2f}, '
15
+ s += f'{self.age_m08=:.2f} ± {self.age_m08_err:.2f}'
16
+ return s.replace('self.', '')
17
+
18
+
19
+ def calc_prot_age(self, bv=None):
20
+ """
21
+ Calculate rotation period and age from logR'HK activity level, based on the
22
+ empirical relations of Noyes et al. (1984) and Mamajek & Hillenbrand (2008).
23
+
24
+ Args:
25
+ self (`arvi.RV`):
26
+ RV object
27
+ bv (float, optional):
28
+ B-V colour. If None, use value from Simbad
29
+
30
+ Returns:
31
+ An object with the following attributes:
32
+
33
+ prot_n84 (float, array):
34
+ Chromospheric rotational period via Noyes et al. (1984)
35
+ prot_n84_err (float, array):
36
+ Error on 'prot_n84'
37
+ prot_m08 (float, array):
38
+ Chromospheric rotational period via Mamajek & Hillenbrand (2008)
39
+ prot_m08_err (float, array):
40
+ Error on 'prot_m08'
41
+ age_m08 (float, array):
42
+ Gyrochronology age via Mamajek & Hillenbrand (2008)
43
+ age_m08_err (float, array):
44
+ Error on 'age_m08'
45
+
46
+ Range of logR'HK-Prot relation: -5.5 < logR'HK < -4.3
47
+ Range of Mamajek & Hillenbrand (2008) relation for ages: 0.5 < B-V < 0.9
48
+ """
49
+ log_rhk = np.nanmean(self.rhk[self.mask])
50
+ if bv is None:
51
+ bv = self.simbad.B - self.simbad.V
52
+
53
+ # Calculate chromospheric Prot:
54
+ if np.any(log_rhk < -4.3) & np.any(log_rhk > -5.5):
55
+ if bv < 1:
56
+ tau = 1.362 - 0.166*(1-bv) + 0.025*(1-bv)**2 - 5.323*(1-bv)**3
57
+ else:
58
+ tau = 1.362 - 0.14*(1-bv)
59
+
60
+ prot_n84 = 0.324 - 0.400*(5 + log_rhk) - 0.283*(5 + log_rhk)**2 - 1.325*(5 + log_rhk)**3 + tau
61
+ prot_n84 = 10**prot_n84
62
+ prot_n84_err = np.log(10)*0.08*prot_n84
63
+
64
+ prot_m08 = (0.808 - 2.966*(log_rhk + 4.52))*10**tau
65
+ prot_m08_err = 4.4*bv*1.7 - 1.7
66
+ else:
67
+ prot_n84 = np.nan
68
+ prot_n84_err = np.nan
69
+ prot_m08 = np.nan
70
+ prot_m08_err = np.nan
71
+
72
+ # Calculate gyrochronology age:
73
+ if np.any(prot_m08 > 0.0) & (bv > 0.50) & (bv < 0.9):
74
+ age_m08 = 1e-3*(prot_m08/0.407/(bv - 0.495)**0.325)**(1./0.566)
75
+ #age_m08_err = 0.05*np.log(10)*age_m08
76
+ age_m08_err = 0.2 * age_m08 * np.log(10) # using 0.2 dex typical error from paper
77
+ else:
78
+ age_m08 = np.nan
79
+ age_m08_err = np.nan
80
+
81
+ r = prot_age_result()
82
+ r.prot_n84 = prot_n84
83
+ r.prot_n84_err = prot_n84_err
84
+ r.prot_m08 = prot_m08
85
+ r.prot_m08_err = prot_m08_err
86
+ r.age_m08 = age_m08
87
+ r.age_m08_err = age_m08_err
88
+ return r
89
+ # return prot_n84, prot_n84_err, prot_m08, prot_m08_err, age_m08, age_m08_err