ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +45 -0
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +18 -12
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +11 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +25 -6
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +7 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +1338 -1256
- ml4gw/waveforms/phenom_p.py +796 -0
- ml4gw/waveforms/ringdown.py +109 -0
- ml4gw/waveforms/sine_gaussian.py +10 -11
- ml4gw/waveforms/taylorf2.py +304 -279
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/METADATA +5 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.4.2.dist-info/RECORD +0 -44
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
ml4gw/spectral.py
CHANGED
|
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
|
|
|
12
12
|
from typing import Optional, Union
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
|
-
from
|
|
15
|
+
from jaxtyping import Float
|
|
16
|
+
from torch import Tensor
|
|
16
17
|
|
|
17
|
-
from ml4gw import
|
|
18
|
+
from ml4gw.types import (
|
|
19
|
+
FrequencySeries1to3d,
|
|
20
|
+
PSDTensor,
|
|
21
|
+
TimeSeries1to3d,
|
|
22
|
+
WaveformTensor,
|
|
23
|
+
)
|
|
18
24
|
|
|
19
|
-
time = None
|
|
20
25
|
|
|
21
|
-
|
|
22
|
-
def median(x, axis):
|
|
26
|
+
def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
23
27
|
"""
|
|
24
28
|
Implements a median calculation that matches numpy's
|
|
25
29
|
behavior for an even number of elements and includes
|
|
@@ -33,7 +37,7 @@ def median(x, axis):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
def _validate_shapes(
|
|
36
|
-
x:
|
|
40
|
+
x: Tensor, nperseg: int, y: Optional[Tensor] = None
|
|
37
41
|
) -> None:
|
|
38
42
|
if x.shape[-1] < nperseg:
|
|
39
43
|
raise ValueError(
|
|
@@ -83,14 +87,14 @@ def _validate_shapes(
|
|
|
83
87
|
|
|
84
88
|
|
|
85
89
|
def fast_spectral_density(
|
|
86
|
-
x:
|
|
90
|
+
x: TimeSeries1to3d,
|
|
87
91
|
nperseg: int,
|
|
88
92
|
nstride: int,
|
|
89
|
-
window:
|
|
90
|
-
scale:
|
|
93
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
94
|
+
scale: float,
|
|
91
95
|
average: str = "median",
|
|
92
|
-
y: Optional[
|
|
93
|
-
) ->
|
|
96
|
+
y: Optional[TimeSeries1to3d] = None,
|
|
97
|
+
) -> FrequencySeries1to3d:
|
|
94
98
|
"""
|
|
95
99
|
Compute the power spectral density of a multichannel
|
|
96
100
|
timeseries or a batch of multichannel timeseries, or
|
|
@@ -107,9 +111,9 @@ def fast_spectral_density(
|
|
|
107
111
|
The timeseries tensor whose power spectral density
|
|
108
112
|
to compute, or for cross spectral density the
|
|
109
113
|
timeseries whose fft will be conjugated. Can have
|
|
110
|
-
shape
|
|
111
|
-
`(
|
|
112
|
-
|
|
114
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
115
|
+
`(num_channels, length * sample_rate)`, or
|
|
116
|
+
`(length * sample_rate)`.
|
|
113
117
|
nperseg:
|
|
114
118
|
Number of samples included in each FFT window
|
|
115
119
|
nstride:
|
|
@@ -150,7 +154,7 @@ def fast_spectral_density(
|
|
|
150
154
|
channel in `x` across _all_ of `x`'s batch elements.
|
|
151
155
|
Returns:
|
|
152
156
|
Tensor of power spectral densities of `x` or its cross spectral
|
|
153
|
-
|
|
157
|
+
density with the timeseries in `y`.
|
|
154
158
|
"""
|
|
155
159
|
|
|
156
160
|
_validate_shapes(x, nperseg, y)
|
|
@@ -240,17 +244,16 @@ def fast_spectral_density(
|
|
|
240
244
|
|
|
241
245
|
|
|
242
246
|
def spectral_density(
|
|
243
|
-
x:
|
|
247
|
+
x: TimeSeries1to3d,
|
|
244
248
|
nperseg: int,
|
|
245
249
|
nstride: int,
|
|
246
|
-
window:
|
|
247
|
-
scale:
|
|
250
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
251
|
+
scale: float,
|
|
248
252
|
average: str = "median",
|
|
249
|
-
) ->
|
|
253
|
+
) -> FrequencySeries1to3d:
|
|
250
254
|
"""
|
|
251
255
|
Compute the power spectral density of a multichannel
|
|
252
|
-
timeseries or a batch of multichannel timeseries
|
|
253
|
-
the cross power spectral density of two such timeseries.
|
|
256
|
+
timeseries or a batch of multichannel timeseries.
|
|
254
257
|
This implementation is exact for all frequency bins, but
|
|
255
258
|
slower than the fast implementation.
|
|
256
259
|
|
|
@@ -259,9 +262,9 @@ def spectral_density(
|
|
|
259
262
|
The timeseries tensor whose power spectral density
|
|
260
263
|
to compute, or for cross spectral density the
|
|
261
264
|
timeseries whose fft will be conjugated. Can have
|
|
262
|
-
shape
|
|
263
|
-
`(
|
|
264
|
-
|
|
265
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
266
|
+
`(num_channels, length * sample_rate)`, or
|
|
267
|
+
`(length * sample_rate)`.
|
|
265
268
|
nperseg:
|
|
266
269
|
Number of samples included in each FFT window
|
|
267
270
|
nstride:
|
|
@@ -336,11 +339,11 @@ def spectral_density(
|
|
|
336
339
|
|
|
337
340
|
|
|
338
341
|
def truncate_inverse_power_spectrum(
|
|
339
|
-
psd:
|
|
340
|
-
fduration: Union[
|
|
342
|
+
psd: PSDTensor,
|
|
343
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
341
344
|
sample_rate: float,
|
|
342
345
|
highpass: Optional[float] = None,
|
|
343
|
-
) ->
|
|
346
|
+
) -> PSDTensor:
|
|
344
347
|
"""
|
|
345
348
|
Truncate the length of the time domain response
|
|
346
349
|
of a whitening filter built using the specified
|
|
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
|
|
|
399
402
|
q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
|
|
400
403
|
|
|
401
404
|
# taper the edges of the TD filter
|
|
402
|
-
if isinstance(fduration,
|
|
405
|
+
if isinstance(fduration, Tensor):
|
|
403
406
|
pad = fduration.size(-1) // 2
|
|
404
407
|
window = fduration
|
|
405
408
|
else:
|
|
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
|
|
|
422
425
|
|
|
423
426
|
|
|
424
427
|
def normalize_by_psd(
|
|
425
|
-
X:
|
|
426
|
-
psd:
|
|
428
|
+
X: WaveformTensor,
|
|
429
|
+
psd: PSDTensor,
|
|
427
430
|
sample_rate: float,
|
|
428
431
|
pad: int,
|
|
429
432
|
):
|
|
@@ -447,12 +450,12 @@ def normalize_by_psd(
|
|
|
447
450
|
|
|
448
451
|
|
|
449
452
|
def whiten(
|
|
450
|
-
X:
|
|
451
|
-
psd:
|
|
452
|
-
fduration: Union[
|
|
453
|
+
X: WaveformTensor,
|
|
454
|
+
psd: PSDTensor,
|
|
455
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
453
456
|
sample_rate: float,
|
|
454
457
|
highpass: Optional[float] = None,
|
|
455
|
-
) ->
|
|
458
|
+
) -> WaveformTensor:
|
|
456
459
|
"""
|
|
457
460
|
Whiten a batch of timeseries using the specified
|
|
458
461
|
background one-sided power spectral densities (PSDs),
|
|
@@ -460,7 +463,8 @@ def whiten(
|
|
|
460
463
|
`fduration` and possibly to highpass filter.
|
|
461
464
|
|
|
462
465
|
Args:
|
|
463
|
-
X:
|
|
466
|
+
X:
|
|
467
|
+
batch of multichannel timeseries to whiten
|
|
464
468
|
psd:
|
|
465
469
|
PSDs use to whiten the data. The frequency
|
|
466
470
|
response of the whitening filter will be roughly
|
|
@@ -496,7 +500,7 @@ def whiten(
|
|
|
496
500
|
|
|
497
501
|
# figure out how much data we'll need to slice
|
|
498
502
|
# off after whitening
|
|
499
|
-
if isinstance(fduration,
|
|
503
|
+
if isinstance(fduration, Tensor):
|
|
500
504
|
pad = fduration.size(-1) // 2
|
|
501
505
|
else:
|
|
502
506
|
pad = int(fduration * sample_rate / 2)
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
from torch import Tensor
|
|
2
4
|
|
|
5
|
+
from ml4gw.types import TimeSeries1to3d
|
|
3
6
|
from ml4gw.utils.slicing import unfold_windows
|
|
4
7
|
|
|
5
8
|
|
|
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
40
43
|
super().__init__()
|
|
41
44
|
self.max_shift = max_shift
|
|
42
45
|
|
|
43
|
-
def _shape_checks(self, x:
|
|
46
|
+
def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
|
|
44
47
|
if x.ndim > 3:
|
|
45
48
|
raise ValueError(
|
|
46
49
|
"Tensor x can only have up to 3 dimensions "
|
|
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
61
64
|
)
|
|
62
65
|
)
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
67
|
+
def forward(
|
|
68
|
+
self, x: TimeSeries1to3d, y: TimeSeries1to3d
|
|
69
|
+
) -> Float[Tensor, "windows ..."]:
|
|
66
70
|
self._shape_checks(x, y)
|
|
67
71
|
dim = x.size(-1)
|
|
68
72
|
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -3,6 +3,10 @@ from typing import List, Optional, Tuple
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float, Int
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
6
10
|
|
|
7
11
|
"""
|
|
8
12
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -44,7 +48,7 @@ class QTile(torch.nn.Module):
|
|
|
44
48
|
duration: float,
|
|
45
49
|
sample_rate: float,
|
|
46
50
|
mismatch: float,
|
|
47
|
-
):
|
|
51
|
+
) -> None:
|
|
48
52
|
super().__init__()
|
|
49
53
|
self.mismatch = mismatch
|
|
50
54
|
self.q = q
|
|
@@ -63,18 +67,18 @@ class QTile(torch.nn.Module):
|
|
|
63
67
|
self.register_buffer("indices", self.get_data_indices())
|
|
64
68
|
self.register_buffer("window", self.get_window())
|
|
65
69
|
|
|
66
|
-
def ntiles(self):
|
|
70
|
+
def ntiles(self) -> int:
|
|
67
71
|
"""
|
|
68
72
|
Number of tiles in this frequency row
|
|
69
73
|
"""
|
|
70
74
|
tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
|
|
71
75
|
return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
|
|
72
76
|
|
|
73
|
-
def _get_indices(self):
|
|
77
|
+
def _get_indices(self) -> Int[Tensor, " windowsize"]:
|
|
74
78
|
half = int((self.windowsize - 1) / 2)
|
|
75
79
|
return torch.arange(-half, half + 1)
|
|
76
80
|
|
|
77
|
-
def get_window(self):
|
|
81
|
+
def get_window(self) -> Float[Tensor, " windowsize"]:
|
|
78
82
|
"""
|
|
79
83
|
Generate the bi-square window for this row
|
|
80
84
|
"""
|
|
@@ -87,7 +91,7 @@ class QTile(torch.nn.Module):
|
|
|
87
91
|
)
|
|
88
92
|
return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
|
|
89
93
|
|
|
90
|
-
def get_data_indices(self):
|
|
94
|
+
def get_data_indices(self) -> Int[Tensor, " windowsize"]:
|
|
91
95
|
"""
|
|
92
96
|
Get the index array of relevant frequencies for this row
|
|
93
97
|
"""
|
|
@@ -95,7 +99,9 @@ class QTile(torch.nn.Module):
|
|
|
95
99
|
self._get_indices() + 1 + self.frequency * self.duration,
|
|
96
100
|
).type(torch.long)
|
|
97
101
|
|
|
98
|
-
def forward(
|
|
102
|
+
def forward(
|
|
103
|
+
self, fseries: FrequencySeries1to3d, norm: str = "median"
|
|
104
|
+
) -> TimeSeries1to3d:
|
|
99
105
|
"""
|
|
100
106
|
Compute the transform for this row
|
|
101
107
|
|
|
@@ -176,7 +182,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
176
182
|
q: float = 12,
|
|
177
183
|
frange: List[float] = [0, torch.inf],
|
|
178
184
|
mismatch: float = 0.2,
|
|
179
|
-
):
|
|
185
|
+
) -> None:
|
|
180
186
|
super().__init__()
|
|
181
187
|
self.q = q
|
|
182
188
|
self.spectrogram_shape = spectrogram_shape
|
|
@@ -198,7 +204,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
198
204
|
)
|
|
199
205
|
self.qtiles = None
|
|
200
206
|
|
|
201
|
-
def get_freqs(self):
|
|
207
|
+
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
202
208
|
"""
|
|
203
209
|
Calculate the frequencies that will be used in this transform.
|
|
204
210
|
For each frequency, a `QTile` is created.
|
|
@@ -262,7 +268,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
262
268
|
if dimension == "batch":
|
|
263
269
|
return torch.max(max_across_ft, dim=-1).values
|
|
264
270
|
|
|
265
|
-
def compute_qtiles(self, X:
|
|
271
|
+
def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
|
|
266
272
|
"""
|
|
267
273
|
Take the FFT of the input timeseries and calculate the transform
|
|
268
274
|
for each `QTile`
|
|
@@ -272,7 +278,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
272
278
|
X[..., 1:] *= 2
|
|
273
279
|
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
274
280
|
|
|
275
|
-
def interpolate(self, num_f_bins: int, num_t_bins: int):
|
|
281
|
+
def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
|
|
276
282
|
"""
|
|
277
283
|
Interpolate each `QTile` to the specified number of time and
|
|
278
284
|
frequency bins. Note that PyTorch does not have the same
|
|
@@ -299,7 +305,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
299
305
|
|
|
300
306
|
def forward(
|
|
301
307
|
self,
|
|
302
|
-
X:
|
|
308
|
+
X: TimeSeries1to3d,
|
|
303
309
|
norm: str = "median",
|
|
304
310
|
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
305
311
|
):
|
|
@@ -371,7 +377,7 @@ class QScan(torch.nn.Module):
|
|
|
371
377
|
qrange: List[float] = [4, 64],
|
|
372
378
|
frange: List[float] = [0, torch.inf],
|
|
373
379
|
mismatch: float = 0.2,
|
|
374
|
-
):
|
|
380
|
+
) -> None:
|
|
375
381
|
super().__init__()
|
|
376
382
|
self.qrange = qrange
|
|
377
383
|
self.mismatch = mismatch
|
|
@@ -397,7 +403,7 @@ class QScan(torch.nn.Module):
|
|
|
397
403
|
]
|
|
398
404
|
)
|
|
399
405
|
|
|
400
|
-
def get_qs(self):
|
|
406
|
+
def get_qs(self) -> List[float]:
|
|
401
407
|
"""
|
|
402
408
|
Determine the values of Q to try for the set of Q-transforms
|
|
403
409
|
"""
|
|
@@ -413,7 +419,7 @@ class QScan(torch.nn.Module):
|
|
|
413
419
|
|
|
414
420
|
def forward(
|
|
415
421
|
self,
|
|
416
|
-
X:
|
|
422
|
+
X: TimeSeries1to3d,
|
|
417
423
|
fsearch_range: List[float] = None,
|
|
418
424
|
norm: str = "median",
|
|
419
425
|
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.transforms.transform import FittableTransform
|
|
6
8
|
|
|
@@ -34,7 +36,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
34
36
|
self.register_buffer("mean", mean)
|
|
35
37
|
self.register_buffer("std", std)
|
|
36
38
|
|
|
37
|
-
def fit(self, X:
|
|
39
|
+
def fit(self, X: Float[Tensor, "... time"]) -> None:
|
|
38
40
|
"""Fit the scaling parameters to a timeseries
|
|
39
41
|
|
|
40
42
|
Computes the channel-wise mean and standard deviation
|
|
@@ -60,7 +62,9 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
60
62
|
|
|
61
63
|
super().build(mean=mean, std=std)
|
|
62
64
|
|
|
63
|
-
def forward(
|
|
65
|
+
def forward(
|
|
66
|
+
self, X: Float[Tensor, "... time"], reverse: bool = False
|
|
67
|
+
) -> Float[Tensor, "... time"]:
|
|
64
68
|
if not reverse:
|
|
65
69
|
return (X - self.mean) / self.std
|
|
66
70
|
else:
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -2,8 +2,9 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ml4gw import
|
|
5
|
+
from ml4gw.gw import compute_network_snr
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SnrRescaler(FittableSpectralTransform):
|
|
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
34
35
|
|
|
35
36
|
def fit(
|
|
36
37
|
self,
|
|
37
|
-
*background:
|
|
38
|
+
*background: TimeSeries2d,
|
|
38
39
|
fftlength: Optional[float] = None,
|
|
39
40
|
overlap: Optional[float] = None,
|
|
40
41
|
):
|
|
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
58
59
|
|
|
59
60
|
def forward(
|
|
60
61
|
self,
|
|
61
|
-
responses:
|
|
62
|
-
target_snrs: Optional[
|
|
62
|
+
responses: WaveformTensor,
|
|
63
|
+
target_snrs: Optional[BatchTensor] = None,
|
|
63
64
|
):
|
|
64
|
-
snrs =
|
|
65
|
+
snrs = compute_network_snr(
|
|
65
66
|
responses, self.background, self.sample_rate, self.mask
|
|
66
67
|
)
|
|
67
68
|
if target_snrs is None:
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.spectral import fast_spectral_density, spectral_density
|
|
8
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class SpectralDensity(torch.nn.Module):
|
|
@@ -34,6 +37,10 @@ class SpectralDensity(torch.nn.Module):
|
|
|
34
37
|
average:
|
|
35
38
|
Aggregation method to use for combining windowed FFTs.
|
|
36
39
|
Allowed values are `"mean"` and `"median"`.
|
|
40
|
+
window:
|
|
41
|
+
Window array to multiply by each FFT window before
|
|
42
|
+
FFT computation. Should have length `nperseg`.
|
|
43
|
+
Defaults to a hanning window.
|
|
37
44
|
fast:
|
|
38
45
|
Whether to use a faster spectral density computation that
|
|
39
46
|
support cross spectral density, or a slower one which does
|
|
@@ -47,6 +54,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
47
54
|
fftlength: float,
|
|
48
55
|
overlap: Optional[float] = None,
|
|
49
56
|
average: str = "mean",
|
|
57
|
+
window: Optional[
|
|
58
|
+
Float[Tensor, " {int(fftlength*sample_rate)}"]
|
|
59
|
+
] = None,
|
|
50
60
|
fast: bool = False,
|
|
51
61
|
) -> None:
|
|
52
62
|
if overlap is None:
|
|
@@ -63,11 +73,18 @@ class SpectralDensity(torch.nn.Module):
|
|
|
63
73
|
self.nperseg = int(fftlength * sample_rate)
|
|
64
74
|
self.nstride = self.nperseg - int(overlap * sample_rate)
|
|
65
75
|
|
|
66
|
-
#
|
|
67
|
-
#
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
76
|
+
# if no window is provided, default to a hanning window;
|
|
77
|
+
# validate that window is correct size
|
|
78
|
+
if window is None:
|
|
79
|
+
window = torch.hann_window(self.nperseg)
|
|
80
|
+
|
|
81
|
+
if window.size(0) != self.nperseg:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"Window must have length {} got {}".format(
|
|
84
|
+
self.nperseg, window.size(0)
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
self.register_buffer("window", window)
|
|
71
88
|
|
|
72
89
|
# scale corresponds to "density" normalization, worth
|
|
73
90
|
# considering adding this as a kwarg and changing this calc
|
|
@@ -81,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
81
98
|
self.average = average
|
|
82
99
|
self.fast = fast
|
|
83
100
|
|
|
84
|
-
def forward(
|
|
101
|
+
def forward(
|
|
102
|
+
self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
|
|
103
|
+
) -> FrequencySeries1to3d:
|
|
85
104
|
if self.fast:
|
|
86
105
|
return fast_spectral_density(
|
|
87
106
|
x,
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -3,8 +3,12 @@ from typing import Dict, List
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import Tensor
|
|
6
8
|
from torchaudio.transforms import Spectrogram
|
|
7
9
|
|
|
10
|
+
from ml4gw.types import TimeSeries3d
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
class MultiResolutionSpectrogram(torch.nn.Module):
|
|
10
14
|
"""
|
|
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
122
126
|
|
|
123
127
|
return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
|
|
124
128
|
|
|
125
|
-
def forward(
|
|
129
|
+
def forward(
|
|
130
|
+
self, X: TimeSeries3d
|
|
131
|
+
) -> Float[Tensor, "batch channel frequency time"]:
|
|
126
132
|
"""
|
|
127
133
|
Calculate spectrograms of the input tensor and
|
|
128
134
|
combine them into a single spectrogram
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw.spectral import spectral_density
|
|
6
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class FittableTransform(torch.nn.Module):
|
|
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
|
|
|
43
44
|
class FittableSpectralTransform(FittableTransform):
|
|
44
45
|
def normalize_psd(
|
|
45
46
|
self,
|
|
46
|
-
x,
|
|
47
|
+
x: TimeSeries1to3d,
|
|
47
48
|
sample_rate: float,
|
|
48
49
|
num_freqs: int,
|
|
49
50
|
fftlength: Optional[float] = None,
|
|
50
51
|
overlap: Optional[float] = None,
|
|
51
|
-
):
|
|
52
|
+
) -> FrequencySeries1to3d:
|
|
52
53
|
# if we specified an FFT length, convert
|
|
53
54
|
# the (assumed) time-domain data to the
|
|
54
55
|
# frequency domain
|
|
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
68
69
|
scale=scale,
|
|
69
70
|
)
|
|
70
71
|
|
|
71
|
-
# add two dummy dimensions in case we need to
|
|
72
|
+
# add two dummy dimensions in case we need to interpolate
|
|
72
73
|
# the frequency dimension, since `interpolate` expects
|
|
73
74
|
# a (batch, channel, spatial) formatted tensor as input
|
|
74
75
|
x = x.view(1, 1, -1)
|
ml4gw/transforms/waveforms.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw import gw
|
|
8
|
+
from ml4gw.types import BatchTensor
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
# TODO: should these live in ml4gw.waveforms submodule?
|
|
@@ -10,8 +13,8 @@ from ml4gw import gw
|
|
|
10
13
|
class WaveformSampler(torch.nn.Module):
|
|
11
14
|
def __init__(
|
|
12
15
|
self,
|
|
13
|
-
parameters: Optional[
|
|
14
|
-
**polarizations:
|
|
16
|
+
parameters: Optional[Float[Tensor, "batch num_params"]] = None,
|
|
17
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
15
18
|
):
|
|
16
19
|
super().__init__()
|
|
17
20
|
# make sure we have the same number of waveforms
|
|
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
|
|
|
29
32
|
elif num_waveforms is None:
|
|
30
33
|
num_waveforms = tensor.shape[0]
|
|
31
34
|
|
|
32
|
-
self.polarizations[polarization] =
|
|
35
|
+
self.polarizations[polarization] = Tensor(tensor)
|
|
33
36
|
|
|
34
37
|
if parameters is not None and len(parameters) != num_waveforms:
|
|
35
38
|
raise ValueError(
|
|
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
|
|
|
73
76
|
|
|
74
77
|
def forward(
|
|
75
78
|
self,
|
|
76
|
-
dec:
|
|
77
|
-
psi:
|
|
78
|
-
phi:
|
|
79
|
-
**polarizations,
|
|
79
|
+
dec: BatchTensor,
|
|
80
|
+
psi: BatchTensor,
|
|
81
|
+
phi: BatchTensor,
|
|
82
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
80
83
|
):
|
|
81
84
|
ifo_responses = gw.compute_observed_strain(
|
|
82
85
|
dec,
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw import spectral
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import (
|
|
8
|
+
FrequencySeries1d,
|
|
9
|
+
FrequencySeries1to3d,
|
|
10
|
+
TimeSeries1d,
|
|
11
|
+
TimeSeries3d,
|
|
12
|
+
)
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
class Whiten(torch.nn.Module):
|
|
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
|
|
|
58
64
|
window = torch.hann_window(size, dtype=torch.float64)
|
|
59
65
|
self.register_buffer("window", window)
|
|
60
66
|
|
|
61
|
-
def forward(
|
|
67
|
+
def forward(
|
|
68
|
+
self, X: TimeSeries3d, psd: FrequencySeries1to3d
|
|
69
|
+
) -> TimeSeries3d:
|
|
62
70
|
"""
|
|
63
71
|
Whiten a batch of multichannel timeseries by a
|
|
64
72
|
background power spectral density.
|
|
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
142
150
|
def fit(
|
|
143
151
|
self,
|
|
144
152
|
fduration: float,
|
|
145
|
-
*background:
|
|
153
|
+
*background: Union[TimeSeries1d, FrequencySeries1d],
|
|
146
154
|
fftlength: Optional[float] = None,
|
|
147
155
|
highpass: Optional[float] = None,
|
|
148
156
|
overlap: Optional[float] = None
|
|
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
224
232
|
fduration = torch.Tensor([fduration])
|
|
225
233
|
self.build(psd=psd, fduration=fduration)
|
|
226
234
|
|
|
227
|
-
def forward(self, X:
|
|
235
|
+
def forward(self, X: TimeSeries3d) -> TimeSeries3d:
|
|
228
236
|
"""
|
|
229
237
|
Whiten the input timeseries tensor using the
|
|
230
238
|
PSD fit by the `.fit` method, which must be
|
ml4gw/types.py
CHANGED
|
@@ -1,10 +1,25 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
WaveformTensor = Float[Tensor, "batch num_ifos time"]
|
|
7
|
+
PSDTensor = Float[Tensor, "num_ifos frequency"]
|
|
8
|
+
BatchTensor = Float[Tensor, "batch"]
|
|
9
|
+
VectorGeometry = Float[Tensor, "batch space"]
|
|
10
|
+
TensorGeometry = Float[Tensor, "batch space space"]
|
|
11
|
+
NetworkVertices = Float[Tensor, "num_ifos 3"]
|
|
12
|
+
NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
TimeSeries1d = Float[Tensor, "time"]
|
|
16
|
+
TimeSeries2d = Float[TimeSeries1d, "channel"]
|
|
17
|
+
TimeSeries3d = Float[TimeSeries2d, "batch"]
|
|
18
|
+
TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
|
|
19
|
+
|
|
20
|
+
FrequencySeries1d = Float[Tensor, "frequency"]
|
|
21
|
+
FrequencySeries2d = Float[FrequencySeries1d, "channel"]
|
|
22
|
+
FrequencySeries3d = Float[FrequencySeries2d, "batch"]
|
|
23
|
+
FrequencySeries1to3d = Union[
|
|
24
|
+
FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
|
|
25
|
+
]
|
ml4gw/utils/interferometer.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
# based on values from
|
|
5
5
|
# https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
|
|
6
6
|
class InterferometerGeometry:
|
|
7
|
-
def __init__(self, name: str):
|
|
7
|
+
def __init__(self, name: str) -> None:
|
|
8
8
|
if name == "H1":
|
|
9
9
|
self.x_arm = torch.Tensor(
|
|
10
10
|
(-0.22389266154, +0.79983062746, +0.55690487831)
|
|
@@ -35,6 +35,12 @@ class InterferometerGeometry:
|
|
|
35
35
|
self.vertex = torch.Tensor(
|
|
36
36
|
(4.54637409900e06, 8.42989697626e05, 4.37857696241e06)
|
|
37
37
|
)
|
|
38
|
+
elif name == "K1":
|
|
39
|
+
self.x_arm = torch.Tensor((-0.3759040, -0.8361583, 0.3994189))
|
|
40
|
+
self.y_arm = torch.Tensor((0.7164378, 0.01114076, 0.6975620))
|
|
41
|
+
self.vertex = torch.Tensor(
|
|
42
|
+
(-3777336.024, 3484898.411, 3765313.697)
|
|
43
|
+
)
|
|
38
44
|
else:
|
|
39
45
|
raise ValueError(
|
|
40
46
|
f"{name} is not recognized as an interferometer, "
|