arvi 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/timeseries.py CHANGED
@@ -1,17 +1,15 @@
1
1
  import os
2
2
  from dataclasses import dataclass, field
3
3
  from typing import Union
4
- from functools import partial
4
+ from functools import partial, partialmethod
5
5
  from glob import glob
6
6
  import warnings
7
7
  from copy import deepcopy
8
8
  from datetime import datetime, timezone
9
9
  import numpy as np
10
10
 
11
- from astropy import units
12
-
13
11
  from .setup_logger import logger
14
- from . import config
12
+ from .config import config
15
13
  from .translations import translate
16
14
  from .dace_wrapper import do_download_filetype, do_symlink_filetype, get_observations, get_arrays
17
15
  from .simbad_wrapper import simbad
@@ -20,8 +18,11 @@ from .extra_data import get_extra_data
20
18
  from .stats import wmean, wrms
21
19
  from .binning import bin_ccf_mask, binRV
22
20
  from .HZ import getHZ_period
23
- from .utils import strtobool, there_is_internet, timer
21
+ from .utils import strtobool, there_is_internet, timer, chdir
22
+ from .utils import lazy_import
24
23
 
24
+ units = lazy_import('astropy.units')
25
+ # from astropy import units
25
26
 
26
27
  class ExtraFields:
27
28
  pass
@@ -54,12 +55,19 @@ class RV:
54
55
  do_adjust_means: bool = field(init=True, repr=False, default=True)
55
56
  only_latest_pipeline: bool = field(init=True, repr=False, default=True)
56
57
  load_extra_data: Union[bool, str] = field(init=True, repr=False, default=False)
58
+ check_drs_qc: bool = field(init=True, repr=False, default=True)
57
59
  #
60
+ units = 'm/s'
58
61
  _child: bool = field(init=True, repr=False, default=False)
59
62
  _did_secular_acceleration: bool = field(init=False, repr=False, default=False)
60
63
  _did_sigma_clip: bool = field(init=False, repr=False, default=False)
61
64
  _did_adjust_means: bool = field(init=False, repr=False, default=False)
65
+ _did_simbad_query: bool = field(init=False, repr=False, default=False)
66
+ _did_gaia_query: bool = field(init=False, repr=False, default=False)
62
67
  _raise_on_error: bool = field(init=True, repr=False, default=True)
68
+ #
69
+ _simbad = None
70
+ _gaia = None
63
71
 
64
72
  def __repr__(self):
65
73
  if self.N == 0:
@@ -70,60 +78,126 @@ class RV:
70
78
  nmasked = self.N - self.mtime.size
71
79
  return f"RV(star='{self.star}', N={self.N}, masked={nmasked})"
72
80
 
81
+ @property
82
+ def simbad(self):
83
+ if self._simbad is not None:
84
+ return self._simbad
85
+
86
+ if self._child:
87
+ return None
88
+
89
+ if self._did_simbad_query:
90
+ return None
91
+
92
+ if self.verbose:
93
+ logger.info('querying Simbad...')
94
+
95
+ # complicated way to query Simbad with self.__star__ or, if that
96
+ # fails, try after removing a trailing 'A'
97
+ for target in set([self.__star__, self.__star__.replace('A', '')]):
98
+ try:
99
+ self._simbad = simbad(target)
100
+ break
101
+ except ValueError:
102
+ continue
103
+ else:
104
+ if self.verbose:
105
+ logger.error(f'simbad query for {self.__star__} failed')
106
+
107
+ self._did_simbad_query = True
108
+ return self._simbad
109
+
110
+ @property
111
+ def gaia(self):
112
+ if self._gaia is not None:
113
+ return self._gaia
114
+
115
+ if self._child:
116
+ return None
117
+
118
+ if self._did_gaia_query:
119
+ return None
120
+
121
+ if self.verbose:
122
+ logger.info('querying Gaia...')
123
+
124
+ # complicated way to query Gaia with self.__star__ or, if that fails,
125
+ # try after removing a trailing 'A'
126
+ for target in set([self.__star__, self.__star__.replace('A', '')]):
127
+ try:
128
+ self._gaia = gaia(target)
129
+ break
130
+ except ValueError:
131
+ continue
132
+ else:
133
+ if self.verbose:
134
+ logger.error(f'Gaia query for {self.__star__} failed')
135
+
136
+ self._did_gaia_query = True
137
+ return self._gaia
138
+
139
+ def __post_init_special_sun(self):
140
+ import pickle
141
+ from .extra_data import get_sun_data
142
+ path = get_sun_data(download=not self._child)
143
+ self.dace_result = pickle.load(open(path, 'rb'))
144
+
145
+
73
146
  def __post_init__(self):
74
147
  self.__star__ = translate(self.star)
75
148
 
76
- if not self._child:
77
- if config.check_internet and not there_is_internet():
78
- raise ConnectionError('There is no internet connection?')
149
+ if self.star.lower() == 'sun':
150
+ self.__post_init_special_sun()
151
+ self.do_secular_acceleration = False
152
+ self.units = 'km/s'
79
153
 
80
- # complicated way to query Simbad with self.__star__ or, if that
81
- # fails, try after removing a trailing 'A'
82
- for target in (self.__star__, self.__star__.replace('A', '')):
83
- try:
84
- self.simbad = simbad(target)
85
- break
86
- except ValueError:
87
- continue
88
- else:
154
+ else:
155
+ if not self._child:
156
+ if config.check_internet and not there_is_internet():
157
+ raise ConnectionError('There is no internet connection?')
158
+
159
+ # make Simbad and Gaia queries in parallel
160
+ import concurrent.futures
161
+ with timer('simbad and gaia queries'):
162
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
163
+ executor.map(self.__getattribute__, ('simbad', 'gaia'))
164
+
165
+ # with timer('simbad query'):
166
+ # self.simbad
167
+ # with timer('gaia query'):
168
+ # self.gaia
169
+
170
+ # query DACE
89
171
  if self.verbose:
90
- logger.error(f'simbad query for {self.__star__} failed')
91
-
92
- # complicated way to query Gaia with self.__star__ or, if that
93
- # fails, try after removing a trailing 'A'
94
- for target in (self.__star__, self.__star__.replace('A', '')):
172
+ logger.info(f'querying DACE for {self.__star__}...')
95
173
  try:
96
- self.gaia = gaia(target)
97
- break
98
- except ValueError:
99
- continue
100
- else:
101
- if self.verbose:
102
- logger.error(f'Gaia query for {self.__star__} failed')
103
-
104
- # query DACE
105
- if self.verbose:
106
- logger.info(f'querying DACE for {self.__star__}...')
107
- try:
108
- with timer():
109
- mid = self.simbad.main_id if hasattr(self, 'simbad') else None
110
- self.dace_result = get_observations(self.__star__, self.instrument,
111
- main_id=mid, verbose=self.verbose)
112
- except ValueError as e:
113
- # querying DACE failed, should we raise an error?
114
- if self._raise_on_error:
115
- raise e
116
- else:
117
- self.time = np.array([])
118
- self.instruments = []
119
- self.units = ''
120
- return
174
+ if hasattr(self, 'simbad') and self.simbad is not None:
175
+ mid = self.simbad.main_id
176
+ else:
177
+ mid = None
178
+
179
+ with timer():
180
+ self.dace_result = get_observations(self.__star__, self.instrument,
181
+ main_id=mid, verbose=self.verbose)
182
+ except ValueError as e:
183
+ # querying DACE failed, should we raise an error?
184
+ if self._raise_on_error:
185
+ raise e
186
+ else:
187
+ self.time = np.array([])
188
+ self.instruments = []
189
+ self.units = ''
190
+ return
121
191
 
122
- # store the date of the last DACE query
123
- time_stamp = datetime.now(timezone.utc) #.isoformat().split('.')[0]
124
- self._last_dace_query = time_stamp
192
+ # store the date of the last DACE query
193
+ time_stamp = datetime.now(timezone.utc) #.isoformat().split('.')[0]
194
+ self._last_dace_query = time_stamp
125
195
 
126
- self.units = 'm/s'
196
+ _replacements = (('-', '_'), ('.', '_'), ('__', '_'))
197
+ def do_replacements(s):
198
+ for a, b in _replacements:
199
+ s = s.replace(a, b)
200
+ return s
127
201
 
128
202
  # build children
129
203
  if not self._child:
@@ -133,9 +207,9 @@ class RV:
133
207
 
134
208
  for (inst, pipe, mode), data in arrays:
135
209
  child = RV.from_dace_data(self.star, inst, pipe, mode, data, _child=True,
136
- verbose=self.verbose)
137
- inst = inst.replace('-', '_')
138
- pipe = pipe.replace('.', '_').replace('__', '_')
210
+ check_drs_qc=self.check_drs_qc, verbose=self.verbose)
211
+ inst = do_replacements(inst)
212
+ pipe = do_replacements(pipe)
139
213
  if self.only_latest_pipeline:
140
214
  # save as self.INST
141
215
  setattr(self, inst, child)
@@ -148,16 +222,14 @@ class RV:
148
222
  #! sorted?
149
223
  if self.only_latest_pipeline:
150
224
  self.instruments = [
151
- inst.replace('-', '_')
225
+ do_replacements(inst)
152
226
  for (inst, _, _), _ in arrays
153
227
  ]
154
228
  else:
155
229
  self.instruments = [
156
- inst.replace('-', '_') + '_' + pipe.replace('.', '_').replace('__', '_')
230
+ do_replacements(inst) + '_' + do_replacements(pipe)
157
231
  for (inst, pipe, _), _ in arrays
158
232
  ]
159
- # self.pipelines =
160
-
161
233
  # all other quantities
162
234
  self._build_arrays()
163
235
 
@@ -190,6 +262,8 @@ class RV:
190
262
 
191
263
  if self.do_adjust_means:
192
264
  self.adjust_means()
265
+
266
+ self._download_directory = f'{self.star.replace(" ", "")}_downloads'
193
267
 
194
268
  def __add__(self, other, inplace=False):
195
269
  # if not isinstance(other, self.__class__):
@@ -232,7 +306,7 @@ class RV:
232
306
  file = f'{star_name}_{ts}.pkl'
233
307
  pickle.dump(self, open(file, 'wb'), protocol=0)
234
308
  if self.verbose:
235
- logger.info(f'Saved snapshot to {file}')
309
+ logger.info(f'saved snapshot to {file}')
236
310
 
237
311
  @property
238
312
  def N(self) -> int:
@@ -312,6 +386,10 @@ class RV:
312
386
  def _mtime_sorter(self):
313
387
  return np.argsort(self.mtime)
314
388
 
389
+ def _index_from_instrument_index(self, index, instrument):
390
+ ind = np.where(self.instrument_array == instrument)[0]
391
+ return ind[getattr(self, instrument).mask][index]
392
+
315
393
  @property
316
394
  def _tt(self) -> np.ndarray:
317
395
  return np.linspace(self.mtime.min(), self.mtime.max(), 20*self.N)
@@ -319,6 +397,7 @@ class RV:
319
397
  @classmethod
320
398
  def from_dace_data(cls, star, inst, pipe, mode, data, **kwargs):
321
399
  verbose = kwargs.pop('verbose', False)
400
+ check_drs_qc = kwargs.pop('check_drs_qc', True)
322
401
  s = cls(star, **kwargs)
323
402
  #
324
403
  ind = np.argsort(data['rjd'])
@@ -344,9 +423,9 @@ class RV:
344
423
  s._quantities.append('ccf_mask')
345
424
  else:
346
425
  # be careful with bogus values in rhk and rhk_err
347
- if arr in ('rhk', 'rhk_err'):
348
- mask99999 = (data[arr] == -99999) | (data[arr] == -99)
349
- data[arr][mask99999] = np.nan
426
+ # --> not just in rhk and rhk_err...
427
+ if data[arr].dtype == float and (bad := data[arr] == -99999).any():
428
+ data[arr][bad] = np.nan
350
429
 
351
430
  setattr(s, arr, data[arr][ind])
352
431
  s._quantities.append(arr)
@@ -354,7 +433,7 @@ class RV:
354
433
  s._quantities = np.array(s._quantities)
355
434
 
356
435
  # mask out drs_qc = False
357
- if not s.drs_qc.all():
436
+ if check_drs_qc and not s.drs_qc.all():
358
437
  n = (~s.drs_qc).sum()
359
438
  if verbose:
360
439
  logger.warning(f'masking {n} points where DRS QC failed for {inst}')
@@ -406,8 +485,11 @@ class RV:
406
485
 
407
486
  dt = datetime.fromtimestamp(float(timestamp))
408
487
  if verbose:
409
- logger.info(f'Reading snapshot of {star} from {dt}')
410
- return pickle.load(open(file, 'rb'))
488
+ logger.info(f'reading snapshot of {star} from {dt}')
489
+
490
+ s = pickle.load(open(file, 'rb'))
491
+ s._snapshot = file
492
+ return s
411
493
 
412
494
  @classmethod
413
495
  def from_rdb(cls, files, star=None, instrument=None, units='ms', **kwargs):
@@ -474,12 +556,16 @@ class RV:
474
556
  names = header.split()
475
557
 
476
558
  if len(names) > 3:
477
- kw = dict(skip_header=0, comments='--', names=True, dtype=None, encoding=None)
559
+ if f.endswith('.rdb'):
560
+ kw = dict(skip_header=2, dtype=None, encoding=None)
561
+ else:
562
+ kw = dict(skip_header=0, comments='--', names=True, dtype=None, encoding=None)
478
563
  if '\t' in header:
479
564
  data = np.genfromtxt(f, **kw, delimiter='\t')
480
565
  else:
481
566
  data = np.genfromtxt(f, **kw)
482
- # data.dtype.names = names
567
+ if len(names) == len(data.dtype.names):
568
+ data.dtype.names = names
483
569
  else:
484
570
  data = np.array([], dtype=np.dtype([]))
485
571
 
@@ -591,13 +677,11 @@ class RV:
591
677
 
592
678
  _s.fwhm = np.array([i.FWHM*1e3 for i in CCFs])
593
679
  _s.fwhm_err = np.array([i.FWHMerror*1e3 for i in CCFs])
594
-
595
680
  _quantities.append('fwhm')
596
681
  _quantities.append('fwhm_err')
597
682
 
598
683
  _s.contrast = np.array([i.contrast for i in CCFs])
599
684
  _s.contrast_err = np.array([i.contrast_error for i in CCFs])
600
-
601
685
  _quantities.append('contrast')
602
686
  _quantities.append('contrast_err')
603
687
 
@@ -618,7 +702,6 @@ class RV:
618
702
  if verbose:
619
703
  logger.warning(f'masking {n} points where DRS QC failed for {instrument}')
620
704
  _s.mask &= _s.drs_qc
621
- print(_s.mask)
622
705
 
623
706
  _s._quantities = np.array(_quantities)
624
707
  setattr(s, instrument, _s)
@@ -714,8 +797,17 @@ class RV:
714
797
  )
715
798
  setattr(self, q, arr)
716
799
 
800
+ @property
801
+ def download_directory(self):
802
+ """ Directory where to download data """
803
+ return self._download_directory
804
+
805
+ @download_directory.setter
806
+ def download_directory(self, value):
807
+ self._download_directory = value
808
+
717
809
  def download_ccf(self, instrument=None, index=None, limit=None,
718
- directory=None, symlink=False, **kwargs):
810
+ directory=None, symlink=False, load=True, **kwargs):
719
811
  """ Download CCFs from DACE
720
812
 
721
813
  Args:
@@ -724,17 +816,13 @@ class RV:
724
816
  limit (int): Maximum number of files to download.
725
817
  directory (str): Directory where to store data.
726
818
  """
727
- if directory is None:
728
- directory = f'{self.star}_downloads'
819
+ directory = directory or self.download_directory
729
820
 
730
- if instrument is None:
731
- files = [file for file in self.raw_file if file.endswith('.fits')]
732
- else:
733
- strict = kwargs.pop('strict', False)
734
- instrument = self._check_instrument(instrument, strict=strict)
735
- files = []
736
- for inst in instrument:
737
- files += list(getattr(self, inst).raw_file)
821
+ strict = kwargs.pop('strict', False)
822
+ instrument = self._check_instrument(instrument, strict=strict)
823
+ files = []
824
+ for inst in instrument:
825
+ files += list(getattr(self, inst).raw_file)
738
826
 
739
827
  if index is not None:
740
828
  index = np.atleast_1d(index)
@@ -750,6 +838,23 @@ class RV:
750
838
  else:
751
839
  do_download_filetype('CCF', files[:limit], directory, verbose=self.verbose, **kwargs)
752
840
 
841
+ if load:
842
+ try:
843
+ from os.path import basename, join
844
+ from .utils import sanitize_path
845
+ import iCCF
846
+ downloaded = [
847
+ sanitize_path(join(directory, basename(f).replace('.fits', '_CCF_A.fits')))
848
+ for f in files[:limit]
849
+ ]
850
+ if self.verbose:
851
+ logger.info('loading the CCF(s) into `.CCF` attribute')
852
+
853
+ self.CCF = iCCF.from_file(downloaded)
854
+
855
+ except (ImportError, ValueError):
856
+ pass
857
+
753
858
  def download_s1d(self, instrument=None, index=None, limit=None,
754
859
  directory=None, symlink=False, **kwargs):
755
860
  """ Download S1Ds from DACE
@@ -760,17 +865,13 @@ class RV:
760
865
  limit (int): Maximum number of files to download.
761
866
  directory (str): Directory where to store data.
762
867
  """
763
- if directory is None:
764
- directory = f'{self.star}_downloads'
868
+ directory = directory or self.download_directory
765
869
 
766
- if instrument is None:
767
- files = [file for file in self.raw_file if file.endswith('.fits')]
768
- else:
769
- strict = kwargs.pop('strict', False)
770
- instrument = self._check_instrument(instrument, strict=strict)
771
- files = []
772
- for inst in instrument:
773
- files += list(getattr(self, inst).raw_file)
870
+ strict = kwargs.pop('strict', False)
871
+ instrument = self._check_instrument(instrument, strict=strict)
872
+ files = []
873
+ for inst in instrument:
874
+ files += list(getattr(self, inst).raw_file)
774
875
 
775
876
  if index is not None:
776
877
  index = np.atleast_1d(index)
@@ -796,17 +897,13 @@ class RV:
796
897
  limit (int): Maximum number of files to download.
797
898
  directory (str): Directory where to store data.
798
899
  """
799
- if directory is None:
800
- directory = f'{self.star}_downloads'
900
+ directory = directory or self.download_directory
801
901
 
802
- if instrument is None:
803
- files = [file for file in self.raw_file if file.endswith('.fits')]
804
- else:
805
- strict = kwargs.pop('strict', False)
806
- instrument = self._check_instrument(instrument, strict=strict)
807
- files = []
808
- for inst in instrument:
809
- files += list(getattr(self, inst).raw_file)
902
+ strict = kwargs.pop('strict', False)
903
+ instrument = self._check_instrument(instrument, strict=strict)
904
+ files = []
905
+ for inst in instrument:
906
+ files += list(getattr(self, inst).raw_file)
810
907
 
811
908
  if index is not None:
812
909
  index = np.atleast_1d(index)
@@ -859,8 +956,9 @@ class RV:
859
956
  instruments = self._check_instrument(instrument, strict)
860
957
 
861
958
  if instruments is None:
862
- logger.error(f"No data from instrument '{instrument}'")
863
- logger.info(f'available: {self.instruments}')
959
+ if self.verbose:
960
+ logger.error(f"No data from instrument '{instrument}'")
961
+ logger.info(f'available: {self.instruments}')
864
962
  return
865
963
 
866
964
  for instrument in instruments:
@@ -922,7 +1020,11 @@ class RV:
922
1020
  return
923
1021
 
924
1022
  if self.verbose:
925
- logger.info(f'removing points {index}')
1023
+ inst = np.unique(self.instrument_array[index])
1024
+ if len(index) == 1:
1025
+ logger.info(f'removing point {index[0]} from {inst[0]}')
1026
+ else:
1027
+ logger.info(f'removing points {index} from {inst}')
926
1028
 
927
1029
  self.mask[index] = False
928
1030
  self._propagate_mask_changes()
@@ -932,6 +1034,31 @@ class RV:
932
1034
  if config.return_self:
933
1035
  return self
934
1036
 
1037
+ def restore_point(self, index):
1038
+ """
1039
+ Restore previously deleted individual observations at a given index (or
1040
+ indices). NOTE: Like Python, the index is 0-based.
1041
+
1042
+ Args:
1043
+ index (int, list, ndarray):
1044
+ Single index, list, or array of indices to restore.
1045
+ """
1046
+ index = np.atleast_1d(index)
1047
+ try:
1048
+ instrument_index = self.obs[index]
1049
+ np.array(self.instruments)[instrument_index - 1]
1050
+ except IndexError:
1051
+ logger.errors(f'index {index} is out of bounds for N={self.N}')
1052
+ return
1053
+
1054
+ if self.verbose:
1055
+ logger.info(f'restoring point{"s" if index.size > 1 else ""} {index}')
1056
+
1057
+ self.mask[index] = True
1058
+ self._propagate_mask_changes()
1059
+ if config.return_self:
1060
+ return self
1061
+
935
1062
  def remove_non_public(self):
936
1063
  """ Remove non-public observations """
937
1064
  if self.verbose:
@@ -1040,6 +1167,11 @@ class RV:
1040
1167
  self.gaia
1041
1168
  self.gaia.plx
1042
1169
 
1170
+ if self.gaia.plx < 0:
1171
+ if self.verbose:
1172
+ logger.error('negative Gaia parallax, falling back to Simbad')
1173
+ raise AttributeError
1174
+
1043
1175
  if self.verbose:
1044
1176
  logger.info('using Gaia information to remove secular acceleration')
1045
1177
 
@@ -1054,10 +1186,11 @@ class RV:
1054
1186
  μ = μα**2 + μδ**2
1055
1187
  sa = (μ * d).to(units.m / units.second / units.year,
1056
1188
  equivalencies=units.dimensionless_angles())
1057
-
1058
1189
  except AttributeError:
1059
1190
  try:
1060
1191
  self.simbad
1192
+ if self.simbad is None:
1193
+ raise AttributeError
1061
1194
  except AttributeError:
1062
1195
  if self.verbose:
1063
1196
  logger.error('no information from simbad, cannot remove secular acceleration')
@@ -1158,6 +1291,7 @@ class RV:
1158
1291
  return
1159
1292
 
1160
1293
  instruments = self._check_instrument(instrument, strict)
1294
+ changed_instruments = []
1161
1295
 
1162
1296
  for inst in instruments:
1163
1297
  m = self.instrument_array == inst
@@ -1170,6 +1304,10 @@ class RV:
1170
1304
  s = 's' if (n == 0 or n > 1) else ''
1171
1305
  logger.warning(f'sigma-clip RVs will remove {n} point{s} for {inst}')
1172
1306
 
1307
+ if n > 0:
1308
+ self.mask[ind] = False
1309
+ changed_instruments.append(inst)
1310
+
1173
1311
  # # check if going to remove all observations from one instrument
1174
1312
  # if n in self.NN.values(): # all observations
1175
1313
  # # insts = np.unique(self.instrument_array[~ind])
@@ -1180,13 +1318,11 @@ class RV:
1180
1318
  # return self
1181
1319
  # continue
1182
1320
 
1183
- self.mask[ind] = False
1184
-
1185
1321
  self._propagate_mask_changes()
1186
1322
 
1187
1323
  if self._did_adjust_means:
1188
1324
  self._did_adjust_means = False
1189
- self.adjust_means()
1325
+ self.adjust_means(instrument=changed_instruments)
1190
1326
 
1191
1327
  if config.return_self:
1192
1328
  return self
@@ -1308,11 +1444,16 @@ class RV:
1308
1444
  snew._build_arrays()
1309
1445
  return snew
1310
1446
 
1311
- def nth_day_mean(self, n=1.0):
1447
+ def nth_day_mean(self, n=1.0, masked=True):
1312
1448
  """ Calculate the n-th day rolling mean of the radial velocities """
1313
- mask = np.abs(self.mtime[:, None] - self.mtime[None, :]) < n
1314
- z = np.full((self.mtime.size, self.mtime.size), np.nan)
1315
- z[mask] = np.repeat(self.mvrad[:, None], self.mtime.size, axis=1)[mask]
1449
+ if masked:
1450
+ mask = np.abs(self.mtime[:, None] - self.mtime[None, :]) < n
1451
+ z = np.full((self.mtime.size, self.mtime.size), np.nan)
1452
+ z[mask] = np.repeat(self.mvrad[:, None], self.mtime.size, axis=1)[mask]
1453
+ else:
1454
+ mask = np.abs(self.time[:, None] - self.time[None, :]) < n
1455
+ z = np.full((self.time.size, self.time.size), np.nan)
1456
+ z[mask] = np.repeat(self.vrad[:, None], self.time.size, axis=1)[mask]
1316
1457
  return np.nanmean(z, axis=0)
1317
1458
 
1318
1459
  def subtract_mean(self):
@@ -1334,13 +1475,26 @@ class RV:
1334
1475
  s.vrad += self._meanRV
1335
1476
  self._build_arrays()
1336
1477
 
1337
- def adjust_means(self, just_rv=False):
1338
- """ Subtract individual mean RVs from each instrument """
1478
+ def adjust_means(self, just_rv=False, instrument=None, **kwargs):
1479
+ """
1480
+ Subtract individual mean RVs from each instrument or from specific
1481
+ instruments
1482
+ """
1339
1483
  if self._child or self._did_adjust_means:
1340
1484
  return
1341
1485
 
1486
+ # if self.verbose:
1487
+ # print_as_table = len(self.instruments) > 2 and len(self.instruments) < 7
1488
+ # rows = [self.instruments]
1489
+ # row = []
1490
+ # if print_as_table:
1491
+ # logger.info('subtracted weighted average from each instrument:')
1492
+
1342
1493
  others = ('fwhm', 'bispan', )
1343
- for inst in self.instruments:
1494
+
1495
+ instruments = self._check_instrument(instrument, strict=kwargs.get('strict', False))
1496
+
1497
+ for inst in instruments:
1344
1498
  s = getattr(self, inst)
1345
1499
 
1346
1500
  if s.mtime.size == 0:
@@ -1361,33 +1515,105 @@ class RV:
1361
1515
  s.vrad -= s.rv_mean
1362
1516
 
1363
1517
  if self.verbose:
1518
+ # if print_as_table:
1519
+ # row.append(f'{s.rv_mean:.3f}')
1520
+ # else:
1364
1521
  logger.info(f'subtracted weighted average from {inst:10s}: ({s.rv_mean:.3f} {self.units})')
1365
1522
 
1366
1523
  if just_rv:
1367
1524
  continue
1368
1525
 
1369
1526
  for i, other in enumerate(others):
1370
- y, ye = getattr(s, other), getattr(s, other + '_err')
1527
+ try:
1528
+ y, ye = getattr(s, other), getattr(s, other + '_err')
1529
+ except AttributeError:
1530
+ continue
1371
1531
  m = wmean(y[s.mask], ye[s.mask])
1372
1532
  setattr(s, f'{other}_mean', m)
1373
1533
  setattr(s, other, getattr(s, other) - m)
1374
1534
 
1535
+ # if print_as_table:
1536
+ # from .utils import pretty_print_table
1537
+ # rows.append(row)
1538
+ # pretty_print_table(rows, logger=logger)
1539
+
1375
1540
  self._build_arrays()
1376
1541
  self._did_adjust_means = True
1377
1542
  if config.return_self:
1378
1543
  return self
1379
1544
 
1380
1545
  def add_to_vrad(self, values):
1381
- """ Add an array of values to the RVs of all instruments """
1546
+ """ Add a value of array of values to the RVs of all instruments """
1547
+ values = np.atleast_1d(values)
1548
+ if values.size == 1:
1549
+ values = np.full_like(self.vrad, values)
1550
+
1551
+ masked = False
1382
1552
  if values.size != self.vrad.size:
1383
- raise ValueError(f"incompatible sizes: len(values) must equal self.N, got {values.size} != {self.vrad.size}")
1553
+ if values.size == self.mvrad.size:
1554
+ logger.warning('adding to masked RVs only')
1555
+ masked = True
1556
+ else:
1557
+ raise ValueError(f"incompatible sizes: len(values) must equal self.N, got {values.size} != {self.vrad.size}")
1558
+
1559
+ for inst in self.instruments:
1560
+ s = getattr(self, inst)
1561
+ if masked:
1562
+ mask = self.instrument_array[self.mask] == inst
1563
+ s.vrad[s.mask] += values[mask]
1564
+ else:
1565
+ mask = self.instrument_array == inst
1566
+ s.vrad += values[mask]
1567
+ self._build_arrays()
1568
+
1569
+ def add_to_quantity(self, quantity, values):
1570
+ """
1571
+ Add a value of array of values to the given quantity of all instruments
1572
+ """
1573
+ if not hasattr(self, quantity):
1574
+ logger.error(f"cannot find '{quantity}' attribute")
1575
+ return
1576
+ q = getattr(self, quantity)
1577
+
1578
+ values = np.atleast_1d(values)
1579
+ if values.size == 1:
1580
+ values = np.full_like(q, values)
1581
+ if values.size != q.size:
1582
+ raise ValueError(f"incompatible sizes: len(values) must equal self.N, got {values.size} != {q.size}")
1384
1583
 
1385
1584
  for inst in self.instruments:
1386
1585
  s = getattr(self, inst)
1387
1586
  mask = self.instrument_array == inst
1388
- s.vrad += values[mask]
1587
+ setattr(s, quantity, getattr(s, quantity) + values[mask])
1588
+ self._build_arrays()
1589
+
1590
+ def change_units(self, new_units):
1591
+ possible = {'m/s': 'm/s', 'km/s': 'km/s', 'ms': 'm/s', 'kms': 'km/s'}
1592
+ if new_units not in possible:
1593
+ msg = f"new_units must be one of 'm/s', 'km/s', 'ms', 'kms', got '{new_units}'"
1594
+ raise ValueError(msg)
1595
+
1596
+ new_units = possible[new_units]
1597
+ if new_units == self.units:
1598
+ return
1599
+
1600
+ if self.verbose:
1601
+ logger.info(f"changing units from {self.units} to {new_units}")
1602
+
1603
+ if new_units == 'm/s' and self.units == 'km/s':
1604
+ factor = 1e3
1605
+ elif new_units == 'km/s' and self.units == 'm/s':
1606
+ factor = 1e-3
1607
+
1608
+ for inst in self.instruments:
1609
+ s = getattr(self, inst)
1610
+ s.vrad *= factor
1611
+ s.svrad *= factor
1612
+ s.fwhm *= factor
1613
+ s.fwhm_err *= factor
1389
1614
 
1390
1615
  self._build_arrays()
1616
+ self.units = new_units
1391
1617
 
1392
1618
 
1393
1619
  def put_at_systemic_velocity(self):
@@ -1474,12 +1700,14 @@ class RV:
1474
1700
  _s.time, _s.vrad, _s.svrad,
1475
1701
  _s.fwhm, _s.fwhm_err,
1476
1702
  _s.rhk, _s.rhk_err,
1703
+ _s.bispan, _s.bispan_err,
1477
1704
  ]
1478
1705
  else:
1479
1706
  d = np.c_[
1480
1707
  _s.mtime, _s.mvrad, _s.msvrad,
1481
1708
  _s.fwhm[_s.mask], _s.fwhm_err[_s.mask],
1482
1709
  _s.rhk[_s.mask], _s.rhk_err[_s.mask],
1710
+ _s.bispan[_s.mask], _s.bispan_err[_s.mask],
1483
1711
  ]
1484
1712
  if not save_nans:
1485
1713
  if np.isnan(d).any():
@@ -1489,8 +1717,14 @@ class RV:
1489
1717
  if self.verbose:
1490
1718
  logger.warning(f'masking {nan_mask.sum()} observations with NaN in indicators')
1491
1719
 
1492
- header = 'bjd\tvrad\tsvrad\tfwhm\tsfwhm\trhk\tsrhk\n'
1493
- header += '---\t----\t-----\t----\t-----\t---\t----'
1720
+ header = '\t'.join(['bjd', 'vrad', 'svrad',
1721
+ 'fwhm', 'sfwhm',
1722
+ 'rhk', 'srhk',
1723
+ 'bispan', 'sbispan'
1724
+ ])
1725
+ header += '\n'
1726
+ header += '\t'.join(['-' * len(c) for c in header.strip().split('\t')])
1727
+
1494
1728
  else:
1495
1729
  if save_masked:
1496
1730
  d = np.c_[_s.time, _s.vrad, _s.svrad]
@@ -1678,3 +1912,62 @@ def fit_sine(t, y, yerr=None, period='gls', fix_period=False):
1678
1912
 
1679
1913
  xbest, _ = leastsq(f, p0, args=(t, y, yerr))
1680
1914
  return xbest, partial(sine, p=xbest)
1915
+
1916
+
1917
+ def fit_n_sines(t, y, yerr=None, n=1, period='gls', fix_period=False):
1918
+ """ Fit N sine curves of the form y = ∑i Ai * sin(2π * t / Pi + φi) + c
1919
+
1920
+ Args:
1921
+ t (ndarray):
1922
+ Time array
1923
+ y (ndarray):
1924
+ Array of observed values
1925
+ yerr (ndarray, optional):
1926
+ Array of uncertainties. Defaults to None.
1927
+ n (int, optional):
1928
+ Number of sine curves to fit. Defaults to 1.
1929
+ period (str or float, optional):
1930
+ Initial guess for periods or 'gls' to get them from Lomb-Scargle
1931
+ periodogram. Defaults to 'gls'.
1932
+ fix_period (bool, optional):
1933
+ Whether to fix the periods. Defaults to False.
1934
+
1935
+ Returns:
1936
+ p (ndarray):
1937
+ Best-fit parameters [A, P, φ, c] or [A, φ, c] for each sine curve
1938
+ f (callable):
1939
+ Function that returns the best-fit curve for input times
1940
+ """
1941
+ from scipy.optimize import leastsq
1942
+ if period == 'gls':
1943
+ from astropy.timeseries import LombScargle
1944
+ # first period guess
1945
+ gls = LombScargle(t, y, yerr)
1946
+ freq, power = gls.autopower()
1947
+ period = [1 / freq[power.argmax()]]
1948
+ yc = y.copy()
1949
+ for i in range(1, n):
1950
+ p, f = fit_sine(t, y, yerr, period=period[i-1], fix_period=True)
1951
+ yc -= f(t)
1952
+ gls = LombScargle(t, yc, yerr)
1953
+ freq, power = gls.autopower()
1954
+ period.append(1 / freq[power.argmax()])
1955
+ else:
1956
+ assert len(period) == n, f'wrong number of periods, expected {n} but got {len(period)}'
1957
+
1958
+ if yerr is None:
1959
+ yerr = np.ones_like(y)
1960
+
1961
+ if fix_period:
1962
+ def sine(t, p):
1963
+ return p[-1] + np.sum([p[2*i] * np.sin(2 * np.pi * t / period[i] + p[2*i+1]) for i in range(n)], axis=0)
1964
+ f = lambda p, t, y, ye: (sine(t, p) - y) / ye
1965
+ p0 = [y.std(), 0.0] * n + [y.mean()]
1966
+ else:
1967
+ def sine(t, p):
1968
+ return p[-1] + np.sum([p[3*i] * np.sin(2 * np.pi * t / p[3*i+1] + p[3*i+2]) for i in range(n)], axis=0)
1969
+ f = lambda p, t, y, ye: (sine(t, p) - y) / ye
1970
+ p0 = np.r_[np.insert([y.std(), 0.0] * n, np.arange(1, 2*n, n), period), y.mean()]
1971
+
1972
+ xbest, _ = leastsq(f, p0, args=(t, y, yerr))
1973
+ return xbest, partial(sine, p=xbest)