arvi 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/programs.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import os
2
- import concurrent.futures
2
+ import multiprocessing
3
+ from functools import partial
4
+ from itertools import chain
3
5
  from collections import namedtuple
4
6
  from tqdm import tqdm
5
7
  # import numpy as np
@@ -13,7 +15,8 @@ path = os.path.join(os.path.dirname(__file__), 'data')
13
15
 
14
16
 
15
17
  def get_star(star, instrument=None):
16
- return RV(star, verbose=False, instrument=instrument, _raise_on_error=False)
18
+ return RV(star, instrument=instrument,
19
+ _raise_on_error=False, verbose=False, load_extra_data=False)
17
20
 
18
21
 
19
22
  class LazyRV:
@@ -22,6 +25,7 @@ class LazyRV:
22
25
  if isinstance(self.stars, str):
23
26
  self.stars = [self.stars]
24
27
  self.instrument = instrument
28
+ self._saved = None
25
29
 
26
30
  @property
27
31
  def N(self):
@@ -31,30 +35,49 @@ class LazyRV:
31
35
  return f"RV({self.N} stars)"
32
36
 
33
37
  def _get(self):
34
- result = []
35
- # use a with statement to ensure threads are cleaned up promptly
36
- with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
37
- star_to_RV = {
38
- pool.submit(get_star, star, self.instrument): star
39
- for star in self.stars
40
- }
38
+ if self.N > 10:
39
+ # logger.info('Querying DACE...')
40
+ _get_star = partial(get_star, instrument=self.instrument)
41
+ with multiprocessing.Pool() as pool:
42
+ result = list(tqdm(pool.imap(_get_star, self.stars),
43
+ total=self.N, unit='star', desc='Querying DACE'))
44
+ # result = pool.map(get_star, self.stars)
45
+ else:
46
+ result = []
41
47
  logger.info('Querying DACE...')
42
- pbar = tqdm(concurrent.futures.as_completed(star_to_RV),
43
- total=self.N, unit='star')
44
- for future in pbar:
45
- star = star_to_RV[future]
48
+ pbar = tqdm(self.stars, total=self.N, unit='star')
49
+ for star in pbar:
46
50
  pbar.set_description(star)
47
- try:
48
- result.append(future.result())
49
- except Exception:
50
- print(f'{star} generated an exception')
51
+ result.append(get_star(star, self.instrument))
52
+
51
53
  return result
52
54
 
55
+ # # use a with statement to ensure threads are cleaned up promptly
56
+ # with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
57
+ # star_to_RV = {
58
+ # pool.submit(get_star, star, self.instrument): star
59
+ # for star in self.stars
60
+ # }
61
+ # logger.info('Querying DACE...')
62
+ # pbar = tqdm(concurrent.futures.as_completed(star_to_RV),
63
+ # total=self.N, unit='star')
64
+ # for future in pbar:
65
+ # star = star_to_RV[future]
66
+ # pbar.set_description(star)
67
+ # try:
68
+ # result.append(future.result())
69
+ # except ValueError:
70
+ # print(f'{star} generated an exception')
71
+ # result.append(None)
72
+ # return result
73
+
53
74
  def __iter__(self):
54
75
  return self._get()
55
76
 
56
77
  def __call__(self):
57
- return self._get()
78
+ if not self._saved:
79
+ self._saved = self._get()
80
+ return self._saved
58
81
 
59
82
 
60
83
  # sorted by spectral type
arvi/simbad_wrapper.py CHANGED
@@ -1,9 +1,12 @@
1
+ import os
1
2
  from dataclasses import dataclass, field
2
3
  import requests
3
4
 
4
5
  from astropy.coordinates import SkyCoord
5
6
  import pysweetcat
6
7
 
8
+ DATA_PATH = os.path.dirname(__file__)
9
+ DATA_PATH = os.path.join(DATA_PATH, 'data')
7
10
 
8
11
  QUERY = """
9
12
  SELECT basic.OID,
@@ -45,6 +48,8 @@ def run_query(query):
45
48
  response = requests.post(url, data=data, timeout=10)
46
49
  except requests.ReadTimeout as err:
47
50
  raise IndexError(err)
51
+ except requests.ConnectionError as err:
52
+ raise IndexError(err)
48
53
  return response.content.decode()
49
54
 
50
55
  def parse_table(table, cols=None, values=None):
@@ -80,7 +85,8 @@ class simbad:
80
85
  dec (float): declination
81
86
  coords (SkyCoord): coordinates as a SkyCoord object
82
87
  main_id (str): main identifier
83
- plx_value (float): parallax
88
+ gaia_id (int): Gaia DR3 identifier
89
+ plx (float): parallax
84
90
  rvz_radvel (float): radial velocity
85
91
  sp_type (str): spectral type
86
92
  B (float): B magnitude
@@ -94,6 +100,16 @@ class simbad:
94
100
  """
95
101
  self.star = star
96
102
 
103
+ if 'kobe' in self.star.lower():
104
+ fname = os.path.join(DATA_PATH, 'KOBE-translate.csv')
105
+ kobe_translate = {}
106
+ if os.path.exists(fname):
107
+ with open(fname) as f:
108
+ for line in f.readlines():
109
+ kobe_id, catname = line.strip().split(',')
110
+ kobe_translate[kobe_id] = catname
111
+ self.star = star = kobe_translate[self.star]
112
+
97
113
  # oid = run_query(query=OID_QUERY.format(star=star))
98
114
  # self.oid = str(oid.split()[-1])
99
115
 
@@ -110,6 +126,8 @@ class simbad:
110
126
  except IndexError:
111
127
  raise ValueError(f'simbad query for {star} failed')
112
128
 
129
+ self.gaia_id = int([i for i in self.ids if 'Gaia DR3' in i][0].split('Gaia DR3')[-1])
130
+
113
131
  for col, val in zip(cols, values):
114
132
  if col == 'oid':
115
133
  setattr(self, col, str(val))
@@ -123,6 +141,9 @@ class simbad:
123
141
 
124
142
  if self.plx_value == '':
125
143
  self.plx_value = None
144
+
145
+ self.plx = self._plx_value = self.plx_value
146
+ del self.plx_value
126
147
 
127
148
  try:
128
149
  swc_data = pysweetcat.get_data()
arvi/stats.py CHANGED
@@ -19,15 +19,19 @@ def wmean(a, e):
19
19
  raise ValueError
20
20
  return np.average(a, weights=1 / e**2)
21
21
 
22
- def rms(a):
22
+ def rms(a, ignore_nans=False):
23
23
  """ Root mean square of array `a`
24
24
 
25
25
  Args:
26
26
  a (array): Array containing data
27
27
  """
28
+ if ignore_nans:
29
+ a = a[~np.isnan(a)]
30
+ if len(a) == 0:
31
+ return np.nan
28
32
  return np.sqrt((a**2).mean())
29
33
 
30
- def wrms(a, e):
34
+ def wrms(a, e, ignore_nans=False):
31
35
  """ Weighted root mean square of array `a`, with uncertanty given by `e`.
32
36
  The weighted rms is calculated using the weighted mean, where the weights
33
37
  are equal to 1/e**2.
@@ -36,6 +40,16 @@ def wrms(a, e):
36
40
  a (array): Array containing data
37
41
  e (array): Uncertainties on `a`
38
42
  """
43
+ if ignore_nans:
44
+ nans = np.logical_or(np.isnan(a), np.isnan(e))
45
+ a = a[~nans]
46
+ e = e[~nans]
47
+ if (e == 0).any():
48
+ raise ZeroDivisionError('uncertainty cannot be zero')
49
+ if (e < 0).any():
50
+ raise ValueError('uncertainty cannot be negative')
51
+ if (a.shape != e.shape):
52
+ raise ValueError('arrays must have the same shape')
39
53
  w = 1 / e**2
40
54
  return np.sqrt(np.sum(w * (a - np.average(a, weights=w))**2) / sum(w))
41
55