ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,413 @@
1
+ """
2
+ In large part lifted from
3
+ https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
4
+ but with arbitrary kernel sizes
5
+ """
6
+
7
+ from typing import Callable, List, Literal, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from ml4gw.nn.norm import GroupNorm2DGetter, NormLayer
14
+
15
+
16
+ def convN(
17
+ in_planes: int,
18
+ out_planes: int,
19
+ kernel_size: int = 3,
20
+ stride: int = 1,
21
+ groups: int = 1,
22
+ dilation: int = 1,
23
+ ) -> nn.Conv2d:
24
+ """2d convolution with padding"""
25
+ if not kernel_size % 2:
26
+ raise ValueError("Can't use even sized kernels")
27
+
28
+ return nn.Conv2d(
29
+ in_planes,
30
+ out_planes,
31
+ kernel_size=kernel_size,
32
+ stride=stride,
33
+ padding=dilation * int(kernel_size // 2),
34
+ groups=groups,
35
+ bias=False,
36
+ dilation=dilation,
37
+ )
38
+
39
+
40
+ def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
41
+ """Kernel-size 1 convolution"""
42
+ return nn.Conv2d(
43
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
44
+ )
45
+
46
+
47
+ class BasicBlock(nn.Module):
48
+ """Defines the structure of the blocks used to build the ResNet"""
49
+
50
+ expansion: int = 1
51
+
52
+ def __init__(
53
+ self,
54
+ inplanes: int,
55
+ planes: int,
56
+ kernel_size: int = 3,
57
+ stride: int = 1,
58
+ downsample: Optional[nn.Module] = None,
59
+ groups: int = 1,
60
+ base_width: int = 64,
61
+ dilation: int = 1,
62
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
63
+ ) -> None:
64
+
65
+ super().__init__()
66
+ if norm_layer is None:
67
+ norm_layer = nn.BatchNorm2d
68
+ if groups != 1 or base_width != 64:
69
+ raise ValueError(
70
+ "BasicBlock only supports groups=1 and base_width=64"
71
+ )
72
+ if dilation > 1:
73
+ raise NotImplementedError(
74
+ "Dilation > 1 not supported in BasicBlock"
75
+ )
76
+
77
+ # Both self.conv1 and self.downsample layers
78
+ # downsample the input when stride != 1
79
+ self.conv1 = convN(inplanes, planes, kernel_size, stride)
80
+ self.bn1 = norm_layer(planes)
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.conv2 = convN(planes, planes, kernel_size)
83
+ self.bn2 = norm_layer(planes)
84
+ self.downsample = downsample
85
+ self.stride = stride
86
+
87
+ def forward(self, x: Tensor) -> Tensor:
88
+ identity = x
89
+
90
+ out = self.conv1(x)
91
+ out = self.bn1(out)
92
+ out = self.relu(out)
93
+
94
+ out = self.conv2(out)
95
+ out = self.bn2(out)
96
+
97
+ if self.downsample is not None:
98
+ identity = self.downsample(x)
99
+
100
+ out += identity
101
+ out = self.relu(out)
102
+
103
+ return out
104
+
105
+
106
+ class Bottleneck(nn.Module):
107
+ """
108
+ Bottleneck blocks implement one extra convolution
109
+ compared to basic blocks. In this layers, the `planes`
110
+ parameter is generally meant to _downsize_ the number
111
+ of feature maps first, which then get expanded out to
112
+ `planes * Bottleneck.expansion` feature maps at the
113
+ output of the layer.
114
+ """
115
+
116
+ expansion: int = 4
117
+
118
+ def __init__(
119
+ self,
120
+ inplanes: int,
121
+ planes: int,
122
+ kernel_size: int = 3,
123
+ stride: int = 1,
124
+ downsample: Optional[nn.Module] = None,
125
+ groups: int = 1,
126
+ base_width: int = 64,
127
+ dilation: int = 1,
128
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
129
+ ) -> None:
130
+ super().__init__()
131
+ if norm_layer is None:
132
+ norm_layer = nn.BatchNorm2d
133
+
134
+ width = int(planes * (base_width / 64.0)) * groups
135
+
136
+ # conv1 does no downsampling, just reduces the number of
137
+ # feature maps from inplanes to width (where width == planes)
138
+ # if groups == 1 and base_width == 64
139
+ self.conv1 = convN(inplanes, width, kernel_size)
140
+ self.bn1 = norm_layer(width)
141
+
142
+ # conv2 keeps the same number of feature maps,
143
+ # but downsamples along the time axis if stride
144
+ # or dilation > 1
145
+ self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
146
+ self.bn2 = norm_layer(width)
147
+
148
+ # conv3 expands the feature maps back out to planes * expansion
149
+ self.conv3 = conv1(width, planes * self.expansion)
150
+ self.bn3 = norm_layer(planes * self.expansion)
151
+
152
+ self.relu = nn.ReLU(inplace=True)
153
+ self.downsample = downsample
154
+ self.stride = stride
155
+
156
+ def forward(self, x: Tensor) -> Tensor:
157
+ identity = x
158
+
159
+ out = self.conv1(x)
160
+ out = self.bn1(out)
161
+ out = self.relu(out)
162
+
163
+ out = self.conv2(out)
164
+ out = self.bn2(out)
165
+ out = self.relu(out)
166
+
167
+ out = self.conv3(out)
168
+ out = self.bn3(out)
169
+
170
+ if self.downsample is not None:
171
+ identity = self.downsample(x)
172
+
173
+ out += identity
174
+ out = self.relu(out)
175
+
176
+ return out
177
+
178
+
179
+ class ResNet2D(nn.Module):
180
+ """2D ResNet architecture
181
+
182
+ Simple extension of ResNet with arbitrary kernel sizes
183
+ to support the longer timeseries used in BBH detection.
184
+
185
+ Args:
186
+ in_channels:
187
+ The number of channels in input tensor.
188
+ layers:
189
+ A list representing the number of residual
190
+ blocks to include in each "layer" of the
191
+ network. Total layers (e.g. 50 in ResNet50)
192
+ is `2 + sum(layers) * factor`, where factor
193
+ is `2` for vanilla `ResNet` and `3` for
194
+ `BottleneckResNet`.
195
+ kernel_size:
196
+ The size of the convolutional kernel to
197
+ use in all residual layers. _NOT_ the size
198
+ of the input kernel to the network, which
199
+ is determined at run-time.
200
+ zero_init_residual:
201
+ Flag indicating whether to initialize the
202
+ weights of the batch-norm layer in each block
203
+ to 0 so that residuals are initialized as
204
+ identities. Can improve training results.
205
+ groups:
206
+ Number of convolutional groups to use in all
207
+ layers. Grouped convolutions induce local
208
+ connections between feature maps at subsequent
209
+ layers rather than global. Generally won't
210
+ need this to be >1, and wil raise an error if
211
+ >1 when using vanilla `ResNet`.
212
+ width_per_group:
213
+ Base width of each of the feature map groups,
214
+ which is scaled up by the typical expansion
215
+ factor at each layer of the network. Meaningless
216
+ for vanilla `ResNet`.
217
+ stride_type:
218
+ Whether to achieve downsampling on the time axis
219
+ by strided or dilated convolutions for each layer.
220
+ If left as `None`, strided convolutions will be
221
+ used at each layer. Otherwise, `stride_type` should
222
+ be one element shorter than `layers` and indicate either
223
+ `stride` or `dilation` for each layer after the first.
224
+ norm_groups:
225
+ The number of groups to use in GroupNorm layers
226
+ throughout the model. If left as `-1`, the number
227
+ of groups will be equal to the number of channels,
228
+ making this equilavent to LayerNorm
229
+ """
230
+
231
+ block = BasicBlock
232
+
233
+ def __init__(
234
+ self,
235
+ in_channels: int,
236
+ layers: List[int],
237
+ classes: int,
238
+ kernel_size: int = 3,
239
+ zero_init_residual: bool = False,
240
+ groups: int = 1,
241
+ width_per_group: int = 64,
242
+ stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
243
+ norm_layer: Optional[NormLayer] = None,
244
+ ) -> None:
245
+ super().__init__()
246
+ # default to using InstanceNorm if no
247
+ # norm layer is provided explicitly
248
+ self._norm_layer = norm_layer or GroupNorm2DGetter()
249
+
250
+ self.inplanes = 64
251
+ self.dilation = 1
252
+
253
+ # TODO: should we support passing a single string
254
+ # for simplicity here?
255
+ if stride_type is None:
256
+ # each element in the tuple indicates if we should replace
257
+ # the stride with a dilated convolution instead
258
+ stride_type = ["stride"] * (len(layers) - 1)
259
+ if len(stride_type) != (len(layers) - 1):
260
+ raise ValueError(
261
+ "'stride_type' should be None or a "
262
+ "{}-element tuple, got {}".format(len(layers) - 1, stride_type)
263
+ )
264
+
265
+ self.groups = groups
266
+ self.base_width = width_per_group
267
+
268
+ # start with a basic conv-bn-relu-maxpool block
269
+ # to reduce the dimensionality before the heavy
270
+ # lifting starts
271
+ self.conv1 = nn.Conv2d(
272
+ in_channels,
273
+ self.inplanes,
274
+ kernel_size=7,
275
+ stride=2,
276
+ padding=3,
277
+ bias=False,
278
+ )
279
+ self.bn1 = self._norm_layer(self.inplanes)
280
+ self.relu = nn.ReLU(inplace=True)
281
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
282
+
283
+ # now create layers of residual blocks where each
284
+ # layer uses the same number of feature maps for
285
+ # all its blocks (some power of 2 times 64).
286
+ # Don't downsample along the time axis in the first
287
+ # layer, but downsample in all the rest (either by
288
+ # striding or dilating depending on the stride_type
289
+ # argument)
290
+ residual_layers = [self._make_layer(64, layers[0], kernel_size)]
291
+ it = zip(layers[1:], stride_type)
292
+ for i, (num_blocks, stride) in enumerate(it):
293
+ block_size = 64 * 2 ** (i + 1)
294
+ layer = self._make_layer(
295
+ block_size,
296
+ num_blocks,
297
+ kernel_size,
298
+ stride=2,
299
+ stride_type=stride,
300
+ )
301
+ residual_layers.append(layer)
302
+ self.residual_layers = nn.ModuleList(residual_layers)
303
+
304
+ # Average pool over each feature map to create a
305
+ # single value for each feature map that we'll use
306
+ # in the fully connected head
307
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
308
+
309
+ # use a fully connected layer to map from the
310
+ # feature maps to the binary output that we need
311
+ self.fc = nn.Linear(block_size * self.block.expansion, classes)
312
+
313
+ for m in self.modules():
314
+ if isinstance(m, nn.Conv2d):
315
+ nn.init.kaiming_normal_(
316
+ m.weight, mode="fan_out", nonlinearity="relu"
317
+ )
318
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
319
+ nn.init.constant_(m.weight, 1)
320
+ nn.init.constant_(m.bias, 0)
321
+
322
+ # Zero-initialize the last BN in each residual branch,
323
+ # so that the residual branch starts with zeros,
324
+ # and each residual block behaves like an identity.
325
+ # This improves the model by 0.2~0.3% according to
326
+ # https://arxiv.org/abs/1706.02677
327
+ if zero_init_residual:
328
+ for m in self.modules():
329
+ if isinstance(m, Bottleneck):
330
+ nn.init.constant_(m.bn3.weight, 0)
331
+ elif isinstance(m, BasicBlock):
332
+ nn.init.constant_(m.bn2.weight, 0)
333
+
334
+ def _make_layer(
335
+ self,
336
+ planes: int,
337
+ blocks: int,
338
+ kernel_size: int = 3,
339
+ stride: int = 1,
340
+ stride_type: Literal["stride", "dilation"] = "stride",
341
+ ) -> nn.Sequential:
342
+ block = self.block
343
+ norm_layer = self._norm_layer
344
+ downsample = None
345
+ previous_dilation = self.dilation
346
+
347
+ if stride_type == "dilation":
348
+ self.dilation *= stride
349
+ stride = 1
350
+ elif stride_type != "stride":
351
+ raise ValueError("Unknown stride type {stride}")
352
+
353
+ if stride != 1 or self.inplanes != planes * block.expansion:
354
+ downsample = nn.Sequential(
355
+ conv1(self.inplanes, planes * block.expansion, stride),
356
+ norm_layer(planes * block.expansion),
357
+ )
358
+
359
+ layers = []
360
+ layers.append(
361
+ block(
362
+ self.inplanes,
363
+ planes,
364
+ kernel_size,
365
+ stride,
366
+ downsample,
367
+ self.groups,
368
+ self.base_width,
369
+ previous_dilation,
370
+ norm_layer,
371
+ )
372
+ )
373
+ self.inplanes = planes * block.expansion
374
+ for _ in range(1, blocks):
375
+ layers.append(
376
+ block(
377
+ self.inplanes,
378
+ planes,
379
+ kernel_size,
380
+ groups=self.groups,
381
+ base_width=self.base_width,
382
+ dilation=self.dilation,
383
+ norm_layer=norm_layer,
384
+ )
385
+ )
386
+
387
+ return nn.Sequential(*layers)
388
+
389
+ def _forward_impl(self, x: Tensor) -> Tensor:
390
+ # See note [TorchScript super()]
391
+ x = self.conv1(x)
392
+ x = self.bn1(x)
393
+ x = self.relu(x)
394
+ x = self.maxpool(x)
395
+
396
+ for layer in self.residual_layers:
397
+ x = layer(x)
398
+
399
+ x = self.avgpool(x)
400
+ x = torch.flatten(x, 1)
401
+ x = self.fc(x)
402
+
403
+ return x
404
+
405
+ def forward(self, x: Tensor) -> Tensor:
406
+ return self._forward_impl(x)
407
+
408
+
409
+ # TODO: implement as arg of ResNet instead?
410
+ class BottleneckResNet2D(ResNet2D):
411
+ """A version of ResNet that uses bottleneck blocks"""
412
+
413
+ block = Bottleneck
@@ -0,0 +1,2 @@
1
+ from .online_average import OnlineAverager
2
+ from .snapshotter import Snapshotter
@@ -0,0 +1,121 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from ml4gw.utils.slicing import unfold_windows
6
+
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ class OnlineAverager(torch.nn.Module):
11
+ """
12
+ Module for performing stateful online averaging of
13
+ batches of overlapping timeseries. At present, the
14
+ first `num_updates` predictions produced by this
15
+ model will underestimate the true average.
16
+
17
+ Args:
18
+ update_size:
19
+ The number of samples separating the timestamps
20
+ of subsequent inputs.
21
+ batch_size:
22
+ The number of batched inputs to expect at inference
23
+ time.
24
+ num_updates:
25
+ The number of steps over which to average predictions
26
+ before returning them.
27
+ num_channels:
28
+ The expected channel dimension of the input passed
29
+ to the module at inference time.
30
+ offset:
31
+ Number of samples to throw away from the front
32
+ edge of the kernel when averaging.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ update_size: int,
38
+ batch_size: int,
39
+ num_updates: int,
40
+ num_channels: int,
41
+ offset: Optional[int] = None,
42
+ ) -> None:
43
+ super().__init__()
44
+ self.update_size = update_size
45
+ self.num_updates = num_updates
46
+ self.batch_size = batch_size
47
+ self.num_channels = num_channels
48
+ self.offset = offset
49
+
50
+ # build a blank tensor into which we will embed
51
+ # the updated snapshot predictions at the
52
+ # appropriate time offset for in-batch averaging
53
+ self.batch_update_size = int(batch_size * update_size)
54
+ self.state_size = int((num_updates - 1) * update_size)
55
+ blank_size = self.batch_update_size + self.state_size
56
+ blank = torch.zeros((batch_size, num_channels, blank_size))
57
+ self.register_buffer("blank", blank)
58
+
59
+ # set up the indices at which the updated snapshots
60
+ # will be embedded into the blank tensor
61
+ idx = torch.arange(num_updates * update_size)
62
+ idx = torch.stack([idx + i * update_size for i in range(batch_size)])
63
+ idx = idx.view(batch_size, 1, -1).repeat(1, num_channels, 1)
64
+ self.register_buffer("idx", idx)
65
+
66
+ # normalization indices used to downweight the
67
+ # existing average at each in-batch aggregation
68
+ weights = torch.scatter(blank, -1, idx, 1).sum(0)
69
+ weight_size = int(num_updates * update_size)
70
+ weights = unfold_windows(weights, weight_size, update_size)
71
+ self.register_buffer("weights", weights)
72
+
73
+ def get_initial_state(self):
74
+ return torch.zeros((self.num_channels, self.state_size))
75
+
76
+ def forward(
77
+ self, update: torch.Tensor, state: Optional[torch.Tensor] = None
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ if state is None:
80
+ state = self.get_initial_state()
81
+
82
+ # slice off the steps from this update closest
83
+ # to the future that we'll actually use. Divide
84
+ # these values by the number of updates up-front
85
+ # for averaging purposes.
86
+ start = -self.num_updates * self.update_size
87
+ if self.offset is not None:
88
+ end = -self.offset
89
+ start += end
90
+ else:
91
+ end = None
92
+ x = update[:, :, start:end] / self.num_updates
93
+
94
+ # append zeros to the state into which we
95
+ # can insert our updates
96
+ state = torch.nn.functional.pad(state, (0, self.batch_update_size))
97
+
98
+ # window the existing snapshot into overlapping
99
+ # segments and average them with our new updates
100
+ windowed = unfold_windows(state, x.size(-1), self.update_size)
101
+ windowed /= self.weights
102
+ windowed += x
103
+
104
+ # embed these windowed averages into a blank
105
+ # array with offsets so that we can add the
106
+ # overlapping bits
107
+ padded = torch.scatter(self.blank, -1, self.idx, windowed)
108
+ new_state = padded.sum(axis=0)
109
+
110
+ if self.num_updates == 1:
111
+ # if we don't need stateful behavior,
112
+ # just return the "snapshot" as-is
113
+ output, new_state = new_state, self.get_initial_state()
114
+ else:
115
+ # otherwise split off the values that have finished
116
+ # averaging and are being returned from the ones that
117
+ # will comprise the snapshot at the next update
118
+ splits = [self.batch_size, self.num_updates - 1]
119
+ splits = [i * self.update_size for i in splits]
120
+ output, new_state = torch.split(new_state, splits, dim=-1)
121
+ return output, new_state
@@ -0,0 +1,121 @@
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+
5
+ from ml4gw.utils.slicing import unfold_windows
6
+
7
+
8
+ class Snapshotter(torch.nn.Module):
9
+ """
10
+ Model for converting streaming state updates into
11
+ a batch of overlapping snaphots of a multichannel
12
+ timeseries. Can support multiple timeseries in a
13
+ single state update via the `channels_per_snapshot`
14
+ kwarg.
15
+
16
+ Specifically, maps tensors of shape
17
+ `(num_channels, batch_size * stride_size)` to a tensor
18
+ of shape `(batch_size, num_channels, snapshot_size)`.
19
+ If `channels_per_snapshot` is specified, it will return
20
+ `len(channels_per_snapshot)` tensors of this shape,
21
+ with the channel dimension replaced by the corresponding
22
+ value of `channels_per_snapshot`. The last tensor returned
23
+ at call time will be the current state that can be passed
24
+ to the next `forward` call.
25
+
26
+ Args:
27
+ num_channels:
28
+ Number of channels in the timeseries. If
29
+ `channels_per_snapshot` is not `None`,
30
+ this should be equal to `sum(channels_per_snapshot)`.
31
+ snapshot_size:
32
+ The size of the output snapshot windows in
33
+ number of samples
34
+ stride_size:
35
+ The number of samples in between each output
36
+ snapshot
37
+ batch_size:
38
+ The number of snapshots to produce at each
39
+ update. The last dimension of the input
40
+ tensor should have size `batch_size * stride_size`.
41
+ channels_per_snapshot:
42
+ How to split up the channels in the timeseries
43
+ for different tensors. If left as `None`, all
44
+ the channels will be returned in a single tensor.
45
+ Otherwise, the channels will be split up into
46
+ `len(channels_per_snapshot)` tensors, with each
47
+ tensor's channel dimension being equal to the
48
+ corresponding value in `channels_per_snapshot`.
49
+ Therefore, if specified, these values should
50
+ add up to `num_channels`.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ num_channels: int,
56
+ snapshot_size: int,
57
+ stride_size: int,
58
+ batch_size: int,
59
+ channels_per_snapshot: Optional[Sequence[int]] = None,
60
+ ) -> None:
61
+ super().__init__()
62
+ if stride_size >= snapshot_size:
63
+ raise ValueError(
64
+ "Snapshotter can't accommodate stride {} "
65
+ "which is greater than snapshot size {}".format(
66
+ stride_size, snapshot_size
67
+ )
68
+ )
69
+
70
+ self.snapshot_size = snapshot_size
71
+ self.stride_size = stride_size
72
+ self.state_size = snapshot_size - stride_size
73
+ self.batch_size = batch_size
74
+
75
+ if channels_per_snapshot is not None:
76
+ if sum(channels_per_snapshot) != num_channels:
77
+ raise ValueError(
78
+ "Can't break {} channels into {}".format(
79
+ num_channels, channels_per_snapshot
80
+ )
81
+ )
82
+ self.channels_per_snapshot = channels_per_snapshot
83
+ self.num_channels = num_channels
84
+
85
+ def get_initial_state(self):
86
+ return torch.zeros((self.num_channels, self.state_size))
87
+
88
+ # TODO: use torchtyping annotations to make
89
+ # clear what the expected shapes are
90
+ def forward(
91
+ self, update: torch.Tensor, snapshot: Optional[torch.Tensor] = None
92
+ ) -> Tuple[torch.Tensor, ...]:
93
+ if snapshot is None:
94
+ snapshot = self.get_initial_state()
95
+
96
+ # append new data to the snapshot
97
+ snapshot = torch.cat([snapshot, update], axis=-1)
98
+
99
+ if self.batch_size > 1:
100
+ snapshots = unfold_windows(
101
+ snapshot, self.snapshot_size, self.stride_size
102
+ )
103
+ else:
104
+ snapshots = snapshot[None]
105
+
106
+ if self.channels_per_snapshot is not None:
107
+ if snapshots.size(1) != self.num_channels:
108
+ raise ValueError(
109
+ "Expected {} channels, found {}".format(
110
+ self.num_channels, snapshots.size(1)
111
+ )
112
+ )
113
+ snapshots = torch.split(
114
+ snapshots, self.channels_per_snapshot, dim=1
115
+ )
116
+ else:
117
+ snapshots = (snapshots,)
118
+
119
+ # keep only the latest snapshot as our state
120
+ snapshot = snapshot[:, -self.state_size :]
121
+ return tuple(snapshots) + (snapshot,)
@@ -1,5 +1,7 @@
1
+ from .pearson import ShiftedPearsonCorrelation
1
2
  from .scaler import ChannelWiseScaler
2
3
  from .snr_rescaler import SnrRescaler
3
4
  from .spectral import SpectralDensity
5
+ from .spectrogram import MultiResolutionSpectrogram
4
6
  from .waveforms import WaveformProjector, WaveformSampler
5
7
  from .whitening import FixedWhiten, Whiten