ml4gw 0.7.5__tar.gz → 0.7.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (121) hide show
  1. {ml4gw-0.7.5 → ml4gw-0.7.7}/PKG-INFO +6 -5
  2. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/augmentations.py +4 -4
  3. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/chunked_dataset.py +3 -3
  4. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/hdf5_dataset.py +7 -10
  5. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/in_memory_dataset.py +21 -21
  6. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/distributions.py +20 -18
  7. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/gw.py +60 -53
  8. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/base.py +9 -9
  9. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/convolutional.py +4 -4
  10. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/resnet_1d.py +13 -13
  11. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/resnet_2d.py +12 -12
  12. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/online_average.py +1 -1
  13. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/snapshotter.py +14 -14
  14. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/spectral.py +48 -48
  15. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/__init__.py +1 -1
  16. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/iirfilter.py +3 -3
  17. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/pearson.py +7 -8
  18. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/qtransform.py +29 -34
  19. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/scaler.py +4 -4
  20. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spectral.py +10 -10
  21. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spectrogram.py +12 -11
  22. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spline_interpolation.py +310 -146
  23. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/transform.py +1 -1
  24. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/whitening.py +36 -36
  25. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/utils/slicing.py +40 -40
  26. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_d.py +22 -66
  27. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_p.py +9 -5
  28. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/taylorf2.py +8 -7
  29. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/conversion.py +2 -1
  30. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/generator.py +33 -32
  31. ml4gw-0.7.7/ml4gw.egg-info/PKG-INFO +57 -0
  32. ml4gw-0.7.7/ml4gw.egg-info/SOURCES.txt +63 -0
  33. ml4gw-0.7.7/ml4gw.egg-info/dependency_links.txt +1 -0
  34. ml4gw-0.7.7/ml4gw.egg-info/requires.txt +5 -0
  35. ml4gw-0.7.7/ml4gw.egg-info/top_level.txt +1 -0
  36. {ml4gw-0.7.5 → ml4gw-0.7.7}/pyproject.toml +12 -3
  37. ml4gw-0.7.7/setup.cfg +4 -0
  38. {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_distributions.py +43 -3
  39. {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_gw.py +5 -0
  40. {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_spectral.py +7 -1
  41. ml4gw-0.7.5/.coverage +0 -0
  42. ml4gw-0.7.5/.gitattributes +0 -2
  43. ml4gw-0.7.5/.github/workflows/coverage.yaml +0 -31
  44. ml4gw-0.7.5/.github/workflows/docs.yaml +0 -29
  45. ml4gw-0.7.5/.github/workflows/pre-commit.yaml +0 -17
  46. ml4gw-0.7.5/.github/workflows/publish.yaml +0 -27
  47. ml4gw-0.7.5/.github/workflows/unit-tests.yaml +0 -80
  48. ml4gw-0.7.5/.gitignore +0 -3
  49. ml4gw-0.7.5/.pre-commit-config.yaml +0 -23
  50. ml4gw-0.7.5/.readthedocs.yaml +0 -36
  51. ml4gw-0.7.5/CITATION.cff +0 -37
  52. ml4gw-0.7.5/docs/Makefile +0 -20
  53. ml4gw-0.7.5/docs/conf.py +0 -65
  54. ml4gw-0.7.5/docs/index.rst +0 -50
  55. ml4gw-0.7.5/docs/installation.rst +0 -19
  56. ml4gw-0.7.5/docs/make.bat +0 -35
  57. ml4gw-0.7.5/docs/ml4gw.dataloading.rst +0 -37
  58. ml4gw-0.7.5/docs/ml4gw.nn.autoencoder.rst +0 -45
  59. ml4gw-0.7.5/docs/ml4gw.nn.resnet.rst +0 -29
  60. ml4gw-0.7.5/docs/ml4gw.nn.rst +0 -31
  61. ml4gw-0.7.5/docs/ml4gw.nn.streaming.rst +0 -29
  62. ml4gw-0.7.5/docs/ml4gw.rst +0 -64
  63. ml4gw-0.7.5/docs/ml4gw.transforms.rst +0 -77
  64. ml4gw-0.7.5/docs/ml4gw.waveforms.rst +0 -53
  65. ml4gw-0.7.5/docs/modules.rst +0 -7
  66. ml4gw-0.7.5/docs/requirements.txt +0 -3
  67. ml4gw-0.7.5/examples/README.md +0 -12
  68. ml4gw-0.7.5/examples/ml4gw_tutorial.ipynb +0 -1757
  69. ml4gw-0.7.5/examples/pyproject.toml +0 -22
  70. ml4gw-0.7.5/examples/uv.lock +0 -2960
  71. ml4gw-0.7.5/tests/conftest.py +0 -200
  72. ml4gw-0.7.5/tests/dataloading/test_chunked_dataset.py +0 -82
  73. ml4gw-0.7.5/tests/dataloading/test_hdf5_dataset.py +0 -188
  74. ml4gw-0.7.5/tests/dataloading/test_in_memory_dataset.py +0 -357
  75. ml4gw-0.7.5/tests/nn/resnet/test_resnet_1d.py +0 -137
  76. ml4gw-0.7.5/tests/nn/resnet/test_resnet_2d.py +0 -138
  77. ml4gw-0.7.5/tests/nn/streaming/test_online_average.py +0 -88
  78. ml4gw-0.7.5/tests/nn/streaming/test_snapshotter.py +0 -120
  79. ml4gw-0.7.5/tests/nn/test_norm.py +0 -75
  80. ml4gw-0.7.5/tests/transforms/test_iirfilter.py +0 -321
  81. ml4gw-0.7.5/tests/transforms/test_pearson.py +0 -81
  82. ml4gw-0.7.5/tests/transforms/test_qtransform.py +0 -184
  83. ml4gw-0.7.5/tests/transforms/test_scaler.py +0 -123
  84. ml4gw-0.7.5/tests/transforms/test_snr_rescaler.py +0 -86
  85. ml4gw-0.7.5/tests/transforms/test_spectral_transform.py +0 -290
  86. ml4gw-0.7.5/tests/transforms/test_spectrogram.py +0 -109
  87. ml4gw-0.7.5/tests/transforms/test_spline_interpolation.py +0 -101
  88. ml4gw-0.7.5/tests/transforms/test_waveforms.py +0 -101
  89. ml4gw-0.7.5/tests/transforms/test_whitening.py +0 -191
  90. ml4gw-0.7.5/tests/utils/test_slicing.py +0 -334
  91. ml4gw-0.7.5/tests/waveforms/adhoc/test_sine_gaussian.py +0 -100
  92. ml4gw-0.7.5/tests/waveforms/cbc/test_cbc_waveforms.py +0 -480
  93. ml4gw-0.7.5/tests/waveforms/cbc/test_utils.py +0 -115
  94. ml4gw-0.7.5/tests/waveforms/test_conversion.py +0 -65
  95. ml4gw-0.7.5/tests/waveforms/test_generator.py +0 -216
  96. ml4gw-0.7.5/uv.lock +0 -3344
  97. {ml4gw-0.7.5 → ml4gw-0.7.7}/LICENSE +0 -0
  98. {ml4gw-0.7.5 → ml4gw-0.7.7}/README.md +0 -0
  99. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/__init__.py +0 -0
  100. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/constants.py +0 -0
  101. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/__init__.py +0 -0
  102. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/__init__.py +0 -0
  103. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/__init__.py +0 -0
  104. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
  105. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/utils.py +0 -0
  106. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/norm.py +0 -0
  107. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/__init__.py +0 -0
  108. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/__init__.py +0 -0
  109. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/snr_rescaler.py +0 -0
  110. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/waveforms.py +0 -0
  111. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/types.py +0 -0
  112. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/utils/interferometer.py +0 -0
  113. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/__init__.py +0 -0
  114. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/__init__.py +0 -0
  115. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/ringdown.py +0 -0
  116. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/sine_gaussian.py +0 -0
  117. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/__init__.py +0 -0
  118. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/coefficients.py +0 -0
  119. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_d_data.py +0 -0
  120. {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/utils.py +0 -0
  121. {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_augmentations.py +0 -0
@@ -1,9 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ml4gw
3
- Version: 0.7.5
3
+ Version: 0.7.7
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
6
- License-File: LICENSE
7
6
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
8
7
  Classifier: Programming Language :: Python :: 3.9
9
8
  Classifier: Programming Language :: Python :: 3.10
@@ -11,12 +10,14 @@ Classifier: Programming Language :: Python :: 3.11
11
10
  Classifier: Programming Language :: Python :: 3.12
12
11
  Classifier: Programming Language :: Python :: 3.13
13
12
  Requires-Python: <3.13,>=3.9
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
14
15
  Requires-Dist: jaxtyping<0.3,>=0.2
16
+ Requires-Dist: torch~=2.0
17
+ Requires-Dist: torchaudio~=2.0
15
18
  Requires-Dist: numpy<2.0.0
16
19
  Requires-Dist: scipy<1.15,>=1.9.0
17
- Requires-Dist: torchaudio~=2.0
18
- Requires-Dist: torch~=2.0
19
- Description-Content-Type: text/markdown
20
+ Dynamic: license-file
20
21
 
21
22
  # ML4GW
22
23
  ![PyPI - Version](https://img.shields.io/pypi/v/ml4gw)
@@ -6,8 +6,8 @@ from torch import Tensor
6
6
  class SignalInverter(torch.nn.Module):
7
7
  """
8
8
  Takes a tensor of timeseries of arbitrary dimension
9
- and randomly inverts (i.e. h(t) -> -h(t))
10
- each timeseries with probability `prob`.
9
+ and randomly inverts i.e. :math:`h(t) \\rightarrow -h(t)`
10
+ each timeseries with probability ``prob``.
11
11
 
12
12
  Args:
13
13
  prob:
@@ -29,8 +29,8 @@ class SignalInverter(torch.nn.Module):
29
29
  class SignalReverser(torch.nn.Module):
30
30
  """
31
31
  Takes a tensor of timeseries of arbitrary dimension
32
- and randomly reverses (i.e. h(t) -> h(-t))
33
- each timeseries with probability `prob`.
32
+ and randomly reverses i.e., :math:`h(t) \\rightarrow h(-t)`.
33
+ each timeseries with probability ``prob``.
34
34
 
35
35
  Args:
36
36
  prob:
@@ -15,9 +15,9 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
15
15
  chunk_it:
16
16
  Iterator which will produce chunks of timeseries
17
17
  data to sample windows from. Should have shape
18
- `(N, C, T)`, where `N` is the number of chunks
19
- to sample from, `C` is the number of channels,
20
- and `T` is the number of samples along the
18
+ ``(N, C, T)``, where ``N`` is the number of chunks
19
+ to sample from, ``C`` is the number of channels,
20
+ and ``T`` is the number of samples along the
21
21
  time dimension for each chunk.
22
22
  kernel_size:
23
23
  Size of windows to be sampled from each chunk.
@@ -17,8 +17,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
17
17
  Iterable dataset that samples and loads windows of
18
18
  timeseries data uniformly from a set of HDF5 files.
19
19
  It is _strongly_ recommended that these files have been
20
- written using [chunked storage]
21
- (https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
20
+ written using `chunked storage <https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage>`_.
22
21
  This has shown to produce increases in read-time speeds
23
22
  of over an order of magnitude.
24
23
 
@@ -37,27 +36,25 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
37
36
  Number of windows to sample at each iteration.
38
37
  batches_per_epoch:
39
38
  Number of batches to generate during each call
40
- to `__iter__`.
39
+ to ``__iter__``.
41
40
  coincident:
42
41
  Whether windows for each channel in a given batch
43
42
  element should be sampled coincidentally, i.e.
44
43
  corresponding to the same time indices from the
45
44
  same files, or should be sampled independently.
46
45
  For the latter case, users can either specify
47
- `False`, which will sample filenames independently
48
- for each channel, or `"files"`, which will sample
46
+ ``False``, which will sample filenames independently
47
+ for each channel, or ``"files"``, which will sample
49
48
  windows independently within a given file for each
50
49
  channel. The latter setting limits the amount of
51
50
  entropy in the effective dataset, but can provide
52
51
  over 2x improvement in total throughput.
53
52
  num_files_per_batch:
54
53
  The number of unique files from which to sample
55
- batch elements each epoch. If left as `None`,
54
+ batch elements each epoch. If left as ``None``,
56
55
  will use all available files. Useful when reading
57
56
  from many files is bottlenecking dataloading.
58
-
59
-
60
- """
57
+ """ # noqa E501
61
58
 
62
59
  def __init__(
63
60
  self,
@@ -117,7 +114,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
117
114
  return self.batches_per_epoch
118
115
 
119
116
  def sample_fnames(self, size) -> np.ndarray:
120
- # first, randomly select `self.num_files_per_batch`
117
+ # first, randomly select ``self.num_files_per_batch``
121
118
  # file indices based on their probabilities
122
119
  fname_indices = np.arange(len(self.fnames))
123
120
  fname_indices = np.random.choice(
@@ -20,56 +20,56 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
20
20
  Args:
21
21
  X:
22
22
  Timeseries data to be iterated through. Should have
23
- shape `(num_channels, length * sample_rate)`. Windows
23
+ shape ``(num_channels, length * sample_rate)``. Windows
24
24
  will be sampled from the time (1st) dimension for all
25
25
  channels along the channel (0th) dimension.
26
26
  kernel_size:
27
- The length of the windows to sample from `X` in units
27
+ The length of the windows to sample from ``X`` in units
28
28
  of samples.
29
29
  y:
30
30
  Target timeseries to be iterated through. If specified,
31
31
  should be a single channel and have shape
32
- `(length * sample_rate,)`. If left as `None`, only windows
33
- sampled from `X` will be returned during iteration.
32
+ ``(length * sample_rate,)``. If left as ``None``, only windows
33
+ sampled from ``X`` will be returned during iteration.
34
34
  Otherwise, windows sampled from both arrays will be
35
35
  returned. Note that if sampling is performed non-coincidentally,
36
36
  there's no sensible way to align windows sampled from this
37
- array with the windows sampled from `X`, so this combination
37
+ array with the windows sampled from ``X``, so this combination
38
38
  of arguments is not permitted.
39
39
  batch_size:
40
40
  Maximum number of windows to return at each iteration. Will
41
41
  be the length of the 0th dimension of the returned array(s).
42
- If `batches_per_epoch` is specified, this will be the length
43
- of _every_ array returned during iteration. Otherwise, it's
42
+ If ``batches_per_epoch`` is specified, this will be the length
43
+ of **every** array returned during iteration. Otherwise, it's
44
44
  possible that the last array will be shorter due to the number
45
45
  of windows in the timeseries being a non-integer multiple of
46
- `batch_size`.
46
+ ``batch_size``.
47
47
  stride:
48
48
  The resolution at which windows will be sampled from the
49
49
  specified timeseries, in units of samples. E.g. if
50
- `stride=2`, the first sample of each window can only be
51
- from an index of `X` which is a multiple of 2. Obviously,
50
+ ``stride=2``, the first sample of each window can only be
51
+ from an index of ``X`` which is a multiple of 2. Obviously,
52
52
  this reduces the number of windows which can be iterated
53
- through by a factor of `stride`.
53
+ through by a factor of ``stride``.
54
54
  batches_per_epoch:
55
55
  Number of batches of window to produce during iteration
56
- before raising a `StopIteration`. Must be specified if
56
+ before raising a ``StopIteration``. Must be specified if
57
57
  performing non-coincident sampling. Otherwise, if left
58
- as `None`, windows will be sampled until the entire
58
+ as ``None``, windows will be sampled until the entire
59
59
  timeseries has been exhausted. Note that
60
- `batch_size * batches_per_epoch` must be be small
60
+ ``batch_size * batches_per_epoch`` must be be small
61
61
  enough to be able to be fulfilled by the number of
62
- windows in the timeseries, otherise a `ValueError`
62
+ windows in the timeseries, otherise a ``ValueError``
63
63
  will be raised.
64
64
  coincident:
65
- Whether to sample windows from the channels of `X`
65
+ Whether to sample windows from the channels of ``X``
66
66
  using the same indices or independently. Can't be
67
- `True` if `batches_per_epoch` is `None` or `y` is
68
- _not_ `None`.
67
+ ``True`` if ``batches_per_epoch`` is ``None`` or ``y`` is
68
+ **not** ``None``.
69
69
  shuffle:
70
70
  Whether to sample windows from timeseries randomly
71
- or in order along the time axis. If `coincident=False`
72
- and `shuffle=False`, channels will be iterated through
71
+ or in order along the time axis. If ``coincident=False``
72
+ and ``shuffle=False``, channels will be iterated through
73
73
  with the index along the last channel moving fastest.
74
74
  device:
75
75
  Which device to host the timeseries arrays on
@@ -91,7 +91,7 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
91
91
 
92
92
  # make sure if we specified a target array that all other
93
93
  # other necessary conditions are met (it has the same
94
- # length as `X` and we're sampling coincidentally)
94
+ # length as ``X`` and we're sampling coincidentally)
95
95
  if y is not None and y.shape[-1] != X.shape[-1]:
96
96
  raise ValueError(
97
97
  "Target timeseries must have same length as input"
@@ -1,7 +1,7 @@
1
1
  """
2
2
  Module containing callables classes for generating samples
3
3
  from specified distributions. Each callable should map from
4
- an integer `N` to a 1D torch `Tensor` containing `N` samples
4
+ an integer ``N`` to a 1D torch ``Tensor`` containing ``N`` samples
5
5
  from the corresponding distribution.
6
6
  """
7
7
 
@@ -22,8 +22,9 @@ _PLANCK18_OMEGA_M = 0.30966 # Matter density parameter
22
22
  class Cosine(dist.Distribution):
23
23
  """
24
24
  Cosine distribution based on
25
- ``torch.distributions.TransformedDistribution``.
26
- """
25
+ ``torch.distributions.TransformedDistribution``
26
+ (see `documentation <https://docs.pytorch.org/docs/stable/distributions.html#transformeddistribution>`_).
27
+ """ # noqa E501
27
28
 
28
29
  arg_constraints = {}
29
30
 
@@ -117,18 +118,17 @@ class LogNormal(dist.LogNormal):
117
118
  class PowerLaw(dist.TransformedDistribution):
118
119
  """
119
120
  Sample from a power law distribution,
120
- .. math::
121
- p(x) \approx x^{\alpha}.
121
+
122
+ .. math:: p(x) \\approx x^{\\alpha}.
122
123
 
123
124
  Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
124
125
  This could be used, for example, as a universal distribution of
125
126
  signal-to-noise ratios (SNRs) from uniformly volume distributed
126
127
  sources
127
- .. math::
128
128
 
129
- p(\rho) = 3*\rho_0^3 / \rho^4
129
+ .. math:: p(\\rho) = 3\;\\rho_0^3 / \\rho^4
130
130
 
131
- where :math:`\rho_0` is a representative minimum SNR
131
+ where :math:`\\rho_0` is a representative minimum SNR
132
132
  considered for detection. See, for example,
133
133
  `Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
134
134
  Or, for example, ``index=2`` for uniform in Euclidean volume.
@@ -137,10 +137,10 @@ class PowerLaw(dist.TransformedDistribution):
137
137
  support = dist.constraints.nonnegative
138
138
 
139
139
  def __init__(
140
- self, minimum: float, maximum: float, index: int, validate_args=None
140
+ self, minimum: float, maximum: float, index: float, validate_args=None
141
141
  ):
142
142
  if index == 0:
143
- raise RuntimeError("Index of 0 is the same as Uniform")
143
+ raise ValueError("Index of 0 is the same as Uniform")
144
144
  elif index == -1:
145
145
  base_min = torch.as_tensor(minimum).log()
146
146
  base_max = torch.as_tensor(maximum).log()
@@ -185,14 +185,14 @@ class UniformComovingVolume(dist.Distribution):
185
185
  Sample either redshift, comoving distance, or luminosity distance
186
186
  such that they are uniform in comoving volume, assuming a flat
187
187
  lambda-CDM cosmology. Default H0 and Omega_M values match
188
- astropy.cosmology.Planck18
188
+ `Planck18 parameters in Astropy <https://docs.astropy.org/en/latest/api/astropy.cosmology.realizations.Planck18.html>`_.
189
189
 
190
190
  Args:
191
191
  minimum: Minimum distance in the specified distance type
192
192
  maximum: Maximum distance in the specified distance type
193
193
  distance_type:
194
- Type of distance to sample from. Can be 'redshift',
195
- 'comoving_distance', or 'luminosity_distance'
194
+ Type of distance to sample from. Can be ``redshift``,
195
+ ``comoving_distance``, or ``luminosity_distance``
196
196
  h0: Hubble constant in km/s/Mpc
197
197
  omega_m: Matter density parameter
198
198
  z_max: Maximum redshift for the grid
@@ -347,18 +347,20 @@ class UniformComovingVolume(dist.Distribution):
347
347
 
348
348
  class RateEvolution(UniformComovingVolume):
349
349
  """
350
- Wrapper around `UniformComovingVolume` to allow for
350
+ Wrapper around :meth:`~ml4gw.distributions.UniformComovingVolume` to allow for
351
351
  arbitrary rate evolution functions. E.g., if
352
- `rate_function = 1 / (1 + z)`, then the distribution
352
+ ``rate_function = lambda z: 1 / (1 + z)``, then the distribution
353
353
  will sample values such that they occur uniform in
354
354
  source frame time.
355
355
 
356
356
  Args:
357
357
  rate_function: Callable that takes redshift as input
358
358
  and returns the rate evolution factor.
359
- *args, **kwargs: Arguments passed to `UniformComovingVolume`
360
- constructor.
361
- """
359
+ *args: Arguments passed to
360
+ :meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
361
+ **kwargs: Keyword arguments passed to
362
+ :meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
363
+ """ # noqa E501
362
364
 
363
365
  def __init__(
364
366
  self,
@@ -2,13 +2,11 @@
2
2
  Tools for manipulating raw gravitational waveforms
3
3
  and projecting them onto interferometer responses.
4
4
  Much of the projection code is an extension of the
5
- implementation made available in bilby:
6
-
7
- https://arxiv.org/abs/1811.02042
8
-
9
- Specifically the code here:
10
- https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
11
- """
5
+ implementation made available in
6
+ `bilby <https://arxiv.org/abs/1811.02042>`_.
7
+ Specifically code from
8
+ `this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
9
+ """ # noqa E501
12
10
 
13
11
  from typing import List, Tuple, Union
14
12
 
@@ -134,6 +132,9 @@ def compute_antenna_responses(
134
132
  # shape: batch x num_polarizations x 3 x 3
135
133
  polarization = torch.stack(polarizations, axis=1)
136
134
 
135
+ # Ensure dtype consistency before einsum
136
+ detector_tensors = detector_tensors.to(polarization.dtype)
137
+
137
138
  # compute the weight of each interferometer's response
138
139
  # to each polarization: batch x polarizations x ifos
139
140
  return torch.einsum("...jk,ijk->...i", polarization, detector_tensors)
@@ -194,7 +195,7 @@ def compute_observed_strain(
194
195
  **polarizations: Float[Tensor, "batch time"],
195
196
  ) -> WaveformTensor:
196
197
  """
197
- Compute the strain timeseries $h(t)$ observed by a network
198
+ Compute the strain timeseries :math:`h(t)` observed by a network
198
199
  of interferometers from the given polarization timeseries
199
200
  corresponding to gravitational waveforms from sources with
200
201
  the indicated sky parameters.
@@ -222,13 +223,13 @@ def compute_observed_strain(
222
223
  between the waveform observed at the geocenter and
223
224
  the one observed at the detector site. To avoid
224
225
  adding any delay between the two, reset your coordinates
225
- such that the desired interferometer is at `(0., 0., 0.)`.
226
+ such that the desired interferometer is at ``(0., 0., 0.)``.
226
227
  sample_rate:
227
228
  Rate at which the polarization timeseries have been sampled
228
229
  polarziations:
229
230
  Timeseries for each waveform polarization which
230
231
  contributes to the interferometer response. Allowed
231
- polarizations are `cross`, `plus`, and `breathing`.
232
+ polarizations are ``cross``, ``plus``, and ``breathing``.
232
233
  Returns:
233
234
  Tensor representing the observed strain at each
234
235
  interferometer for each waveform.
@@ -236,13 +237,15 @@ def compute_observed_strain(
236
237
 
237
238
  # TODO: just use theta as the input parameter?
238
239
  # note that ** syntax is ordered, so we're safe
239
- # to be lazy and use `list` for the keys and values
240
+ # to be lazy and use ``list`` for the keys and values
240
241
  theta = torch.pi / 2 - dec
241
242
  antenna_responses = compute_antenna_responses(
242
243
  theta, psi, phi, detector_tensors, list(polarizations)
243
244
  )
244
245
 
245
246
  polarizations = torch.stack(list(polarizations.values()), axis=1)
247
+ # Ensure dtype consistency before einsum
248
+ antenna_responses = antenna_responses.to(polarizations.dtype)
246
249
  waveforms = torch.einsum(
247
250
  "...pi,...pt->...it", antenna_responses, polarizations
248
251
  )
@@ -286,26 +289,28 @@ def compute_ifo_snr(
286
289
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
287
290
  lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
288
291
  ) -> Float[Tensor, "batch num_ifos"]:
289
- r"""Compute the SNRs of a batch of interferometer responses
292
+ """Compute the SNRs of a batch of interferometer responses
290
293
 
291
294
  Compute the signal to noise ratio (SNR) of individual
292
295
  interferometer responses to gravitational waveforms with
293
296
  respect to a background PSD for each interferometer. The
294
- SNR of the $i$th waveform at the $j$th interferometer
297
+ SNR of the :math:`i` th waveform at the :math:`j` th interferometer
295
298
  is computed as:
296
299
 
297
- $$\rho_{ij} =
298
- 4 \int_{f_{\text{min}}}^{f_{\text{max}}}
299
- \frac{\tilde{h_{ij}}(f)\tilde{h_{ij}}^*(f)}
300
- {S_n^{(j)}(f)}df$$
300
+ .. math::
301
301
 
302
- Where $f_{\text{min}}$ is a minimum frequency denoted
303
- by `highpass`, `f_{\text{max}}` is the maximum frequency
304
- denoted by `lowpass`, which defaults to the Nyquist frequency
305
- dictated by `sample_rate`; `\tilde{h_{ij}}` and `\tilde{h_{ij}}*`
306
- indicate the fourier transform of the $i$th waveform at
307
- the $j$th inteferometer and its complex conjugate, respectively;
308
- and $S_n^{(j)}$ is the backround PSD at the $j$th interferometer.
302
+ \\rho_{ij} =
303
+ 4 \\int_{f_{\\text{min}}}^{f_{\\text{max}}}
304
+ \\frac{\\tilde{h_{ij}}(f)\\tilde{h_{ij}}^*(f)}
305
+ {S_n^{(j)}(f)}df
306
+
307
+ Where :math:`f_{\\text{min}}` is a minimum frequency denoted
308
+ by ``highpass``, :math:`f_{\\text{max}}` is the maximum frequency
309
+ denoted by ``lowpass``, which defaults to the Nyquist frequency
310
+ dictated by ``sample_rate``; :math:`\\tilde{h}_{ij}` and :math:`\\tilde{h}_{ij}^*`
311
+ indicate the fourier transform of the :math:`i` th waveform at
312
+ the :math:`j` th inteferometer and its complex conjugate, respectively;
313
+ and :math:`S_n^{(j)}` is the backround PSD at the :math:`j` th interferometer.
309
314
 
310
315
  Args:
311
316
  responses:
@@ -314,12 +319,12 @@ def compute_ifo_snr(
314
319
  psd:
315
320
  The one-sided power spectral density of the background
316
321
  noise at each interferometer to which a response
317
- in `responses` has been calculated. If 2D, each row of
318
- `psd` will be assumed to be the background PSD for each
319
- channel of _every_ batch element in `responses`. If 3D,
322
+ in ``responses`` has been calculated. If 2D, each row of
323
+ ``psd`` will be assumed to be the background PSD for each
324
+ channel of _every_ batch element in ``responses``. If 3D,
320
325
  this should contain a background PSD for each channel
321
- of each element in `responses`, and therefore the first
322
- two dimensions of `psd` and `responses` should match.
326
+ of each element in ``responses``, and therefore the first
327
+ two dimensions of ``psd`` and ``responses`` should match.
323
328
  sample_rate:
324
329
  The frequency at which the waveform responses timeseries
325
330
  have been sampled. Upon fourier transforming, should
@@ -329,18 +334,18 @@ def compute_ifo_snr(
329
334
  If a tensor is provided, it will be assumed to be a
330
335
  pre-computed mask used to 0-out low frequency components.
331
336
  If a float, it will be used to compute such a mask. If
332
- left as `None`, all frequencies up to `lowpass`
337
+ left as ``None``, all frequencies up to ``lowpass``
333
338
  will contribute to the SNR calculation.
334
339
  lowpass:
335
340
  The maximum frequency below which to compute the SNR.
336
341
  If a tensor is provided, it will be assumed to be a
337
342
  pre-computed mask used to 0-out high frequency components.
338
343
  If a float, it will be used to compute such a mask. If
339
- left as `None`, all frequencies from `highpass` up to
344
+ left as ``None``, all frequencies from ``highpass`` up to
340
345
  the Nyquist freqyency will contribute to the SNR calculation.
341
346
  Returns:
342
347
  Batch of SNRs computed for each interferometer
343
- """
348
+ """ # noqa E501
344
349
 
345
350
  # TODO: should we do windowing here?
346
351
  # compute frequency power, upsampling precision so that
@@ -388,10 +393,10 @@ def compute_ifo_snr(
388
393
  # that the user specify the sample rate by taking the
389
394
  # fft as-is (without dividing by sample rate) and then
390
395
  # taking the mean here (or taking the sum and dividing
391
- # by the sum of `highpass` if it's a mask). If we want
396
+ # by the sum of ``highpass`` if it's a mask). If we want
392
397
  # to allow the user to pass a float for highpass, we'll
393
398
  # need the sample rate to compute the mask, but if we
394
- # replace this with a `mask` argument instead we're in
399
+ # replace this with a ``mask`` argument instead we're in
395
400
  # the clear
396
401
  df = sample_rate / responses.shape[-1]
397
402
  integrated = integrand.sum(axis=-1) * df
@@ -408,15 +413,17 @@ def compute_network_snr(
408
413
  highpass: Union[float, Float[Tensor, " frequency"], None] = None,
409
414
  lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
410
415
  ) -> BatchTensor:
411
- r"""
416
+ """
412
417
  Compute the total SNR from a gravitational waveform
413
418
  from a network of interferometers. The total SNR for
414
- the $i$th waveform is computed as
419
+ the :math:`i` th waveform is computed as
420
+
421
+ .. math::
415
422
 
416
- $$\rho_i = \sqrt{\sum_{j}^{N}\rho_{ij}^2}$$
423
+ \\rho_i = \\sqrt{\\sum_{j}^{N}\\rho_{ij}^2}
417
424
 
418
- where \rho_{ij} is the SNR for the $i$th waveform at
419
- the $j$th interferometer in the network and $N$ is
425
+ where :math:`\\rho_{ij}` is the SNR for the :math:`i` th waveform at
426
+ the :math:`j` th interferometer in the network and :math:`N` is
420
427
  the total number of interferometers.
421
428
 
422
429
  Args:
@@ -426,12 +433,12 @@ def compute_network_snr(
426
433
  backgrounds:
427
434
  The one-sided power spectral density of the background
428
435
  noise at each interferometer to which a response
429
- in `responses` has been calculated. If 2D, each row of
430
- `psd` will be assumed to be the background PSD for each
431
- channel of _every_ batch element in `responses`. If 3D,
436
+ in ``responses`` has been calculated. If 2D, each row of
437
+ ``psd`` will be assumed to be the background PSD for each
438
+ channel of **every** batch element in ``responses``. If 3D,
432
439
  this should contain a background PSD for each channel
433
- of each element in `responses`, and therefore the first
434
- two dimensions of `psd` and `responses` should match.
440
+ of each element in ``responses``, and therefore the first
441
+ two dimensions of ``psd`` and ``responses`` should match.
435
442
  sample_rate:
436
443
  The frequency at which the waveform responses timeseries
437
444
  have been sampled. Upon fourier transforming, should
@@ -441,14 +448,14 @@ def compute_network_snr(
441
448
  If a tensor is provided, it will be assumed to be a
442
449
  pre-computed mask used to 0-out low frequency components.
443
450
  If a float, it will be used to compute such a mask. If
444
- left as `None`, all frequencies up to `sample_rate / 2`
451
+ left as ``None``, all frequencies up to ``sample_rate / 2``
445
452
  will contribute to the SNR calculation.
446
453
  lowpass:
447
454
  The maximum frequency below which to compute the SNR.
448
455
  If a tensor is provided, it will be assumed to be a
449
456
  pre-computed mask used to 0-out high frequency components.
450
457
  If a float, it will be used to compute such a mask. If
451
- left as `None`, all frequencies from `highpass` up to
458
+ left as ``None``, all frequencies from ``highpass`` up to
452
459
  the Nyquist freqyency will contribute to the SNR calculation.
453
460
  Returns:
454
461
  Batch of SNRs for each waveform across the interferometer network
@@ -478,12 +485,12 @@ def reweight_snrs(
478
485
  psd:
479
486
  The one-sided power spectral density of the background
480
487
  noise at each interferometer to which a response
481
- in `responses` has been calculated. If 2D, each row of
482
- `psd` will be assumed to be the background PSD for each
483
- channel of _every_ batch element in `responses`. If 3D,
488
+ in ``responses`` has been calculated. If 2D, each row of
489
+ ``psd`` will be assumed to be the background PSD for each
490
+ channel of **every** batch element in ``responses``. If 3D,
484
491
  this should contain a background PSD for each channel
485
- of each element in `responses`, and therefore the first
486
- two dimensions of `psd` and `responses` should match.
492
+ of each element in ``responses``, and therefore the first
493
+ two dimensions of ``psd`` and ``responses`` should match.
487
494
  sample_rate:
488
495
  The frequency at which the waveform responses timeseries
489
496
  have been sampled. Upon fourier transforming, should
@@ -493,14 +500,14 @@ def reweight_snrs(
493
500
  If a tensor is provided, it will be assumed to be a
494
501
  pre-computed mask used to 0-out low frequency components.
495
502
  If a float, it will be used to compute such a mask. If
496
- left as `None`, all frequencies up to `sample_rate / 2`
503
+ left as ``None``, all frequencies up to ``sample_rate / 2``
497
504
  will contribute to the SNR calculation.
498
505
  lowpass:
499
506
  The maximum frequency below which to compute the SNR.
500
507
  If a tensor is provided, it will be assumed to be a
501
508
  pre-computed mask used to 0-out high frequency components.
502
509
  If a float, it will be used to compute such a mask. If
503
- left as `None`, all frequencies from `highpass` up to
510
+ left as ``None``, all frequencies from ``highpass`` up to
504
511
  the Nyquist freqyency will contribute to the SNR calculation.
505
512
  Returns:
506
513
  Rescaled interferometer responses
@@ -12,18 +12,18 @@ class Autoencoder(torch.nn.Module):
12
12
  Base autoencoder class that defines some of the
13
13
  basic methods and functionality. Autoencoders are
14
14
  defined here as a set of sequential blocks that
15
- have an `encode` method, which acts on the input
16
- data to the autoencoder, and a `decode` method, which
17
- acts on the encoded vector generated by the `encode`
18
- method. `forward` just runs these steps one after the
15
+ have an ``encode`` method, which acts on the input
16
+ data to the autoencoder, and a ``decode`` method, which
17
+ acts on the encoded vector generated by the ``encode``
18
+ method. ``forward`` just runs these steps one after the
19
19
  other. Although it isn't explicitly enforced, a good
20
- rule of thumb is that the ouput of a block's `decode`
20
+ rule of thumb is that the ouput of a block's ``decode``
21
21
  method should have the same shape as the _input_ of its
22
- `encode` method.
22
+ ``encode`` method.
23
23
 
24
- Accepts a `skip_connection` argument that defines how to
25
- combine information from the input of one block's `encode`
26
- layer with the output to its `decode`layer. See `skip_connections.py`
24
+ Accepts a ``skip_connection`` argument that defines how to
25
+ combine information from the input of one block's ``encode``
26
+ layer with the output to its ``decode`` layer. See ``skip_connections.py``
27
27
  for more info about what these classes are expected to contain
28
28
  and how they operate.
29
29
  """
@@ -83,11 +83,11 @@ class ConvolutionalAutoencoder(Autoencoder):
83
83
  match the shape of the input to its corresponding
84
84
  encoder layer, except for the last decoder which
85
85
  can have an arbitrary number of channels specified
86
- by `decode_channels`.
86
+ by ``decode_channels``.
87
87
 
88
- All layers also share the same `activation` except
88
+ All layers also share the same ``activation`` except
89
89
  for the last decoder layer, which can have an
90
- arbitrary `output_activation`.
90
+ arbitrary ``output_activation``.
91
91
  """
92
92
 
93
93
  def __init__(
@@ -115,7 +115,7 @@ class ConvolutionalAutoencoder(Autoencoder):
115
115
  # All intermediate layers should decode to
116
116
  # the same number of channels. The last decoder
117
117
  # should decode to whatever number of channels
118
- # was specified, even if it's `None` (in which
118
+ # was specified, even if it's ``None`` (in which
119
119
  # case it will just be in_channels anyway)
120
120
  decode = in_channels if i else decode_channels
121
121