arvi 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arvi might be problematic. Click here for more details.

arvi/timeseries.py CHANGED
@@ -11,15 +11,16 @@ import numpy as np
11
11
  from astropy import units
12
12
 
13
13
  from .setup_logger import logger
14
- from .config import return_self, check_internet
14
+ from .config import return_self, check_internet, debug
15
15
  from .translations import translate
16
- from .dace_wrapper import do_download_filetype, get_observations, get_arrays
16
+ from .dace_wrapper import do_download_filetype, do_symlink_filetype, get_observations, get_arrays
17
17
  from .simbad_wrapper import simbad
18
+ from .gaia_wrapper import gaia
18
19
  from .extra_data import get_extra_data
19
20
  from .stats import wmean, wrms
20
21
  from .binning import bin_ccf_mask, binRV
21
22
  from .HZ import getHZ_period
22
- from .utils import strtobool, there_is_internet
23
+ from .utils import strtobool, there_is_internet, timer
23
24
 
24
25
 
25
26
  @dataclass
@@ -49,7 +50,7 @@ class RV:
49
50
  do_sigma_clip: bool = field(init=True, repr=False, default=False)
50
51
  do_adjust_means: bool = field(init=True, repr=False, default=True)
51
52
  only_latest_pipeline: bool = field(init=True, repr=False, default=True)
52
- load_extra_data: Union[bool, str] = field(init=True, repr=False, default=True)
53
+ load_extra_data: Union[bool, str] = field(init=True, repr=False, default=False)
53
54
  #
54
55
  _child: bool = field(init=True, repr=False, default=False)
55
56
  _did_secular_acceleration: bool = field(init=False, repr=False, default=False)
@@ -72,6 +73,7 @@ class RV:
72
73
  if not self._child:
73
74
  if check_internet and not there_is_internet():
74
75
  raise ConnectionError('There is no internet connection?')
76
+
75
77
  # complicated way to query Simbad with self.__star__ or, if that
76
78
  # fails, try after removing a trailing 'A'
77
79
  for target in (self.__star__, self.__star__.replace('A', '')):
@@ -84,12 +86,26 @@ class RV:
84
86
  if self.verbose:
85
87
  logger.error(f'simbad query for {self.__star__} failed')
86
88
 
89
+ # complicated way to query Gaia with self.__star__ or, if that
90
+ # fails, try after removing a trailing 'A'
91
+ for target in (self.__star__, self.__star__.replace('A', '')):
92
+ try:
93
+ self.gaia = gaia(target)
94
+ break
95
+ except ValueError:
96
+ continue
97
+ else:
98
+ if self.verbose:
99
+ logger.error(f'Gaia query for {self.__star__} failed')
100
+
87
101
  # query DACE
88
102
  if self.verbose:
89
103
  logger.info(f'querying DACE for {self.__star__}...')
90
104
  try:
91
- self.dace_result = get_observations(self.__star__, self.instrument,
92
- verbose=self.verbose)
105
+ with timer():
106
+ self.dace_result = get_observations(self.__star__, self.instrument,
107
+ main_id=self.simbad.main_id,
108
+ verbose=self.verbose)
93
109
  except ValueError as e:
94
110
  # querying DACE failed, should we raise an error?
95
111
  if self._raise_on_error:
@@ -243,12 +259,16 @@ class RV:
243
259
  table += ' | '.join(map(str, self.NN.values())) + '\n'
244
260
  return table
245
261
 
262
+ @property
263
+ def point(self):
264
+ return [(t.round(4), v.round(4), sv.round(4)) for t, v, sv in zip(self.time, self.vrad, self.svrad)]
265
+
246
266
  @property
247
267
  def mtime(self) -> np.ndarray:
248
268
  """ Masked array of times """
249
269
  return self.time[self.mask]
250
270
 
251
- @property
271
+ @property
252
272
  def mvrad(self) -> np.ndarray:
253
273
  """ Masked array of radial velocities """
254
274
  return self.vrad[self.mask]
@@ -374,7 +394,7 @@ class RV:
374
394
  star, timestamp = file.replace('.pkl', '').split('_')
375
395
  else:
376
396
  try:
377
- file = sorted(glob(f'{star}_*.pkl'))[-1]
397
+ file = sorted(glob(f'{star}_*.*.pkl'))[-1]
378
398
  except IndexError:
379
399
  raise ValueError(f'cannot find any file matching {star}_*.pkl')
380
400
  star, timestamp = file.replace('.pkl', '').split('_')
@@ -409,7 +429,7 @@ class RV:
409
429
  if star_.size == 1:
410
430
  logger.info(f'assuming star is {star_[0]}')
411
431
  star = star_[0]
412
-
432
+
413
433
  if instrument is None:
414
434
  instruments = np.array([os.path.splitext(f)[0].split('_')[1] for f in files])
415
435
  logger.info(f'assuming instruments: {instruments}')
@@ -495,7 +515,7 @@ class RV:
495
515
  for q in ['drs_qc']:
496
516
  setattr(_s, q, np.full(time.size, True))
497
517
  _quantities.append(q)
498
-
518
+
499
519
  #! end hack
500
520
 
501
521
  _s.mask = np.ones_like(time, dtype=bool)
@@ -522,10 +542,10 @@ class RV:
522
542
  except ImportError:
523
543
  logger.error('iCCF is not installed. Please install it with `pip install iCCF`')
524
544
  return
525
-
545
+
526
546
  if isinstance(files, str):
527
547
  files = [files]
528
-
548
+
529
549
  I = iCCF.from_file(files)
530
550
 
531
551
  objects = np.unique([i.HDU[0].header['OBJECT'].replace(' ', '') for i in I])
@@ -552,7 +572,7 @@ class RV:
552
572
  return s
553
573
 
554
574
 
555
- def _check_instrument(self, instrument, strict=False):# -> list | None:
575
+ def _check_instrument(self, instrument, strict=False, log=False):# -> list | None:
556
576
  """
557
577
  Check if there are observations from `instrument`.
558
578
 
@@ -585,6 +605,11 @@ class RV:
585
605
  if any([instrument in inst for inst in self.instruments]):
586
606
  return [inst for inst in self.instruments if instrument in inst]
587
607
 
608
+ if log:
609
+ logger.error(f"No data from instrument '{instrument}'")
610
+ logger.info(f'available: {self.instruments}')
611
+ return
612
+
588
613
 
589
614
  def _build_arrays(self):
590
615
  """ build all concatenated arrays of `self` from each of the `.inst`s """
@@ -629,7 +654,8 @@ class RV:
629
654
  setattr(self, q, arr)
630
655
 
631
656
 
632
- def download_ccf(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
657
+ def download_ccf(self, instrument=None, index=None, limit=None,
658
+ directory=None, symlink=False, **kwargs):
633
659
  """ Download CCFs from DACE
634
660
 
635
661
  Args:
@@ -657,9 +683,15 @@ class RV:
657
683
  # remove empty strings
658
684
  files = list(filter(None, files))
659
685
 
660
- do_download_filetype('CCF', files[:limit], directory, **kwargs)
686
+ if symlink:
687
+ if 'top_level' not in kwargs:
688
+ logger.warning('may need to provide `top_level` in kwargs to find file')
689
+ do_symlink_filetype('CCF', files[:limit], directory, **kwargs)
690
+ else:
691
+ do_download_filetype('CCF', files[:limit], directory, verbose=self.verbose, **kwargs)
661
692
 
662
- def download_s1d(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
693
+ def download_s1d(self, instrument=None, index=None, limit=None,
694
+ directory=None, symlink=False, **kwargs):
663
695
  """ Download S1Ds from DACE
664
696
 
665
697
  Args:
@@ -687,9 +719,15 @@ class RV:
687
719
  # remove empty strings
688
720
  files = list(filter(None, files))
689
721
 
690
- do_download_filetype('S1D', files[:limit], directory, **kwargs)
722
+ if symlink:
723
+ if 'top_level' not in kwargs:
724
+ logger.warning('may need to provide `top_level` in kwargs to find file')
725
+ do_symlink_filetype('S1D', files[:limit], directory, **kwargs)
726
+ else:
727
+ do_download_filetype('S1D', files[:limit], directory, verbose=self.verbose, **kwargs)
691
728
 
692
- def download_s2d(self, instrument=None, index=None, limit=None, directory=None, **kwargs):
729
+ def download_s2d(self, instrument=None, index=None, limit=None,
730
+ directory=None, symlink=False, **kwargs):
693
731
  """ Download S2Ds from DACE
694
732
 
695
733
  Args:
@@ -717,11 +755,16 @@ class RV:
717
755
  # remove empty strings
718
756
  files = list(filter(None, files))
719
757
 
720
- do_download_filetype('S2D', files[:limit], directory, **kwargs)
758
+ if symlink:
759
+ if 'top_level' not in kwargs:
760
+ logger.warning('may need to provide `top_level` in kwargs to find file')
761
+ do_symlink_filetype('S2D', files[:limit], directory, **kwargs)
762
+ else:
763
+ do_download_filetype('S2D', files[:limit], directory, verbose=self.verbose, **kwargs)
721
764
 
722
765
 
723
- from .plots import plot, plot_fwhm, plot_bis, plot_rhk, plot_quantity
724
- from .plots import gls, gls_fwhm, gls_bis, gls_rhk
766
+ from .plots import plot, plot_fwhm, plot_bis, plot_rhk, plot_berv, plot_quantity
767
+ from .plots import gls, gls_fwhm, gls_bis, gls_rhk, window_function
725
768
  from .reports import report
726
769
 
727
770
  from .instrument_specific import known_issues
@@ -729,13 +772,13 @@ class RV:
729
772
 
730
773
  def remove_instrument(self, instrument, strict=False):
731
774
  """ Remove all observations from one instrument
732
-
775
+
733
776
  Args:
734
777
  instrument (str or list):
735
778
  The instrument(s) for which to remove observations.
736
779
  strict (bool):
737
780
  Whether to match (each) `instrument` exactly
738
-
781
+
739
782
  Note:
740
783
  A common name can be used to remove observations for several subsets
741
784
  of a given instrument. For example
@@ -788,11 +831,24 @@ class RV:
788
831
  if return_self:
789
832
  return self
790
833
 
834
+ def remove_condition(self, condition):
835
+ """ Remove all observations that satisfy a condition
836
+
837
+ Args:
838
+ condition (np.ndarray):
839
+ Boolean array of the same length as the observations
840
+ """
841
+ if self.verbose:
842
+ inst = np.unique(self.instrument_array[condition])
843
+ logger.info(f"Removing {condition.sum()} points from instruments {inst}")
844
+ self.mask = self.mask & ~condition
845
+ self._propagate_mask_changes()
846
+
791
847
  def remove_point(self, index):
792
848
  """
793
849
  Remove individual observations at a given index (or indices).
794
850
  NOTE: Like Python, the index is 0-based.
795
-
851
+
796
852
  Args:
797
853
  index (int, list, ndarray):
798
854
  Single index, list, or array of indices to remove.
@@ -899,45 +955,76 @@ class RV:
899
955
  n_before = (self.obs < self.obs[m]).sum()
900
956
  getattr(self, inst).mask[m - n_before] = False
901
957
 
902
- def secular_acceleration(self, epoch=55500, plot=False):
958
+ def secular_acceleration(self, epoch=None, just_compute=False, force_simbad=False):
903
959
  """
904
960
  Remove secular acceleration from RVs
905
961
 
906
962
  Args:
907
- epoch (float):
963
+ epoch (float, optional):
908
964
  The reference epoch (DACE uses 55500, 31/10/2010)
909
965
  instruments (bool or collection of str):
910
- Only remove secular acceleration for some instruments, or for all
966
+ Only remove secular acceleration for some instruments, or for all
911
967
  if `instruments=True`
912
- plot (bool):
913
- Show a plot of the RVs with the secular acceleration
914
968
  """
915
- if self._did_secular_acceleration: # don't do it twice
969
+ if self._did_secular_acceleration and not just_compute: # don't do it twice
916
970
  return
917
-
971
+
972
+ #as_yr = units.arcsec / units.year
973
+ mas_yr = units.milliarcsecond / units.year
974
+ mas = units.milliarcsecond
975
+
918
976
  try:
919
- self.simbad
920
- except AttributeError:
977
+ if force_simbad:
978
+ raise AttributeError
979
+
980
+ self.gaia
981
+ self.gaia.plx
982
+
921
983
  if self.verbose:
922
- logger.error('no information from simbad, cannot remove secular acceleration')
923
- return
984
+ logger.info('using Gaia information to remove secular acceleration')
985
+
986
+ if epoch is None:
987
+ # Gaia DR3 epoch (astropy.time.Time('J2016.0', format='jyear_str').jd)
988
+ epoch = 57389.0
989
+
990
+ π = self.gaia.plx * mas
991
+ d = π.to(units.pc, equivalencies=units.parallax())
992
+ μα = self.gaia.pmra * mas_yr
993
+ μδ = self.gaia.pmdec * mas_yr
994
+ μ = μα**2 + μδ**2
995
+ sa = (μ * d).to(units.m / units.second / units.year,
996
+ equivalencies=units.dimensionless_angles())
997
+
998
+ except AttributeError:
999
+ try:
1000
+ self.simbad
1001
+ except AttributeError:
1002
+ if self.verbose:
1003
+ logger.error('no information from simbad, cannot remove secular acceleration')
1004
+ return
1005
+
1006
+ if self.simbad.plx is None:
1007
+ if self.verbose:
1008
+ logger.error('no parallax from simbad, cannot remove secular acceleration')
1009
+ return
924
1010
 
925
- if self.simbad.plx_value is None:
926
1011
  if self.verbose:
927
- logger.error('no parallax from simbad, cannot remove secular acceleration')
928
- return
1012
+ logger.info('using Simbad information to remove secular acceleration')
929
1013
 
930
- #as_yr = units.arcsec / units.year
931
- mas_yr = units.milliarcsecond / units.year
932
- mas = units.milliarcsecond
1014
+ if epoch is None:
1015
+ epoch = 55500
1016
+
1017
+ π = self.simbad.plx * mas
1018
+ d = π.to(units.pc, equivalencies=units.parallax())
1019
+ μα = self.simbad.pmra * mas_yr
1020
+ μδ = self.simbad.pmdec * mas_yr
1021
+ μ = μα**2 + μδ**2
1022
+ sa = (μ * d).to(units.m / units.second / units.year,
1023
+ equivalencies=units.dimensionless_angles())
1024
+
1025
+ if just_compute:
1026
+ return sa
933
1027
 
934
- π = self.simbad.plx_value * mas
935
- d = π.to(units.pc, equivalencies=units.parallax())
936
- μα = self.simbad.pmra * mas_yr
937
- μδ = self.simbad.pmdec * mas_yr
938
- μ = μα**2 + μδ**2
939
- sa = (μ * d).to(units.m / units.second / units.year,
940
- equivalencies=units.dimensionless_angles())
941
1028
  sa = sa.value
942
1029
 
943
1030
  if self.verbose:
@@ -961,12 +1048,41 @@ class RV:
961
1048
  continue
962
1049
 
963
1050
  s.vrad = s.vrad - sa * (s.time - epoch) / 365.25
964
-
1051
+
965
1052
  self._build_arrays()
966
1053
 
967
1054
  self._did_secular_acceleration = True
1055
+ self._did_secular_acceleration_epoch = epoch
1056
+ self._did_secular_acceleration_simbad = force_simbad
1057
+
968
1058
  if return_self:
969
1059
  return self
1060
+
1061
+ def _undo_secular_acceleration(self):
1062
+ if self._did_secular_acceleration:
1063
+ _old_verbose = self.verbose
1064
+ self.verbose = False
1065
+ sa = self.secular_acceleration(just_compute=True,
1066
+ force_simbad=self._did_secular_acceleration_simbad)
1067
+ self.verbose = _old_verbose
1068
+ sa = sa.value
1069
+
1070
+ if self._child:
1071
+ self.vrad = self.vrad + sa * (self.time - self._did_secular_acceleration_epoch) / 365.25
1072
+ else:
1073
+ for inst in self.instruments:
1074
+ if 'HIRES' in inst: # never remove it from HIRES...
1075
+ continue
1076
+ if 'NIRPS' in inst: # never remove it from NIRPS...
1077
+ continue
1078
+
1079
+ s = getattr(self, inst)
1080
+
1081
+ s.vrad = s.vrad + sa * (s.time - self._did_secular_acceleration_epoch) / 365.25
1082
+
1083
+ self._build_arrays()
1084
+
1085
+ self._did_secular_acceleration = False
970
1086
 
971
1087
  def sigmaclip(self, sigma=5, instrument=None, strict=True):
972
1088
  """ Sigma-clip RVs (per instrument!) """
@@ -1012,7 +1128,7 @@ class RV:
1012
1128
 
1013
1129
  def clip_maxerror(self, maxerror:float):
1014
1130
  """ Mask out points with RV error larger than a given value
1015
-
1131
+
1016
1132
  Args:
1017
1133
  maxerror (float): Maximum error to keep.
1018
1134
  """
@@ -1038,10 +1154,10 @@ class RV:
1038
1154
 
1039
1155
  WARNING: This creates and returns a new object and does not modify self.
1040
1156
  """
1041
-
1157
+
1042
1158
  # create copy of self to be returned
1043
1159
  snew = deepcopy(self)
1044
-
1160
+
1045
1161
  all_bad_quantities = []
1046
1162
 
1047
1163
  for inst in snew.instruments:
@@ -1050,7 +1166,7 @@ class RV:
1050
1166
  # only one observation?
1051
1167
  if s.N == 1:
1052
1168
  continue
1053
-
1169
+
1054
1170
  # are all observations masked?
1055
1171
  if s.mtime.size == 0:
1056
1172
  continue
@@ -1101,7 +1217,7 @@ class RV:
1101
1217
  with warnings.catch_warnings():
1102
1218
  warnings.filterwarnings('ignore', category=RuntimeWarning)
1103
1219
  try:
1104
- _, yb = binRV(s.mtime, Q[s.mask],
1220
+ _, yb = binRV(s.mtime, Q[s.mask],
1105
1221
  stat=np.nanmean, tstat=np.nanmean)
1106
1222
  setattr(s, q, yb)
1107
1223
  except TypeError:
@@ -1116,7 +1232,7 @@ class RV:
1116
1232
 
1117
1233
  s.time = tb
1118
1234
  s.mask = np.full(tb.shape, True)
1119
-
1235
+
1120
1236
  if snew.verbose and len(all_bad_quantities) > 0:
1121
1237
  logger.warning('\nnew object will not have these non-float quantities')
1122
1238
 
@@ -1162,6 +1278,11 @@ class RV:
1162
1278
  for inst in self.instruments:
1163
1279
  s = getattr(self, inst)
1164
1280
 
1281
+ if s.mtime.size == 0:
1282
+ if self.verbose:
1283
+ logger.info(f'all observations of {inst} are masked')
1284
+ continue
1285
+
1165
1286
  if s.N == 1:
1166
1287
  if self.verbose:
1167
1288
  msg = (f'only 1 observation for {inst}, '
@@ -1173,28 +1294,37 @@ class RV:
1173
1294
 
1174
1295
  s.rv_mean = wmean(s.mvrad, s.msvrad)
1175
1296
  s.vrad -= s.rv_mean
1297
+
1176
1298
  if self.verbose:
1177
1299
  logger.info(f'subtracted weighted average from {inst:10s}: ({s.rv_mean:.3f} {self.units})')
1300
+
1178
1301
  if just_rv:
1179
1302
  continue
1180
- # log_msg = 'same for '
1303
+
1181
1304
  for i, other in enumerate(others):
1182
1305
  y, ye = getattr(s, other), getattr(s, other + '_err')
1183
1306
  m = wmean(y[s.mask], ye[s.mask])
1184
1307
  setattr(s, f'{other}_mean', m)
1185
1308
  setattr(s, other, getattr(s, other) - m)
1186
- # log_msg += other
1187
- # if i < len(others) - 1:
1188
- # log_msg += ', '
1189
-
1190
- # if self.verbose:
1191
- # logger.info(log_msg)
1192
1309
 
1193
1310
  self._build_arrays()
1194
1311
  self._did_adjust_means = True
1195
1312
  if return_self:
1196
1313
  return self
1197
1314
 
1315
+ def add_to_vrad(self, values):
1316
+ """ Add an array of values to the RVs of all instruments """
1317
+ if values.size != self.vrad.size:
1318
+ raise ValueError(f"incompatible sizes: len(values) must equal self.N, got {values.size} != {self.vrad.size}")
1319
+
1320
+ for inst in self.instruments:
1321
+ s = getattr(self, inst)
1322
+ mask = self.instrument_array == inst
1323
+ s.vrad += values[mask]
1324
+
1325
+ self._build_arrays()
1326
+
1327
+
1198
1328
  def put_at_systemic_velocity(self):
1199
1329
  """
1200
1330
  For instruments in which mean(RV) < ptp(RV), "move" RVs to the systemic
@@ -1237,7 +1367,7 @@ class RV:
1237
1367
  self._build_arrays()
1238
1368
 
1239
1369
 
1240
- def save(self, directory=None, instrument=None, full=False,
1370
+ def save(self, directory=None, instrument=None, full=False, postfix=None,
1241
1371
  save_masked=False, save_nans=True):
1242
1372
  """ Save the observations in .rdb files.
1243
1373
 
@@ -1246,9 +1376,10 @@ class RV:
1246
1376
  Directory where to save the .rdb files.
1247
1377
  instrument (str, optional):
1248
1378
  Instrument for which to save observations.
1249
- full (bool, optional):
1250
- Whether to save just RVs and errors (False) or more indicators
1251
- (True).
1379
+ full (bool, optional):
1380
+ Save just RVs and errors (False) or more indicators (True).
1381
+ postfix (str, optional):
1382
+ Postfix to add to the filenames ([star]_[instrument]_[postfix].rdb).
1252
1383
  save_nans (bool, optional)
1253
1384
  Whether to save NaN values in the indicators, if they exist. If
1254
1385
  False, the full observation is not saved.
@@ -1301,8 +1432,11 @@ class RV:
1301
1432
  else:
1302
1433
  d = np.c_[_s.mtime, _s.mvrad, _s.msvrad]
1303
1434
  header = 'bjd\tvrad\tsvrad\n---\t----\t-----'
1304
-
1435
+
1305
1436
  file = f'{star_name}_{inst}.rdb'
1437
+ if postfix is not None:
1438
+ file = f'{star_name}_{inst}_{postfix}.rdb'
1439
+
1306
1440
  files.append(file)
1307
1441
  file = os.path.join(directory, file)
1308
1442
 
@@ -1310,7 +1444,7 @@ class RV:
1310
1444
 
1311
1445
  if self.verbose:
1312
1446
  logger.info(f'saving to {file}')
1313
-
1447
+
1314
1448
  return files
1315
1449
 
1316
1450
  def checksum(self, write_to=None):
@@ -1325,7 +1459,7 @@ class RV:
1325
1459
 
1326
1460
 
1327
1461
  #
1328
- def run_lbl(self, instrument=None, data_dir=None,
1462
+ def run_lbl(self, instrument=None, data_dir=None,
1329
1463
  skysub=False, tell=False, limit=None, **kwargs):
1330
1464
  from .lbl_wrapper import run_lbl, NIRPS_create_telluric_corrected_S2D
1331
1465
 
@@ -1339,7 +1473,7 @@ class RV:
1339
1473
  logger.error(f"No data from instrument '{instrument}'")
1340
1474
  logger.info(f'available: {self.instruments}')
1341
1475
  return
1342
-
1476
+
1343
1477
  if isinstance(instrument, str):
1344
1478
  instruments = [instrument]
1345
1479
  else:
@@ -1394,7 +1528,7 @@ class RV:
1394
1528
  logger.error(f"No data from instrument '{instrument}'")
1395
1529
  logger.info(f'available: {self.instruments}')
1396
1530
  return
1397
-
1531
+
1398
1532
  if isinstance(instrument, str):
1399
1533
  instruments = [instrument]
1400
1534
  else:
@@ -1430,22 +1564,50 @@ class RV:
1430
1564
  return self._planets
1431
1565
 
1432
1566
 
1433
- def fit_sine(t, y, yerr, period='gls', fix_period=False):
1567
+ def fit_sine(t, y, yerr=None, period='gls', fix_period=False):
1568
+ """ Fit a sine curve of the form y = A * sin(2π * t / P + φ) + c
1569
+
1570
+ Args:
1571
+ t (ndarray):
1572
+ Time array
1573
+ y (ndarray):
1574
+ Array of observed values
1575
+ yerr (ndarray, optional):
1576
+ Array of uncertainties. Defaults to None.
1577
+ period (str or float, optional):
1578
+ Initial guess for period or 'gls' to get it from Lomb-Scargle
1579
+ periodogram. Defaults to 'gls'.
1580
+ fix_period (bool, optional):
1581
+ Whether to fix the period. Defaults to False.
1582
+
1583
+ Returns:
1584
+ p (ndarray):
1585
+ Best-fit parameters [A, P, φ, c] or [A, φ, c]
1586
+ f (callable):
1587
+ Function that returns the best-fit sine curve for input times
1588
+ """
1434
1589
  from scipy.optimize import leastsq
1435
1590
  if period == 'gls':
1436
1591
  from astropy.timeseries import LombScargle
1437
1592
  gls = LombScargle(t, y, yerr)
1438
1593
  freq, power = gls.autopower()
1439
1594
  period = 1 / freq[power.argmax()]
1440
-
1441
- if fix_period and period is None:
1442
- logger.error('period is fixed but no value provided')
1443
- return
1444
-
1445
- def sine(t, p):
1446
- return p[0] * np.sin(2 * np.pi * t / p[1] + p[2]) + p[3]
1447
-
1448
- p0 = [y.ptp(), period, 0.0, 0.0]
1449
- xbest, _ = leastsq(lambda p, t, y, ye: (sine(t, p) - y) / ye,
1450
- p0, args=(t, y, yerr))
1451
- return xbest, partial(sine, p=xbest)
1595
+ else:
1596
+ period = float(period)
1597
+
1598
+ if yerr is None:
1599
+ yerr = np.ones_like(y)
1600
+
1601
+ if fix_period:
1602
+ def sine(t, p):
1603
+ return p[0] * np.sin(2 * np.pi * t / period + p[1]) + p[2]
1604
+ f = lambda p, t, y, ye: (sine(t, p) - y) / ye
1605
+ p0 = [y.ptp(), 0.0, 0.0]
1606
+ else:
1607
+ def sine(t, p):
1608
+ return p[0] * np.sin(2 * np.pi * t / p[1] + p[2]) + p[3]
1609
+ f = lambda p, t, y, ye: (sine(t, p) - y) / ye
1610
+ p0 = [y.ptp(), period, 0.0, 0.0]
1611
+
1612
+ xbest, _ = leastsq(f, p0, args=(t, y, yerr))
1613
+ return xbest, partial(sine, p=xbest)
arvi/translations.py CHANGED
@@ -1,3 +1,5 @@
1
+ import re
2
+
1
3
  STARS = {
2
4
  'Barnard': 'GJ699',
3
5
  "Barnard's": 'GJ699',
@@ -5,6 +7,15 @@ STARS = {
5
7
 
6
8
 
7
9
  def translate(star):
10
+ # known translations
8
11
  if star in STARS:
9
12
  return STARS[star]
13
+
14
+ # regex translations
15
+ NGC_match = re.match(r'NGC([\s\d]+)No([\s\d]+)', star)
16
+ if NGC_match:
17
+ cluster = NGC_match.group(1).replace(' ', '')
18
+ target = NGC_match.group(2).replace(' ', '')
19
+ return f'Cl* NGC {cluster} MMU {target}'
20
+
10
21
  return star
arvi/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import time
2
3
  from contextlib import contextmanager
3
4
  try:
4
5
  from unittest.mock import patch
@@ -19,6 +20,9 @@ except ImportError:
19
20
  tqdm = lambda x, *args, **kwargs: x
20
21
  trange = lambda *args, **kwargs: range(*args, **kwargs)
21
22
 
23
+ from .setup_logger import logger
24
+ from . import config
25
+
22
26
 
23
27
  def create_directory(directory):
24
28
  """ Create a directory if it does not exist """
@@ -61,6 +65,23 @@ def all_logging_disabled():
61
65
  finally:
62
66
  logging.disable(previous_level)
63
67
 
68
+
69
+ @contextmanager
70
+ def timer():
71
+ """ A simple context manager to time a block of code """
72
+ if not config.debug:
73
+ yield
74
+ return
75
+
76
+ logger.debug(f'starting timer')
77
+ start = time.time()
78
+ try:
79
+ yield
80
+ finally:
81
+ end = time.time()
82
+ logger.debug(f'elapsed time: {end - start:.2f} seconds')
83
+
84
+
64
85
  def strtobool(val):
65
86
  """Convert a string representation of truth to true (1) or false (0).
66
87