ml4gw 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/nn/norm.py +97 -0
- ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/spectrogram.py +162 -0
- ml4gw/transforms/whitening.py +1 -1
- ml4gw/waveforms/phenom_d.py +1 -0
- {ml4gw-0.3.0.dist-info → ml4gw-0.4.1.dist-info}/METADATA +12 -7
- {ml4gw-0.3.0.dist-info → ml4gw-0.4.1.dist-info}/RECORD +11 -6
- {ml4gw-0.3.0.dist-info → ml4gw-0.4.1.dist-info}/WHEEL +1 -1
ml4gw/nn/norm.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
NormLayer = Callable[[int], torch.nn.Module]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GroupNorm1D(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Custom implementation of GroupNorm which is faster than the
|
|
11
|
+
out-of-the-box PyTorch version at inference time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_channels: int,
|
|
17
|
+
num_groups: Optional[int] = None,
|
|
18
|
+
eps: float = 1e-5,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
num_groups = num_groups or num_channels
|
|
22
|
+
if num_channels % num_groups:
|
|
23
|
+
raise ValueError("num_groups must be a factor of num_channels")
|
|
24
|
+
|
|
25
|
+
self.num_channels = num_channels
|
|
26
|
+
self.num_groups = num_groups
|
|
27
|
+
self.channels_per_group = self.num_channels // self.num_groups
|
|
28
|
+
self.eps = eps
|
|
29
|
+
|
|
30
|
+
shape = (self.num_channels, 1)
|
|
31
|
+
self.weight = torch.nn.Parameter(torch.ones(shape))
|
|
32
|
+
self.bias = torch.nn.Parameter(torch.zeros(shape))
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
keepdims = self.num_groups == self.num_channels
|
|
36
|
+
|
|
37
|
+
# compute group variance via the E[x**2] - E**2[x] trick
|
|
38
|
+
mean = x.mean(-1, keepdims=keepdims)
|
|
39
|
+
sq_mean = (x**2).mean(-1, keepdims=keepdims)
|
|
40
|
+
|
|
41
|
+
# if we have groups, do some reshape magic
|
|
42
|
+
# to calculate group level stats then
|
|
43
|
+
# reshape back to full channel dimension
|
|
44
|
+
if self.num_groups != self.num_channels:
|
|
45
|
+
mean = torch.stack([mean, sq_mean], dim=1)
|
|
46
|
+
mean = mean.reshape(
|
|
47
|
+
-1, 2, self.num_groups, self.channels_per_group
|
|
48
|
+
)
|
|
49
|
+
mean = mean.mean(-1, keepdims=True)
|
|
50
|
+
mean = mean.expand(-1, -1, -1, self.channels_per_group)
|
|
51
|
+
mean = mean.reshape(-1, 2, self.num_channels, 1)
|
|
52
|
+
mean, sq_mean = mean[:, 0], mean[:, 1]
|
|
53
|
+
|
|
54
|
+
# roll the mean and variance into the
|
|
55
|
+
# weight and bias so that we have to do
|
|
56
|
+
# fewer computations along the full time axis
|
|
57
|
+
std = (sq_mean - mean**2 + self.eps) ** 0.5
|
|
58
|
+
scale = self.weight / std
|
|
59
|
+
shift = self.bias - scale * mean
|
|
60
|
+
return shift + x * scale
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GroupNorm1DGetter:
|
|
64
|
+
"""
|
|
65
|
+
Utility for making a NormLayer Callable that maps from
|
|
66
|
+
an integer number of channels to a torch Module. Useful
|
|
67
|
+
for command-line parameterization with jsonargparse.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
71
|
+
self.groups = groups
|
|
72
|
+
|
|
73
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
74
|
+
if self.groups is None:
|
|
75
|
+
num_groups = None
|
|
76
|
+
else:
|
|
77
|
+
num_groups = min(num_channels, self.groups)
|
|
78
|
+
return GroupNorm1D(num_channels, num_groups)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# TODO generalize faster 1dDGroupNorm to 2D
|
|
82
|
+
class GroupNorm2DGetter:
|
|
83
|
+
"""
|
|
84
|
+
Utility for making a NormLayer Callable that maps from
|
|
85
|
+
an integer number of channels to a torch Module. Useful
|
|
86
|
+
for command-line parameterization with jsonargparse.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
90
|
+
self.groups = groups
|
|
91
|
+
|
|
92
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
93
|
+
if self.groups is None:
|
|
94
|
+
num_groups = num_channels
|
|
95
|
+
else:
|
|
96
|
+
num_groups = min(num_channels, self.groups)
|
|
97
|
+
return torch.nn.GroupNorm(num_groups, num_channels)
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In large part lifted from
|
|
3
|
+
https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
|
4
|
+
but with 1d convolutions and arbitrary kernel sizes, and a
|
|
5
|
+
default norm layer that makes more sense for most GW applications
|
|
6
|
+
where training-time statistics are entirely arbitrary due to
|
|
7
|
+
simulations.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Callable, List, Literal, Optional
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def convN(
|
|
20
|
+
in_planes: int,
|
|
21
|
+
out_planes: int,
|
|
22
|
+
kernel_size: int = 3,
|
|
23
|
+
stride: int = 1,
|
|
24
|
+
groups: int = 1,
|
|
25
|
+
dilation: int = 1,
|
|
26
|
+
) -> nn.Conv1d:
|
|
27
|
+
"""1d convolution with padding"""
|
|
28
|
+
if not kernel_size % 2:
|
|
29
|
+
raise ValueError("Can't use even sized kernels")
|
|
30
|
+
|
|
31
|
+
return nn.Conv1d(
|
|
32
|
+
in_planes,
|
|
33
|
+
out_planes,
|
|
34
|
+
kernel_size=kernel_size,
|
|
35
|
+
stride=stride,
|
|
36
|
+
padding=dilation * int(kernel_size // 2),
|
|
37
|
+
groups=groups,
|
|
38
|
+
bias=False,
|
|
39
|
+
dilation=dilation,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
|
|
44
|
+
"""Kernel-size 1 convolution"""
|
|
45
|
+
return nn.Conv1d(
|
|
46
|
+
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BasicBlock(nn.Module):
|
|
51
|
+
"""Defines the structure of the blocks used to build the ResNet"""
|
|
52
|
+
|
|
53
|
+
expansion: int = 1
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
inplanes: int,
|
|
58
|
+
planes: int,
|
|
59
|
+
kernel_size: int = 3,
|
|
60
|
+
stride: int = 1,
|
|
61
|
+
downsample: Optional[nn.Module] = None,
|
|
62
|
+
groups: int = 1,
|
|
63
|
+
base_width: int = 64,
|
|
64
|
+
dilation: int = 1,
|
|
65
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
|
|
68
|
+
super().__init__()
|
|
69
|
+
if norm_layer is None:
|
|
70
|
+
norm_layer = nn.BatchNorm1d
|
|
71
|
+
if groups != 1 or base_width != 64:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"BasicBlock only supports groups=1 and base_width=64"
|
|
74
|
+
)
|
|
75
|
+
if dilation > 1:
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
"Dilation > 1 not supported in BasicBlock"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Both self.conv1 and self.downsample layers
|
|
81
|
+
# downsample the input when stride != 1
|
|
82
|
+
self.conv1 = convN(inplanes, planes, kernel_size, stride)
|
|
83
|
+
self.bn1 = norm_layer(planes)
|
|
84
|
+
self.relu = nn.ReLU(inplace=True)
|
|
85
|
+
self.conv2 = convN(planes, planes, kernel_size)
|
|
86
|
+
self.bn2 = norm_layer(planes)
|
|
87
|
+
self.downsample = downsample
|
|
88
|
+
self.stride = stride
|
|
89
|
+
|
|
90
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
91
|
+
identity = x
|
|
92
|
+
|
|
93
|
+
out = self.conv1(x)
|
|
94
|
+
out = self.bn1(out)
|
|
95
|
+
out = self.relu(out)
|
|
96
|
+
|
|
97
|
+
out = self.conv2(out)
|
|
98
|
+
out = self.bn2(out)
|
|
99
|
+
|
|
100
|
+
if self.downsample is not None:
|
|
101
|
+
identity = self.downsample(x)
|
|
102
|
+
|
|
103
|
+
out += identity
|
|
104
|
+
out = self.relu(out)
|
|
105
|
+
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Bottleneck(nn.Module):
|
|
110
|
+
"""
|
|
111
|
+
Bottleneck blocks implement one extra convolution
|
|
112
|
+
compared to basic blocks. In this layers, the `planes`
|
|
113
|
+
parameter is generally meant to _downsize_ the number
|
|
114
|
+
of feature maps first, which then get expanded out to
|
|
115
|
+
`planes * Bottleneck.expansion` feature maps at the
|
|
116
|
+
output of the layer.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
expansion: int = 4
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
inplanes: int,
|
|
124
|
+
planes: int,
|
|
125
|
+
kernel_size: int = 3,
|
|
126
|
+
stride: int = 1,
|
|
127
|
+
downsample: Optional[nn.Module] = None,
|
|
128
|
+
groups: int = 1,
|
|
129
|
+
base_width: int = 64,
|
|
130
|
+
dilation: int = 1,
|
|
131
|
+
norm_layer: Optional[NormLayer] = None,
|
|
132
|
+
) -> None:
|
|
133
|
+
super().__init__()
|
|
134
|
+
if norm_layer is None:
|
|
135
|
+
norm_layer = nn.BatchNorm1d
|
|
136
|
+
|
|
137
|
+
width = int(planes * (base_width / 64.0)) * groups
|
|
138
|
+
|
|
139
|
+
# conv1 does no downsampling, just reduces the number of
|
|
140
|
+
# feature maps from inplanes to width (where width == planes)
|
|
141
|
+
# if groups == 1 and base_width == 64
|
|
142
|
+
self.conv1 = convN(inplanes, width, kernel_size)
|
|
143
|
+
self.bn1 = norm_layer(width)
|
|
144
|
+
|
|
145
|
+
# conv2 keeps the same number of feature maps,
|
|
146
|
+
# but downsamples along the time axis if stride
|
|
147
|
+
# or dilation > 1
|
|
148
|
+
self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
|
|
149
|
+
self.bn2 = norm_layer(width)
|
|
150
|
+
|
|
151
|
+
# conv3 expands the feature maps back out to planes * expansion
|
|
152
|
+
self.conv3 = conv1(width, planes * self.expansion)
|
|
153
|
+
self.bn3 = norm_layer(planes * self.expansion)
|
|
154
|
+
|
|
155
|
+
self.relu = nn.ReLU(inplace=True)
|
|
156
|
+
self.downsample = downsample
|
|
157
|
+
self.stride = stride
|
|
158
|
+
|
|
159
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
160
|
+
identity = x
|
|
161
|
+
|
|
162
|
+
out = self.conv1(x)
|
|
163
|
+
out = self.bn1(out)
|
|
164
|
+
out = self.relu(out)
|
|
165
|
+
|
|
166
|
+
out = self.conv2(out)
|
|
167
|
+
out = self.bn2(out)
|
|
168
|
+
out = self.relu(out)
|
|
169
|
+
|
|
170
|
+
out = self.conv3(out)
|
|
171
|
+
out = self.bn3(out)
|
|
172
|
+
|
|
173
|
+
if self.downsample is not None:
|
|
174
|
+
identity = self.downsample(x)
|
|
175
|
+
|
|
176
|
+
out += identity
|
|
177
|
+
out = self.relu(out)
|
|
178
|
+
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ResNet1D(nn.Module):
|
|
183
|
+
"""1D ResNet architecture
|
|
184
|
+
|
|
185
|
+
Simple extension of ResNet to 1D convolutions with
|
|
186
|
+
arbitrary kernel sizes to support the longer timeseries
|
|
187
|
+
used in BBH detection.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
in_channels:
|
|
191
|
+
The number of channels in input tensor.
|
|
192
|
+
layers:
|
|
193
|
+
A list representing the number of residual
|
|
194
|
+
blocks to include in each "layer" of the
|
|
195
|
+
network. Total layers (e.g. 50 in ResNet50)
|
|
196
|
+
is `2 + sum(layers) * factor`, where factor
|
|
197
|
+
is `2` for vanilla `ResNet` and `3` for
|
|
198
|
+
`BottleneckResNet`.
|
|
199
|
+
kernel_size:
|
|
200
|
+
The size of the convolutional kernel to
|
|
201
|
+
use in all residual layers. _NOT_ the size
|
|
202
|
+
of the input kernel to the network, which
|
|
203
|
+
is determined at run-time.
|
|
204
|
+
zero_init_residual:
|
|
205
|
+
Flag indicating whether to initialize the
|
|
206
|
+
weights of the batch-norm layer in each block
|
|
207
|
+
to 0 so that residuals are initialized as
|
|
208
|
+
identities. Can improve training results.
|
|
209
|
+
groups:
|
|
210
|
+
Number of convolutional groups to use in all
|
|
211
|
+
layers. Grouped convolutions induce local
|
|
212
|
+
connections between feature maps at subsequent
|
|
213
|
+
layers rather than global. Generally won't
|
|
214
|
+
need this to be >1, and wil raise an error if
|
|
215
|
+
>1 when using vanilla `ResNet`.
|
|
216
|
+
width_per_group:
|
|
217
|
+
Base width of each of the feature map groups,
|
|
218
|
+
which is scaled up by the typical expansion
|
|
219
|
+
factor at each layer of the network. Meaningless
|
|
220
|
+
for vanilla `ResNet`.
|
|
221
|
+
stride_type:
|
|
222
|
+
Whether to achieve downsampling on the time axis
|
|
223
|
+
by strided or dilated convolutions for each layer.
|
|
224
|
+
If left as `None`, strided convolutions will be
|
|
225
|
+
used at each layer. Otherwise, `stride_type` should
|
|
226
|
+
be one element shorter than `layers` and indicate either
|
|
227
|
+
`stride` or `dilation` for each layer after the first.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
block = BasicBlock
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
in_channels: int,
|
|
235
|
+
layers: List[int],
|
|
236
|
+
classes: int,
|
|
237
|
+
kernel_size: int = 3,
|
|
238
|
+
zero_init_residual: bool = False,
|
|
239
|
+
groups: int = 1,
|
|
240
|
+
width_per_group: int = 64,
|
|
241
|
+
stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
|
|
242
|
+
norm_layer: Optional[NormLayer] = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
super().__init__()
|
|
245
|
+
|
|
246
|
+
self.inplanes = 64
|
|
247
|
+
self.dilation = 1
|
|
248
|
+
|
|
249
|
+
# default to using InstanceNorm if no
|
|
250
|
+
# norm layer is provided explicitly
|
|
251
|
+
self._norm_layer = norm_layer or GroupNorm1DGetter()
|
|
252
|
+
|
|
253
|
+
# TODO: should we support passing a single string
|
|
254
|
+
# for simplicity here?
|
|
255
|
+
if stride_type is None:
|
|
256
|
+
# each element in the tuple indicates if we should replace
|
|
257
|
+
# the stride with a dilated convolution instead
|
|
258
|
+
stride_type = ["stride"] * (len(layers) - 1)
|
|
259
|
+
if len(stride_type) != (len(layers) - 1):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"'stride_type' should be None or a "
|
|
262
|
+
"{}-element tuple, got {}".format(len(layers) - 1, stride_type)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.groups = groups
|
|
266
|
+
self.base_width = width_per_group
|
|
267
|
+
|
|
268
|
+
# start with a basic conv-bn-relu-maxpool block
|
|
269
|
+
# to reduce the dimensionality before the heavy
|
|
270
|
+
# lifting starts
|
|
271
|
+
self.conv1 = nn.Conv1d(
|
|
272
|
+
in_channels,
|
|
273
|
+
self.inplanes,
|
|
274
|
+
kernel_size=7,
|
|
275
|
+
stride=2,
|
|
276
|
+
padding=3,
|
|
277
|
+
bias=False,
|
|
278
|
+
)
|
|
279
|
+
self.bn1 = self._norm_layer(self.inplanes)
|
|
280
|
+
self.relu = nn.ReLU(inplace=True)
|
|
281
|
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
|
282
|
+
|
|
283
|
+
# now create layers of residual blocks where each
|
|
284
|
+
# layer uses the same number of feature maps for
|
|
285
|
+
# all its blocks (some power of 2 times 64).
|
|
286
|
+
# Don't downsample along the time axis in the first
|
|
287
|
+
# layer, but downsample in all the rest (either by
|
|
288
|
+
# striding or dilating depending on the stride_type
|
|
289
|
+
# argument)
|
|
290
|
+
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
291
|
+
it = zip(layers[1:], stride_type)
|
|
292
|
+
for i, (num_blocks, stride) in enumerate(it):
|
|
293
|
+
block_size = 64 * 2 ** (i + 1)
|
|
294
|
+
layer = self._make_layer(
|
|
295
|
+
block_size,
|
|
296
|
+
num_blocks,
|
|
297
|
+
kernel_size,
|
|
298
|
+
stride=2,
|
|
299
|
+
stride_type=stride,
|
|
300
|
+
)
|
|
301
|
+
residual_layers.append(layer)
|
|
302
|
+
self.residual_layers = nn.ModuleList(residual_layers)
|
|
303
|
+
|
|
304
|
+
# Average pool over each feature map to create a
|
|
305
|
+
# single value for each feature map that we'll use
|
|
306
|
+
# in the fully connected head
|
|
307
|
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
|
308
|
+
|
|
309
|
+
# use a fully connected layer to map from the
|
|
310
|
+
# feature maps to the binary output that we need
|
|
311
|
+
self.fc = nn.Linear(block_size * self.block.expansion, classes)
|
|
312
|
+
|
|
313
|
+
for m in self.modules():
|
|
314
|
+
if isinstance(m, nn.Conv1d):
|
|
315
|
+
nn.init.kaiming_normal_(
|
|
316
|
+
m.weight, mode="fan_out", nonlinearity="relu"
|
|
317
|
+
)
|
|
318
|
+
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
|
|
319
|
+
nn.init.constant_(m.weight, 1)
|
|
320
|
+
nn.init.constant_(m.bias, 0)
|
|
321
|
+
|
|
322
|
+
# Zero-initialize the last BN in each residual branch,
|
|
323
|
+
# so that the residual branch starts with zeros,
|
|
324
|
+
# and each residual block behaves like an identity.
|
|
325
|
+
# This improves the model by 0.2~0.3% according to
|
|
326
|
+
# https://arxiv.org/abs/1706.02677
|
|
327
|
+
if zero_init_residual:
|
|
328
|
+
for m in self.modules():
|
|
329
|
+
if isinstance(m, Bottleneck):
|
|
330
|
+
nn.init.constant_(m.bn3.weight, 0)
|
|
331
|
+
elif isinstance(m, BasicBlock):
|
|
332
|
+
nn.init.constant_(m.bn2.weight, 0)
|
|
333
|
+
|
|
334
|
+
def _make_layer(
|
|
335
|
+
self,
|
|
336
|
+
planes: int,
|
|
337
|
+
blocks: int,
|
|
338
|
+
kernel_size: int = 3,
|
|
339
|
+
stride: int = 1,
|
|
340
|
+
stride_type: Literal["stride", "dilation"] = "stride",
|
|
341
|
+
) -> nn.Sequential:
|
|
342
|
+
block = self.block
|
|
343
|
+
norm_layer = self._norm_layer
|
|
344
|
+
downsample = None
|
|
345
|
+
previous_dilation = self.dilation
|
|
346
|
+
|
|
347
|
+
if stride_type == "dilation":
|
|
348
|
+
self.dilation *= stride
|
|
349
|
+
stride = 1
|
|
350
|
+
elif stride_type != "stride":
|
|
351
|
+
raise ValueError("Unknown stride type {stride}")
|
|
352
|
+
|
|
353
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
354
|
+
downsample = nn.Sequential(
|
|
355
|
+
conv1(self.inplanes, planes * block.expansion, stride),
|
|
356
|
+
norm_layer(planes * block.expansion),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
layers = []
|
|
360
|
+
layers.append(
|
|
361
|
+
block(
|
|
362
|
+
self.inplanes,
|
|
363
|
+
planes,
|
|
364
|
+
kernel_size,
|
|
365
|
+
stride,
|
|
366
|
+
downsample,
|
|
367
|
+
self.groups,
|
|
368
|
+
self.base_width,
|
|
369
|
+
previous_dilation,
|
|
370
|
+
norm_layer,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
self.inplanes = planes * block.expansion
|
|
374
|
+
for _ in range(1, blocks):
|
|
375
|
+
layers.append(
|
|
376
|
+
block(
|
|
377
|
+
self.inplanes,
|
|
378
|
+
planes,
|
|
379
|
+
kernel_size,
|
|
380
|
+
groups=self.groups,
|
|
381
|
+
base_width=self.base_width,
|
|
382
|
+
dilation=self.dilation,
|
|
383
|
+
norm_layer=norm_layer,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return nn.Sequential(*layers)
|
|
388
|
+
|
|
389
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
390
|
+
# See note [TorchScript super()]
|
|
391
|
+
x = self.conv1(x)
|
|
392
|
+
x = self.bn1(x)
|
|
393
|
+
x = self.relu(x)
|
|
394
|
+
x = self.maxpool(x)
|
|
395
|
+
|
|
396
|
+
for layer in self.residual_layers:
|
|
397
|
+
x = layer(x)
|
|
398
|
+
|
|
399
|
+
x = self.avgpool(x)
|
|
400
|
+
x = torch.flatten(x, 1)
|
|
401
|
+
x = self.fc(x)
|
|
402
|
+
|
|
403
|
+
return x
|
|
404
|
+
|
|
405
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
406
|
+
return self._forward_impl(x)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# TODO: implement as arg of ResNet instead?
|
|
410
|
+
class BottleneckResNet1D(ResNet1D):
|
|
411
|
+
"""A version of ResNet that uses bottleneck blocks"""
|
|
412
|
+
|
|
413
|
+
block = Bottleneck
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In large part lifted from
|
|
3
|
+
https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
|
4
|
+
but with arbitrary kernel sizes
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Callable, List, Literal, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from ml4gw.nn.norm import GroupNorm2DGetter, NormLayer
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def convN(
|
|
17
|
+
in_planes: int,
|
|
18
|
+
out_planes: int,
|
|
19
|
+
kernel_size: int = 3,
|
|
20
|
+
stride: int = 1,
|
|
21
|
+
groups: int = 1,
|
|
22
|
+
dilation: int = 1,
|
|
23
|
+
) -> nn.Conv2d:
|
|
24
|
+
"""2d convolution with padding"""
|
|
25
|
+
if not kernel_size % 2:
|
|
26
|
+
raise ValueError("Can't use even sized kernels")
|
|
27
|
+
|
|
28
|
+
return nn.Conv2d(
|
|
29
|
+
in_planes,
|
|
30
|
+
out_planes,
|
|
31
|
+
kernel_size=kernel_size,
|
|
32
|
+
stride=stride,
|
|
33
|
+
padding=dilation * int(kernel_size // 2),
|
|
34
|
+
groups=groups,
|
|
35
|
+
bias=False,
|
|
36
|
+
dilation=dilation,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
|
41
|
+
"""Kernel-size 1 convolution"""
|
|
42
|
+
return nn.Conv2d(
|
|
43
|
+
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BasicBlock(nn.Module):
|
|
48
|
+
"""Defines the structure of the blocks used to build the ResNet"""
|
|
49
|
+
|
|
50
|
+
expansion: int = 1
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
inplanes: int,
|
|
55
|
+
planes: int,
|
|
56
|
+
kernel_size: int = 3,
|
|
57
|
+
stride: int = 1,
|
|
58
|
+
downsample: Optional[nn.Module] = None,
|
|
59
|
+
groups: int = 1,
|
|
60
|
+
base_width: int = 64,
|
|
61
|
+
dilation: int = 1,
|
|
62
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
|
|
65
|
+
super().__init__()
|
|
66
|
+
if norm_layer is None:
|
|
67
|
+
norm_layer = nn.BatchNorm2d
|
|
68
|
+
if groups != 1 or base_width != 64:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"BasicBlock only supports groups=1 and base_width=64"
|
|
71
|
+
)
|
|
72
|
+
if dilation > 1:
|
|
73
|
+
raise NotImplementedError(
|
|
74
|
+
"Dilation > 1 not supported in BasicBlock"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Both self.conv1 and self.downsample layers
|
|
78
|
+
# downsample the input when stride != 1
|
|
79
|
+
self.conv1 = convN(inplanes, planes, kernel_size, stride)
|
|
80
|
+
self.bn1 = norm_layer(planes)
|
|
81
|
+
self.relu = nn.ReLU(inplace=True)
|
|
82
|
+
self.conv2 = convN(planes, planes, kernel_size)
|
|
83
|
+
self.bn2 = norm_layer(planes)
|
|
84
|
+
self.downsample = downsample
|
|
85
|
+
self.stride = stride
|
|
86
|
+
|
|
87
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
88
|
+
identity = x
|
|
89
|
+
|
|
90
|
+
out = self.conv1(x)
|
|
91
|
+
out = self.bn1(out)
|
|
92
|
+
out = self.relu(out)
|
|
93
|
+
|
|
94
|
+
out = self.conv2(out)
|
|
95
|
+
out = self.bn2(out)
|
|
96
|
+
|
|
97
|
+
if self.downsample is not None:
|
|
98
|
+
identity = self.downsample(x)
|
|
99
|
+
|
|
100
|
+
out += identity
|
|
101
|
+
out = self.relu(out)
|
|
102
|
+
|
|
103
|
+
return out
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Bottleneck(nn.Module):
|
|
107
|
+
"""
|
|
108
|
+
Bottleneck blocks implement one extra convolution
|
|
109
|
+
compared to basic blocks. In this layers, the `planes`
|
|
110
|
+
parameter is generally meant to _downsize_ the number
|
|
111
|
+
of feature maps first, which then get expanded out to
|
|
112
|
+
`planes * Bottleneck.expansion` feature maps at the
|
|
113
|
+
output of the layer.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
expansion: int = 4
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
inplanes: int,
|
|
121
|
+
planes: int,
|
|
122
|
+
kernel_size: int = 3,
|
|
123
|
+
stride: int = 1,
|
|
124
|
+
downsample: Optional[nn.Module] = None,
|
|
125
|
+
groups: int = 1,
|
|
126
|
+
base_width: int = 64,
|
|
127
|
+
dilation: int = 1,
|
|
128
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
if norm_layer is None:
|
|
132
|
+
norm_layer = nn.BatchNorm2d
|
|
133
|
+
|
|
134
|
+
width = int(planes * (base_width / 64.0)) * groups
|
|
135
|
+
|
|
136
|
+
# conv1 does no downsampling, just reduces the number of
|
|
137
|
+
# feature maps from inplanes to width (where width == planes)
|
|
138
|
+
# if groups == 1 and base_width == 64
|
|
139
|
+
self.conv1 = convN(inplanes, width, kernel_size)
|
|
140
|
+
self.bn1 = norm_layer(width)
|
|
141
|
+
|
|
142
|
+
# conv2 keeps the same number of feature maps,
|
|
143
|
+
# but downsamples along the time axis if stride
|
|
144
|
+
# or dilation > 1
|
|
145
|
+
self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
|
|
146
|
+
self.bn2 = norm_layer(width)
|
|
147
|
+
|
|
148
|
+
# conv3 expands the feature maps back out to planes * expansion
|
|
149
|
+
self.conv3 = conv1(width, planes * self.expansion)
|
|
150
|
+
self.bn3 = norm_layer(planes * self.expansion)
|
|
151
|
+
|
|
152
|
+
self.relu = nn.ReLU(inplace=True)
|
|
153
|
+
self.downsample = downsample
|
|
154
|
+
self.stride = stride
|
|
155
|
+
|
|
156
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
157
|
+
identity = x
|
|
158
|
+
|
|
159
|
+
out = self.conv1(x)
|
|
160
|
+
out = self.bn1(out)
|
|
161
|
+
out = self.relu(out)
|
|
162
|
+
|
|
163
|
+
out = self.conv2(out)
|
|
164
|
+
out = self.bn2(out)
|
|
165
|
+
out = self.relu(out)
|
|
166
|
+
|
|
167
|
+
out = self.conv3(out)
|
|
168
|
+
out = self.bn3(out)
|
|
169
|
+
|
|
170
|
+
if self.downsample is not None:
|
|
171
|
+
identity = self.downsample(x)
|
|
172
|
+
|
|
173
|
+
out += identity
|
|
174
|
+
out = self.relu(out)
|
|
175
|
+
|
|
176
|
+
return out
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class ResNet2D(nn.Module):
|
|
180
|
+
"""2D ResNet architecture
|
|
181
|
+
|
|
182
|
+
Simple extension of ResNet with arbitrary kernel sizes
|
|
183
|
+
to support the longer timeseries used in BBH detection.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
in_channels:
|
|
187
|
+
The number of channels in input tensor.
|
|
188
|
+
layers:
|
|
189
|
+
A list representing the number of residual
|
|
190
|
+
blocks to include in each "layer" of the
|
|
191
|
+
network. Total layers (e.g. 50 in ResNet50)
|
|
192
|
+
is `2 + sum(layers) * factor`, where factor
|
|
193
|
+
is `2` for vanilla `ResNet` and `3` for
|
|
194
|
+
`BottleneckResNet`.
|
|
195
|
+
kernel_size:
|
|
196
|
+
The size of the convolutional kernel to
|
|
197
|
+
use in all residual layers. _NOT_ the size
|
|
198
|
+
of the input kernel to the network, which
|
|
199
|
+
is determined at run-time.
|
|
200
|
+
zero_init_residual:
|
|
201
|
+
Flag indicating whether to initialize the
|
|
202
|
+
weights of the batch-norm layer in each block
|
|
203
|
+
to 0 so that residuals are initialized as
|
|
204
|
+
identities. Can improve training results.
|
|
205
|
+
groups:
|
|
206
|
+
Number of convolutional groups to use in all
|
|
207
|
+
layers. Grouped convolutions induce local
|
|
208
|
+
connections between feature maps at subsequent
|
|
209
|
+
layers rather than global. Generally won't
|
|
210
|
+
need this to be >1, and wil raise an error if
|
|
211
|
+
>1 when using vanilla `ResNet`.
|
|
212
|
+
width_per_group:
|
|
213
|
+
Base width of each of the feature map groups,
|
|
214
|
+
which is scaled up by the typical expansion
|
|
215
|
+
factor at each layer of the network. Meaningless
|
|
216
|
+
for vanilla `ResNet`.
|
|
217
|
+
stride_type:
|
|
218
|
+
Whether to achieve downsampling on the time axis
|
|
219
|
+
by strided or dilated convolutions for each layer.
|
|
220
|
+
If left as `None`, strided convolutions will be
|
|
221
|
+
used at each layer. Otherwise, `stride_type` should
|
|
222
|
+
be one element shorter than `layers` and indicate either
|
|
223
|
+
`stride` or `dilation` for each layer after the first.
|
|
224
|
+
norm_groups:
|
|
225
|
+
The number of groups to use in GroupNorm layers
|
|
226
|
+
throughout the model. If left as `-1`, the number
|
|
227
|
+
of groups will be equal to the number of channels,
|
|
228
|
+
making this equilavent to LayerNorm
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
block = BasicBlock
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
in_channels: int,
|
|
236
|
+
layers: List[int],
|
|
237
|
+
classes: int,
|
|
238
|
+
kernel_size: int = 3,
|
|
239
|
+
zero_init_residual: bool = False,
|
|
240
|
+
groups: int = 1,
|
|
241
|
+
width_per_group: int = 64,
|
|
242
|
+
stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
|
|
243
|
+
norm_layer: Optional[NormLayer] = None,
|
|
244
|
+
) -> None:
|
|
245
|
+
super().__init__()
|
|
246
|
+
# default to using InstanceNorm if no
|
|
247
|
+
# norm layer is provided explicitly
|
|
248
|
+
self._norm_layer = norm_layer or GroupNorm2DGetter()
|
|
249
|
+
|
|
250
|
+
self.inplanes = 64
|
|
251
|
+
self.dilation = 1
|
|
252
|
+
|
|
253
|
+
# TODO: should we support passing a single string
|
|
254
|
+
# for simplicity here?
|
|
255
|
+
if stride_type is None:
|
|
256
|
+
# each element in the tuple indicates if we should replace
|
|
257
|
+
# the stride with a dilated convolution instead
|
|
258
|
+
stride_type = ["stride"] * (len(layers) - 1)
|
|
259
|
+
if len(stride_type) != (len(layers) - 1):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"'stride_type' should be None or a "
|
|
262
|
+
"{}-element tuple, got {}".format(len(layers) - 1, stride_type)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.groups = groups
|
|
266
|
+
self.base_width = width_per_group
|
|
267
|
+
|
|
268
|
+
# start with a basic conv-bn-relu-maxpool block
|
|
269
|
+
# to reduce the dimensionality before the heavy
|
|
270
|
+
# lifting starts
|
|
271
|
+
self.conv1 = nn.Conv2d(
|
|
272
|
+
in_channels,
|
|
273
|
+
self.inplanes,
|
|
274
|
+
kernel_size=7,
|
|
275
|
+
stride=2,
|
|
276
|
+
padding=3,
|
|
277
|
+
bias=False,
|
|
278
|
+
)
|
|
279
|
+
self.bn1 = self._norm_layer(self.inplanes)
|
|
280
|
+
self.relu = nn.ReLU(inplace=True)
|
|
281
|
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
282
|
+
|
|
283
|
+
# now create layers of residual blocks where each
|
|
284
|
+
# layer uses the same number of feature maps for
|
|
285
|
+
# all its blocks (some power of 2 times 64).
|
|
286
|
+
# Don't downsample along the time axis in the first
|
|
287
|
+
# layer, but downsample in all the rest (either by
|
|
288
|
+
# striding or dilating depending on the stride_type
|
|
289
|
+
# argument)
|
|
290
|
+
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
291
|
+
it = zip(layers[1:], stride_type)
|
|
292
|
+
for i, (num_blocks, stride) in enumerate(it):
|
|
293
|
+
block_size = 64 * 2 ** (i + 1)
|
|
294
|
+
layer = self._make_layer(
|
|
295
|
+
block_size,
|
|
296
|
+
num_blocks,
|
|
297
|
+
kernel_size,
|
|
298
|
+
stride=2,
|
|
299
|
+
stride_type=stride,
|
|
300
|
+
)
|
|
301
|
+
residual_layers.append(layer)
|
|
302
|
+
self.residual_layers = nn.ModuleList(residual_layers)
|
|
303
|
+
|
|
304
|
+
# Average pool over each feature map to create a
|
|
305
|
+
# single value for each feature map that we'll use
|
|
306
|
+
# in the fully connected head
|
|
307
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
308
|
+
|
|
309
|
+
# use a fully connected layer to map from the
|
|
310
|
+
# feature maps to the binary output that we need
|
|
311
|
+
self.fc = nn.Linear(block_size * self.block.expansion, classes)
|
|
312
|
+
|
|
313
|
+
for m in self.modules():
|
|
314
|
+
if isinstance(m, nn.Conv2d):
|
|
315
|
+
nn.init.kaiming_normal_(
|
|
316
|
+
m.weight, mode="fan_out", nonlinearity="relu"
|
|
317
|
+
)
|
|
318
|
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
319
|
+
nn.init.constant_(m.weight, 1)
|
|
320
|
+
nn.init.constant_(m.bias, 0)
|
|
321
|
+
|
|
322
|
+
# Zero-initialize the last BN in each residual branch,
|
|
323
|
+
# so that the residual branch starts with zeros,
|
|
324
|
+
# and each residual block behaves like an identity.
|
|
325
|
+
# This improves the model by 0.2~0.3% according to
|
|
326
|
+
# https://arxiv.org/abs/1706.02677
|
|
327
|
+
if zero_init_residual:
|
|
328
|
+
for m in self.modules():
|
|
329
|
+
if isinstance(m, Bottleneck):
|
|
330
|
+
nn.init.constant_(m.bn3.weight, 0)
|
|
331
|
+
elif isinstance(m, BasicBlock):
|
|
332
|
+
nn.init.constant_(m.bn2.weight, 0)
|
|
333
|
+
|
|
334
|
+
def _make_layer(
|
|
335
|
+
self,
|
|
336
|
+
planes: int,
|
|
337
|
+
blocks: int,
|
|
338
|
+
kernel_size: int = 3,
|
|
339
|
+
stride: int = 1,
|
|
340
|
+
stride_type: Literal["stride", "dilation"] = "stride",
|
|
341
|
+
) -> nn.Sequential:
|
|
342
|
+
block = self.block
|
|
343
|
+
norm_layer = self._norm_layer
|
|
344
|
+
downsample = None
|
|
345
|
+
previous_dilation = self.dilation
|
|
346
|
+
|
|
347
|
+
if stride_type == "dilation":
|
|
348
|
+
self.dilation *= stride
|
|
349
|
+
stride = 1
|
|
350
|
+
elif stride_type != "stride":
|
|
351
|
+
raise ValueError("Unknown stride type {stride}")
|
|
352
|
+
|
|
353
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
354
|
+
downsample = nn.Sequential(
|
|
355
|
+
conv1(self.inplanes, planes * block.expansion, stride),
|
|
356
|
+
norm_layer(planes * block.expansion),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
layers = []
|
|
360
|
+
layers.append(
|
|
361
|
+
block(
|
|
362
|
+
self.inplanes,
|
|
363
|
+
planes,
|
|
364
|
+
kernel_size,
|
|
365
|
+
stride,
|
|
366
|
+
downsample,
|
|
367
|
+
self.groups,
|
|
368
|
+
self.base_width,
|
|
369
|
+
previous_dilation,
|
|
370
|
+
norm_layer,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
self.inplanes = planes * block.expansion
|
|
374
|
+
for _ in range(1, blocks):
|
|
375
|
+
layers.append(
|
|
376
|
+
block(
|
|
377
|
+
self.inplanes,
|
|
378
|
+
planes,
|
|
379
|
+
kernel_size,
|
|
380
|
+
groups=self.groups,
|
|
381
|
+
base_width=self.base_width,
|
|
382
|
+
dilation=self.dilation,
|
|
383
|
+
norm_layer=norm_layer,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return nn.Sequential(*layers)
|
|
388
|
+
|
|
389
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
390
|
+
# See note [TorchScript super()]
|
|
391
|
+
x = self.conv1(x)
|
|
392
|
+
x = self.bn1(x)
|
|
393
|
+
x = self.relu(x)
|
|
394
|
+
x = self.maxpool(x)
|
|
395
|
+
|
|
396
|
+
for layer in self.residual_layers:
|
|
397
|
+
x = layer(x)
|
|
398
|
+
|
|
399
|
+
x = self.avgpool(x)
|
|
400
|
+
x = torch.flatten(x, 1)
|
|
401
|
+
x = self.fc(x)
|
|
402
|
+
|
|
403
|
+
return x
|
|
404
|
+
|
|
405
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
406
|
+
return self._forward_impl(x)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# TODO: implement as arg of ResNet instead?
|
|
410
|
+
class BottleneckResNet2D(ResNet2D):
|
|
411
|
+
"""A version of ResNet that uses bottleneck blocks"""
|
|
412
|
+
|
|
413
|
+
block = Bottleneck
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -2,5 +2,6 @@ from .pearson import ShiftedPearsonCorrelation
|
|
|
2
2
|
from .scaler import ChannelWiseScaler
|
|
3
3
|
from .snr_rescaler import SnrRescaler
|
|
4
4
|
from .spectral import SpectralDensity
|
|
5
|
+
from .spectrogram import MultiResolutionSpectrogram
|
|
5
6
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
6
7
|
from .whitening import FixedWhiten, Whiten
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torchaudio.transforms import Spectrogram
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MultiResolutionSpectrogram(torch.nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Create a batch of multi-resolution spectrograms
|
|
12
|
+
from a batch of timeseries. Input is expected to
|
|
13
|
+
have the shape `(B, C, T)`, where `B` is the number
|
|
14
|
+
of batches, `C` is the number of channels, and `T`
|
|
15
|
+
is the number of time samples.
|
|
16
|
+
|
|
17
|
+
For each timeseries, calculate multiple normalized
|
|
18
|
+
spectrograms based on the `Spectrogram` `kwargs` given.
|
|
19
|
+
Combine the spectrograms by taking the maximum value
|
|
20
|
+
from the nearest time-frequncy bin.
|
|
21
|
+
|
|
22
|
+
If the largest number of time bins among the spectrograms
|
|
23
|
+
is `N` and the largest number of frequency bins is `M`,
|
|
24
|
+
the output will have dimensions `(B, C, M, N)`
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
kernel_length:
|
|
28
|
+
The length in seconds of the time dimension
|
|
29
|
+
of the tensor that will be turned into a
|
|
30
|
+
spectrogram
|
|
31
|
+
sample_rate:
|
|
32
|
+
The sample rate of the timeseries in Hz
|
|
33
|
+
kwargs:
|
|
34
|
+
Arguments passed in kwargs will used to create
|
|
35
|
+
`torchaudio.transforms.Spectrogram`s. Each
|
|
36
|
+
argument should be a list of values. Any list
|
|
37
|
+
of length greater than 1 should be the same
|
|
38
|
+
length
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self, kernel_length: float, sample_rate: float, **kwargs
|
|
43
|
+
) -> None:
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.kernel_size = kernel_length * sample_rate
|
|
46
|
+
# This method of combination makes sense only when
|
|
47
|
+
# the spectrograms are normalized, so enforce this
|
|
48
|
+
if "normalized" in kwargs.keys():
|
|
49
|
+
if not all(kwargs["normalized"]):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Received a value of False for 'normalized'. "
|
|
52
|
+
"This method of combination is sensible only for "
|
|
53
|
+
"normalized spectrograms."
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
kwargs["normalized"] = [True]
|
|
57
|
+
self.kwargs = self._check_and_format_kwargs(kwargs)
|
|
58
|
+
|
|
59
|
+
self.transforms = torch.nn.ModuleList(
|
|
60
|
+
[Spectrogram(**k) for k in self.kwargs]
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
dummy_input = torch.ones(int(kernel_length * sample_rate))
|
|
64
|
+
self.shapes = torch.tensor(
|
|
65
|
+
[t(dummy_input).shape for t in self.transforms]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.num_freqs = max([shape[0] for shape in self.shapes])
|
|
69
|
+
self.num_times = max([shape[1] for shape in self.shapes])
|
|
70
|
+
|
|
71
|
+
left_pad = torch.zeros(len(self.transforms), dtype=torch.int)
|
|
72
|
+
top_pad = torch.zeros(len(self.transforms), dtype=torch.int)
|
|
73
|
+
bottom_pad = torch.tensor(
|
|
74
|
+
[int(self.num_freqs - shape[0]) for shape in self.shapes]
|
|
75
|
+
)
|
|
76
|
+
right_pad = torch.tensor(
|
|
77
|
+
[int(self.num_times - shape[1]) for shape in self.shapes]
|
|
78
|
+
)
|
|
79
|
+
self.register_buffer("left_pad", left_pad)
|
|
80
|
+
self.register_buffer("top_pad", top_pad)
|
|
81
|
+
self.register_buffer("bottom_pad", bottom_pad)
|
|
82
|
+
self.register_buffer("right_pad", right_pad)
|
|
83
|
+
|
|
84
|
+
freq_idxs = torch.tensor(
|
|
85
|
+
[
|
|
86
|
+
[int(i * shape[0] / self.num_freqs) for shape in self.shapes]
|
|
87
|
+
for i in range(self.num_freqs)
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
freq_idxs = freq_idxs.repeat(self.num_times, 1, 1).transpose(0, 1)
|
|
91
|
+
time_idxs = torch.tensor(
|
|
92
|
+
[
|
|
93
|
+
[int(i * shape[1] / self.num_times) for shape in self.shapes]
|
|
94
|
+
for i in range(self.num_times)
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
time_idxs = time_idxs.repeat(self.num_freqs, 1, 1)
|
|
98
|
+
|
|
99
|
+
self.register_buffer("freq_idxs", freq_idxs)
|
|
100
|
+
self.register_buffer("time_idxs", time_idxs)
|
|
101
|
+
|
|
102
|
+
def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
|
|
103
|
+
lengths = sorted(set([len(v) for v in kwargs.values()]))
|
|
104
|
+
|
|
105
|
+
if lengths[-1] > 3:
|
|
106
|
+
warnings.warn(
|
|
107
|
+
"Combining too many spectrograms can impede computation time. "
|
|
108
|
+
"If performance is slower than desired, try reducing the "
|
|
109
|
+
"number of spectrograms",
|
|
110
|
+
RuntimeWarning,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if len(lengths) > 2 or (len(lengths) == 2 and lengths[0] != 1):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Spectrogram keyword args should all have the same "
|
|
116
|
+
f"length or be of length one. Got lengths {lengths}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if len(lengths) == 2:
|
|
120
|
+
size = lengths[1]
|
|
121
|
+
kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
|
|
122
|
+
|
|
123
|
+
return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
|
|
124
|
+
|
|
125
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Calculate spectrograms of the input tensor and
|
|
128
|
+
combine them into a single spectrogram
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
X:
|
|
132
|
+
Batch of multichannel timeseries which will
|
|
133
|
+
be used to calculate the multi-resolution
|
|
134
|
+
spectrogram. Should have the shape
|
|
135
|
+
`(B, C, T)`, where `B` is the number of
|
|
136
|
+
batches, `C` is the number of channels,
|
|
137
|
+
and `T` is the number of time samples.
|
|
138
|
+
"""
|
|
139
|
+
if X.shape[-1] != self.kernel_size:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"Expected time dimension to be "
|
|
142
|
+
f"{self.kernel_size} samples long, got input with "
|
|
143
|
+
f"{X.shape[-1]} samples"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
spectrograms = [t(X) for t in self.transforms]
|
|
147
|
+
|
|
148
|
+
padded_specs = []
|
|
149
|
+
for spec, left, right, top, bottom in zip(
|
|
150
|
+
spectrograms,
|
|
151
|
+
self.left_pad,
|
|
152
|
+
self.right_pad,
|
|
153
|
+
self.top_pad,
|
|
154
|
+
self.bottom_pad,
|
|
155
|
+
):
|
|
156
|
+
padded_specs.append(F.pad(spec, (left, right, top, bottom)))
|
|
157
|
+
|
|
158
|
+
padded_specs = torch.stack(padded_specs)
|
|
159
|
+
remapped_specs = padded_specs[..., self.freq_idxs, self.time_idxs]
|
|
160
|
+
remapped_specs = torch.diagonal(remapped_specs, dim1=0, dim2=-1)
|
|
161
|
+
|
|
162
|
+
return torch.max(remapped_specs, axis=-1)[0]
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -123,7 +123,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
123
123
|
num_channels: float,
|
|
124
124
|
kernel_length: float,
|
|
125
125
|
sample_rate: float,
|
|
126
|
-
dtype: torch.dtype = torch.
|
|
126
|
+
dtype: torch.dtype = torch.float64,
|
|
127
127
|
) -> None:
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.num_channels = num_channels
|
ml4gw/waveforms/phenom_d.py
CHANGED
|
@@ -477,6 +477,7 @@ def rho3_fun(eta, eta2, xi):
|
|
|
477
477
|
|
|
478
478
|
def FinalSpin0815(eta, eta2, chi1, chi2):
|
|
479
479
|
Seta = torch.sqrt(1.0 - 4.0 * eta)
|
|
480
|
+
Seta = torch.nan_to_num(Seta) # avoid nan around eta = 0.25
|
|
480
481
|
m1 = 0.5 * (1.0 + Seta)
|
|
481
482
|
m2 = 0.5 * (1.0 - Seta)
|
|
482
483
|
m1s = m1 * m1
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author: Alec Gunny
|
|
6
6
|
Author-email: alec.gunny@ligo.org
|
|
7
|
-
Requires-Python: >=3.8,<
|
|
7
|
+
Requires-Python: >=3.8,<3.12
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.8
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.9
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
|
|
14
|
-
Requires-Dist:
|
|
13
|
+
Requires-Dist: torch (>=2.0,<3.0)
|
|
14
|
+
Requires-Dist: torchaudio (>=2.0,<3.0)
|
|
15
15
|
Requires-Dist: torchtyping (>=0.1,<0.2)
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
|
|
@@ -38,8 +38,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
|
|
|
38
38
|
|
|
39
39
|
```toml
|
|
40
40
|
[tool.poetry.dependencies]
|
|
41
|
-
python = "^3.8" # python versions 3.8-3.
|
|
42
|
-
ml4gw = "^0.
|
|
41
|
+
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
42
|
+
ml4gw = "^0.3.0"
|
|
43
43
|
```
|
|
44
44
|
|
|
45
45
|
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
@@ -47,7 +47,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
|
|
|
47
47
|
```toml
|
|
48
48
|
[tool.poetry.dependencies]
|
|
49
49
|
python = "^3.8"
|
|
50
|
-
ml4gw = "^0.
|
|
50
|
+
ml4gw = "^0.3.0"
|
|
51
51
|
torch = {version = "^1.12", source = "torch"}
|
|
52
52
|
|
|
53
53
|
[[tool.poetry.source]]
|
|
@@ -57,6 +57,8 @@ secondary = true
|
|
|
57
57
|
default = false
|
|
58
58
|
```
|
|
59
59
|
|
|
60
|
+
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
61
|
+
|
|
60
62
|
## Use cases
|
|
61
63
|
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
62
64
|
|
|
@@ -146,3 +148,6 @@ We also strongly encourage ML users in the GW physics space to try their hand at
|
|
|
146
148
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
147
149
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
148
150
|
|
|
151
|
+
## Funding
|
|
152
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
153
|
+
|
|
@@ -12,27 +12,32 @@ ml4gw/nn/autoencoder/base.py,sha256=PLr26Cn5DHmgDYX1qj4idfrLehHVeiJqer065ea8_QM,
|
|
|
12
12
|
ml4gw/nn/autoencoder/convolutional.py,sha256=JTMpTJVdFju9HPPAh9UDdXG1MsFbADrqUIKM8_xg74E,5316
|
|
13
13
|
ml4gw/nn/autoencoder/skip_connection.py,sha256=bOKBLzMqZDh9w8s9G5U93LCESjTSFUHzQGo0hLDOeSk,1304
|
|
14
14
|
ml4gw/nn/autoencoder/utils.py,sha256=whTnWPvdKuVDlxg52azJeM1d9YjiYFWoqIOzJVDGups,326
|
|
15
|
+
ml4gw/nn/norm.py,sha256=9IHZTCCp4zgP7EaGpw1FpAm7o0EU5zu-LYFHKfuLzzw,3250
|
|
16
|
+
ml4gw/nn/resnet/__init__.py,sha256=vBI0IftVP_EYAeDlqomtkGqUYE-RE_S4WNioUhniw9s,64
|
|
17
|
+
ml4gw/nn/resnet/resnet_1d.py,sha256=IQ-EIIzAXd-NWuLwt7JTXLWg5bO3FGJpuFAZwZ78jaI,13218
|
|
18
|
+
ml4gw/nn/resnet/resnet_2d.py,sha256=aK4I0FOZk62JxnYFz0t1O0s5s7J7yRNYSM1flRypvVc,13301
|
|
15
19
|
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
16
20
|
ml4gw/nn/streaming/online_average.py,sha256=T-wWw7eEufbUVPRNnLAXIq0cedAyJWEE9tdZ6CTi3cs,4561
|
|
17
21
|
ml4gw/nn/streaming/snapshotter.py,sha256=-l_YsWby7ZnEzGIAlLAV2mtR0daLMtLCxovtt4OI3Z0,4432
|
|
18
22
|
ml4gw/spectral.py,sha256=5GfKAV_1vw5yyzTD2u_myjT5jIlAyAHDX6TXj9ynL_o,19021
|
|
19
|
-
ml4gw/transforms/__init__.py,sha256=
|
|
23
|
+
ml4gw/transforms/__init__.py,sha256=t6ZJcq23apqDKhLGM-U5l_bqxJcXFj3riY6cTGY47Gc,314
|
|
20
24
|
ml4gw/transforms/pearson.py,sha256=bJ77lO4wBY6y1R1aESN_bcUEMbc55hWCIaCBdbIj4CY,3133
|
|
21
25
|
ml4gw/transforms/scaler.py,sha256=5VGov0M80NZostRzccViC3HNftx4ZVu0kOKTDmiLrR4,2327
|
|
22
26
|
ml4gw/transforms/snr_rescaler.py,sha256=ocYr6UjpHW7t5TvruV7fyY8KuuDfGOJyvxEulmiFA6o,2275
|
|
23
27
|
ml4gw/transforms/spectral.py,sha256=Vba9199z_ZaxsHWxdpgHB3U216rmGoSyehtvM3R9Z7A,3771
|
|
28
|
+
ml4gw/transforms/spectrogram.py,sha256=R3O8eUB6NHdBFx89v8e_WdJIvXl4qwVeGWZnPyLhHHQ,6024
|
|
24
29
|
ml4gw/transforms/transform.py,sha256=jEr9OFj4u7Wjeh_rpRq90jMpK_TfzcIelbBmt30DxQU,2408
|
|
25
30
|
ml4gw/transforms/waveforms.py,sha256=iyEDSRqK_1zZrxxJenJFbwGUWqbE-alVTXhvjaGl1ww,3060
|
|
26
|
-
ml4gw/transforms/whitening.py,sha256=
|
|
31
|
+
ml4gw/transforms/whitening.py,sha256=TmvFCCeTOcSEWo5Pt_JQRJ23X5byiJ91q5jHgBRy0rc,9428
|
|
27
32
|
ml4gw/types.py,sha256=XbxunX8zRF95Fp1mZ9jEbixb63bwDQMoayRMMxT9Lzo,429
|
|
28
33
|
ml4gw/utils/interferometer.py,sha256=w_0WkboCJZMKAg-4lhiNGOOkNogAghpT96I0TE5aJ1g,1519
|
|
29
34
|
ml4gw/utils/slicing.py,sha256=Cbwcpk_0hsfN4zczFVM2YbDRjeirA7jFvApM4Jy0U8s,13535
|
|
30
35
|
ml4gw/waveforms/__init__.py,sha256=zjqOKNY4z1A5iPhWTxyhnkLh2robB-obPTtaK-pDUoU,104
|
|
31
36
|
ml4gw/waveforms/generator.py,sha256=4Z6vUEuI84t__3t0DDnXlOyB8R96ynf8xFvtwCGu9JA,1057
|
|
32
|
-
ml4gw/waveforms/phenom_d.py,sha256=
|
|
37
|
+
ml4gw/waveforms/phenom_d.py,sha256=pxHk7paW5709Ak29m_DYeQ8kiMLC8wrUnM13flUU36o,38419
|
|
33
38
|
ml4gw/waveforms/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9n25he9U,53447
|
|
34
39
|
ml4gw/waveforms/sine_gaussian.py,sha256=WZ6KiVEFSjB9Tv5otJbvI_Yr3341th1Noec_LB9kPOE,3577
|
|
35
40
|
ml4gw/waveforms/taylorf2.py,sha256=x3drvKUMarWI9xHUzMRQhVp1Hh7X-j5WC2bdsbEiVfk,8482
|
|
36
|
-
ml4gw-0.
|
|
37
|
-
ml4gw-0.
|
|
38
|
-
ml4gw-0.
|
|
41
|
+
ml4gw-0.4.1.dist-info/METADATA,sha256=KI3VTUKW8-DASUnqtZpFeviTVoWwqq0I75tdElQMWBo,5706
|
|
42
|
+
ml4gw-0.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
43
|
+
ml4gw-0.4.1.dist-info/RECORD,,
|