ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +45 -0
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +18 -12
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +11 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +25 -6
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +7 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +1338 -1256
- ml4gw/waveforms/phenom_p.py +796 -0
- ml4gw/waveforms/ringdown.py +109 -0
- ml4gw/waveforms/sine_gaussian.py +10 -11
- ml4gw/waveforms/taylorf2.py +304 -279
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/METADATA +5 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.4.2.dist-info/RECORD +0 -44
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
ml4gw/utils/slicing.py
CHANGED
|
@@ -1,25 +1,30 @@
|
|
|
1
1
|
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float, Int64
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
from torch.nn.functional import unfold
|
|
5
|
-
from torchtyping import TensorType
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
from ml4gw.types import (
|
|
9
|
+
TimeSeries1d,
|
|
10
|
+
TimeSeries1to3d,
|
|
11
|
+
TimeSeries2d,
|
|
12
|
+
TimeSeries3d,
|
|
13
|
+
)
|
|
9
14
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
BatchTimeSeriesTensor = Union[
|
|
13
|
-
TensorType["batch", "time"], TensorType["batch", "channel", "time"]
|
|
14
|
-
]
|
|
15
|
+
BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def unfold_windows(
|
|
18
|
-
x:
|
|
19
|
+
x: TimeSeries1to3d,
|
|
19
20
|
window_size: int,
|
|
20
21
|
stride: int,
|
|
21
22
|
drop_last: bool = True,
|
|
22
|
-
)
|
|
23
|
+
) -> Union[
|
|
24
|
+
Float[TimeSeries1d, " window"],
|
|
25
|
+
Float[TimeSeries2d, " window"],
|
|
26
|
+
Float[TimeSeries3d, " window"],
|
|
27
|
+
]:
|
|
23
28
|
"""Unfold a timeseries into windows
|
|
24
29
|
|
|
25
30
|
Args:
|
|
@@ -83,8 +88,8 @@ def unfold_windows(
|
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
def slice_kernels(
|
|
86
|
-
x:
|
|
87
|
-
idx:
|
|
91
|
+
x: TimeSeries1to3d,
|
|
92
|
+
idx: Int64[Tensor, "..."],
|
|
88
93
|
kernel_size: int,
|
|
89
94
|
) -> BatchTimeSeriesTensor:
|
|
90
95
|
"""Slice kernels from single or multichannel timeseries
|
|
@@ -96,7 +101,8 @@ def slice_kernels(
|
|
|
96
101
|
one more dimension than `x`.
|
|
97
102
|
|
|
98
103
|
Args:
|
|
99
|
-
x:
|
|
104
|
+
x:
|
|
105
|
+
The timeseries tensor to slice kernels from
|
|
100
106
|
idx:
|
|
101
107
|
The indices in `x` of the first sample of each
|
|
102
108
|
kernel. If `x` is 1D, `idx` must be 1D as well.
|
|
@@ -114,6 +120,7 @@ def slice_kernels(
|
|
|
114
120
|
coincidentally among the channels.
|
|
115
121
|
kernel_size:
|
|
116
122
|
The length of the kernels to slice from the timeseries
|
|
123
|
+
|
|
117
124
|
Returns:
|
|
118
125
|
A tensor of shape `(batch_size, kernel_size)` if `x` is
|
|
119
126
|
1D and `(batch_size, num_channels, kernel_size)` if `x`
|
|
@@ -225,7 +232,7 @@ def slice_kernels(
|
|
|
225
232
|
|
|
226
233
|
|
|
227
234
|
def sample_kernels(
|
|
228
|
-
X:
|
|
235
|
+
X: TimeSeries1to3d,
|
|
229
236
|
kernel_size: int,
|
|
230
237
|
N: Optional[int] = None,
|
|
231
238
|
max_center_offset: Optional[int] = None,
|
|
@@ -245,8 +252,9 @@ def sample_kernels(
|
|
|
245
252
|
either be `None` or be equal to `len(X)`.
|
|
246
253
|
|
|
247
254
|
Args:
|
|
248
|
-
X:
|
|
249
|
-
|
|
255
|
+
X:
|
|
256
|
+
The timeseries tensor from which to sample kernels
|
|
257
|
+
kernel_size: The size of the kernels to sample
|
|
250
258
|
N:
|
|
251
259
|
The number of kernels to sample. Can be left as
|
|
252
260
|
`None` if `X` is 3D, otherwise must be specified
|
ml4gw/waveforms/__init__.py
CHANGED
ml4gw/waveforms/generator.py
CHANGED
|
@@ -1,24 +1,26 @@
|
|
|
1
|
-
from typing import Callable
|
|
1
|
+
from typing import Callable, Dict, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class ParameterSampler(torch.nn.Module):
|
|
7
|
-
def __init__(self, **parameters: Callable):
|
|
9
|
+
def __init__(self, **parameters: Callable) -> None:
|
|
8
10
|
super().__init__()
|
|
9
11
|
self.parameters = parameters
|
|
10
12
|
|
|
11
13
|
def forward(
|
|
12
14
|
self,
|
|
13
15
|
N: int,
|
|
14
|
-
):
|
|
16
|
+
) -> Dict[str, Float[Tensor, " {N}"]]:
|
|
15
17
|
return {k: v.sample((N,)) for k, v in self.parameters.items()}
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class WaveformGenerator(torch.nn.Module):
|
|
19
21
|
def __init__(
|
|
20
22
|
self, waveform: Callable, parameter_sampler: ParameterSampler
|
|
21
|
-
):
|
|
23
|
+
) -> None:
|
|
22
24
|
"""
|
|
23
25
|
A torch module that generates waveforms from a given waveform function
|
|
24
26
|
and a parameter sampler.
|
|
@@ -34,6 +36,8 @@ class WaveformGenerator(torch.nn.Module):
|
|
|
34
36
|
self.waveform = waveform
|
|
35
37
|
self.parameter_sampler = parameter_sampler
|
|
36
38
|
|
|
37
|
-
def forward(
|
|
39
|
+
def forward(
|
|
40
|
+
self, N: int
|
|
41
|
+
) -> Tuple[Float[Tensor, "{N} samples"], Dict[str, Float[Tensor, " {N}"]]]:
|
|
38
42
|
parameters = self.parameter_sampler(N)
|
|
39
43
|
return self.waveform(**parameters), parameters
|