ml4gw 0.7.5__tar.gz → 0.7.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- {ml4gw-0.7.5 → ml4gw-0.7.7}/PKG-INFO +6 -5
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/augmentations.py +4 -4
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/chunked_dataset.py +3 -3
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/hdf5_dataset.py +7 -10
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/in_memory_dataset.py +21 -21
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/distributions.py +20 -18
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/gw.py +60 -53
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/base.py +9 -9
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/convolutional.py +4 -4
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/resnet_1d.py +13 -13
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/resnet_2d.py +12 -12
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/online_average.py +1 -1
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/snapshotter.py +14 -14
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/spectral.py +48 -48
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/__init__.py +1 -1
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/iirfilter.py +3 -3
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/pearson.py +7 -8
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/qtransform.py +29 -34
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/scaler.py +4 -4
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spectral.py +10 -10
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spectrogram.py +12 -11
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/spline_interpolation.py +310 -146
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/transform.py +1 -1
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/whitening.py +36 -36
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/utils/slicing.py +40 -40
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_d.py +22 -66
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_p.py +9 -5
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/taylorf2.py +8 -7
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/conversion.py +2 -1
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/generator.py +33 -32
- ml4gw-0.7.7/ml4gw.egg-info/PKG-INFO +57 -0
- ml4gw-0.7.7/ml4gw.egg-info/SOURCES.txt +63 -0
- ml4gw-0.7.7/ml4gw.egg-info/dependency_links.txt +1 -0
- ml4gw-0.7.7/ml4gw.egg-info/requires.txt +5 -0
- ml4gw-0.7.7/ml4gw.egg-info/top_level.txt +1 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/pyproject.toml +12 -3
- ml4gw-0.7.7/setup.cfg +4 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_distributions.py +43 -3
- {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_gw.py +5 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_spectral.py +7 -1
- ml4gw-0.7.5/.coverage +0 -0
- ml4gw-0.7.5/.gitattributes +0 -2
- ml4gw-0.7.5/.github/workflows/coverage.yaml +0 -31
- ml4gw-0.7.5/.github/workflows/docs.yaml +0 -29
- ml4gw-0.7.5/.github/workflows/pre-commit.yaml +0 -17
- ml4gw-0.7.5/.github/workflows/publish.yaml +0 -27
- ml4gw-0.7.5/.github/workflows/unit-tests.yaml +0 -80
- ml4gw-0.7.5/.gitignore +0 -3
- ml4gw-0.7.5/.pre-commit-config.yaml +0 -23
- ml4gw-0.7.5/.readthedocs.yaml +0 -36
- ml4gw-0.7.5/CITATION.cff +0 -37
- ml4gw-0.7.5/docs/Makefile +0 -20
- ml4gw-0.7.5/docs/conf.py +0 -65
- ml4gw-0.7.5/docs/index.rst +0 -50
- ml4gw-0.7.5/docs/installation.rst +0 -19
- ml4gw-0.7.5/docs/make.bat +0 -35
- ml4gw-0.7.5/docs/ml4gw.dataloading.rst +0 -37
- ml4gw-0.7.5/docs/ml4gw.nn.autoencoder.rst +0 -45
- ml4gw-0.7.5/docs/ml4gw.nn.resnet.rst +0 -29
- ml4gw-0.7.5/docs/ml4gw.nn.rst +0 -31
- ml4gw-0.7.5/docs/ml4gw.nn.streaming.rst +0 -29
- ml4gw-0.7.5/docs/ml4gw.rst +0 -64
- ml4gw-0.7.5/docs/ml4gw.transforms.rst +0 -77
- ml4gw-0.7.5/docs/ml4gw.waveforms.rst +0 -53
- ml4gw-0.7.5/docs/modules.rst +0 -7
- ml4gw-0.7.5/docs/requirements.txt +0 -3
- ml4gw-0.7.5/examples/README.md +0 -12
- ml4gw-0.7.5/examples/ml4gw_tutorial.ipynb +0 -1757
- ml4gw-0.7.5/examples/pyproject.toml +0 -22
- ml4gw-0.7.5/examples/uv.lock +0 -2960
- ml4gw-0.7.5/tests/conftest.py +0 -200
- ml4gw-0.7.5/tests/dataloading/test_chunked_dataset.py +0 -82
- ml4gw-0.7.5/tests/dataloading/test_hdf5_dataset.py +0 -188
- ml4gw-0.7.5/tests/dataloading/test_in_memory_dataset.py +0 -357
- ml4gw-0.7.5/tests/nn/resnet/test_resnet_1d.py +0 -137
- ml4gw-0.7.5/tests/nn/resnet/test_resnet_2d.py +0 -138
- ml4gw-0.7.5/tests/nn/streaming/test_online_average.py +0 -88
- ml4gw-0.7.5/tests/nn/streaming/test_snapshotter.py +0 -120
- ml4gw-0.7.5/tests/nn/test_norm.py +0 -75
- ml4gw-0.7.5/tests/transforms/test_iirfilter.py +0 -321
- ml4gw-0.7.5/tests/transforms/test_pearson.py +0 -81
- ml4gw-0.7.5/tests/transforms/test_qtransform.py +0 -184
- ml4gw-0.7.5/tests/transforms/test_scaler.py +0 -123
- ml4gw-0.7.5/tests/transforms/test_snr_rescaler.py +0 -86
- ml4gw-0.7.5/tests/transforms/test_spectral_transform.py +0 -290
- ml4gw-0.7.5/tests/transforms/test_spectrogram.py +0 -109
- ml4gw-0.7.5/tests/transforms/test_spline_interpolation.py +0 -101
- ml4gw-0.7.5/tests/transforms/test_waveforms.py +0 -101
- ml4gw-0.7.5/tests/transforms/test_whitening.py +0 -191
- ml4gw-0.7.5/tests/utils/test_slicing.py +0 -334
- ml4gw-0.7.5/tests/waveforms/adhoc/test_sine_gaussian.py +0 -100
- ml4gw-0.7.5/tests/waveforms/cbc/test_cbc_waveforms.py +0 -480
- ml4gw-0.7.5/tests/waveforms/cbc/test_utils.py +0 -115
- ml4gw-0.7.5/tests/waveforms/test_conversion.py +0 -65
- ml4gw-0.7.5/tests/waveforms/test_generator.py +0 -216
- ml4gw-0.7.5/uv.lock +0 -3344
- {ml4gw-0.7.5 → ml4gw-0.7.7}/LICENSE +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/README.md +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/constants.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/dataloading/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/autoencoder/utils.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/norm.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/resnet/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/nn/streaming/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/snr_rescaler.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/transforms/waveforms.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/types.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/utils/interferometer.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/ringdown.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/adhoc/sine_gaussian.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/__init__.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/coefficients.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/phenom_d_data.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/ml4gw/waveforms/cbc/utils.py +0 -0
- {ml4gw-0.7.5 → ml4gw-0.7.7}/tests/test_augmentations.py +0 -0
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.7
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
|
|
6
|
-
License-File: LICENSE
|
|
7
6
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
8
7
|
Classifier: Programming Language :: Python :: 3.9
|
|
9
8
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -11,12 +10,14 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.13
|
|
13
12
|
Requires-Python: <3.13,>=3.9
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
14
15
|
Requires-Dist: jaxtyping<0.3,>=0.2
|
|
16
|
+
Requires-Dist: torch~=2.0
|
|
17
|
+
Requires-Dist: torchaudio~=2.0
|
|
15
18
|
Requires-Dist: numpy<2.0.0
|
|
16
19
|
Requires-Dist: scipy<1.15,>=1.9.0
|
|
17
|
-
|
|
18
|
-
Requires-Dist: torch~=2.0
|
|
19
|
-
Description-Content-Type: text/markdown
|
|
20
|
+
Dynamic: license-file
|
|
20
21
|
|
|
21
22
|
# ML4GW
|
|
22
23
|

|
|
@@ -6,8 +6,8 @@ from torch import Tensor
|
|
|
6
6
|
class SignalInverter(torch.nn.Module):
|
|
7
7
|
"""
|
|
8
8
|
Takes a tensor of timeseries of arbitrary dimension
|
|
9
|
-
and randomly inverts
|
|
10
|
-
each timeseries with probability
|
|
9
|
+
and randomly inverts i.e. :math:`h(t) \\rightarrow -h(t)`
|
|
10
|
+
each timeseries with probability ``prob``.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
prob:
|
|
@@ -29,8 +29,8 @@ class SignalInverter(torch.nn.Module):
|
|
|
29
29
|
class SignalReverser(torch.nn.Module):
|
|
30
30
|
"""
|
|
31
31
|
Takes a tensor of timeseries of arbitrary dimension
|
|
32
|
-
and randomly reverses
|
|
33
|
-
each timeseries with probability
|
|
32
|
+
and randomly reverses i.e., :math:`h(t) \\rightarrow h(-t)`.
|
|
33
|
+
each timeseries with probability ``prob``.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
prob:
|
|
@@ -15,9 +15,9 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
15
15
|
chunk_it:
|
|
16
16
|
Iterator which will produce chunks of timeseries
|
|
17
17
|
data to sample windows from. Should have shape
|
|
18
|
-
|
|
19
|
-
to sample from,
|
|
20
|
-
and
|
|
18
|
+
``(N, C, T)``, where ``N`` is the number of chunks
|
|
19
|
+
to sample from, ``C`` is the number of channels,
|
|
20
|
+
and ``T`` is the number of samples along the
|
|
21
21
|
time dimension for each chunk.
|
|
22
22
|
kernel_size:
|
|
23
23
|
Size of windows to be sampled from each chunk.
|
|
@@ -17,8 +17,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
17
17
|
Iterable dataset that samples and loads windows of
|
|
18
18
|
timeseries data uniformly from a set of HDF5 files.
|
|
19
19
|
It is _strongly_ recommended that these files have been
|
|
20
|
-
written using
|
|
21
|
-
(https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
|
|
20
|
+
written using `chunked storage <https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage>`_.
|
|
22
21
|
This has shown to produce increases in read-time speeds
|
|
23
22
|
of over an order of magnitude.
|
|
24
23
|
|
|
@@ -37,27 +36,25 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
37
36
|
Number of windows to sample at each iteration.
|
|
38
37
|
batches_per_epoch:
|
|
39
38
|
Number of batches to generate during each call
|
|
40
|
-
to
|
|
39
|
+
to ``__iter__``.
|
|
41
40
|
coincident:
|
|
42
41
|
Whether windows for each channel in a given batch
|
|
43
42
|
element should be sampled coincidentally, i.e.
|
|
44
43
|
corresponding to the same time indices from the
|
|
45
44
|
same files, or should be sampled independently.
|
|
46
45
|
For the latter case, users can either specify
|
|
47
|
-
|
|
48
|
-
for each channel, or
|
|
46
|
+
``False``, which will sample filenames independently
|
|
47
|
+
for each channel, or ``"files"``, which will sample
|
|
49
48
|
windows independently within a given file for each
|
|
50
49
|
channel. The latter setting limits the amount of
|
|
51
50
|
entropy in the effective dataset, but can provide
|
|
52
51
|
over 2x improvement in total throughput.
|
|
53
52
|
num_files_per_batch:
|
|
54
53
|
The number of unique files from which to sample
|
|
55
|
-
batch elements each epoch. If left as
|
|
54
|
+
batch elements each epoch. If left as ``None``,
|
|
56
55
|
will use all available files. Useful when reading
|
|
57
56
|
from many files is bottlenecking dataloading.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
"""
|
|
57
|
+
""" # noqa E501
|
|
61
58
|
|
|
62
59
|
def __init__(
|
|
63
60
|
self,
|
|
@@ -117,7 +114,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
117
114
|
return self.batches_per_epoch
|
|
118
115
|
|
|
119
116
|
def sample_fnames(self, size) -> np.ndarray:
|
|
120
|
-
# first, randomly select
|
|
117
|
+
# first, randomly select ``self.num_files_per_batch``
|
|
121
118
|
# file indices based on their probabilities
|
|
122
119
|
fname_indices = np.arange(len(self.fnames))
|
|
123
120
|
fname_indices = np.random.choice(
|
|
@@ -20,56 +20,56 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
20
20
|
Args:
|
|
21
21
|
X:
|
|
22
22
|
Timeseries data to be iterated through. Should have
|
|
23
|
-
shape
|
|
23
|
+
shape ``(num_channels, length * sample_rate)``. Windows
|
|
24
24
|
will be sampled from the time (1st) dimension for all
|
|
25
25
|
channels along the channel (0th) dimension.
|
|
26
26
|
kernel_size:
|
|
27
|
-
The length of the windows to sample from
|
|
27
|
+
The length of the windows to sample from ``X`` in units
|
|
28
28
|
of samples.
|
|
29
29
|
y:
|
|
30
30
|
Target timeseries to be iterated through. If specified,
|
|
31
31
|
should be a single channel and have shape
|
|
32
|
-
|
|
33
|
-
sampled from
|
|
32
|
+
``(length * sample_rate,)``. If left as ``None``, only windows
|
|
33
|
+
sampled from ``X`` will be returned during iteration.
|
|
34
34
|
Otherwise, windows sampled from both arrays will be
|
|
35
35
|
returned. Note that if sampling is performed non-coincidentally,
|
|
36
36
|
there's no sensible way to align windows sampled from this
|
|
37
|
-
array with the windows sampled from
|
|
37
|
+
array with the windows sampled from ``X``, so this combination
|
|
38
38
|
of arguments is not permitted.
|
|
39
39
|
batch_size:
|
|
40
40
|
Maximum number of windows to return at each iteration. Will
|
|
41
41
|
be the length of the 0th dimension of the returned array(s).
|
|
42
|
-
If
|
|
43
|
-
of
|
|
42
|
+
If ``batches_per_epoch`` is specified, this will be the length
|
|
43
|
+
of **every** array returned during iteration. Otherwise, it's
|
|
44
44
|
possible that the last array will be shorter due to the number
|
|
45
45
|
of windows in the timeseries being a non-integer multiple of
|
|
46
|
-
|
|
46
|
+
``batch_size``.
|
|
47
47
|
stride:
|
|
48
48
|
The resolution at which windows will be sampled from the
|
|
49
49
|
specified timeseries, in units of samples. E.g. if
|
|
50
|
-
|
|
51
|
-
from an index of
|
|
50
|
+
``stride=2``, the first sample of each window can only be
|
|
51
|
+
from an index of ``X`` which is a multiple of 2. Obviously,
|
|
52
52
|
this reduces the number of windows which can be iterated
|
|
53
|
-
through by a factor of
|
|
53
|
+
through by a factor of ``stride``.
|
|
54
54
|
batches_per_epoch:
|
|
55
55
|
Number of batches of window to produce during iteration
|
|
56
|
-
before raising a
|
|
56
|
+
before raising a ``StopIteration``. Must be specified if
|
|
57
57
|
performing non-coincident sampling. Otherwise, if left
|
|
58
|
-
as
|
|
58
|
+
as ``None``, windows will be sampled until the entire
|
|
59
59
|
timeseries has been exhausted. Note that
|
|
60
|
-
|
|
60
|
+
``batch_size * batches_per_epoch`` must be be small
|
|
61
61
|
enough to be able to be fulfilled by the number of
|
|
62
|
-
windows in the timeseries, otherise a
|
|
62
|
+
windows in the timeseries, otherise a ``ValueError``
|
|
63
63
|
will be raised.
|
|
64
64
|
coincident:
|
|
65
|
-
Whether to sample windows from the channels of
|
|
65
|
+
Whether to sample windows from the channels of ``X``
|
|
66
66
|
using the same indices or independently. Can't be
|
|
67
|
-
|
|
68
|
-
|
|
67
|
+
``True`` if ``batches_per_epoch`` is ``None`` or ``y`` is
|
|
68
|
+
**not** ``None``.
|
|
69
69
|
shuffle:
|
|
70
70
|
Whether to sample windows from timeseries randomly
|
|
71
|
-
or in order along the time axis. If
|
|
72
|
-
and
|
|
71
|
+
or in order along the time axis. If ``coincident=False``
|
|
72
|
+
and ``shuffle=False``, channels will be iterated through
|
|
73
73
|
with the index along the last channel moving fastest.
|
|
74
74
|
device:
|
|
75
75
|
Which device to host the timeseries arrays on
|
|
@@ -91,7 +91,7 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
91
91
|
|
|
92
92
|
# make sure if we specified a target array that all other
|
|
93
93
|
# other necessary conditions are met (it has the same
|
|
94
|
-
# length as
|
|
94
|
+
# length as ``X`` and we're sampling coincidentally)
|
|
95
95
|
if y is not None and y.shape[-1] != X.shape[-1]:
|
|
96
96
|
raise ValueError(
|
|
97
97
|
"Target timeseries must have same length as input"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Module containing callables classes for generating samples
|
|
3
3
|
from specified distributions. Each callable should map from
|
|
4
|
-
an integer
|
|
4
|
+
an integer ``N`` to a 1D torch ``Tensor`` containing ``N`` samples
|
|
5
5
|
from the corresponding distribution.
|
|
6
6
|
"""
|
|
7
7
|
|
|
@@ -22,8 +22,9 @@ _PLANCK18_OMEGA_M = 0.30966 # Matter density parameter
|
|
|
22
22
|
class Cosine(dist.Distribution):
|
|
23
23
|
"""
|
|
24
24
|
Cosine distribution based on
|
|
25
|
-
``torch.distributions.TransformedDistribution
|
|
26
|
-
|
|
25
|
+
``torch.distributions.TransformedDistribution``
|
|
26
|
+
(see `documentation <https://docs.pytorch.org/docs/stable/distributions.html#transformeddistribution>`_).
|
|
27
|
+
""" # noqa E501
|
|
27
28
|
|
|
28
29
|
arg_constraints = {}
|
|
29
30
|
|
|
@@ -117,18 +118,17 @@ class LogNormal(dist.LogNormal):
|
|
|
117
118
|
class PowerLaw(dist.TransformedDistribution):
|
|
118
119
|
"""
|
|
119
120
|
Sample from a power law distribution,
|
|
120
|
-
|
|
121
|
-
|
|
121
|
+
|
|
122
|
+
.. math:: p(x) \\approx x^{\\alpha}.
|
|
122
123
|
|
|
123
124
|
Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
|
|
124
125
|
This could be used, for example, as a universal distribution of
|
|
125
126
|
signal-to-noise ratios (SNRs) from uniformly volume distributed
|
|
126
127
|
sources
|
|
127
|
-
.. math::
|
|
128
128
|
|
|
129
|
-
|
|
129
|
+
.. math:: p(\\rho) = 3\;\\rho_0^3 / \\rho^4
|
|
130
130
|
|
|
131
|
-
where :math
|
|
131
|
+
where :math:`\\rho_0` is a representative minimum SNR
|
|
132
132
|
considered for detection. See, for example,
|
|
133
133
|
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
|
|
134
134
|
Or, for example, ``index=2`` for uniform in Euclidean volume.
|
|
@@ -137,10 +137,10 @@ class PowerLaw(dist.TransformedDistribution):
|
|
|
137
137
|
support = dist.constraints.nonnegative
|
|
138
138
|
|
|
139
139
|
def __init__(
|
|
140
|
-
self, minimum: float, maximum: float, index:
|
|
140
|
+
self, minimum: float, maximum: float, index: float, validate_args=None
|
|
141
141
|
):
|
|
142
142
|
if index == 0:
|
|
143
|
-
raise
|
|
143
|
+
raise ValueError("Index of 0 is the same as Uniform")
|
|
144
144
|
elif index == -1:
|
|
145
145
|
base_min = torch.as_tensor(minimum).log()
|
|
146
146
|
base_max = torch.as_tensor(maximum).log()
|
|
@@ -185,14 +185,14 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
185
185
|
Sample either redshift, comoving distance, or luminosity distance
|
|
186
186
|
such that they are uniform in comoving volume, assuming a flat
|
|
187
187
|
lambda-CDM cosmology. Default H0 and Omega_M values match
|
|
188
|
-
astropy.cosmology.Planck18
|
|
188
|
+
`Planck18 parameters in Astropy <https://docs.astropy.org/en/latest/api/astropy.cosmology.realizations.Planck18.html>`_.
|
|
189
189
|
|
|
190
190
|
Args:
|
|
191
191
|
minimum: Minimum distance in the specified distance type
|
|
192
192
|
maximum: Maximum distance in the specified distance type
|
|
193
193
|
distance_type:
|
|
194
|
-
Type of distance to sample from. Can be
|
|
195
|
-
|
|
194
|
+
Type of distance to sample from. Can be ``redshift``,
|
|
195
|
+
``comoving_distance``, or ``luminosity_distance``
|
|
196
196
|
h0: Hubble constant in km/s/Mpc
|
|
197
197
|
omega_m: Matter density parameter
|
|
198
198
|
z_max: Maximum redshift for the grid
|
|
@@ -347,18 +347,20 @@ class UniformComovingVolume(dist.Distribution):
|
|
|
347
347
|
|
|
348
348
|
class RateEvolution(UniformComovingVolume):
|
|
349
349
|
"""
|
|
350
|
-
Wrapper around
|
|
350
|
+
Wrapper around :meth:`~ml4gw.distributions.UniformComovingVolume` to allow for
|
|
351
351
|
arbitrary rate evolution functions. E.g., if
|
|
352
|
-
|
|
352
|
+
``rate_function = lambda z: 1 / (1 + z)``, then the distribution
|
|
353
353
|
will sample values such that they occur uniform in
|
|
354
354
|
source frame time.
|
|
355
355
|
|
|
356
356
|
Args:
|
|
357
357
|
rate_function: Callable that takes redshift as input
|
|
358
358
|
and returns the rate evolution factor.
|
|
359
|
-
*args
|
|
360
|
-
constructor.
|
|
361
|
-
|
|
359
|
+
*args: Arguments passed to
|
|
360
|
+
:meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
|
|
361
|
+
**kwargs: Keyword arguments passed to
|
|
362
|
+
:meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
|
|
363
|
+
""" # noqa E501
|
|
362
364
|
|
|
363
365
|
def __init__(
|
|
364
366
|
self,
|
|
@@ -2,13 +2,11 @@
|
|
|
2
2
|
Tools for manipulating raw gravitational waveforms
|
|
3
3
|
and projecting them onto interferometer responses.
|
|
4
4
|
Much of the projection code is an extension of the
|
|
5
|
-
implementation made available in
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
|
|
11
|
-
"""
|
|
5
|
+
implementation made available in
|
|
6
|
+
`bilby <https://arxiv.org/abs/1811.02042>`_.
|
|
7
|
+
Specifically code from
|
|
8
|
+
`this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
|
|
9
|
+
""" # noqa E501
|
|
12
10
|
|
|
13
11
|
from typing import List, Tuple, Union
|
|
14
12
|
|
|
@@ -134,6 +132,9 @@ def compute_antenna_responses(
|
|
|
134
132
|
# shape: batch x num_polarizations x 3 x 3
|
|
135
133
|
polarization = torch.stack(polarizations, axis=1)
|
|
136
134
|
|
|
135
|
+
# Ensure dtype consistency before einsum
|
|
136
|
+
detector_tensors = detector_tensors.to(polarization.dtype)
|
|
137
|
+
|
|
137
138
|
# compute the weight of each interferometer's response
|
|
138
139
|
# to each polarization: batch x polarizations x ifos
|
|
139
140
|
return torch.einsum("...jk,ijk->...i", polarization, detector_tensors)
|
|
@@ -194,7 +195,7 @@ def compute_observed_strain(
|
|
|
194
195
|
**polarizations: Float[Tensor, "batch time"],
|
|
195
196
|
) -> WaveformTensor:
|
|
196
197
|
"""
|
|
197
|
-
Compute the strain timeseries
|
|
198
|
+
Compute the strain timeseries :math:`h(t)` observed by a network
|
|
198
199
|
of interferometers from the given polarization timeseries
|
|
199
200
|
corresponding to gravitational waveforms from sources with
|
|
200
201
|
the indicated sky parameters.
|
|
@@ -222,13 +223,13 @@ def compute_observed_strain(
|
|
|
222
223
|
between the waveform observed at the geocenter and
|
|
223
224
|
the one observed at the detector site. To avoid
|
|
224
225
|
adding any delay between the two, reset your coordinates
|
|
225
|
-
such that the desired interferometer is at
|
|
226
|
+
such that the desired interferometer is at ``(0., 0., 0.)``.
|
|
226
227
|
sample_rate:
|
|
227
228
|
Rate at which the polarization timeseries have been sampled
|
|
228
229
|
polarziations:
|
|
229
230
|
Timeseries for each waveform polarization which
|
|
230
231
|
contributes to the interferometer response. Allowed
|
|
231
|
-
polarizations are
|
|
232
|
+
polarizations are ``cross``, ``plus``, and ``breathing``.
|
|
232
233
|
Returns:
|
|
233
234
|
Tensor representing the observed strain at each
|
|
234
235
|
interferometer for each waveform.
|
|
@@ -236,13 +237,15 @@ def compute_observed_strain(
|
|
|
236
237
|
|
|
237
238
|
# TODO: just use theta as the input parameter?
|
|
238
239
|
# note that ** syntax is ordered, so we're safe
|
|
239
|
-
# to be lazy and use
|
|
240
|
+
# to be lazy and use ``list`` for the keys and values
|
|
240
241
|
theta = torch.pi / 2 - dec
|
|
241
242
|
antenna_responses = compute_antenna_responses(
|
|
242
243
|
theta, psi, phi, detector_tensors, list(polarizations)
|
|
243
244
|
)
|
|
244
245
|
|
|
245
246
|
polarizations = torch.stack(list(polarizations.values()), axis=1)
|
|
247
|
+
# Ensure dtype consistency before einsum
|
|
248
|
+
antenna_responses = antenna_responses.to(polarizations.dtype)
|
|
246
249
|
waveforms = torch.einsum(
|
|
247
250
|
"...pi,...pt->...it", antenna_responses, polarizations
|
|
248
251
|
)
|
|
@@ -286,26 +289,28 @@ def compute_ifo_snr(
|
|
|
286
289
|
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
287
290
|
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
288
291
|
) -> Float[Tensor, "batch num_ifos"]:
|
|
289
|
-
|
|
292
|
+
"""Compute the SNRs of a batch of interferometer responses
|
|
290
293
|
|
|
291
294
|
Compute the signal to noise ratio (SNR) of individual
|
|
292
295
|
interferometer responses to gravitational waveforms with
|
|
293
296
|
respect to a background PSD for each interferometer. The
|
|
294
|
-
SNR of the
|
|
297
|
+
SNR of the :math:`i` th waveform at the :math:`j` th interferometer
|
|
295
298
|
is computed as:
|
|
296
299
|
|
|
297
|
-
|
|
298
|
-
4 \int_{f_{\text{min}}}^{f_{\text{max}}}
|
|
299
|
-
\frac{\tilde{h_{ij}}(f)\tilde{h_{ij}}^*(f)}
|
|
300
|
-
{S_n^{(j)}(f)}df$$
|
|
300
|
+
.. math::
|
|
301
301
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
302
|
+
\\rho_{ij} =
|
|
303
|
+
4 \\int_{f_{\\text{min}}}^{f_{\\text{max}}}
|
|
304
|
+
\\frac{\\tilde{h_{ij}}(f)\\tilde{h_{ij}}^*(f)}
|
|
305
|
+
{S_n^{(j)}(f)}df
|
|
306
|
+
|
|
307
|
+
Where :math:`f_{\\text{min}}` is a minimum frequency denoted
|
|
308
|
+
by ``highpass``, :math:`f_{\\text{max}}` is the maximum frequency
|
|
309
|
+
denoted by ``lowpass``, which defaults to the Nyquist frequency
|
|
310
|
+
dictated by ``sample_rate``; :math:`\\tilde{h}_{ij}` and :math:`\\tilde{h}_{ij}^*`
|
|
311
|
+
indicate the fourier transform of the :math:`i` th waveform at
|
|
312
|
+
the :math:`j` th inteferometer and its complex conjugate, respectively;
|
|
313
|
+
and :math:`S_n^{(j)}` is the backround PSD at the :math:`j` th interferometer.
|
|
309
314
|
|
|
310
315
|
Args:
|
|
311
316
|
responses:
|
|
@@ -314,12 +319,12 @@ def compute_ifo_snr(
|
|
|
314
319
|
psd:
|
|
315
320
|
The one-sided power spectral density of the background
|
|
316
321
|
noise at each interferometer to which a response
|
|
317
|
-
in
|
|
318
|
-
|
|
319
|
-
channel of _every_ batch element in
|
|
322
|
+
in ``responses`` has been calculated. If 2D, each row of
|
|
323
|
+
``psd`` will be assumed to be the background PSD for each
|
|
324
|
+
channel of _every_ batch element in ``responses``. If 3D,
|
|
320
325
|
this should contain a background PSD for each channel
|
|
321
|
-
of each element in
|
|
322
|
-
two dimensions of
|
|
326
|
+
of each element in ``responses``, and therefore the first
|
|
327
|
+
two dimensions of ``psd`` and ``responses`` should match.
|
|
323
328
|
sample_rate:
|
|
324
329
|
The frequency at which the waveform responses timeseries
|
|
325
330
|
have been sampled. Upon fourier transforming, should
|
|
@@ -329,18 +334,18 @@ def compute_ifo_snr(
|
|
|
329
334
|
If a tensor is provided, it will be assumed to be a
|
|
330
335
|
pre-computed mask used to 0-out low frequency components.
|
|
331
336
|
If a float, it will be used to compute such a mask. If
|
|
332
|
-
left as
|
|
337
|
+
left as ``None``, all frequencies up to ``lowpass``
|
|
333
338
|
will contribute to the SNR calculation.
|
|
334
339
|
lowpass:
|
|
335
340
|
The maximum frequency below which to compute the SNR.
|
|
336
341
|
If a tensor is provided, it will be assumed to be a
|
|
337
342
|
pre-computed mask used to 0-out high frequency components.
|
|
338
343
|
If a float, it will be used to compute such a mask. If
|
|
339
|
-
left as
|
|
344
|
+
left as ``None``, all frequencies from ``highpass`` up to
|
|
340
345
|
the Nyquist freqyency will contribute to the SNR calculation.
|
|
341
346
|
Returns:
|
|
342
347
|
Batch of SNRs computed for each interferometer
|
|
343
|
-
"""
|
|
348
|
+
""" # noqa E501
|
|
344
349
|
|
|
345
350
|
# TODO: should we do windowing here?
|
|
346
351
|
# compute frequency power, upsampling precision so that
|
|
@@ -388,10 +393,10 @@ def compute_ifo_snr(
|
|
|
388
393
|
# that the user specify the sample rate by taking the
|
|
389
394
|
# fft as-is (without dividing by sample rate) and then
|
|
390
395
|
# taking the mean here (or taking the sum and dividing
|
|
391
|
-
# by the sum of
|
|
396
|
+
# by the sum of ``highpass`` if it's a mask). If we want
|
|
392
397
|
# to allow the user to pass a float for highpass, we'll
|
|
393
398
|
# need the sample rate to compute the mask, but if we
|
|
394
|
-
# replace this with a
|
|
399
|
+
# replace this with a ``mask`` argument instead we're in
|
|
395
400
|
# the clear
|
|
396
401
|
df = sample_rate / responses.shape[-1]
|
|
397
402
|
integrated = integrand.sum(axis=-1) * df
|
|
@@ -408,15 +413,17 @@ def compute_network_snr(
|
|
|
408
413
|
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
409
414
|
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
410
415
|
) -> BatchTensor:
|
|
411
|
-
|
|
416
|
+
"""
|
|
412
417
|
Compute the total SNR from a gravitational waveform
|
|
413
418
|
from a network of interferometers. The total SNR for
|
|
414
|
-
the
|
|
419
|
+
the :math:`i` th waveform is computed as
|
|
420
|
+
|
|
421
|
+
.. math::
|
|
415
422
|
|
|
416
|
-
|
|
423
|
+
\\rho_i = \\sqrt{\\sum_{j}^{N}\\rho_{ij}^2}
|
|
417
424
|
|
|
418
|
-
where
|
|
419
|
-
the
|
|
425
|
+
where :math:`\\rho_{ij}` is the SNR for the :math:`i` th waveform at
|
|
426
|
+
the :math:`j` th interferometer in the network and :math:`N` is
|
|
420
427
|
the total number of interferometers.
|
|
421
428
|
|
|
422
429
|
Args:
|
|
@@ -426,12 +433,12 @@ def compute_network_snr(
|
|
|
426
433
|
backgrounds:
|
|
427
434
|
The one-sided power spectral density of the background
|
|
428
435
|
noise at each interferometer to which a response
|
|
429
|
-
in
|
|
430
|
-
|
|
431
|
-
channel of
|
|
436
|
+
in ``responses`` has been calculated. If 2D, each row of
|
|
437
|
+
``psd`` will be assumed to be the background PSD for each
|
|
438
|
+
channel of **every** batch element in ``responses``. If 3D,
|
|
432
439
|
this should contain a background PSD for each channel
|
|
433
|
-
of each element in
|
|
434
|
-
two dimensions of
|
|
440
|
+
of each element in ``responses``, and therefore the first
|
|
441
|
+
two dimensions of ``psd`` and ``responses`` should match.
|
|
435
442
|
sample_rate:
|
|
436
443
|
The frequency at which the waveform responses timeseries
|
|
437
444
|
have been sampled. Upon fourier transforming, should
|
|
@@ -441,14 +448,14 @@ def compute_network_snr(
|
|
|
441
448
|
If a tensor is provided, it will be assumed to be a
|
|
442
449
|
pre-computed mask used to 0-out low frequency components.
|
|
443
450
|
If a float, it will be used to compute such a mask. If
|
|
444
|
-
left as
|
|
451
|
+
left as ``None``, all frequencies up to ``sample_rate / 2``
|
|
445
452
|
will contribute to the SNR calculation.
|
|
446
453
|
lowpass:
|
|
447
454
|
The maximum frequency below which to compute the SNR.
|
|
448
455
|
If a tensor is provided, it will be assumed to be a
|
|
449
456
|
pre-computed mask used to 0-out high frequency components.
|
|
450
457
|
If a float, it will be used to compute such a mask. If
|
|
451
|
-
left as
|
|
458
|
+
left as ``None``, all frequencies from ``highpass`` up to
|
|
452
459
|
the Nyquist freqyency will contribute to the SNR calculation.
|
|
453
460
|
Returns:
|
|
454
461
|
Batch of SNRs for each waveform across the interferometer network
|
|
@@ -478,12 +485,12 @@ def reweight_snrs(
|
|
|
478
485
|
psd:
|
|
479
486
|
The one-sided power spectral density of the background
|
|
480
487
|
noise at each interferometer to which a response
|
|
481
|
-
in
|
|
482
|
-
|
|
483
|
-
channel of
|
|
488
|
+
in ``responses`` has been calculated. If 2D, each row of
|
|
489
|
+
``psd`` will be assumed to be the background PSD for each
|
|
490
|
+
channel of **every** batch element in ``responses``. If 3D,
|
|
484
491
|
this should contain a background PSD for each channel
|
|
485
|
-
of each element in
|
|
486
|
-
two dimensions of
|
|
492
|
+
of each element in ``responses``, and therefore the first
|
|
493
|
+
two dimensions of ``psd`` and ``responses`` should match.
|
|
487
494
|
sample_rate:
|
|
488
495
|
The frequency at which the waveform responses timeseries
|
|
489
496
|
have been sampled. Upon fourier transforming, should
|
|
@@ -493,14 +500,14 @@ def reweight_snrs(
|
|
|
493
500
|
If a tensor is provided, it will be assumed to be a
|
|
494
501
|
pre-computed mask used to 0-out low frequency components.
|
|
495
502
|
If a float, it will be used to compute such a mask. If
|
|
496
|
-
left as
|
|
503
|
+
left as ``None``, all frequencies up to ``sample_rate / 2``
|
|
497
504
|
will contribute to the SNR calculation.
|
|
498
505
|
lowpass:
|
|
499
506
|
The maximum frequency below which to compute the SNR.
|
|
500
507
|
If a tensor is provided, it will be assumed to be a
|
|
501
508
|
pre-computed mask used to 0-out high frequency components.
|
|
502
509
|
If a float, it will be used to compute such a mask. If
|
|
503
|
-
left as
|
|
510
|
+
left as ``None``, all frequencies from ``highpass`` up to
|
|
504
511
|
the Nyquist freqyency will contribute to the SNR calculation.
|
|
505
512
|
Returns:
|
|
506
513
|
Rescaled interferometer responses
|
|
@@ -12,18 +12,18 @@ class Autoencoder(torch.nn.Module):
|
|
|
12
12
|
Base autoencoder class that defines some of the
|
|
13
13
|
basic methods and functionality. Autoencoders are
|
|
14
14
|
defined here as a set of sequential blocks that
|
|
15
|
-
have an
|
|
16
|
-
data to the autoencoder, and a
|
|
17
|
-
acts on the encoded vector generated by the
|
|
18
|
-
method.
|
|
15
|
+
have an ``encode`` method, which acts on the input
|
|
16
|
+
data to the autoencoder, and a ``decode`` method, which
|
|
17
|
+
acts on the encoded vector generated by the ``encode``
|
|
18
|
+
method. ``forward`` just runs these steps one after the
|
|
19
19
|
other. Although it isn't explicitly enforced, a good
|
|
20
|
-
rule of thumb is that the ouput of a block's
|
|
20
|
+
rule of thumb is that the ouput of a block's ``decode``
|
|
21
21
|
method should have the same shape as the _input_ of its
|
|
22
|
-
|
|
22
|
+
``encode`` method.
|
|
23
23
|
|
|
24
|
-
Accepts a
|
|
25
|
-
combine information from the input of one block's
|
|
26
|
-
layer with the output to its
|
|
24
|
+
Accepts a ``skip_connection`` argument that defines how to
|
|
25
|
+
combine information from the input of one block's ``encode``
|
|
26
|
+
layer with the output to its ``decode`` layer. See ``skip_connections.py``
|
|
27
27
|
for more info about what these classes are expected to contain
|
|
28
28
|
and how they operate.
|
|
29
29
|
"""
|
|
@@ -83,11 +83,11 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
83
83
|
match the shape of the input to its corresponding
|
|
84
84
|
encoder layer, except for the last decoder which
|
|
85
85
|
can have an arbitrary number of channels specified
|
|
86
|
-
by
|
|
86
|
+
by ``decode_channels``.
|
|
87
87
|
|
|
88
|
-
All layers also share the same
|
|
88
|
+
All layers also share the same ``activation`` except
|
|
89
89
|
for the last decoder layer, which can have an
|
|
90
|
-
arbitrary
|
|
90
|
+
arbitrary ``output_activation``.
|
|
91
91
|
"""
|
|
92
92
|
|
|
93
93
|
def __init__(
|
|
@@ -115,7 +115,7 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
115
115
|
# All intermediate layers should decode to
|
|
116
116
|
# the same number of channels. The last decoder
|
|
117
117
|
# should decode to whatever number of channels
|
|
118
|
-
# was specified, even if it's
|
|
118
|
+
# was specified, even if it's ``None`` (in which
|
|
119
119
|
# case it will just be in_channels anyway)
|
|
120
120
|
decode = in_channels if i else decode_channels
|
|
121
121
|
|