ml4gw 0.7.7__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/augmentations.py +5 -0
- ml4gw/dataloading/__init__.py +5 -0
- ml4gw/dataloading/chunked_dataset.py +2 -4
- ml4gw/dataloading/hdf5_dataset.py +12 -10
- ml4gw/dataloading/in_memory_dataset.py +12 -12
- ml4gw/distributions.py +2 -2
- ml4gw/gw.py +18 -21
- ml4gw/nn/__init__.py +6 -0
- ml4gw/nn/autoencoder/base.py +5 -9
- ml4gw/nn/autoencoder/convolutional.py +7 -10
- ml4gw/nn/autoencoder/skip_connection.py +3 -5
- ml4gw/nn/norm.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +12 -13
- ml4gw/nn/resnet/resnet_2d.py +13 -14
- ml4gw/nn/streaming/online_average.py +3 -5
- ml4gw/nn/streaming/snapshotter.py +10 -14
- ml4gw/spectral.py +20 -23
- ml4gw/transforms/__init__.py +6 -0
- ml4gw/transforms/decimator.py +183 -0
- ml4gw/transforms/iirfilter.py +3 -5
- ml4gw/transforms/pearson.py +3 -4
- ml4gw/transforms/qtransform.py +10 -11
- ml4gw/transforms/scaler.py +3 -5
- ml4gw/transforms/snr_rescaler.py +7 -11
- ml4gw/transforms/spectral.py +6 -13
- ml4gw/transforms/spectrogram.py +6 -3
- ml4gw/transforms/spline_interpolation.py +7 -9
- ml4gw/transforms/transform.py +4 -6
- ml4gw/transforms/waveforms.py +8 -15
- ml4gw/transforms/whitening.py +11 -16
- ml4gw/types.py +8 -5
- ml4gw/utils/interferometer.py +20 -3
- ml4gw/utils/slicing.py +26 -30
- ml4gw/waveforms/__init__.py +6 -0
- ml4gw/waveforms/cbc/phenom_p.py +7 -9
- ml4gw/waveforms/conversion.py +2 -4
- ml4gw/waveforms/generator.py +3 -3
- {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/METADATA +28 -8
- ml4gw-0.7.8.dist-info/RECORD +57 -0
- ml4gw-0.7.7.dist-info/RECORD +0 -56
- {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +0 -0
- {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
- {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/top_level.txt +0 -0
ml4gw/spectral.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
+
This module provides functions for calculation of spectral densities
|
|
3
|
+
and for whitening.
|
|
4
|
+
|
|
2
5
|
Several implementation details are derived from the scipy csd and welch
|
|
3
6
|
implementations. For more info, see
|
|
4
7
|
|
|
@@ -9,8 +12,6 @@ and
|
|
|
9
12
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
|
|
10
13
|
"""
|
|
11
14
|
|
|
12
|
-
from typing import Optional, Union
|
|
13
|
-
|
|
14
15
|
import torch
|
|
15
16
|
from jaxtyping import Float
|
|
16
17
|
from torch import Tensor
|
|
@@ -36,13 +37,11 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
|
36
37
|
return torch.quantile(x, q=0.5, axis=axis) / bias
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def _validate_shapes(
|
|
40
|
-
x: Tensor, nperseg: int, y: Optional[Tensor] = None
|
|
41
|
-
) -> None:
|
|
40
|
+
def _validate_shapes(x: Tensor, nperseg: int, y: Tensor | None = None) -> None:
|
|
42
41
|
if x.shape[-1] < nperseg:
|
|
43
42
|
raise ValueError(
|
|
44
|
-
"Number of samples {} in input x is insufficient "
|
|
45
|
-
"for number of fft samples {}"
|
|
43
|
+
f"Number of samples {x.shape[-1]} in input x is insufficient "
|
|
44
|
+
f"for number of fft samples {nperseg}"
|
|
46
45
|
)
|
|
47
46
|
elif x.ndim > 3:
|
|
48
47
|
raise ValueError(
|
|
@@ -59,30 +58,30 @@ def _validate_shapes(
|
|
|
59
58
|
if x.shape[-1] != y.shape[-1]:
|
|
60
59
|
raise ValueError(
|
|
61
60
|
"Time dimensions of x and y tensors must "
|
|
62
|
-
"be the same, found {
|
|
61
|
+
f"be the same, found {x.shape[-1]} and {y.shape[-1]}"
|
|
63
62
|
)
|
|
64
63
|
elif x.ndim == 1 and not y.ndim == 1:
|
|
65
64
|
raise ValueError(
|
|
66
65
|
"Can't compute cross spectral density of "
|
|
67
|
-
"1D tensor x with {}D tensor y"
|
|
66
|
+
f"1D tensor x with {y.ndim}D tensor y"
|
|
68
67
|
)
|
|
69
68
|
elif x.ndim > 1 and y.ndim == x.ndim:
|
|
70
69
|
if not y.shape == x.shape:
|
|
71
70
|
raise ValueError(
|
|
72
71
|
"If x and y tensors have the same number "
|
|
73
72
|
"of dimensions, shapes must fully match. "
|
|
74
|
-
"Found shapes {} and {
|
|
73
|
+
f"Found shapes {x.shape} and {y.shape}"
|
|
75
74
|
)
|
|
76
75
|
elif x.ndim > 1 and y.ndim != (x.ndim - 1):
|
|
77
76
|
raise ValueError(
|
|
78
77
|
"Can't compute cross spectral density of "
|
|
79
|
-
"tensors with shapes {} and {
|
|
78
|
+
f"tensors with shapes {x.shape} and {y.shape}"
|
|
80
79
|
)
|
|
81
80
|
elif x.ndim > 2 and y.shape[0] != x.shape[0]:
|
|
82
81
|
raise ValueError(
|
|
83
82
|
"If x is a 3D tensor and y is a 2D tensor, "
|
|
84
83
|
"0th batch dimensions must match, but found "
|
|
85
|
-
"values {
|
|
84
|
+
f"values {x.shape[0]} and {y.shape[0]}"
|
|
86
85
|
)
|
|
87
86
|
|
|
88
87
|
|
|
@@ -93,7 +92,7 @@ def fast_spectral_density(
|
|
|
93
92
|
window: Float[Tensor, " {nperseg//2+1}"],
|
|
94
93
|
scale: float,
|
|
95
94
|
average: str = "median",
|
|
96
|
-
y:
|
|
95
|
+
y: TimeSeries1to3d | None = None,
|
|
97
96
|
) -> FrequencySeries1to3d:
|
|
98
97
|
"""
|
|
99
98
|
Compute the power spectral density of a multichannel
|
|
@@ -340,10 +339,10 @@ def spectral_density(
|
|
|
340
339
|
|
|
341
340
|
def truncate_inverse_power_spectrum(
|
|
342
341
|
psd: PSDTensor,
|
|
343
|
-
fduration:
|
|
342
|
+
fduration: Float[Tensor, " time"] | float,
|
|
344
343
|
sample_rate: float,
|
|
345
|
-
highpass:
|
|
346
|
-
lowpass:
|
|
344
|
+
highpass: float | None = None,
|
|
345
|
+
lowpass: float | None = None,
|
|
347
346
|
) -> PSDTensor:
|
|
348
347
|
"""
|
|
349
348
|
Truncate the length of the time domain response
|
|
@@ -460,10 +459,10 @@ def normalize_by_psd(
|
|
|
460
459
|
def whiten(
|
|
461
460
|
X: WaveformTensor,
|
|
462
461
|
psd: PSDTensor,
|
|
463
|
-
fduration:
|
|
462
|
+
fduration: Float[Tensor, " time"] | float,
|
|
464
463
|
sample_rate: float,
|
|
465
|
-
highpass:
|
|
466
|
-
lowpass:
|
|
464
|
+
highpass: float | None = None,
|
|
465
|
+
lowpass: float | None = None,
|
|
467
466
|
) -> WaveformTensor:
|
|
468
467
|
"""
|
|
469
468
|
Whiten a batch of timeseries using the specified
|
|
@@ -522,10 +521,8 @@ def whiten(
|
|
|
522
521
|
N = X.size(-1)
|
|
523
522
|
if N <= (2 * pad):
|
|
524
523
|
raise ValueError(
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
"padded samples {}"
|
|
528
|
-
).format(N, 2 * pad)
|
|
524
|
+
f"Not enough timeseries samples {N} for number of "
|
|
525
|
+
f"padded samples {2 * pad}"
|
|
529
526
|
)
|
|
530
527
|
|
|
531
528
|
# normalize the number of expected dimensions in the PSD
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains a variety of data transformation classes,
|
|
3
|
+
including objects to calculate spectral densities, whiten data,
|
|
4
|
+
and compute Q-transforms.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
from .iirfilter import IIRFilter
|
|
2
8
|
from .pearson import ShiftedPearsonCorrelation
|
|
3
9
|
from .qtransform import QScan, SingleQTransform
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Decimator(torch.nn.Module):
|
|
5
|
+
r"""
|
|
6
|
+
Downsample (decimate) a timeseries according to a user-defined schedule.
|
|
7
|
+
|
|
8
|
+
.. note::
|
|
9
|
+
|
|
10
|
+
This is a naive decimator that does not use any IIR/FIR filtering
|
|
11
|
+
and selects every M-th sample according to the schedule.
|
|
12
|
+
|
|
13
|
+
The schedule specifies which segments of the input to keep and at what
|
|
14
|
+
sampling rate. Each row of the schedule has the form:
|
|
15
|
+
|
|
16
|
+
`[start_time, end_time, target_sample_rate]`
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
sample_rate (int):
|
|
20
|
+
Sampling rate (Hz) of the input timeseries.
|
|
21
|
+
schedule (torch.Tensor):
|
|
22
|
+
Tensor of shape `(N, 3)` defining start time, end time,
|
|
23
|
+
and target sample rate for each segment.
|
|
24
|
+
|
|
25
|
+
Shape:
|
|
26
|
+
- Input: `(B, C, T)` where
|
|
27
|
+
- B = batch size
|
|
28
|
+
- C = channels
|
|
29
|
+
- T = number of timesteps
|
|
30
|
+
(must equal schedule duration × sample_rate)
|
|
31
|
+
- Output:
|
|
32
|
+
- If ``split=False`` → `(B, C, T')` where `T'` is total
|
|
33
|
+
number of decimated samples across all segments.
|
|
34
|
+
- If ``split=True`` → list of tensors, one per segment.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
torch.Tensor or List[torch.Tensor]:
|
|
38
|
+
The decimated timeseries, or list of decimated segments if
|
|
39
|
+
``split=True``.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
>>> import torch
|
|
45
|
+
>>> from ml4gw.transforms.decimator import Decimator
|
|
46
|
+
|
|
47
|
+
>>> sample_rate = 2048
|
|
48
|
+
>>> X_duration = 60
|
|
49
|
+
|
|
50
|
+
>>> schedule = torch.tensor(
|
|
51
|
+
... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
|
|
52
|
+
... dtype=torch.int,
|
|
53
|
+
... )
|
|
54
|
+
|
|
55
|
+
>>> decimator = Decimator(sample_rate=sample_rate,
|
|
56
|
+
... schedule=schedule)
|
|
57
|
+
>>> X = torch.randn(1, 1, sample_rate * X_duration)
|
|
58
|
+
>>> X_dec = decimator(X)
|
|
59
|
+
>>> X_seg = decimator(X, split=True)
|
|
60
|
+
|
|
61
|
+
>>> print("Original shape:", X.shape)
|
|
62
|
+
Original shape: torch.Size([1, 1, 122880])
|
|
63
|
+
>>> print("Decimated shape:", X_dec.shape)
|
|
64
|
+
Decimated shape: torch.Size([1, 1, 23552])
|
|
65
|
+
>>> for i, seg in enumerate(X_seg):
|
|
66
|
+
... print(f"Segment {i} shape:", seg.shape)
|
|
67
|
+
Segment 0 shape: torch.Size([1, 1, 10240])
|
|
68
|
+
Segment 1 shape: torch.Size([1, 1, 9216])
|
|
69
|
+
Segment 2 shape: torch.Size([1, 1, 4096])
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
sample_rate: int = None,
|
|
75
|
+
schedule: torch.Tensor = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.sample_rate = sample_rate
|
|
79
|
+
self.schedule = schedule
|
|
80
|
+
|
|
81
|
+
self._validate_inputs()
|
|
82
|
+
idx = self.build_variable_indices()
|
|
83
|
+
self.register_buffer("idx", idx)
|
|
84
|
+
|
|
85
|
+
self.expected_len = int(
|
|
86
|
+
(self.schedule[:, 1][-1] - self.schedule[:, 0][0])
|
|
87
|
+
* self.sample_rate
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _validate_inputs(self) -> None:
|
|
91
|
+
r"""
|
|
92
|
+
Validate the schedule and sample_rate.
|
|
93
|
+
"""
|
|
94
|
+
if self.schedule.ndim != 2 or self.schedule.shape[1] != 3:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Schedule must be of shape (N, 3), got {self.schedule.shape}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if not torch.all(self.schedule[:, 1] > self.schedule[:, 0]):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Each schedule segment must have end_time > start_time"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if torch.any(self.sample_rate % self.schedule[:, 2].long() != 0):
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Sample rate {self.sample_rate} must be divisible by all "
|
|
107
|
+
f"target rates {self.schedule[:, 2].tolist()}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def build_variable_indices(self) -> torch.Tensor:
|
|
111
|
+
r"""
|
|
112
|
+
Compute the time indices to keep based on the schedule.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
torch.Tensor:
|
|
116
|
+
1D tensor of indices used to decimate the input.
|
|
117
|
+
"""
|
|
118
|
+
idx = torch.tensor([], dtype=torch.long)
|
|
119
|
+
|
|
120
|
+
for s in self.schedule:
|
|
121
|
+
if idx.size(0) == 0:
|
|
122
|
+
start = int(s[0] * self.sample_rate)
|
|
123
|
+
else:
|
|
124
|
+
start = int(idx[-1]) + int(idx[-1] - idx[-2])
|
|
125
|
+
stop = int(start + (s[1] - s[0]) * self.sample_rate)
|
|
126
|
+
step = int(self.sample_rate // s[2])
|
|
127
|
+
new_idx = torch.arange(start, stop, step, dtype=torch.long)
|
|
128
|
+
idx = torch.cat((idx, new_idx))
|
|
129
|
+
return idx
|
|
130
|
+
|
|
131
|
+
def split_by_schedule(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
132
|
+
r"""
|
|
133
|
+
Split a decimated timeseries into segments according to the schedule.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
X (torch.Tensor):
|
|
137
|
+
Decimated input of shape `(B, C, T')`.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
tuple of torch.Tensor:
|
|
141
|
+
Each segment has shape :math:`(B, C, T_i)`
|
|
142
|
+
where :math:`T_i` is the length implied by
|
|
143
|
+
the corresponding schedule row.
|
|
144
|
+
"""
|
|
145
|
+
split_sizes = (
|
|
146
|
+
((self.schedule[:, 1] - self.schedule[:, 0]) * self.schedule[:, 2])
|
|
147
|
+
.long()
|
|
148
|
+
.tolist()
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return torch.split(X, split_sizes, dim=-1)
|
|
152
|
+
|
|
153
|
+
def forward(
|
|
154
|
+
self,
|
|
155
|
+
X: torch.Tensor,
|
|
156
|
+
split: bool = False,
|
|
157
|
+
) -> torch.Tensor | list[torch.Tensor]:
|
|
158
|
+
r"""
|
|
159
|
+
Apply decimation to the input timeseries.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
X (torch.Tensor):
|
|
163
|
+
Input tensor of shape `(B, C, T)`, where `T` must equal
|
|
164
|
+
schedule duration × sample_rate.
|
|
165
|
+
split (bool, optional):
|
|
166
|
+
If True, return a list of segments instead of a single
|
|
167
|
+
concatenated tensor. Default: False.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
torch.Tensor or List[torch.Tensor]:
|
|
171
|
+
Decimated timeseries, or list of decimated segments.
|
|
172
|
+
"""
|
|
173
|
+
if X.shape[-1] != self.expected_len:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"X length {X.shape[-1]} does not match "
|
|
176
|
+
f"expected schedule duration {self.expected_len}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
X_dec = X.index_select(dim=-1, index=self.idx)
|
|
180
|
+
|
|
181
|
+
if split:
|
|
182
|
+
X_dec = self.split_by_schedule(X_dec)
|
|
183
|
+
return X_dec
|
ml4gw/transforms/iirfilter.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from scipy.signal import iirfilter
|
|
5
3
|
from torchaudio.functional import filtfilt
|
|
@@ -55,9 +53,9 @@ class IIRFilter(torch.nn.Module):
|
|
|
55
53
|
def __init__(
|
|
56
54
|
self,
|
|
57
55
|
N: int,
|
|
58
|
-
Wn:
|
|
59
|
-
rs:
|
|
60
|
-
rp:
|
|
56
|
+
Wn: float | torch.Tensor,
|
|
57
|
+
rs: None | float | torch.Tensor = None,
|
|
58
|
+
rp: None | float | torch.Tensor = None,
|
|
61
59
|
btype="band",
|
|
62
60
|
analog=False,
|
|
63
61
|
ftype="butter",
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -52,15 +52,14 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
52
52
|
raise ValueError(
|
|
53
53
|
"y may not have more dimensions that x for "
|
|
54
54
|
"ShiftedPearsonCorrelation, but found shapes "
|
|
55
|
-
"{} and {
|
|
55
|
+
f"{y.shape} and {x.shape}"
|
|
56
56
|
)
|
|
57
57
|
for dim in range(y.ndim):
|
|
58
58
|
if y.size(-dim - 1) != x.size(-dim - 1):
|
|
59
59
|
raise ValueError(
|
|
60
60
|
"x and y expected to have same size along "
|
|
61
|
-
"last dimensions, but found shapes {} and
|
|
62
|
-
|
|
63
|
-
)
|
|
61
|
+
f"last dimensions, but found shapes {x.shape} and "
|
|
62
|
+
f"{y.shape}"
|
|
64
63
|
)
|
|
65
64
|
|
|
66
65
|
def forward(
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import List, Tuple
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
import torch.nn.functional as F
|
|
@@ -146,7 +145,7 @@ class QTile(torch.nn.Module):
|
|
|
146
145
|
means = torch.mean(energy, dim=-1, keepdim=True)
|
|
147
146
|
energy /= means
|
|
148
147
|
else:
|
|
149
|
-
raise ValueError("Invalid normalisation
|
|
148
|
+
raise ValueError(f"Invalid normalisation {norm}")
|
|
150
149
|
energy = energy.type(torch.float32)
|
|
151
150
|
return energy
|
|
152
151
|
|
|
@@ -194,9 +193,9 @@ class SingleQTransform(torch.nn.Module):
|
|
|
194
193
|
self,
|
|
195
194
|
duration: float,
|
|
196
195
|
sample_rate: float,
|
|
197
|
-
spectrogram_shape:
|
|
196
|
+
spectrogram_shape: tuple[int, int],
|
|
198
197
|
q: float = 12,
|
|
199
|
-
frange:
|
|
198
|
+
frange: list[float] = None,
|
|
200
199
|
mismatch: float = 0.2,
|
|
201
200
|
interpolation_method: str = "bicubic",
|
|
202
201
|
) -> None:
|
|
@@ -303,7 +302,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
303
302
|
return torch.unique(freqs)
|
|
304
303
|
|
|
305
304
|
def get_max_energy(
|
|
306
|
-
self, fsearch_range:
|
|
305
|
+
self, fsearch_range: list[float] = None, dimension: str = "both"
|
|
307
306
|
):
|
|
308
307
|
"""
|
|
309
308
|
Gets the maximum energy value among the QTiles. The maximum can
|
|
@@ -375,7 +374,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
375
374
|
[
|
|
376
375
|
qtile_interpolator(qtile)
|
|
377
376
|
for qtile, qtile_interpolator in zip(
|
|
378
|
-
qtiles, self.qtile_interpolators
|
|
377
|
+
qtiles, self.qtile_interpolators, strict=True
|
|
379
378
|
)
|
|
380
379
|
],
|
|
381
380
|
dim=-2,
|
|
@@ -459,9 +458,9 @@ class QScan(torch.nn.Module):
|
|
|
459
458
|
self,
|
|
460
459
|
duration: float,
|
|
461
460
|
sample_rate: float,
|
|
462
|
-
spectrogram_shape:
|
|
463
|
-
qrange:
|
|
464
|
-
frange:
|
|
461
|
+
spectrogram_shape: tuple[int, int],
|
|
462
|
+
qrange: list[float] = None,
|
|
463
|
+
frange: list[float] = None,
|
|
465
464
|
interpolation_method="bicubic",
|
|
466
465
|
mismatch: float = 0.2,
|
|
467
466
|
) -> None:
|
|
@@ -500,7 +499,7 @@ class QScan(torch.nn.Module):
|
|
|
500
499
|
]
|
|
501
500
|
)
|
|
502
501
|
|
|
503
|
-
def get_qs(self) ->
|
|
502
|
+
def get_qs(self) -> list[float]:
|
|
504
503
|
"""
|
|
505
504
|
Determine the values of Q to try for the set of Q-transforms
|
|
506
505
|
"""
|
|
@@ -518,7 +517,7 @@ class QScan(torch.nn.Module):
|
|
|
518
517
|
def forward(
|
|
519
518
|
self,
|
|
520
519
|
X: TimeSeries1to3d,
|
|
521
|
-
fsearch_range:
|
|
520
|
+
fsearch_range: list[float] = None,
|
|
522
521
|
norm: str = "median",
|
|
523
522
|
):
|
|
524
523
|
"""
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -24,7 +22,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
24
22
|
to be 1D (single channel).
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
|
-
def __init__(self, num_channels:
|
|
25
|
+
def __init__(self, num_channels: int | None = None) -> None:
|
|
28
26
|
super().__init__()
|
|
29
27
|
|
|
30
28
|
shape = (num_channels or 1,)
|
|
@@ -37,7 +35,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
37
35
|
self.register_buffer("std", std)
|
|
38
36
|
|
|
39
37
|
def fit(
|
|
40
|
-
self, X: Float[Tensor, "... time"], std_reg:
|
|
38
|
+
self, X: Float[Tensor, "... time"], std_reg: float | None = 0.0
|
|
41
39
|
) -> None:
|
|
42
40
|
"""Fit the scaling parameters to a timeseries
|
|
43
41
|
|
|
@@ -59,7 +57,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
59
57
|
else:
|
|
60
58
|
raise ValueError(
|
|
61
59
|
"Can't fit channel wise mean and standard deviation "
|
|
62
|
-
"from tensor of shape {
|
|
60
|
+
f"from tensor of shape {X.shape}"
|
|
63
61
|
)
|
|
64
62
|
std += std_reg * torch.ones_like(std)
|
|
65
63
|
super().build(mean=mean, std=std)
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
3
|
from ..gw import compute_network_snr
|
|
@@ -13,8 +11,8 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
13
11
|
num_channels: int,
|
|
14
12
|
sample_rate: float,
|
|
15
13
|
waveform_duration: float,
|
|
16
|
-
highpass:
|
|
17
|
-
lowpass:
|
|
14
|
+
highpass: float | None = None,
|
|
15
|
+
lowpass: float | None = None,
|
|
18
16
|
dtype: torch.dtype = torch.float32,
|
|
19
17
|
) -> None:
|
|
20
18
|
super().__init__()
|
|
@@ -45,15 +43,13 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
45
43
|
def fit(
|
|
46
44
|
self,
|
|
47
45
|
*background: TimeSeries2d,
|
|
48
|
-
fftlength:
|
|
49
|
-
overlap:
|
|
46
|
+
fftlength: float | None = None,
|
|
47
|
+
overlap: float | None = None,
|
|
50
48
|
):
|
|
51
49
|
if len(background) != self.num_channels:
|
|
52
50
|
raise ValueError(
|
|
53
|
-
"Expected to fit whitening transform on {}
|
|
54
|
-
"timeseries, but was passed {}"
|
|
55
|
-
self.num_channels, len(background)
|
|
56
|
-
)
|
|
51
|
+
f"Expected to fit whitening transform on {self.num_channels} "
|
|
52
|
+
f"background timeseries, but was passed {len(background)}"
|
|
57
53
|
)
|
|
58
54
|
|
|
59
55
|
num_freqs = self.background.size(1)
|
|
@@ -69,7 +65,7 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
69
65
|
def forward(
|
|
70
66
|
self,
|
|
71
67
|
responses: WaveformTensor,
|
|
72
|
-
target_snrs:
|
|
68
|
+
target_snrs: BatchTensor | None = None,
|
|
73
69
|
):
|
|
74
70
|
snrs = compute_network_snr(
|
|
75
71
|
responses,
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -52,20 +50,17 @@ class SpectralDensity(torch.nn.Module):
|
|
|
52
50
|
self,
|
|
53
51
|
sample_rate: float,
|
|
54
52
|
fftlength: float,
|
|
55
|
-
overlap:
|
|
53
|
+
overlap: float | None = None,
|
|
56
54
|
average: str = "mean",
|
|
57
|
-
window:
|
|
58
|
-
Float[Tensor, " {int(fftlength*sample_rate)}"]
|
|
59
|
-
] = None,
|
|
55
|
+
window: Float[Tensor, " {int(fftlength*sample_rate)}"] | None = None,
|
|
60
56
|
fast: bool = False,
|
|
61
57
|
) -> None:
|
|
62
58
|
if overlap is None:
|
|
63
59
|
overlap = fftlength / 2
|
|
64
60
|
elif overlap >= fftlength:
|
|
65
61
|
raise ValueError(
|
|
66
|
-
"Can't have overlap {} longer than fftlength
|
|
67
|
-
|
|
68
|
-
)
|
|
62
|
+
f"Can't have overlap {overlap} longer than fftlength "
|
|
63
|
+
f"{fftlength}"
|
|
69
64
|
)
|
|
70
65
|
|
|
71
66
|
super().__init__()
|
|
@@ -80,9 +75,7 @@ class SpectralDensity(torch.nn.Module):
|
|
|
80
75
|
|
|
81
76
|
if window.size(0) != self.nperseg:
|
|
82
77
|
raise ValueError(
|
|
83
|
-
"Window must have length {} got {}"
|
|
84
|
-
self.nperseg, window.size(0)
|
|
85
|
-
)
|
|
78
|
+
f"Window must have length {self.nperseg} got {window.size(0)}"
|
|
86
79
|
)
|
|
87
80
|
self.register_buffer("window", window)
|
|
88
81
|
|
|
@@ -99,7 +92,7 @@ class SpectralDensity(torch.nn.Module):
|
|
|
99
92
|
self.fast = fast
|
|
100
93
|
|
|
101
94
|
def forward(
|
|
102
|
-
self, x: TimeSeries1to3d, y:
|
|
95
|
+
self, x: TimeSeries1to3d, y: TimeSeries1to3d | None = None
|
|
103
96
|
) -> FrequencySeries1to3d:
|
|
104
97
|
if self.fast:
|
|
105
98
|
return fast_spectral_density(
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
import torch.nn.functional as F
|
|
@@ -104,7 +103,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
104
103
|
self.register_buffer("freq_idxs", freq_idxs)
|
|
105
104
|
self.register_buffer("time_idxs", time_idxs)
|
|
106
105
|
|
|
107
|
-
def _check_and_format_kwargs(self, kwargs:
|
|
106
|
+
def _check_and_format_kwargs(self, kwargs: dict[str, list]) -> list:
|
|
108
107
|
lengths = sorted(len(v) for v in kwargs.values())
|
|
109
108
|
lengths = list(set(lengths))
|
|
110
109
|
|
|
@@ -127,7 +126,10 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
127
126
|
size = lengths[1]
|
|
128
127
|
kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
|
|
129
128
|
|
|
130
|
-
return [
|
|
129
|
+
return [
|
|
130
|
+
dict(zip(kwargs, col, strict=True))
|
|
131
|
+
for col in zip(*kwargs.values(), strict=True)
|
|
132
|
+
]
|
|
131
133
|
|
|
132
134
|
def forward(
|
|
133
135
|
self, X: TimeSeries3d
|
|
@@ -161,6 +163,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
161
163
|
self.right_pad,
|
|
162
164
|
self.top_pad,
|
|
163
165
|
self.bottom_pad,
|
|
166
|
+
strict=True,
|
|
164
167
|
):
|
|
165
168
|
padded_specs.append(F.pad(spec, (left, right, top, bottom)))
|
|
166
169
|
|
|
@@ -2,8 +2,6 @@
|
|
|
2
2
|
Adaptation of code from https://github.com/dottormale/Qtransform_torch/
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Optional, Tuple
|
|
6
|
-
|
|
7
5
|
import torch
|
|
8
6
|
from torch import Tensor
|
|
9
7
|
|
|
@@ -50,7 +48,7 @@ class SplineInterpolateBase(torch.nn.Module):
|
|
|
50
48
|
t: Tensor,
|
|
51
49
|
d: int,
|
|
52
50
|
m: int,
|
|
53
|
-
) ->
|
|
51
|
+
) -> tuple[Tensor, Tensor]:
|
|
54
52
|
"""
|
|
55
53
|
Compute the L and R values for B-spline basis functions.
|
|
56
54
|
L and R are respectively the first and second coefficient multiplying
|
|
@@ -208,7 +206,7 @@ class SplineInterpolate1D(SplineInterpolateBase):
|
|
|
208
206
|
x_in: Tensor,
|
|
209
207
|
kx: int = 3,
|
|
210
208
|
sx: float = 0.0,
|
|
211
|
-
x_out:
|
|
209
|
+
x_out: Tensor | None = None,
|
|
212
210
|
):
|
|
213
211
|
super().__init__()
|
|
214
212
|
|
|
@@ -284,7 +282,7 @@ class SplineInterpolate1D(SplineInterpolateBase):
|
|
|
284
282
|
def forward(
|
|
285
283
|
self,
|
|
286
284
|
Z: Tensor,
|
|
287
|
-
x_out:
|
|
285
|
+
x_out: Tensor | None = None,
|
|
288
286
|
) -> Tensor:
|
|
289
287
|
"""
|
|
290
288
|
Compute the interpolated data
|
|
@@ -377,8 +375,8 @@ class SplineInterpolate2D(SplineInterpolateBase):
|
|
|
377
375
|
ky: int = 3,
|
|
378
376
|
sx: float = 0.0,
|
|
379
377
|
sy: float = 0.0,
|
|
380
|
-
x_out:
|
|
381
|
-
y_out:
|
|
378
|
+
x_out: Tensor | None = None,
|
|
379
|
+
y_out: Tensor | None = None,
|
|
382
380
|
):
|
|
383
381
|
super().__init__()
|
|
384
382
|
|
|
@@ -483,8 +481,8 @@ class SplineInterpolate2D(SplineInterpolateBase):
|
|
|
483
481
|
def forward(
|
|
484
482
|
self,
|
|
485
483
|
Z: Tensor,
|
|
486
|
-
x_out:
|
|
487
|
-
y_out:
|
|
484
|
+
x_out: Tensor | None = None,
|
|
485
|
+
y_out: Tensor | None = None,
|
|
488
486
|
) -> Tensor:
|
|
489
487
|
"""
|
|
490
488
|
Compute the interpolated data
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
3
|
from ..spectral import spectral_density
|
|
@@ -20,8 +18,8 @@ class FittableTransform(torch.nn.Module):
|
|
|
20
18
|
def _check_built(self):
|
|
21
19
|
if not self.built:
|
|
22
20
|
raise ValueError(
|
|
23
|
-
"Must fit parameters of {} transform
|
|
24
|
-
"before calling forward step"
|
|
21
|
+
f"Must fit parameters of {self.__class__.__name__} transform "
|
|
22
|
+
"to data before calling forward step"
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
def __call__(self, *args, **kwargs):
|
|
@@ -47,8 +45,8 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
47
45
|
x: TimeSeries1to3d,
|
|
48
46
|
sample_rate: float,
|
|
49
47
|
num_freqs: int,
|
|
50
|
-
fftlength:
|
|
51
|
-
overlap:
|
|
48
|
+
fftlength: float | None = None,
|
|
49
|
+
overlap: float | None = None,
|
|
52
50
|
) -> FrequencySeries1to3d:
|
|
53
51
|
# if we specified an FFT length, convert
|
|
54
52
|
# the (assumed) time-domain data to the
|