ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +43 -0
- ml4gw/dataloading/__init__.py +2 -1
- ml4gw/dataloading/chunked_dataset.py +66 -212
- ml4gw/dataloading/hdf5_dataset.py +176 -0
- ml4gw/nn/__init__.py +0 -0
- ml4gw/nn/autoencoder/__init__.py +3 -0
- ml4gw/nn/autoencoder/base.py +89 -0
- ml4gw/nn/autoencoder/convolutional.py +156 -0
- ml4gw/nn/autoencoder/skip_connection.py +46 -0
- ml4gw/nn/autoencoder/utils.py +14 -0
- ml4gw/nn/norm.py +97 -0
- ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw/nn/streaming/__init__.py +2 -0
- ml4gw/nn/streaming/online_average.py +121 -0
- ml4gw/nn/streaming/snapshotter.py +121 -0
- ml4gw/transforms/__init__.py +2 -0
- ml4gw/transforms/pearson.py +87 -0
- ml4gw/transforms/spectrogram.py +162 -0
- ml4gw/transforms/whitening.py +1 -1
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/phenom_d.py +1359 -0
- ml4gw/waveforms/phenom_d_data.py +3026 -0
- ml4gw/waveforms/taylorf2.py +306 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/METADATA +14 -6
- ml4gw-0.4.0.dist-info/RECORD +43 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/WHEEL +1 -1
- ml4gw-0.2.0.dist-info/RECORD +0 -23
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ml4gw.nn.autoencoder.base import Autoencoder
|
|
7
|
+
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
8
|
+
from ml4gw.nn.autoencoder.utils import match_size
|
|
9
|
+
|
|
10
|
+
Module = Callable[[...], torch.nn.Module]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConvBlock(Autoencoder):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
in_channels: int,
|
|
17
|
+
encode_channels: int,
|
|
18
|
+
kernel_size: int,
|
|
19
|
+
stride: int = 1,
|
|
20
|
+
groups: int = 1,
|
|
21
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
22
|
+
norm: Module = torch.nn.BatchNorm1d,
|
|
23
|
+
decode_channels: Optional[int] = None,
|
|
24
|
+
output_activation: Optional[torch.nn.Module] = None,
|
|
25
|
+
skip_connection: Optional[SkipConnection] = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(skip_connection=None)
|
|
28
|
+
|
|
29
|
+
self.kernel_size = kernel_size
|
|
30
|
+
self.padding = int((kernel_size - 1) // 2)
|
|
31
|
+
self.stride = stride
|
|
32
|
+
|
|
33
|
+
out_channels = encode_channels * groups
|
|
34
|
+
self.encode_layer = torch.nn.Conv1d(
|
|
35
|
+
in_channels,
|
|
36
|
+
out_channels,
|
|
37
|
+
kernel_size=kernel_size,
|
|
38
|
+
stride=stride,
|
|
39
|
+
padding=self.padding,
|
|
40
|
+
bias=False,
|
|
41
|
+
groups=groups,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
decode_channels = decode_channels or in_channels
|
|
45
|
+
in_channels = encode_channels * groups
|
|
46
|
+
if skip_connection is not None:
|
|
47
|
+
in_channels = skip_connection.get_out_channels(in_channels)
|
|
48
|
+
self.decode_layer = torch.nn.ConvTranspose1d(
|
|
49
|
+
in_channels,
|
|
50
|
+
decode_channels,
|
|
51
|
+
kernel_size=kernel_size,
|
|
52
|
+
stride=stride,
|
|
53
|
+
padding=self.padding,
|
|
54
|
+
bias=False,
|
|
55
|
+
groups=groups,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
self.activation = activation
|
|
59
|
+
if output_activation is not None:
|
|
60
|
+
self.output_activation = output_activation
|
|
61
|
+
else:
|
|
62
|
+
self.output_activation = activation
|
|
63
|
+
|
|
64
|
+
self.encode_norm = norm(out_channels)
|
|
65
|
+
self.decode_norm = norm(decode_channels)
|
|
66
|
+
|
|
67
|
+
def encode(self, X):
|
|
68
|
+
X = self.encode_layer(X)
|
|
69
|
+
X = self.encode_norm(X)
|
|
70
|
+
return self.activation(X)
|
|
71
|
+
|
|
72
|
+
def decode(self, X):
|
|
73
|
+
X = self.decode_layer(X)
|
|
74
|
+
X = self.decode_norm(X)
|
|
75
|
+
return self.output_activation(X)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class ConvolutionalAutoencoder(Autoencoder):
|
|
79
|
+
"""
|
|
80
|
+
Build a stack of convolutional autoencoder layer
|
|
81
|
+
blocks. The output of each decoder layer will
|
|
82
|
+
match the shape of the input to its corresponding
|
|
83
|
+
encoder layer, except for the last decoder which
|
|
84
|
+
can have an arbitrary number of channels specified
|
|
85
|
+
by `decode_channels`.
|
|
86
|
+
|
|
87
|
+
All layers also share the same `activation` except
|
|
88
|
+
for the last decoder layer, which can have an
|
|
89
|
+
arbitrary `output_activation`.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
in_channels: int,
|
|
95
|
+
encode_channels: Sequence[int],
|
|
96
|
+
kernel_size: int,
|
|
97
|
+
stride: int = 1,
|
|
98
|
+
groups: int = 1,
|
|
99
|
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
|
100
|
+
output_activation: Optional[torch.nn.Module] = None,
|
|
101
|
+
norm: Module = torch.nn.BatchNorm1d,
|
|
102
|
+
decode_channels: Optional[int] = None,
|
|
103
|
+
skip_connection: Optional[SkipConnection] = None,
|
|
104
|
+
) -> None:
|
|
105
|
+
# TODO: how to do this dynamically? Maybe the base
|
|
106
|
+
# architecture looks for overlapping arguments between
|
|
107
|
+
# this and the skip connection class and then provides them?
|
|
108
|
+
# if skip_connection is not None:
|
|
109
|
+
# skip_connection = skip_connection(groups)
|
|
110
|
+
super().__init__(skip_connection=skip_connection)
|
|
111
|
+
|
|
112
|
+
output_activation = output_activation or activation
|
|
113
|
+
for i, channels in enumerate(encode_channels):
|
|
114
|
+
# All intermediate layers should decode to
|
|
115
|
+
# the same number of channels. The last decoder
|
|
116
|
+
# should decode to whatever number of channels
|
|
117
|
+
# was specified, even if it's `None` (in which
|
|
118
|
+
# case it will just be in_channels anyway)
|
|
119
|
+
decode = in_channels if i else decode_channels
|
|
120
|
+
|
|
121
|
+
# don't have the middle layer skip to itself
|
|
122
|
+
# TODO: wait I don't think this makes sense.
|
|
123
|
+
# j = len(encode_channels) - 1 - i
|
|
124
|
+
# connect = skip_connection if j else None
|
|
125
|
+
connect = skip_connection
|
|
126
|
+
|
|
127
|
+
# all intermediate layers should use the same
|
|
128
|
+
# activation. Only the last decoder should have
|
|
129
|
+
# a potentially different activation
|
|
130
|
+
out_act = None if i else output_activation
|
|
131
|
+
|
|
132
|
+
block = ConvBlock(
|
|
133
|
+
in_channels,
|
|
134
|
+
channels,
|
|
135
|
+
kernel_size,
|
|
136
|
+
stride,
|
|
137
|
+
groups,
|
|
138
|
+
activation=activation,
|
|
139
|
+
norm=norm,
|
|
140
|
+
decode_channels=decode,
|
|
141
|
+
skip_connection=connect,
|
|
142
|
+
output_activation=out_act,
|
|
143
|
+
)
|
|
144
|
+
self.blocks.append(block)
|
|
145
|
+
in_channels = channels * groups
|
|
146
|
+
|
|
147
|
+
def decode(self, *X, states=None, input_size: Optional[int] = None):
|
|
148
|
+
X = super().decode(*X, states=states)
|
|
149
|
+
if input_size is not None:
|
|
150
|
+
return match_size(X, input_size)
|
|
151
|
+
return X
|
|
152
|
+
|
|
153
|
+
def forward(self, X):
|
|
154
|
+
input_size = X.size(-1)
|
|
155
|
+
X = super().forward(X)
|
|
156
|
+
return match_size(X, input_size)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ml4gw.nn.autoencoder.utils import match_size
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SkipConnection(torch.nn.Module):
|
|
7
|
+
def forward(self, X: torch.Tensor, state: torch.Tensor):
|
|
8
|
+
return match_size(X, state.size(-1))
|
|
9
|
+
|
|
10
|
+
def get_out_channels(self, in_channels):
|
|
11
|
+
return in_channels
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AddSkipConnect(SkipConnection):
|
|
15
|
+
def forward(self, X, state):
|
|
16
|
+
X = super().forward(X, state)
|
|
17
|
+
return X + state
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ConcatSkipConnect(SkipConnection):
|
|
21
|
+
def __init__(self, groups: int = 1):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.groups = groups
|
|
24
|
+
|
|
25
|
+
def get_out_channels(self, in_channels):
|
|
26
|
+
return 2 * in_channels
|
|
27
|
+
|
|
28
|
+
def forward(self, X, state):
|
|
29
|
+
X = super().forward(X, state)
|
|
30
|
+
if self.groups == 1:
|
|
31
|
+
return torch.cat([X, state], dim=1)
|
|
32
|
+
|
|
33
|
+
num_channels = X.size(1)
|
|
34
|
+
rem = num_channels % self.groups
|
|
35
|
+
if rem:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Number of channels in input tensor {} cannot "
|
|
38
|
+
"be divided evenly into {} groups".format(
|
|
39
|
+
num_channels, self.groups
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
X = torch.split(X, self.groups, dim=1)
|
|
44
|
+
state = torch.split(state, self.groups, dim=1)
|
|
45
|
+
frags = [i for j in zip(X, state) for i in j]
|
|
46
|
+
return torch.cat(frags, dim=1)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def match_size(X: torch.Tensor, target_size: int):
|
|
5
|
+
diff = target_size - X.size(-1)
|
|
6
|
+
left = int(diff // 2)
|
|
7
|
+
right = diff - left
|
|
8
|
+
|
|
9
|
+
if diff > 0:
|
|
10
|
+
return torch.nn.functional.pad(X, (left, right))
|
|
11
|
+
elif diff < 0:
|
|
12
|
+
right = -right or None
|
|
13
|
+
return X[:, :, -left:right]
|
|
14
|
+
return X
|
ml4gw/nn/norm.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
NormLayer = Callable[[int], torch.nn.Module]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GroupNorm1D(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Custom implementation of GroupNorm which is faster than the
|
|
11
|
+
out-of-the-box PyTorch version at inference time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_channels: int,
|
|
17
|
+
num_groups: Optional[int] = None,
|
|
18
|
+
eps: float = 1e-5,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
num_groups = num_groups or num_channels
|
|
22
|
+
if num_channels % num_groups:
|
|
23
|
+
raise ValueError("num_groups must be a factor of num_channels")
|
|
24
|
+
|
|
25
|
+
self.num_channels = num_channels
|
|
26
|
+
self.num_groups = num_groups
|
|
27
|
+
self.channels_per_group = self.num_channels // self.num_groups
|
|
28
|
+
self.eps = eps
|
|
29
|
+
|
|
30
|
+
shape = (self.num_channels, 1)
|
|
31
|
+
self.weight = torch.nn.Parameter(torch.ones(shape))
|
|
32
|
+
self.bias = torch.nn.Parameter(torch.zeros(shape))
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
keepdims = self.num_groups == self.num_channels
|
|
36
|
+
|
|
37
|
+
# compute group variance via the E[x**2] - E**2[x] trick
|
|
38
|
+
mean = x.mean(-1, keepdims=keepdims)
|
|
39
|
+
sq_mean = (x**2).mean(-1, keepdims=keepdims)
|
|
40
|
+
|
|
41
|
+
# if we have groups, do some reshape magic
|
|
42
|
+
# to calculate group level stats then
|
|
43
|
+
# reshape back to full channel dimension
|
|
44
|
+
if self.num_groups != self.num_channels:
|
|
45
|
+
mean = torch.stack([mean, sq_mean], dim=1)
|
|
46
|
+
mean = mean.reshape(
|
|
47
|
+
-1, 2, self.num_groups, self.channels_per_group
|
|
48
|
+
)
|
|
49
|
+
mean = mean.mean(-1, keepdims=True)
|
|
50
|
+
mean = mean.expand(-1, -1, -1, self.channels_per_group)
|
|
51
|
+
mean = mean.reshape(-1, 2, self.num_channels, 1)
|
|
52
|
+
mean, sq_mean = mean[:, 0], mean[:, 1]
|
|
53
|
+
|
|
54
|
+
# roll the mean and variance into the
|
|
55
|
+
# weight and bias so that we have to do
|
|
56
|
+
# fewer computations along the full time axis
|
|
57
|
+
std = (sq_mean - mean**2 + self.eps) ** 0.5
|
|
58
|
+
scale = self.weight / std
|
|
59
|
+
shift = self.bias - scale * mean
|
|
60
|
+
return shift + x * scale
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GroupNorm1DGetter:
|
|
64
|
+
"""
|
|
65
|
+
Utility for making a NormLayer Callable that maps from
|
|
66
|
+
an integer number of channels to a torch Module. Useful
|
|
67
|
+
for command-line parameterization with jsonargparse.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
71
|
+
self.groups = groups
|
|
72
|
+
|
|
73
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
74
|
+
if self.groups is None:
|
|
75
|
+
num_groups = None
|
|
76
|
+
else:
|
|
77
|
+
num_groups = min(num_channels, self.groups)
|
|
78
|
+
return GroupNorm1D(num_channels, num_groups)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# TODO generalize faster 1dDGroupNorm to 2D
|
|
82
|
+
class GroupNorm2DGetter:
|
|
83
|
+
"""
|
|
84
|
+
Utility for making a NormLayer Callable that maps from
|
|
85
|
+
an integer number of channels to a torch Module. Useful
|
|
86
|
+
for command-line parameterization with jsonargparse.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
90
|
+
self.groups = groups
|
|
91
|
+
|
|
92
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
93
|
+
if self.groups is None:
|
|
94
|
+
num_groups = num_channels
|
|
95
|
+
else:
|
|
96
|
+
num_groups = min(num_channels, self.groups)
|
|
97
|
+
return torch.nn.GroupNorm(num_groups, num_channels)
|