ml4gw 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +10 -19
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +5 -3
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +5 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +41 -37
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +151 -53
- ml4gw/transforms/scaler.py +9 -3
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +9 -2
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/spline_interpolation.py +370 -0
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +1 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -5
- ml4gw/waveforms/adhoc/__init__.py +2 -0
- ml4gw/waveforms/{ringdown.py → adhoc/ringdown.py} +8 -9
- ml4gw/waveforms/{sine_gaussian.py → adhoc/sine_gaussian.py} +6 -6
- ml4gw/waveforms/cbc/__init__.py +3 -0
- ml4gw/waveforms/{phenom_d.py → cbc/phenom_d.py} +20 -18
- ml4gw/waveforms/{phenom_p.py → cbc/phenom_p.py} +106 -95
- ml4gw/waveforms/{taylorf2.py → cbc/taylorf2.py} +33 -27
- ml4gw/waveforms/conversion.py +187 -0
- ml4gw/waveforms/generator.py +9 -5
- {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/METADATA +4 -3
- ml4gw-0.6.0.dist-info/RECORD +51 -0
- {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/WHEEL +1 -1
- ml4gw-0.5.0.dist-info/RECORD +0 -47
- /ml4gw/waveforms/{phenom_d_data.py → cbc/phenom_d_data.py} +0 -0
ml4gw/spectral.py
CHANGED
|
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
|
|
|
12
12
|
from typing import Optional, Union
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
|
-
from
|
|
15
|
+
from jaxtyping import Float
|
|
16
|
+
from torch import Tensor
|
|
16
17
|
|
|
17
|
-
from ml4gw import
|
|
18
|
+
from ml4gw.types import (
|
|
19
|
+
FrequencySeries1to3d,
|
|
20
|
+
PSDTensor,
|
|
21
|
+
TimeSeries1to3d,
|
|
22
|
+
WaveformTensor,
|
|
23
|
+
)
|
|
18
24
|
|
|
19
|
-
time = None
|
|
20
25
|
|
|
21
|
-
|
|
22
|
-
def median(x, axis):
|
|
26
|
+
def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
23
27
|
"""
|
|
24
28
|
Implements a median calculation that matches numpy's
|
|
25
29
|
behavior for an even number of elements and includes
|
|
@@ -33,7 +37,7 @@ def median(x, axis):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
def _validate_shapes(
|
|
36
|
-
x:
|
|
40
|
+
x: Tensor, nperseg: int, y: Optional[Tensor] = None
|
|
37
41
|
) -> None:
|
|
38
42
|
if x.shape[-1] < nperseg:
|
|
39
43
|
raise ValueError(
|
|
@@ -83,14 +87,14 @@ def _validate_shapes(
|
|
|
83
87
|
|
|
84
88
|
|
|
85
89
|
def fast_spectral_density(
|
|
86
|
-
x:
|
|
90
|
+
x: TimeSeries1to3d,
|
|
87
91
|
nperseg: int,
|
|
88
92
|
nstride: int,
|
|
89
|
-
window:
|
|
90
|
-
scale:
|
|
93
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
94
|
+
scale: float,
|
|
91
95
|
average: str = "median",
|
|
92
|
-
y: Optional[
|
|
93
|
-
) ->
|
|
96
|
+
y: Optional[TimeSeries1to3d] = None,
|
|
97
|
+
) -> FrequencySeries1to3d:
|
|
94
98
|
"""
|
|
95
99
|
Compute the power spectral density of a multichannel
|
|
96
100
|
timeseries or a batch of multichannel timeseries, or
|
|
@@ -107,9 +111,9 @@ def fast_spectral_density(
|
|
|
107
111
|
The timeseries tensor whose power spectral density
|
|
108
112
|
to compute, or for cross spectral density the
|
|
109
113
|
timeseries whose fft will be conjugated. Can have
|
|
110
|
-
shape
|
|
111
|
-
`(
|
|
112
|
-
|
|
114
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
115
|
+
`(num_channels, length * sample_rate)`, or
|
|
116
|
+
`(length * sample_rate)`.
|
|
113
117
|
nperseg:
|
|
114
118
|
Number of samples included in each FFT window
|
|
115
119
|
nstride:
|
|
@@ -150,7 +154,7 @@ def fast_spectral_density(
|
|
|
150
154
|
channel in `x` across _all_ of `x`'s batch elements.
|
|
151
155
|
Returns:
|
|
152
156
|
Tensor of power spectral densities of `x` or its cross spectral
|
|
153
|
-
|
|
157
|
+
density with the timeseries in `y`.
|
|
154
158
|
"""
|
|
155
159
|
|
|
156
160
|
_validate_shapes(x, nperseg, y)
|
|
@@ -240,17 +244,16 @@ def fast_spectral_density(
|
|
|
240
244
|
|
|
241
245
|
|
|
242
246
|
def spectral_density(
|
|
243
|
-
x:
|
|
247
|
+
x: TimeSeries1to3d,
|
|
244
248
|
nperseg: int,
|
|
245
249
|
nstride: int,
|
|
246
|
-
window:
|
|
247
|
-
scale:
|
|
250
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
251
|
+
scale: float,
|
|
248
252
|
average: str = "median",
|
|
249
|
-
) ->
|
|
253
|
+
) -> FrequencySeries1to3d:
|
|
250
254
|
"""
|
|
251
255
|
Compute the power spectral density of a multichannel
|
|
252
|
-
timeseries or a batch of multichannel timeseries
|
|
253
|
-
the cross power spectral density of two such timeseries.
|
|
256
|
+
timeseries or a batch of multichannel timeseries.
|
|
254
257
|
This implementation is exact for all frequency bins, but
|
|
255
258
|
slower than the fast implementation.
|
|
256
259
|
|
|
@@ -259,9 +262,9 @@ def spectral_density(
|
|
|
259
262
|
The timeseries tensor whose power spectral density
|
|
260
263
|
to compute, or for cross spectral density the
|
|
261
264
|
timeseries whose fft will be conjugated. Can have
|
|
262
|
-
shape
|
|
263
|
-
`(
|
|
264
|
-
|
|
265
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
266
|
+
`(num_channels, length * sample_rate)`, or
|
|
267
|
+
`(length * sample_rate)`.
|
|
265
268
|
nperseg:
|
|
266
269
|
Number of samples included in each FFT window
|
|
267
270
|
nstride:
|
|
@@ -336,11 +339,11 @@ def spectral_density(
|
|
|
336
339
|
|
|
337
340
|
|
|
338
341
|
def truncate_inverse_power_spectrum(
|
|
339
|
-
psd:
|
|
340
|
-
fduration: Union[
|
|
342
|
+
psd: PSDTensor,
|
|
343
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
341
344
|
sample_rate: float,
|
|
342
345
|
highpass: Optional[float] = None,
|
|
343
|
-
) ->
|
|
346
|
+
) -> PSDTensor:
|
|
344
347
|
"""
|
|
345
348
|
Truncate the length of the time domain response
|
|
346
349
|
of a whitening filter built using the specified
|
|
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
|
|
|
399
402
|
q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
|
|
400
403
|
|
|
401
404
|
# taper the edges of the TD filter
|
|
402
|
-
if isinstance(fduration,
|
|
405
|
+
if isinstance(fduration, Tensor):
|
|
403
406
|
pad = fduration.size(-1) // 2
|
|
404
407
|
window = fduration
|
|
405
408
|
else:
|
|
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
|
|
|
422
425
|
|
|
423
426
|
|
|
424
427
|
def normalize_by_psd(
|
|
425
|
-
X:
|
|
426
|
-
psd:
|
|
428
|
+
X: WaveformTensor,
|
|
429
|
+
psd: PSDTensor,
|
|
427
430
|
sample_rate: float,
|
|
428
431
|
pad: int,
|
|
429
432
|
):
|
|
@@ -438,7 +441,7 @@ def normalize_by_psd(
|
|
|
438
441
|
|
|
439
442
|
# convert back to the time domain and normalize
|
|
440
443
|
# TODO: what's this normalization factor?
|
|
441
|
-
X = torch.fft.irfft(X_tilde, norm="forward", dim=-1)
|
|
444
|
+
X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1)
|
|
442
445
|
X = X.float() / sample_rate**0.5
|
|
443
446
|
|
|
444
447
|
# slice off corrupted data at edges of kernel
|
|
@@ -447,12 +450,12 @@ def normalize_by_psd(
|
|
|
447
450
|
|
|
448
451
|
|
|
449
452
|
def whiten(
|
|
450
|
-
X:
|
|
451
|
-
psd:
|
|
452
|
-
fduration: Union[
|
|
453
|
+
X: WaveformTensor,
|
|
454
|
+
psd: PSDTensor,
|
|
455
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
453
456
|
sample_rate: float,
|
|
454
457
|
highpass: Optional[float] = None,
|
|
455
|
-
) ->
|
|
458
|
+
) -> WaveformTensor:
|
|
456
459
|
"""
|
|
457
460
|
Whiten a batch of timeseries using the specified
|
|
458
461
|
background one-sided power spectral densities (PSDs),
|
|
@@ -460,7 +463,8 @@ def whiten(
|
|
|
460
463
|
`fduration` and possibly to highpass filter.
|
|
461
464
|
|
|
462
465
|
Args:
|
|
463
|
-
X:
|
|
466
|
+
X:
|
|
467
|
+
batch of multichannel timeseries to whiten
|
|
464
468
|
psd:
|
|
465
469
|
PSDs use to whiten the data. The frequency
|
|
466
470
|
response of the whitening filter will be roughly
|
|
@@ -496,7 +500,7 @@ def whiten(
|
|
|
496
500
|
|
|
497
501
|
# figure out how much data we'll need to slice
|
|
498
502
|
# off after whitening
|
|
499
|
-
if isinstance(fduration,
|
|
503
|
+
if isinstance(fduration, Tensor):
|
|
500
504
|
pad = fduration.size(-1) // 2
|
|
501
505
|
else:
|
|
502
506
|
pad = int(fduration * sample_rate / 2)
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -4,5 +4,6 @@ from .scaler import ChannelWiseScaler
|
|
|
4
4
|
from .snr_rescaler import SnrRescaler
|
|
5
5
|
from .spectral import SpectralDensity
|
|
6
6
|
from .spectrogram import MultiResolutionSpectrogram
|
|
7
|
+
from .spline_interpolation import SplineInterpolate
|
|
7
8
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
8
9
|
from .whitening import FixedWhiten, Whiten
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
from torch import Tensor
|
|
2
4
|
|
|
5
|
+
from ml4gw.types import TimeSeries1to3d
|
|
3
6
|
from ml4gw.utils.slicing import unfold_windows
|
|
4
7
|
|
|
5
8
|
|
|
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
40
43
|
super().__init__()
|
|
41
44
|
self.max_shift = max_shift
|
|
42
45
|
|
|
43
|
-
def _shape_checks(self, x:
|
|
46
|
+
def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
|
|
44
47
|
if x.ndim > 3:
|
|
45
48
|
raise ValueError(
|
|
46
49
|
"Tensor x can only have up to 3 dimensions "
|
|
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
61
64
|
)
|
|
62
65
|
)
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
67
|
+
def forward(
|
|
68
|
+
self, x: TimeSeries1to3d, y: TimeSeries1to3d
|
|
69
|
+
) -> Float[Tensor, "windows ..."]:
|
|
66
70
|
self._shape_checks(x, y)
|
|
67
71
|
dim = x.size(-1)
|
|
68
72
|
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Tuple
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
7
|
+
from jaxtyping import Float, Int
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from ml4gw.transforms.spline_interpolation import SplineInterpolate
|
|
11
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
6
12
|
|
|
7
13
|
"""
|
|
8
14
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -34,7 +40,6 @@ class QTile(torch.nn.Module):
|
|
|
34
40
|
mismatch:
|
|
35
41
|
The maximum fractional mismatch between neighboring tiles
|
|
36
42
|
|
|
37
|
-
|
|
38
43
|
"""
|
|
39
44
|
|
|
40
45
|
def __init__(
|
|
@@ -44,7 +49,7 @@ class QTile(torch.nn.Module):
|
|
|
44
49
|
duration: float,
|
|
45
50
|
sample_rate: float,
|
|
46
51
|
mismatch: float,
|
|
47
|
-
):
|
|
52
|
+
) -> None:
|
|
48
53
|
super().__init__()
|
|
49
54
|
self.mismatch = mismatch
|
|
50
55
|
self.q = q
|
|
@@ -63,18 +68,18 @@ class QTile(torch.nn.Module):
|
|
|
63
68
|
self.register_buffer("indices", self.get_data_indices())
|
|
64
69
|
self.register_buffer("window", self.get_window())
|
|
65
70
|
|
|
66
|
-
def ntiles(self):
|
|
71
|
+
def ntiles(self) -> int:
|
|
67
72
|
"""
|
|
68
73
|
Number of tiles in this frequency row
|
|
69
74
|
"""
|
|
70
75
|
tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
|
|
71
76
|
return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
|
|
72
77
|
|
|
73
|
-
def _get_indices(self):
|
|
78
|
+
def _get_indices(self) -> Int[Tensor, " windowsize"]:
|
|
74
79
|
half = int((self.windowsize - 1) / 2)
|
|
75
80
|
return torch.arange(-half, half + 1)
|
|
76
81
|
|
|
77
|
-
def get_window(self):
|
|
82
|
+
def get_window(self) -> Float[Tensor, " windowsize"]:
|
|
78
83
|
"""
|
|
79
84
|
Generate the bi-square window for this row
|
|
80
85
|
"""
|
|
@@ -87,7 +92,7 @@ class QTile(torch.nn.Module):
|
|
|
87
92
|
)
|
|
88
93
|
return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
|
|
89
94
|
|
|
90
|
-
def get_data_indices(self):
|
|
95
|
+
def get_data_indices(self) -> Int[Tensor, " windowsize"]:
|
|
91
96
|
"""
|
|
92
97
|
Get the index array of relevant frequencies for this row
|
|
93
98
|
"""
|
|
@@ -95,7 +100,11 @@ class QTile(torch.nn.Module):
|
|
|
95
100
|
self._get_indices() + 1 + self.frequency * self.duration,
|
|
96
101
|
).type(torch.long)
|
|
97
102
|
|
|
98
|
-
def forward(
|
|
103
|
+
def forward(
|
|
104
|
+
self,
|
|
105
|
+
fseries: FrequencySeries1to3d,
|
|
106
|
+
norm: str = "median",
|
|
107
|
+
) -> TimeSeries1to3d:
|
|
99
108
|
"""
|
|
100
109
|
Compute the transform for this row
|
|
101
110
|
|
|
@@ -138,7 +147,7 @@ class QTile(torch.nn.Module):
|
|
|
138
147
|
energy /= means
|
|
139
148
|
else:
|
|
140
149
|
raise ValueError("Invalid normalisation %r" % norm)
|
|
141
|
-
|
|
150
|
+
energy = energy.type(torch.float32)
|
|
142
151
|
return energy
|
|
143
152
|
|
|
144
153
|
|
|
@@ -166,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
|
|
|
166
175
|
be chosen based on q, sample_rate, and duration
|
|
167
176
|
mismatch:
|
|
168
177
|
The maximum fractional mismatch between neighboring tiles
|
|
178
|
+
interpolation_method:
|
|
179
|
+
The method by which to interpolate each `QTile` to the specified
|
|
180
|
+
number of time and frequency bins. The acceptable values are
|
|
181
|
+
"bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
|
|
182
|
+
options will use PyTorch's built-in interpolation modes, while
|
|
183
|
+
"spline" will use the custom Torch-based implementation in
|
|
184
|
+
`ml4gw`, as PyTorch does not have spline-based intertpolation.
|
|
185
|
+
The "spline" mode is most similar to the results of GWpy's
|
|
186
|
+
Q-transform, which uses `scipy` to do spline interpolation.
|
|
187
|
+
However, it is also the slowest and most memory intensive due to
|
|
188
|
+
the matrix equation solving steps. Therefore, the default method
|
|
189
|
+
is "bicubic" as it produces the most similar results while
|
|
190
|
+
optimizing for computing performance.
|
|
169
191
|
"""
|
|
170
192
|
|
|
171
193
|
def __init__(
|
|
@@ -176,7 +198,8 @@ class SingleQTransform(torch.nn.Module):
|
|
|
176
198
|
q: float = 12,
|
|
177
199
|
frange: List[float] = [0, torch.inf],
|
|
178
200
|
mismatch: float = 0.2,
|
|
179
|
-
|
|
201
|
+
interpolation_method: str = "bicubic",
|
|
202
|
+
) -> None:
|
|
180
203
|
super().__init__()
|
|
181
204
|
self.q = q
|
|
182
205
|
self.spectrogram_shape = spectrogram_shape
|
|
@@ -184,21 +207,88 @@ class SingleQTransform(torch.nn.Module):
|
|
|
184
207
|
self.duration = duration
|
|
185
208
|
self.mismatch = mismatch
|
|
186
209
|
|
|
210
|
+
# If q is too large, the minimum of the frange computed
|
|
211
|
+
# below will be larger than the maximum
|
|
212
|
+
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
|
|
213
|
+
if q >= max_q:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"The given q value is too large for the given duration and "
|
|
216
|
+
f"sample rate. The maximum allowable value is {max_q}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if interpolation_method not in ["bilinear", "bicubic", "spline"]:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Interpolation method must be either 'bilinear', 'bicubic', "
|
|
222
|
+
f"or 'spline'; got {interpolation_method}"
|
|
223
|
+
)
|
|
224
|
+
self.interpolation_method = interpolation_method
|
|
225
|
+
|
|
187
226
|
qprime = self.q / 11 ** (1 / 2.0)
|
|
188
227
|
if self.frange[0] <= 0: # set non-zero lower frequency
|
|
189
228
|
self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
|
|
190
229
|
if math.isinf(self.frange[1]): # set non-infinite upper frequency
|
|
191
230
|
self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
|
|
231
|
+
|
|
192
232
|
self.freqs = self.get_freqs()
|
|
193
233
|
self.qtile_transforms = torch.nn.ModuleList(
|
|
194
234
|
[
|
|
195
|
-
QTile(
|
|
235
|
+
QTile(
|
|
236
|
+
q=self.q,
|
|
237
|
+
frequency=freq,
|
|
238
|
+
duration=self.duration,
|
|
239
|
+
sample_rate=sample_rate,
|
|
240
|
+
mismatch=self.mismatch,
|
|
241
|
+
)
|
|
196
242
|
for freq in self.freqs
|
|
197
243
|
]
|
|
198
244
|
)
|
|
199
245
|
self.qtiles = None
|
|
200
246
|
|
|
201
|
-
|
|
247
|
+
if self.interpolation_method == "spline":
|
|
248
|
+
self._set_up_spline_interp()
|
|
249
|
+
|
|
250
|
+
def _set_up_spline_interp(self):
|
|
251
|
+
ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
|
|
252
|
+
# For efficiency, we'll stack all qtiles of the same length before
|
|
253
|
+
# interpolating, so we need to figure out which those are
|
|
254
|
+
unique_ntiles = sorted(list(set(ntiles)))
|
|
255
|
+
idx = torch.arange(len(ntiles))
|
|
256
|
+
self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]
|
|
257
|
+
|
|
258
|
+
t_out = torch.arange(
|
|
259
|
+
0, self.duration, self.duration / self.spectrogram_shape[1]
|
|
260
|
+
)
|
|
261
|
+
self.qtile_interpolators = torch.nn.ModuleList(
|
|
262
|
+
[
|
|
263
|
+
SplineInterpolate(
|
|
264
|
+
kx=3,
|
|
265
|
+
x_in=torch.arange(0, self.duration, self.duration / tiles),
|
|
266
|
+
y_in=torch.arange(len(idx)),
|
|
267
|
+
x_out=t_out,
|
|
268
|
+
y_out=torch.arange(len(idx)),
|
|
269
|
+
)
|
|
270
|
+
for tiles, idx in zip(unique_ntiles, self.stack_idx)
|
|
271
|
+
]
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
t_in = t_out
|
|
275
|
+
f_in = self.freqs
|
|
276
|
+
f_out = torch.logspace(
|
|
277
|
+
math.log10(self.frange[0]),
|
|
278
|
+
math.log10(self.frange[-1]),
|
|
279
|
+
self.spectrogram_shape[0],
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self.interpolator = SplineInterpolate(
|
|
283
|
+
kx=3,
|
|
284
|
+
ky=3,
|
|
285
|
+
x_in=t_in,
|
|
286
|
+
y_in=f_in,
|
|
287
|
+
x_out=t_out,
|
|
288
|
+
y_out=f_out,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
202
292
|
"""
|
|
203
293
|
Calculate the frequencies that will be used in this transform.
|
|
204
294
|
For each frequency, a `QTile` is created.
|
|
@@ -214,7 +304,8 @@ class SingleQTransform(torch.nn.Module):
|
|
|
214
304
|
|
|
215
305
|
freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
|
|
216
306
|
freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
|
|
217
|
-
freqs
|
|
307
|
+
# Cast freqs to float64 to avoid off-by-ones from rounding
|
|
308
|
+
freqs = (minf * freqs.double() // fstepmin) * fstepmin
|
|
218
309
|
return torch.unique(freqs)
|
|
219
310
|
|
|
220
311
|
def get_max_energy(
|
|
@@ -262,7 +353,11 @@ class SingleQTransform(torch.nn.Module):
|
|
|
262
353
|
if dimension == "batch":
|
|
263
354
|
return torch.max(max_across_ft, dim=-1).values
|
|
264
355
|
|
|
265
|
-
def compute_qtiles(
|
|
356
|
+
def compute_qtiles(
|
|
357
|
+
self,
|
|
358
|
+
X: TimeSeries1to3d,
|
|
359
|
+
norm: str = "median",
|
|
360
|
+
) -> None:
|
|
266
361
|
"""
|
|
267
362
|
Take the FFT of the input timeseries and calculate the transform
|
|
268
363
|
for each `QTile`
|
|
@@ -272,36 +367,47 @@ class SingleQTransform(torch.nn.Module):
|
|
|
272
367
|
X[..., 1:] *= 2
|
|
273
368
|
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
274
369
|
|
|
275
|
-
def interpolate(self
|
|
276
|
-
"""
|
|
277
|
-
Interpolate each `QTile` to the specified number of time and
|
|
278
|
-
frequency bins. Note that PyTorch does not have the same
|
|
279
|
-
interpolation methods that GWpy uses, and so the interpolated
|
|
280
|
-
spectrograms will be different even though the uninterpolated
|
|
281
|
-
values match. The `bicubic` interpolation method is used as
|
|
282
|
-
it seems to match GWpy most closely.
|
|
283
|
-
"""
|
|
370
|
+
def interpolate(self) -> TimeSeries3d:
|
|
284
371
|
if self.qtiles is None:
|
|
285
372
|
raise RuntimeError(
|
|
286
373
|
"Q-tiles must first be computed with .compute_qtiles()"
|
|
287
374
|
)
|
|
375
|
+
if self.interpolation_method == "spline":
|
|
376
|
+
qtiles = [
|
|
377
|
+
torch.stack([self.qtiles[i] for i in idx], dim=-2)
|
|
378
|
+
for idx in self.stack_idx
|
|
379
|
+
]
|
|
380
|
+
time_interped = torch.cat(
|
|
381
|
+
[
|
|
382
|
+
interpolator(qtile)
|
|
383
|
+
for qtile, interpolator in zip(
|
|
384
|
+
qtiles, self.qtile_interpolators
|
|
385
|
+
)
|
|
386
|
+
],
|
|
387
|
+
dim=-2,
|
|
388
|
+
)
|
|
389
|
+
return self.interpolator(time_interped)
|
|
390
|
+
num_f_bins, num_t_bins = self.spectrogram_shape
|
|
288
391
|
resampled = [
|
|
289
392
|
F.interpolate(
|
|
290
|
-
qtile[None],
|
|
393
|
+
qtile[None],
|
|
394
|
+
(qtile.shape[-2], num_t_bins),
|
|
395
|
+
mode=self.interpolation_method,
|
|
291
396
|
)
|
|
292
397
|
for qtile in self.qtiles
|
|
293
398
|
]
|
|
294
399
|
resampled = torch.stack(resampled, dim=-2)
|
|
295
400
|
resampled = F.interpolate(
|
|
296
|
-
resampled[0],
|
|
401
|
+
resampled[0],
|
|
402
|
+
(num_f_bins, num_t_bins),
|
|
403
|
+
mode=self.interpolation_method,
|
|
297
404
|
)
|
|
298
405
|
return torch.squeeze(resampled)
|
|
299
406
|
|
|
300
407
|
def forward(
|
|
301
408
|
self,
|
|
302
|
-
X:
|
|
409
|
+
X: TimeSeries1to3d,
|
|
303
410
|
norm: str = "median",
|
|
304
|
-
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
305
411
|
):
|
|
306
412
|
"""
|
|
307
413
|
Compute the Q-tiles and interpolate
|
|
@@ -315,24 +421,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
315
421
|
three-dimensional, axes will be added during Q-tile
|
|
316
422
|
computation.
|
|
317
423
|
norm:
|
|
318
|
-
The method of
|
|
319
|
-
spectrogram_shape:
|
|
320
|
-
The shape of the interpolated spectrogram, specified as
|
|
321
|
-
`(num_f_bins, num_t_bins)`. Because the
|
|
322
|
-
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
323
|
-
interpolation is log-spaced as well. If not given, the shape
|
|
324
|
-
used to initialize the transform will be used.
|
|
424
|
+
The method of normalization used by each QTile
|
|
325
425
|
|
|
326
426
|
Returns:
|
|
327
427
|
The interpolated Q-transform for the batch of data. Output will
|
|
328
428
|
have one more dimension than the input
|
|
329
429
|
"""
|
|
330
430
|
|
|
331
|
-
if spectrogram_shape is None:
|
|
332
|
-
spectrogram_shape = self.spectrogram_shape
|
|
333
|
-
num_f_bins, num_t_bins = spectrogram_shape
|
|
334
431
|
self.compute_qtiles(X, norm)
|
|
335
|
-
return self.interpolate(
|
|
432
|
+
return self.interpolate()
|
|
336
433
|
|
|
337
434
|
|
|
338
435
|
class QScan(torch.nn.Module):
|
|
@@ -370,14 +467,22 @@ class QScan(torch.nn.Module):
|
|
|
370
467
|
spectrogram_shape: Tuple[int, int],
|
|
371
468
|
qrange: List[float] = [4, 64],
|
|
372
469
|
frange: List[float] = [0, torch.inf],
|
|
470
|
+
interpolation_method="bicubic",
|
|
373
471
|
mismatch: float = 0.2,
|
|
374
|
-
):
|
|
472
|
+
) -> None:
|
|
375
473
|
super().__init__()
|
|
376
474
|
self.qrange = qrange
|
|
377
475
|
self.mismatch = mismatch
|
|
378
|
-
self.qs = self.get_qs()
|
|
379
476
|
self.frange = frange
|
|
380
477
|
self.spectrogram_shape = spectrogram_shape
|
|
478
|
+
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
|
|
479
|
+
self.qs = self.get_qs()
|
|
480
|
+
if self.qs[-1] >= max_q:
|
|
481
|
+
warnings.warn(
|
|
482
|
+
"Some Q values exceed the maximum allowable Q value of "
|
|
483
|
+
f"{max_q}. The list of Q values to be tested in this "
|
|
484
|
+
"scan will be truncated to avoid those values."
|
|
485
|
+
)
|
|
381
486
|
|
|
382
487
|
# Deliberately doing something different from GWpy here.
|
|
383
488
|
# Their final frange is the intersection of the frange
|
|
@@ -391,13 +496,15 @@ class QScan(torch.nn.Module):
|
|
|
391
496
|
spectrogram_shape=spectrogram_shape,
|
|
392
497
|
q=q,
|
|
393
498
|
frange=self.frange.copy(),
|
|
499
|
+
interpolation_method=interpolation_method,
|
|
394
500
|
mismatch=self.mismatch,
|
|
395
501
|
)
|
|
396
502
|
for q in self.qs
|
|
503
|
+
if q < max_q
|
|
397
504
|
]
|
|
398
505
|
)
|
|
399
506
|
|
|
400
|
-
def get_qs(self):
|
|
507
|
+
def get_qs(self) -> List[float]:
|
|
401
508
|
"""
|
|
402
509
|
Determine the values of Q to try for the set of Q-transforms
|
|
403
510
|
"""
|
|
@@ -409,14 +516,14 @@ class QScan(torch.nn.Module):
|
|
|
409
516
|
self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
|
|
410
517
|
for i in range(nplanes)
|
|
411
518
|
]
|
|
519
|
+
|
|
412
520
|
return qs
|
|
413
521
|
|
|
414
522
|
def forward(
|
|
415
523
|
self,
|
|
416
|
-
X:
|
|
524
|
+
X: TimeSeries1to3d,
|
|
417
525
|
fsearch_range: List[float] = None,
|
|
418
526
|
norm: str = "median",
|
|
419
|
-
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
420
527
|
):
|
|
421
528
|
"""
|
|
422
529
|
Compute the set of QTiles for each Q transform and determine which
|
|
@@ -436,12 +543,6 @@ class QScan(torch.nn.Module):
|
|
|
436
543
|
for the maximum energy
|
|
437
544
|
norm:
|
|
438
545
|
The method of interpolation used by each QTile
|
|
439
|
-
spectrogram_shape:
|
|
440
|
-
The shape of the interpolated spectrogram, specified as
|
|
441
|
-
`(num_f_bins, num_t_bins)`. Because the
|
|
442
|
-
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
443
|
-
interpolation is log-spaced as well. If not given, the shape
|
|
444
|
-
used to initialize the transform will be used.
|
|
445
546
|
|
|
446
547
|
Returns:
|
|
447
548
|
An interpolated Q-transform for the batch of data. Output will
|
|
@@ -457,7 +558,4 @@ class QScan(torch.nn.Module):
|
|
|
457
558
|
]
|
|
458
559
|
)
|
|
459
560
|
)
|
|
460
|
-
|
|
461
|
-
spectrogram_shape = self.spectrogram_shape
|
|
462
|
-
num_f_bins, num_t_bins = spectrogram_shape
|
|
463
|
-
return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
|
|
561
|
+
return self.q_transforms[idx].interpolate()
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.transforms.transform import FittableTransform
|
|
6
8
|
|
|
@@ -34,7 +36,9 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
34
36
|
self.register_buffer("mean", mean)
|
|
35
37
|
self.register_buffer("std", std)
|
|
36
38
|
|
|
37
|
-
def fit(
|
|
39
|
+
def fit(
|
|
40
|
+
self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
|
|
41
|
+
) -> None:
|
|
38
42
|
"""Fit the scaling parameters to a timeseries
|
|
39
43
|
|
|
40
44
|
Computes the channel-wise mean and standard deviation
|
|
@@ -57,10 +61,12 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
57
61
|
"Can't fit channel wise mean and standard deviation "
|
|
58
62
|
"from tensor of shape {}".format(X.shape)
|
|
59
63
|
)
|
|
60
|
-
|
|
64
|
+
std += std_reg * torch.ones_like(std)
|
|
61
65
|
super().build(mean=mean, std=std)
|
|
62
66
|
|
|
63
|
-
def forward(
|
|
67
|
+
def forward(
|
|
68
|
+
self, X: Float[Tensor, "... time"], reverse: bool = False
|
|
69
|
+
) -> Float[Tensor, "... time"]:
|
|
64
70
|
if not reverse:
|
|
65
71
|
return (X - self.mean) / self.std
|
|
66
72
|
else:
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -2,8 +2,9 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ml4gw import
|
|
5
|
+
from ml4gw.gw import compute_network_snr
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SnrRescaler(FittableSpectralTransform):
|
|
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
34
35
|
|
|
35
36
|
def fit(
|
|
36
37
|
self,
|
|
37
|
-
*background:
|
|
38
|
+
*background: TimeSeries2d,
|
|
38
39
|
fftlength: Optional[float] = None,
|
|
39
40
|
overlap: Optional[float] = None,
|
|
40
41
|
):
|
|
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
58
59
|
|
|
59
60
|
def forward(
|
|
60
61
|
self,
|
|
61
|
-
responses:
|
|
62
|
-
target_snrs: Optional[
|
|
62
|
+
responses: WaveformTensor,
|
|
63
|
+
target_snrs: Optional[BatchTensor] = None,
|
|
63
64
|
):
|
|
64
|
-
snrs =
|
|
65
|
+
snrs = compute_network_snr(
|
|
65
66
|
responses, self.background, self.sample_rate, self.mask
|
|
66
67
|
)
|
|
67
68
|
if target_snrs is None:
|