ml4gw 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- {ml4gw-0.3.0 → ml4gw-0.4.0}/PKG-INFO +14 -7
- {ml4gw-0.3.0 → ml4gw-0.4.0}/README.md +8 -3
- ml4gw-0.4.0/ml4gw/nn/norm.py +97 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/__init__.py +2 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/resnet_1d.py +413 -0
- ml4gw-0.4.0/ml4gw/nn/resnet/resnet_2d.py +413 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/__init__.py +1 -0
- ml4gw-0.4.0/ml4gw/transforms/spectrogram.py +162 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/whitening.py +1 -1
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/phenom_d.py +1 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/pyproject.toml +15 -4
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/augmentations.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/dataloading/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/dataloading/chunked_dataset.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/dataloading/hdf5_dataset.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/dataloading/in_memory_dataset.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/distributions.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/gw.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/autoencoder/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/autoencoder/base.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/autoencoder/convolutional.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/autoencoder/utils.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/streaming/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/streaming/online_average.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/nn/streaming/snapshotter.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/spectral.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/pearson.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/scaler.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/snr_rescaler.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/spectral.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/transform.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/transforms/waveforms.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/types.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/utils/interferometer.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/utils/slicing.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/__init__.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/generator.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/phenom_d_data.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/sine_gaussian.py +0 -0
- {ml4gw-0.3.0 → ml4gw-0.4.0}/ml4gw/waveforms/taylorf2.py +0 -0
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author: Alec Gunny
|
|
6
6
|
Author-email: alec.gunny@ligo.org
|
|
7
|
-
Requires-Python: >=3.8,<
|
|
7
|
+
Requires-Python: >=3.8,<3.12
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.8
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.9
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
|
|
14
|
-
Requires-Dist: torch (>=
|
|
13
|
+
Requires-Dist: torch (>=1.10,<2.0) ; python_version >= "3.8" and python_version < "3.11"
|
|
14
|
+
Requires-Dist: torch (>=2.0,<3.0) ; python_version >= "3.11"
|
|
15
|
+
Requires-Dist: torchaudio (>=0.13,<0.14) ; python_version >= "3.8" and python_version < "3.11"
|
|
16
|
+
Requires-Dist: torchaudio (>=2.0,<3.0) ; python_version >= "3.11"
|
|
15
17
|
Requires-Dist: torchtyping (>=0.1,<0.2)
|
|
16
18
|
Description-Content-Type: text/markdown
|
|
17
19
|
|
|
@@ -38,8 +40,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
|
|
|
38
40
|
|
|
39
41
|
```toml
|
|
40
42
|
[tool.poetry.dependencies]
|
|
41
|
-
python = "^3.8" # python versions 3.8-3.
|
|
42
|
-
ml4gw = "^0.
|
|
43
|
+
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
44
|
+
ml4gw = "^0.3.0"
|
|
43
45
|
```
|
|
44
46
|
|
|
45
47
|
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
@@ -47,7 +49,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
|
|
|
47
49
|
```toml
|
|
48
50
|
[tool.poetry.dependencies]
|
|
49
51
|
python = "^3.8"
|
|
50
|
-
ml4gw = "^0.
|
|
52
|
+
ml4gw = "^0.3.0"
|
|
51
53
|
torch = {version = "^1.12", source = "torch"}
|
|
52
54
|
|
|
53
55
|
[[tool.poetry.source]]
|
|
@@ -57,6 +59,8 @@ secondary = true
|
|
|
57
59
|
default = false
|
|
58
60
|
```
|
|
59
61
|
|
|
62
|
+
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
63
|
+
|
|
60
64
|
## Use cases
|
|
61
65
|
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
62
66
|
|
|
@@ -146,3 +150,6 @@ We also strongly encourage ML users in the GW physics space to try their hand at
|
|
|
146
150
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
147
151
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
148
152
|
|
|
153
|
+
## Funding
|
|
154
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
155
|
+
|
|
@@ -21,8 +21,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
|
|
|
21
21
|
|
|
22
22
|
```toml
|
|
23
23
|
[tool.poetry.dependencies]
|
|
24
|
-
python = "^3.8" # python versions 3.8-3.
|
|
25
|
-
ml4gw = "^0.
|
|
24
|
+
python = "^3.8" # python versions 3.8-3.11 are supported
|
|
25
|
+
ml4gw = "^0.3.0"
|
|
26
26
|
```
|
|
27
27
|
|
|
28
28
|
To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
|
|
@@ -30,7 +30,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
|
|
|
30
30
|
```toml
|
|
31
31
|
[tool.poetry.dependencies]
|
|
32
32
|
python = "^3.8"
|
|
33
|
-
ml4gw = "^0.
|
|
33
|
+
ml4gw = "^0.3.0"
|
|
34
34
|
torch = {version = "^1.12", source = "torch"}
|
|
35
35
|
|
|
36
36
|
[[tool.poetry.source]]
|
|
@@ -40,6 +40,8 @@ secondary = true
|
|
|
40
40
|
default = false
|
|
41
41
|
```
|
|
42
42
|
|
|
43
|
+
Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
|
|
44
|
+
|
|
43
45
|
## Use cases
|
|
44
46
|
This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
|
|
45
47
|
|
|
@@ -128,3 +130,6 @@ We encourage users who encounter these difficulties to file issues on GitHub, an
|
|
|
128
130
|
We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
|
|
129
131
|
For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
|
|
130
132
|
By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
|
|
133
|
+
|
|
134
|
+
## Funding
|
|
135
|
+
We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
NormLayer = Callable[[int], torch.nn.Module]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GroupNorm1D(torch.nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Custom implementation of GroupNorm which is faster than the
|
|
11
|
+
out-of-the-box PyTorch version at inference time.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
num_channels: int,
|
|
17
|
+
num_groups: Optional[int] = None,
|
|
18
|
+
eps: float = 1e-5,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
num_groups = num_groups or num_channels
|
|
22
|
+
if num_channels % num_groups:
|
|
23
|
+
raise ValueError("num_groups must be a factor of num_channels")
|
|
24
|
+
|
|
25
|
+
self.num_channels = num_channels
|
|
26
|
+
self.num_groups = num_groups
|
|
27
|
+
self.channels_per_group = self.num_channels // self.num_groups
|
|
28
|
+
self.eps = eps
|
|
29
|
+
|
|
30
|
+
shape = (self.num_channels, 1)
|
|
31
|
+
self.weight = torch.nn.Parameter(torch.ones(shape))
|
|
32
|
+
self.bias = torch.nn.Parameter(torch.zeros(shape))
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
keepdims = self.num_groups == self.num_channels
|
|
36
|
+
|
|
37
|
+
# compute group variance via the E[x**2] - E**2[x] trick
|
|
38
|
+
mean = x.mean(-1, keepdims=keepdims)
|
|
39
|
+
sq_mean = (x**2).mean(-1, keepdims=keepdims)
|
|
40
|
+
|
|
41
|
+
# if we have groups, do some reshape magic
|
|
42
|
+
# to calculate group level stats then
|
|
43
|
+
# reshape back to full channel dimension
|
|
44
|
+
if self.num_groups != self.num_channels:
|
|
45
|
+
mean = torch.stack([mean, sq_mean], dim=1)
|
|
46
|
+
mean = mean.reshape(
|
|
47
|
+
-1, 2, self.num_groups, self.channels_per_group
|
|
48
|
+
)
|
|
49
|
+
mean = mean.mean(-1, keepdims=True)
|
|
50
|
+
mean = mean.expand(-1, -1, -1, self.channels_per_group)
|
|
51
|
+
mean = mean.reshape(-1, 2, self.num_channels, 1)
|
|
52
|
+
mean, sq_mean = mean[:, 0], mean[:, 1]
|
|
53
|
+
|
|
54
|
+
# roll the mean and variance into the
|
|
55
|
+
# weight and bias so that we have to do
|
|
56
|
+
# fewer computations along the full time axis
|
|
57
|
+
std = (sq_mean - mean**2 + self.eps) ** 0.5
|
|
58
|
+
scale = self.weight / std
|
|
59
|
+
shift = self.bias - scale * mean
|
|
60
|
+
return shift + x * scale
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GroupNorm1DGetter:
|
|
64
|
+
"""
|
|
65
|
+
Utility for making a NormLayer Callable that maps from
|
|
66
|
+
an integer number of channels to a torch Module. Useful
|
|
67
|
+
for command-line parameterization with jsonargparse.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
71
|
+
self.groups = groups
|
|
72
|
+
|
|
73
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
74
|
+
if self.groups is None:
|
|
75
|
+
num_groups = None
|
|
76
|
+
else:
|
|
77
|
+
num_groups = min(num_channels, self.groups)
|
|
78
|
+
return GroupNorm1D(num_channels, num_groups)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# TODO generalize faster 1dDGroupNorm to 2D
|
|
82
|
+
class GroupNorm2DGetter:
|
|
83
|
+
"""
|
|
84
|
+
Utility for making a NormLayer Callable that maps from
|
|
85
|
+
an integer number of channels to a torch Module. Useful
|
|
86
|
+
for command-line parameterization with jsonargparse.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, groups: Optional[int] = None) -> None:
|
|
90
|
+
self.groups = groups
|
|
91
|
+
|
|
92
|
+
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
93
|
+
if self.groups is None:
|
|
94
|
+
num_groups = num_channels
|
|
95
|
+
else:
|
|
96
|
+
num_groups = min(num_channels, self.groups)
|
|
97
|
+
return torch.nn.GroupNorm(num_groups, num_channels)
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In large part lifted from
|
|
3
|
+
https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
|
4
|
+
but with 1d convolutions and arbitrary kernel sizes, and a
|
|
5
|
+
default norm layer that makes more sense for most GW applications
|
|
6
|
+
where training-time statistics are entirely arbitrary due to
|
|
7
|
+
simulations.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Callable, List, Literal, Optional
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def convN(
|
|
20
|
+
in_planes: int,
|
|
21
|
+
out_planes: int,
|
|
22
|
+
kernel_size: int = 3,
|
|
23
|
+
stride: int = 1,
|
|
24
|
+
groups: int = 1,
|
|
25
|
+
dilation: int = 1,
|
|
26
|
+
) -> nn.Conv1d:
|
|
27
|
+
"""1d convolution with padding"""
|
|
28
|
+
if not kernel_size % 2:
|
|
29
|
+
raise ValueError("Can't use even sized kernels")
|
|
30
|
+
|
|
31
|
+
return nn.Conv1d(
|
|
32
|
+
in_planes,
|
|
33
|
+
out_planes,
|
|
34
|
+
kernel_size=kernel_size,
|
|
35
|
+
stride=stride,
|
|
36
|
+
padding=dilation * int(kernel_size // 2),
|
|
37
|
+
groups=groups,
|
|
38
|
+
bias=False,
|
|
39
|
+
dilation=dilation,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
|
|
44
|
+
"""Kernel-size 1 convolution"""
|
|
45
|
+
return nn.Conv1d(
|
|
46
|
+
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BasicBlock(nn.Module):
|
|
51
|
+
"""Defines the structure of the blocks used to build the ResNet"""
|
|
52
|
+
|
|
53
|
+
expansion: int = 1
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
inplanes: int,
|
|
58
|
+
planes: int,
|
|
59
|
+
kernel_size: int = 3,
|
|
60
|
+
stride: int = 1,
|
|
61
|
+
downsample: Optional[nn.Module] = None,
|
|
62
|
+
groups: int = 1,
|
|
63
|
+
base_width: int = 64,
|
|
64
|
+
dilation: int = 1,
|
|
65
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
|
|
68
|
+
super().__init__()
|
|
69
|
+
if norm_layer is None:
|
|
70
|
+
norm_layer = nn.BatchNorm1d
|
|
71
|
+
if groups != 1 or base_width != 64:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"BasicBlock only supports groups=1 and base_width=64"
|
|
74
|
+
)
|
|
75
|
+
if dilation > 1:
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
"Dilation > 1 not supported in BasicBlock"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Both self.conv1 and self.downsample layers
|
|
81
|
+
# downsample the input when stride != 1
|
|
82
|
+
self.conv1 = convN(inplanes, planes, kernel_size, stride)
|
|
83
|
+
self.bn1 = norm_layer(planes)
|
|
84
|
+
self.relu = nn.ReLU(inplace=True)
|
|
85
|
+
self.conv2 = convN(planes, planes, kernel_size)
|
|
86
|
+
self.bn2 = norm_layer(planes)
|
|
87
|
+
self.downsample = downsample
|
|
88
|
+
self.stride = stride
|
|
89
|
+
|
|
90
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
91
|
+
identity = x
|
|
92
|
+
|
|
93
|
+
out = self.conv1(x)
|
|
94
|
+
out = self.bn1(out)
|
|
95
|
+
out = self.relu(out)
|
|
96
|
+
|
|
97
|
+
out = self.conv2(out)
|
|
98
|
+
out = self.bn2(out)
|
|
99
|
+
|
|
100
|
+
if self.downsample is not None:
|
|
101
|
+
identity = self.downsample(x)
|
|
102
|
+
|
|
103
|
+
out += identity
|
|
104
|
+
out = self.relu(out)
|
|
105
|
+
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Bottleneck(nn.Module):
|
|
110
|
+
"""
|
|
111
|
+
Bottleneck blocks implement one extra convolution
|
|
112
|
+
compared to basic blocks. In this layers, the `planes`
|
|
113
|
+
parameter is generally meant to _downsize_ the number
|
|
114
|
+
of feature maps first, which then get expanded out to
|
|
115
|
+
`planes * Bottleneck.expansion` feature maps at the
|
|
116
|
+
output of the layer.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
expansion: int = 4
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
inplanes: int,
|
|
124
|
+
planes: int,
|
|
125
|
+
kernel_size: int = 3,
|
|
126
|
+
stride: int = 1,
|
|
127
|
+
downsample: Optional[nn.Module] = None,
|
|
128
|
+
groups: int = 1,
|
|
129
|
+
base_width: int = 64,
|
|
130
|
+
dilation: int = 1,
|
|
131
|
+
norm_layer: Optional[NormLayer] = None,
|
|
132
|
+
) -> None:
|
|
133
|
+
super().__init__()
|
|
134
|
+
if norm_layer is None:
|
|
135
|
+
norm_layer = nn.BatchNorm1d
|
|
136
|
+
|
|
137
|
+
width = int(planes * (base_width / 64.0)) * groups
|
|
138
|
+
|
|
139
|
+
# conv1 does no downsampling, just reduces the number of
|
|
140
|
+
# feature maps from inplanes to width (where width == planes)
|
|
141
|
+
# if groups == 1 and base_width == 64
|
|
142
|
+
self.conv1 = convN(inplanes, width, kernel_size)
|
|
143
|
+
self.bn1 = norm_layer(width)
|
|
144
|
+
|
|
145
|
+
# conv2 keeps the same number of feature maps,
|
|
146
|
+
# but downsamples along the time axis if stride
|
|
147
|
+
# or dilation > 1
|
|
148
|
+
self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
|
|
149
|
+
self.bn2 = norm_layer(width)
|
|
150
|
+
|
|
151
|
+
# conv3 expands the feature maps back out to planes * expansion
|
|
152
|
+
self.conv3 = conv1(width, planes * self.expansion)
|
|
153
|
+
self.bn3 = norm_layer(planes * self.expansion)
|
|
154
|
+
|
|
155
|
+
self.relu = nn.ReLU(inplace=True)
|
|
156
|
+
self.downsample = downsample
|
|
157
|
+
self.stride = stride
|
|
158
|
+
|
|
159
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
160
|
+
identity = x
|
|
161
|
+
|
|
162
|
+
out = self.conv1(x)
|
|
163
|
+
out = self.bn1(out)
|
|
164
|
+
out = self.relu(out)
|
|
165
|
+
|
|
166
|
+
out = self.conv2(out)
|
|
167
|
+
out = self.bn2(out)
|
|
168
|
+
out = self.relu(out)
|
|
169
|
+
|
|
170
|
+
out = self.conv3(out)
|
|
171
|
+
out = self.bn3(out)
|
|
172
|
+
|
|
173
|
+
if self.downsample is not None:
|
|
174
|
+
identity = self.downsample(x)
|
|
175
|
+
|
|
176
|
+
out += identity
|
|
177
|
+
out = self.relu(out)
|
|
178
|
+
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ResNet1D(nn.Module):
|
|
183
|
+
"""1D ResNet architecture
|
|
184
|
+
|
|
185
|
+
Simple extension of ResNet to 1D convolutions with
|
|
186
|
+
arbitrary kernel sizes to support the longer timeseries
|
|
187
|
+
used in BBH detection.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
in_channels:
|
|
191
|
+
The number of channels in input tensor.
|
|
192
|
+
layers:
|
|
193
|
+
A list representing the number of residual
|
|
194
|
+
blocks to include in each "layer" of the
|
|
195
|
+
network. Total layers (e.g. 50 in ResNet50)
|
|
196
|
+
is `2 + sum(layers) * factor`, where factor
|
|
197
|
+
is `2` for vanilla `ResNet` and `3` for
|
|
198
|
+
`BottleneckResNet`.
|
|
199
|
+
kernel_size:
|
|
200
|
+
The size of the convolutional kernel to
|
|
201
|
+
use in all residual layers. _NOT_ the size
|
|
202
|
+
of the input kernel to the network, which
|
|
203
|
+
is determined at run-time.
|
|
204
|
+
zero_init_residual:
|
|
205
|
+
Flag indicating whether to initialize the
|
|
206
|
+
weights of the batch-norm layer in each block
|
|
207
|
+
to 0 so that residuals are initialized as
|
|
208
|
+
identities. Can improve training results.
|
|
209
|
+
groups:
|
|
210
|
+
Number of convolutional groups to use in all
|
|
211
|
+
layers. Grouped convolutions induce local
|
|
212
|
+
connections between feature maps at subsequent
|
|
213
|
+
layers rather than global. Generally won't
|
|
214
|
+
need this to be >1, and wil raise an error if
|
|
215
|
+
>1 when using vanilla `ResNet`.
|
|
216
|
+
width_per_group:
|
|
217
|
+
Base width of each of the feature map groups,
|
|
218
|
+
which is scaled up by the typical expansion
|
|
219
|
+
factor at each layer of the network. Meaningless
|
|
220
|
+
for vanilla `ResNet`.
|
|
221
|
+
stride_type:
|
|
222
|
+
Whether to achieve downsampling on the time axis
|
|
223
|
+
by strided or dilated convolutions for each layer.
|
|
224
|
+
If left as `None`, strided convolutions will be
|
|
225
|
+
used at each layer. Otherwise, `stride_type` should
|
|
226
|
+
be one element shorter than `layers` and indicate either
|
|
227
|
+
`stride` or `dilation` for each layer after the first.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
block = BasicBlock
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
in_channels: int,
|
|
235
|
+
layers: List[int],
|
|
236
|
+
classes: int,
|
|
237
|
+
kernel_size: int = 3,
|
|
238
|
+
zero_init_residual: bool = False,
|
|
239
|
+
groups: int = 1,
|
|
240
|
+
width_per_group: int = 64,
|
|
241
|
+
stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
|
|
242
|
+
norm_layer: Optional[NormLayer] = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
super().__init__()
|
|
245
|
+
|
|
246
|
+
self.inplanes = 64
|
|
247
|
+
self.dilation = 1
|
|
248
|
+
|
|
249
|
+
# default to using InstanceNorm if no
|
|
250
|
+
# norm layer is provided explicitly
|
|
251
|
+
self._norm_layer = norm_layer or GroupNorm1DGetter()
|
|
252
|
+
|
|
253
|
+
# TODO: should we support passing a single string
|
|
254
|
+
# for simplicity here?
|
|
255
|
+
if stride_type is None:
|
|
256
|
+
# each element in the tuple indicates if we should replace
|
|
257
|
+
# the stride with a dilated convolution instead
|
|
258
|
+
stride_type = ["stride"] * (len(layers) - 1)
|
|
259
|
+
if len(stride_type) != (len(layers) - 1):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"'stride_type' should be None or a "
|
|
262
|
+
"{}-element tuple, got {}".format(len(layers) - 1, stride_type)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.groups = groups
|
|
266
|
+
self.base_width = width_per_group
|
|
267
|
+
|
|
268
|
+
# start with a basic conv-bn-relu-maxpool block
|
|
269
|
+
# to reduce the dimensionality before the heavy
|
|
270
|
+
# lifting starts
|
|
271
|
+
self.conv1 = nn.Conv1d(
|
|
272
|
+
in_channels,
|
|
273
|
+
self.inplanes,
|
|
274
|
+
kernel_size=7,
|
|
275
|
+
stride=2,
|
|
276
|
+
padding=3,
|
|
277
|
+
bias=False,
|
|
278
|
+
)
|
|
279
|
+
self.bn1 = self._norm_layer(self.inplanes)
|
|
280
|
+
self.relu = nn.ReLU(inplace=True)
|
|
281
|
+
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
|
282
|
+
|
|
283
|
+
# now create layers of residual blocks where each
|
|
284
|
+
# layer uses the same number of feature maps for
|
|
285
|
+
# all its blocks (some power of 2 times 64).
|
|
286
|
+
# Don't downsample along the time axis in the first
|
|
287
|
+
# layer, but downsample in all the rest (either by
|
|
288
|
+
# striding or dilating depending on the stride_type
|
|
289
|
+
# argument)
|
|
290
|
+
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
291
|
+
it = zip(layers[1:], stride_type)
|
|
292
|
+
for i, (num_blocks, stride) in enumerate(it):
|
|
293
|
+
block_size = 64 * 2 ** (i + 1)
|
|
294
|
+
layer = self._make_layer(
|
|
295
|
+
block_size,
|
|
296
|
+
num_blocks,
|
|
297
|
+
kernel_size,
|
|
298
|
+
stride=2,
|
|
299
|
+
stride_type=stride,
|
|
300
|
+
)
|
|
301
|
+
residual_layers.append(layer)
|
|
302
|
+
self.residual_layers = nn.ModuleList(residual_layers)
|
|
303
|
+
|
|
304
|
+
# Average pool over each feature map to create a
|
|
305
|
+
# single value for each feature map that we'll use
|
|
306
|
+
# in the fully connected head
|
|
307
|
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
|
308
|
+
|
|
309
|
+
# use a fully connected layer to map from the
|
|
310
|
+
# feature maps to the binary output that we need
|
|
311
|
+
self.fc = nn.Linear(block_size * self.block.expansion, classes)
|
|
312
|
+
|
|
313
|
+
for m in self.modules():
|
|
314
|
+
if isinstance(m, nn.Conv1d):
|
|
315
|
+
nn.init.kaiming_normal_(
|
|
316
|
+
m.weight, mode="fan_out", nonlinearity="relu"
|
|
317
|
+
)
|
|
318
|
+
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
|
|
319
|
+
nn.init.constant_(m.weight, 1)
|
|
320
|
+
nn.init.constant_(m.bias, 0)
|
|
321
|
+
|
|
322
|
+
# Zero-initialize the last BN in each residual branch,
|
|
323
|
+
# so that the residual branch starts with zeros,
|
|
324
|
+
# and each residual block behaves like an identity.
|
|
325
|
+
# This improves the model by 0.2~0.3% according to
|
|
326
|
+
# https://arxiv.org/abs/1706.02677
|
|
327
|
+
if zero_init_residual:
|
|
328
|
+
for m in self.modules():
|
|
329
|
+
if isinstance(m, Bottleneck):
|
|
330
|
+
nn.init.constant_(m.bn3.weight, 0)
|
|
331
|
+
elif isinstance(m, BasicBlock):
|
|
332
|
+
nn.init.constant_(m.bn2.weight, 0)
|
|
333
|
+
|
|
334
|
+
def _make_layer(
|
|
335
|
+
self,
|
|
336
|
+
planes: int,
|
|
337
|
+
blocks: int,
|
|
338
|
+
kernel_size: int = 3,
|
|
339
|
+
stride: int = 1,
|
|
340
|
+
stride_type: Literal["stride", "dilation"] = "stride",
|
|
341
|
+
) -> nn.Sequential:
|
|
342
|
+
block = self.block
|
|
343
|
+
norm_layer = self._norm_layer
|
|
344
|
+
downsample = None
|
|
345
|
+
previous_dilation = self.dilation
|
|
346
|
+
|
|
347
|
+
if stride_type == "dilation":
|
|
348
|
+
self.dilation *= stride
|
|
349
|
+
stride = 1
|
|
350
|
+
elif stride_type != "stride":
|
|
351
|
+
raise ValueError("Unknown stride type {stride}")
|
|
352
|
+
|
|
353
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
354
|
+
downsample = nn.Sequential(
|
|
355
|
+
conv1(self.inplanes, planes * block.expansion, stride),
|
|
356
|
+
norm_layer(planes * block.expansion),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
layers = []
|
|
360
|
+
layers.append(
|
|
361
|
+
block(
|
|
362
|
+
self.inplanes,
|
|
363
|
+
planes,
|
|
364
|
+
kernel_size,
|
|
365
|
+
stride,
|
|
366
|
+
downsample,
|
|
367
|
+
self.groups,
|
|
368
|
+
self.base_width,
|
|
369
|
+
previous_dilation,
|
|
370
|
+
norm_layer,
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
self.inplanes = planes * block.expansion
|
|
374
|
+
for _ in range(1, blocks):
|
|
375
|
+
layers.append(
|
|
376
|
+
block(
|
|
377
|
+
self.inplanes,
|
|
378
|
+
planes,
|
|
379
|
+
kernel_size,
|
|
380
|
+
groups=self.groups,
|
|
381
|
+
base_width=self.base_width,
|
|
382
|
+
dilation=self.dilation,
|
|
383
|
+
norm_layer=norm_layer,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return nn.Sequential(*layers)
|
|
388
|
+
|
|
389
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
390
|
+
# See note [TorchScript super()]
|
|
391
|
+
x = self.conv1(x)
|
|
392
|
+
x = self.bn1(x)
|
|
393
|
+
x = self.relu(x)
|
|
394
|
+
x = self.maxpool(x)
|
|
395
|
+
|
|
396
|
+
for layer in self.residual_layers:
|
|
397
|
+
x = layer(x)
|
|
398
|
+
|
|
399
|
+
x = self.avgpool(x)
|
|
400
|
+
x = torch.flatten(x, 1)
|
|
401
|
+
x = self.fc(x)
|
|
402
|
+
|
|
403
|
+
return x
|
|
404
|
+
|
|
405
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
406
|
+
return self._forward_impl(x)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# TODO: implement as arg of ResNet instead?
|
|
410
|
+
class BottleneckResNet1D(ResNet1D):
|
|
411
|
+
"""A version of ResNet that uses bottleneck blocks"""
|
|
412
|
+
|
|
413
|
+
block = Bottleneck
|