ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/augmentations.py +5 -0
- ml4gw/dataloading/__init__.py +5 -0
- ml4gw/dataloading/chunked_dataset.py +2 -4
- ml4gw/dataloading/hdf5_dataset.py +12 -10
- ml4gw/dataloading/in_memory_dataset.py +12 -12
- ml4gw/distributions.py +3 -3
- ml4gw/gw.py +18 -21
- ml4gw/nn/__init__.py +6 -0
- ml4gw/nn/autoencoder/base.py +5 -9
- ml4gw/nn/autoencoder/convolutional.py +7 -10
- ml4gw/nn/autoencoder/skip_connection.py +3 -5
- ml4gw/nn/norm.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +12 -13
- ml4gw/nn/resnet/resnet_2d.py +13 -14
- ml4gw/nn/streaming/online_average.py +3 -5
- ml4gw/nn/streaming/snapshotter.py +10 -14
- ml4gw/spectral.py +20 -23
- ml4gw/transforms/__init__.py +7 -1
- ml4gw/transforms/decimator.py +183 -0
- ml4gw/transforms/iirfilter.py +3 -5
- ml4gw/transforms/pearson.py +3 -4
- ml4gw/transforms/qtransform.py +20 -26
- ml4gw/transforms/scaler.py +3 -5
- ml4gw/transforms/snr_rescaler.py +7 -11
- ml4gw/transforms/spectral.py +6 -13
- ml4gw/transforms/spectrogram.py +6 -3
- ml4gw/transforms/spline_interpolation.py +312 -143
- ml4gw/transforms/transform.py +4 -6
- ml4gw/transforms/waveforms.py +8 -15
- ml4gw/transforms/whitening.py +11 -16
- ml4gw/types.py +8 -5
- ml4gw/utils/interferometer.py +20 -3
- ml4gw/utils/slicing.py +26 -30
- ml4gw/waveforms/__init__.py +6 -0
- ml4gw/waveforms/cbc/phenom_p.py +7 -9
- ml4gw/waveforms/conversion.py +2 -4
- ml4gw/waveforms/generator.py +3 -3
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
- ml4gw-0.7.8.dist-info/RECORD +57 -0
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
- ml4gw-0.7.8.dist-info/top_level.txt +1 -0
- ml4gw-0.7.6.dist-info/RECORD +0 -55
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Sequence
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from jaxtyping import Float
|
|
@@ -58,15 +58,13 @@ class Snapshotter(torch.nn.Module):
|
|
|
58
58
|
snapshot_size: int,
|
|
59
59
|
stride_size: int,
|
|
60
60
|
batch_size: int,
|
|
61
|
-
channels_per_snapshot:
|
|
61
|
+
channels_per_snapshot: Sequence[int] | None = None,
|
|
62
62
|
) -> None:
|
|
63
63
|
super().__init__()
|
|
64
64
|
if stride_size >= snapshot_size:
|
|
65
65
|
raise ValueError(
|
|
66
|
-
"Snapshotter can't accommodate stride {} "
|
|
67
|
-
"which is greater than snapshot size {}"
|
|
68
|
-
stride_size, snapshot_size
|
|
69
|
-
)
|
|
66
|
+
f"Snapshotter can't accommodate stride {stride_size} "
|
|
67
|
+
f"which is greater than snapshot size {snapshot_size}"
|
|
70
68
|
)
|
|
71
69
|
|
|
72
70
|
self.snapshot_size = snapshot_size
|
|
@@ -77,9 +75,8 @@ class Snapshotter(torch.nn.Module):
|
|
|
77
75
|
if channels_per_snapshot is not None:
|
|
78
76
|
if sum(channels_per_snapshot) != num_channels:
|
|
79
77
|
raise ValueError(
|
|
80
|
-
"Can't break {} channels into
|
|
81
|
-
|
|
82
|
-
)
|
|
78
|
+
f"Can't break {num_channels} channels into "
|
|
79
|
+
f"{channels_per_snapshot}"
|
|
83
80
|
)
|
|
84
81
|
self.channels_per_snapshot = channels_per_snapshot
|
|
85
82
|
self.num_channels = num_channels
|
|
@@ -90,8 +87,8 @@ class Snapshotter(torch.nn.Module):
|
|
|
90
87
|
def forward(
|
|
91
88
|
self,
|
|
92
89
|
update: Float[Tensor, "channel time1"],
|
|
93
|
-
snapshot:
|
|
94
|
-
) ->
|
|
90
|
+
snapshot: Float[Tensor, "channel time2"] | None = None,
|
|
91
|
+
) -> tuple[Tensor, ...]:
|
|
95
92
|
if snapshot is None:
|
|
96
93
|
snapshot = self.get_initial_state()
|
|
97
94
|
|
|
@@ -108,9 +105,8 @@ class Snapshotter(torch.nn.Module):
|
|
|
108
105
|
if self.channels_per_snapshot is not None:
|
|
109
106
|
if snapshots.size(1) != self.num_channels:
|
|
110
107
|
raise ValueError(
|
|
111
|
-
"Expected {} channels, found
|
|
112
|
-
|
|
113
|
-
)
|
|
108
|
+
f"Expected {self.num_channels} channels, found "
|
|
109
|
+
f"{snapshots.size(1)}"
|
|
114
110
|
)
|
|
115
111
|
snapshots = torch.split(
|
|
116
112
|
snapshots, self.channels_per_snapshot, dim=1
|
ml4gw/spectral.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
+
This module provides functions for calculation of spectral densities
|
|
3
|
+
and for whitening.
|
|
4
|
+
|
|
2
5
|
Several implementation details are derived from the scipy csd and welch
|
|
3
6
|
implementations. For more info, see
|
|
4
7
|
|
|
@@ -9,8 +12,6 @@ and
|
|
|
9
12
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
|
|
10
13
|
"""
|
|
11
14
|
|
|
12
|
-
from typing import Optional, Union
|
|
13
|
-
|
|
14
15
|
import torch
|
|
15
16
|
from jaxtyping import Float
|
|
16
17
|
from torch import Tensor
|
|
@@ -36,13 +37,11 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
|
36
37
|
return torch.quantile(x, q=0.5, axis=axis) / bias
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def _validate_shapes(
|
|
40
|
-
x: Tensor, nperseg: int, y: Optional[Tensor] = None
|
|
41
|
-
) -> None:
|
|
40
|
+
def _validate_shapes(x: Tensor, nperseg: int, y: Tensor | None = None) -> None:
|
|
42
41
|
if x.shape[-1] < nperseg:
|
|
43
42
|
raise ValueError(
|
|
44
|
-
"Number of samples {} in input x is insufficient "
|
|
45
|
-
"for number of fft samples {}"
|
|
43
|
+
f"Number of samples {x.shape[-1]} in input x is insufficient "
|
|
44
|
+
f"for number of fft samples {nperseg}"
|
|
46
45
|
)
|
|
47
46
|
elif x.ndim > 3:
|
|
48
47
|
raise ValueError(
|
|
@@ -59,30 +58,30 @@ def _validate_shapes(
|
|
|
59
58
|
if x.shape[-1] != y.shape[-1]:
|
|
60
59
|
raise ValueError(
|
|
61
60
|
"Time dimensions of x and y tensors must "
|
|
62
|
-
"be the same, found {
|
|
61
|
+
f"be the same, found {x.shape[-1]} and {y.shape[-1]}"
|
|
63
62
|
)
|
|
64
63
|
elif x.ndim == 1 and not y.ndim == 1:
|
|
65
64
|
raise ValueError(
|
|
66
65
|
"Can't compute cross spectral density of "
|
|
67
|
-
"1D tensor x with {}D tensor y"
|
|
66
|
+
f"1D tensor x with {y.ndim}D tensor y"
|
|
68
67
|
)
|
|
69
68
|
elif x.ndim > 1 and y.ndim == x.ndim:
|
|
70
69
|
if not y.shape == x.shape:
|
|
71
70
|
raise ValueError(
|
|
72
71
|
"If x and y tensors have the same number "
|
|
73
72
|
"of dimensions, shapes must fully match. "
|
|
74
|
-
"Found shapes {} and {
|
|
73
|
+
f"Found shapes {x.shape} and {y.shape}"
|
|
75
74
|
)
|
|
76
75
|
elif x.ndim > 1 and y.ndim != (x.ndim - 1):
|
|
77
76
|
raise ValueError(
|
|
78
77
|
"Can't compute cross spectral density of "
|
|
79
|
-
"tensors with shapes {} and {
|
|
78
|
+
f"tensors with shapes {x.shape} and {y.shape}"
|
|
80
79
|
)
|
|
81
80
|
elif x.ndim > 2 and y.shape[0] != x.shape[0]:
|
|
82
81
|
raise ValueError(
|
|
83
82
|
"If x is a 3D tensor and y is a 2D tensor, "
|
|
84
83
|
"0th batch dimensions must match, but found "
|
|
85
|
-
"values {
|
|
84
|
+
f"values {x.shape[0]} and {y.shape[0]}"
|
|
86
85
|
)
|
|
87
86
|
|
|
88
87
|
|
|
@@ -93,7 +92,7 @@ def fast_spectral_density(
|
|
|
93
92
|
window: Float[Tensor, " {nperseg//2+1}"],
|
|
94
93
|
scale: float,
|
|
95
94
|
average: str = "median",
|
|
96
|
-
y:
|
|
95
|
+
y: TimeSeries1to3d | None = None,
|
|
97
96
|
) -> FrequencySeries1to3d:
|
|
98
97
|
"""
|
|
99
98
|
Compute the power spectral density of a multichannel
|
|
@@ -340,10 +339,10 @@ def spectral_density(
|
|
|
340
339
|
|
|
341
340
|
def truncate_inverse_power_spectrum(
|
|
342
341
|
psd: PSDTensor,
|
|
343
|
-
fduration:
|
|
342
|
+
fduration: Float[Tensor, " time"] | float,
|
|
344
343
|
sample_rate: float,
|
|
345
|
-
highpass:
|
|
346
|
-
lowpass:
|
|
344
|
+
highpass: float | None = None,
|
|
345
|
+
lowpass: float | None = None,
|
|
347
346
|
) -> PSDTensor:
|
|
348
347
|
"""
|
|
349
348
|
Truncate the length of the time domain response
|
|
@@ -460,10 +459,10 @@ def normalize_by_psd(
|
|
|
460
459
|
def whiten(
|
|
461
460
|
X: WaveformTensor,
|
|
462
461
|
psd: PSDTensor,
|
|
463
|
-
fduration:
|
|
462
|
+
fduration: Float[Tensor, " time"] | float,
|
|
464
463
|
sample_rate: float,
|
|
465
|
-
highpass:
|
|
466
|
-
lowpass:
|
|
464
|
+
highpass: float | None = None,
|
|
465
|
+
lowpass: float | None = None,
|
|
467
466
|
) -> WaveformTensor:
|
|
468
467
|
"""
|
|
469
468
|
Whiten a batch of timeseries using the specified
|
|
@@ -522,10 +521,8 @@ def whiten(
|
|
|
522
521
|
N = X.size(-1)
|
|
523
522
|
if N <= (2 * pad):
|
|
524
523
|
raise ValueError(
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
"padded samples {}"
|
|
528
|
-
).format(N, 2 * pad)
|
|
524
|
+
f"Not enough timeseries samples {N} for number of "
|
|
525
|
+
f"padded samples {2 * pad}"
|
|
529
526
|
)
|
|
530
527
|
|
|
531
528
|
# normalize the number of expected dimensions in the PSD
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains a variety of data transformation classes,
|
|
3
|
+
including objects to calculate spectral densities, whiten data,
|
|
4
|
+
and compute Q-transforms.
|
|
5
|
+
"""
|
|
6
|
+
|
|
1
7
|
from .iirfilter import IIRFilter
|
|
2
8
|
from .pearson import ShiftedPearsonCorrelation
|
|
3
9
|
from .qtransform import QScan, SingleQTransform
|
|
@@ -5,6 +11,6 @@ from .scaler import ChannelWiseScaler
|
|
|
5
11
|
from .snr_rescaler import SnrRescaler
|
|
6
12
|
from .spectral import SpectralDensity
|
|
7
13
|
from .spectrogram import MultiResolutionSpectrogram
|
|
8
|
-
from .spline_interpolation import
|
|
14
|
+
from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
|
|
9
15
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
10
16
|
from .whitening import FixedWhiten, Whiten
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Decimator(torch.nn.Module):
|
|
5
|
+
r"""
|
|
6
|
+
Downsample (decimate) a timeseries according to a user-defined schedule.
|
|
7
|
+
|
|
8
|
+
.. note::
|
|
9
|
+
|
|
10
|
+
This is a naive decimator that does not use any IIR/FIR filtering
|
|
11
|
+
and selects every M-th sample according to the schedule.
|
|
12
|
+
|
|
13
|
+
The schedule specifies which segments of the input to keep and at what
|
|
14
|
+
sampling rate. Each row of the schedule has the form:
|
|
15
|
+
|
|
16
|
+
`[start_time, end_time, target_sample_rate]`
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
sample_rate (int):
|
|
20
|
+
Sampling rate (Hz) of the input timeseries.
|
|
21
|
+
schedule (torch.Tensor):
|
|
22
|
+
Tensor of shape `(N, 3)` defining start time, end time,
|
|
23
|
+
and target sample rate for each segment.
|
|
24
|
+
|
|
25
|
+
Shape:
|
|
26
|
+
- Input: `(B, C, T)` where
|
|
27
|
+
- B = batch size
|
|
28
|
+
- C = channels
|
|
29
|
+
- T = number of timesteps
|
|
30
|
+
(must equal schedule duration × sample_rate)
|
|
31
|
+
- Output:
|
|
32
|
+
- If ``split=False`` → `(B, C, T')` where `T'` is total
|
|
33
|
+
number of decimated samples across all segments.
|
|
34
|
+
- If ``split=True`` → list of tensors, one per segment.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
torch.Tensor or List[torch.Tensor]:
|
|
38
|
+
The decimated timeseries, or list of decimated segments if
|
|
39
|
+
``split=True``.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
>>> import torch
|
|
45
|
+
>>> from ml4gw.transforms.decimator import Decimator
|
|
46
|
+
|
|
47
|
+
>>> sample_rate = 2048
|
|
48
|
+
>>> X_duration = 60
|
|
49
|
+
|
|
50
|
+
>>> schedule = torch.tensor(
|
|
51
|
+
... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
|
|
52
|
+
... dtype=torch.int,
|
|
53
|
+
... )
|
|
54
|
+
|
|
55
|
+
>>> decimator = Decimator(sample_rate=sample_rate,
|
|
56
|
+
... schedule=schedule)
|
|
57
|
+
>>> X = torch.randn(1, 1, sample_rate * X_duration)
|
|
58
|
+
>>> X_dec = decimator(X)
|
|
59
|
+
>>> X_seg = decimator(X, split=True)
|
|
60
|
+
|
|
61
|
+
>>> print("Original shape:", X.shape)
|
|
62
|
+
Original shape: torch.Size([1, 1, 122880])
|
|
63
|
+
>>> print("Decimated shape:", X_dec.shape)
|
|
64
|
+
Decimated shape: torch.Size([1, 1, 23552])
|
|
65
|
+
>>> for i, seg in enumerate(X_seg):
|
|
66
|
+
... print(f"Segment {i} shape:", seg.shape)
|
|
67
|
+
Segment 0 shape: torch.Size([1, 1, 10240])
|
|
68
|
+
Segment 1 shape: torch.Size([1, 1, 9216])
|
|
69
|
+
Segment 2 shape: torch.Size([1, 1, 4096])
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
sample_rate: int = None,
|
|
75
|
+
schedule: torch.Tensor = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.sample_rate = sample_rate
|
|
79
|
+
self.schedule = schedule
|
|
80
|
+
|
|
81
|
+
self._validate_inputs()
|
|
82
|
+
idx = self.build_variable_indices()
|
|
83
|
+
self.register_buffer("idx", idx)
|
|
84
|
+
|
|
85
|
+
self.expected_len = int(
|
|
86
|
+
(self.schedule[:, 1][-1] - self.schedule[:, 0][0])
|
|
87
|
+
* self.sample_rate
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _validate_inputs(self) -> None:
|
|
91
|
+
r"""
|
|
92
|
+
Validate the schedule and sample_rate.
|
|
93
|
+
"""
|
|
94
|
+
if self.schedule.ndim != 2 or self.schedule.shape[1] != 3:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Schedule must be of shape (N, 3), got {self.schedule.shape}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if not torch.all(self.schedule[:, 1] > self.schedule[:, 0]):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"Each schedule segment must have end_time > start_time"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if torch.any(self.sample_rate % self.schedule[:, 2].long() != 0):
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Sample rate {self.sample_rate} must be divisible by all "
|
|
107
|
+
f"target rates {self.schedule[:, 2].tolist()}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def build_variable_indices(self) -> torch.Tensor:
|
|
111
|
+
r"""
|
|
112
|
+
Compute the time indices to keep based on the schedule.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
torch.Tensor:
|
|
116
|
+
1D tensor of indices used to decimate the input.
|
|
117
|
+
"""
|
|
118
|
+
idx = torch.tensor([], dtype=torch.long)
|
|
119
|
+
|
|
120
|
+
for s in self.schedule:
|
|
121
|
+
if idx.size(0) == 0:
|
|
122
|
+
start = int(s[0] * self.sample_rate)
|
|
123
|
+
else:
|
|
124
|
+
start = int(idx[-1]) + int(idx[-1] - idx[-2])
|
|
125
|
+
stop = int(start + (s[1] - s[0]) * self.sample_rate)
|
|
126
|
+
step = int(self.sample_rate // s[2])
|
|
127
|
+
new_idx = torch.arange(start, stop, step, dtype=torch.long)
|
|
128
|
+
idx = torch.cat((idx, new_idx))
|
|
129
|
+
return idx
|
|
130
|
+
|
|
131
|
+
def split_by_schedule(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
132
|
+
r"""
|
|
133
|
+
Split a decimated timeseries into segments according to the schedule.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
X (torch.Tensor):
|
|
137
|
+
Decimated input of shape `(B, C, T')`.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
tuple of torch.Tensor:
|
|
141
|
+
Each segment has shape :math:`(B, C, T_i)`
|
|
142
|
+
where :math:`T_i` is the length implied by
|
|
143
|
+
the corresponding schedule row.
|
|
144
|
+
"""
|
|
145
|
+
split_sizes = (
|
|
146
|
+
((self.schedule[:, 1] - self.schedule[:, 0]) * self.schedule[:, 2])
|
|
147
|
+
.long()
|
|
148
|
+
.tolist()
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return torch.split(X, split_sizes, dim=-1)
|
|
152
|
+
|
|
153
|
+
def forward(
|
|
154
|
+
self,
|
|
155
|
+
X: torch.Tensor,
|
|
156
|
+
split: bool = False,
|
|
157
|
+
) -> torch.Tensor | list[torch.Tensor]:
|
|
158
|
+
r"""
|
|
159
|
+
Apply decimation to the input timeseries.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
X (torch.Tensor):
|
|
163
|
+
Input tensor of shape `(B, C, T)`, where `T` must equal
|
|
164
|
+
schedule duration × sample_rate.
|
|
165
|
+
split (bool, optional):
|
|
166
|
+
If True, return a list of segments instead of a single
|
|
167
|
+
concatenated tensor. Default: False.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
torch.Tensor or List[torch.Tensor]:
|
|
171
|
+
Decimated timeseries, or list of decimated segments.
|
|
172
|
+
"""
|
|
173
|
+
if X.shape[-1] != self.expected_len:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"X length {X.shape[-1]} does not match "
|
|
176
|
+
f"expected schedule duration {self.expected_len}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
X_dec = X.index_select(dim=-1, index=self.idx)
|
|
180
|
+
|
|
181
|
+
if split:
|
|
182
|
+
X_dec = self.split_by_schedule(X_dec)
|
|
183
|
+
return X_dec
|
ml4gw/transforms/iirfilter.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from scipy.signal import iirfilter
|
|
5
3
|
from torchaudio.functional import filtfilt
|
|
@@ -55,9 +53,9 @@ class IIRFilter(torch.nn.Module):
|
|
|
55
53
|
def __init__(
|
|
56
54
|
self,
|
|
57
55
|
N: int,
|
|
58
|
-
Wn:
|
|
59
|
-
rs:
|
|
60
|
-
rp:
|
|
56
|
+
Wn: float | torch.Tensor,
|
|
57
|
+
rs: None | float | torch.Tensor = None,
|
|
58
|
+
rp: None | float | torch.Tensor = None,
|
|
61
59
|
btype="band",
|
|
62
60
|
analog=False,
|
|
63
61
|
ftype="butter",
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -52,15 +52,14 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
52
52
|
raise ValueError(
|
|
53
53
|
"y may not have more dimensions that x for "
|
|
54
54
|
"ShiftedPearsonCorrelation, but found shapes "
|
|
55
|
-
"{} and {
|
|
55
|
+
f"{y.shape} and {x.shape}"
|
|
56
56
|
)
|
|
57
57
|
for dim in range(y.ndim):
|
|
58
58
|
if y.size(-dim - 1) != x.size(-dim - 1):
|
|
59
59
|
raise ValueError(
|
|
60
60
|
"x and y expected to have same size along "
|
|
61
|
-
"last dimensions, but found shapes {} and
|
|
62
|
-
|
|
63
|
-
)
|
|
61
|
+
f"last dimensions, but found shapes {x.shape} and "
|
|
62
|
+
f"{y.shape}"
|
|
64
63
|
)
|
|
65
64
|
|
|
66
65
|
def forward(
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import List, Tuple
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
import torch.nn.functional as F
|
|
@@ -8,7 +7,7 @@ from jaxtyping import Float, Int
|
|
|
8
7
|
from torch import Tensor
|
|
9
8
|
|
|
10
9
|
from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
11
|
-
from .spline_interpolation import
|
|
10
|
+
from .spline_interpolation import SplineInterpolate1D
|
|
12
11
|
|
|
13
12
|
"""
|
|
14
13
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -146,7 +145,7 @@ class QTile(torch.nn.Module):
|
|
|
146
145
|
means = torch.mean(energy, dim=-1, keepdim=True)
|
|
147
146
|
energy /= means
|
|
148
147
|
else:
|
|
149
|
-
raise ValueError("Invalid normalisation
|
|
148
|
+
raise ValueError(f"Invalid normalisation {norm}")
|
|
150
149
|
energy = energy.type(torch.float32)
|
|
151
150
|
return energy
|
|
152
151
|
|
|
@@ -194,9 +193,9 @@ class SingleQTransform(torch.nn.Module):
|
|
|
194
193
|
self,
|
|
195
194
|
duration: float,
|
|
196
195
|
sample_rate: float,
|
|
197
|
-
spectrogram_shape:
|
|
196
|
+
spectrogram_shape: tuple[int, int],
|
|
198
197
|
q: float = 12,
|
|
199
|
-
frange:
|
|
198
|
+
frange: list[float] = None,
|
|
200
199
|
mismatch: float = 0.2,
|
|
201
200
|
interpolation_method: str = "bicubic",
|
|
202
201
|
) -> None:
|
|
@@ -260,18 +259,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
260
259
|
)
|
|
261
260
|
self.qtile_interpolators = torch.nn.ModuleList(
|
|
262
261
|
[
|
|
263
|
-
|
|
262
|
+
SplineInterpolate1D(
|
|
264
263
|
kx=3,
|
|
265
264
|
x_in=torch.arange(0, self.duration, self.duration / tiles),
|
|
266
|
-
y_in=torch.arange(len(idx)),
|
|
267
265
|
x_out=t_out,
|
|
268
|
-
y_out=torch.arange(len(idx)),
|
|
269
266
|
)
|
|
270
|
-
for tiles
|
|
267
|
+
for tiles in unique_ntiles
|
|
271
268
|
]
|
|
272
269
|
)
|
|
273
270
|
|
|
274
|
-
t_in = t_out
|
|
275
271
|
f_in = self.freqs
|
|
276
272
|
f_out = torch.logspace(
|
|
277
273
|
math.log10(self.frange[0]),
|
|
@@ -279,13 +275,10 @@ class SingleQTransform(torch.nn.Module):
|
|
|
279
275
|
self.spectrogram_shape[0],
|
|
280
276
|
)
|
|
281
277
|
|
|
282
|
-
self.interpolator =
|
|
278
|
+
self.interpolator = SplineInterpolate1D(
|
|
283
279
|
kx=3,
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
y_in=f_in,
|
|
287
|
-
x_out=t_out,
|
|
288
|
-
y_out=f_out,
|
|
280
|
+
x_in=f_in,
|
|
281
|
+
x_out=f_out,
|
|
289
282
|
)
|
|
290
283
|
|
|
291
284
|
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
@@ -309,7 +302,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
309
302
|
return torch.unique(freqs)
|
|
310
303
|
|
|
311
304
|
def get_max_energy(
|
|
312
|
-
self, fsearch_range:
|
|
305
|
+
self, fsearch_range: list[float] = None, dimension: str = "both"
|
|
313
306
|
):
|
|
314
307
|
"""
|
|
315
308
|
Gets the maximum energy value among the QTiles. The maximum can
|
|
@@ -379,14 +372,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
379
372
|
]
|
|
380
373
|
time_interped = torch.cat(
|
|
381
374
|
[
|
|
382
|
-
|
|
383
|
-
for qtile,
|
|
384
|
-
qtiles, self.qtile_interpolators
|
|
375
|
+
qtile_interpolator(qtile)
|
|
376
|
+
for qtile, qtile_interpolator in zip(
|
|
377
|
+
qtiles, self.qtile_interpolators, strict=True
|
|
385
378
|
)
|
|
386
379
|
],
|
|
387
380
|
dim=-2,
|
|
388
381
|
)
|
|
389
|
-
|
|
382
|
+
# Transpose because the final dimension gets interpolated
|
|
383
|
+
return self.interpolator(time_interped.mT).mT
|
|
390
384
|
num_f_bins, num_t_bins = self.spectrogram_shape
|
|
391
385
|
resampled = [
|
|
392
386
|
F.interpolate(
|
|
@@ -464,9 +458,9 @@ class QScan(torch.nn.Module):
|
|
|
464
458
|
self,
|
|
465
459
|
duration: float,
|
|
466
460
|
sample_rate: float,
|
|
467
|
-
spectrogram_shape:
|
|
468
|
-
qrange:
|
|
469
|
-
frange:
|
|
461
|
+
spectrogram_shape: tuple[int, int],
|
|
462
|
+
qrange: list[float] = None,
|
|
463
|
+
frange: list[float] = None,
|
|
470
464
|
interpolation_method="bicubic",
|
|
471
465
|
mismatch: float = 0.2,
|
|
472
466
|
) -> None:
|
|
@@ -505,7 +499,7 @@ class QScan(torch.nn.Module):
|
|
|
505
499
|
]
|
|
506
500
|
)
|
|
507
501
|
|
|
508
|
-
def get_qs(self) ->
|
|
502
|
+
def get_qs(self) -> list[float]:
|
|
509
503
|
"""
|
|
510
504
|
Determine the values of Q to try for the set of Q-transforms
|
|
511
505
|
"""
|
|
@@ -523,7 +517,7 @@ class QScan(torch.nn.Module):
|
|
|
523
517
|
def forward(
|
|
524
518
|
self,
|
|
525
519
|
X: TimeSeries1to3d,
|
|
526
|
-
fsearch_range:
|
|
520
|
+
fsearch_range: list[float] = None,
|
|
527
521
|
norm: str = "median",
|
|
528
522
|
):
|
|
529
523
|
"""
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -24,7 +22,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
24
22
|
to be 1D (single channel).
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
|
-
def __init__(self, num_channels:
|
|
25
|
+
def __init__(self, num_channels: int | None = None) -> None:
|
|
28
26
|
super().__init__()
|
|
29
27
|
|
|
30
28
|
shape = (num_channels or 1,)
|
|
@@ -37,7 +35,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
37
35
|
self.register_buffer("std", std)
|
|
38
36
|
|
|
39
37
|
def fit(
|
|
40
|
-
self, X: Float[Tensor, "... time"], std_reg:
|
|
38
|
+
self, X: Float[Tensor, "... time"], std_reg: float | None = 0.0
|
|
41
39
|
) -> None:
|
|
42
40
|
"""Fit the scaling parameters to a timeseries
|
|
43
41
|
|
|
@@ -59,7 +57,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
59
57
|
else:
|
|
60
58
|
raise ValueError(
|
|
61
59
|
"Can't fit channel wise mean and standard deviation "
|
|
62
|
-
"from tensor of shape {
|
|
60
|
+
f"from tensor of shape {X.shape}"
|
|
63
61
|
)
|
|
64
62
|
std += std_reg * torch.ones_like(std)
|
|
65
63
|
super().build(mean=mean, std=std)
|
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
3
|
from ..gw import compute_network_snr
|
|
@@ -13,8 +11,8 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
13
11
|
num_channels: int,
|
|
14
12
|
sample_rate: float,
|
|
15
13
|
waveform_duration: float,
|
|
16
|
-
highpass:
|
|
17
|
-
lowpass:
|
|
14
|
+
highpass: float | None = None,
|
|
15
|
+
lowpass: float | None = None,
|
|
18
16
|
dtype: torch.dtype = torch.float32,
|
|
19
17
|
) -> None:
|
|
20
18
|
super().__init__()
|
|
@@ -45,15 +43,13 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
45
43
|
def fit(
|
|
46
44
|
self,
|
|
47
45
|
*background: TimeSeries2d,
|
|
48
|
-
fftlength:
|
|
49
|
-
overlap:
|
|
46
|
+
fftlength: float | None = None,
|
|
47
|
+
overlap: float | None = None,
|
|
50
48
|
):
|
|
51
49
|
if len(background) != self.num_channels:
|
|
52
50
|
raise ValueError(
|
|
53
|
-
"Expected to fit whitening transform on {}
|
|
54
|
-
"timeseries, but was passed {}"
|
|
55
|
-
self.num_channels, len(background)
|
|
56
|
-
)
|
|
51
|
+
f"Expected to fit whitening transform on {self.num_channels} "
|
|
52
|
+
f"background timeseries, but was passed {len(background)}"
|
|
57
53
|
)
|
|
58
54
|
|
|
59
55
|
num_freqs = self.background.size(1)
|
|
@@ -69,7 +65,7 @@ class SnrRescaler(FittableSpectralTransform):
|
|
|
69
65
|
def forward(
|
|
70
66
|
self,
|
|
71
67
|
responses: WaveformTensor,
|
|
72
|
-
target_snrs:
|
|
68
|
+
target_snrs: BatchTensor | None = None,
|
|
73
69
|
):
|
|
74
70
|
snrs = compute_network_snr(
|
|
75
71
|
responses,
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -52,20 +50,17 @@ class SpectralDensity(torch.nn.Module):
|
|
|
52
50
|
self,
|
|
53
51
|
sample_rate: float,
|
|
54
52
|
fftlength: float,
|
|
55
|
-
overlap:
|
|
53
|
+
overlap: float | None = None,
|
|
56
54
|
average: str = "mean",
|
|
57
|
-
window:
|
|
58
|
-
Float[Tensor, " {int(fftlength*sample_rate)}"]
|
|
59
|
-
] = None,
|
|
55
|
+
window: Float[Tensor, " {int(fftlength*sample_rate)}"] | None = None,
|
|
60
56
|
fast: bool = False,
|
|
61
57
|
) -> None:
|
|
62
58
|
if overlap is None:
|
|
63
59
|
overlap = fftlength / 2
|
|
64
60
|
elif overlap >= fftlength:
|
|
65
61
|
raise ValueError(
|
|
66
|
-
"Can't have overlap {} longer than fftlength
|
|
67
|
-
|
|
68
|
-
)
|
|
62
|
+
f"Can't have overlap {overlap} longer than fftlength "
|
|
63
|
+
f"{fftlength}"
|
|
69
64
|
)
|
|
70
65
|
|
|
71
66
|
super().__init__()
|
|
@@ -80,9 +75,7 @@ class SpectralDensity(torch.nn.Module):
|
|
|
80
75
|
|
|
81
76
|
if window.size(0) != self.nperseg:
|
|
82
77
|
raise ValueError(
|
|
83
|
-
"Window must have length {} got {}"
|
|
84
|
-
self.nperseg, window.size(0)
|
|
85
|
-
)
|
|
78
|
+
f"Window must have length {self.nperseg} got {window.size(0)}"
|
|
86
79
|
)
|
|
87
80
|
self.register_buffer("window", window)
|
|
88
81
|
|
|
@@ -99,7 +92,7 @@ class SpectralDensity(torch.nn.Module):
|
|
|
99
92
|
self.fast = fast
|
|
100
93
|
|
|
101
94
|
def forward(
|
|
102
|
-
self, x: TimeSeries1to3d, y:
|
|
95
|
+
self, x: TimeSeries1to3d, y: TimeSeries1to3d | None = None
|
|
103
96
|
) -> FrequencySeries1to3d:
|
|
104
97
|
if self.fast:
|
|
105
98
|
return fast_spectral_density(
|