ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/augmentations.py +5 -0
- ml4gw/dataloading/__init__.py +5 -0
- ml4gw/dataloading/chunked_dataset.py +2 -4
- ml4gw/dataloading/hdf5_dataset.py +12 -10
- ml4gw/dataloading/in_memory_dataset.py +12 -12
- ml4gw/distributions.py +3 -3
- ml4gw/gw.py +18 -21
- ml4gw/nn/__init__.py +6 -0
- ml4gw/nn/autoencoder/base.py +5 -9
- ml4gw/nn/autoencoder/convolutional.py +7 -10
- ml4gw/nn/autoencoder/skip_connection.py +3 -5
- ml4gw/nn/norm.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +12 -13
- ml4gw/nn/resnet/resnet_2d.py +13 -14
- ml4gw/nn/streaming/online_average.py +3 -5
- ml4gw/nn/streaming/snapshotter.py +10 -14
- ml4gw/spectral.py +20 -23
- ml4gw/transforms/__init__.py +7 -1
- ml4gw/transforms/decimator.py +183 -0
- ml4gw/transforms/iirfilter.py +3 -5
- ml4gw/transforms/pearson.py +3 -4
- ml4gw/transforms/qtransform.py +20 -26
- ml4gw/transforms/scaler.py +3 -5
- ml4gw/transforms/snr_rescaler.py +7 -11
- ml4gw/transforms/spectral.py +6 -13
- ml4gw/transforms/spectrogram.py +6 -3
- ml4gw/transforms/spline_interpolation.py +312 -143
- ml4gw/transforms/transform.py +4 -6
- ml4gw/transforms/waveforms.py +8 -15
- ml4gw/transforms/whitening.py +11 -16
- ml4gw/types.py +8 -5
- ml4gw/utils/interferometer.py +20 -3
- ml4gw/utils/slicing.py +26 -30
- ml4gw/waveforms/__init__.py +6 -0
- ml4gw/waveforms/cbc/phenom_p.py +7 -9
- ml4gw/waveforms/conversion.py +2 -4
- ml4gw/waveforms/generator.py +3 -3
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
- ml4gw-0.7.8.dist-info/RECORD +57 -0
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
- ml4gw-0.7.8.dist-info/top_level.txt +1 -0
- ml4gw-0.7.6.dist-info/RECORD +0 -55
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
ml4gw/transforms/whitening.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
3
|
from .. import spectral
|
|
@@ -55,8 +53,8 @@ class Whiten(torch.nn.Module):
|
|
|
55
53
|
self,
|
|
56
54
|
fduration: float,
|
|
57
55
|
sample_rate: float,
|
|
58
|
-
highpass:
|
|
59
|
-
lowpass:
|
|
56
|
+
highpass: float | None = None,
|
|
57
|
+
lowpass: float | None = None,
|
|
60
58
|
) -> None:
|
|
61
59
|
super().__init__()
|
|
62
60
|
self.fduration = fduration
|
|
@@ -157,11 +155,11 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
157
155
|
def fit(
|
|
158
156
|
self,
|
|
159
157
|
fduration: float,
|
|
160
|
-
*background:
|
|
161
|
-
fftlength:
|
|
162
|
-
highpass:
|
|
163
|
-
lowpass:
|
|
164
|
-
overlap:
|
|
158
|
+
*background: TimeSeries1d | FrequencySeries1d,
|
|
159
|
+
fftlength: float | None = None,
|
|
160
|
+
highpass: float | None = None,
|
|
161
|
+
lowpass: float | None = None,
|
|
162
|
+
overlap: float | None = None,
|
|
165
163
|
) -> None:
|
|
166
164
|
"""
|
|
167
165
|
Compute the PSD of channel-wise background to
|
|
@@ -224,10 +222,8 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
224
222
|
"""
|
|
225
223
|
if len(background) != self.num_channels:
|
|
226
224
|
raise ValueError(
|
|
227
|
-
"Expected to fit whitening transform on {}
|
|
228
|
-
"timeseries, but was passed {}"
|
|
229
|
-
self.num_channels, len(background)
|
|
230
|
-
)
|
|
225
|
+
f"Expected to fit whitening transform on {self.num_channels} "
|
|
226
|
+
f"background timeseries, but was passed {len(background)}"
|
|
231
227
|
)
|
|
232
228
|
|
|
233
229
|
num_freqs = self.psd.size(-1)
|
|
@@ -257,9 +253,8 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
257
253
|
if X.size(-1) != expected_dim:
|
|
258
254
|
raise ValueError(
|
|
259
255
|
"Whitening transform expected a kernel length "
|
|
260
|
-
"of {}s, but was passed data
|
|
261
|
-
|
|
262
|
-
)
|
|
256
|
+
f"of {self.kernel_length}s, but was passed data "
|
|
257
|
+
f"of length {X.size(-1) / self.sample_rate}s"
|
|
263
258
|
)
|
|
264
259
|
|
|
265
260
|
pad = int(self.fduration.item() * self.sample_rate / 2)
|
ml4gw/types.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module defines common types used for type
|
|
3
|
+
annotation throughout the package
|
|
4
|
+
"""
|
|
2
5
|
|
|
3
6
|
from jaxtyping import Float
|
|
4
7
|
from torch import Tensor
|
|
@@ -15,11 +18,11 @@ NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
|
|
|
15
18
|
TimeSeries1d = Float[Tensor, "time"]
|
|
16
19
|
TimeSeries2d = Float[TimeSeries1d, "channel"]
|
|
17
20
|
TimeSeries3d = Float[TimeSeries2d, "batch"]
|
|
18
|
-
TimeSeries1to3d =
|
|
21
|
+
TimeSeries1to3d = TimeSeries1d | TimeSeries2d | TimeSeries3d
|
|
19
22
|
|
|
20
23
|
FrequencySeries1d = Float[Tensor, "frequency"]
|
|
21
24
|
FrequencySeries2d = Float[FrequencySeries1d, "channel"]
|
|
22
25
|
FrequencySeries3d = Float[FrequencySeries2d, "batch"]
|
|
23
|
-
FrequencySeries1to3d =
|
|
24
|
-
FrequencySeries1d
|
|
25
|
-
|
|
26
|
+
FrequencySeries1to3d = (
|
|
27
|
+
FrequencySeries1d | FrequencySeries2d | FrequencySeries3d
|
|
28
|
+
)
|
ml4gw/utils/interferometer.py
CHANGED
|
@@ -1,10 +1,27 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the interferometer geometry
|
|
3
|
+
used to calculate waveform projection.
|
|
4
|
+
|
|
5
|
+
Values taken from
|
|
6
|
+
https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
1
11
|
import torch
|
|
2
12
|
|
|
3
13
|
|
|
4
|
-
# based on values from
|
|
5
|
-
# https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
|
|
6
14
|
class InterferometerGeometry:
|
|
7
|
-
|
|
15
|
+
"""
|
|
16
|
+
Contains geometric information for the LIGO, Virgo, and KAGRA
|
|
17
|
+
interferometers.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
name: The name of the interferometer. This should be either
|
|
21
|
+
'H1', 'L1', 'V1', or 'K1'.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, name: Literal["H1", "L1", "V1", "K1"]) -> None:
|
|
8
25
|
if name == "H1":
|
|
9
26
|
self.x_arm = torch.Tensor(
|
|
10
27
|
(-0.22389266154, +0.79983062746, +0.55690487831)
|
ml4gw/utils/slicing.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module contains functions for randomly sampling
|
|
3
|
+
windows of data from timeseries data, as well as for
|
|
4
|
+
unfolding timeseries data into potentially overlapping
|
|
5
|
+
windows.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
import torch
|
|
4
9
|
from jaxtyping import Float, Int64
|
|
@@ -7,7 +12,7 @@ from torch.nn.functional import unfold
|
|
|
7
12
|
|
|
8
13
|
from ..types import TimeSeries1d, TimeSeries1to3d, TimeSeries2d, TimeSeries3d
|
|
9
14
|
|
|
10
|
-
BatchTimeSeriesTensor =
|
|
15
|
+
BatchTimeSeriesTensor = Float[Tensor, "batch time"] | TimeSeries3d
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
def unfold_windows(
|
|
@@ -15,11 +20,11 @@ def unfold_windows(
|
|
|
15
20
|
window_size: int,
|
|
16
21
|
stride: int,
|
|
17
22
|
drop_last: bool = True,
|
|
18
|
-
) ->
|
|
19
|
-
Float[TimeSeries1d, " window"]
|
|
20
|
-
Float[TimeSeries2d, " window"]
|
|
21
|
-
Float[TimeSeries3d, " window"]
|
|
22
|
-
|
|
23
|
+
) -> (
|
|
24
|
+
Float[TimeSeries1d, " window"]
|
|
25
|
+
| Float[TimeSeries2d, " window"]
|
|
26
|
+
| Float[TimeSeries3d, " window"]
|
|
27
|
+
):
|
|
23
28
|
"""Unfold a timeseries into windows
|
|
24
29
|
|
|
25
30
|
Args:
|
|
@@ -171,10 +176,8 @@ def slice_kernels(
|
|
|
171
176
|
# to select _different_ kernels from each channel
|
|
172
177
|
if len(x) != idx.shape[1]:
|
|
173
178
|
raise ValueError(
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
"with shape {}"
|
|
177
|
-
).format(x.shape, idx.shape)
|
|
179
|
+
f"Can't slice array with shape {x.shape} with indices "
|
|
180
|
+
f"with shape {idx.shape}"
|
|
178
181
|
)
|
|
179
182
|
|
|
180
183
|
# batch_size x num_channels x kernel_size
|
|
@@ -204,8 +207,8 @@ def slice_kernels(
|
|
|
204
207
|
# of multichannel timeseries
|
|
205
208
|
if len(idx) != len(x):
|
|
206
209
|
raise ValueError(
|
|
207
|
-
"Can't slice kernels from batch of length {} "
|
|
208
|
-
"using indices of length {
|
|
210
|
+
f"Can't slice kernels from batch of length {len(x)} "
|
|
211
|
+
f"using indices of length {len(idx)}"
|
|
209
212
|
)
|
|
210
213
|
|
|
211
214
|
# batch_size x kernel_size
|
|
@@ -231,8 +234,8 @@ def slice_kernels(
|
|
|
231
234
|
def sample_kernels(
|
|
232
235
|
X: TimeSeries1to3d,
|
|
233
236
|
kernel_size: int,
|
|
234
|
-
N:
|
|
235
|
-
max_center_offset:
|
|
237
|
+
N: int | None = None,
|
|
238
|
+
max_center_offset: int | None = None,
|
|
236
239
|
coincident: bool = True,
|
|
237
240
|
) -> BatchTimeSeriesTensor:
|
|
238
241
|
"""Randomly sample kernels from a single or multichannel timeseries
|
|
@@ -286,9 +289,8 @@ def sample_kernels(
|
|
|
286
289
|
|
|
287
290
|
if X.shape[-1] < kernel_size:
|
|
288
291
|
raise ValueError(
|
|
289
|
-
"Can't sample kernels of size {} from tensor
|
|
290
|
-
|
|
291
|
-
)
|
|
292
|
+
f"Can't sample kernels of size {kernel_size} from tensor "
|
|
293
|
+
f"with shape {X.shape}"
|
|
292
294
|
)
|
|
293
295
|
elif X.ndim > 3:
|
|
294
296
|
raise ValueError(
|
|
@@ -300,10 +302,8 @@ def sample_kernels(
|
|
|
300
302
|
)
|
|
301
303
|
elif X.ndim == 3 and N is not None and N != len(X):
|
|
302
304
|
raise ValueError(
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
"batch dimension {}"
|
|
306
|
-
).format(N, len(X))
|
|
305
|
+
f"Can't sample {N} kernels from 3D tensor with "
|
|
306
|
+
f"batch dimension {len(X)}"
|
|
307
307
|
)
|
|
308
308
|
|
|
309
309
|
if X.ndim == 1:
|
|
@@ -334,20 +334,16 @@ def sample_kernels(
|
|
|
334
334
|
# the kernel length, we won't be able to sample
|
|
335
335
|
# any kernels at all
|
|
336
336
|
raise ValueError(
|
|
337
|
-
"Negative center offset value {} is too
|
|
338
|
-
"for requested kernel size {}"
|
|
339
|
-
max_center_offset, kernel_size
|
|
340
|
-
)
|
|
337
|
+
f"Negative center offset value {max_center_offset} is too "
|
|
338
|
+
f"large for requested kernel size {kernel_size}"
|
|
341
339
|
)
|
|
342
340
|
|
|
343
341
|
if min_val < 0:
|
|
344
342
|
# if kernel_size > center - max_center_offset,
|
|
345
343
|
# we may end up with negative indices
|
|
346
344
|
raise ValueError(
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
"offset value {}"
|
|
350
|
-
).format(kernel_size, max_center_offset)
|
|
345
|
+
f"Kernel size {kernel_size} is too large for requested center "
|
|
346
|
+
f"offset value {max_center_offset}"
|
|
351
347
|
)
|
|
352
348
|
|
|
353
349
|
if X.ndim == 3 or coincident:
|
ml4gw/waveforms/__init__.py
CHANGED
ml4gw/waveforms/cbc/phenom_p.py
CHANGED
|
@@ -3,8 +3,6 @@ Based on the JAX implementation of IMRPhenomPv2 from
|
|
|
3
3
|
https://github.com/tedwards2412/ripple/blob/main/src/ripplegw/waveforms/IMRPhenomPv2.py
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from typing import Dict, Optional, Tuple
|
|
7
|
-
|
|
8
6
|
import torch
|
|
9
7
|
from jaxtyping import Float
|
|
10
8
|
from torch import Tensor
|
|
@@ -38,7 +36,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
38
36
|
phic: BatchTensor,
|
|
39
37
|
inclination: BatchTensor,
|
|
40
38
|
f_ref: float,
|
|
41
|
-
tc:
|
|
39
|
+
tc: BatchTensor | None = None,
|
|
42
40
|
**kwargs,
|
|
43
41
|
):
|
|
44
42
|
"""
|
|
@@ -207,11 +205,11 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
207
205
|
chi2_l: BatchTensor,
|
|
208
206
|
chip: BatchTensor,
|
|
209
207
|
M: BatchTensor,
|
|
210
|
-
angcoeffs:
|
|
208
|
+
angcoeffs: dict[str, BatchTensor],
|
|
211
209
|
Y2m: BatchTensor,
|
|
212
210
|
alphaoffset: BatchTensor,
|
|
213
211
|
epsilonoffset: BatchTensor,
|
|
214
|
-
) ->
|
|
212
|
+
) -> tuple[BatchTensor, BatchTensor]:
|
|
215
213
|
assert angcoeffs is not None
|
|
216
214
|
assert Y2m is not None
|
|
217
215
|
f = fHz * MTSUN_SI * M.unsqueeze(1) # Frequency in geometric units
|
|
@@ -461,7 +459,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
461
459
|
s2x: BatchTensor,
|
|
462
460
|
s2y: BatchTensor,
|
|
463
461
|
s2z: BatchTensor,
|
|
464
|
-
) ->
|
|
462
|
+
) -> tuple[
|
|
465
463
|
BatchTensor,
|
|
466
464
|
BatchTensor,
|
|
467
465
|
BatchTensor,
|
|
@@ -634,7 +632,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
634
632
|
SL: BatchTensor,
|
|
635
633
|
eta: BatchTensor,
|
|
636
634
|
Sp: BatchTensor,
|
|
637
|
-
) ->
|
|
635
|
+
) -> tuple[BatchTensor, BatchTensor]:
|
|
638
636
|
# We define the shorthand s := Sp / (L + SL)
|
|
639
637
|
L = self.L2PNR(v, eta)
|
|
640
638
|
s = (Sp / (L + SL)).mT
|
|
@@ -650,7 +648,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
650
648
|
q: BatchTensor,
|
|
651
649
|
chil: BatchTensor,
|
|
652
650
|
chip: BatchTensor,
|
|
653
|
-
) ->
|
|
651
|
+
) -> dict[str, BatchTensor]:
|
|
654
652
|
m2 = q / (1.0 + q)
|
|
655
653
|
m1 = 1.0 / (1.0 + q)
|
|
656
654
|
dm = m1 - m2
|
|
@@ -796,7 +794,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
796
794
|
|
|
797
795
|
def phP_get_fRD_fdamp(
|
|
798
796
|
self, m1, m2, chi1_l, chi2_l, chip
|
|
799
|
-
) ->
|
|
797
|
+
) -> tuple[BatchTensor, BatchTensor]:
|
|
800
798
|
# m1 > m2 should hold here
|
|
801
799
|
finspin = self.FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip)
|
|
802
800
|
m1_s = m1 * MTSUN_SI
|
ml4gw/waveforms/conversion.py
CHANGED
|
@@ -86,10 +86,8 @@ def bilby_spins_to_lalsim(
|
|
|
86
86
|
# check if f_ref is valid
|
|
87
87
|
if f_ref <= 0.0:
|
|
88
88
|
raise ValueError(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
"Please pass in the starting GW frequency instead."
|
|
92
|
-
)
|
|
89
|
+
"f_ref <= 0 is invalid. "
|
|
90
|
+
"Please pass in the starting GW frequency instead."
|
|
93
91
|
)
|
|
94
92
|
|
|
95
93
|
# starting frame: LNhat is along the z-axis and the unit
|
ml4gw/waveforms/generator.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
@@ -115,7 +115,7 @@ class TimeDomainCBCWaveformGenerator(torch.nn.Module):
|
|
|
115
115
|
|
|
116
116
|
def generate_conditioned_fd_waveform(
|
|
117
117
|
self, **parameters: dict[str, BatchTensor]
|
|
118
|
-
) ->
|
|
118
|
+
) -> tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
|
|
119
119
|
"""
|
|
120
120
|
Generate a conditioned frequency domain waveform from a
|
|
121
121
|
frequency-domain approximant.
|
|
@@ -271,7 +271,7 @@ class TimeDomainCBCWaveformGenerator(torch.nn.Module):
|
|
|
271
271
|
def forward(
|
|
272
272
|
self,
|
|
273
273
|
**parameters,
|
|
274
|
-
) ->
|
|
274
|
+
) -> tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
|
|
275
275
|
"""
|
|
276
276
|
Generates a time-domain waveform from a frequency-domain approximant.
|
|
277
277
|
Conditioning is based onhttps://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/waveform_conditioning.py?ref_type=heads#L248
|
|
@@ -1,22 +1,24 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.8
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
|
-
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
5
|
+
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>, Ravi Kumar <ravi.kumar@ligo.org>
|
|
6
|
+
Maintainer-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>
|
|
7
|
+
License-Expression: GPL-3.0-or-later
|
|
9
8
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
9
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
-
Classifier:
|
|
13
|
-
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Astronomy
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
13
|
+
Requires-Python: <3.13,>=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
14
16
|
Requires-Dist: jaxtyping<0.3,>=0.2
|
|
17
|
+
Requires-Dist: torch~=2.0
|
|
18
|
+
Requires-Dist: torchaudio~=2.0
|
|
15
19
|
Requires-Dist: numpy<2.0.0
|
|
16
20
|
Requires-Dist: scipy<1.15,>=1.9.0
|
|
17
|
-
|
|
18
|
-
Requires-Dist: torch~=2.0
|
|
19
|
-
Description-Content-Type: text/markdown
|
|
21
|
+
Dynamic: license-file
|
|
20
22
|
|
|
21
23
|
# ML4GW
|
|
22
24
|

|
|
@@ -29,7 +31,10 @@ Torch utilities for training neural networks in gravitational wave physics appli
|
|
|
29
31
|
|
|
30
32
|
## Documentation
|
|
31
33
|
Please visit our [documentation page](https://ml4gw.github.io/ml4gw/) to see descriptions and examples of the functions and modules available in `ml4gw`.
|
|
32
|
-
We also have an interactive Jupyter notebook
|
|
34
|
+
We also have an interactive Jupyter notebook demonstrating much of the core functionality available [here](https://github.com/ML4GW/ml4gw/blob/main/docs/tutorials/ml4gw_tutorial.ipynb).
|
|
35
|
+
To run this notebook, download it from the above link and follow the instructions within it to install the required packages.
|
|
36
|
+
See also the [documentation page](https://ml4gw.github.io/ml4gw/tutorials/ml4gw_tutorial.html) for the tutorial to look
|
|
37
|
+
through it without running the code.
|
|
33
38
|
|
|
34
39
|
## Installation
|
|
35
40
|
### Pip installation
|
|
@@ -45,9 +50,25 @@ To build with a specific version of PyTorch/CUDA, please see the PyTorch install
|
|
|
45
50
|
pip install ml4gw torch==2.5.1--extra-index-url=https://download.pytorch.org/whl/cu118
|
|
46
51
|
```
|
|
47
52
|
|
|
53
|
+
### uv installation
|
|
54
|
+
If you want to develop `ml4gw`, you can use [uv](https://docs.astral.sh/uv/getting-started/installation/) to install the project in editable mode.
|
|
55
|
+
For example, after cloning the repository, create a virtualenv using
|
|
56
|
+
```bash
|
|
57
|
+
uv venv --python=3.11
|
|
58
|
+
```
|
|
59
|
+
Then sync the dependencies from the [uv lock file](/uv.lock) using
|
|
60
|
+
```bash
|
|
61
|
+
uv sync --all-extras
|
|
62
|
+
```
|
|
63
|
+
Code changes can be tested using
|
|
64
|
+
```bash
|
|
65
|
+
uv run pytest
|
|
66
|
+
```
|
|
67
|
+
See [contribution guide](/CONTRIBUTING.md) for more details.
|
|
68
|
+
|
|
48
69
|
## Contributing
|
|
49
70
|
If you come across errors in the code, have difficulties using this software, or simply find that the current version doesn't cover your use case, please file an issue on our GitHub page, and we'll be happy to offer support.
|
|
50
|
-
|
|
71
|
+
If you want to add feature, please refer to the [contribution guide](/CONTRIBUTING.md) for more details.
|
|
51
72
|
We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
|
|
52
73
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu).
|
|
53
74
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool that makes deep learning more accessible for gravitational wave physicists everywhere.
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
|
+
ml4gw/augmentations.py,sha256=Jck8FVtKM18evUESIRrJC-WFk8amhTdqD4kM74HjlGI,1382
|
|
3
|
+
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
+
ml4gw/distributions.py,sha256=zyE0Sk1cpBqmJ3oumWnJ9FCcIszz9yQDguAvbcrKcn4,12828
|
|
5
|
+
ml4gw/gw.py,sha256=L4WgicskE_Nn6UQMDG1C0TNquXC7D3A84yd_I54e9y0,20181
|
|
6
|
+
ml4gw/spectral.py,sha256=lT08wBEfCOnmdYDf5mwUqqFF-rrf-FZFx6jmpLeh63U,19897
|
|
7
|
+
ml4gw/types.py,sha256=weerxUI-PEWWvzGtzwY_A9AH_m_I1p-rGWvKbiZle-0,919
|
|
8
|
+
ml4gw/dataloading/__init__.py,sha256=Qrfq3yPGMy8-1gAMyo81MuKGK5gwp7y6Gx8gd4suEjc,240
|
|
9
|
+
ml4gw/dataloading/chunked_dataset.py,sha256=CZ40OTjuRx_v0nzLrApybIj40zCkdhP5xsGBnqGArB8,5255
|
|
10
|
+
ml4gw/dataloading/hdf5_dataset.py,sha256=72RfdKXCJiZShbRD0nqbuQFDdnVitYtkF-G2LBjG_fc,7934
|
|
11
|
+
ml4gw/dataloading/in_memory_dataset.py,sha256=wbdWID76IqhD01u-XlD-QPRnWPSVmkSQTyHR2lyy_NA,9559
|
|
12
|
+
ml4gw/nn/__init__.py,sha256=Vn8CqewYAK6GD-gOvsR7TJK6I568soXLpp9Ga9sYqkY,188
|
|
13
|
+
ml4gw/nn/norm.py,sha256=zd8NcjrtqM4yFyHFmDkknuV623NA5Cj0o6jBdPv6xh0,3584
|
|
14
|
+
ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
|
|
15
|
+
ml4gw/nn/autoencoder/base.py,sha256=rsb04p5m1MsDt8Z_CxfE4jVUHjW2N9iIjDWNJnwrIIk,3126
|
|
16
|
+
ml4gw/nn/autoencoder/convolutional.py,sha256=npdoiJD1jyHXJwyT-VEzDO5dgdqEVaaLcn6BOHB9u3A,5302
|
|
17
|
+
ml4gw/nn/autoencoder/skip_connection.py,sha256=jK8dYxSjkEhmwxYyOWjrzRTOO_uc8CH_EUJQk-BtSHU,1357
|
|
18
|
+
ml4gw/nn/autoencoder/utils.py,sha256=m_ivYGNwdrhA7cFxJVD4gqM8AHiWIGmlQI3pFNRklXQ,355
|
|
19
|
+
ml4gw/nn/resnet/__init__.py,sha256=vBI0IftVP_EYAeDlqomtkGqUYE-RE_S4WNioUhniw9s,64
|
|
20
|
+
ml4gw/nn/resnet/resnet_1d.py,sha256=uEa-Dz7MIe6udcRn4k1vLsiSiZMEhdBlypA4mN7KrGc,13221
|
|
21
|
+
ml4gw/nn/resnet/resnet_2d.py,sha256=YiHxP3cNIfjOrEmKSVqZYOUxoVnIkpDdwCW9VieNM7E,13319
|
|
22
|
+
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
23
|
+
ml4gw/nn/streaming/online_average.py,sha256=22jQ_JbJTpusmqeGdo7Ta7lTsGoTBjYtKZnXzucW3wc,4676
|
|
24
|
+
ml4gw/nn/streaming/snapshotter.py,sha256=kH73np-LUGF0ZP-tkWY19TrCJa3m1RIvvZ-SmLA7YvM,4378
|
|
25
|
+
ml4gw/transforms/__init__.py,sha256=0pnZhiW1Yz9ozhMlmZrMyPjRPQVZ9hUn0nIPeIIRpAk,632
|
|
26
|
+
ml4gw/transforms/decimator.py,sha256=6iuRyordzteWp9fE27zdrEfJY_sTmIe5sTSfs4xTL1U,6137
|
|
27
|
+
ml4gw/transforms/iirfilter.py,sha256=T4qgrJeA3vPeVWyZ-bPBOxQkJL0yfaUVoU0MTtpwhvg,3152
|
|
28
|
+
ml4gw/transforms/pearson.py,sha256=buPEfjaPJkMtdnBP5tvFzRzId8DbThAfJcGzaYmANqc,3202
|
|
29
|
+
ml4gw/transforms/qtransform.py,sha256=a7hqZ4pq9J6pq8L3Dm4Dqyxz-QyaznoNZO4mTdb5apY,20616
|
|
30
|
+
ml4gw/transforms/scaler.py,sha256=IqMekxqoyjeYWIuDJynkAk_MLs6l6G_TIY0mUE9h5Vo,2475
|
|
31
|
+
ml4gw/transforms/snr_rescaler.py,sha256=32__g-OI_6TCY9J6h3sCM0tbyzBx5vdbLQQ2vvbQtLM,2637
|
|
32
|
+
ml4gw/transforms/spectral.py,sha256=CGZOA0D-Oa7x3bVtfZsrwFWqEKGisMx2gTtHO-5wLcQ,4313
|
|
33
|
+
ml4gw/transforms/spectrogram.py,sha256=TGG_fVng-Y569KsIQAaf5_WN-W4-6F89oQSyHFxVo_w,6424
|
|
34
|
+
ml4gw/transforms/spline_interpolation.py,sha256=QHBp5g_1_lOYmCFJqQyZAbPXBZaqzb5_LBWeLY6ppkI,18614
|
|
35
|
+
ml4gw/transforms/transform.py,sha256=_jAxsCnLmIo97g4b2J8WKS0Omy-yyOzNt6lFHEM5ESM,2463
|
|
36
|
+
ml4gw/transforms/waveforms.py,sha256=yFOzGlYjjM488oYxZLpikS9noZCnisVM0_gjgWoF-_E,3018
|
|
37
|
+
ml4gw/transforms/whitening.py,sha256=65B6V_ERh_Pc9zghRPlMQSacxYg2bPunTDTswAW7GOk,10265
|
|
38
|
+
ml4gw/utils/interferometer.py,sha256=lRGtMRFSco1mI1Y1O2kz4dRp5hmK5cp4nG7sYbAiYG4,2179
|
|
39
|
+
ml4gw/utils/slicing.py,sha256=kaO54GMUV8d7vIt3oodVZJ4jFwJMEK8aGknJ8UOJzRs,13667
|
|
40
|
+
ml4gw/waveforms/__init__.py,sha256=SxTc6rSkQfoOtEgNYvA-8tMJsQQQROTRRKaFDRQOmh4,172
|
|
41
|
+
ml4gw/waveforms/conversion.py,sha256=RsJwJ_aZfIYzpGifzdvrAq5rg7dQbgUJDyfZuX_9uYI,6912
|
|
42
|
+
ml4gw/waveforms/generator.py,sha256=P8pHOr5r-Egz5EMJzdBhRNn1NTBAvdqWh2BmJtoVxJw,12300
|
|
43
|
+
ml4gw/waveforms/adhoc/__init__.py,sha256=XVwP4t8TMUj87WY3yMGRTkXsv7_lVr1w8p8iKBW8iKE,71
|
|
44
|
+
ml4gw/waveforms/adhoc/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
|
|
45
|
+
ml4gw/waveforms/adhoc/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
|
|
46
|
+
ml4gw/waveforms/cbc/__init__.py,sha256=hGbPsFNAIveYJnff8qKY8RWeBPFtZoYcnGHxraPWtWI,99
|
|
47
|
+
ml4gw/waveforms/cbc/coefficients.py,sha256=PMr0IBALEQ38eAvZqYg-w_FE_sS1mH2FWr9soQ5MRfU,1106
|
|
48
|
+
ml4gw/waveforms/cbc/phenom_d.py,sha256=FS4XBbhCicqYqaZnb3itqBZrjFex6wNoFMEMfClsW68,46908
|
|
49
|
+
ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
|
|
50
|
+
ml4gw/waveforms/cbc/phenom_p.py,sha256=m81Xt_zIffHiGlWgzf-AmI46mSn6CFZAX-6Fwwr5Tfk,27635
|
|
51
|
+
ml4gw/waveforms/cbc/taylorf2.py,sha256=emWbl3vjsCzBOooHOVO7pPlPcj05r4up6InlMkO5m_E,10422
|
|
52
|
+
ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
|
|
53
|
+
ml4gw-0.7.8.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
54
|
+
ml4gw-0.7.8.dist-info/METADATA,sha256=ZURGSlU7-WODIwI7GckMtDhzWkFxxsJ31_F14Gk23UU,4282
|
|
55
|
+
ml4gw-0.7.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
56
|
+
ml4gw-0.7.8.dist-info/top_level.txt,sha256=JnWLyPXJ3_WUcjr6fRV0ZTXj8FR0x4vBzjkg-1bl2tw,6
|
|
57
|
+
ml4gw-0.7.8.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ml4gw
|
ml4gw-0.7.6.dist-info/RECORD
DELETED
|
@@ -1,55 +0,0 @@
|
|
|
1
|
-
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
|
-
ml4gw/augmentations.py,sha256=4tSWO-I4Eg2QJWzdcLFg9QcOLlvRjNHvnjLCZS8K-Wc,1270
|
|
3
|
-
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
-
ml4gw/distributions.py,sha256=T6H1r5IMWyO38Uyb-BpmYx0AcokWN_ZJHGo-G_20m6w,12830
|
|
5
|
-
ml4gw/gw.py,sha256=bJ-GCZxanqrhbm373h9muOSZpam7wM-dJBZroy_pVNQ,20291
|
|
6
|
-
ml4gw/spectral.py,sha256=Mx_zRjZ9tD7N-wknv35oA3fk2X0rDJxQdQzRyuCFryw,19982
|
|
7
|
-
ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
|
|
8
|
-
ml4gw/dataloading/__init__.py,sha256=EHBBqU7y2-Np5iQ_xyufxamUEM1pPEquqFo7oaJnaJE,149
|
|
9
|
-
ml4gw/dataloading/chunked_dataset.py,sha256=exvhC0zbEkd3SnDidClQRhxY713cY68wQmEQ__3vRLI,5316
|
|
10
|
-
ml4gw/dataloading/hdf5_dataset.py,sha256=pKf_0UmwZ3UPOCDrwCuFxpcLbMihU2AKpjT_igmv87k,7935
|
|
11
|
-
ml4gw/dataloading/in_memory_dataset.py,sha256=7eDHq365XXBy1NywU72FOdHxSksK7UZYHFc3kvhNp8c,9597
|
|
12
|
-
ml4gw/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
-
ml4gw/nn/norm.py,sha256=JIOMXQbUtoWlrhncGsqW6f1-DiGDx9zQH2O3CvQml3U,3594
|
|
14
|
-
ml4gw/nn/autoencoder/__init__.py,sha256=ZaT1XhJTHpMuPQqu5E__Jezeh9uwtjcXlT7IZ18byq4,161
|
|
15
|
-
ml4gw/nn/autoencoder/base.py,sha256=RfDEeD6Ni1EJtQni5JTrLOWPr-zzDSWztO-DijKCgLQ,3226
|
|
16
|
-
ml4gw/nn/autoencoder/convolutional.py,sha256=tSFIgIZ0XUvdUaPQ-vJ2AeJ9g5FrwcBy_Py5IP6ezRw,5365
|
|
17
|
-
ml4gw/nn/autoencoder/skip_connection.py,sha256=9PKoCCvCUj5di9tuFM0Cl1v6gtcOK1bDeE_fS_R__FE,1391
|
|
18
|
-
ml4gw/nn/autoencoder/utils.py,sha256=m_ivYGNwdrhA7cFxJVD4gqM8AHiWIGmlQI3pFNRklXQ,355
|
|
19
|
-
ml4gw/nn/resnet/__init__.py,sha256=vBI0IftVP_EYAeDlqomtkGqUYE-RE_S4WNioUhniw9s,64
|
|
20
|
-
ml4gw/nn/resnet/resnet_1d.py,sha256=dNWD8vGBiwF8qbQfjNEXDIEUgWRgrEoq__XUU-XN2uA,13268
|
|
21
|
-
ml4gw/nn/resnet/resnet_2d.py,sha256=MAbXtkSrP4aWGtY-QC8ox3-y5jDHJrzRPL5ryQ4RBvM,13367
|
|
22
|
-
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
23
|
-
ml4gw/nn/streaming/online_average.py,sha256=YSFUHhwNfQjUJbzQCqaCVApSueswzYB4yel981Omiqw,4718
|
|
24
|
-
ml4gw/nn/streaming/snapshotter.py,sha256=vEQLFi-fEH-o7TO9SmYXy5whxFxXQBDeOQOFhSnofSg,4503
|
|
25
|
-
ml4gw/transforms/__init__.py,sha256=OaTQJD4GFkDkcxt0DIwt2AzeEcv9t21ciKXxQnqDiuI,447
|
|
26
|
-
ml4gw/transforms/iirfilter.py,sha256=HcdsjcSaSi2xe65ojxnaqeSdbYvSQVFIkHKon3nW238,3194
|
|
27
|
-
ml4gw/transforms/pearson.py,sha256=sFyHD6IdskbRS8V1fY0Kt9N8R2_EhnuL6UjFa6fnmTU,3244
|
|
28
|
-
ml4gw/transforms/qtransform.py,sha256=dXE3Genxgg3UdQ5dM-FfcvbX--UGpr0hjX9sO5tpM7k,20754
|
|
29
|
-
ml4gw/transforms/scaler.py,sha256=BKn4RQ_TNArdwPI_j5nAe7H2jOH_-MrZPsNByE-8Pl8,2518
|
|
30
|
-
ml4gw/transforms/snr_rescaler.py,sha256=lfuwdwMY117gB-emmn0_22gsK_A9xnkHJv2-76HFWc4,2728
|
|
31
|
-
ml4gw/transforms/spectral.py,sha256=ebAuPSdQqha6J3MMzxqJqR31XPKUDrSz3iJaHM3orpk,4449
|
|
32
|
-
ml4gw/transforms/spectrogram.py,sha256=NIyTD8kZRe8rjMUTy1_-wpFyvAswzTfYwD4TJJcPqgs,6369
|
|
33
|
-
ml4gw/transforms/spline_interpolation.py,sha256=iz6CkRzAYFSMjRTLFJAetE5FAI6WmrpfKzMPK4sueNQ,13320
|
|
34
|
-
ml4gw/transforms/transform.py,sha256=lpHQbM4PhdijvNBsZigPX-mS04aiVVq5q3HMfxvpFg0,2506
|
|
35
|
-
ml4gw/transforms/waveforms.py,sha256=koWOuHuUpQWmTT1yawSWa_MOuLfDBuugy91KIyuklOo,3189
|
|
36
|
-
ml4gw/transforms/whitening.py,sha256=UyFustRhu3zv0ynJBvvxekWA-YOMwEIOYDNpoD5r_qQ,10400
|
|
37
|
-
ml4gw/utils/interferometer.py,sha256=lRS0N3SwUTknhYXX57VACJ99jK1P9M19oUWN_i_nQN0,1814
|
|
38
|
-
ml4gw/utils/slicing.py,sha256=kQ0xIW5Ojko4uKS1VI5i7PMUk7Fk81dT6p6tuQ0nyBI,13763
|
|
39
|
-
ml4gw/waveforms/__init__.py,sha256=QVUzBx_y8A9_AsRuTJruPvL9mqGnBt11Iw1MOYjXyE4,40
|
|
40
|
-
ml4gw/waveforms/conversion.py,sha256=vF8u_4FWwXAXJEtWZ_N0GhbEnt6snsyW-9fasGLaCok,6948
|
|
41
|
-
ml4gw/waveforms/generator.py,sha256=Ml23ZoxyN4FDYty5the13rQ_HO4bnnvxInORKJqMkBk,12298
|
|
42
|
-
ml4gw/waveforms/adhoc/__init__.py,sha256=XVwP4t8TMUj87WY3yMGRTkXsv7_lVr1w8p8iKBW8iKE,71
|
|
43
|
-
ml4gw/waveforms/adhoc/ringdown.py,sha256=m8IBQTxKBBGFqBtWGEO4KG3DEYR8TTnNyGVdVLaMKa8,3316
|
|
44
|
-
ml4gw/waveforms/adhoc/sine_gaussian.py,sha256=-MtrI7ydwBTk4K0O4tdkC8-w5OifQszdnWN9__I4XzY,3569
|
|
45
|
-
ml4gw/waveforms/cbc/__init__.py,sha256=hGbPsFNAIveYJnff8qKY8RWeBPFtZoYcnGHxraPWtWI,99
|
|
46
|
-
ml4gw/waveforms/cbc/coefficients.py,sha256=PMr0IBALEQ38eAvZqYg-w_FE_sS1mH2FWr9soQ5MRfU,1106
|
|
47
|
-
ml4gw/waveforms/cbc/phenom_d.py,sha256=FS4XBbhCicqYqaZnb3itqBZrjFex6wNoFMEMfClsW68,46908
|
|
48
|
-
ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
|
|
49
|
-
ml4gw/waveforms/cbc/phenom_p.py,sha256=tOUBoYfr0ub6OGRjDQbquGoW8AnThiGjJvbHhyGnAnk,27680
|
|
50
|
-
ml4gw/waveforms/cbc/taylorf2.py,sha256=emWbl3vjsCzBOooHOVO7pPlPcj05r4up6InlMkO5m_E,10422
|
|
51
|
-
ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
|
|
52
|
-
ml4gw-0.7.6.dist-info/METADATA,sha256=dI3qI2Kk4p-XP_hPs7QWPfgjzRGQMcIva-ST6mBdA0A,3380
|
|
53
|
-
ml4gw-0.7.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
54
|
-
ml4gw-0.7.6.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
55
|
-
ml4gw-0.7.6.dist-info/RECORD,,
|
|
File without changes
|