ml4gw 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +5 -3
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +5 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +9 -2
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +1 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +20 -18
- ml4gw/waveforms/phenom_p.py +77 -60
- ml4gw/waveforms/ringdown.py +8 -9
- ml4gw/waveforms/sine_gaussian.py +6 -6
- ml4gw/waveforms/taylorf2.py +33 -27
- {ml4gw-0.5.0.dist-info → ml4gw-0.5.1.dist-info}/METADATA +4 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.5.0.dist-info/RECORD +0 -47
- {ml4gw-0.5.0.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
ml4gw/augmentations.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
from torch import Tensor
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SignalInverter(torch.nn.Module):
|
|
@@ -16,7 +18,9 @@ class SignalInverter(torch.nn.Module):
|
|
|
16
18
|
super().__init__()
|
|
17
19
|
self.prob = prob
|
|
18
20
|
|
|
19
|
-
def forward(
|
|
21
|
+
def forward(
|
|
22
|
+
self, X: Float[Tensor, "*batch time"]
|
|
23
|
+
) -> Float[Tensor, "*batch time"]:
|
|
20
24
|
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
21
25
|
X[mask] *= -1
|
|
22
26
|
return X
|
|
@@ -37,7 +41,9 @@ class SignalReverser(torch.nn.Module):
|
|
|
37
41
|
super().__init__()
|
|
38
42
|
self.prob = prob
|
|
39
43
|
|
|
40
|
-
def forward(
|
|
44
|
+
def forward(
|
|
45
|
+
self, X: Float[Tensor, "*batch time"]
|
|
46
|
+
) -> Float[Tensor, "*batch time"]:
|
|
41
47
|
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
42
48
|
X[mask] = X[mask].flip(-1)
|
|
43
49
|
return X
|
|
@@ -2,6 +2,8 @@ from collections.abc import Iterable
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
+
from ml4gw.types import WaveformTensor
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
7
9
|
"""
|
|
@@ -55,10 +57,10 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
55
57
|
self.coincident = coincident
|
|
56
58
|
self.device = device
|
|
57
59
|
|
|
58
|
-
def __len__(self):
|
|
60
|
+
def __len__(self) -> int:
|
|
59
61
|
return len(self.chunk_it) * self.batches_per_chunk
|
|
60
62
|
|
|
61
|
-
def __iter__(self):
|
|
63
|
+
def __iter__(self) -> WaveformTensor:
|
|
62
64
|
it = iter(self.chunk_it)
|
|
63
65
|
chunk = next(it)
|
|
64
66
|
num_chunks, num_channels, chunk_size = chunk.shape
|
|
@@ -161,7 +161,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
161
161
|
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
|
|
162
162
|
return torch.Tensor(x)
|
|
163
163
|
|
|
164
|
-
def __iter__(self) ->
|
|
164
|
+
def __iter__(self) -> WaveformTensor:
|
|
165
165
|
worker_info = torch.utils.data.get_worker_info()
|
|
166
166
|
if worker_info is None:
|
|
167
167
|
num_batches = self.batches_per_epoch
|
|
@@ -2,8 +2,9 @@ import itertools
|
|
|
2
2
|
from typing import Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from torch import Tensor
|
|
5
7
|
|
|
6
|
-
from ml4gw import types
|
|
7
8
|
from ml4gw.utils.slicing import slice_kernels
|
|
8
9
|
|
|
9
10
|
|
|
@@ -76,9 +77,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
76
77
|
|
|
77
78
|
def __init__(
|
|
78
79
|
self,
|
|
79
|
-
X:
|
|
80
|
+
X: Float[Tensor, "channels time"],
|
|
80
81
|
kernel_size: int,
|
|
81
|
-
y: Optional[
|
|
82
|
+
y: Optional[Float[Tensor, " time"]] = None,
|
|
82
83
|
batch_size: int = 32,
|
|
83
84
|
stride: int = 1,
|
|
84
85
|
batches_per_epoch: Optional[int] = None,
|
|
@@ -207,7 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
207
208
|
|
|
208
209
|
def __iter__(
|
|
209
210
|
self,
|
|
210
|
-
) -> Union[
|
|
211
|
+
) -> Union[
|
|
212
|
+
Float[Tensor, "batch channel time"],
|
|
213
|
+
Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
|
|
214
|
+
]:
|
|
211
215
|
|
|
212
216
|
indices = self.init_indices()
|
|
213
217
|
for i in range(len(self)):
|
ml4gw/distributions.py
CHANGED
|
@@ -9,6 +9,8 @@ from typing import Optional
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
import torch.distributions as dist
|
|
12
|
+
from jaxtyping import Float
|
|
13
|
+
from torch import Tensor
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class Cosine(dist.Distribution):
|
|
@@ -31,11 +33,11 @@ class Cosine(dist.Distribution):
|
|
|
31
33
|
self.high = torch.as_tensor(high)
|
|
32
34
|
self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))
|
|
33
35
|
|
|
34
|
-
def rsample(self, sample_shape: torch.Size = torch.Size()) ->
|
|
36
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
|
|
35
37
|
u = torch.rand(sample_shape, device=self.low.device)
|
|
36
38
|
return torch.arcsin(u / self.norm + torch.sin(self.low))
|
|
37
39
|
|
|
38
|
-
def log_prob(self, value):
|
|
40
|
+
def log_prob(self, value: float) -> Float[Tensor, ""]:
|
|
39
41
|
value = torch.as_tensor(value)
|
|
40
42
|
inside_range = (value >= self.low) & (value <= self.high)
|
|
41
43
|
return value.cos().log() * inside_range
|
|
@@ -164,7 +166,7 @@ class DeltaFunction(dist.Distribution):
|
|
|
164
166
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
165
167
|
self.peak = torch.as_tensor(peak)
|
|
166
168
|
|
|
167
|
-
def rsample(self, sample_shape: torch.Size = torch.Size()) ->
|
|
169
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
|
|
168
170
|
return self.peak * torch.ones(
|
|
169
171
|
sample_shape, device=self.peak.device, dtype=torch.float32
|
|
170
172
|
)
|
ml4gw/gw.py
CHANGED
|
@@ -13,27 +13,21 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
|
|
|
13
13
|
from typing import List, Tuple, Union
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
|
-
from
|
|
16
|
+
from jaxtyping import Float
|
|
17
|
+
from torch import Tensor
|
|
17
18
|
|
|
19
|
+
from ml4gw.constants import C
|
|
18
20
|
from ml4gw.types import (
|
|
21
|
+
BatchTensor,
|
|
19
22
|
NetworkDetectorTensors,
|
|
20
23
|
NetworkVertices,
|
|
21
24
|
PSDTensor,
|
|
22
|
-
ScalarTensor,
|
|
23
25
|
TensorGeometry,
|
|
24
26
|
VectorGeometry,
|
|
25
27
|
WaveformTensor,
|
|
26
28
|
)
|
|
27
29
|
from ml4gw.utils.interferometer import InterferometerGeometry
|
|
28
30
|
|
|
29
|
-
SPEED_OF_LIGHT = 299792458.0 # m/s
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# define some tensor shapes we'll reuse a bit
|
|
33
|
-
# up front. Need to assign these variables so
|
|
34
|
-
# that static linters don't give us name errors
|
|
35
|
-
batch = num_ifos = polarizations = time = frequency = space = None # noqa
|
|
36
|
-
|
|
37
31
|
|
|
38
32
|
def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
|
|
39
33
|
"""
|
|
@@ -62,12 +56,12 @@ polarization_funcs = {
|
|
|
62
56
|
|
|
63
57
|
|
|
64
58
|
def compute_antenna_responses(
|
|
65
|
-
theta:
|
|
66
|
-
psi:
|
|
67
|
-
phi:
|
|
59
|
+
theta: BatchTensor,
|
|
60
|
+
psi: BatchTensor,
|
|
61
|
+
phi: BatchTensor,
|
|
68
62
|
detector_tensors: NetworkDetectorTensors,
|
|
69
63
|
modes: List[str],
|
|
70
|
-
) ->
|
|
64
|
+
) -> Float[Tensor, "batch polarizations num_ifos"]:
|
|
71
65
|
"""
|
|
72
66
|
Compute the antenna pattern factors of a batch of
|
|
73
67
|
waveforms as a function of the sky parameters of
|
|
@@ -147,8 +141,8 @@ def compute_antenna_responses(
|
|
|
147
141
|
|
|
148
142
|
def shift_responses(
|
|
149
143
|
responses: WaveformTensor,
|
|
150
|
-
theta:
|
|
151
|
-
phi:
|
|
144
|
+
theta: BatchTensor,
|
|
145
|
+
phi: BatchTensor,
|
|
152
146
|
vertices: NetworkVertices,
|
|
153
147
|
sample_rate: float,
|
|
154
148
|
) -> WaveformTensor:
|
|
@@ -166,7 +160,7 @@ def shift_responses(
|
|
|
166
160
|
# Divide by c in the second line so that we only
|
|
167
161
|
# need to multiply the array by a single float
|
|
168
162
|
dt = -(omega * vertices).sum(axis=-1)
|
|
169
|
-
dt *= sample_rate /
|
|
163
|
+
dt *= sample_rate / C
|
|
170
164
|
dt = torch.trunc(dt).type(torch.int64)
|
|
171
165
|
|
|
172
166
|
# rolling by gathering implementation based on
|
|
@@ -191,13 +185,13 @@ def shift_responses(
|
|
|
191
185
|
|
|
192
186
|
|
|
193
187
|
def compute_observed_strain(
|
|
194
|
-
dec:
|
|
195
|
-
psi:
|
|
196
|
-
phi:
|
|
188
|
+
dec: BatchTensor,
|
|
189
|
+
psi: BatchTensor,
|
|
190
|
+
phi: BatchTensor,
|
|
197
191
|
detector_tensors: NetworkDetectorTensors,
|
|
198
192
|
detector_vertices: NetworkVertices,
|
|
199
193
|
sample_rate: float,
|
|
200
|
-
**polarizations:
|
|
194
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
201
195
|
) -> WaveformTensor:
|
|
202
196
|
"""
|
|
203
197
|
Compute the strain timeseries $h(t)$ observed by a network
|
|
@@ -289,8 +283,8 @@ def compute_ifo_snr(
|
|
|
289
283
|
responses: WaveformTensor,
|
|
290
284
|
psd: PSDTensor,
|
|
291
285
|
sample_rate: float,
|
|
292
|
-
highpass: Union[float,
|
|
293
|
-
) ->
|
|
286
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
287
|
+
) -> Float[Tensor, "batch num_ifos"]:
|
|
294
288
|
r"""Compute the SNRs of a batch of interferometer responses
|
|
295
289
|
|
|
296
290
|
Compute the signal to noise ratio (SNR) of individual
|
|
@@ -390,8 +384,8 @@ def compute_network_snr(
|
|
|
390
384
|
responses: WaveformTensor,
|
|
391
385
|
psd: PSDTensor,
|
|
392
386
|
sample_rate: float,
|
|
393
|
-
highpass: Union[float,
|
|
394
|
-
) ->
|
|
387
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
388
|
+
) -> BatchTensor:
|
|
395
389
|
r"""
|
|
396
390
|
Compute the total SNR from a gravitational waveform
|
|
397
391
|
from a network of interferometers. The total SNR for
|
|
@@ -437,10 +431,10 @@ def compute_network_snr(
|
|
|
437
431
|
|
|
438
432
|
def reweight_snrs(
|
|
439
433
|
responses: WaveformTensor,
|
|
440
|
-
target_snrs: Union[float,
|
|
434
|
+
target_snrs: Union[float, BatchTensor],
|
|
441
435
|
psd: PSDTensor,
|
|
442
436
|
sample_rate: float,
|
|
443
|
-
highpass: Union[float,
|
|
437
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
444
438
|
) -> WaveformTensor:
|
|
445
439
|
"""Scale interferometer responses such that they have a desired SNR
|
|
446
440
|
|
ml4gw/nn/autoencoder/base.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
7
8
|
|
|
@@ -27,12 +28,16 @@ class Autoencoder(torch.nn.Module):
|
|
|
27
28
|
and how they operate.
|
|
28
29
|
"""
|
|
29
30
|
|
|
30
|
-
def __init__(
|
|
31
|
+
def __init__(
|
|
32
|
+
self, skip_connection: Optional[SkipConnection] = None
|
|
33
|
+
) -> None:
|
|
31
34
|
super().__init__()
|
|
32
35
|
self.skip_connection = skip_connection
|
|
33
36
|
self.blocks = torch.nn.ModuleList()
|
|
34
37
|
|
|
35
|
-
def encode(
|
|
38
|
+
def encode(
|
|
39
|
+
self, *X: Tensor, return_states: bool = False
|
|
40
|
+
) -> Union[Tensor, Tuple[Tensor, Sequence]]:
|
|
36
41
|
states = []
|
|
37
42
|
for block in self.blocks:
|
|
38
43
|
if isinstance(X, tuple):
|
|
@@ -48,7 +53,7 @@ class Autoencoder(torch.nn.Module):
|
|
|
48
53
|
return X, states[:-1]
|
|
49
54
|
return X
|
|
50
55
|
|
|
51
|
-
def decode(self, *X, states: Optional[Sequence[
|
|
56
|
+
def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
|
|
52
57
|
if self.skip_connection is not None and states is None:
|
|
53
58
|
raise ValueError(
|
|
54
59
|
"Must pass intermediate states when autoencoder "
|
|
@@ -76,7 +81,7 @@ class Autoencoder(torch.nn.Module):
|
|
|
76
81
|
X = self.skip_connection(X, state)
|
|
77
82
|
return X
|
|
78
83
|
|
|
79
|
-
def forward(self, *X):
|
|
84
|
+
def forward(self, *X: Tensor) -> Tensor:
|
|
80
85
|
return_states = self.skip_connection is not None
|
|
81
86
|
X = self.encode(*X, return_states=return_states)
|
|
82
87
|
if return_states:
|
|
@@ -84,6 +89,6 @@ class Autoencoder(torch.nn.Module):
|
|
|
84
89
|
else:
|
|
85
90
|
states = None
|
|
86
91
|
|
|
87
|
-
if isinstance(X,
|
|
92
|
+
if isinstance(X, Tensor):
|
|
88
93
|
X = (X,)
|
|
89
94
|
return self.decode(*X, states=states)
|
|
@@ -2,6 +2,7 @@ from collections.abc import Callable, Sequence
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
from ml4gw.nn.autoencoder.base import Autoencoder
|
|
7
8
|
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
@@ -64,12 +65,12 @@ class ConvBlock(Autoencoder):
|
|
|
64
65
|
self.encode_norm = norm(out_channels)
|
|
65
66
|
self.decode_norm = norm(decode_channels)
|
|
66
67
|
|
|
67
|
-
def encode(self, X):
|
|
68
|
+
def encode(self, X: Tensor) -> Tensor:
|
|
68
69
|
X = self.encode_layer(X)
|
|
69
70
|
X = self.encode_norm(X)
|
|
70
71
|
return self.activation(X)
|
|
71
72
|
|
|
72
|
-
def decode(self, X):
|
|
73
|
+
def decode(self, X: Tensor) -> Tensor:
|
|
73
74
|
X = self.decode_layer(X)
|
|
74
75
|
X = self.decode_norm(X)
|
|
75
76
|
return self.output_activation(X)
|
|
@@ -144,13 +145,15 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
144
145
|
self.blocks.append(block)
|
|
145
146
|
in_channels = channels * groups
|
|
146
147
|
|
|
147
|
-
def decode(
|
|
148
|
+
def decode(
|
|
149
|
+
self, *X, states=None, input_size: Optional[int] = None
|
|
150
|
+
) -> Tensor:
|
|
148
151
|
X = super().decode(*X, states=states)
|
|
149
152
|
if input_size is not None:
|
|
150
153
|
return match_size(X, input_size)
|
|
151
154
|
return X
|
|
152
155
|
|
|
153
|
-
def forward(self, X):
|
|
156
|
+
def forward(self, X: Tensor) -> Tensor:
|
|
154
157
|
input_size = X.size(-1)
|
|
155
158
|
X = super().forward(X)
|
|
156
159
|
return match_size(X, input_size)
|
|
@@ -1,31 +1,32 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from torch import Tensor
|
|
2
3
|
|
|
3
4
|
from ml4gw.nn.autoencoder.utils import match_size
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class SkipConnection(torch.nn.Module):
|
|
7
|
-
def forward(self, X:
|
|
8
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
8
9
|
return match_size(X, state.size(-1))
|
|
9
10
|
|
|
10
|
-
def get_out_channels(self, in_channels):
|
|
11
|
+
def get_out_channels(self, in_channels: int) -> int:
|
|
11
12
|
return in_channels
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class AddSkipConnect(SkipConnection):
|
|
15
|
-
def forward(self, X, state):
|
|
16
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
16
17
|
X = super().forward(X, state)
|
|
17
18
|
return X + state
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class ConcatSkipConnect(SkipConnection):
|
|
21
|
-
def __init__(self, groups: int = 1):
|
|
22
|
+
def __init__(self, groups: int = 1) -> None:
|
|
22
23
|
super().__init__()
|
|
23
24
|
self.groups = groups
|
|
24
25
|
|
|
25
|
-
def get_out_channels(self, in_channels):
|
|
26
|
+
def get_out_channels(self, in_channels: int) -> int:
|
|
26
27
|
return 2 * in_channels
|
|
27
28
|
|
|
28
|
-
def forward(self, X, state):
|
|
29
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
29
30
|
X = super().forward(X, state)
|
|
30
31
|
if self.groups == 1:
|
|
31
32
|
return torch.cat([X, state], dim=1)
|
ml4gw/nn/autoencoder/utils.py
CHANGED
ml4gw/nn/norm.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Callable, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
NormLayer = Callable[[int], torch.nn.Module]
|
|
6
8
|
|
|
@@ -31,7 +33,9 @@ class GroupNorm1D(torch.nn.Module):
|
|
|
31
33
|
self.weight = torch.nn.Parameter(torch.ones(shape))
|
|
32
34
|
self.bias = torch.nn.Parameter(torch.zeros(shape))
|
|
33
35
|
|
|
34
|
-
def forward(
|
|
36
|
+
def forward(
|
|
37
|
+
self, x: Float[Tensor, "batch channel length"]
|
|
38
|
+
) -> Float[Tensor, "batch channel length"]:
|
|
35
39
|
if len(x.shape) != 3:
|
|
36
40
|
raise ValueError(
|
|
37
41
|
"GroupNorm1D requires 3-dimensional input, "
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from typing import Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.utils.slicing import unfold_windows
|
|
6
8
|
|
|
7
|
-
Tensor = torch.Tensor
|
|
8
|
-
|
|
9
9
|
|
|
10
10
|
class OnlineAverager(torch.nn.Module):
|
|
11
11
|
"""
|
|
@@ -70,12 +70,14 @@ class OnlineAverager(torch.nn.Module):
|
|
|
70
70
|
weights = unfold_windows(weights, weight_size, update_size)
|
|
71
71
|
self.register_buffer("weights", weights)
|
|
72
72
|
|
|
73
|
-
def get_initial_state(self):
|
|
73
|
+
def get_initial_state(self) -> Float[Tensor, "channel time"]:
|
|
74
74
|
return torch.zeros((self.num_channels, self.state_size))
|
|
75
75
|
|
|
76
76
|
def forward(
|
|
77
|
-
self,
|
|
78
|
-
|
|
77
|
+
self,
|
|
78
|
+
update: Float[Tensor, "batch channel time1"],
|
|
79
|
+
state: Optional[Float[Tensor, "channel time2"]] = None,
|
|
80
|
+
) -> Tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
|
|
79
81
|
if state is None:
|
|
80
82
|
state = self.get_initial_state()
|
|
81
83
|
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Optional, Sequence, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.utils.slicing import unfold_windows
|
|
6
8
|
|
|
@@ -82,14 +84,14 @@ class Snapshotter(torch.nn.Module):
|
|
|
82
84
|
self.channels_per_snapshot = channels_per_snapshot
|
|
83
85
|
self.num_channels = num_channels
|
|
84
86
|
|
|
85
|
-
def get_initial_state(self):
|
|
87
|
+
def get_initial_state(self) -> Float[Tensor, "channel time"]:
|
|
86
88
|
return torch.zeros((self.num_channels, self.state_size))
|
|
87
89
|
|
|
88
|
-
# TODO: use torchtyping annotations to make
|
|
89
|
-
# clear what the expected shapes are
|
|
90
90
|
def forward(
|
|
91
|
-
self,
|
|
92
|
-
|
|
91
|
+
self,
|
|
92
|
+
update: Float[Tensor, "channel time1"],
|
|
93
|
+
snapshot: Optional[Float[Tensor, "channel time2"]] = None,
|
|
94
|
+
) -> Tuple[Tensor, ...]:
|
|
93
95
|
if snapshot is None:
|
|
94
96
|
snapshot = self.get_initial_state()
|
|
95
97
|
|
ml4gw/spectral.py
CHANGED
|
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
|
|
|
12
12
|
from typing import Optional, Union
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
|
-
from
|
|
15
|
+
from jaxtyping import Float
|
|
16
|
+
from torch import Tensor
|
|
16
17
|
|
|
17
|
-
from ml4gw import
|
|
18
|
+
from ml4gw.types import (
|
|
19
|
+
FrequencySeries1to3d,
|
|
20
|
+
PSDTensor,
|
|
21
|
+
TimeSeries1to3d,
|
|
22
|
+
WaveformTensor,
|
|
23
|
+
)
|
|
18
24
|
|
|
19
|
-
time = None
|
|
20
25
|
|
|
21
|
-
|
|
22
|
-
def median(x, axis):
|
|
26
|
+
def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
23
27
|
"""
|
|
24
28
|
Implements a median calculation that matches numpy's
|
|
25
29
|
behavior for an even number of elements and includes
|
|
@@ -33,7 +37,7 @@ def median(x, axis):
|
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
def _validate_shapes(
|
|
36
|
-
x:
|
|
40
|
+
x: Tensor, nperseg: int, y: Optional[Tensor] = None
|
|
37
41
|
) -> None:
|
|
38
42
|
if x.shape[-1] < nperseg:
|
|
39
43
|
raise ValueError(
|
|
@@ -83,14 +87,14 @@ def _validate_shapes(
|
|
|
83
87
|
|
|
84
88
|
|
|
85
89
|
def fast_spectral_density(
|
|
86
|
-
x:
|
|
90
|
+
x: TimeSeries1to3d,
|
|
87
91
|
nperseg: int,
|
|
88
92
|
nstride: int,
|
|
89
|
-
window:
|
|
90
|
-
scale:
|
|
93
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
94
|
+
scale: float,
|
|
91
95
|
average: str = "median",
|
|
92
|
-
y: Optional[
|
|
93
|
-
) ->
|
|
96
|
+
y: Optional[TimeSeries1to3d] = None,
|
|
97
|
+
) -> FrequencySeries1to3d:
|
|
94
98
|
"""
|
|
95
99
|
Compute the power spectral density of a multichannel
|
|
96
100
|
timeseries or a batch of multichannel timeseries, or
|
|
@@ -107,9 +111,9 @@ def fast_spectral_density(
|
|
|
107
111
|
The timeseries tensor whose power spectral density
|
|
108
112
|
to compute, or for cross spectral density the
|
|
109
113
|
timeseries whose fft will be conjugated. Can have
|
|
110
|
-
shape
|
|
111
|
-
`(
|
|
112
|
-
|
|
114
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
115
|
+
`(num_channels, length * sample_rate)`, or
|
|
116
|
+
`(length * sample_rate)`.
|
|
113
117
|
nperseg:
|
|
114
118
|
Number of samples included in each FFT window
|
|
115
119
|
nstride:
|
|
@@ -150,7 +154,7 @@ def fast_spectral_density(
|
|
|
150
154
|
channel in `x` across _all_ of `x`'s batch elements.
|
|
151
155
|
Returns:
|
|
152
156
|
Tensor of power spectral densities of `x` or its cross spectral
|
|
153
|
-
|
|
157
|
+
density with the timeseries in `y`.
|
|
154
158
|
"""
|
|
155
159
|
|
|
156
160
|
_validate_shapes(x, nperseg, y)
|
|
@@ -240,17 +244,16 @@ def fast_spectral_density(
|
|
|
240
244
|
|
|
241
245
|
|
|
242
246
|
def spectral_density(
|
|
243
|
-
x:
|
|
247
|
+
x: TimeSeries1to3d,
|
|
244
248
|
nperseg: int,
|
|
245
249
|
nstride: int,
|
|
246
|
-
window:
|
|
247
|
-
scale:
|
|
250
|
+
window: Float[Tensor, " {nperseg//2+1}"],
|
|
251
|
+
scale: float,
|
|
248
252
|
average: str = "median",
|
|
249
|
-
) ->
|
|
253
|
+
) -> FrequencySeries1to3d:
|
|
250
254
|
"""
|
|
251
255
|
Compute the power spectral density of a multichannel
|
|
252
|
-
timeseries or a batch of multichannel timeseries
|
|
253
|
-
the cross power spectral density of two such timeseries.
|
|
256
|
+
timeseries or a batch of multichannel timeseries.
|
|
254
257
|
This implementation is exact for all frequency bins, but
|
|
255
258
|
slower than the fast implementation.
|
|
256
259
|
|
|
@@ -259,9 +262,9 @@ def spectral_density(
|
|
|
259
262
|
The timeseries tensor whose power spectral density
|
|
260
263
|
to compute, or for cross spectral density the
|
|
261
264
|
timeseries whose fft will be conjugated. Can have
|
|
262
|
-
shape
|
|
263
|
-
`(
|
|
264
|
-
|
|
265
|
+
shape `(batch_size, num_channels, length * sample_rate)`,
|
|
266
|
+
`(num_channels, length * sample_rate)`, or
|
|
267
|
+
`(length * sample_rate)`.
|
|
265
268
|
nperseg:
|
|
266
269
|
Number of samples included in each FFT window
|
|
267
270
|
nstride:
|
|
@@ -336,11 +339,11 @@ def spectral_density(
|
|
|
336
339
|
|
|
337
340
|
|
|
338
341
|
def truncate_inverse_power_spectrum(
|
|
339
|
-
psd:
|
|
340
|
-
fduration: Union[
|
|
342
|
+
psd: PSDTensor,
|
|
343
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
341
344
|
sample_rate: float,
|
|
342
345
|
highpass: Optional[float] = None,
|
|
343
|
-
) ->
|
|
346
|
+
) -> PSDTensor:
|
|
344
347
|
"""
|
|
345
348
|
Truncate the length of the time domain response
|
|
346
349
|
of a whitening filter built using the specified
|
|
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
|
|
|
399
402
|
q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
|
|
400
403
|
|
|
401
404
|
# taper the edges of the TD filter
|
|
402
|
-
if isinstance(fduration,
|
|
405
|
+
if isinstance(fduration, Tensor):
|
|
403
406
|
pad = fduration.size(-1) // 2
|
|
404
407
|
window = fduration
|
|
405
408
|
else:
|
|
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
|
|
|
422
425
|
|
|
423
426
|
|
|
424
427
|
def normalize_by_psd(
|
|
425
|
-
X:
|
|
426
|
-
psd:
|
|
428
|
+
X: WaveformTensor,
|
|
429
|
+
psd: PSDTensor,
|
|
427
430
|
sample_rate: float,
|
|
428
431
|
pad: int,
|
|
429
432
|
):
|
|
@@ -447,12 +450,12 @@ def normalize_by_psd(
|
|
|
447
450
|
|
|
448
451
|
|
|
449
452
|
def whiten(
|
|
450
|
-
X:
|
|
451
|
-
psd:
|
|
452
|
-
fduration: Union[
|
|
453
|
+
X: WaveformTensor,
|
|
454
|
+
psd: PSDTensor,
|
|
455
|
+
fduration: Union[Float[Tensor, " time"], float],
|
|
453
456
|
sample_rate: float,
|
|
454
457
|
highpass: Optional[float] = None,
|
|
455
|
-
) ->
|
|
458
|
+
) -> WaveformTensor:
|
|
456
459
|
"""
|
|
457
460
|
Whiten a batch of timeseries using the specified
|
|
458
461
|
background one-sided power spectral densities (PSDs),
|
|
@@ -460,7 +463,8 @@ def whiten(
|
|
|
460
463
|
`fduration` and possibly to highpass filter.
|
|
461
464
|
|
|
462
465
|
Args:
|
|
463
|
-
X:
|
|
466
|
+
X:
|
|
467
|
+
batch of multichannel timeseries to whiten
|
|
464
468
|
psd:
|
|
465
469
|
PSDs use to whiten the data. The frequency
|
|
466
470
|
response of the whitening filter will be roughly
|
|
@@ -496,7 +500,7 @@ def whiten(
|
|
|
496
500
|
|
|
497
501
|
# figure out how much data we'll need to slice
|
|
498
502
|
# off after whitening
|
|
499
|
-
if isinstance(fduration,
|
|
503
|
+
if isinstance(fduration, Tensor):
|
|
500
504
|
pad = fduration.size(-1) // 2
|
|
501
505
|
else:
|
|
502
506
|
pad = int(fduration * sample_rate / 2)
|