gwpy 3.0.7__py3-none-any.whl → 3.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwpy might be problematic. Click here for more details.

Files changed (50) hide show
  1. gwpy/_version.py +2 -2
  2. gwpy/astro/range.py +3 -3
  3. gwpy/astro/tests/test_range.py +4 -5
  4. gwpy/cli/qtransform.py +1 -1
  5. gwpy/detector/tests/test_units.py +3 -0
  6. gwpy/detector/units.py +12 -3
  7. gwpy/frequencyseries/frequencyseries.py +13 -4
  8. gwpy/io/cache.py +24 -1
  9. gwpy/io/datafind.py +1 -0
  10. gwpy/io/ffldatafind.py +27 -16
  11. gwpy/io/tests/test_ffldatafind.py +10 -3
  12. gwpy/plot/gps.py +7 -2
  13. gwpy/plot/tests/test_gps.py +1 -0
  14. gwpy/plot/tests/test_tex.py +3 -0
  15. gwpy/plot/tex.py +8 -6
  16. gwpy/segments/flag.py +14 -3
  17. gwpy/segments/tests/test_flag.py +53 -0
  18. gwpy/signal/filter_design.py +4 -3
  19. gwpy/signal/spectral/_lal.py +1 -1
  20. gwpy/signal/tests/test_coherence.py +9 -9
  21. gwpy/signal/tests/test_filter_design.py +3 -3
  22. gwpy/spectrogram/tests/test_spectrogram.py +2 -2
  23. gwpy/testing/fixtures.py +2 -19
  24. gwpy/testing/utils.py +43 -19
  25. gwpy/time/_tconvert.py +19 -33
  26. gwpy/time/tests/test_time.py +4 -45
  27. gwpy/timeseries/core.py +4 -2
  28. gwpy/timeseries/io/cache.py +47 -29
  29. gwpy/timeseries/io/gwf/framecpp.py +11 -2
  30. gwpy/timeseries/io/gwf/lalframe.py +21 -0
  31. gwpy/timeseries/io/losc.py +8 -2
  32. gwpy/timeseries/tests/test_io_cache.py +74 -0
  33. gwpy/timeseries/tests/test_io_gwf_lalframe.py +17 -0
  34. gwpy/timeseries/tests/test_timeseries.py +78 -36
  35. gwpy/types/array.py +16 -3
  36. gwpy/types/index.py +7 -5
  37. gwpy/types/sliceutils.py +5 -1
  38. gwpy/types/tests/test_array2d.py +4 -0
  39. gwpy/types/tests/test_series.py +26 -0
  40. gwpy/utils/shell.py +2 -2
  41. gwpy/utils/sphinx/zenodo.py +118 -61
  42. gwpy/utils/tests/test_shell.py +2 -2
  43. gwpy/utils/tests/test_sphinx_zenodo.py +175 -0
  44. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/METADATA +8 -7
  45. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/RECORD +49 -48
  46. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/WHEEL +1 -1
  47. gwpy/utils/sphinx/epydoc.py +0 -104
  48. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/LICENSE +0 -0
  49. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/entry_points.txt +0 -0
  50. {gwpy-3.0.7.dist-info → gwpy-3.0.9.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,7 @@ NDS2_GW150914_CHANNEL = "L1:DCS-CALIB_STRAIN_C02"
91
91
  GWOSC_GW150914_FRAMETYPE = "L1_LOSC_16_V1"
92
92
  GWOSC_GW150914 = 1126259462
93
93
  GWOSC_GW150914_SEGMENT = Segment(GWOSC_GW150914-2, GWOSC_GW150914+2)
94
+ GWOSC_GW150914_SEGMENT_32 = Segment(GWOSC_GW150914-16, GWOSC_GW150914+16)
94
95
  GWOSC_GW150914_DQ_BITS = {
95
96
  'hdf5': [
96
97
  'data present',
@@ -153,6 +154,22 @@ class TestTimeSeries(_TestTimeSeriesBase):
153
154
  sample_rate=16384,
154
155
  )
155
156
 
157
+ @pytest.fixture(scope="class")
158
+ @pytest_skip_network_error
159
+ def gw150914_h1_32(self):
160
+ return self.TEST_CLASS.fetch_open_data(
161
+ "H1",
162
+ *GWOSC_GW150914_SEGMENT_32,
163
+ )
164
+
165
+ @pytest.fixture(scope="class")
166
+ @pytest_skip_network_error
167
+ def gw150914_l1_32(self):
168
+ return self.TEST_CLASS.fetch_open_data(
169
+ "L1",
170
+ *GWOSC_GW150914_SEGMENT_32,
171
+ )
172
+
156
173
  # -- test class functionality ---------------
157
174
 
158
175
  def test_ligotimegps(self):
@@ -236,28 +253,29 @@ class TestTimeSeries(_TestTimeSeriesBase):
236
253
  exclude=['channel'])
237
254
 
238
255
  @pytest.mark.parametrize('api', GWF_APIS)
239
- def test_read_write_gwf_gps_errors(self, tmp_path, api):
256
+ def test_read_gwf_end_error(self, api):
257
+ """Test that reading past the end of available data fails.
258
+ """
240
259
  fmt = "gwf" if api is None else "gwf." + api
241
- array = self.create(name='TEST')
242
- tmp = tmp_path / "test.gwf"
243
- array.write(tmp, format=fmt)
244
-
245
- # check that reading past the end of the array fails
246
- with pytest.raises((ValueError, RuntimeError)):
260
+ with pytest.raises(ValueError):
247
261
  self.TEST_CLASS.read(
248
- tmp,
249
- array.name,
262
+ utils.TEST_GWF_FILE,
263
+ "L1:LDAS-STRAIN",
250
264
  format=fmt,
251
- start=array.span[1],
265
+ start=utils.TEST_GWF_SPAN[1],
252
266
  )
253
267
 
254
- # check that reading before the start of the array also fails
255
- with pytest.raises((ValueError, RuntimeError)):
268
+ @pytest.mark.parametrize('api', GWF_APIS)
269
+ def test_read_gwf_negative_duration_error(self, api):
270
+ """Test that reading a negative duration fails.
271
+ """
272
+ fmt = "gwf" if api is None else "gwf." + api
273
+ with pytest.raises(ValueError):
256
274
  self.TEST_CLASS.read(
257
- tmp,
258
- array.name,
275
+ utils.TEST_GWF_FILE,
276
+ "L1:LDAS-STRAIN",
259
277
  format=fmt,
260
- end=array.span[0]-1,
278
+ end=utils.TEST_GWF_SPAN[0]-1,
261
279
  )
262
280
 
263
281
  @pytest.mark.parametrize('api', GWF_APIS)
@@ -784,6 +802,22 @@ class TestTimeSeries(_TestTimeSeriesBase):
784
802
  assert fs.size == 129
785
803
  assert fs.dx == gw150914.sample_rate / 256
786
804
 
805
+ @pytest.mark.parametrize("data", [
806
+ [1., 0., -1., 0.],
807
+ [1., 2., 3., 2., 1., 0.],
808
+ numpy.arange(10),
809
+ numpy.random.random(100),
810
+ ])
811
+ def test_fft_ifft(self, data):
812
+ a = self.TEST_CLASS(data)
813
+ utils.assert_quantity_sub_equal(
814
+ a,
815
+ a.fft().ifft(),
816
+ almost_equal=True,
817
+ rtol=1e-7,
818
+ atol=1e-10,
819
+ )
820
+
787
821
  def test_average_fft(self, gw150914):
788
822
  # test all defaults
789
823
  fs = gw150914.average_fft()
@@ -1040,18 +1074,22 @@ class TestTimeSeries(_TestTimeSeriesBase):
1040
1074
  fgram = gw150914.fftgram(1)
1041
1075
  fs = int(gw150914.sample_rate.value)
1042
1076
  f, t, sxx = signal.spectrogram(
1043
- gw150914, fs,
1077
+ gw150914,
1078
+ fs,
1044
1079
  window='hann',
1045
1080
  nperseg=fs,
1046
1081
  mode='complex',
1047
1082
  )
1048
1083
  utils.assert_array_equal(gw150914.t0.value + t, fgram.xindex.value)
1049
1084
  utils.assert_array_equal(f, fgram.yindex.value)
1050
- utils.assert_array_equal(sxx.T, fgram)
1085
+ utils.assert_array_equal(sxx.T, fgram.value)
1051
1086
 
1087
+ def test_fftgram_overlap(self, gw150914):
1052
1088
  fgram = gw150914.fftgram(1, overlap=0.5)
1089
+ fs = int(gw150914.sample_rate.value)
1053
1090
  f, t, sxx = signal.spectrogram(
1054
- gw150914, fs,
1091
+ gw150914,
1092
+ fs,
1055
1093
  window='hann',
1056
1094
  nperseg=fs,
1057
1095
  noverlap=fs//2,
@@ -1059,7 +1097,7 @@ class TestTimeSeries(_TestTimeSeriesBase):
1059
1097
  )
1060
1098
  utils.assert_array_equal(gw150914.t0.value + t, fgram.xindex.value)
1061
1099
  utils.assert_array_equal(f, fgram.yindex.value)
1062
- utils.assert_array_equal(sxx.T, fgram)
1100
+ utils.assert_array_equal(sxx.T, fgram.value)
1063
1101
 
1064
1102
  def test_spectral_variance(self, gw150914):
1065
1103
  variance = gw150914.spectral_variance(.5, method="median")
@@ -1348,7 +1386,9 @@ class TestTimeSeries(_TestTimeSeriesBase):
1348
1386
 
1349
1387
  def test_convolve(self):
1350
1388
  data = self.TEST_CLASS(
1351
- signal.hann(1024), sample_rate=512, epoch=-1
1389
+ signal.get_window("hann", 1024),
1390
+ sample_rate=512,
1391
+ epoch=-1,
1352
1392
  )
1353
1393
  filt = numpy.array([1, 0])
1354
1394
 
@@ -1498,28 +1538,30 @@ class TestTimeSeries(_TestTimeSeriesBase):
1498
1538
  assert comp.name == '%s >= 2.0' % (array.name)
1499
1539
  assert (array == array).name == '{0} == {0}'.format(array.name)
1500
1540
 
1501
- @pytest_skip_network_error
1502
- def test_transfer_function(self):
1503
- tsh = TimeSeries.fetch_open_data('H1', 1126259446, 1126259478)
1504
- tsl = TimeSeries.fetch_open_data('L1', 1126259446, 1126259478)
1505
- tf = tsh.transfer_function(tsl, fftlength=1.0, overlap=0.5)
1541
+ def test_transfer_function(self, gw150914_h1_32, gw150914_l1_32):
1542
+ tf = gw150914_h1_32.transfer_function(
1543
+ gw150914_l1_32,
1544
+ fftlength=1.0,
1545
+ overlap=0.5,
1546
+ )
1506
1547
  assert tf.df == 1 * units.Hz
1507
1548
  assert tf.frequencies[abs(tf).argmax()] == 516 * units.Hz
1508
1549
 
1509
- @pytest_skip_network_error
1510
- def test_coherence(self):
1511
- tsh = TimeSeries.fetch_open_data('H1', 1126259446, 1126259478)
1512
- tsl = TimeSeries.fetch_open_data('L1', 1126259446, 1126259478)
1513
- coh = tsh.coherence(tsl, fftlength=1.0)
1550
+ def test_coherence(self, gw150914_h1_32, gw150914_l1_32):
1551
+ coh = gw150914_h1_32.coherence(
1552
+ gw150914_l1_32,
1553
+ fftlength=1.0,
1554
+ )
1514
1555
  assert coh.df == 1 * units.Hz
1515
1556
  assert coh.frequencies[coh.argmax()] == 60 * units.Hz
1516
1557
 
1517
- @pytest_skip_network_error
1518
- def test_coherence_spectrogram(self):
1519
- tsh = TimeSeries.fetch_open_data('H1', 1126259446, 1126259478)
1520
- tsl = TimeSeries.fetch_open_data('L1', 1126259446, 1126259478)
1521
- cohsg = tsh.coherence_spectrogram(tsl, 4, fftlength=1.0)
1522
- assert cohsg.t0 == tsh.t0
1558
+ def test_coherence_spectrogram(self, gw150914_h1_32, gw150914_l1_32):
1559
+ cohsg = gw150914_h1_32.coherence_spectrogram(
1560
+ gw150914_l1_32,
1561
+ 4,
1562
+ fftlength=1.0,
1563
+ )
1564
+ assert cohsg.t0 == gw150914_h1_32.t0
1523
1565
  assert cohsg.dt == 4 * units.second
1524
1566
  assert cohsg.df == 1 * units.Hz
1525
1567
  tmax, fmax = numpy.unravel_index(cohsg.argmax(), cohsg.shape)
gwpy/types/array.py CHANGED
@@ -34,6 +34,11 @@ from math import modf
34
34
  import numpy
35
35
 
36
36
  from astropy.units import Quantity
37
+ try:
38
+ from astropy.utils.compat.numpycompat import COPY_IF_NEEDED
39
+ except ImportError: # astropy < 6.1
40
+ from astropy.utils import minversion
41
+ COPY_IF_NEEDED = None if minversion(numpy, "2.0.0.dev") else False
37
42
 
38
43
  from ..detector import Channel
39
44
  from ..detector.units import parse_unit
@@ -119,8 +124,16 @@ class Array(Quantity):
119
124
  unit = parse_unit(unit, parse_strict='warn')
120
125
 
121
126
  # create new array
122
- new = super().__new__(cls, value, unit=unit, dtype=dtype, copy=False,
123
- order=order, subok=subok, ndmin=ndmin)
127
+ new = super().__new__(
128
+ cls,
129
+ value,
130
+ unit=unit,
131
+ dtype=dtype,
132
+ copy=COPY_IF_NEEDED,
133
+ order=order,
134
+ subok=subok,
135
+ ndmin=ndmin,
136
+ )
124
137
 
125
138
  # explicitly copy here to get ownership of the data,
126
139
  # see (astropy/astropy#7244)
@@ -397,7 +410,7 @@ class Array(Quantity):
397
410
  out = super().__array_ufunc__(function, method, *inputs, **kwargs)
398
411
  # if a ufunc returns a scalar, return a Quantity
399
412
  if not out.ndim:
400
- return Quantity(out, copy=False)
413
+ return Quantity(out, copy=COPY_IF_NEEDED)
401
414
  # otherwise return an array
402
415
  return out
403
416
 
gwpy/types/index.py CHANGED
@@ -23,6 +23,8 @@ import numpy
23
23
 
24
24
  from astropy.units import Quantity
25
25
 
26
+ from .array import COPY_IF_NEEDED
27
+
26
28
 
27
29
  class Index(Quantity):
28
30
  """1-D `~astropy.units.Quantity` array for indexing a `Series`
@@ -57,11 +59,11 @@ class Index(Quantity):
57
59
  """
58
60
  if dtype is None:
59
61
  dtype = max(
60
- numpy.array(start, subok=True, copy=False).dtype,
61
- numpy.array(step, subok=True, copy=False).dtype,
62
+ numpy.array(start, subok=True, copy=COPY_IF_NEEDED).dtype,
63
+ numpy.array(step, subok=True, copy=COPY_IF_NEEDED).dtype,
62
64
  )
63
- start = Quantity(start, dtype=dtype, copy=False)
64
- step = Quantity(step, dtype=dtype, copy=False).to(start.unit)
65
+ start = Quantity(start, dtype=dtype, copy=COPY_IF_NEEDED)
66
+ step = Quantity(step, dtype=dtype, copy=COPY_IF_NEEDED).to(start.unit)
65
67
  stop = start + step * num
66
68
  return cls(
67
69
  numpy.arange(
@@ -71,7 +73,7 @@ class Index(Quantity):
71
73
  dtype=dtype,
72
74
  )[:num],
73
75
  unit=start.unit,
74
- copy=False,
76
+ copy=COPY_IF_NEEDED,
75
77
  )
76
78
 
77
79
  @property
gwpy/types/sliceutils.py CHANGED
@@ -107,7 +107,11 @@ def null_slice(slice_):
107
107
  except TypeError:
108
108
  return False
109
109
 
110
- if isinstance(slice_, numpy.ndarray) and numpy.all(slice_):
110
+ if (
111
+ isinstance(slice_, numpy.ndarray)
112
+ and slice_.dtype == bool
113
+ and slice_.all()
114
+ ):
111
115
  return True
112
116
  if isinstance(slice_, slice) and slice_ in (
113
117
  slice(None, None, None), slice(0, None, 1)
@@ -250,3 +250,7 @@ class TestArray2D(_TestSeries):
250
250
  @pytest.mark.skip("not implemented for >1D arrays")
251
251
  def test_pad_asymmetric(self):
252
252
  return NotImplemented
253
+
254
+ @pytest.mark.skip("not applicable for >1D arrays")
255
+ def test_single_getitem_not_created(self):
256
+ return NotImplemented
@@ -163,6 +163,32 @@ class TestSeries(_TestArray):
163
163
  name=array.name, epoch=array.epoch, unit=array.unit),
164
164
  )
165
165
 
166
+ def test_getitem_index(self, array):
167
+ """Test that __getitem__ also applies to an xindex.
168
+
169
+ When subsetting a Series with an iterable of integer indices,
170
+ make sure that the xindex, if it exists, is also subsetted. Tests
171
+ regression against https://github.com/gwpy/gwpy/issues/1680.
172
+ """
173
+ array.xindex # create xindex
174
+ indices = numpy.array([0, 1, len(array)-1])
175
+ newarray = array[indices]
176
+
177
+ assert len(newarray) == 3
178
+ assert len(newarray) == len(newarray.value)
179
+ assert len(newarray.value) == len(newarray.xindex)
180
+
181
+ def test_single_getitem_not_created(self, array):
182
+ """Test that array[i] does not return an object with a new _xindex."""
183
+
184
+ # check that there is no xindex when a single value is accessed
185
+ with pytest.raises(AttributeError):
186
+ array[0].xindex
187
+
188
+ # we don't need this, we don't want it accidentally injected
189
+ with pytest.raises(AttributeError):
190
+ array[0]._xindex
191
+
166
192
  def test_empty_slice(self, array):
167
193
  """Check that we can slice a `Series` into nothing
168
194
 
gwpy/utils/shell.py CHANGED
@@ -20,7 +20,6 @@
20
20
  """
21
21
 
22
22
  import warnings
23
- from distutils.spawn import find_executable
24
23
  from subprocess import (Popen, PIPE, CalledProcessError)
25
24
 
26
25
  from .decorators import deprecated_function
@@ -52,7 +51,8 @@ def which(program):
52
51
  ValueError
53
52
  if not executable program is found
54
53
  """
55
- exe = find_executable(program)
54
+ from shutil import which
55
+ exe = which(program)
56
56
  if exe is None:
57
57
  raise ValueError("No executable '%s' found in PATH" % program)
58
58
  return exe
@@ -1,5 +1,5 @@
1
1
  # -*- coding: utf-8 -*-
2
- # Copyright (C) Duncan Macleod (2018-2020)
2
+ # Copyright (C) Cardiff University (2018-2023)
3
3
  #
4
4
  # This file is part of GWpy.
5
5
  #
@@ -17,89 +17,146 @@
17
17
  # along with GWpy. If not, see <http://www.gnu.org/licenses/>.
18
18
 
19
19
  import argparse
20
- import requests
21
20
  import sys
22
21
 
22
+ import requests
23
23
 
24
- def parse_command_line():
25
- parser = argparse.ArgumentParser(
26
- description=__doc__,
27
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
- parser.add_argument('id', type=int, help='Zenodo ID for package')
29
- parser.add_argument('-u', '--url', default='https://zenodo.org/',
30
- help='%(metavar)s to query')
31
- parser.add_argument('-n', '--hits', default=10, type=int,
32
- help='number of versions to display')
33
- parser.add_argument('-p', '--tag-prefix', default='v',
34
- help='prefix for git version tags')
35
- parser.add_argument('-o', '--output-file', help='output file path')
36
-
37
- if len(sys.argv) == 1: # print --help message for no arguments
38
- parser.print_help()
39
- sys.exit(0)
40
-
41
- return parser.parse_args()
24
+ DEFAULT_ZENODO_URL = "https://zenodo.org"
25
+ DEFAULT_HITS = 10
42
26
 
43
27
 
44
- def format_citations(zid, url='https://zenodo.org/', hits=10, tag_prefix='v'):
45
- """Query and format a citations page from Zenodo entries
28
+ def format_citations(
29
+ zid,
30
+ url=DEFAULT_ZENODO_URL,
31
+ hits=10,
32
+ tag_prefix="v",
33
+ ):
34
+ """Query and format a citations page from Zenodo entries.
46
35
 
47
36
  Parameters
48
37
  ----------
49
38
  zid : `int`, `str`
50
- the Zenodo ID of the target record
39
+ The Zenodo ID (``conceptrecid``) of the parent target record.
51
40
 
52
41
  url : `str`, optional
53
- the base URL of the Zenodo host, defaults to ``https://zenodo.org``
42
+ The base URL of the Zenodo host, defaults to ``https://zenodo.org``.
54
43
 
55
- hist : `int`, optional
56
- the maximum number of hits to show, default: ``10``
44
+ hits : `int`, optional
45
+ The maximum number of results to show, default: ``10``.
57
46
 
58
47
  tag_prefix : `str`, optional
59
- the prefix for git tags. This is removed to generate the section
60
- headers in the output RST
48
+ The prefix for git tags. This is removed to generate the section
49
+ headers in the output RST.
61
50
 
62
51
  Returns
63
52
  -------
64
53
  rst : `str`
65
- an RST-formatted string of DOI badges with URLs
54
+ An RST-formatted string of DOI badges with URLs.
66
55
  """
67
56
  # query for metadata
68
- url = ('{url}/api/records/?'
69
- 'page=1&'
70
- 'size={hits}&'
71
- 'q=conceptrecid:"{id}"&'
72
- 'sort=-version&'
73
- 'all_versions=True'.format(id=zid, url=url, hits=hits))
74
- resp = requests.get(url) # make the request
57
+ apiurl = f"{url.rstrip('/')}/api/records"
58
+ params = {
59
+ "q": f"conceptrecid:{zid}",
60
+ "allversions": True,
61
+ "sort": "version",
62
+ "page": 1,
63
+ "size": int(hits),
64
+ }
65
+ resp = requests.get(apiurl, params) # make the request
75
66
  resp.raise_for_status() # make sure it worked
76
- metadata = resp.json() # parse the response
67
+ records = resp.json() # parse the response
77
68
 
78
69
  lines = []
79
- for i, hit in enumerate(metadata['hits']['hits']):
80
- version = hit['metadata']['version'][len(tag_prefix):]
81
- lines.append('-' * len(version))
82
- lines.append(version)
83
- lines.append('-' * len(version))
84
- lines.append('')
85
- lines.append('.. image:: {badge}\n'
86
- ' :target: {doi}'.format(**hit['links']))
87
- if i < hits - 1:
88
- lines.append('')
89
-
90
- return '\n'.join(lines)
91
-
92
-
93
- if __name__ == '__main__':
94
- args = parse_command_line()
95
-
96
- if args.output_file:
97
- f = open(args.output_file, 'w')
98
- else:
70
+ for i, rec in enumerate(records["hits"]["hits"]):
71
+ # print RST-format header
72
+ version = str(rec['metadata']['version'])[len(tag_prefix):]
73
+ head = "-" * len(version)
74
+ lines.extend([
75
+ head,
76
+ version,
77
+ head,
78
+ "",
79
+ ])
80
+
81
+ # add DOI badge
82
+ badge = f"{url}/badge/doi/{rec['doi']}.svg"
83
+ lines.extend([
84
+ f".. image:: {badge}",
85
+ f" :alt: {rec['title']} Zenodo DOI badge",
86
+ f" :target: {rec['doi_url']}",
87
+ ])
88
+
89
+ # add break before next record
90
+ lines.append("")
91
+
92
+ return '\n'.join(lines).strip()
93
+
94
+
95
+ # -- command-line usage ---------------
96
+
97
+ def create_parser():
98
+ """Create an `argparse.ArgumentParser` for this tool.
99
+ """
100
+ parser = argparse.ArgumentParser(
101
+ description=__doc__,
102
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
103
+ )
104
+ parser.add_argument(
105
+ "id",
106
+ type=int,
107
+ help="Zenodo concept ID for package",
108
+ )
109
+ parser.add_argument(
110
+ "-u",
111
+ "--url",
112
+ default=DEFAULT_ZENODO_URL,
113
+ help="Base URL of API to query",
114
+ )
115
+ parser.add_argument(
116
+ "-n",
117
+ "--hits",
118
+ default=DEFAULT_HITS,
119
+ type=int,
120
+ help="Number of versions to display",
121
+ )
122
+ parser.add_argument(
123
+ "-p",
124
+ "--tag-prefix",
125
+ default="v",
126
+ help="Prefix for version tags",
127
+ )
128
+ parser.add_argument(
129
+ "-o",
130
+ "--output-file",
131
+ default="stdout",
132
+ help="Output file path",
133
+ )
134
+ return parser
135
+
136
+
137
+ def main(args=None):
138
+ """Run this tool as a command-line script.
139
+ """
140
+ # parse arguments
141
+ parser = create_parser()
142
+ opts = parser.parse_args(args=args)
143
+
144
+ # generate RST
145
+ citing = format_citations(
146
+ opts.id,
147
+ url=opts.url,
148
+ hits=opts.hits,
149
+ tag_prefix=opts.tag_prefix,
150
+ )
151
+
152
+ # print
153
+ if opts.output_file in {None, "stdout"}:
99
154
  f = sys.stdout
100
-
101
- citing = format_citations(args.id, url=args.url, hits=args.hits,
102
- tag_prefix=args.tag_prefix)
103
-
155
+ else:
156
+ f = open(opts.output_file, 'w')
104
157
  with f:
105
158
  print(citing, file=f)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ sys.exit(main())
@@ -20,9 +20,9 @@
20
20
  """
21
21
 
22
22
  import platform
23
+ import shutil
23
24
  import subprocess
24
25
  import sys
25
- from distutils.spawn import find_executable
26
26
 
27
27
  import pytest
28
28
 
@@ -72,6 +72,6 @@ def test_shell_call_error_warn():
72
72
 
73
73
  def test_which():
74
74
  with pytest.warns(DeprecationWarning):
75
- assert shell.which('python') == find_executable('python')
75
+ assert shell.which('python') == shutil.which('python')
76
76
  with pytest.raises(ValueError), pytest.warns(DeprecationWarning):
77
77
  shell.which('gwpy-no-executable')