ml4gw 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +5 -3
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +5 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +9 -2
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +1 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +20 -18
- ml4gw/waveforms/phenom_p.py +77 -60
- ml4gw/waveforms/ringdown.py +8 -9
- ml4gw/waveforms/sine_gaussian.py +6 -6
- ml4gw/waveforms/taylorf2.py +33 -27
- {ml4gw-0.5.0.dist-info → ml4gw-0.5.1.dist-info}/METADATA +4 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.5.0.dist-info/RECORD +0 -47
- {ml4gw-0.5.0.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
ml4gw/transforms/pearson.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
from torch import Tensor
|
|
2
4
|
|
|
5
|
+
from ml4gw.types import TimeSeries1to3d
|
|
3
6
|
from ml4gw.utils.slicing import unfold_windows
|
|
4
7
|
|
|
5
8
|
|
|
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
40
43
|
super().__init__()
|
|
41
44
|
self.max_shift = max_shift
|
|
42
45
|
|
|
43
|
-
def _shape_checks(self, x:
|
|
46
|
+
def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
|
|
44
47
|
if x.ndim > 3:
|
|
45
48
|
raise ValueError(
|
|
46
49
|
"Tensor x can only have up to 3 dimensions "
|
|
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
61
64
|
)
|
|
62
65
|
)
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
67
|
+
def forward(
|
|
68
|
+
self, x: TimeSeries1to3d, y: TimeSeries1to3d
|
|
69
|
+
) -> Float[Tensor, "windows ..."]:
|
|
66
70
|
self._shape_checks(x, y)
|
|
67
71
|
dim = x.size(-1)
|
|
68
72
|
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -3,6 +3,10 @@ from typing import List, Optional, Tuple
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float, Int
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
6
10
|
|
|
7
11
|
"""
|
|
8
12
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -44,7 +48,7 @@ class QTile(torch.nn.Module):
|
|
|
44
48
|
duration: float,
|
|
45
49
|
sample_rate: float,
|
|
46
50
|
mismatch: float,
|
|
47
|
-
):
|
|
51
|
+
) -> None:
|
|
48
52
|
super().__init__()
|
|
49
53
|
self.mismatch = mismatch
|
|
50
54
|
self.q = q
|
|
@@ -63,18 +67,18 @@ class QTile(torch.nn.Module):
|
|
|
63
67
|
self.register_buffer("indices", self.get_data_indices())
|
|
64
68
|
self.register_buffer("window", self.get_window())
|
|
65
69
|
|
|
66
|
-
def ntiles(self):
|
|
70
|
+
def ntiles(self) -> int:
|
|
67
71
|
"""
|
|
68
72
|
Number of tiles in this frequency row
|
|
69
73
|
"""
|
|
70
74
|
tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
|
|
71
75
|
return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
|
|
72
76
|
|
|
73
|
-
def _get_indices(self):
|
|
77
|
+
def _get_indices(self) -> Int[Tensor, " windowsize"]:
|
|
74
78
|
half = int((self.windowsize - 1) / 2)
|
|
75
79
|
return torch.arange(-half, half + 1)
|
|
76
80
|
|
|
77
|
-
def get_window(self):
|
|
81
|
+
def get_window(self) -> Float[Tensor, " windowsize"]:
|
|
78
82
|
"""
|
|
79
83
|
Generate the bi-square window for this row
|
|
80
84
|
"""
|
|
@@ -87,7 +91,7 @@ class QTile(torch.nn.Module):
|
|
|
87
91
|
)
|
|
88
92
|
return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
|
|
89
93
|
|
|
90
|
-
def get_data_indices(self):
|
|
94
|
+
def get_data_indices(self) -> Int[Tensor, " windowsize"]:
|
|
91
95
|
"""
|
|
92
96
|
Get the index array of relevant frequencies for this row
|
|
93
97
|
"""
|
|
@@ -95,7 +99,9 @@ class QTile(torch.nn.Module):
|
|
|
95
99
|
self._get_indices() + 1 + self.frequency * self.duration,
|
|
96
100
|
).type(torch.long)
|
|
97
101
|
|
|
98
|
-
def forward(
|
|
102
|
+
def forward(
|
|
103
|
+
self, fseries: FrequencySeries1to3d, norm: str = "median"
|
|
104
|
+
) -> TimeSeries1to3d:
|
|
99
105
|
"""
|
|
100
106
|
Compute the transform for this row
|
|
101
107
|
|
|
@@ -176,7 +182,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
176
182
|
q: float = 12,
|
|
177
183
|
frange: List[float] = [0, torch.inf],
|
|
178
184
|
mismatch: float = 0.2,
|
|
179
|
-
):
|
|
185
|
+
) -> None:
|
|
180
186
|
super().__init__()
|
|
181
187
|
self.q = q
|
|
182
188
|
self.spectrogram_shape = spectrogram_shape
|
|
@@ -198,7 +204,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
198
204
|
)
|
|
199
205
|
self.qtiles = None
|
|
200
206
|
|
|
201
|
-
def get_freqs(self):
|
|
207
|
+
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
202
208
|
"""
|
|
203
209
|
Calculate the frequencies that will be used in this transform.
|
|
204
210
|
For each frequency, a `QTile` is created.
|
|
@@ -262,7 +268,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
262
268
|
if dimension == "batch":
|
|
263
269
|
return torch.max(max_across_ft, dim=-1).values
|
|
264
270
|
|
|
265
|
-
def compute_qtiles(self, X:
|
|
271
|
+
def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
|
|
266
272
|
"""
|
|
267
273
|
Take the FFT of the input timeseries and calculate the transform
|
|
268
274
|
for each `QTile`
|
|
@@ -272,7 +278,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
272
278
|
X[..., 1:] *= 2
|
|
273
279
|
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
274
280
|
|
|
275
|
-
def interpolate(self, num_f_bins: int, num_t_bins: int):
|
|
281
|
+
def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
|
|
276
282
|
"""
|
|
277
283
|
Interpolate each `QTile` to the specified number of time and
|
|
278
284
|
frequency bins. Note that PyTorch does not have the same
|
|
@@ -299,7 +305,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
299
305
|
|
|
300
306
|
def forward(
|
|
301
307
|
self,
|
|
302
|
-
X:
|
|
308
|
+
X: TimeSeries1to3d,
|
|
303
309
|
norm: str = "median",
|
|
304
310
|
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
305
311
|
):
|
|
@@ -371,7 +377,7 @@ class QScan(torch.nn.Module):
|
|
|
371
377
|
qrange: List[float] = [4, 64],
|
|
372
378
|
frange: List[float] = [0, torch.inf],
|
|
373
379
|
mismatch: float = 0.2,
|
|
374
|
-
):
|
|
380
|
+
) -> None:
|
|
375
381
|
super().__init__()
|
|
376
382
|
self.qrange = qrange
|
|
377
383
|
self.mismatch = mismatch
|
|
@@ -397,7 +403,7 @@ class QScan(torch.nn.Module):
|
|
|
397
403
|
]
|
|
398
404
|
)
|
|
399
405
|
|
|
400
|
-
def get_qs(self):
|
|
406
|
+
def get_qs(self) -> List[float]:
|
|
401
407
|
"""
|
|
402
408
|
Determine the values of Q to try for the set of Q-transforms
|
|
403
409
|
"""
|
|
@@ -413,7 +419,7 @@ class QScan(torch.nn.Module):
|
|
|
413
419
|
|
|
414
420
|
def forward(
|
|
415
421
|
self,
|
|
416
|
-
X:
|
|
422
|
+
X: TimeSeries1to3d,
|
|
417
423
|
fsearch_range: List[float] = None,
|
|
418
424
|
norm: str = "median",
|
|
419
425
|
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.transforms.transform import FittableTransform
|
|
6
8
|
|
|
@@ -34,7 +36,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
34
36
|
self.register_buffer("mean", mean)
|
|
35
37
|
self.register_buffer("std", std)
|
|
36
38
|
|
|
37
|
-
def fit(self, X:
|
|
39
|
+
def fit(self, X: Float[Tensor, "... time"]) -> None:
|
|
38
40
|
"""Fit the scaling parameters to a timeseries
|
|
39
41
|
|
|
40
42
|
Computes the channel-wise mean and standard deviation
|
|
@@ -60,7 +62,9 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
60
62
|
|
|
61
63
|
super().build(mean=mean, std=std)
|
|
62
64
|
|
|
63
|
-
def forward(
|
|
65
|
+
def forward(
|
|
66
|
+
self, X: Float[Tensor, "... time"], reverse: bool = False
|
|
67
|
+
) -> Float[Tensor, "... time"]:
|
|
64
68
|
if not reverse:
|
|
65
69
|
return (X - self.mean) / self.std
|
|
66
70
|
else:
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -2,8 +2,9 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ml4gw import
|
|
5
|
+
from ml4gw.gw import compute_network_snr
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SnrRescaler(FittableSpectralTransform):
|
|
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
34
35
|
|
|
35
36
|
def fit(
|
|
36
37
|
self,
|
|
37
|
-
*background:
|
|
38
|
+
*background: TimeSeries2d,
|
|
38
39
|
fftlength: Optional[float] = None,
|
|
39
40
|
overlap: Optional[float] = None,
|
|
40
41
|
):
|
|
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
58
59
|
|
|
59
60
|
def forward(
|
|
60
61
|
self,
|
|
61
|
-
responses:
|
|
62
|
-
target_snrs: Optional[
|
|
62
|
+
responses: WaveformTensor,
|
|
63
|
+
target_snrs: Optional[BatchTensor] = None,
|
|
63
64
|
):
|
|
64
|
-
snrs =
|
|
65
|
+
snrs = compute_network_snr(
|
|
65
66
|
responses, self.background, self.sample_rate, self.mask
|
|
66
67
|
)
|
|
67
68
|
if target_snrs is None:
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.spectral import fast_spectral_density, spectral_density
|
|
8
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class SpectralDensity(torch.nn.Module):
|
|
@@ -51,7 +54,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
51
54
|
fftlength: float,
|
|
52
55
|
overlap: Optional[float] = None,
|
|
53
56
|
average: str = "mean",
|
|
54
|
-
window: Optional[
|
|
57
|
+
window: Optional[
|
|
58
|
+
Float[Tensor, " {int(fftlength*sample_rate)}"]
|
|
59
|
+
] = None,
|
|
55
60
|
fast: bool = False,
|
|
56
61
|
) -> None:
|
|
57
62
|
if overlap is None:
|
|
@@ -93,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
93
98
|
self.average = average
|
|
94
99
|
self.fast = fast
|
|
95
100
|
|
|
96
|
-
def forward(
|
|
101
|
+
def forward(
|
|
102
|
+
self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
|
|
103
|
+
) -> FrequencySeries1to3d:
|
|
97
104
|
if self.fast:
|
|
98
105
|
return fast_spectral_density(
|
|
99
106
|
x,
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -3,8 +3,12 @@ from typing import Dict, List
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import Tensor
|
|
6
8
|
from torchaudio.transforms import Spectrogram
|
|
7
9
|
|
|
10
|
+
from ml4gw.types import TimeSeries3d
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
class MultiResolutionSpectrogram(torch.nn.Module):
|
|
10
14
|
"""
|
|
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
122
126
|
|
|
123
127
|
return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
|
|
124
128
|
|
|
125
|
-
def forward(
|
|
129
|
+
def forward(
|
|
130
|
+
self, X: TimeSeries3d
|
|
131
|
+
) -> Float[Tensor, "batch channel frequency time"]:
|
|
126
132
|
"""
|
|
127
133
|
Calculate spectrograms of the input tensor and
|
|
128
134
|
combine them into a single spectrogram
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw.spectral import spectral_density
|
|
6
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class FittableTransform(torch.nn.Module):
|
|
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
|
|
|
43
44
|
class FittableSpectralTransform(FittableTransform):
|
|
44
45
|
def normalize_psd(
|
|
45
46
|
self,
|
|
46
|
-
x,
|
|
47
|
+
x: TimeSeries1to3d,
|
|
47
48
|
sample_rate: float,
|
|
48
49
|
num_freqs: int,
|
|
49
50
|
fftlength: Optional[float] = None,
|
|
50
51
|
overlap: Optional[float] = None,
|
|
51
|
-
):
|
|
52
|
+
) -> FrequencySeries1to3d:
|
|
52
53
|
# if we specified an FFT length, convert
|
|
53
54
|
# the (assumed) time-domain data to the
|
|
54
55
|
# frequency domain
|
|
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
68
69
|
scale=scale,
|
|
69
70
|
)
|
|
70
71
|
|
|
71
|
-
# add two dummy dimensions in case we need to
|
|
72
|
+
# add two dummy dimensions in case we need to interpolate
|
|
72
73
|
# the frequency dimension, since `interpolate` expects
|
|
73
74
|
# a (batch, channel, spatial) formatted tensor as input
|
|
74
75
|
x = x.view(1, 1, -1)
|
ml4gw/transforms/waveforms.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw import gw
|
|
8
|
+
from ml4gw.types import BatchTensor
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
# TODO: should these live in ml4gw.waveforms submodule?
|
|
@@ -10,8 +13,8 @@ from ml4gw import gw
|
|
|
10
13
|
class WaveformSampler(torch.nn.Module):
|
|
11
14
|
def __init__(
|
|
12
15
|
self,
|
|
13
|
-
parameters: Optional[
|
|
14
|
-
**polarizations:
|
|
16
|
+
parameters: Optional[Float[Tensor, "batch num_params"]] = None,
|
|
17
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
15
18
|
):
|
|
16
19
|
super().__init__()
|
|
17
20
|
# make sure we have the same number of waveforms
|
|
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
|
|
|
29
32
|
elif num_waveforms is None:
|
|
30
33
|
num_waveforms = tensor.shape[0]
|
|
31
34
|
|
|
32
|
-
self.polarizations[polarization] =
|
|
35
|
+
self.polarizations[polarization] = Tensor(tensor)
|
|
33
36
|
|
|
34
37
|
if parameters is not None and len(parameters) != num_waveforms:
|
|
35
38
|
raise ValueError(
|
|
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
|
|
|
73
76
|
|
|
74
77
|
def forward(
|
|
75
78
|
self,
|
|
76
|
-
dec:
|
|
77
|
-
psi:
|
|
78
|
-
phi:
|
|
79
|
-
**polarizations,
|
|
79
|
+
dec: BatchTensor,
|
|
80
|
+
psi: BatchTensor,
|
|
81
|
+
phi: BatchTensor,
|
|
82
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
80
83
|
):
|
|
81
84
|
ifo_responses = gw.compute_observed_strain(
|
|
82
85
|
dec,
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw import spectral
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import (
|
|
8
|
+
FrequencySeries1d,
|
|
9
|
+
FrequencySeries1to3d,
|
|
10
|
+
TimeSeries1d,
|
|
11
|
+
TimeSeries3d,
|
|
12
|
+
)
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
class Whiten(torch.nn.Module):
|
|
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
|
|
|
58
64
|
window = torch.hann_window(size, dtype=torch.float64)
|
|
59
65
|
self.register_buffer("window", window)
|
|
60
66
|
|
|
61
|
-
def forward(
|
|
67
|
+
def forward(
|
|
68
|
+
self, X: TimeSeries3d, psd: FrequencySeries1to3d
|
|
69
|
+
) -> TimeSeries3d:
|
|
62
70
|
"""
|
|
63
71
|
Whiten a batch of multichannel timeseries by a
|
|
64
72
|
background power spectral density.
|
|
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
142
150
|
def fit(
|
|
143
151
|
self,
|
|
144
152
|
fduration: float,
|
|
145
|
-
*background:
|
|
153
|
+
*background: Union[TimeSeries1d, FrequencySeries1d],
|
|
146
154
|
fftlength: Optional[float] = None,
|
|
147
155
|
highpass: Optional[float] = None,
|
|
148
156
|
overlap: Optional[float] = None
|
|
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
224
232
|
fduration = torch.Tensor([fduration])
|
|
225
233
|
self.build(psd=psd, fduration=fduration)
|
|
226
234
|
|
|
227
|
-
def forward(self, X:
|
|
235
|
+
def forward(self, X: TimeSeries3d) -> TimeSeries3d:
|
|
228
236
|
"""
|
|
229
237
|
Whiten the input timeseries tensor using the
|
|
230
238
|
PSD fit by the `.fit` method, which must be
|
ml4gw/types.py
CHANGED
|
@@ -1,10 +1,25 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
WaveformTensor = Float[Tensor, "batch num_ifos time"]
|
|
7
|
+
PSDTensor = Float[Tensor, "num_ifos frequency"]
|
|
8
|
+
BatchTensor = Float[Tensor, "batch"]
|
|
9
|
+
VectorGeometry = Float[Tensor, "batch space"]
|
|
10
|
+
TensorGeometry = Float[Tensor, "batch space space"]
|
|
11
|
+
NetworkVertices = Float[Tensor, "num_ifos 3"]
|
|
12
|
+
NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
TimeSeries1d = Float[Tensor, "time"]
|
|
16
|
+
TimeSeries2d = Float[TimeSeries1d, "channel"]
|
|
17
|
+
TimeSeries3d = Float[TimeSeries2d, "batch"]
|
|
18
|
+
TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
|
|
19
|
+
|
|
20
|
+
FrequencySeries1d = Float[Tensor, "frequency"]
|
|
21
|
+
FrequencySeries2d = Float[FrequencySeries1d, "channel"]
|
|
22
|
+
FrequencySeries3d = Float[FrequencySeries2d, "batch"]
|
|
23
|
+
FrequencySeries1to3d = Union[
|
|
24
|
+
FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
|
|
25
|
+
]
|
ml4gw/utils/interferometer.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
# based on values from
|
|
5
5
|
# https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
|
|
6
6
|
class InterferometerGeometry:
|
|
7
|
-
def __init__(self, name: str):
|
|
7
|
+
def __init__(self, name: str) -> None:
|
|
8
8
|
if name == "H1":
|
|
9
9
|
self.x_arm = torch.Tensor(
|
|
10
10
|
(-0.22389266154, +0.79983062746, +0.55690487831)
|
ml4gw/utils/slicing.py
CHANGED
|
@@ -1,25 +1,30 @@
|
|
|
1
1
|
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float, Int64
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
from torch.nn.functional import unfold
|
|
5
|
-
from torchtyping import TensorType
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
from ml4gw.types import (
|
|
9
|
+
TimeSeries1d,
|
|
10
|
+
TimeSeries1to3d,
|
|
11
|
+
TimeSeries2d,
|
|
12
|
+
TimeSeries3d,
|
|
13
|
+
)
|
|
9
14
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
BatchTimeSeriesTensor = Union[
|
|
13
|
-
TensorType["batch", "time"], TensorType["batch", "channel", "time"]
|
|
14
|
-
]
|
|
15
|
+
BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def unfold_windows(
|
|
18
|
-
x:
|
|
19
|
+
x: TimeSeries1to3d,
|
|
19
20
|
window_size: int,
|
|
20
21
|
stride: int,
|
|
21
22
|
drop_last: bool = True,
|
|
22
|
-
)
|
|
23
|
+
) -> Union[
|
|
24
|
+
Float[TimeSeries1d, " window"],
|
|
25
|
+
Float[TimeSeries2d, " window"],
|
|
26
|
+
Float[TimeSeries3d, " window"],
|
|
27
|
+
]:
|
|
23
28
|
"""Unfold a timeseries into windows
|
|
24
29
|
|
|
25
30
|
Args:
|
|
@@ -83,8 +88,8 @@ def unfold_windows(
|
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
def slice_kernels(
|
|
86
|
-
x:
|
|
87
|
-
idx:
|
|
91
|
+
x: TimeSeries1to3d,
|
|
92
|
+
idx: Int64[Tensor, "..."],
|
|
88
93
|
kernel_size: int,
|
|
89
94
|
) -> BatchTimeSeriesTensor:
|
|
90
95
|
"""Slice kernels from single or multichannel timeseries
|
|
@@ -96,7 +101,8 @@ def slice_kernels(
|
|
|
96
101
|
one more dimension than `x`.
|
|
97
102
|
|
|
98
103
|
Args:
|
|
99
|
-
x:
|
|
104
|
+
x:
|
|
105
|
+
The timeseries tensor to slice kernels from
|
|
100
106
|
idx:
|
|
101
107
|
The indices in `x` of the first sample of each
|
|
102
108
|
kernel. If `x` is 1D, `idx` must be 1D as well.
|
|
@@ -114,6 +120,7 @@ def slice_kernels(
|
|
|
114
120
|
coincidentally among the channels.
|
|
115
121
|
kernel_size:
|
|
116
122
|
The length of the kernels to slice from the timeseries
|
|
123
|
+
|
|
117
124
|
Returns:
|
|
118
125
|
A tensor of shape `(batch_size, kernel_size)` if `x` is
|
|
119
126
|
1D and `(batch_size, num_channels, kernel_size)` if `x`
|
|
@@ -225,7 +232,7 @@ def slice_kernels(
|
|
|
225
232
|
|
|
226
233
|
|
|
227
234
|
def sample_kernels(
|
|
228
|
-
X:
|
|
235
|
+
X: TimeSeries1to3d,
|
|
229
236
|
kernel_size: int,
|
|
230
237
|
N: Optional[int] = None,
|
|
231
238
|
max_center_offset: Optional[int] = None,
|
|
@@ -245,8 +252,9 @@ def sample_kernels(
|
|
|
245
252
|
either be `None` or be equal to `len(X)`.
|
|
246
253
|
|
|
247
254
|
Args:
|
|
248
|
-
X:
|
|
249
|
-
|
|
255
|
+
X:
|
|
256
|
+
The timeseries tensor from which to sample kernels
|
|
257
|
+
kernel_size: The size of the kernels to sample
|
|
250
258
|
N:
|
|
251
259
|
The number of kernels to sample. Can be left as
|
|
252
260
|
`None` if `X` is 3D, otherwise must be specified
|
ml4gw/waveforms/generator.py
CHANGED
|
@@ -1,24 +1,26 @@
|
|
|
1
|
-
from typing import Callable
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class ParameterSampler(torch.nn.Module):
|
|
7
|
-
def __init__(self, **parameters: Callable):
|
|
9
|
+
def __init__(self, **parameters: Callable) -> None:
|
|
8
10
|
super().__init__()
|
|
9
11
|
self.parameters = parameters
|
|
10
12
|
|
|
11
13
|
def forward(
|
|
12
14
|
self,
|
|
13
15
|
N: int,
|
|
14
|
-
):
|
|
16
|
+
) -> Dict[str, Float[Tensor, " {N}"]]:
|
|
15
17
|
return {k: v.sample((N,)) for k, v in self.parameters.items()}
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class WaveformGenerator(torch.nn.Module):
|
|
19
21
|
def __init__(
|
|
20
22
|
self, waveform: Callable, parameter_sampler: ParameterSampler
|
|
21
|
-
):
|
|
23
|
+
) -> None:
|
|
22
24
|
"""
|
|
23
25
|
A torch module that generates waveforms from a given waveform function
|
|
24
26
|
and a parameter sampler.
|
|
@@ -34,6 +36,8 @@ class WaveformGenerator(torch.nn.Module):
|
|
|
34
36
|
self.waveform = waveform
|
|
35
37
|
self.parameter_sampler = parameter_sampler
|
|
36
38
|
|
|
37
|
-
def forward(
|
|
39
|
+
def forward(
|
|
40
|
+
self, N: int
|
|
41
|
+
) -> Tuple[Float[Tensor, "{N} samples"], Dict[str, Float[Tensor, " {N}"]]]:
|
|
38
42
|
parameters = self.parameter_sampler(N)
|
|
39
43
|
return self.waveform(**parameters), parameters
|
ml4gw/waveforms/phenom_d.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
|
|
4
|
+
from ml4gw.constants import MTSUN_SI, PI
|
|
5
|
+
from ml4gw.types import BatchTensor, FrequencySeries1d
|
|
3
6
|
|
|
4
|
-
from ..constants import MTSUN_SI, PI
|
|
5
7
|
from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
|
|
6
8
|
from .taylorf2 import TaylorF2
|
|
7
9
|
|
|
@@ -15,14 +17,14 @@ class IMRPhenomD(TaylorF2):
|
|
|
15
17
|
|
|
16
18
|
def forward(
|
|
17
19
|
self,
|
|
18
|
-
f:
|
|
19
|
-
chirp_mass:
|
|
20
|
-
mass_ratio:
|
|
21
|
-
chi1:
|
|
22
|
-
chi2:
|
|
23
|
-
distance:
|
|
24
|
-
phic:
|
|
25
|
-
inclination:
|
|
20
|
+
f: FrequencySeries1d,
|
|
21
|
+
chirp_mass: BatchTensor,
|
|
22
|
+
mass_ratio: BatchTensor,
|
|
23
|
+
chi1: BatchTensor,
|
|
24
|
+
chi2: BatchTensor,
|
|
25
|
+
distance: BatchTensor,
|
|
26
|
+
phic: BatchTensor,
|
|
27
|
+
inclination: BatchTensor,
|
|
26
28
|
f_ref: float,
|
|
27
29
|
):
|
|
28
30
|
"""
|
|
@@ -76,15 +78,15 @@ class IMRPhenomD(TaylorF2):
|
|
|
76
78
|
|
|
77
79
|
def phenom_d_htilde(
|
|
78
80
|
self,
|
|
79
|
-
f:
|
|
80
|
-
chirp_mass:
|
|
81
|
-
mass_ratio:
|
|
82
|
-
chi1:
|
|
83
|
-
chi2:
|
|
84
|
-
distance:
|
|
85
|
-
phic:
|
|
81
|
+
f: FrequencySeries1d,
|
|
82
|
+
chirp_mass: BatchTensor,
|
|
83
|
+
mass_ratio: BatchTensor,
|
|
84
|
+
chi1: BatchTensor,
|
|
85
|
+
chi2: BatchTensor,
|
|
86
|
+
distance: BatchTensor,
|
|
87
|
+
phic: BatchTensor,
|
|
86
88
|
f_ref: float,
|
|
87
|
-
):
|
|
89
|
+
) -> Float[FrequencySeries1d, " batch"]:
|
|
88
90
|
total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
|
|
89
91
|
mass_1 = total_mass / (1 + mass_ratio)
|
|
90
92
|
mass_2 = mass_1 * mass_ratio
|