arvi 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/plots.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import os
2
2
  from functools import partial, partialmethod
3
+ from itertools import cycle
3
4
 
5
+ import matplotlib.collections
4
6
  import numpy as np
7
+ import matplotlib
5
8
  import matplotlib.pyplot as plt
6
9
  from matplotlib.collections import LineCollection
7
10
  import mplcursors
@@ -13,8 +16,8 @@ from . import config
13
16
 
14
17
 
15
18
  def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
16
- remove_50000=False, tooltips=True, label=None, N_in_label=False,
17
- versus_n=False, show_histogram=False, **kwargs):
19
+ remove_50000=False, tooltips=False, label=None, N_in_label=False,
20
+ versus_n=False, show_histogram=False, bw=False, **kwargs):
18
21
  """ Plot the RVs
19
22
 
20
23
  Args:
@@ -37,6 +40,8 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
37
40
  show_histogram (bool, optional)
38
41
  Whether to show a panel with the RV histograms (per intrument).
39
42
  Defaults to False.
43
+ bw (bool, optional):
44
+ Adapt plot to black and white. Defaults to False.
40
45
 
41
46
  Returns:
42
47
  Figure: the figure
@@ -57,7 +62,7 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
57
62
  ax, axh = ax
58
63
  fig = ax.figure
59
64
 
60
- kwargs.setdefault('marker', 'o')
65
+
61
66
  kwargs.setdefault('ls', '')
62
67
  kwargs.setdefault('capsize', 0)
63
68
  kwargs.setdefault('ms', 4)
@@ -65,10 +70,22 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
65
70
  if remove_50000:
66
71
  time_offset = 50000
67
72
 
68
- instruments = self._check_instrument(instrument)
73
+ strict = kwargs.pop('strict', False)
74
+ instruments = self._check_instrument(instrument, strict=strict)
75
+
76
+ if bw:
77
+ markers = cycle(('o', 'P', 's', '^', '*'))
78
+ else:
79
+ markers = cycle(('o',) * len(instruments))
80
+
81
+ try:
82
+ zorders = cycle(-np.argsort([getattr(self, i).error for i in instruments])[::-1])
83
+ except AttributeError:
84
+ zorders = cycle([1] * len(instruments))
69
85
 
70
86
  cursors = {}
71
- for inst in instruments:
87
+ containers = {}
88
+ for _i, inst in enumerate(instruments):
72
89
  s = self if self._child else getattr(self, inst)
73
90
  if s.mask.sum() == 0:
74
91
  continue
@@ -80,37 +97,42 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
80
97
  p = p.replace('_', '.')
81
98
  _label = f'{i}-{p}'
82
99
  else:
83
- _label = label
100
+ if isinstance(label, list):
101
+ _label = label[_i]
102
+ else:
103
+ _label = label
84
104
 
85
105
  if versus_n:
86
- container = ax.errorbar(np.arange(1, s.mtime.size + 1),
87
- s.mvrad, s.msvrad, label=_label, picker=True, **kwargs)
106
+ container = ax.errorbar(np.arange(1, s.mtime.size + 1), s.mvrad, s.msvrad,
107
+ label=_label, picker=True, marker=next(markers), zorder=next(zorders),
108
+ **kwargs)
88
109
  else:
89
- container = ax.errorbar(s.mtime - time_offset,
90
- s.mvrad, s.msvrad, label=_label, picker=True, **kwargs)
110
+ container = ax.errorbar(s.mtime - time_offset, s.mvrad, s.msvrad,
111
+ label=_label, picker=True, marker=next(markers), zorder=next(zorders),
112
+ **kwargs)
113
+
114
+ containers[inst] = list(container)
91
115
 
92
116
  if show_histogram:
93
117
  kw = dict(histtype='step', bins='doane', orientation='horizontal')
94
118
  hlabel = f'{s.mvrad.std():.2f} {self.units}'
95
119
  axh.hist(s.mvrad, **kw, label=hlabel)
96
120
 
97
- if tooltips:
98
- cursors[inst] = crsr = mplcursors.cursor(container, multiple=False)
99
-
100
- @crsr.connect("add")
101
- def _(sel):
102
- inst = sel.artist.get_label()
103
- _s = getattr(self, inst)
104
- vrad, svrad = _s.vrad[sel.index], _s.svrad[sel.index]
105
- sel.annotation.get_bbox_patch().set(fc="white")
106
- text = f'{inst}\n'
107
- text += f'BJD: {sel.target[0]:9.5f}\n'
108
- text += f'RV: {vrad:.3f} ± {svrad:.3f}'
109
- if fig.canvas.manager.toolmanager.get_tool('infotool').toggled:
110
- text += '\n\n'
111
- text += f'date: {_s.date_night[sel.index]}\n'
112
- text += f'mask: {_s.ccf_mask[sel.index]}'
113
- sel.annotation.set_text(text)
121
+ # cursors[inst] = crsr = mplcursors.cursor(container, multiple=False)
122
+ # @crsr.connect("add")
123
+ # def _(sel):
124
+ # inst = sel.artist.get_label()
125
+ # _s = getattr(self, inst)
126
+ # vrad, svrad = _s.vrad[sel.index], _s.svrad[sel.index]
127
+ # sel.annotation.get_bbox_patch().set(fc="white")
128
+ # text = f'{inst}\n'
129
+ # text += f'BJD: {sel.target[0]:9.5f}\n'
130
+ # text += f'RV: {vrad:.3f} ± {svrad:.3f}'
131
+ # # if fig.canvas.manager.toolmanager.get_tool('infotool').toggled:
132
+ # # text += '\n\n'
133
+ # # text += f'date: {_s.date_night[sel.index]}\n'
134
+ # # text += f'mask: {_s.ccf_mask[sel.index]}'
135
+ # sel.annotation.set_text(text)
114
136
 
115
137
  if show_masked:
116
138
  if versus_n:
@@ -128,7 +150,62 @@ def plot(self, ax=None, show_masked=False, instrument=None, time_offset=0,
128
150
  ax.errorbar(self.time[~self.mask] - time_offset, self.vrad[~self.mask], self.svrad[~self.mask],
129
151
  label='masked', fmt='x', ms=10, color='k', zorder=-2)
130
152
 
131
- ax.legend()
153
+ leg = ax.legend()
154
+ handles, labels = ax.get_legend_handles_labels()
155
+ for text in leg.get_texts():
156
+ text.set_picker(True)
157
+
158
+ def on_pick_legend(event):
159
+ artist = event.artist
160
+ if isinstance(artist, matplotlib.text.Text):
161
+ try:
162
+ h = handles[labels.index(artist.get_text())]
163
+ alpha_text = {None:0.2, 1.0: 0.2, 0.2:1.0}[artist.get_alpha()]
164
+ alpha_point = {None: 0.0, 1.0: 0.0, 0.2: 1.0}[artist.get_alpha()]
165
+ h[0].set_alpha(alpha_point)
166
+ h[2][0].set_alpha(alpha_point)
167
+ artist.set_alpha(alpha_text)
168
+ fig.canvas.draw()
169
+ except ValueError:
170
+ pass
171
+ plt.connect('pick_event', on_pick_legend)
172
+
173
+ if tooltips:
174
+ annotations = []
175
+ def on_pick_point(event):
176
+ print('annotations:', annotations)
177
+ for text in annotations:
178
+ text.remove()
179
+ annotations.remove(text)
180
+
181
+ artist = event.artist
182
+ if isinstance(artist, (matplotlib.lines.Line2D, matplotlib.collections.LineCollection)):
183
+ print(event.ind, artist)
184
+ if isinstance(artist, matplotlib.lines.Line2D):
185
+ matching_instrument = [k for k, v in containers.items() if artist in v]
186
+ print(matching_instrument)
187
+ if len(matching_instrument) == 0:
188
+ return
189
+ inst = matching_instrument[0]
190
+ _s = getattr(self, inst)
191
+ ind = event.ind[0]
192
+ # print(_s.mtime[ind], _s.mvrad[ind], _s.msvrad[ind])
193
+
194
+ text = f'{inst}\n'
195
+ text += f'{_s.mtime[ind]:9.5f}\n'
196
+ text += f'RV: {_s.mvrad[ind]:.1f} ± {_s.msvrad[ind]:.1f}'
197
+
198
+ annotations.append(
199
+ ax.annotate(text, (_s.mtime[ind], _s.mvrad[ind]), xycoords='data',
200
+ xytext=(5, 10), textcoords='offset points', fontsize=9,
201
+ bbox={'boxstyle': 'round', 'fc': 'w'}, arrowprops=dict(arrowstyle="-"))
202
+ )
203
+ # ax.annotate(f'{inst}', (0.5, 0.5), xycoords=artist, ha='center', va='center')
204
+ fig.canvas.draw()
205
+ # print(event.ind, artist.get_label())
206
+ plt.connect('pick_event', on_pick_point)
207
+
208
+
132
209
  if show_histogram:
133
210
  axh.legend()
134
211
 
@@ -209,7 +286,11 @@ def plot_quantity(self, quantity, ax=None, show_masked=False, instrument=None,
209
286
  label = f'{i}-{p}'
210
287
 
211
288
  y = getattr(s, quantity)
212
- ye = getattr(s, quantity + '_err')
289
+ try:
290
+ ye = getattr(s, quantity + '_err')
291
+ except AttributeError:
292
+ ye = np.zeros_like(y)
293
+
213
294
 
214
295
  if np.isnan(y).all() or np.isnan(ye).all():
215
296
  lines, *_ = ax.errorbar([], [], [],
@@ -256,6 +337,21 @@ plot_rhk = partialmethod(plot_quantity, quantity='rhk')
256
337
 
257
338
 
258
339
  def gls(self, ax=None, label=None, fap=True, picker=True, instrument=None, **kwargs):
340
+ """
341
+ Calculate and plot the Generalised Lomb-Scargle periodogram of the radial
342
+ velocities.
343
+
344
+ Args:
345
+ ax (matplotlib.axes.Axes):
346
+ The matplotlib axes to plot on. If None, a new figure will be
347
+ created.
348
+ label (str):
349
+ The label to use for the plot.
350
+ fap (bool):
351
+ Whether to show the false alarm probability.
352
+ instrument (str or list):
353
+ Which instruments' data to include in the periodogram.
354
+ """
259
355
  if self.N == 0:
260
356
  if self.verbose:
261
357
  logger.error('no data to compute gls')
@@ -267,7 +363,8 @@ def gls(self, ax=None, label=None, fap=True, picker=True, instrument=None, **kwa
267
363
  fig = ax.figure
268
364
 
269
365
  if instrument is not None:
270
- instrument = self._check_instrument(instrument)
366
+ strict = kwargs.pop('strict', False)
367
+ instrument = self._check_instrument(instrument, strict=strict)
271
368
  if instrument is not None:
272
369
  instrument_mask = np.isin(self.instrument_array, instrument)
273
370
  t = self.time[instrument_mask & self.mask]
@@ -298,6 +395,31 @@ def gls(self, ax=None, label=None, fap=True, picker=True, instrument=None, **kwa
298
395
  if label is not None:
299
396
  ax.legend()
300
397
 
398
+ if ax.get_legend() is not None:
399
+ leg = ax.get_legend()
400
+ for text in leg.get_texts():
401
+ text.set_picker(True)
402
+
403
+ def on_pick_legend(event):
404
+ handles, labels = ax.get_legend_handles_labels()
405
+ artist = event.artist
406
+ if isinstance(artist, matplotlib.text.Text):
407
+ # print('handles:', handles)
408
+ # print('labels:', labels)
409
+ # print(artist.get_text())
410
+ try:
411
+ h = handles[labels.index(artist.get_text())]
412
+ alpha_text = {None:0.2, 1.0: 0.2, 0.2:1.0}[artist.get_alpha()]
413
+ alpha_point = {None: 0.0, 1.0: 0.0, 0.2: 1.0}[artist.get_alpha()]
414
+ h.set_alpha(alpha_point)
415
+ artist.set_alpha(alpha_text)
416
+ fig.canvas.draw()
417
+ except ValueError:
418
+ pass
419
+
420
+ if 'pick_event' not in fig.canvas.callbacks.callbacks:
421
+ plt.connect('pick_event', on_pick_legend)
422
+
301
423
 
302
424
  if config.return_self:
303
425
  return self
@@ -347,3 +469,41 @@ def gls_quantity(self, quantity, ax=None, fap=True, picker=True):
347
469
  gls_fwhm = partialmethod(gls_quantity, quantity='fwhm')
348
470
  gls_bis = partialmethod(gls_quantity, quantity='bispan')
349
471
  gls_rhk = partialmethod(gls_quantity, quantity='rhk')
472
+
473
+
474
+ def histogram_svrad(self, ax=None, instrument=None, label=None):
475
+ """ Plot an histogram of the radial velocity uncertainties.
476
+
477
+ Args:
478
+ ax (matplotlib.axes.Axes):
479
+ The matplotlib axes to plot on. If None, a new figure will be
480
+ created.
481
+ instrument (str or list):
482
+ Which instruments' data to include in the histogram.
483
+ label (str):
484
+ The label to use for the plot.
485
+ """
486
+ if ax is None:
487
+ fig, ax = plt.subplots(1, 1, constrained_layout=True)
488
+ else:
489
+ fig = ax.figure
490
+
491
+ instruments = self._check_instrument(instrument)
492
+
493
+ for inst in instruments:
494
+ s = self if self._child else getattr(self, inst)
495
+
496
+ if label is None:
497
+ _label = inst
498
+ if not self.only_latest_pipeline:
499
+ i, p = _label.split('_', 1)
500
+ p = p.replace('_', '.')
501
+ _label = f'{i}-{p}'
502
+ else:
503
+ _label = label
504
+
505
+ kw = dict(bins=40, histtype='step', density=False, lw=2)
506
+ ax.hist(s.msvrad, label=_label, **kw)
507
+ ax.legend()
508
+ ax.set(xlabel=f'RV uncertainty [m/s]', ylabel='Number')
509
+
arvi/simbad_wrapper.py CHANGED
@@ -1,9 +1,12 @@
1
+ import os
1
2
  from dataclasses import dataclass, field
2
3
  import requests
3
4
 
4
5
  from astropy.coordinates import SkyCoord
5
6
  import pysweetcat
6
7
 
8
+ DATA_PATH = os.path.dirname(__file__)
9
+ DATA_PATH = os.path.join(DATA_PATH, 'data')
7
10
 
8
11
  QUERY = """
9
12
  SELECT basic.OID,
@@ -40,12 +43,13 @@ WHERE id = '{star}';
40
43
 
41
44
  def run_query(query):
42
45
  url = 'http://simbad.u-strasbg.fr/simbad/sim-tap/sync'
43
- response = requests.post(url,
44
- data=dict(query=query,
45
- request='doQuery',
46
- lang='ADQL',
47
- format='text/plain',
48
- phase='run'))
46
+ data = dict(query=query, request='doQuery', lang='ADQL', format='text/plain', phase='run')
47
+ try:
48
+ response = requests.post(url, data=data, timeout=10)
49
+ except requests.ReadTimeout as err:
50
+ raise IndexError(err)
51
+ except requests.ConnectionError as err:
52
+ raise IndexError(err)
49
53
  return response.content.decode()
50
54
 
51
55
  def parse_table(table, cols=None, values=None):
@@ -95,6 +99,16 @@ class simbad:
95
99
  """
96
100
  self.star = star
97
101
 
102
+ if 'kobe' in self.star.lower():
103
+ fname = os.path.join(DATA_PATH, 'KOBE-translate.csv')
104
+ kobe_translate = {}
105
+ if os.path.exists(fname):
106
+ with open(fname) as f:
107
+ for line in f.readlines():
108
+ kobe_id, catname = line.strip().split(',')
109
+ kobe_translate[kobe_id] = catname
110
+ self.star = star = kobe_translate[self.star]
111
+
98
112
  # oid = run_query(query=OID_QUERY.format(star=star))
99
113
  # self.oid = str(oid.split()[-1])
100
114
 
@@ -122,6 +136,9 @@ class simbad:
122
136
 
123
137
  self.coords = SkyCoord(self.ra, self.dec, unit='deg')
124
138
 
139
+ if self.plx_value == '':
140
+ self.plx_value = None
141
+
125
142
  try:
126
143
  swc_data = pysweetcat.get_data()
127
144
  data = swc_data.find(star)
arvi/stats.py CHANGED
@@ -11,17 +11,27 @@ def wmean(a, e):
11
11
  a (array): Array containing data
12
12
  e (array): Uncertainties on `a`
13
13
  """
14
+ if (e == 0).any():
15
+ raise ZeroDivisionError
16
+ if (e < 0).any():
17
+ raise ValueError
18
+ if (a.shape != e.shape):
19
+ raise ValueError
14
20
  return np.average(a, weights=1 / e**2)
15
21
 
16
- def rms(a):
22
+ def rms(a, ignore_nans=False):
17
23
  """ Root mean square of array `a`
18
24
 
19
25
  Args:
20
26
  a (array): Array containing data
21
27
  """
28
+ if ignore_nans:
29
+ a = a[~np.isnan(a)]
30
+ if len(a) == 0:
31
+ return np.nan
22
32
  return np.sqrt((a**2).mean())
23
33
 
24
- def wrms(a, e):
34
+ def wrms(a, e, ignore_nans=False):
25
35
  """ Weighted root mean square of array `a`, with uncertanty given by `e`.
26
36
  The weighted rms is calculated using the weighted mean, where the weights
27
37
  are equal to 1/e**2.
@@ -30,6 +40,16 @@ def wrms(a, e):
30
40
  a (array): Array containing data
31
41
  e (array): Uncertainties on `a`
32
42
  """
43
+ if ignore_nans:
44
+ nans = np.logical_or(np.isnan(a), np.isnan(e))
45
+ a = a[~nans]
46
+ e = e[~nans]
47
+ if (e == 0).any():
48
+ raise ZeroDivisionError('uncertainty cannot be zero')
49
+ if (e < 0).any():
50
+ raise ValueError('uncertainty cannot be negative')
51
+ if (a.shape != e.shape):
52
+ raise ValueError('arrays must have the same shape')
33
53
  w = 1 / e**2
34
54
  return np.sqrt(np.sum(w * (a - np.average(a, weights=w))**2) / sum(w))
35
55
 
@@ -38,6 +58,16 @@ def sigmaclip_median(a, low=4.0, high=4.0):
38
58
  """
39
59
  Same as scipy.stats.sigmaclip but using the median and median absolute
40
60
  deviation instead of the mean and standard deviation.
61
+
62
+ Args:
63
+ a (array): Array containing data
64
+ low (float): Number of MAD to use for the lower clipping limit
65
+ high (float): Number of MAD to use for the upper clipping limit
66
+ Returns:
67
+ SigmaclipResult: Object with the following attributes:
68
+ - `clipped`: Masked array of data
69
+ - `lower`: Lower clipping limit
70
+ - `upper`: Upper clipping limit
41
71
  """
42
72
  c = np.asarray(a).ravel()
43
73
  delta = 1