arvi 0.1.16__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/__init__.py CHANGED
@@ -1,7 +1,16 @@
1
1
  __all__ = ['RV']
2
2
 
3
+ from importlib.metadata import version, PackageNotFoundError
4
+
5
+ from .config import config
3
6
  from .timeseries import RV
4
7
 
8
+ try:
9
+ __version__ = version("arvi")
10
+ except PackageNotFoundError:
11
+ # package is not installed
12
+ pass
13
+
5
14
  ## OLD
6
15
  # # the __getattr__ function is always called twice, so we need this
7
16
  # # to only build and return the RV object on the second time
arvi/ariadne_wrapper.py CHANGED
@@ -2,6 +2,9 @@ import os
2
2
  import sys
3
3
  from matplotlib import pyplot as plt
4
4
 
5
+ from .utils import stdout_disabled, all_logging_disabled
6
+ from .setup_logger import logger
7
+
5
8
  try:
6
9
  from astroARIADNE.star import Star
7
10
  from astroARIADNE.fitter import Fitter
arvi/berv.py CHANGED
@@ -12,7 +12,7 @@ from astropy.timeseries import LombScargle
12
12
  from tqdm import tqdm
13
13
 
14
14
  from .setup_logger import logger
15
- from . import config
15
+ from .config import config
16
16
 
17
17
 
18
18
  def correct_rvs(self, simple=False, H=None, save_files=False, plot=True):
@@ -349,7 +349,6 @@ def BERV(self, H=None, use_gaia_meassurements=False, plx=None,
349
349
  axs[1].plot(bjd, diff, 'k.', label=label)
350
350
  axs[1].axhline(np.mean(diff), ls='--', c='k', alpha=0.1)
351
351
 
352
- from adjustText import adjust_text
353
352
  text = axs[1].text(bjd.max(), diff.min() + 0.1*diff.ptp(),
354
353
  f'ptp: {diff.ptp()*1e2:.2f} cm/s',
355
354
  ha='right', va='bottom', color='g', alpha=0.8)
arvi/config.py CHANGED
@@ -1,14 +1,36 @@
1
- # whether to return self from (some) RV methods
2
- return_self = False
3
1
 
4
- # whether to check internet connection before querying DACE
5
- check_internet = False
2
+ def instancer(cls):
3
+ return cls()
6
4
 
7
- # make all DACE requests without using a .dacerc file
8
- request_as_public = False
5
+ @instancer
6
+ class config:
7
+ # configuration values
8
+ __conf = {
9
+ # whether to return self from (some) RV methods
10
+ 'return_self': False,
11
+ # whether to adjust instrument means before gls by default
12
+ 'adjust_means_gls': True,
13
+ # whether to check internet connection before querying DACE
14
+ 'check_internet': False,
15
+ # make all DACE requests without using a .dacerc file
16
+ 'request_as_public': False,
17
+ # username for DACE servers
18
+ 'username': 'desousaj',
19
+ # debug
20
+ 'debug': False,
21
+ }
22
+ # all, for now
23
+ __setters = list(__conf.keys())
9
24
 
10
- # whether to adjust instrument means before gls by default
11
- adjust_means_gls = True
25
+ def __getattr__(self, name):
26
+ if name in ('__custom_documentations__', ):
27
+ # return {'return_self': 'help!'}
28
+ return {}
12
29
 
13
- # debug
14
- debug = False
30
+ return self.__conf[name]
31
+
32
+ def __setattr__(self, name, value):
33
+ if name in config.__setters:
34
+ self.__conf[name] = value
35
+ else:
36
+ raise NameError(f"unknown configuration name '{name}'")
arvi/dace_wrapper.py CHANGED
@@ -10,8 +10,8 @@ from .utils import create_directory, all_logging_disabled, stdout_disabled, tqdm
10
10
 
11
11
 
12
12
  def load_spectroscopy() -> SpectroscopyClass:
13
- from .config import request_as_public
14
- if request_as_public:
13
+ from .config import config
14
+ if config.request_as_public:
15
15
  with all_logging_disabled():
16
16
  dace = DaceClass(dace_rc_config_path='none')
17
17
  return SpectroscopyClass(dace_instance=dace)
@@ -41,14 +41,18 @@ def get_arrays(result, latest_pipeline=True, ESPRESSO_mode='HR11', NIRPS_mode='H
41
41
 
42
42
  # select ESPRESSO mode, which is defined at the level of the pipeline
43
43
  if 'ESPRESSO' in inst:
44
- if any(ESPRESSO_mode in pipe for pipe in pipelines):
44
+
45
+ find_mode = [ESPRESSO_mode in pipe for pipe in pipelines]
46
+ # the mode was not found
47
+ if not any(find_mode):
48
+ if len(pipelines) > 1 and verbose:
49
+ logger.warning(f'no observations for requested ESPRESSO mode ({ESPRESSO_mode})')
50
+ # the mode was found but do nothing if it's the only one
51
+ elif any(find_mode) and not all(find_mode):
45
52
  if verbose:
46
53
  logger.info(f'selecting mode {ESPRESSO_mode} for ESPRESSO')
47
54
  i = [i for i, pipe in enumerate(pipelines) if ESPRESSO_mode in pipe][0]
48
55
  pipelines = [pipelines[i]]
49
- else:
50
- if len(pipelines) > 1 and verbose:
51
- logger.warning(f'no observations for requested ESPRESSO mode ({ESPRESSO_mode})')
52
56
 
53
57
  if latest_pipeline:
54
58
  if verbose and len(pipelines) > 1:
@@ -59,6 +63,7 @@ def get_arrays(result, latest_pipeline=True, ESPRESSO_mode='HR11', NIRPS_mode='H
59
63
  for pipe in pipelines:
60
64
  modes = list(result[inst][pipe].keys())
61
65
 
66
+
62
67
  # select NIRPS mode, which is defined at the level of the mode
63
68
  if 'NIRPS' in inst:
64
69
  if NIRPS_mode in modes:
@@ -70,6 +75,19 @@ def get_arrays(result, latest_pipeline=True, ESPRESSO_mode='HR11', NIRPS_mode='H
70
75
  if verbose:
71
76
  logger.warning(f'no observations for requested NIRPS mode ({NIRPS_mode})')
72
77
 
78
+ # HARPS15 observations should not be separated by 'mode' if some are
79
+ # done together with NIRPS
80
+ if 'HARPS15' in inst and 'HARPS+NIRPS' in modes:
81
+ m0 = modes[0]
82
+ data = {
83
+ k: np.concatenate([result[inst][pipe][m][k] for m in modes])
84
+ for k in result[inst][pipe][m0].keys()
85
+ }
86
+ arrays.append(
87
+ ((inst, pipe, m0), data)
88
+ )
89
+ continue
90
+
73
91
  for mode in modes:
74
92
  if 'rjd' not in result[inst][pipe][mode]:
75
93
  logger.error(f"No 'rjd' key for {inst} - {pipe}")
@@ -105,7 +123,7 @@ def get_observations_from_instrument(star, instrument, main_id=None):
105
123
 
106
124
  Spectroscopy = load_spectroscopy()
107
125
  filters = {
108
- "ins_name": {"contains": [instrument]},
126
+ "ins_name": {"contains": [instrument]},
109
127
  "obj_id_daceid": {"contains": [dace_id]}
110
128
  }
111
129
  with stdout_disabled(), all_logging_disabled():
@@ -115,62 +133,70 @@ def get_observations_from_instrument(star, instrument, main_id=None):
115
133
  raise ValueError
116
134
 
117
135
  r = {}
136
+
118
137
  for inst in np.unique(result['ins_name']):
119
138
  mask1 = result['ins_name'] == inst
120
139
  r[inst] = {}
140
+
121
141
  for pipe in np.unique(result['ins_drs_version'][mask1]):
122
142
  mask2 = mask1 & (result['ins_drs_version'] == pipe)
123
- ins_mode = np.unique(result['ins_mode'][mask2])[0]
124
- _nan = np.full(mask2.sum(), np.nan)
125
- r[inst][pipe] = {
126
- ins_mode: {
127
- 'texp': result['texp'][mask2],
128
- 'bispan': result['spectro_ccf_bispan'][mask2],
129
- 'bispan_err': result['spectro_ccf_bispan_err'][mask2],
130
- 'drift_noise': result['spectro_cal_drift_noise'][mask2],
131
- 'rjd': result['obj_date_bjd'][mask2],
143
+ r[inst][pipe] = {}
144
+
145
+ for ins_mode in np.unique(result['ins_mode'][mask2]):
146
+ mask3 = mask2 & (result['ins_mode'] == ins_mode)
147
+ _nan = np.full(mask3.sum(), np.nan)
148
+
149
+ r[inst][pipe][ins_mode] = {
150
+ 'texp': result['texp'][mask3],
151
+ 'bispan': result['spectro_ccf_bispan'][mask3],
152
+ 'bispan_err': result['spectro_ccf_bispan_err'][mask3],
153
+ 'drift_noise': result['spectro_cal_drift_noise'][mask3],
154
+ 'rjd': result['obj_date_bjd'][mask3],
132
155
  'cal_therror': _nan,
133
- 'fwhm': result['spectro_ccf_fwhm'][mask2],
134
- 'fwhm_err': result['spectro_ccf_fwhm_err'][mask2],
135
- 'rv': result['spectro_ccf_rv'][mask2],
136
- 'rv_err': result['spectro_ccf_rv_err'][mask2],
137
- 'berv': result['spectro_cal_berv'][mask2],
156
+ 'fwhm': result['spectro_ccf_fwhm'][mask3],
157
+ 'fwhm_err': result['spectro_ccf_fwhm_err'][mask3],
158
+ 'rv': result['spectro_ccf_rv'][mask3],
159
+ 'rv_err': result['spectro_ccf_rv_err'][mask3],
160
+ 'berv': result['spectro_cal_berv'][mask3],
138
161
  'ccf_noise': _nan,
139
- 'rhk': result['spectro_analysis_rhk'][mask2],
140
- 'rhk_err': result['spectro_analysis_rhk_err'][mask2],
141
- 'contrast': result['spectro_ccf_contrast'][mask2],
142
- 'contrast_err': result['spectro_ccf_contrast_err'][mask2],
143
- 'cal_thfile': result['spectro_cal_thfile'][mask2],
144
- 'spectroFluxSn50': result['spectro_flux_sn50'][mask2],
145
- 'protm08': result['spectro_analysis_protm08'][mask2],
146
- 'protm08_err': result['spectro_analysis_protm08_err'][mask2],
147
- 'caindex': result['spectro_analysis_ca'][mask2],
148
- 'caindex_err': result['spectro_analysis_ca_err'][mask2],
149
- 'pub_reference': result['pub_ref'][mask2],
150
- 'drs_qc': result['spectro_drs_qc'][mask2],
151
- 'haindex': result['spectro_analysis_halpha'][mask2],
152
- 'haindex_err': result['spectro_analysis_halpha_err'][mask2],
153
- 'protn84': result['spectro_analysis_protn84'][mask2],
154
- 'protn84_err': result['spectro_analysis_protn84_err'][mask2],
155
- 'naindex': result['spectro_analysis_na'][mask2],
156
- 'naindex_err': result['spectro_analysis_na_err'][mask2],
162
+ 'rhk': result['spectro_analysis_rhk'][mask3],
163
+ 'rhk_err': result['spectro_analysis_rhk_err'][mask3],
164
+ 'contrast': result['spectro_ccf_contrast'][mask3],
165
+ 'contrast_err': result['spectro_ccf_contrast_err'][mask3],
166
+ 'cal_thfile': result['spectro_cal_thfile'][mask3],
167
+ 'spectroFluxSn50': result['spectro_flux_sn50'][mask3],
168
+ 'protm08': result['spectro_analysis_protm08'][mask3],
169
+ 'protm08_err': result['spectro_analysis_protm08_err'][mask3],
170
+ 'caindex': result['spectro_analysis_ca'][mask3],
171
+ 'caindex_err': result['spectro_analysis_ca_err'][mask3],
172
+ 'pub_reference': result['pub_ref'][mask3],
173
+ 'drs_qc': result['spectro_drs_qc'][mask3],
174
+ 'haindex': result['spectro_analysis_halpha'][mask3],
175
+ 'haindex_err': result['spectro_analysis_halpha_err'][mask3],
176
+ 'protn84': result['spectro_analysis_protn84'][mask3],
177
+ 'protn84_err': result['spectro_analysis_protn84_err'][mask3],
178
+ 'naindex': result['spectro_analysis_na'][mask3],
179
+ 'naindex_err': result['spectro_analysis_na_err'][mask3],
157
180
  'snca2': _nan,
158
- 'mask': result['spectro_ccf_mask'][mask2],
159
- 'public': result['public'][mask2],
160
- 'spectroFluxSn20': result['spectro_flux_sn20'][mask2],
161
- 'sindex': result['spectro_analysis_smw'][mask2],
162
- 'sindex_err': result['spectro_analysis_smw_err'][mask2],
181
+ 'mask': result['spectro_ccf_mask'][mask3],
182
+ 'public': result['public'][mask3],
183
+ 'spectroFluxSn20': result['spectro_flux_sn20'][mask3],
184
+ 'sindex': result['spectro_analysis_smw'][mask3],
185
+ 'sindex_err': result['spectro_analysis_smw_err'][mask3],
163
186
  'drift_used': _nan,
164
- 'ccf_asym': result['spectro_ccf_asym'][mask2],
165
- 'ccf_asym_err': result['spectro_ccf_asym_err'][mask2],
166
- 'date_night': result['date_night'][mask2],
167
- 'raw_file': result['file_rootpath'][mask2],
168
- 'prog_id': result['prog_id'][mask2],
169
- 'th_ar': result['th_ar'][mask2],
170
- 'th_ar1': result['th_ar1'][mask2],
171
- 'th_ar2': result['th_ar2'][mask2],
187
+ 'ccf_asym': result['spectro_ccf_asym'][mask3],
188
+ 'ccf_asym_err': result['spectro_ccf_asym_err'][mask3],
189
+ 'date_night': result['date_night'][mask3],
190
+ 'raw_file': result['file_rootpath'][mask3],
191
+ 'prog_id': result['prog_id'][mask3],
192
+ 'th_ar': result['th_ar'][mask3],
193
+ 'th_ar1': result['th_ar1'][mask3],
194
+ 'th_ar2': result['th_ar2'][mask3],
172
195
  }
173
- }
196
+
197
+ # print(r.keys())
198
+ # print([r[k].keys() for k in r.keys()])
199
+ # print([r[k1][k2].keys() for k1 in r.keys() for k2 in r[k1].keys()])
174
200
  return r
175
201
 
176
202
  def get_observations(star, instrument=None, main_id=None, verbose=True):
@@ -225,9 +251,9 @@ def get_observations(star, instrument=None, main_id=None, verbose=True):
225
251
  # (i.e. ensure that 3.x.x > 3.5)
226
252
  from re import match
227
253
  def cmp(a, b):
228
- if a[0] in ('3.5', '3.5 EGGS') and match(r'3.\d.\d', b[0]):
254
+ if a[0] in ('3.5', '3.5 EGGS') or 'EGGS' in a[0] and match(r'3.\d.\d', b[0]):
229
255
  return -1
230
- if b[0] in ('3.5', '3.5 EGGS') and match(r'3.\d.\d', a[0]):
256
+ if b[0] in ('3.5', '3.5 EGGS') or 'EGGS' in b[0] and match(r'3.\d.\d', a[0]):
231
257
  return 1
232
258
 
233
259
  if a[0] == b[0]:
@@ -248,16 +274,20 @@ def get_observations(star, instrument=None, main_id=None, verbose=True):
248
274
  _inst = ''
249
275
  for inst in instruments:
250
276
  pipelines = list(new_result[inst].keys())
277
+ max_len = max([len(pipe) for pipe in pipelines])
251
278
  for pipe in pipelines:
279
+ last_pipe = pipe == pipelines[-1]
252
280
  modes = list(new_result[inst][pipe].keys())
253
281
  for mode in modes:
254
282
  N = len(new_result[inst][pipe][mode]['rjd'])
255
283
  # LOG
256
- if inst == _inst:
257
- logger.info(f'{" ":>12s} └ {pipe} - {mode} ({N} observations)')
284
+ if inst == _inst and last_pipe:
285
+ logger.info(f'{" ":>12s} └ {pipe:{max_len}s} - {mode} ({N} observations)')
286
+ elif inst == _inst:
287
+ logger.info(f'{" ":>12s} ├ {pipe:{max_len}s} - {mode} ({N} observations)')
258
288
  else:
259
- logger.info(f'{inst:>12s} ├ {pipe} - {mode} ({N} observations)')
260
- _inst = inst
289
+ logger.info(f'{inst:>12s} ├ {pipe:{max_len}s} - {mode} ({N} observations)')
290
+ _inst = inst
261
291
 
262
292
  return new_result
263
293
 
@@ -381,12 +411,12 @@ def do_download_filetype(type, raw_files, output_directory, clobber=False,
381
411
 
382
412
  if verbose:
383
413
  if chunk_size < n:
384
- msg = f"Downloading {n} {type}s "
414
+ msg = f"downloading {n} {type}s "
385
415
  msg += f"(in chunks of {chunk_size}) "
386
416
  msg += f"into '{output_directory}'..."
387
417
  logger.info(msg)
388
418
  else:
389
- msg = f"Downloading {n} {type}s into '{output_directory}'..."
419
+ msg = f"downloading {n} {type}s into '{output_directory}'..."
390
420
  logger.info(msg)
391
421
 
392
422
  iterator = [raw_files[i:i + chunk_size] for i in range(0, n, chunk_size)]
@@ -394,7 +424,7 @@ def do_download_filetype(type, raw_files, output_directory, clobber=False,
394
424
  download(files, type, output_directory)
395
425
  extract_fits(output_directory)
396
426
 
397
- logger.info('Extracted .fits files')
427
+ logger.info('extracted .fits files')
398
428
 
399
429
 
400
430
  # def do_download_s1d(raw_files, output_directory, clobber=False, verbose=True):
arvi/gaia_wrapper.py CHANGED
@@ -22,11 +22,24 @@ CONTAINS(
22
22
  )=1
23
23
  """
24
24
 
25
+ QUERY_ID = """
26
+ SELECT TOP 20 gaia_source.designation,gaia_source.source_id,gaia_source.ra,gaia_source.dec,gaia_source.parallax,gaia_source.pmra,gaia_source.pmdec,gaia_source.ruwe,gaia_source.phot_g_mean_mag,gaia_source.bp_rp,gaia_source.radial_velocity,gaia_source.phot_variable_flag,gaia_source.non_single_star,gaia_source.has_xp_continuous,gaia_source.has_xp_sampled,gaia_source.has_rvs,gaia_source.has_epoch_photometry,gaia_source.has_epoch_rv,gaia_source.has_mcmc_gspphot,gaia_source.has_mcmc_msc,gaia_source.teff_gspphot,gaia_source.logg_gspphot,gaia_source.mh_gspphot,gaia_source.distance_gspphot,gaia_source.azero_gspphot,gaia_source.ag_gspphot,gaia_source.ebpminrp_gspphot
27
+ FROM gaiadr3.gaia_source
28
+ WHERE
29
+ gaia_source.source_id = {id}
30
+ """
31
+
32
+ translate = {
33
+ 'Proxima': '5853498713190525696',
34
+ 'LS II +14 13': '4318465066420528000',
35
+ }
36
+
37
+
25
38
  def run_query(query):
26
39
  url = 'https://gea.esac.esa.int/tap-server/tap/sync'
27
40
  data = dict(query=query, request='doQuery', lang='ADQL', format='csv')
28
41
  try:
29
- response = requests.post(url, data=data, timeout=5)
42
+ response = requests.post(url, data=data, timeout=2)
30
43
  except requests.ReadTimeout as err:
31
44
  raise IndexError(err)
32
45
  except requests.ConnectionError as err:
@@ -71,8 +84,13 @@ class gaia:
71
84
  args = dict(ra=ra, dec=dec, plx=plx, pmra=pmra, pmdec=pmdec, rv=rv)
72
85
 
73
86
  try:
74
- table1 = run_query(query=QUERY.format(**args))
75
- results = parse_csv(table1)[0]
87
+ if star in translate:
88
+ table = run_query(query=QUERY_ID.format(id=translate[star]))
89
+ elif hasattr(simbad, 'gaia_id'):
90
+ table = run_query(query=QUERY_ID.format(id=simbad.gaia_id))
91
+ else:
92
+ table = run_query(query=QUERY.format(**args))
93
+ results = parse_csv(table)[0]
76
94
  except IndexError:
77
95
  raise ValueError(f'Gaia query for {star} failed')
78
96
 
@@ -3,13 +3,99 @@ import numpy as np
3
3
 
4
4
  from .setup_logger import logger
5
5
 
6
+
7
+ # HARPS fiber upgrade (28 May 2015)
8
+ # https://www.eso.org/sci/facilities/lasilla/instruments/harps/news/harps_upgrade_2015.html
9
+ HARPS_technical_intervention = 57170
10
+
11
+ # ESPRESSO fiber link upgrade (1 July 2019)
12
+ ESPRESSO_technical_intervention = 58665
13
+
14
+
15
+ def divide_ESPRESSO(self):
16
+ """ Split ESPRESSO data into separate sub ESP18 and ESP19 subsets """
17
+ if self._check_instrument('ESPRESSO', strict=False) is None:
18
+ return
19
+ if 'ESPRESSO18' in self.instruments and 'ESPRESSO19' in self.instruments:
20
+ if self.verbose:
21
+ logger.info('ESPRESSO data seems to be split already, doing nothing')
22
+ return
23
+
24
+ from .timeseries import RV
25
+
26
+ before = self.time < ESPRESSO_technical_intervention
27
+ after = self.time >= ESPRESSO_technical_intervention
28
+ new_instruments = []
29
+
30
+
31
+ for inst, mask in zip(['ESPRESSO18', 'ESPRESSO19'], [before, after]):
32
+ if not mask.any():
33
+ continue
34
+
35
+ _s = RV.from_arrays(self.star, self.time[mask], self.vrad[mask], self.svrad[mask],
36
+ inst=inst)
37
+ for q in self._quantities:
38
+ setattr(_s, q, getattr(self, q)[mask])
39
+ setattr(self, inst, _s)
40
+ _s._quantities = self._quantities
41
+ _s.mask = self.mask[mask]
42
+ new_instruments.append(inst)
43
+
44
+ delattr(self, 'ESPRESSO')
45
+ self.instruments = new_instruments
46
+ self._build_arrays()
47
+
48
+ if self.verbose:
49
+ logger.info(f'divided ESPRESSO into {self.instruments}')
50
+
51
+
52
+ def divide_HARPS(self):
53
+ """ Split HARPS data into separate sub HARPS03 and HARPS15 subsets """
54
+ if self._check_instrument('HARPS', strict=False) is None:
55
+ return
56
+ if 'HARPS03' in self.instruments and 'HARPS15' in self.instruments:
57
+ if self.verbose:
58
+ logger.info('HARPS data seems to be split already, doing nothing')
59
+ return
60
+
61
+ from .timeseries import RV
62
+
63
+ new_instruments = []
64
+ before = self.time < HARPS_technical_intervention
65
+ if before.any():
66
+ new_instruments += ['HARPS03']
67
+
68
+ after = self.time >= HARPS_technical_intervention
69
+ if after.any():
70
+ new_instruments += ['HARPS15']
71
+
72
+ for inst, mask in zip(new_instruments, [before, after]):
73
+ _s = RV.from_arrays(self.star, self.time[mask], self.vrad[mask], self.svrad[mask],
74
+ inst=inst)
75
+ for q in self._quantities:
76
+ setattr(_s, q, getattr(self, q)[mask])
77
+ setattr(self, inst, _s)
78
+ _s._quantities = self._quantities
79
+ _s.mask = self.mask[mask]
80
+
81
+ delattr(self, 'HARPS')
82
+ self.instruments = new_instruments
83
+ self._build_arrays()
84
+
85
+ if self.verbose:
86
+ logger.info(f'divided HARPS into {self.instruments}')
87
+
88
+
89
+
6
90
  # ESPRESSO ADC issues
7
91
  from .utils import ESPRESSO_ADC_issues
8
92
 
9
- def ADC_issues(self, plot=True, check_headers=False):
10
- """ Identify and mask points affected by ADC issues (ESPRESSO).
93
+ def ADC_issues(self, mask=True, plot=True, check_headers=False):
94
+ """ Identify and optionally mask points affected by ADC issues (ESPRESSO).
11
95
 
12
96
  Args:
97
+ mask (bool, optional):
98
+ Whether to mask out the points.
13
99
  plot (bool, optional):
14
100
  Whether to plot the masked points.
15
101
  check_headers (bool, optional):
@@ -17,11 +103,10 @@ def ADC_issues(self, plot=True, check_headers=False):
17
103
  """
18
104
  instruments = self._check_instrument('ESPRESSO')
19
105
 
20
- if len(instruments) < 1:
106
+ if instruments is None:
21
107
  if self.verbose:
22
- logger.error(f"no data from ESPRESSO")
23
- logger.info(f'available: {self.instruments}')
24
-
108
+ logger.error(f"ADC_issues: no data from ESPRESSO")
109
+ return
25
110
 
26
111
  affected_file_roots = ESPRESSO_ADC_issues()
27
112
  file_roots = [os.path.basename(f).replace('.fits', '') for f in self.raw_file]
@@ -42,29 +127,31 @@ def ADC_issues(self, plot=True, check_headers=False):
42
127
  logger.info(f"there {'are'[:n^1]}{'is'[n^1:]} {n} frame{'s'[:n^1]} "
43
128
  "affected by ADC issues")
44
129
 
45
- self.mask[intersect] = False
46
- self._propagate_mask_changes()
130
+ if mask:
131
+ self.mask[intersect] = False
132
+ self._propagate_mask_changes()
47
133
 
48
- if plot:
49
- self.plot(show_masked=True)
134
+ if plot:
135
+ self.plot(show_masked=True)
50
136
 
51
137
  return intersect
52
138
 
53
139
  # ESPRESSO cryostat issues
54
140
  from .utils import ESPRESSO_cryostat_issues
55
141
 
56
- def blue_cryostat_issues(self, plot=True):
142
+ def blue_cryostat_issues(self, mask=True, plot=True):
57
143
  """ Identify and mask points affected by blue cryostat issues (ESPRESSO).
58
144
 
59
145
  Args:
146
+ mask (bool, optional): Whether to mask out the points.
60
147
  plot (bool, optional): Whether to plot the masked points.
61
148
  """
62
149
  instruments = self._check_instrument('ESPRESSO')
63
150
 
64
- if len(instruments) < 1:
151
+ if instruments is None:
65
152
  if self.verbose:
66
- logger.error(f"no data from ESPRESSO")
67
- logger.info(f'available: {self.instruments}')
153
+ logger.error(f"blue_cryostat_issues: no data from ESPRESSO")
154
+ return
68
155
 
69
156
  affected_file_roots = ESPRESSO_cryostat_issues()
70
157
  file_roots = [os.path.basename(f).replace('.fits', '') for f in self.raw_file]
@@ -77,11 +164,12 @@ def blue_cryostat_issues(self, plot=True):
77
164
  logger.info(f"there {'are'[:n^1]}{'is'[n^1:]} {n} frame{'s'[:n^1]} "
78
165
  "affected by blue cryostat issues")
79
166
 
80
- self.mask[intersect] = False
81
- self._propagate_mask_changes()
167
+ if mask:
168
+ self.mask[intersect] = False
169
+ self._propagate_mask_changes()
82
170
 
83
- if plot:
84
- self.plot(show_masked=True)
171
+ if plot:
172
+ self.plot(show_masked=True)
85
173
 
86
174
  return intersect
87
175
 
@@ -132,24 +220,27 @@ def qc_scired_issues(self, plot=False, **kwargs):
132
220
  return affected
133
221
 
134
222
 
135
- def known_issues(self, plot=False, **kwargs):
136
- """ Identify and mask known instrumental issues (ADC and blue cryostat for ESPRESSO)
223
+ def known_issues(self, mask=True, plot=False, **kwargs):
224
+ """ Identify and optionally mask known instrumental issues (ADC and blue cryostat for ESPRESSO)
137
225
 
138
226
  Args:
227
+ mask (bool, optional): Whether to mask out the points.
139
228
  plot (bool, optional): Whether to plot the masked points.
140
229
  """
141
230
  try:
142
- adc = ADC_issues(self, plot, **kwargs)
143
- except IndexError as e:
231
+ adc = ADC_issues(self, mask, plot, **kwargs)
232
+ except IndexError:
144
233
  # logger.error(e)
145
234
  logger.error('are the data binned? cannot proceed to mask these points...')
146
235
 
147
236
  try:
148
- cryostat = blue_cryostat_issues(self, plot)
149
- except IndexError as e:
237
+ cryostat = blue_cryostat_issues(self, mask, plot)
238
+ except IndexError:
150
239
  # logger.error(e)
151
240
  logger.error('are the data binned? cannot proceed to mask these points...')
152
241
 
242
+ if adc is None and cryostat is None:
243
+ return
153
244
  try:
154
245
  return adc | cryostat
155
246
  except UnboundLocalError:
arvi/kima_wrapper.py ADDED
@@ -0,0 +1,74 @@
1
+ import os
2
+ import numpy as np
3
+
4
+ from .setup_logger import logger
5
+
6
+ try:
7
+ import kima
8
+ from kima.pykima.utils import chdir
9
+ from kima import distributions
10
+ from kima import RVData, RVmodel
11
+ kima_available = True
12
+ except ImportError:
13
+ kima_available = False
14
+
15
+
16
+ def try_to_guess_prior(model, prior):
17
+ if 'jitter' in prior:
18
+ return 'Jprior'
19
+ if 'vsys' in prior:
20
+ return 'Cprior'
21
+ return None
22
+
23
+
24
+ def run_kima(self, run=False, load=False, run_directory=None, priors={}, **kwargs):
25
+ if not kima_available:
26
+ raise ImportError('kima not available, please install with `pip install kima`')
27
+
28
+ time = [getattr(self, inst).mtime for inst in self.instruments]
29
+ vrad = [getattr(self, inst).mvrad for inst in self.instruments]
30
+ err = [getattr(self, inst).msvrad for inst in self.instruments]
31
+ data = RVData(time, vrad, err, instruments=self.instruments)
32
+
33
+ fix = kwargs.pop('fix', False)
34
+ npmax = kwargs.pop('npmax', 1)
35
+ model = RVmodel(fix=fix, npmax=npmax, data=data)
36
+
37
+ model.trend = kwargs.pop('trend', False)
38
+ model.degree = kwargs.pop('degree', 0)
39
+
40
+ model.studentt = kwargs.pop('studentt', False)
41
+ model.enforce_stability = kwargs.pop('enforce_stability', False)
42
+ model.star_mass = kwargs.pop('star_mass', 1.0)
43
+
44
+ for k, v in priors.items():
45
+ try:
46
+ if 'conditional' in k:
47
+ setattr(model.conditional, k.replace('conditional.', ''), v)
48
+ else:
49
+ setattr(model, k, v)
50
+
51
+ except AttributeError:
52
+ msg = f'`RVmodel` has no attribute `{k}`, '
53
+ if guess := try_to_guess_prior(model, k):
54
+ msg += f'did you mean `{guess}`?'
55
+ logger.warning(msg)
56
+ return
57
+
58
+ if run:
59
+ if run_directory is None:
60
+ run_directory = os.getcwd()
61
+
62
+ # TODO: use signature of kima.run to pop the correct kwargs
63
+ # model_name = model.__class__.__name__
64
+ # model_name = f'kima.{model_name}.{model_name}'
65
+ # signature, defaults = [sig for sig in kima.run.__nb_signature__ if model_name in sig[0]]
66
+
67
+ with chdir(run_directory):
68
+ kima.run(model, **kwargs)
69
+
70
+ if load:
71
+ res = kima.load_results(model)
72
+ return data, model, res
73
+
74
+ return data, model