ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +43 -0
- ml4gw/dataloading/__init__.py +2 -1
- ml4gw/dataloading/chunked_dataset.py +66 -212
- ml4gw/dataloading/hdf5_dataset.py +176 -0
- ml4gw/nn/__init__.py +0 -0
- ml4gw/nn/autoencoder/__init__.py +3 -0
- ml4gw/nn/autoencoder/base.py +89 -0
- ml4gw/nn/autoencoder/convolutional.py +156 -0
- ml4gw/nn/autoencoder/skip_connection.py +46 -0
- ml4gw/nn/autoencoder/utils.py +14 -0
- ml4gw/nn/norm.py +97 -0
- ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw/nn/streaming/__init__.py +2 -0
- ml4gw/nn/streaming/online_average.py +121 -0
- ml4gw/nn/streaming/snapshotter.py +121 -0
- ml4gw/transforms/__init__.py +2 -0
- ml4gw/transforms/pearson.py +87 -0
- ml4gw/transforms/spectrogram.py +162 -0
- ml4gw/transforms/whitening.py +1 -1
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/phenom_d.py +1359 -0
- ml4gw/waveforms/phenom_d_data.py +3026 -0
- ml4gw/waveforms/taylorf2.py +306 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/METADATA +14 -6
- ml4gw-0.4.0.dist-info/RECORD +43 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/WHEEL +1 -1
- ml4gw-0.2.0.dist-info/RECORD +0 -23
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ml4gw.utils.slicing import unfold_windows
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Compute the [Pearson correlation]
|
|
9
|
+
(https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
|
|
10
|
+
for two equal-length timeseries over a pre-defined number of time
|
|
11
|
+
shifts in each direction. Useful for when you want a
|
|
12
|
+
correlation, but not over every possible shift (i.e.
|
|
13
|
+
a convolution).
|
|
14
|
+
|
|
15
|
+
The number of dimensions of the second timeseries `y`
|
|
16
|
+
passed at call time should always be less than or equal
|
|
17
|
+
to the number of dimensions of the first timeseries `x`,
|
|
18
|
+
and each dimension should match the corresponding one of
|
|
19
|
+
`x` in reverse order (i.e. if `x` has shape `(B, C, T)`
|
|
20
|
+
then `y` should either have shape `(T,)`, `(C, T)`, or
|
|
21
|
+
`(B, C, T)`).
|
|
22
|
+
|
|
23
|
+
Note that no windowing to either timeseries is applied
|
|
24
|
+
at call time. Users should do any requisite windowing
|
|
25
|
+
beforehand.
|
|
26
|
+
|
|
27
|
+
TODOs:
|
|
28
|
+
- Should we perform windowing?
|
|
29
|
+
- Should we support stride > 1?
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
max_shift:
|
|
33
|
+
The maximum number of 1-step time shifts in
|
|
34
|
+
each direction over which to compute the
|
|
35
|
+
Pearson coefficient. Output shape will then
|
|
36
|
+
be `(2 * max_shifts + 1, B, C)`.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, max_shift: int) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.max_shift = max_shift
|
|
42
|
+
|
|
43
|
+
def _shape_checks(self, x: torch.Tensor, y: torch.Tensor):
|
|
44
|
+
if x.ndim > 3:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"Tensor x can only have up to 3 dimensions "
|
|
47
|
+
f"to compute ShiftedPearsonCorrelation. Found {x.ndim}."
|
|
48
|
+
)
|
|
49
|
+
elif y.ndim > x.ndim:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"y may not have more dimensions that x for "
|
|
52
|
+
"ShiftedPearsonCorrelation, but found shapes "
|
|
53
|
+
"{} and {}".format(y.shape, x.shape)
|
|
54
|
+
)
|
|
55
|
+
for dim in range(y.ndim):
|
|
56
|
+
if y.size(-dim - 1) != x.size(-dim - 1):
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"x and y expected to have same size along "
|
|
59
|
+
"last dimensions, but found shapes {} and {}".format(
|
|
60
|
+
x.shape, y.shape
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# TODO: torchtyping annotate
|
|
65
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
self._shape_checks(x, y)
|
|
67
|
+
dim = x.size(-1)
|
|
68
|
+
|
|
69
|
+
# pad x along time dimension so that it has shape
|
|
70
|
+
# batch x channels x (time + 2 * max_shift)
|
|
71
|
+
pad = (self.max_shift, self.max_shift)
|
|
72
|
+
x = torch.nn.functional.pad(x, pad)
|
|
73
|
+
|
|
74
|
+
# num_windows x batch x channels x time
|
|
75
|
+
x = unfold_windows(x, dim, 1)
|
|
76
|
+
|
|
77
|
+
# now compute the correlation between each window
|
|
78
|
+
# of x and the single window of y. Start by de-meaning
|
|
79
|
+
x = x - x.mean(-1, keepdims=True)
|
|
80
|
+
y = y - y.mean(-1, keepdims=True)
|
|
81
|
+
|
|
82
|
+
# apply formula and sum along time dimension to give final shape
|
|
83
|
+
# num_windows x batch x channels
|
|
84
|
+
corr = (x * y).sum(axis=-1)
|
|
85
|
+
norm = (x**2).sum(-1) * (y**2).sum(-1)
|
|
86
|
+
|
|
87
|
+
return corr / norm**0.5
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torchaudio.transforms import Spectrogram
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MultiResolutionSpectrogram(torch.nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Create a batch of multi-resolution spectrograms
|
|
12
|
+
from a batch of timeseries. Input is expected to
|
|
13
|
+
have the shape `(B, C, T)`, where `B` is the number
|
|
14
|
+
of batches, `C` is the number of channels, and `T`
|
|
15
|
+
is the number of time samples.
|
|
16
|
+
|
|
17
|
+
For each timeseries, calculate multiple normalized
|
|
18
|
+
spectrograms based on the `Spectrogram` `kwargs` given.
|
|
19
|
+
Combine the spectrograms by taking the maximum value
|
|
20
|
+
from the nearest time-frequncy bin.
|
|
21
|
+
|
|
22
|
+
If the largest number of time bins among the spectrograms
|
|
23
|
+
is `N` and the largest number of frequency bins is `M`,
|
|
24
|
+
the output will have dimensions `(B, C, M, N)`
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
kernel_length:
|
|
28
|
+
The length in seconds of the time dimension
|
|
29
|
+
of the tensor that will be turned into a
|
|
30
|
+
spectrogram
|
|
31
|
+
sample_rate:
|
|
32
|
+
The sample rate of the timeseries in Hz
|
|
33
|
+
kwargs:
|
|
34
|
+
Arguments passed in kwargs will used to create
|
|
35
|
+
`torchaudio.transforms.Spectrogram`s. Each
|
|
36
|
+
argument should be a list of values. Any list
|
|
37
|
+
of length greater than 1 should be the same
|
|
38
|
+
length
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self, kernel_length: float, sample_rate: float, **kwargs
|
|
43
|
+
) -> None:
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.kernel_size = kernel_length * sample_rate
|
|
46
|
+
# This method of combination makes sense only when
|
|
47
|
+
# the spectrograms are normalized, so enforce this
|
|
48
|
+
if "normalized" in kwargs.keys():
|
|
49
|
+
if not all(kwargs["normalized"]):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Received a value of False for 'normalized'. "
|
|
52
|
+
"This method of combination is sensible only for "
|
|
53
|
+
"normalized spectrograms."
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
kwargs["normalized"] = [True]
|
|
57
|
+
self.kwargs = self._check_and_format_kwargs(kwargs)
|
|
58
|
+
|
|
59
|
+
self.transforms = torch.nn.ModuleList(
|
|
60
|
+
[Spectrogram(**k) for k in self.kwargs]
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
dummy_input = torch.ones(int(kernel_length * sample_rate))
|
|
64
|
+
self.shapes = torch.tensor(
|
|
65
|
+
[t(dummy_input).shape for t in self.transforms]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.num_freqs = max([shape[0] for shape in self.shapes])
|
|
69
|
+
self.num_times = max([shape[1] for shape in self.shapes])
|
|
70
|
+
|
|
71
|
+
left_pad = torch.zeros(len(self.transforms), dtype=torch.int)
|
|
72
|
+
top_pad = torch.zeros(len(self.transforms), dtype=torch.int)
|
|
73
|
+
bottom_pad = torch.tensor(
|
|
74
|
+
[int(self.num_freqs - shape[0]) for shape in self.shapes]
|
|
75
|
+
)
|
|
76
|
+
right_pad = torch.tensor(
|
|
77
|
+
[int(self.num_times - shape[1]) for shape in self.shapes]
|
|
78
|
+
)
|
|
79
|
+
self.register_buffer("left_pad", left_pad)
|
|
80
|
+
self.register_buffer("top_pad", top_pad)
|
|
81
|
+
self.register_buffer("bottom_pad", bottom_pad)
|
|
82
|
+
self.register_buffer("right_pad", right_pad)
|
|
83
|
+
|
|
84
|
+
freq_idxs = torch.tensor(
|
|
85
|
+
[
|
|
86
|
+
[int(i * shape[0] / self.num_freqs) for shape in self.shapes]
|
|
87
|
+
for i in range(self.num_freqs)
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
freq_idxs = freq_idxs.repeat(self.num_times, 1, 1).transpose(0, 1)
|
|
91
|
+
time_idxs = torch.tensor(
|
|
92
|
+
[
|
|
93
|
+
[int(i * shape[1] / self.num_times) for shape in self.shapes]
|
|
94
|
+
for i in range(self.num_times)
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
time_idxs = time_idxs.repeat(self.num_freqs, 1, 1)
|
|
98
|
+
|
|
99
|
+
self.register_buffer("freq_idxs", freq_idxs)
|
|
100
|
+
self.register_buffer("time_idxs", time_idxs)
|
|
101
|
+
|
|
102
|
+
def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
|
|
103
|
+
lengths = sorted(set([len(v) for v in kwargs.values()]))
|
|
104
|
+
|
|
105
|
+
if lengths[-1] > 3:
|
|
106
|
+
warnings.warn(
|
|
107
|
+
"Combining too many spectrograms can impede computation time. "
|
|
108
|
+
"If performance is slower than desired, try reducing the "
|
|
109
|
+
"number of spectrograms",
|
|
110
|
+
RuntimeWarning,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if len(lengths) > 2 or (len(lengths) == 2 and lengths[0] != 1):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Spectrogram keyword args should all have the same "
|
|
116
|
+
f"length or be of length one. Got lengths {lengths}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if len(lengths) == 2:
|
|
120
|
+
size = lengths[1]
|
|
121
|
+
kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
|
|
122
|
+
|
|
123
|
+
return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
|
|
124
|
+
|
|
125
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Calculate spectrograms of the input tensor and
|
|
128
|
+
combine them into a single spectrogram
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
X:
|
|
132
|
+
Batch of multichannel timeseries which will
|
|
133
|
+
be used to calculate the multi-resolution
|
|
134
|
+
spectrogram. Should have the shape
|
|
135
|
+
`(B, C, T)`, where `B` is the number of
|
|
136
|
+
batches, `C` is the number of channels,
|
|
137
|
+
and `T` is the number of time samples.
|
|
138
|
+
"""
|
|
139
|
+
if X.shape[-1] != self.kernel_size:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"Expected time dimension to be "
|
|
142
|
+
f"{self.kernel_size} samples long, got input with "
|
|
143
|
+
f"{X.shape[-1]} samples"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
spectrograms = [t(X) for t in self.transforms]
|
|
147
|
+
|
|
148
|
+
padded_specs = []
|
|
149
|
+
for spec, left, right, top, bottom in zip(
|
|
150
|
+
spectrograms,
|
|
151
|
+
self.left_pad,
|
|
152
|
+
self.right_pad,
|
|
153
|
+
self.top_pad,
|
|
154
|
+
self.bottom_pad,
|
|
155
|
+
):
|
|
156
|
+
padded_specs.append(F.pad(spec, (left, right, top, bottom)))
|
|
157
|
+
|
|
158
|
+
padded_specs = torch.stack(padded_specs)
|
|
159
|
+
remapped_specs = padded_specs[..., self.freq_idxs, self.time_idxs]
|
|
160
|
+
remapped_specs = torch.diagonal(remapped_specs, dim1=0, dim2=-1)
|
|
161
|
+
|
|
162
|
+
return torch.max(remapped_specs, axis=-1)[0]
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -123,7 +123,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
123
123
|
num_channels: float,
|
|
124
124
|
kernel_length: float,
|
|
125
125
|
sample_rate: float,
|
|
126
|
-
dtype: torch.dtype = torch.
|
|
126
|
+
dtype: torch.dtype = torch.float64,
|
|
127
127
|
) -> None:
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.num_channels = num_channels
|
ml4gw/waveforms/__init__.py
CHANGED