arvi 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/timeseries.py CHANGED
@@ -11,13 +11,15 @@ import numpy as np
11
11
  from astropy import units
12
12
 
13
13
  from .setup_logger import logger
14
- from .config import return_self
14
+ from .config import return_self, check_internet
15
15
  from .translations import translate
16
- from .dace_wrapper import get_observations, get_arrays
17
- from .dace_wrapper import do_download_ccf, do_download_s1d, do_download_s2d
16
+ from .dace_wrapper import do_download_filetype, get_observations, get_arrays
18
17
  from .simbad_wrapper import simbad
18
+ from .extra_data import get_extra_data
19
19
  from .stats import wmean, wrms
20
- from .binning import binRV
20
+ from .binning import bin_ccf_mask, binRV
21
+ from .HZ import getHZ_period
22
+ from .utils import strtobool, there_is_internet
21
23
 
22
24
 
23
25
  @dataclass
@@ -41,13 +43,13 @@ class RV:
41
43
  """
42
44
  star: str
43
45
  instrument: str = field(init=True, repr=False, default=None)
44
- N: int = field(init=False, repr=True)
45
46
  verbose: bool = field(init=True, repr=False, default=True)
46
47
  do_maxerror: Union[bool, float] = field(init=True, repr=False, default=False)
47
48
  do_secular_acceleration: bool = field(init=True, repr=False, default=True)
48
49
  do_sigma_clip: bool = field(init=True, repr=False, default=False)
49
50
  do_adjust_means: bool = field(init=True, repr=False, default=True)
50
51
  only_latest_pipeline: bool = field(init=True, repr=False, default=True)
52
+ load_extra_data: Union[bool, str] = field(init=True, repr=False, default=True)
51
53
  #
52
54
  _child: bool = field(init=True, repr=False, default=False)
53
55
  _did_secular_acceleration: bool = field(init=False, repr=False, default=False)
@@ -68,17 +70,28 @@ class RV:
68
70
  self.__star__ = translate(self.star)
69
71
 
70
72
  if not self._child:
71
- try:
72
- self.simbad = simbad(self.__star__)
73
- except ValueError as e:
74
- logger.error(e)
73
+ if check_internet and not there_is_internet():
74
+ raise ConnectionError('There is no internet connection?')
75
+ # complicated way to query Simbad with self.__star__ or, if that
76
+ # fails, try after removing a trailing 'A'
77
+ for target in (self.__star__, self.__star__.replace('A', '')):
78
+ try:
79
+ self.simbad = simbad(target)
80
+ break
81
+ except ValueError:
82
+ continue
83
+ else:
84
+ if self.verbose:
85
+ logger.error(f'simbad query for {self.__star__} failed')
75
86
 
87
+ # query DACE
76
88
  if self.verbose:
77
89
  logger.info(f'querying DACE for {self.__star__}...')
78
90
  try:
79
91
  self.dace_result = get_observations(self.__star__, self.instrument,
80
92
  verbose=self.verbose)
81
93
  except ValueError as e:
94
+ # querying DACE failed, should we raise an error?
82
95
  if self._raise_on_error:
83
96
  raise e
84
97
  else:
@@ -87,7 +100,6 @@ class RV:
87
100
  self.units = ''
88
101
  return
89
102
 
90
-
91
103
  # store the date of the last DACE query
92
104
  time_stamp = datetime.now(timezone.utc) #.isoformat().split('.')[0]
93
105
  self._last_dace_query = time_stamp
@@ -101,7 +113,8 @@ class RV:
101
113
  verbose=self.verbose)
102
114
 
103
115
  for (inst, pipe, mode), data in arrays:
104
- child = RV.from_dace_data(self.star, inst, pipe, mode, data, _child=True)
116
+ child = RV.from_dace_data(self.star, inst, pipe, mode, data, _child=True,
117
+ verbose=self.verbose)
105
118
  inst = inst.replace('-', '_')
106
119
  pipe = pipe.replace('.', '_').replace('__', '_')
107
120
  if self.only_latest_pipeline:
@@ -129,6 +142,22 @@ class RV:
129
142
  # all other quantities
130
143
  self._build_arrays()
131
144
 
145
+
146
+ if self.load_extra_data:
147
+ if isinstance(self.load_extra_data, str):
148
+ path = self.load_extra_data
149
+ else:
150
+ path = None
151
+ try:
152
+ self.__add__(get_extra_data(self.star, instrument=self.instrument, path=path),
153
+ inplace=True)
154
+
155
+ except FileNotFoundError:
156
+ pass
157
+
158
+ # all other quantities
159
+ self._build_arrays()
160
+
132
161
  # do clip_maxerror, secular_acceleration, sigmaclip, adjust_means
133
162
  if not self._child:
134
163
  if self.do_maxerror:
@@ -143,6 +172,31 @@ class RV:
143
172
  if self.do_adjust_means:
144
173
  self.adjust_means()
145
174
 
175
+ def __add__(self, other, inplace=False):
176
+ # if not isinstance(other, self.__class__):
177
+ # raise TypeError('unsupported operand type(s) for +: '
178
+ # f"'{self.__class__.__name__}' and '{other.__class__.__name__}'")
179
+
180
+ if np.isin(self.instruments, other.instruments).any():
181
+ logger.error('the two objects share instrument(s), cannot add them')
182
+ return
183
+
184
+ if inplace:
185
+ #? could it be as simple as this?
186
+ for i in other.instruments:
187
+ self.instruments.append(i)
188
+ setattr(self, i, getattr(other, i))
189
+ self._build_arrays()
190
+ else:
191
+ # make a copy of ourselves
192
+ new_self = deepcopy(self)
193
+ #? could it be as simple as this?
194
+ for i in other.instruments:
195
+ new_self.instruments.append(i)
196
+ setattr(new_self, i, getattr(other, i))
197
+ new_self._build_arrays()
198
+ return new_self
199
+
146
200
 
147
201
  def reload(self):
148
202
  self._did_secular_acceleration = False
@@ -161,23 +215,20 @@ class RV:
161
215
  logger.info(f'Saved snapshot to {file}')
162
216
 
163
217
  @property
164
- def N(self):
218
+ def N(self) -> int:
165
219
  """Total number of observations"""
166
220
  return self.time.size
167
221
 
168
- @N.setter
169
- def N(self, value):
170
- if not isinstance(value, property):
171
- logger.error('Cannot set N directly')
172
-
173
222
  @property
174
223
  def NN(self):
175
224
  """ Total number of observations per instrument """
176
225
  return {inst: getattr(self, inst).N for inst in self.instruments}
177
226
 
178
227
  @property
179
- def N_nights(self):
228
+ def N_nights(self) -> int:
180
229
  """ Number of individual nights """
230
+ if self.mtime.size == 0:
231
+ return 0
181
232
  return binRV(self.mtime, None, None, binning_bins=True).size - 1
182
233
 
183
234
  @property
@@ -185,15 +236,26 @@ class RV:
185
236
  return {inst: getattr(self, inst).N_nights for inst in self.instruments}
186
237
 
187
238
  @property
188
- def mtime(self):
189
- return self.time[self.mask]
239
+ def _NN_as_table(self) -> str:
240
+ table = ''
241
+ table += ' | '.join(self.instruments) + '\n'
242
+ table += ' | '.join([i*'-' for i in map(len, self.instruments)]) + '\n'
243
+ table += ' | '.join(map(str, self.NN.values())) + '\n'
244
+ return table
190
245
 
191
246
  @property
192
- def mvrad(self):
247
+ def mtime(self) -> np.ndarray:
248
+ """ Masked array of times """
249
+ return self.time[self.mask]
250
+
251
+ @property
252
+ def mvrad(self) -> np.ndarray:
253
+ """ Masked array of radial velocities """
193
254
  return self.vrad[self.mask]
194
255
 
195
256
  @property
196
- def msvrad(self):
257
+ def msvrad(self) -> np.ndarray:
258
+ """ Masked array of radial velocity uncertainties """
197
259
  return self.svrad[self.mask]
198
260
 
199
261
  @property
@@ -201,7 +263,7 @@ class RV:
201
263
  return np.concatenate([[i] * n for i, n in self.NN.items()])
202
264
 
203
265
  @property
204
- def rms(self):
266
+ def rms(self) -> float:
205
267
  """ Weighted rms of the (masked) radial velocities """
206
268
  if self.mask.sum() == 0: # only one point
207
269
  return np.nan
@@ -210,7 +272,7 @@ class RV:
210
272
 
211
273
  @property
212
274
  def sigma(self):
213
- """ Average error bar """
275
+ """ Average radial velocity uncertainty """
214
276
  if self.mask.sum() == 0: # only one point
215
277
  return np.nan
216
278
  else:
@@ -227,11 +289,12 @@ class RV:
227
289
  return np.argsort(self.mtime)
228
290
 
229
291
  @property
230
- def _tt(self):
292
+ def _tt(self) -> np.ndarray:
231
293
  return np.linspace(self.mtime.min(), self.mtime.max(), 20*self.N)
232
294
 
233
295
  @classmethod
234
296
  def from_dace_data(cls, star, inst, pipe, mode, data, **kwargs):
297
+ verbose = kwargs.pop('verbose', False)
235
298
  s = cls(star, **kwargs)
236
299
  #
237
300
  ind = np.argsort(data['rjd'])
@@ -269,7 +332,8 @@ class RV:
269
332
  # mask out drs_qc = False
270
333
  if not s.drs_qc.all():
271
334
  n = (~s.drs_qc).sum()
272
- logger.warning(f'masking {n} points where DRS QC failed for {inst}')
335
+ if verbose:
336
+ logger.warning(f'masking {n} points where DRS QC failed for {inst}')
273
337
  s.mask &= s.drs_qc
274
338
 
275
339
  s.instruments = [inst]
@@ -302,7 +366,7 @@ class RV:
302
366
  return s
303
367
 
304
368
  @classmethod
305
- def from_snapshot(cls, file=None, star=None):
369
+ def from_snapshot(cls, file=None, star=None, verbose=True):
306
370
  import pickle
307
371
  from datetime import datetime
308
372
  if star is None:
@@ -316,21 +380,36 @@ class RV:
316
380
  star, timestamp = file.replace('.pkl', '').split('_')
317
381
 
318
382
  dt = datetime.fromtimestamp(float(timestamp))
319
- logger.info(f'Reading snapshot of {star} from {dt}')
383
+ if verbose:
384
+ logger.info(f'Reading snapshot of {star} from {dt}')
320
385
  return pickle.load(open(file, 'rb'))
321
386
 
322
387
  @classmethod
323
388
  def from_rdb(cls, files, star=None, instrument=None, units='ms', **kwargs):
389
+ """ Create an RV object from an rdb file or a list of rdb files
390
+
391
+ Args:
392
+ files (str, list):
393
+ File name or list of file names
394
+ star (str, optional):
395
+ Name of the star. If None, try to infer it from file name
396
+ instrument (str, list, optional):
397
+ Name of the instrument(s). If None, try to infer it from file name
398
+ units (str, optional):
399
+ Units of the radial velocities. Defaults to 'ms'.
400
+
401
+ Examples:
402
+ s = RV.from_rdb('star_HARPS.rdb')
403
+ """
324
404
  if isinstance(files, str):
325
405
  files = [files]
326
406
 
327
407
  if star is None:
328
- star_ = np.unique([os.path.splitext(f)[0].split('_')[0] for f in files])
408
+ star_ = np.unique([os.path.splitext(os.path.basename(f))[0].split('_')[0] for f in files])
329
409
  if star_.size == 1:
330
410
  logger.info(f'assuming star is {star_[0]}')
331
411
  star = star_[0]
332
412
 
333
-
334
413
  if instrument is None:
335
414
  instruments = np.array([os.path.splitext(f)[0].split('_')[1] for f in files])
336
415
  logger.info(f'assuming instruments: {instruments}')
@@ -353,8 +432,24 @@ class RV:
353
432
  _s.svrad = data[2] * factor
354
433
 
355
434
  _quantities = []
435
+
356
436
  #! hack
357
- data = np.genfromtxt(f, names=True, dtype=None, comments='--', encoding=None)
437
+ with open(f) as ff:
438
+ header = ff.readline().strip()
439
+ if '\t' in header:
440
+ names = header.split('\t')
441
+ else:
442
+ names = header.split()
443
+
444
+ if len(names) > 3:
445
+ kw = dict(skip_header=0, comments='--', names=True, dtype=None, encoding=None)
446
+ if '\t' in header:
447
+ data = np.genfromtxt(f, **kw, delimiter='\t')
448
+ else:
449
+ data = np.genfromtxt(f, **kw)
450
+ # data.dtype.names = names
451
+ else:
452
+ data = np.array([], dtype=np.dtype([]))
358
453
 
359
454
  if 'fwhm' in data.dtype.fields:
360
455
  _s.fwhm = data['fwhm']
@@ -371,8 +466,10 @@ class RV:
371
466
 
372
467
  if 'rhk' in data.dtype.fields:
373
468
  _s.rhk = data['rhk']
374
- if 'srhk' in data.dtype.fields:
375
- _s.rhk_err = data['srhk']
469
+ _s.rhk_err = np.full_like(time, np.nan)
470
+ for possible_name in ['srhk', 'rhk_err']:
471
+ if possible_name in data.dtype.fields:
472
+ _s.rhk_err = data[possible_name]
376
473
  else:
377
474
  _s.rhk = np.zeros_like(time)
378
475
  _s.rhk_err = np.full_like(time, np.nan)
@@ -382,6 +479,23 @@ class RV:
382
479
 
383
480
  _s.bispan = np.zeros_like(time)
384
481
  _s.bispan_err = np.full_like(time, np.nan)
482
+
483
+ # other quantities, but all NaNs
484
+ for q in ['bispan', 'caindex', 'ccf_asym', 'contrast', 'haindex', 'naindex', 'sindex']:
485
+ setattr(_s, q, np.full_like(time, np.nan))
486
+ setattr(_s, q + '_err', np.full_like(time, np.nan))
487
+ _quantities.append(q)
488
+ _quantities.append(q + '_err')
489
+ for q in ['berv', 'texp']:
490
+ setattr(_s, q, np.full_like(time, np.nan))
491
+ _quantities.append(q)
492
+ for q in ['ccf_mask', 'date_night', 'prog_id', 'raw_file', 'pub_reference']:
493
+ setattr(_s, q, np.full(time.size, ''))
494
+ _quantities.append(q)
495
+ for q in ['drs_qc']:
496
+ setattr(_s, q, np.full(time.size, True))
497
+ _quantities.append(q)
498
+
385
499
  #! end hack
386
500
 
387
501
  _s.mask = np.ones_like(time, dtype=bool)
@@ -400,7 +514,45 @@ class RV:
400
514
 
401
515
  return s
402
516
 
403
- def _check_instrument(self, instrument, strict=False):
517
+ @classmethod
518
+ def from_ccf(cls, files, star=None, instrument=None, **kwargs):
519
+ """ Create an RV object from a CCF file or a list of CCF files """
520
+ try:
521
+ import iCCF
522
+ except ImportError:
523
+ logger.error('iCCF is not installed. Please install it with `pip install iCCF`')
524
+ return
525
+
526
+ if isinstance(files, str):
527
+ files = [files]
528
+
529
+ I = iCCF.from_file(files)
530
+
531
+ objects = np.unique([i.HDU[0].header['OBJECT'].replace(' ', '') for i in I])
532
+ if objects.size != 1:
533
+ logger.warning(f'found {objects.size} different stars in the CCF files, '
534
+ 'choosing the first one')
535
+ star = objects[0]
536
+
537
+ s = cls(star, _child=True)
538
+
539
+ # time, RVs, uncertainties
540
+ s.time = np.array([i.bjd for i in I])
541
+ s.vrad = np.array([i.RV*1e3 for i in I])
542
+ s.svrad = np.array([i.RVerror*1e3 for i in I])
543
+
544
+ s.fwhm = np.array([i.FWHM*1e3 for i in I])
545
+ s.fwhm_err = np.array([i.FWHMerror*1e3 for i in I])
546
+
547
+ # mask
548
+ s.mask = np.full_like(s.time, True, dtype=bool)
549
+
550
+ s.instruments = list(np.unique([i.instrument for i in I]))
551
+
552
+ return s
553
+
554
+
555
+ def _check_instrument(self, instrument, strict=False):# -> list | None:
404
556
  """
405
557
  Check if there are observations from `instrument`.
406
558
 
@@ -414,12 +566,25 @@ class RV:
414
566
  """
415
567
  if instrument is None:
416
568
  return self.instruments
417
- if not strict:
418
- if any([instrument in inst for inst in self.instruments]):
419
- return [inst for inst in self.instruments if instrument in inst]
420
- if instrument in self.instruments:
421
- return [instrument]
422
-
569
+
570
+ if isinstance(instrument, list):
571
+ if strict:
572
+ return [inst for inst in instrument if inst in self.instruments]
573
+ else:
574
+ r = []
575
+ for i in instrument:
576
+ if any([i in inst for inst in self.instruments]):
577
+ r += [inst for inst in self.instruments if i in inst]
578
+ return r
579
+
580
+ else:
581
+ if strict:
582
+ if instrument in self.instruments:
583
+ return [instrument]
584
+ else:
585
+ if any([instrument in inst for inst in self.instruments]):
586
+ return [inst for inst in self.instruments if instrument in inst]
587
+
423
588
 
424
589
  def _build_arrays(self):
425
590
  """ build all concatenated arrays of `self` from each of the `.inst`s """
@@ -464,11 +629,12 @@ class RV:
464
629
  setattr(self, q, arr)
465
630
 
466
631
 
467
- def download_ccf(self, instrument=None, limit=None, directory=None):
632
+ def download_ccf(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
468
633
  """ Download CCFs from DACE
469
634
 
470
635
  Args:
471
636
  instrument (str): Specific instrument for which to download data
637
+ index (int): Specific index of point for which to download data (0-based)
472
638
  limit (int): Maximum number of files to download.
473
639
  directory (str): Directory where to store data.
474
640
  """
@@ -478,18 +644,27 @@ class RV:
478
644
  if instrument is None:
479
645
  files = [file for file in self.raw_file if file.endswith('.fits')]
480
646
  else:
481
- instrument = self._check_instrument(instrument)
647
+ strict = kwargs.pop('strict', False)
648
+ instrument = self._check_instrument(instrument, strict=strict)
482
649
  files = []
483
650
  for inst in instrument:
484
651
  files += list(getattr(self, inst).raw_file)
485
652
 
486
- do_download_ccf(files[:limit], directory)
653
+ if index is not None:
654
+ index = np.atleast_1d(index)
655
+ files = list(np.array(files)[index])
656
+
657
+ # remove empty strings
658
+ files = list(filter(None, files))
487
659
 
488
- def download_s1d(self, instrument=None, limit=None, directory=None):
660
+ do_download_filetype('CCF', files[:limit], directory, **kwargs)
661
+
662
+ def download_s1d(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
489
663
  """ Download S1Ds from DACE
490
664
 
491
665
  Args:
492
666
  instrument (str): Specific instrument for which to download data
667
+ index (int): Specific index of point for which to download data (0-based)
493
668
  limit (int): Maximum number of files to download.
494
669
  directory (str): Directory where to store data.
495
670
  """
@@ -499,18 +674,27 @@ class RV:
499
674
  if instrument is None:
500
675
  files = [file for file in self.raw_file if file.endswith('.fits')]
501
676
  else:
502
- instrument = self._check_instrument(instrument)
677
+ strict = kwargs.pop('strict', False)
678
+ instrument = self._check_instrument(instrument, strict=strict)
503
679
  files = []
504
680
  for inst in instrument:
505
681
  files += list(getattr(self, inst).raw_file)
506
682
 
507
- do_download_s1d(files[:limit], directory)
683
+ if index is not None:
684
+ index = np.atleast_1d(index)
685
+ files = list(np.array(files)[index])
686
+
687
+ # remove empty strings
688
+ files = list(filter(None, files))
689
+
690
+ do_download_filetype('S1D', files[:limit], directory, **kwargs)
508
691
 
509
- def download_s2d(self, instrument=None, limit=None, directory=None):
692
+ def download_s2d(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
510
693
  """ Download S2Ds from DACE
511
694
 
512
695
  Args:
513
696
  instrument (str): Specific instrument for which to download data
697
+ index (int): Specific index of point for which to download data (0-based)
514
698
  limit (int): Maximum number of files to download.
515
699
  directory (str): Directory where to store data.
516
700
  """
@@ -520,12 +704,20 @@ class RV:
520
704
  if instrument is None:
521
705
  files = [file for file in self.raw_file if file.endswith('.fits')]
522
706
  else:
523
- instrument = self._check_instrument(instrument)
707
+ strict = kwargs.pop('strict', False)
708
+ instrument = self._check_instrument(instrument, strict=strict)
524
709
  files = []
525
710
  for inst in instrument:
526
711
  files += list(getattr(self, inst).raw_file)
527
712
 
528
- extracted_files = do_download_s2d(files[:limit], directory)
713
+ if index is not None:
714
+ index = np.atleast_1d(index)
715
+ files = list(np.array(files)[index])
716
+
717
+ # remove empty strings
718
+ files = list(filter(None, files))
719
+
720
+ do_download_filetype('S2D', files[:limit], directory, **kwargs)
529
721
 
530
722
 
531
723
  from .plots import plot, plot_fwhm, plot_bis, plot_rhk, plot_quantity
@@ -539,8 +731,10 @@ class RV:
539
731
  """ Remove all observations from one instrument
540
732
 
541
733
  Args:
542
- instrument (str): The instrument for which to remove observations.
543
- strict (bool): Whether to match `instrument` exactly
734
+ instrument (str or list):
735
+ The instrument(s) for which to remove observations.
736
+ strict (bool):
737
+ Whether to match (each) `instrument` exactly
544
738
 
545
739
  Note:
546
740
  A common name can be used to remove observations for several subsets
@@ -557,7 +751,7 @@ class RV:
557
751
  s.remove_instrument('HARPS03')
558
752
  ```
559
753
 
560
- will remove observations from the specific subset.
754
+ will only remove observations from the specific subset.
561
755
  """
562
756
  instruments = self._check_instrument(instrument, strict)
563
757
 
@@ -595,7 +789,9 @@ class RV:
595
789
  return self
596
790
 
597
791
  def remove_point(self, index):
598
- """ Remove individual observations at a given index (or indices)
792
+ """
793
+ Remove individual observations at a given index (or indices).
794
+ NOTE: Like Python, the index is 0-based.
599
795
 
600
796
  Args:
601
797
  index (int, list, ndarray):
@@ -621,6 +817,7 @@ class RV:
621
817
  return self
622
818
 
623
819
  def remove_non_public(self):
820
+ """ Remove non-public observations """
624
821
  if self.verbose:
625
822
  n = (~self.public).sum()
626
823
  logger.info(f'masking non-public observations ({n})')
@@ -635,6 +832,7 @@ class RV:
635
832
  self.remove_instrument(inst)
636
833
 
637
834
  def remove_prog_id(self, prog_id):
835
+ """ Remove observations from a given program ID """
638
836
  from glob import has_magic
639
837
  if has_magic(prog_id):
640
838
  from fnmatch import filter
@@ -652,12 +850,46 @@ class RV:
652
850
  if self.verbose:
653
851
  logger.warning(f'no observations for prog_id "{prog_id}"')
654
852
 
655
-
656
853
  def remove_after_bjd(self, bjd):
854
+ """ Remove observations after a given BJD """
657
855
  if (self.time > bjd).any():
658
856
  ind = np.where(self.time > bjd)[0]
659
857
  self.remove_point(ind)
660
858
 
859
+ def remove_before_bjd(self, bjd):
860
+ """ Remove observations before a given BJD """
861
+ if (self.time < bjd).any():
862
+ ind = np.where(self.time < bjd)[0]
863
+ self.remove_point(ind)
864
+
865
+ def choose_n_points(self, n, seed=None, instrument=None):
866
+ """ Randomly choose `n` observations and mask out the remaining ones
867
+
868
+ Args:
869
+ n (int):
870
+ Number of observations to keep.
871
+ seed (int, optional):
872
+ Random seed for reproducibility.
873
+ instrument (str or list, optional):
874
+ For which instrument to choose points (default is all).
875
+ """
876
+ instruments = self._check_instrument(instrument)
877
+ rng = np.random.default_rng(seed=seed)
878
+ for inst in instruments:
879
+ s = getattr(self, inst)
880
+ mask_for_this_inst = self.obs == self.instruments.index(inst) + 1
881
+ # only choose if there are more than n points
882
+ if self.mask[mask_for_this_inst].sum() > n:
883
+ if self.verbose:
884
+ logger.info(f'selecting {n} points from {inst}')
885
+ # indices of points for this instrument which are not masked already
886
+ available = np.where(self.mask & mask_for_this_inst)[0]
887
+ # choose n randomly
888
+ i = rng.choice(available, size=n, replace=False)
889
+ # mask the others out
890
+ self.mask[np.setdiff1d(available, i)] = False
891
+ self._propagate_mask_changes()
892
+
661
893
 
662
894
  def _propagate_mask_changes(self):
663
895
  """ link self.mask with each self.`instrument`.mask """
@@ -690,6 +922,11 @@ class RV:
690
922
  logger.error('no information from simbad, cannot remove secular acceleration')
691
923
  return
692
924
 
925
+ if self.simbad.plx_value is None:
926
+ if self.verbose:
927
+ logger.error('no parallax from simbad, cannot remove secular acceleration')
928
+ return
929
+
693
930
  #as_yr = units.arcsec / units.year
694
931
  mas_yr = units.milliarcsecond / units.year
695
932
  mas = units.milliarcsecond
@@ -719,6 +956,10 @@ class RV:
719
956
  continue
720
957
 
721
958
  s = getattr(self, inst)
959
+
960
+ if hasattr(s, '_did_secular_acceleration') and s._did_secular_acceleration:
961
+ continue
962
+
722
963
  s.vrad = s.vrad - sa * (s.time - epoch) / 365.25
723
964
 
724
965
  self._build_arrays()
@@ -727,7 +968,7 @@ class RV:
727
968
  if return_self:
728
969
  return self
729
970
 
730
- def sigmaclip(self, sigma=5):
971
+ def sigmaclip(self, sigma=5, instrument=None, strict=True):
731
972
  """ Sigma-clip RVs (per instrument!) """
732
973
  #from scipy.stats import sigmaclip as dosigmaclip
733
974
  from .stats import sigmaclip_median as dosigmaclip
@@ -735,7 +976,9 @@ class RV:
735
976
  if self._child or self._did_sigma_clip:
736
977
  return
737
978
 
738
- for inst in self.instruments:
979
+ instruments = self._check_instrument(instrument, strict)
980
+
981
+ for inst in instruments:
739
982
  m = self.instrument_array == inst
740
983
  result = dosigmaclip(self.vrad[m], low=sigma, high=sigma)
741
984
  n = self.vrad[m].size - result.clipped.size
@@ -767,12 +1010,11 @@ class RV:
767
1010
  if return_self:
768
1011
  return self
769
1012
 
770
- def clip_maxerror(self, maxerror:float, plot=False):
1013
+ def clip_maxerror(self, maxerror:float):
771
1014
  """ Mask out points with RV error larger than a given value
772
1015
 
773
1016
  Args:
774
1017
  maxerror (float): Maximum error to keep.
775
- plot (bool): Whether to plot the masked points.
776
1018
  """
777
1019
  if self._child:
778
1020
  return
@@ -828,6 +1070,11 @@ class RV:
828
1070
  setattr(s, q, Q[s.mask][inds])
829
1071
  continue
830
1072
 
1073
+ # treat ccf_mask specially, doing a 'unique' bin
1074
+ if q == 'ccf_mask':
1075
+ setattr(s, q, bin_ccf_mask(s.mtime, getattr(s, q)))
1076
+ continue
1077
+
831
1078
  if Q.dtype != np.float64:
832
1079
  bad_quantities.append(q)
833
1080
  all_bad_quantities.append(q)
@@ -881,6 +1128,7 @@ class RV:
881
1128
  return snew
882
1129
 
883
1130
  def nth_day_mean(self, n=1.0):
1131
+ """ Calculate the n-th day rolling mean of the radial velocities """
884
1132
  mask = np.abs(self.mtime[:, None] - self.mtime[None, :]) < n
885
1133
  z = np.full((self.mtime.size, self.mtime.size), np.nan)
886
1134
  z[mask] = np.repeat(self.mvrad[:, None], self.mtime.size, axis=1)[mask]
@@ -969,11 +1217,16 @@ class RV:
969
1217
  self._build_arrays()
970
1218
 
971
1219
  def sort_instruments(self, by_first_observation=True, by_last_observation=False):
1220
+ """ Sort instruments by first or last observation date.
1221
+
1222
+ Args:
1223
+ by_first_observation (bool, optional):
1224
+ Sort by first observation date.
1225
+ by_last_observation (bool, optional):
1226
+ Sort by last observation data.
1227
+ """
972
1228
  if by_last_observation:
973
1229
  by_first_observation = False
974
- # if by_first_observation and by_last_observation:
975
- # logger.error("'by_first_observation' and 'by_last_observation' can't both be true")
976
- # return
977
1230
  if by_first_observation:
978
1231
  fun = lambda i: getattr(self, i).time.min()
979
1232
  self.instruments = sorted(self.instruments, key=fun)
@@ -983,10 +1236,9 @@ class RV:
983
1236
  self.instruments = sorted(self.instruments, key=fun)
984
1237
  self._build_arrays()
985
1238
 
986
- #
987
1239
 
988
1240
  def save(self, directory=None, instrument=None, full=False,
989
- save_nans=True):
1241
+ save_masked=False, save_nans=True):
990
1242
  """ Save the observations in .rdb files.
991
1243
 
992
1244
  Args:
@@ -1021,11 +1273,18 @@ class RV:
1021
1273
  continue
1022
1274
 
1023
1275
  if full:
1024
- d = np.c_[
1025
- _s.mtime, _s.mvrad, _s.msvrad,
1026
- _s.fwhm[_s.mask], _s.fwhm_err[_s.mask],
1027
- _s.rhk[_s.mask], _s.rhk_err[_s.mask],
1028
- ]
1276
+ if save_masked:
1277
+ d = np.c_[
1278
+ _s.time, _s.vrad, _s.svrad,
1279
+ _s.fwhm, _s.fwhm_err,
1280
+ _s.rhk, _s.rhk_err,
1281
+ ]
1282
+ else:
1283
+ d = np.c_[
1284
+ _s.mtime, _s.mvrad, _s.msvrad,
1285
+ _s.fwhm[_s.mask], _s.fwhm_err[_s.mask],
1286
+ _s.rhk[_s.mask], _s.rhk_err[_s.mask],
1287
+ ]
1029
1288
  if not save_nans:
1030
1289
  if np.isnan(d).any():
1031
1290
  # remove observations where any of the indicators are # NaN
@@ -1037,7 +1296,10 @@ class RV:
1037
1296
  header = 'bjd\tvrad\tsvrad\tfwhm\tsfwhm\trhk\tsrhk\n'
1038
1297
  header += '---\t----\t-----\t----\t-----\t---\t----'
1039
1298
  else:
1040
- d = np.c_[_s.mtime, _s.mvrad, _s.msvrad]
1299
+ if save_masked:
1300
+ d = np.c_[_s.time, _s.vrad, _s.svrad]
1301
+ else:
1302
+ d = np.c_[_s.mtime, _s.mvrad, _s.msvrad]
1041
1303
  header = 'bjd\tvrad\tsvrad\n---\t----\t-----'
1042
1304
 
1043
1305
  file = f'{star_name}_{inst}.rdb'
@@ -1052,6 +1314,7 @@ class RV:
1052
1314
  return files
1053
1315
 
1054
1316
  def checksum(self, write_to=None):
1317
+ """ Calculate a hash based on the data """
1055
1318
  from hashlib import md5
1056
1319
  d = np.r_[self.time, self.vrad, self.svrad]
1057
1320
  H = md5(d.data.tobytes()).hexdigest()
@@ -1101,7 +1364,6 @@ class RV:
1101
1364
  logger.error(f"not all required files exist in {data_dir}")
1102
1365
  logger.error(f"missing {np.logical_not(exist).sum()} / {len(files)}")
1103
1366
 
1104
- from distutils.util import strtobool
1105
1367
  go_on = input('continue? (y/N) ')
1106
1368
  if go_on == '' or not bool(strtobool(go_on)):
1107
1369
  return
@@ -1150,8 +1412,18 @@ class RV:
1150
1412
 
1151
1413
 
1152
1414
  #
1415
+ @property
1416
+ def HZ(self):
1417
+ if not hasattr(self, 'star_mass'):
1418
+ self.star_mass = float(input('stellar mass (Msun): '))
1419
+ if not hasattr(self, 'lum'):
1420
+ self.lum = float(input('luminosity (Lsun): '))
1421
+ return getHZ_period(self.simbad.teff, self.star_mass, 1.0, self.lum)
1422
+
1423
+
1153
1424
  @property
1154
1425
  def planets(self):
1426
+ """ Query the NASA Exoplanet Archive for any known planets """
1155
1427
  from .nasaexo_wrapper import Planets
1156
1428
  if not hasattr(self, '_planets'):
1157
1429
  self._planets = Planets(self)