ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,413 @@
1
+ """
2
+ In large part lifted from
3
+ https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
4
+ but with 1d convolutions and arbitrary kernel sizes, and a
5
+ default norm layer that makes more sense for most GW applications
6
+ where training-time statistics are entirely arbitrary due to
7
+ simulations.
8
+ """
9
+
10
+ from typing import Callable, List, Literal, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+ from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
17
+
18
+
19
+ def convN(
20
+ in_planes: int,
21
+ out_planes: int,
22
+ kernel_size: int = 3,
23
+ stride: int = 1,
24
+ groups: int = 1,
25
+ dilation: int = 1,
26
+ ) -> nn.Conv1d:
27
+ """1d convolution with padding"""
28
+ if not kernel_size % 2:
29
+ raise ValueError("Can't use even sized kernels")
30
+
31
+ return nn.Conv1d(
32
+ in_planes,
33
+ out_planes,
34
+ kernel_size=kernel_size,
35
+ stride=stride,
36
+ padding=dilation * int(kernel_size // 2),
37
+ groups=groups,
38
+ bias=False,
39
+ dilation=dilation,
40
+ )
41
+
42
+
43
+ def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
44
+ """Kernel-size 1 convolution"""
45
+ return nn.Conv1d(
46
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
47
+ )
48
+
49
+
50
+ class BasicBlock(nn.Module):
51
+ """Defines the structure of the blocks used to build the ResNet"""
52
+
53
+ expansion: int = 1
54
+
55
+ def __init__(
56
+ self,
57
+ inplanes: int,
58
+ planes: int,
59
+ kernel_size: int = 3,
60
+ stride: int = 1,
61
+ downsample: Optional[nn.Module] = None,
62
+ groups: int = 1,
63
+ base_width: int = 64,
64
+ dilation: int = 1,
65
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
66
+ ) -> None:
67
+
68
+ super().__init__()
69
+ if norm_layer is None:
70
+ norm_layer = nn.BatchNorm1d
71
+ if groups != 1 or base_width != 64:
72
+ raise ValueError(
73
+ "BasicBlock only supports groups=1 and base_width=64"
74
+ )
75
+ if dilation > 1:
76
+ raise NotImplementedError(
77
+ "Dilation > 1 not supported in BasicBlock"
78
+ )
79
+
80
+ # Both self.conv1 and self.downsample layers
81
+ # downsample the input when stride != 1
82
+ self.conv1 = convN(inplanes, planes, kernel_size, stride)
83
+ self.bn1 = norm_layer(planes)
84
+ self.relu = nn.ReLU(inplace=True)
85
+ self.conv2 = convN(planes, planes, kernel_size)
86
+ self.bn2 = norm_layer(planes)
87
+ self.downsample = downsample
88
+ self.stride = stride
89
+
90
+ def forward(self, x: Tensor) -> Tensor:
91
+ identity = x
92
+
93
+ out = self.conv1(x)
94
+ out = self.bn1(out)
95
+ out = self.relu(out)
96
+
97
+ out = self.conv2(out)
98
+ out = self.bn2(out)
99
+
100
+ if self.downsample is not None:
101
+ identity = self.downsample(x)
102
+
103
+ out += identity
104
+ out = self.relu(out)
105
+
106
+ return out
107
+
108
+
109
+ class Bottleneck(nn.Module):
110
+ """
111
+ Bottleneck blocks implement one extra convolution
112
+ compared to basic blocks. In this layers, the `planes`
113
+ parameter is generally meant to _downsize_ the number
114
+ of feature maps first, which then get expanded out to
115
+ `planes * Bottleneck.expansion` feature maps at the
116
+ output of the layer.
117
+ """
118
+
119
+ expansion: int = 4
120
+
121
+ def __init__(
122
+ self,
123
+ inplanes: int,
124
+ planes: int,
125
+ kernel_size: int = 3,
126
+ stride: int = 1,
127
+ downsample: Optional[nn.Module] = None,
128
+ groups: int = 1,
129
+ base_width: int = 64,
130
+ dilation: int = 1,
131
+ norm_layer: Optional[NormLayer] = None,
132
+ ) -> None:
133
+ super().__init__()
134
+ if norm_layer is None:
135
+ norm_layer = nn.BatchNorm1d
136
+
137
+ width = int(planes * (base_width / 64.0)) * groups
138
+
139
+ # conv1 does no downsampling, just reduces the number of
140
+ # feature maps from inplanes to width (where width == planes)
141
+ # if groups == 1 and base_width == 64
142
+ self.conv1 = convN(inplanes, width, kernel_size)
143
+ self.bn1 = norm_layer(width)
144
+
145
+ # conv2 keeps the same number of feature maps,
146
+ # but downsamples along the time axis if stride
147
+ # or dilation > 1
148
+ self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
149
+ self.bn2 = norm_layer(width)
150
+
151
+ # conv3 expands the feature maps back out to planes * expansion
152
+ self.conv3 = conv1(width, planes * self.expansion)
153
+ self.bn3 = norm_layer(planes * self.expansion)
154
+
155
+ self.relu = nn.ReLU(inplace=True)
156
+ self.downsample = downsample
157
+ self.stride = stride
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ identity = x
161
+
162
+ out = self.conv1(x)
163
+ out = self.bn1(out)
164
+ out = self.relu(out)
165
+
166
+ out = self.conv2(out)
167
+ out = self.bn2(out)
168
+ out = self.relu(out)
169
+
170
+ out = self.conv3(out)
171
+ out = self.bn3(out)
172
+
173
+ if self.downsample is not None:
174
+ identity = self.downsample(x)
175
+
176
+ out += identity
177
+ out = self.relu(out)
178
+
179
+ return out
180
+
181
+
182
+ class ResNet1D(nn.Module):
183
+ """1D ResNet architecture
184
+
185
+ Simple extension of ResNet to 1D convolutions with
186
+ arbitrary kernel sizes to support the longer timeseries
187
+ used in BBH detection.
188
+
189
+ Args:
190
+ in_channels:
191
+ The number of channels in input tensor.
192
+ layers:
193
+ A list representing the number of residual
194
+ blocks to include in each "layer" of the
195
+ network. Total layers (e.g. 50 in ResNet50)
196
+ is `2 + sum(layers) * factor`, where factor
197
+ is `2` for vanilla `ResNet` and `3` for
198
+ `BottleneckResNet`.
199
+ kernel_size:
200
+ The size of the convolutional kernel to
201
+ use in all residual layers. _NOT_ the size
202
+ of the input kernel to the network, which
203
+ is determined at run-time.
204
+ zero_init_residual:
205
+ Flag indicating whether to initialize the
206
+ weights of the batch-norm layer in each block
207
+ to 0 so that residuals are initialized as
208
+ identities. Can improve training results.
209
+ groups:
210
+ Number of convolutional groups to use in all
211
+ layers. Grouped convolutions induce local
212
+ connections between feature maps at subsequent
213
+ layers rather than global. Generally won't
214
+ need this to be >1, and wil raise an error if
215
+ >1 when using vanilla `ResNet`.
216
+ width_per_group:
217
+ Base width of each of the feature map groups,
218
+ which is scaled up by the typical expansion
219
+ factor at each layer of the network. Meaningless
220
+ for vanilla `ResNet`.
221
+ stride_type:
222
+ Whether to achieve downsampling on the time axis
223
+ by strided or dilated convolutions for each layer.
224
+ If left as `None`, strided convolutions will be
225
+ used at each layer. Otherwise, `stride_type` should
226
+ be one element shorter than `layers` and indicate either
227
+ `stride` or `dilation` for each layer after the first.
228
+ """
229
+
230
+ block = BasicBlock
231
+
232
+ def __init__(
233
+ self,
234
+ in_channels: int,
235
+ layers: List[int],
236
+ classes: int,
237
+ kernel_size: int = 3,
238
+ zero_init_residual: bool = False,
239
+ groups: int = 1,
240
+ width_per_group: int = 64,
241
+ stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
242
+ norm_layer: Optional[NormLayer] = None,
243
+ ) -> None:
244
+ super().__init__()
245
+
246
+ self.inplanes = 64
247
+ self.dilation = 1
248
+
249
+ # default to using InstanceNorm if no
250
+ # norm layer is provided explicitly
251
+ self._norm_layer = norm_layer or GroupNorm1DGetter()
252
+
253
+ # TODO: should we support passing a single string
254
+ # for simplicity here?
255
+ if stride_type is None:
256
+ # each element in the tuple indicates if we should replace
257
+ # the stride with a dilated convolution instead
258
+ stride_type = ["stride"] * (len(layers) - 1)
259
+ if len(stride_type) != (len(layers) - 1):
260
+ raise ValueError(
261
+ "'stride_type' should be None or a "
262
+ "{}-element tuple, got {}".format(len(layers) - 1, stride_type)
263
+ )
264
+
265
+ self.groups = groups
266
+ self.base_width = width_per_group
267
+
268
+ # start with a basic conv-bn-relu-maxpool block
269
+ # to reduce the dimensionality before the heavy
270
+ # lifting starts
271
+ self.conv1 = nn.Conv1d(
272
+ in_channels,
273
+ self.inplanes,
274
+ kernel_size=7,
275
+ stride=2,
276
+ padding=3,
277
+ bias=False,
278
+ )
279
+ self.bn1 = self._norm_layer(self.inplanes)
280
+ self.relu = nn.ReLU(inplace=True)
281
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
282
+
283
+ # now create layers of residual blocks where each
284
+ # layer uses the same number of feature maps for
285
+ # all its blocks (some power of 2 times 64).
286
+ # Don't downsample along the time axis in the first
287
+ # layer, but downsample in all the rest (either by
288
+ # striding or dilating depending on the stride_type
289
+ # argument)
290
+ residual_layers = [self._make_layer(64, layers[0], kernel_size)]
291
+ it = zip(layers[1:], stride_type)
292
+ for i, (num_blocks, stride) in enumerate(it):
293
+ block_size = 64 * 2 ** (i + 1)
294
+ layer = self._make_layer(
295
+ block_size,
296
+ num_blocks,
297
+ kernel_size,
298
+ stride=2,
299
+ stride_type=stride,
300
+ )
301
+ residual_layers.append(layer)
302
+ self.residual_layers = nn.ModuleList(residual_layers)
303
+
304
+ # Average pool over each feature map to create a
305
+ # single value for each feature map that we'll use
306
+ # in the fully connected head
307
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
308
+
309
+ # use a fully connected layer to map from the
310
+ # feature maps to the binary output that we need
311
+ self.fc = nn.Linear(block_size * self.block.expansion, classes)
312
+
313
+ for m in self.modules():
314
+ if isinstance(m, nn.Conv1d):
315
+ nn.init.kaiming_normal_(
316
+ m.weight, mode="fan_out", nonlinearity="relu"
317
+ )
318
+ elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
319
+ nn.init.constant_(m.weight, 1)
320
+ nn.init.constant_(m.bias, 0)
321
+
322
+ # Zero-initialize the last BN in each residual branch,
323
+ # so that the residual branch starts with zeros,
324
+ # and each residual block behaves like an identity.
325
+ # This improves the model by 0.2~0.3% according to
326
+ # https://arxiv.org/abs/1706.02677
327
+ if zero_init_residual:
328
+ for m in self.modules():
329
+ if isinstance(m, Bottleneck):
330
+ nn.init.constant_(m.bn3.weight, 0)
331
+ elif isinstance(m, BasicBlock):
332
+ nn.init.constant_(m.bn2.weight, 0)
333
+
334
+ def _make_layer(
335
+ self,
336
+ planes: int,
337
+ blocks: int,
338
+ kernel_size: int = 3,
339
+ stride: int = 1,
340
+ stride_type: Literal["stride", "dilation"] = "stride",
341
+ ) -> nn.Sequential:
342
+ block = self.block
343
+ norm_layer = self._norm_layer
344
+ downsample = None
345
+ previous_dilation = self.dilation
346
+
347
+ if stride_type == "dilation":
348
+ self.dilation *= stride
349
+ stride = 1
350
+ elif stride_type != "stride":
351
+ raise ValueError("Unknown stride type {stride}")
352
+
353
+ if stride != 1 or self.inplanes != planes * block.expansion:
354
+ downsample = nn.Sequential(
355
+ conv1(self.inplanes, planes * block.expansion, stride),
356
+ norm_layer(planes * block.expansion),
357
+ )
358
+
359
+ layers = []
360
+ layers.append(
361
+ block(
362
+ self.inplanes,
363
+ planes,
364
+ kernel_size,
365
+ stride,
366
+ downsample,
367
+ self.groups,
368
+ self.base_width,
369
+ previous_dilation,
370
+ norm_layer,
371
+ )
372
+ )
373
+ self.inplanes = planes * block.expansion
374
+ for _ in range(1, blocks):
375
+ layers.append(
376
+ block(
377
+ self.inplanes,
378
+ planes,
379
+ kernel_size,
380
+ groups=self.groups,
381
+ base_width=self.base_width,
382
+ dilation=self.dilation,
383
+ norm_layer=norm_layer,
384
+ )
385
+ )
386
+
387
+ return nn.Sequential(*layers)
388
+
389
+ def _forward_impl(self, x: Tensor) -> Tensor:
390
+ # See note [TorchScript super()]
391
+ x = self.conv1(x)
392
+ x = self.bn1(x)
393
+ x = self.relu(x)
394
+ x = self.maxpool(x)
395
+
396
+ for layer in self.residual_layers:
397
+ x = layer(x)
398
+
399
+ x = self.avgpool(x)
400
+ x = torch.flatten(x, 1)
401
+ x = self.fc(x)
402
+
403
+ return x
404
+
405
+ def forward(self, x: Tensor) -> Tensor:
406
+ return self._forward_impl(x)
407
+
408
+
409
+ # TODO: implement as arg of ResNet instead?
410
+ class BottleneckResNet1D(ResNet1D):
411
+ """A version of ResNet that uses bottleneck blocks"""
412
+
413
+ block = Bottleneck