arvi 0.2.8__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arvi/timeseries.py CHANGED
@@ -4,8 +4,9 @@ from typing import Union
4
4
  from functools import partial, partialmethod
5
5
  from glob import glob
6
6
  import warnings
7
- from copy import deepcopy
7
+ from copy import copy, deepcopy
8
8
  from datetime import datetime, timezone
9
+
9
10
  import numpy as np
10
11
 
11
12
  from .setup_logger import setup_logger
@@ -24,10 +25,12 @@ from .HZ import getHZ_period
24
25
  from .instrument_specific import ISSUES
25
26
  from .reports import REPORTS
26
27
  from .utils import sanitize_path, strtobool, there_is_internet, timer, chdir
27
- from .utils import lazy_import
28
+ from .setup_logger import setup_logger
29
+ logger = setup_logger()
28
30
 
29
- units = lazy_import('astropy.units')
30
- # from astropy import units
31
+ # units = lazy_import('astropy.units')
32
+ # units = lazy.load('astropy.units')
33
+ from astropy import units
31
34
 
32
35
  class ExtraFields:
33
36
  @property
@@ -408,35 +411,71 @@ class RV(ISSUES, REPORTS):
408
411
  self._did_correct_berv = False
409
412
  self.__post_init__()
410
413
 
411
- def snapshot(self, directory=None, delete_others=False):
412
- import pickle
414
+ def snapshot(self, directory=None, delete_others=False, compress=False):
415
+ if compress:
416
+ try:
417
+ import compress_pickle as pickle
418
+ except ImportError:
419
+ logger.warning('compress_pickle not installed, not compressing')
420
+ import pickle
421
+ compress = False
422
+ else:
423
+ import pickle
424
+ import re
413
425
  from datetime import datetime
426
+
414
427
  ts = datetime.now().timestamp()
415
428
  star_name = self.star.replace(' ', '')
416
429
  file = f'{star_name}_{ts}.pkl'
417
430
 
431
+ server = None
418
432
  if directory is None:
419
433
  directory = '.'
420
434
  else:
421
- os.makedirs(directory, exist_ok=True)
422
-
423
- file = os.path.join(directory, file)
424
-
425
- if delete_others:
426
- import re
427
- other_pkls = [
428
- f for f in os.listdir(directory)
429
- if re.search(fr'{star_name}_\d+.\d+.pkl', f)
430
- ]
431
- for pkl in other_pkls:
432
- os.remove(os.path.join(directory, pkl))
435
+ if ':' in directory:
436
+ server, directory = directory.split(':')
437
+ delete_others = False
438
+ else:
439
+ os.makedirs(directory, exist_ok=True)
433
440
 
434
441
  metadata = {
435
442
  'star': self.star,
436
443
  'timestamp': ts,
437
444
  'description': 'arvi snapshot'
438
445
  }
439
- pickle.dump((self, metadata), open(file, 'wb'), protocol=0)
446
+
447
+
448
+ if server:
449
+ import posixpath
450
+ from .utils import server_sftp, server_file
451
+ with server_sftp(server=server) as sftp:
452
+ try:
453
+ sftp.chdir(directory)
454
+ except FileNotFoundError:
455
+ sftp.mkdir(directory)
456
+ finally:
457
+ sftp.chdir(directory)
458
+ with sftp.open(file, 'wb') as f:
459
+ print('saving snapshot to server...', end='', flush=True)
460
+ pickle.dump((self, metadata), f, protocol=0)
461
+ print('done')
462
+ file = posixpath.join(directory, file)
463
+ else:
464
+ if delete_others:
465
+ other_pkls = [
466
+ f for f in os.listdir(directory)
467
+ if re.search(fr'{star_name}_\d+.\d+.pkl', f)
468
+ ]
469
+ for pkl in other_pkls:
470
+ os.remove(os.path.join(directory, pkl))
471
+
472
+ file = os.path.join(directory, file)
473
+
474
+ if compress:
475
+ file += '.gz'
476
+
477
+ with open(file, 'wb') as f:
478
+ pickle.dump((self, metadata), f)
440
479
 
441
480
  if self.verbose:
442
481
  logger.info(f'saved snapshot to {file}')
@@ -511,6 +550,15 @@ class RV(ISSUES, REPORTS):
511
550
  def instrument_array(self):
512
551
  return np.concatenate([[i] * n for i, n in self.NN.items()])
513
552
 
553
+ def _instrument_mask(self, instrument):
554
+ if isinstance(instrument, str):
555
+ return np.char.find(self.instrument_array, instrument) == 0
556
+ elif isinstance(instrument, (list, tuple, np.ndarray)):
557
+ m = np.full_like(self.time, False, dtype=bool)
558
+ for i in instrument:
559
+ m |= np.char.find(self.instrument_array, i) == 0
560
+ return m
561
+
514
562
  @property
515
563
  def rms(self) -> float:
516
564
  """ Weighted rms of the (masked) radial velocities """
@@ -537,6 +585,11 @@ class RV(ISSUES, REPORTS):
537
585
  def _mtime_sorter(self):
538
586
  return np.argsort(self.mtime)
539
587
 
588
+ @property
589
+ def timespan(self):
590
+ """ Total time span of the (masked) observations """
591
+ return np.ptp(self.mtime)
592
+
540
593
  def _index_from_instrument_index(self, index, instrument):
541
594
  ind = np.where(self.instrument_array == instrument)[0]
542
595
  return ind[getattr(self, instrument).mask][index]
@@ -577,7 +630,8 @@ class RV(ISSUES, REPORTS):
577
630
  # --> not just in rhk and rhk_err...
578
631
  if data[arr].dtype == float and (bad := data[arr] == -99999).any():
579
632
  data[arr][bad] = np.nan
580
-
633
+ if data[arr].dtype == float and (bad := data[arr] == -99).any():
634
+ data[arr][bad] = np.nan
581
635
  setattr(s, arr, data[arr][ind])
582
636
  s._quantities.append(arr)
583
637
 
@@ -629,22 +683,28 @@ class RV(ISSUES, REPORTS):
629
683
  import pickle
630
684
  from datetime import datetime
631
685
  if star is None:
632
- assert file.endswith('.pkl'), 'expected a .pkl file'
633
- star, timestamp = file.replace('.pkl', '').split('_')
686
+ assert file.endswith(('.pkl', '.pkl.gz')), 'expected a .pkl file'
687
+ basefile = os.path.basename(file)
688
+ star, timestamp = basefile.replace('.pkl.gz', '').replace('.pkl', '').split('_')
634
689
  else:
635
690
  try:
636
- file = sorted(glob(f'{star}_*.*.pkl'))[-1]
691
+ file = sorted(glob(f'{star}_*.*.pkl*'))[-1]
637
692
  except IndexError:
638
693
  raise ValueError(f'cannot find any file matching {star}_*.pkl')
639
- star, timestamp = file.replace('.pkl', '').split('_')
694
+ star, timestamp = file.replace('.pkl.gz', '').replace('.pkl', '').split('_')
640
695
 
641
696
  dt = datetime.fromtimestamp(float(timestamp))
642
697
  if verbose:
643
698
  logger.info(f'reading snapshot of {star} from {dt}')
644
699
 
645
- s = pickle.load(open(file, 'rb'))
700
+ with open(file, 'rb') as f:
701
+ if file.endswith('.gz'):
702
+ import compress_pickle as pickle
703
+ s = pickle.load(f)
704
+
646
705
  if isinstance(s, tuple) and len(s) == 2:
647
706
  s, _metadata = s
707
+
648
708
  s._snapshot = file
649
709
  return s
650
710
 
@@ -1504,7 +1564,7 @@ class RV(ISSUES, REPORTS):
1504
1564
  """ Remove all observations that satisfy a condition
1505
1565
 
1506
1566
  Args:
1507
- condition (np.ndarray):
1567
+ condition (ndarray):
1508
1568
  Boolean array of the same length as the observations
1509
1569
  """
1510
1570
  if self.verbose:
@@ -1664,16 +1724,18 @@ class RV(ISSUES, REPORTS):
1664
1724
  self._propagate_mask_changes()
1665
1725
 
1666
1726
 
1667
- def _propagate_mask_changes(self):
1727
+ def _propagate_mask_changes(self, _remove_instrument=True):
1668
1728
  """ link self.mask with each self.`instrument`.mask """
1669
1729
  masked = np.where(~self.mask)[0]
1670
1730
  for m in masked:
1671
1731
  inst = self.instruments[self.obs[m] - 1]
1672
1732
  n_before = (self.obs < self.obs[m]).sum()
1673
1733
  getattr(self, inst).mask[m - n_before] = False
1674
- for inst in self.instruments:
1675
- if getattr(self, inst).mtime.size == 0:
1676
- self.remove_instrument(inst, strict=True)
1734
+ if _remove_instrument:
1735
+ instruments = copy(self.instruments)
1736
+ for inst in instruments:
1737
+ if getattr(self, inst).mtime.size == 0:
1738
+ self.remove_instrument(inst, strict=True)
1677
1739
 
1678
1740
  def secular_acceleration(self, epoch=None, just_compute=False, force_simbad=False):
1679
1741
  """
@@ -1691,9 +1753,12 @@ class RV(ISSUES, REPORTS):
1691
1753
  force_simbad (bool, optional):
1692
1754
  Use Simbad proper motions even if Gaia is available
1693
1755
  """
1694
- if self._did_secular_acceleration and not just_compute: # don't do it twice
1756
+ # don't do it twice
1757
+ if self._did_secular_acceleration and not just_compute:
1695
1758
  return
1696
1759
 
1760
+ from astropy import units
1761
+
1697
1762
  #as_yr = units.arcsec / units.year
1698
1763
  mas_yr = units.milliarcsecond / units.year
1699
1764
  mas = units.milliarcsecond
@@ -1825,15 +1890,21 @@ class RV(ISSUES, REPORTS):
1825
1890
 
1826
1891
  self._did_secular_acceleration = False
1827
1892
 
1828
- def sigmaclip(self, sigma=5, instrument=None, strict=True):
1893
+ def sigmaclip(self, sigma=5, quantity='vrad', instrument=None,
1894
+ strict=True):
1829
1895
  """
1830
- Sigma-clip RVs (per instrument!), by MAD away from the median.
1896
+ Sigma-clip RVs or other quantities (per instrument!), by MAD away from
1897
+ the median.
1831
1898
 
1832
1899
  Args:
1833
1900
  sigma (float):
1834
- Number of MADs to clip
1901
+ Number of MADs away from the median
1902
+ quantity (str):
1903
+ Quantity to sigma-clip (by default the RVs)
1835
1904
  instrument (str, list):
1836
1905
  Instrument(s) to sigma-clip
1906
+ strict (bool):
1907
+ Passed directly to self._check_instrument
1837
1908
  """
1838
1909
  #from scipy.stats import sigmaclip as dosigmaclip
1839
1910
  from .stats import sigmaclip_median as dosigmaclip
@@ -1842,20 +1913,26 @@ class RV(ISSUES, REPORTS):
1842
1913
  return
1843
1914
 
1844
1915
  instruments = self._check_instrument(instrument, strict)
1916
+ if instruments is None:
1917
+ return
1845
1918
  changed_instruments = []
1846
1919
 
1847
1920
  for inst in instruments:
1848
1921
  m = self.instrument_array == inst
1849
- result = dosigmaclip(self.vrad[m], low=sigma, high=sigma)
1922
+ d = getattr(self, quantity)
1923
+
1924
+ if np.isnan(d[m]).all():
1925
+ continue
1926
+
1927
+ result = dosigmaclip(d[m], low=sigma, high=sigma)
1850
1928
  # n = self.vrad[m].size - result.clipped.size
1851
1929
 
1852
- ind = m & self.mask & \
1853
- ((self.vrad < result.lower) | (self.vrad > result.upper))
1930
+ ind = m & self.mask & ((d < result.lower) | (d > result.upper))
1854
1931
  n = ind.sum()
1855
1932
 
1856
1933
  if self.verbose and n > 0:
1857
1934
  s = 's' if (n == 0 or n > 1) else ''
1858
- logger.warning(f'sigma-clip RVs will remove {n} point{s} for {inst}')
1935
+ logger.warning(f'sigma-clip {quantity} will remove {n} point{s} for {inst}')
1859
1936
 
1860
1937
  if n > 0:
1861
1938
  self.mask[ind] = False
@@ -1880,21 +1957,32 @@ class RV(ISSUES, REPORTS):
1880
1957
  if config.return_self:
1881
1958
  return self
1882
1959
 
1883
- def clip_maxerror(self, maxerror:float):
1884
- """ Mask out points with RV error larger than a given value
1960
+ def clip_maxerror(self, maxerror:float, instrument=None):
1961
+ """
1962
+ Mask out points with RV error larger than a given value. If `instrument`
1963
+ is given, mask only observations from that instrument.
1885
1964
 
1886
1965
  Args:
1887
1966
  maxerror (float): Maximum error to keep.
1967
+ instrument (str, list, tuple, ndarray): Instrument(s) to clip
1888
1968
  """
1889
1969
  if self._child:
1890
1970
  return
1891
1971
 
1892
1972
  self.maxerror = maxerror
1973
+
1974
+ if instrument is None:
1975
+ inst_mask = np.ones_like(self.svrad, dtype=bool)
1976
+ else:
1977
+ inst_mask = self._instrument_mask(instrument)
1978
+
1893
1979
  above = self.svrad > maxerror
1894
- n = above.sum()
1895
- self.mask[above] = False
1980
+ old_mask = self.mask.copy()
1981
+
1982
+ self.mask[inst_mask & above] = False
1896
1983
 
1897
1984
  if self.verbose and above.sum() > 0:
1985
+ n = (above[inst_mask] & old_mask[inst_mask]).sum()
1898
1986
  s = 's' if (n == 0 or n > 1) else ''
1899
1987
  logger.warning(f'clip_maxerror ({maxerror} {self.units}) removed {n} point' + s)
1900
1988
 
@@ -1902,6 +1990,36 @@ class RV(ISSUES, REPORTS):
1902
1990
  if config.return_self:
1903
1991
  return self
1904
1992
 
1993
+ def sigmaclip_ew(self, sigma=5):
1994
+ """ Sigma-clip EW (FWHM x contrast), by MAD away from the median """
1995
+ from .stats import sigmaclip_median as dosigmaclip, weighted_median
1996
+
1997
+ S = deepcopy(self)
1998
+ for _s in S:
1999
+ m = _s.mask
2000
+ _s.fwhm -= weighted_median(_s.fwhm[m], 1 / _s.fwhm_err[m])
2001
+ _s.contrast -= weighted_median(_s.contrast[m], 1 / _s.contrast_err[m])
2002
+ S._build_arrays()
2003
+ ew = S.fwhm * S.contrast
2004
+ ew_err = np.hypot(S.fwhm_err * S.contrast, S.fwhm * S.contrast_err)
2005
+
2006
+ wmed = weighted_median(ew[S.mask], 1 / ew_err[S.mask])
2007
+ data = (ew - wmed) / ew_err
2008
+ result = dosigmaclip(data, low=sigma, high=sigma)
2009
+ ind = (data < result.lower) | (data > result.upper)
2010
+ self.mask[ind] = False
2011
+
2012
+ if self.verbose and ind.sum() > 0:
2013
+ n = ind.sum()
2014
+ s = 's' if (n == 0 or n > 1) else ''
2015
+ logger.warning(f'sigmaclip_ew removed {n} point' + s)
2016
+
2017
+ self._propagate_mask_changes()
2018
+ if config.return_self:
2019
+ return self
2020
+
2021
+
2022
+
1905
2023
  def bin(self):
1906
2024
  """
1907
2025
  Nightly bin the observations.
@@ -1912,6 +2030,8 @@ class RV(ISSUES, REPORTS):
1912
2030
 
1913
2031
  # create copy of self to be returned
1914
2032
  snew = deepcopy(self)
2033
+ # store original object
2034
+ snew._unbinned = deepcopy(self)
1915
2035
 
1916
2036
  all_bad_quantities = []
1917
2037
 
@@ -1943,7 +2063,8 @@ class RV(ISSUES, REPORTS):
1943
2063
 
1944
2064
  # treat ccf_mask specially, doing a 'unique' bin
1945
2065
  if q == 'ccf_mask':
1946
- setattr(s, q, bin_ccf_mask(s.mtime, getattr(s, q)))
2066
+ ccf_mask = getattr(s, q)[s.mask]
2067
+ setattr(s, q, bin_ccf_mask(s.mtime, ccf_mask))
1947
2068
  continue
1948
2069
 
1949
2070
  if Q.dtype != np.float64:
@@ -2101,23 +2222,32 @@ class RV(ISSUES, REPORTS):
2101
2222
  if config.return_self:
2102
2223
  return self
2103
2224
 
2104
- def detrend(self, degree=1):
2105
- """ Detrend the RVs of all instruments """
2225
+ def detrend(self, degree: int=1):
2226
+ """
2227
+ Detrend the RVs of all instruments using a polynomial of degree `degree`
2228
+ """
2106
2229
  instrument_indices = np.unique_inverse(self.instrument_array).inverse_indices
2107
- def fun(p, t, degree, ninstruments, just_model=False, index=None):
2230
+ instrument_indices_masked = np.unique_inverse(self.instrument_array[self.mask]).inverse_indices
2231
+
2232
+ def fun(p, t, degree, ninstruments, just_model=False, index=None, masked=True):
2108
2233
  polyp, offsets = p[:degree], p[-ninstruments:]
2109
2234
  polyp = np.r_[polyp, 0.0]
2110
2235
  if index is None:
2111
- model = offsets[instrument_indices] + np.polyval(polyp, t)
2236
+ if masked:
2237
+ model = offsets[instrument_indices_masked] + np.polyval(polyp, t)
2238
+ else:
2239
+ model = offsets[instrument_indices] + np.polyval(polyp, t)
2112
2240
  else:
2113
2241
  model = offsets[index] + np.polyval(polyp, t)
2114
2242
  if just_model:
2115
2243
  return model
2116
2244
  return self.mvrad - model
2245
+
2117
2246
  coef = np.polyfit(self.mtime, self.mvrad, degree)
2118
2247
  x0 = np.append(coef, [0.0] * (len(self.instruments) - 1))
2119
- print(x0)
2248
+ # print(x0)
2120
2249
  fun(x0, self.mtime, degree, len(self.instruments))
2250
+
2121
2251
  from scipy.optimize import leastsq
2122
2252
  xbest, _ = leastsq(fun, x0, args=(self.mtime, degree, len(self.instruments)))
2123
2253
 
@@ -2127,12 +2257,13 @@ class RV(ISSUES, REPORTS):
2127
2257
  self.plot(ax=ax)
2128
2258
  for i, inst in enumerate(self.instruments):
2129
2259
  s = getattr(self, inst)
2130
- ax.plot(s.time, fun(xbest, s.time, degree, len(self.instruments), just_model=True, index=i),
2260
+ ax.plot(s.time,
2261
+ fun(xbest, s.time, degree, len(self.instruments), just_model=True, index=i, masked=False),
2131
2262
  color=f'C{i}')
2132
2263
  ax.set_title('original', loc='left', fontsize=10)
2133
2264
  ax.set_title(f'coefficients: {xbest[:degree]}', loc='right', fontsize=10)
2134
2265
 
2135
- self.add_to_vrad(-fun(xbest, self.time, degree, len(self.instruments), just_model=True))
2266
+ self.add_to_vrad(-fun(xbest, self.time, degree, len(self.instruments), just_model=True, masked=False))
2136
2267
  ax = fig.add_subplot(2, 1, 2)
2137
2268
  self.plot(ax=ax)
2138
2269
  ax.set_title('detrended', loc='left', fontsize=10)
@@ -2141,7 +2272,7 @@ class RV(ISSUES, REPORTS):
2141
2272
  # axs[1].errorbar(self.mtime, fun(xbest, self.mtime, degree, len(self.instruments)), self.msvrad, fmt='o')
2142
2273
 
2143
2274
  return
2144
-
2275
+
2145
2276
 
2146
2277
 
2147
2278
 
@@ -2301,24 +2432,31 @@ class RV(ISSUES, REPORTS):
2301
2432
  self.units = new_units
2302
2433
 
2303
2434
 
2304
- def put_at_systemic_velocity(self):
2435
+ def put_at_systemic_velocity(self, factor=1.0, ignore=None):
2305
2436
  """
2306
- For instruments in which mean(RV) < ptp(RV), "move" RVs to the systemic
2307
- velocity from simbad. This is useful if some instruments are centered
2308
- at zero while others are not, and instead of calling `.adjust_means()`,
2309
- but it only works when the systemic velocity is smaller than ptp(RV).
2437
+ For instruments in which mean(RV) < `factor` * ptp(RV), "move" RVs to
2438
+ the systemic velocity from simbad. This is useful if some instruments
2439
+ are centered at zero while others are not, and instead of calling
2440
+ `.adjust_means()`, but it only works when the systemic velocity is
2441
+ smaller than `factor` * ptp(RV).
2310
2442
  """
2311
2443
  changed = False
2312
2444
  for inst in self.instruments:
2445
+ if ignore is not None:
2446
+ if inst in ignore or any([i in inst for i in ignore]):
2447
+ continue
2448
+ changed_inst = False
2313
2449
  s = getattr(self, inst)
2314
2450
  if s.mask.any():
2315
- if np.abs(s.mvrad.mean()) < np.ptp(s.mvrad):
2451
+ if np.abs(s.mvrad.mean()) < factor * np.ptp(s.mvrad):
2316
2452
  s.vrad += self.simbad.rvz_radvel * 1e3
2317
- changed = True
2453
+ changed = changed_inst = True
2318
2454
  else: # all observations are masked, use non-masked arrays
2319
- if np.abs(s.vrad.mean()) < np.ptp(s.vrad):
2455
+ if np.abs(s.vrad.mean()) < factor * np.ptp(s.vrad):
2320
2456
  s.vrad += self.simbad.rvz_radvel * 1e3
2321
- changed = True
2457
+ changed = changed_inst = True
2458
+ if changed_inst and self.verbose:
2459
+ logger.info(f"putting {inst} RVs at systemic velocity")
2322
2460
  if changed:
2323
2461
  self._build_arrays()
2324
2462
 
@@ -2340,34 +2478,72 @@ class RV(ISSUES, REPORTS):
2340
2478
  self.instruments = sorted(self.instruments, key=lambda i: getattr(self, i).time.max())
2341
2479
  self._build_arrays()
2342
2480
 
2481
+ def put_instrument_last(self, instrument):
2482
+ if not self._check_instrument(instrument, strict=True, log=True):
2483
+ return
2484
+ self.instruments = [i for i in self.instruments if i != instrument] + [instrument]
2485
+ self._build_arrays()
2343
2486
 
2344
- def save(self, directory=None, instrument=None, full=False, postfix=None,
2345
- save_masked=False, save_nans=True):
2346
- """ Save the observations in .rdb files.
2487
+ def save(self, directory=None, instrument=None, format='rdb',
2488
+ indicators=False, join_instruments=False, postfix=None,
2489
+ save_masked=False, save_nans=True, **kwargs):
2490
+ """ Save the observations in .rdb or .csv files.
2347
2491
 
2348
2492
  Args:
2349
2493
  directory (str, optional):
2350
2494
  Directory where to save the .rdb files.
2351
2495
  instrument (str, optional):
2352
2496
  Instrument for which to save observations.
2353
- full (bool, optional):
2354
- Save just RVs and errors (False) or more indicators (True).
2497
+ format (str, optional):
2498
+ Format to use ('rdb' or 'csv').
2499
+ indicators (bool, str, list[str], optional):
2500
+ Save only RVs and errors (False) or more indicators. If True,
2501
+ use a default list, if `str`, use an existing list, if list[str]
2502
+ provide a sequence of specific indicators.
2503
+ join_instruments (bool, optional):
2504
+ Join all instruments in a single file.
2355
2505
  postfix (str, optional):
2356
2506
  Postfix to add to the filenames ([star]_[instrument]_[postfix].rdb).
2507
+ save_masked (bool, optional)
2508
+ If True, also save masked observations (those for which
2509
+ self.mask == False)
2357
2510
  save_nans (bool, optional)
2358
2511
  Whether to save NaN values in the indicators, if they exist. If
2359
2512
  False, the full observation which contains NaN values is not saved.
2360
2513
  """
2514
+ if format not in ('rdb', 'csv'):
2515
+ logger.error(f"format must be 'rdb' or 'csv', got '{format}'")
2516
+ return
2517
+
2361
2518
  star_name = self.star.replace(' ', '')
2362
2519
 
2363
- if directory is None:
2364
- directory = '.'
2365
- else:
2520
+ if directory is not None:
2366
2521
  os.makedirs(directory, exist_ok=True)
2367
2522
 
2523
+ indicator_sets = {
2524
+ "default": [
2525
+ "fwhm", "fwhm_err",
2526
+ "bispan", "bispan_err",
2527
+ "contrast", "contrast_err",
2528
+ "rhk", "rhk_err",
2529
+ "berv",
2530
+ ],
2531
+ "CORALIE": [
2532
+ "fwhm", "fwhm_err",
2533
+ "bispan", "bispan_err",
2534
+ "contrast", "contrast_err",
2535
+ "haindex", "haindex_err",
2536
+ "berv",
2537
+ ],
2538
+ }
2539
+
2540
+ if 'full' in kwargs:
2541
+ logger.warning('argument `full` is deprecated, use `indicators` instead')
2542
+ indicators = kwargs['full']
2543
+
2368
2544
  files = []
2369
2545
 
2370
- for inst in self.instruments:
2546
+ for _i, inst in enumerate(self.instruments):
2371
2547
  if instrument is not None:
2372
2548
  if instrument not in inst:
2373
2549
  continue
@@ -2377,75 +2553,95 @@ class RV(ISSUES, REPORTS):
2377
2553
  if not _s.mask.any(): # all observations are masked, don't save
2378
2554
  continue
2379
2555
 
2380
- if full:
2381
- if save_masked:
2382
- arrays = [
2383
- _s.time, _s.vrad, _s.svrad,
2384
- _s.fwhm, _s.fwhm_err,
2385
- _s.bispan, _s.bispan_err,
2386
- _s.contrast, _s.contrast_err,
2387
- _s.rhk, _s.rhk_err,
2388
- _s.berv,
2389
- ]
2390
- else:
2391
- arrays = [
2392
- _s.mtime, _s.mvrad, _s.msvrad,
2393
- _s.fwhm[_s.mask], _s.fwhm_err[_s.mask],
2394
- _s.bispan[_s.mask], _s.bispan_err[_s.mask],
2395
- _s.contrast[_s.mask], _s.contrast_err[_s.mask],
2396
- _s.rhk[_s.mask], _s.rhk_err[_s.mask],
2397
- _s.berv[_s.mask],
2398
- ]
2399
- if not save_nans:
2400
- raise NotImplementedError
2401
- # if np.isnan(d).any():
2402
- # # remove observations where any of the indicators are # NaN
2403
- # nan_mask = np.isnan(d[:, 3:]).any(axis=1)
2404
- # d = d[~nan_mask]
2405
- # if self.verbose:
2406
- # logger.warning(f'masking {nan_mask.sum()} observations with NaN in indicators')
2407
-
2408
- header = '\t'.join(['rjd', 'vrad', 'svrad',
2409
- 'fwhm', 'sig_fwhm',
2410
- 'bispan', 'sig_bispan',
2411
- 'contrast', 'sig_contrast',
2412
- 'rhk', 'sig_rhk',
2413
- 'berv',
2414
- ])
2415
- header += '\n'
2416
- header += '\t'.join(['-' * len(c) for c in header.strip().split('\t')])
2556
+ if save_masked:
2557
+ arrays = [_s.time, _s.vrad, _s.svrad]
2558
+ if join_instruments:
2559
+ arrays += [_s.instrument_array]
2560
+ else:
2561
+ arrays = [_s.mtime, _s.mvrad, _s.msvrad]
2562
+ if join_instruments:
2563
+ arrays += [_s.instrument_array[_s.mask]]
2564
+
2565
+ if indicators in (False, None):
2566
+ indicator_names = []
2567
+ else:
2568
+ if indicators is True:
2569
+ indicator_names = indicator_sets["default"]
2570
+ elif isinstance(indicators, str):
2571
+ try:
2572
+ indicator_names = indicator_sets[indicators]
2573
+ except KeyError:
2574
+ logger.error(f"unknown indicator set '{indicators}'")
2575
+ logger.error(f"available: {list(indicator_sets.keys())}")
2576
+ return
2577
+ elif isinstance(indicators, list) and all(isinstance(i, str) for i in indicators):
2578
+ indicator_names = indicators
2417
2579
 
2580
+ if save_masked:
2581
+ arrays += [getattr(_s, ind) for ind in indicator_names]
2418
2582
  else:
2419
- if save_masked:
2420
- arrays = [_s.time, _s.vrad, _s.svrad]
2421
- else:
2422
- arrays = [_s.mtime, _s.mvrad, _s.msvrad]
2583
+ arrays += [getattr(_s, ind)[_s.mask] for ind in indicator_names]
2584
+
2585
+ d = np.stack(arrays, axis=1)
2586
+ if not save_nans:
2587
+ # raise NotImplementedError
2588
+ if np.isnan(d).any():
2589
+ # remove observations where any of the indicators are # NaN
2590
+ nan_mask = np.isnan(d[:, 3:]).any(axis=1)
2591
+ d = d[~nan_mask]
2592
+ if self.verbose:
2593
+ msg = f'{inst}: masking {nan_mask.sum()} observations with NaN in indicators'
2594
+ logger.warning(msg)
2595
+
2596
+ cols = ['rjd', 'vrad', 'svrad']
2597
+ cols += ['inst'] if join_instruments else []
2598
+ cols += indicator_names
2423
2599
 
2424
- # d = np.stack(arrays, axis=1)
2425
- header = 'rjd\tvrad\tsvrad\n---\t----\t-----'
2600
+ if format == 'rdb':
2601
+ header = '\t'.join(cols)
2602
+ header += '\n'
2603
+ header += '\t'.join(['-' * len(c) for c in header.strip().split('\t')])
2604
+ else:
2605
+ header = ','.join(cols)
2426
2606
 
2427
- file = f'{star_name}_{inst}.rdb'
2428
- if postfix is not None:
2429
- file = f'{star_name}_{inst}_{postfix}.rdb'
2607
+ if join_instruments:
2608
+ file = f'{star_name}.{format}'
2609
+ if postfix is not None:
2610
+ file = f'{star_name}_{postfix}.{format}'
2611
+ else:
2612
+ file = f'{star_name}_{inst}.{format}'
2613
+ if postfix is not None:
2614
+ file = f'{star_name}_{inst}_{postfix}.{format}'
2430
2615
 
2616
+ if directory is not None:
2617
+ file = os.path.join(directory, file)
2431
2618
  files.append(file)
2432
- file = os.path.join(directory, file)
2433
2619
 
2434
2620
  N = len(arrays[0])
2435
- with open(file, 'w') as f:
2436
- f.write(header + '\n')
2621
+ with open(file, 'a' if join_instruments and _i != 0 else 'w') as f:
2622
+ if join_instruments and _i != 0:
2623
+ pass
2624
+ else:
2625
+ f.write(header + '\n')
2626
+
2437
2627
  for i in range(N):
2438
2628
  for j, a in enumerate(arrays):
2439
2629
  f.write(str(a[i]))
2440
2630
  if j < len(arrays) - 1:
2441
- f.write('\t')
2631
+ f.write('\t' if format == 'rdb' else ',')
2442
2632
  f.write('\n')
2443
2633
 
2444
2634
  # np.savetxt(file, d, header=header, delimiter='\t', comments='', fmt='%f')
2445
2635
 
2446
- if self.verbose:
2636
+ if self.verbose and not join_instruments:
2447
2637
  logger.info(f'saving to {file}')
2448
2638
 
2639
+ if self.verbose and join_instruments:
2640
+ logger.info(f'saving to {files[0]}')
2641
+
2642
+ if join_instruments:
2643
+ files = [files[0]]
2644
+
2449
2645
  return files
2450
2646
 
2451
2647
  def checksum(self, write_to=None):