ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +43 -0
- ml4gw/dataloading/__init__.py +2 -1
- ml4gw/dataloading/chunked_dataset.py +66 -212
- ml4gw/dataloading/hdf5_dataset.py +176 -0
- ml4gw/nn/__init__.py +0 -0
- ml4gw/nn/autoencoder/__init__.py +3 -0
- ml4gw/nn/autoencoder/base.py +89 -0
- ml4gw/nn/autoencoder/convolutional.py +156 -0
- ml4gw/nn/autoencoder/skip_connection.py +46 -0
- ml4gw/nn/autoencoder/utils.py +14 -0
- ml4gw/nn/norm.py +97 -0
- ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw/nn/resnet/resnet_2d.py +413 -0
- ml4gw/nn/streaming/__init__.py +2 -0
- ml4gw/nn/streaming/online_average.py +121 -0
- ml4gw/nn/streaming/snapshotter.py +121 -0
- ml4gw/transforms/__init__.py +2 -0
- ml4gw/transforms/pearson.py +87 -0
- ml4gw/transforms/spectrogram.py +162 -0
- ml4gw/transforms/whitening.py +1 -1
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/phenom_d.py +1359 -0
- ml4gw/waveforms/phenom_d_data.py +3026 -0
- ml4gw/waveforms/taylorf2.py +306 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/METADATA +14 -6
- ml4gw-0.4.0.dist-info/RECORD +43 -0
- {ml4gw-0.2.0.dist-info → ml4gw-0.4.0.dist-info}/WHEEL +1 -1
- ml4gw-0.2.0.dist-info/RECORD +0 -23
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In large part lifted from
|
|
3
|
+
https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
|
4
|
+
but with arbitrary kernel sizes
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Callable, List, Literal, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from ml4gw.nn.norm import GroupNorm2DGetter, NormLayer
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def convN(
|
|
17
|
+
in_planes: int,
|
|
18
|
+
out_planes: int,
|
|
19
|
+
kernel_size: int = 3,
|
|
20
|
+
stride: int = 1,
|
|
21
|
+
groups: int = 1,
|
|
22
|
+
dilation: int = 1,
|
|
23
|
+
) -> nn.Conv2d:
|
|
24
|
+
"""2d convolution with padding"""
|
|
25
|
+
if not kernel_size % 2:
|
|
26
|
+
raise ValueError("Can't use even sized kernels")
|
|
27
|
+
|
|
28
|
+
return nn.Conv2d(
|
|
29
|
+
in_planes,
|
|
30
|
+
out_planes,
|
|
31
|
+
kernel_size=kernel_size,
|
|
32
|
+
stride=stride,
|
|
33
|
+
padding=dilation * int(kernel_size // 2),
|
|
34
|
+
groups=groups,
|
|
35
|
+
bias=False,
|
|
36
|
+
dilation=dilation,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
|
41
|
+
"""Kernel-size 1 convolution"""
|
|
42
|
+
return nn.Conv2d(
|
|
43
|
+
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BasicBlock(nn.Module):
|
|
48
|
+
"""Defines the structure of the blocks used to build the ResNet"""
|
|
49
|
+
|
|
50
|
+
expansion: int = 1
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
inplanes: int,
|
|
55
|
+
planes: int,
|
|
56
|
+
kernel_size: int = 3,
|
|
57
|
+
stride: int = 1,
|
|
58
|
+
downsample: Optional[nn.Module] = None,
|
|
59
|
+
groups: int = 1,
|
|
60
|
+
base_width: int = 64,
|
|
61
|
+
dilation: int = 1,
|
|
62
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
|
|
65
|
+
super().__init__()
|
|
66
|
+
if norm_layer is None:
|
|
67
|
+
norm_layer = nn.BatchNorm2d
|
|
68
|
+
if groups != 1 or base_width != 64:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"BasicBlock only supports groups=1 and base_width=64"
|
|
71
|
+
)
|
|
72
|
+
if dilation > 1:
|
|
73
|
+
raise NotImplementedError(
|
|
74
|
+
"Dilation > 1 not supported in BasicBlock"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Both self.conv1 and self.downsample layers
|
|
78
|
+
# downsample the input when stride != 1
|
|
79
|
+
self.conv1 = convN(inplanes, planes, kernel_size, stride)
|
|
80
|
+
self.bn1 = norm_layer(planes)
|
|
81
|
+
self.relu = nn.ReLU(inplace=True)
|
|
82
|
+
self.conv2 = convN(planes, planes, kernel_size)
|
|
83
|
+
self.bn2 = norm_layer(planes)
|
|
84
|
+
self.downsample = downsample
|
|
85
|
+
self.stride = stride
|
|
86
|
+
|
|
87
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
88
|
+
identity = x
|
|
89
|
+
|
|
90
|
+
out = self.conv1(x)
|
|
91
|
+
out = self.bn1(out)
|
|
92
|
+
out = self.relu(out)
|
|
93
|
+
|
|
94
|
+
out = self.conv2(out)
|
|
95
|
+
out = self.bn2(out)
|
|
96
|
+
|
|
97
|
+
if self.downsample is not None:
|
|
98
|
+
identity = self.downsample(x)
|
|
99
|
+
|
|
100
|
+
out += identity
|
|
101
|
+
out = self.relu(out)
|
|
102
|
+
|
|
103
|
+
return out
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Bottleneck(nn.Module):
|
|
107
|
+
"""
|
|
108
|
+
Bottleneck blocks implement one extra convolution
|
|
109
|
+
compared to basic blocks. In this layers, the `planes`
|
|
110
|
+
parameter is generally meant to _downsize_ the number
|
|
111
|
+
of feature maps first, which then get expanded out to
|
|
112
|
+
`planes * Bottleneck.expansion` feature maps at the
|
|
113
|
+
output of the layer.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
expansion: int = 4
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
inplanes: int,
|
|
121
|
+
planes: int,
|
|
122
|
+
kernel_size: int = 3,
|
|
123
|
+
stride: int = 1,
|
|
124
|
+
downsample: Optional[nn.Module] = None,
|
|
125
|
+
groups: int = 1,
|
|
126
|
+
base_width: int = 64,
|
|
127
|
+
dilation: int = 1,
|
|
128
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
if norm_layer is None:
|
|
132
|
+
norm_layer = nn.BatchNorm2d
|
|
133
|
+
|
|
134
|
+
width = int(planes * (base_width / 64.0)) * groups
|
|
135
|
+
|
|
136
|
+
# conv1 does no downsampling, just reduces the number of
|
|
137
|
+
# feature maps from inplanes to width (where width == planes)
|
|
138
|
+
# if groups == 1 and base_width == 64
|
|
139
|
+
self.conv1 = convN(inplanes, width, kernel_size)
|
|
140
|
+
self.bn1 = norm_layer(width)
|
|
141
|
+
|
|
142
|
+
# conv2 keeps the same number of feature maps,
|
|
143
|
+
# but downsamples along the time axis if stride
|
|
144
|
+
# or dilation > 1
|
|
145
|
+
self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
|
|
146
|
+
self.bn2 = norm_layer(width)
|
|
147
|
+
|
|
148
|
+
# conv3 expands the feature maps back out to planes * expansion
|
|
149
|
+
self.conv3 = conv1(width, planes * self.expansion)
|
|
150
|
+
self.bn3 = norm_layer(planes * self.expansion)
|
|
151
|
+
|
|
152
|
+
self.relu = nn.ReLU(inplace=True)
|
|
153
|
+
self.downsample = downsample
|
|
154
|
+
self.stride = stride
|
|
155
|
+
|
|
156
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
157
|
+
identity = x
|
|
158
|
+
|
|
159
|
+
out = self.conv1(x)
|
|
160
|
+
out = self.bn1(out)
|
|
161
|
+
out = self.relu(out)
|
|
162
|
+
|
|
163
|
+
out = self.conv2(out)
|
|
164
|
+
out = self.bn2(out)
|
|
165
|
+
out = self.relu(out)
|
|
166
|
+
|
|
167
|
+
out = self.conv3(out)
|
|
168
|
+
out = self.bn3(out)
|
|
169
|
+
|
|
170
|
+
if self.downsample is not None:
|
|
171
|
+
identity = self.downsample(x)
|
|
172
|
+
|
|
173
|
+
out += identity
|
|
174
|
+
out = self.relu(out)
|
|
175
|
+
|
|
176
|
+
return out
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class ResNet2D(nn.Module):
|
|
180
|
+
"""2D ResNet architecture
|
|
181
|
+
|
|
182
|
+
Simple extension of ResNet with arbitrary kernel sizes
|
|
183
|
+
to support the longer timeseries used in BBH detection.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
in_channels:
|
|
187
|
+
The number of channels in input tensor.
|
|
188
|
+
layers:
|
|
189
|
+
A list representing the number of residual
|
|
190
|
+
blocks to include in each "layer" of the
|
|
191
|
+
network. Total layers (e.g. 50 in ResNet50)
|
|
192
|
+
is `2 + sum(layers) * factor`, where factor
|
|
193
|
+
is `2` for vanilla `ResNet` and `3` for
|
|
194
|
+
`BottleneckResNet`.
|
|
195
|
+
kernel_size:
|
|
196
|
+
The size of the convolutional kernel to
|
|
197
|
+
use in all residual layers. _NOT_ the size
|
|
198
|
+
of the input kernel to the network, which
|
|
199
|
+
is determined at run-time.
|
|
200
|
+
zero_init_residual:
|
|
201
|
+
Flag indicating whether to initialize the
|
|
202
|
+
weights of the batch-norm layer in each block
|
|
203
|
+
to 0 so that residuals are initialized as
|
|
204
|
+
identities. Can improve training results.
|
|
205
|
+
groups:
|
|
206
|
+
Number of convolutional groups to use in all
|
|
207
|
+
layers. Grouped convolutions induce local
|
|
208
|
+
connections between feature maps at subsequent
|
|
209
|
+
layers rather than global. Generally won't
|
|
210
|
+
need this to be >1, and wil raise an error if
|
|
211
|
+
>1 when using vanilla `ResNet`.
|
|
212
|
+
width_per_group:
|
|
213
|
+
Base width of each of the feature map groups,
|
|
214
|
+
which is scaled up by the typical expansion
|
|
215
|
+
factor at each layer of the network. Meaningless
|
|
216
|
+
for vanilla `ResNet`.
|
|
217
|
+
stride_type:
|
|
218
|
+
Whether to achieve downsampling on the time axis
|
|
219
|
+
by strided or dilated convolutions for each layer.
|
|
220
|
+
If left as `None`, strided convolutions will be
|
|
221
|
+
used at each layer. Otherwise, `stride_type` should
|
|
222
|
+
be one element shorter than `layers` and indicate either
|
|
223
|
+
`stride` or `dilation` for each layer after the first.
|
|
224
|
+
norm_groups:
|
|
225
|
+
The number of groups to use in GroupNorm layers
|
|
226
|
+
throughout the model. If left as `-1`, the number
|
|
227
|
+
of groups will be equal to the number of channels,
|
|
228
|
+
making this equilavent to LayerNorm
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
block = BasicBlock
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
in_channels: int,
|
|
236
|
+
layers: List[int],
|
|
237
|
+
classes: int,
|
|
238
|
+
kernel_size: int = 3,
|
|
239
|
+
zero_init_residual: bool = False,
|
|
240
|
+
groups: int = 1,
|
|
241
|
+
width_per_group: int = 64,
|
|
242
|
+
stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
|
|
243
|
+
norm_layer: Optional[NormLayer] = None,
|
|
244
|
+
) -> None:
|
|
245
|
+
super().__init__()
|
|
246
|
+
# default to using InstanceNorm if no
|
|
247
|
+
# norm layer is provided explicitly
|
|
248
|
+
self._norm_layer = norm_layer or GroupNorm2DGetter()
|
|
249
|
+
|
|
250
|
+
self.inplanes = 64
|
|
251
|
+
self.dilation = 1
|
|
252
|
+
|
|
253
|
+
# TODO: should we support passing a single string
|
|
254
|
+
# for simplicity here?
|
|
255
|
+
if stride_type is None:
|
|
256
|
+
# each element in the tuple indicates if we should replace
|
|
257
|
+
# the stride with a dilated convolution instead
|
|
258
|
+
stride_type = ["stride"] * (len(layers) - 1)
|
|
259
|
+
if len(stride_type) != (len(layers) - 1):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"'stride_type' should be None or a "
|
|
262
|
+
"{}-element tuple, got {}".format(len(layers) - 1, stride_type)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.groups = groups
|
|
266
|
+
self.base_width = width_per_group
|
|
267
|
+
|
|
268
|
+
# start with a basic conv-bn-relu-maxpool block
|
|
269
|
+
# to reduce the dimensionality before the heavy
|
|
270
|
+
# lifting starts
|
|
271
|
+
self.conv1 = nn.Conv2d(
|
|
272
|
+
in_channels,
|
|
273
|
+
self.inplanes,
|
|
274
|
+
kernel_size=7,
|
|
275
|
+
stride=2,
|
|
276
|
+
padding=3,
|
|
277
|
+
bias=False,
|
|
278
|
+
)
|
|
279
|
+
self.bn1 = self._norm_layer(self.inplanes)
|
|
280
|
+
self.relu = nn.ReLU(inplace=True)
|
|
281
|
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
282
|
+
|
|
283
|
+
# now create layers of residual blocks where each
|
|
284
|
+
# layer uses the same number of feature maps for
|
|
285
|
+
# all its blocks (some power of 2 times 64).
|
|
286
|
+
# Don't downsample along the time axis in the first
|
|
287
|
+
# layer, but downsample in all the rest (either by
|
|
288
|
+
# striding or dilating depending on the stride_type
|
|
289
|
+
# argument)
|
|
290
|
+
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
291
|
+
it = zip(layers[1:], stride_type)
|
|
292
|
+
for i, (num_blocks, stride) in enumerate(it):
|
|
293
|
+
block_size = 64 * 2 ** (i + 1)
|
|
294
|
+
layer = self._make_layer(
|
|
295
|
+
block_size,
|
|
296
|
+
num_blocks,
|
|
297
|
+
kernel_size,
|
|
298
|
+
stride=2,
|
|
299
|
+
stride_type=stride,
|
|
300
|
+
)
|
|
301
|
+
residual_layers.append(layer)
|
|
302
|
+
self.residual_layers = nn.ModuleList(residual_layers)
|
|
303
|
+
|
|
304
|
+
# Average pool over each feature map to create a
|
|
305
|
+
# single value for each feature map that we'll use
|
|
306
|
+
# in the fully connected head
|
|
307
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
308
|
+
|
|
309
|
+
# use a fully connected layer to map from the
|
|
310
|
+
# feature maps to the binary output that we need
|
|
311
|
+
self.fc = nn.Linear(block_size * self.block.expansion, classes)
|
|
312
|
+
|
|
313
|
+
for m in self.modules():
|
|
314
|
+
if isinstance(m, nn.Conv2d):
|
|
315
|
+
nn.init.kaiming_normal_(
|
|
316
|
+
m.weight, mode="fan_out", nonlinearity="relu"
|
|
317
|
+
)
|
|
318
|
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
319
|
+
nn.init.constant_(m.weight, 1)
|
|
320
|
+
nn.init.constant_(m.bias, 0)
|
|
321
|
+
|
|
322
|
+
# Zero-initialize the last BN in each residual branch,
|
|
323
|
+
# so that the residual branch starts with zeros,
|
|
324
|
+
# and each residual block behaves like an identity.
|
|
325
|
+
# This improves the model by 0.2~0.3% according to
|
|
326
|
+
# https://arxiv.org/abs/1706.02677
|
|
327
|
+
if zero_init_residual:
|
|
328
|
+
for m in self.modules():
|
|
329
|
+
if isinstance(m, Bottleneck):
|
|
330
|
+
nn.init.constant_(m.bn3.weight, 0)
|
|
331
|
+
elif isinstance(m, BasicBlock):
|
|
332
|
+
nn.init.constant_(m.bn2.weight, 0)
|
|
333
|
+
|
|
334
|
+
def _make_layer(
|
|
335
|
+
self,
|
|
336
|
+
planes: int,
|
|
337
|
+
blocks: int,
|
|
338
|
+
kernel_size: int = 3,
|
|
339
|
+
stride: int = 1,
|
|
340
|
+
stride_type: Literal["stride", "dilation"] = "stride",
|
|
341
|
+
) -> nn.Sequential:
|
|
342
|
+
block = self.block
|
|
343
|
+
norm_layer = self._norm_layer
|
|
344
|
+
downsample = None
|
|
345
|
+
previous_dilation = self.dilation
|
|
346
|
+
|
|
347
|
+
if stride_type == "dilation":
|
|
348
|
+
self.dilation *= stride
|
|
349
|
+
stride = 1
|
|
350
|
+
elif stride_type != "stride":
|
|
351
|
+
raise ValueError("Unknown stride type {stride}")
|
|
352
|
+
|
|
353
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
354
|
+
downsample = nn.Sequential(
|
|
355
|
+
conv1(self.inplanes, planes * block.expansion, stride),
|
|
356
|
+
norm_layer(planes * block.expansion),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
layers = []
|
|
360
|
+
layers.append(
|
|
361
|
+
block(
|
|
362
|
+
self.inplanes,
|
|
363
|
+
planes,
|
|
364
|
+
kernel_size,
|
|
365
|
+
stride,
|
|
366
|
+
downsample,
|
|
367
|
+
self.groups,
|
|
368
|
+
self.base_width,
|
|
369
|
+
previous_dilation,
|
|
370
|
+
norm_layer,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
self.inplanes = planes * block.expansion
|
|
374
|
+
for _ in range(1, blocks):
|
|
375
|
+
layers.append(
|
|
376
|
+
block(
|
|
377
|
+
self.inplanes,
|
|
378
|
+
planes,
|
|
379
|
+
kernel_size,
|
|
380
|
+
groups=self.groups,
|
|
381
|
+
base_width=self.base_width,
|
|
382
|
+
dilation=self.dilation,
|
|
383
|
+
norm_layer=norm_layer,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return nn.Sequential(*layers)
|
|
388
|
+
|
|
389
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
390
|
+
# See note [TorchScript super()]
|
|
391
|
+
x = self.conv1(x)
|
|
392
|
+
x = self.bn1(x)
|
|
393
|
+
x = self.relu(x)
|
|
394
|
+
x = self.maxpool(x)
|
|
395
|
+
|
|
396
|
+
for layer in self.residual_layers:
|
|
397
|
+
x = layer(x)
|
|
398
|
+
|
|
399
|
+
x = self.avgpool(x)
|
|
400
|
+
x = torch.flatten(x, 1)
|
|
401
|
+
x = self.fc(x)
|
|
402
|
+
|
|
403
|
+
return x
|
|
404
|
+
|
|
405
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
406
|
+
return self._forward_impl(x)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# TODO: implement as arg of ResNet instead?
|
|
410
|
+
class BottleneckResNet2D(ResNet2D):
|
|
411
|
+
"""A version of ResNet that uses bottleneck blocks"""
|
|
412
|
+
|
|
413
|
+
block = Bottleneck
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ml4gw.utils.slicing import unfold_windows
|
|
6
|
+
|
|
7
|
+
Tensor = torch.Tensor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OnlineAverager(torch.nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Module for performing stateful online averaging of
|
|
13
|
+
batches of overlapping timeseries. At present, the
|
|
14
|
+
first `num_updates` predictions produced by this
|
|
15
|
+
model will underestimate the true average.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
update_size:
|
|
19
|
+
The number of samples separating the timestamps
|
|
20
|
+
of subsequent inputs.
|
|
21
|
+
batch_size:
|
|
22
|
+
The number of batched inputs to expect at inference
|
|
23
|
+
time.
|
|
24
|
+
num_updates:
|
|
25
|
+
The number of steps over which to average predictions
|
|
26
|
+
before returning them.
|
|
27
|
+
num_channels:
|
|
28
|
+
The expected channel dimension of the input passed
|
|
29
|
+
to the module at inference time.
|
|
30
|
+
offset:
|
|
31
|
+
Number of samples to throw away from the front
|
|
32
|
+
edge of the kernel when averaging.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
update_size: int,
|
|
38
|
+
batch_size: int,
|
|
39
|
+
num_updates: int,
|
|
40
|
+
num_channels: int,
|
|
41
|
+
offset: Optional[int] = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.update_size = update_size
|
|
45
|
+
self.num_updates = num_updates
|
|
46
|
+
self.batch_size = batch_size
|
|
47
|
+
self.num_channels = num_channels
|
|
48
|
+
self.offset = offset
|
|
49
|
+
|
|
50
|
+
# build a blank tensor into which we will embed
|
|
51
|
+
# the updated snapshot predictions at the
|
|
52
|
+
# appropriate time offset for in-batch averaging
|
|
53
|
+
self.batch_update_size = int(batch_size * update_size)
|
|
54
|
+
self.state_size = int((num_updates - 1) * update_size)
|
|
55
|
+
blank_size = self.batch_update_size + self.state_size
|
|
56
|
+
blank = torch.zeros((batch_size, num_channels, blank_size))
|
|
57
|
+
self.register_buffer("blank", blank)
|
|
58
|
+
|
|
59
|
+
# set up the indices at which the updated snapshots
|
|
60
|
+
# will be embedded into the blank tensor
|
|
61
|
+
idx = torch.arange(num_updates * update_size)
|
|
62
|
+
idx = torch.stack([idx + i * update_size for i in range(batch_size)])
|
|
63
|
+
idx = idx.view(batch_size, 1, -1).repeat(1, num_channels, 1)
|
|
64
|
+
self.register_buffer("idx", idx)
|
|
65
|
+
|
|
66
|
+
# normalization indices used to downweight the
|
|
67
|
+
# existing average at each in-batch aggregation
|
|
68
|
+
weights = torch.scatter(blank, -1, idx, 1).sum(0)
|
|
69
|
+
weight_size = int(num_updates * update_size)
|
|
70
|
+
weights = unfold_windows(weights, weight_size, update_size)
|
|
71
|
+
self.register_buffer("weights", weights)
|
|
72
|
+
|
|
73
|
+
def get_initial_state(self):
|
|
74
|
+
return torch.zeros((self.num_channels, self.state_size))
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self, update: torch.Tensor, state: Optional[torch.Tensor] = None
|
|
78
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
79
|
+
if state is None:
|
|
80
|
+
state = self.get_initial_state()
|
|
81
|
+
|
|
82
|
+
# slice off the steps from this update closest
|
|
83
|
+
# to the future that we'll actually use. Divide
|
|
84
|
+
# these values by the number of updates up-front
|
|
85
|
+
# for averaging purposes.
|
|
86
|
+
start = -self.num_updates * self.update_size
|
|
87
|
+
if self.offset is not None:
|
|
88
|
+
end = -self.offset
|
|
89
|
+
start += end
|
|
90
|
+
else:
|
|
91
|
+
end = None
|
|
92
|
+
x = update[:, :, start:end] / self.num_updates
|
|
93
|
+
|
|
94
|
+
# append zeros to the state into which we
|
|
95
|
+
# can insert our updates
|
|
96
|
+
state = torch.nn.functional.pad(state, (0, self.batch_update_size))
|
|
97
|
+
|
|
98
|
+
# window the existing snapshot into overlapping
|
|
99
|
+
# segments and average them with our new updates
|
|
100
|
+
windowed = unfold_windows(state, x.size(-1), self.update_size)
|
|
101
|
+
windowed /= self.weights
|
|
102
|
+
windowed += x
|
|
103
|
+
|
|
104
|
+
# embed these windowed averages into a blank
|
|
105
|
+
# array with offsets so that we can add the
|
|
106
|
+
# overlapping bits
|
|
107
|
+
padded = torch.scatter(self.blank, -1, self.idx, windowed)
|
|
108
|
+
new_state = padded.sum(axis=0)
|
|
109
|
+
|
|
110
|
+
if self.num_updates == 1:
|
|
111
|
+
# if we don't need stateful behavior,
|
|
112
|
+
# just return the "snapshot" as-is
|
|
113
|
+
output, new_state = new_state, self.get_initial_state()
|
|
114
|
+
else:
|
|
115
|
+
# otherwise split off the values that have finished
|
|
116
|
+
# averaging and are being returned from the ones that
|
|
117
|
+
# will comprise the snapshot at the next update
|
|
118
|
+
splits = [self.batch_size, self.num_updates - 1]
|
|
119
|
+
splits = [i * self.update_size for i in splits]
|
|
120
|
+
output, new_state = torch.split(new_state, splits, dim=-1)
|
|
121
|
+
return output, new_state
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Optional, Sequence, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ml4gw.utils.slicing import unfold_windows
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Snapshotter(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Model for converting streaming state updates into
|
|
11
|
+
a batch of overlapping snaphots of a multichannel
|
|
12
|
+
timeseries. Can support multiple timeseries in a
|
|
13
|
+
single state update via the `channels_per_snapshot`
|
|
14
|
+
kwarg.
|
|
15
|
+
|
|
16
|
+
Specifically, maps tensors of shape
|
|
17
|
+
`(num_channels, batch_size * stride_size)` to a tensor
|
|
18
|
+
of shape `(batch_size, num_channels, snapshot_size)`.
|
|
19
|
+
If `channels_per_snapshot` is specified, it will return
|
|
20
|
+
`len(channels_per_snapshot)` tensors of this shape,
|
|
21
|
+
with the channel dimension replaced by the corresponding
|
|
22
|
+
value of `channels_per_snapshot`. The last tensor returned
|
|
23
|
+
at call time will be the current state that can be passed
|
|
24
|
+
to the next `forward` call.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
num_channels:
|
|
28
|
+
Number of channels in the timeseries. If
|
|
29
|
+
`channels_per_snapshot` is not `None`,
|
|
30
|
+
this should be equal to `sum(channels_per_snapshot)`.
|
|
31
|
+
snapshot_size:
|
|
32
|
+
The size of the output snapshot windows in
|
|
33
|
+
number of samples
|
|
34
|
+
stride_size:
|
|
35
|
+
The number of samples in between each output
|
|
36
|
+
snapshot
|
|
37
|
+
batch_size:
|
|
38
|
+
The number of snapshots to produce at each
|
|
39
|
+
update. The last dimension of the input
|
|
40
|
+
tensor should have size `batch_size * stride_size`.
|
|
41
|
+
channels_per_snapshot:
|
|
42
|
+
How to split up the channels in the timeseries
|
|
43
|
+
for different tensors. If left as `None`, all
|
|
44
|
+
the channels will be returned in a single tensor.
|
|
45
|
+
Otherwise, the channels will be split up into
|
|
46
|
+
`len(channels_per_snapshot)` tensors, with each
|
|
47
|
+
tensor's channel dimension being equal to the
|
|
48
|
+
corresponding value in `channels_per_snapshot`.
|
|
49
|
+
Therefore, if specified, these values should
|
|
50
|
+
add up to `num_channels`.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
num_channels: int,
|
|
56
|
+
snapshot_size: int,
|
|
57
|
+
stride_size: int,
|
|
58
|
+
batch_size: int,
|
|
59
|
+
channels_per_snapshot: Optional[Sequence[int]] = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
super().__init__()
|
|
62
|
+
if stride_size >= snapshot_size:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"Snapshotter can't accommodate stride {} "
|
|
65
|
+
"which is greater than snapshot size {}".format(
|
|
66
|
+
stride_size, snapshot_size
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.snapshot_size = snapshot_size
|
|
71
|
+
self.stride_size = stride_size
|
|
72
|
+
self.state_size = snapshot_size - stride_size
|
|
73
|
+
self.batch_size = batch_size
|
|
74
|
+
|
|
75
|
+
if channels_per_snapshot is not None:
|
|
76
|
+
if sum(channels_per_snapshot) != num_channels:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Can't break {} channels into {}".format(
|
|
79
|
+
num_channels, channels_per_snapshot
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
self.channels_per_snapshot = channels_per_snapshot
|
|
83
|
+
self.num_channels = num_channels
|
|
84
|
+
|
|
85
|
+
def get_initial_state(self):
|
|
86
|
+
return torch.zeros((self.num_channels, self.state_size))
|
|
87
|
+
|
|
88
|
+
# TODO: use torchtyping annotations to make
|
|
89
|
+
# clear what the expected shapes are
|
|
90
|
+
def forward(
|
|
91
|
+
self, update: torch.Tensor, snapshot: Optional[torch.Tensor] = None
|
|
92
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
93
|
+
if snapshot is None:
|
|
94
|
+
snapshot = self.get_initial_state()
|
|
95
|
+
|
|
96
|
+
# append new data to the snapshot
|
|
97
|
+
snapshot = torch.cat([snapshot, update], axis=-1)
|
|
98
|
+
|
|
99
|
+
if self.batch_size > 1:
|
|
100
|
+
snapshots = unfold_windows(
|
|
101
|
+
snapshot, self.snapshot_size, self.stride_size
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
snapshots = snapshot[None]
|
|
105
|
+
|
|
106
|
+
if self.channels_per_snapshot is not None:
|
|
107
|
+
if snapshots.size(1) != self.num_channels:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"Expected {} channels, found {}".format(
|
|
110
|
+
self.num_channels, snapshots.size(1)
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
snapshots = torch.split(
|
|
114
|
+
snapshots, self.channels_per_snapshot, dim=1
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
snapshots = (snapshots,)
|
|
118
|
+
|
|
119
|
+
# keep only the latest snapshot as our state
|
|
120
|
+
snapshot = snapshot[:, -self.state_size :]
|
|
121
|
+
return tuple(snapshots) + (snapshot,)
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from .pearson import ShiftedPearsonCorrelation
|
|
1
2
|
from .scaler import ChannelWiseScaler
|
|
2
3
|
from .snr_rescaler import SnrRescaler
|
|
3
4
|
from .spectral import SpectralDensity
|
|
5
|
+
from .spectrogram import MultiResolutionSpectrogram
|
|
4
6
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
5
7
|
from .whitening import FixedWhiten, Whiten
|