ml4gw 0.3.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (43) hide show
  1. {ml4gw-0.3.0 → ml4gw-0.4.1}/PKG-INFO +12 -7
  2. {ml4gw-0.3.0 → ml4gw-0.4.1}/README.md +8 -3
  3. ml4gw-0.4.1/ml4gw/nn/norm.py +97 -0
  4. ml4gw-0.4.1/ml4gw/nn/resnet/__init__.py +2 -0
  5. ml4gw-0.4.1/ml4gw/nn/resnet/resnet_1d.py +413 -0
  6. ml4gw-0.4.1/ml4gw/nn/resnet/resnet_2d.py +413 -0
  7. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/__init__.py +1 -0
  8. ml4gw-0.4.1/ml4gw/transforms/spectrogram.py +162 -0
  9. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/whitening.py +1 -1
  10. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/phenom_d.py +1 -0
  11. {ml4gw-0.3.0 → ml4gw-0.4.1}/pyproject.toml +9 -4
  12. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/__init__.py +0 -0
  13. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/augmentations.py +0 -0
  14. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/dataloading/__init__.py +0 -0
  15. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/dataloading/chunked_dataset.py +0 -0
  16. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/dataloading/hdf5_dataset.py +0 -0
  17. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/dataloading/in_memory_dataset.py +0 -0
  18. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/distributions.py +0 -0
  19. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/gw.py +0 -0
  20. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/__init__.py +0 -0
  21. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/autoencoder/__init__.py +0 -0
  22. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/autoencoder/base.py +0 -0
  23. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/autoencoder/convolutional.py +0 -0
  24. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
  25. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/autoencoder/utils.py +0 -0
  26. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/streaming/__init__.py +0 -0
  27. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/streaming/online_average.py +0 -0
  28. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/nn/streaming/snapshotter.py +0 -0
  29. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/spectral.py +0 -0
  30. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/pearson.py +0 -0
  31. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/scaler.py +0 -0
  32. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/snr_rescaler.py +0 -0
  33. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/spectral.py +0 -0
  34. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/transform.py +0 -0
  35. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/transforms/waveforms.py +0 -0
  36. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/types.py +0 -0
  37. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/utils/interferometer.py +0 -0
  38. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/utils/slicing.py +0 -0
  39. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/__init__.py +0 -0
  40. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/generator.py +0 -0
  41. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/phenom_d_data.py +0 -0
  42. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/sine_gaussian.py +0 -0
  43. {ml4gw-0.3.0 → ml4gw-0.4.1}/ml4gw/waveforms/taylorf2.py +0 -0
@@ -1,17 +1,17 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ml4gw
3
- Version: 0.3.0
3
+ Version: 0.4.1
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author: Alec Gunny
6
6
  Author-email: alec.gunny@ligo.org
7
- Requires-Python: >=3.8,<4.0
7
+ Requires-Python: >=3.8,<3.12
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.8
10
10
  Classifier: Programming Language :: Python :: 3.9
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
- Classifier: Programming Language :: Python :: 3.12
14
- Requires-Dist: torch (>=1.10,<2.0)
13
+ Requires-Dist: torch (>=2.0,<3.0)
14
+ Requires-Dist: torchaudio (>=2.0,<3.0)
15
15
  Requires-Dist: torchtyping (>=0.1,<0.2)
16
16
  Description-Content-Type: text/markdown
17
17
 
@@ -38,8 +38,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
38
38
 
39
39
  ```toml
40
40
  [tool.poetry.dependencies]
41
- python = "^3.8" # python versions 3.8-3.10 are supported
42
- ml4gw = "^0.1.0"
41
+ python = "^3.8" # python versions 3.8-3.11 are supported
42
+ ml4gw = "^0.3.0"
43
43
  ```
44
44
 
45
45
  To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
@@ -47,7 +47,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
47
47
  ```toml
48
48
  [tool.poetry.dependencies]
49
49
  python = "^3.8"
50
- ml4gw = "^0.1.0"
50
+ ml4gw = "^0.3.0"
51
51
  torch = {version = "^1.12", source = "torch"}
52
52
 
53
53
  [[tool.poetry.source]]
@@ -57,6 +57,8 @@ secondary = true
57
57
  default = false
58
58
  ```
59
59
 
60
+ Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
61
+
60
62
  ## Use cases
61
63
  This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
62
64
 
@@ -146,3 +148,6 @@ We also strongly encourage ML users in the GW physics space to try their hand at
146
148
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
147
149
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
148
150
 
151
+ ## Funding
152
+ We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
153
+
@@ -21,8 +21,8 @@ pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/w
21
21
 
22
22
  ```toml
23
23
  [tool.poetry.dependencies]
24
- python = "^3.8" # python versions 3.8-3.10 are supported
25
- ml4gw = "^0.1.0"
24
+ python = "^3.8" # python versions 3.8-3.11 are supported
25
+ ml4gw = "^0.3.0"
26
26
  ```
27
27
 
28
28
  To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the `extra-index-url` via the `tool.poetry.source` table in your `pyproject.toml`. For example, to build against CUDA 11.6, you would do something like:
@@ -30,7 +30,7 @@ To build against a specific PyTorch/CUDA combination, consult the PyTorch instal
30
30
  ```toml
31
31
  [tool.poetry.dependencies]
32
32
  python = "^3.8"
33
- ml4gw = "^0.1.0"
33
+ ml4gw = "^0.3.0"
34
34
  torch = {version = "^1.12", source = "torch"}
35
35
 
36
36
  [[tool.poetry.source]]
@@ -40,6 +40,8 @@ secondary = true
40
40
  default = false
41
41
  ```
42
42
 
43
+ Note: if you are building against CUDA 11.6 or 11.7, make sure that you are using python 3.8, 3.9, or 3.10. Python 3.11 is incompatible with `torchaudio` 0.13, and the following `torchaudio` version is incompatible with CUDA 11.7 and earlier.
44
+
43
45
  ## Use cases
44
46
  This library provided utilities for both data iteration and transformation via dataloaders defined in `ml4gw/dataloading` and transform layers exposed in `ml4gw/transforms`. Lower level functions and utilies are defined at the top level of the library and in the `utils` library.
45
47
 
@@ -128,3 +130,6 @@ We encourage users who encounter these difficulties to file issues on GitHub, an
128
130
  We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators!
129
131
  For more information about how to get involved, feel free to reach out to [ml4gw@ligo.mit.edu](mailto:ml4gw@ligo.mit.edu) .
130
132
  By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.
133
+
134
+ ## Funding
135
+ We are grateful for the support of the U.S. National Science Foundation (NSF) Harnessing the Data Revolution (HDR) Institute for <a href="https://a3d3.ai">Accelerating AI Algorithms for Data Driven Discovery (A3D3)</a> under Cooperative Agreement No. <a href="https://www.nsf.gov/awardsearch/showAward?AWD_ID=2117997">PHY-2117997</a>.
@@ -0,0 +1,97 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+
5
+ NormLayer = Callable[[int], torch.nn.Module]
6
+
7
+
8
+ class GroupNorm1D(torch.nn.Module):
9
+ """
10
+ Custom implementation of GroupNorm which is faster than the
11
+ out-of-the-box PyTorch version at inference time.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ num_channels: int,
17
+ num_groups: Optional[int] = None,
18
+ eps: float = 1e-5,
19
+ ):
20
+ super().__init__()
21
+ num_groups = num_groups or num_channels
22
+ if num_channels % num_groups:
23
+ raise ValueError("num_groups must be a factor of num_channels")
24
+
25
+ self.num_channels = num_channels
26
+ self.num_groups = num_groups
27
+ self.channels_per_group = self.num_channels // self.num_groups
28
+ self.eps = eps
29
+
30
+ shape = (self.num_channels, 1)
31
+ self.weight = torch.nn.Parameter(torch.ones(shape))
32
+ self.bias = torch.nn.Parameter(torch.zeros(shape))
33
+
34
+ def forward(self, x):
35
+ keepdims = self.num_groups == self.num_channels
36
+
37
+ # compute group variance via the E[x**2] - E**2[x] trick
38
+ mean = x.mean(-1, keepdims=keepdims)
39
+ sq_mean = (x**2).mean(-1, keepdims=keepdims)
40
+
41
+ # if we have groups, do some reshape magic
42
+ # to calculate group level stats then
43
+ # reshape back to full channel dimension
44
+ if self.num_groups != self.num_channels:
45
+ mean = torch.stack([mean, sq_mean], dim=1)
46
+ mean = mean.reshape(
47
+ -1, 2, self.num_groups, self.channels_per_group
48
+ )
49
+ mean = mean.mean(-1, keepdims=True)
50
+ mean = mean.expand(-1, -1, -1, self.channels_per_group)
51
+ mean = mean.reshape(-1, 2, self.num_channels, 1)
52
+ mean, sq_mean = mean[:, 0], mean[:, 1]
53
+
54
+ # roll the mean and variance into the
55
+ # weight and bias so that we have to do
56
+ # fewer computations along the full time axis
57
+ std = (sq_mean - mean**2 + self.eps) ** 0.5
58
+ scale = self.weight / std
59
+ shift = self.bias - scale * mean
60
+ return shift + x * scale
61
+
62
+
63
+ class GroupNorm1DGetter:
64
+ """
65
+ Utility for making a NormLayer Callable that maps from
66
+ an integer number of channels to a torch Module. Useful
67
+ for command-line parameterization with jsonargparse.
68
+ """
69
+
70
+ def __init__(self, groups: Optional[int] = None) -> None:
71
+ self.groups = groups
72
+
73
+ def __call__(self, num_channels: int) -> torch.nn.Module:
74
+ if self.groups is None:
75
+ num_groups = None
76
+ else:
77
+ num_groups = min(num_channels, self.groups)
78
+ return GroupNorm1D(num_channels, num_groups)
79
+
80
+
81
+ # TODO generalize faster 1dDGroupNorm to 2D
82
+ class GroupNorm2DGetter:
83
+ """
84
+ Utility for making a NormLayer Callable that maps from
85
+ an integer number of channels to a torch Module. Useful
86
+ for command-line parameterization with jsonargparse.
87
+ """
88
+
89
+ def __init__(self, groups: Optional[int] = None) -> None:
90
+ self.groups = groups
91
+
92
+ def __call__(self, num_channels: int) -> torch.nn.Module:
93
+ if self.groups is None:
94
+ num_groups = num_channels
95
+ else:
96
+ num_groups = min(num_channels, self.groups)
97
+ return torch.nn.GroupNorm(num_groups, num_channels)
@@ -0,0 +1,2 @@
1
+ from .resnet_1d import ResNet1D
2
+ from .resnet_2d import ResNet2D
@@ -0,0 +1,413 @@
1
+ """
2
+ In large part lifted from
3
+ https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
4
+ but with 1d convolutions and arbitrary kernel sizes, and a
5
+ default norm layer that makes more sense for most GW applications
6
+ where training-time statistics are entirely arbitrary due to
7
+ simulations.
8
+ """
9
+
10
+ from typing import Callable, List, Literal, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+ from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
17
+
18
+
19
+ def convN(
20
+ in_planes: int,
21
+ out_planes: int,
22
+ kernel_size: int = 3,
23
+ stride: int = 1,
24
+ groups: int = 1,
25
+ dilation: int = 1,
26
+ ) -> nn.Conv1d:
27
+ """1d convolution with padding"""
28
+ if not kernel_size % 2:
29
+ raise ValueError("Can't use even sized kernels")
30
+
31
+ return nn.Conv1d(
32
+ in_planes,
33
+ out_planes,
34
+ kernel_size=kernel_size,
35
+ stride=stride,
36
+ padding=dilation * int(kernel_size // 2),
37
+ groups=groups,
38
+ bias=False,
39
+ dilation=dilation,
40
+ )
41
+
42
+
43
+ def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
44
+ """Kernel-size 1 convolution"""
45
+ return nn.Conv1d(
46
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
47
+ )
48
+
49
+
50
+ class BasicBlock(nn.Module):
51
+ """Defines the structure of the blocks used to build the ResNet"""
52
+
53
+ expansion: int = 1
54
+
55
+ def __init__(
56
+ self,
57
+ inplanes: int,
58
+ planes: int,
59
+ kernel_size: int = 3,
60
+ stride: int = 1,
61
+ downsample: Optional[nn.Module] = None,
62
+ groups: int = 1,
63
+ base_width: int = 64,
64
+ dilation: int = 1,
65
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
66
+ ) -> None:
67
+
68
+ super().__init__()
69
+ if norm_layer is None:
70
+ norm_layer = nn.BatchNorm1d
71
+ if groups != 1 or base_width != 64:
72
+ raise ValueError(
73
+ "BasicBlock only supports groups=1 and base_width=64"
74
+ )
75
+ if dilation > 1:
76
+ raise NotImplementedError(
77
+ "Dilation > 1 not supported in BasicBlock"
78
+ )
79
+
80
+ # Both self.conv1 and self.downsample layers
81
+ # downsample the input when stride != 1
82
+ self.conv1 = convN(inplanes, planes, kernel_size, stride)
83
+ self.bn1 = norm_layer(planes)
84
+ self.relu = nn.ReLU(inplace=True)
85
+ self.conv2 = convN(planes, planes, kernel_size)
86
+ self.bn2 = norm_layer(planes)
87
+ self.downsample = downsample
88
+ self.stride = stride
89
+
90
+ def forward(self, x: Tensor) -> Tensor:
91
+ identity = x
92
+
93
+ out = self.conv1(x)
94
+ out = self.bn1(out)
95
+ out = self.relu(out)
96
+
97
+ out = self.conv2(out)
98
+ out = self.bn2(out)
99
+
100
+ if self.downsample is not None:
101
+ identity = self.downsample(x)
102
+
103
+ out += identity
104
+ out = self.relu(out)
105
+
106
+ return out
107
+
108
+
109
+ class Bottleneck(nn.Module):
110
+ """
111
+ Bottleneck blocks implement one extra convolution
112
+ compared to basic blocks. In this layers, the `planes`
113
+ parameter is generally meant to _downsize_ the number
114
+ of feature maps first, which then get expanded out to
115
+ `planes * Bottleneck.expansion` feature maps at the
116
+ output of the layer.
117
+ """
118
+
119
+ expansion: int = 4
120
+
121
+ def __init__(
122
+ self,
123
+ inplanes: int,
124
+ planes: int,
125
+ kernel_size: int = 3,
126
+ stride: int = 1,
127
+ downsample: Optional[nn.Module] = None,
128
+ groups: int = 1,
129
+ base_width: int = 64,
130
+ dilation: int = 1,
131
+ norm_layer: Optional[NormLayer] = None,
132
+ ) -> None:
133
+ super().__init__()
134
+ if norm_layer is None:
135
+ norm_layer = nn.BatchNorm1d
136
+
137
+ width = int(planes * (base_width / 64.0)) * groups
138
+
139
+ # conv1 does no downsampling, just reduces the number of
140
+ # feature maps from inplanes to width (where width == planes)
141
+ # if groups == 1 and base_width == 64
142
+ self.conv1 = convN(inplanes, width, kernel_size)
143
+ self.bn1 = norm_layer(width)
144
+
145
+ # conv2 keeps the same number of feature maps,
146
+ # but downsamples along the time axis if stride
147
+ # or dilation > 1
148
+ self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
149
+ self.bn2 = norm_layer(width)
150
+
151
+ # conv3 expands the feature maps back out to planes * expansion
152
+ self.conv3 = conv1(width, planes * self.expansion)
153
+ self.bn3 = norm_layer(planes * self.expansion)
154
+
155
+ self.relu = nn.ReLU(inplace=True)
156
+ self.downsample = downsample
157
+ self.stride = stride
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ identity = x
161
+
162
+ out = self.conv1(x)
163
+ out = self.bn1(out)
164
+ out = self.relu(out)
165
+
166
+ out = self.conv2(out)
167
+ out = self.bn2(out)
168
+ out = self.relu(out)
169
+
170
+ out = self.conv3(out)
171
+ out = self.bn3(out)
172
+
173
+ if self.downsample is not None:
174
+ identity = self.downsample(x)
175
+
176
+ out += identity
177
+ out = self.relu(out)
178
+
179
+ return out
180
+
181
+
182
+ class ResNet1D(nn.Module):
183
+ """1D ResNet architecture
184
+
185
+ Simple extension of ResNet to 1D convolutions with
186
+ arbitrary kernel sizes to support the longer timeseries
187
+ used in BBH detection.
188
+
189
+ Args:
190
+ in_channels:
191
+ The number of channels in input tensor.
192
+ layers:
193
+ A list representing the number of residual
194
+ blocks to include in each "layer" of the
195
+ network. Total layers (e.g. 50 in ResNet50)
196
+ is `2 + sum(layers) * factor`, where factor
197
+ is `2` for vanilla `ResNet` and `3` for
198
+ `BottleneckResNet`.
199
+ kernel_size:
200
+ The size of the convolutional kernel to
201
+ use in all residual layers. _NOT_ the size
202
+ of the input kernel to the network, which
203
+ is determined at run-time.
204
+ zero_init_residual:
205
+ Flag indicating whether to initialize the
206
+ weights of the batch-norm layer in each block
207
+ to 0 so that residuals are initialized as
208
+ identities. Can improve training results.
209
+ groups:
210
+ Number of convolutional groups to use in all
211
+ layers. Grouped convolutions induce local
212
+ connections between feature maps at subsequent
213
+ layers rather than global. Generally won't
214
+ need this to be >1, and wil raise an error if
215
+ >1 when using vanilla `ResNet`.
216
+ width_per_group:
217
+ Base width of each of the feature map groups,
218
+ which is scaled up by the typical expansion
219
+ factor at each layer of the network. Meaningless
220
+ for vanilla `ResNet`.
221
+ stride_type:
222
+ Whether to achieve downsampling on the time axis
223
+ by strided or dilated convolutions for each layer.
224
+ If left as `None`, strided convolutions will be
225
+ used at each layer. Otherwise, `stride_type` should
226
+ be one element shorter than `layers` and indicate either
227
+ `stride` or `dilation` for each layer after the first.
228
+ """
229
+
230
+ block = BasicBlock
231
+
232
+ def __init__(
233
+ self,
234
+ in_channels: int,
235
+ layers: List[int],
236
+ classes: int,
237
+ kernel_size: int = 3,
238
+ zero_init_residual: bool = False,
239
+ groups: int = 1,
240
+ width_per_group: int = 64,
241
+ stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
242
+ norm_layer: Optional[NormLayer] = None,
243
+ ) -> None:
244
+ super().__init__()
245
+
246
+ self.inplanes = 64
247
+ self.dilation = 1
248
+
249
+ # default to using InstanceNorm if no
250
+ # norm layer is provided explicitly
251
+ self._norm_layer = norm_layer or GroupNorm1DGetter()
252
+
253
+ # TODO: should we support passing a single string
254
+ # for simplicity here?
255
+ if stride_type is None:
256
+ # each element in the tuple indicates if we should replace
257
+ # the stride with a dilated convolution instead
258
+ stride_type = ["stride"] * (len(layers) - 1)
259
+ if len(stride_type) != (len(layers) - 1):
260
+ raise ValueError(
261
+ "'stride_type' should be None or a "
262
+ "{}-element tuple, got {}".format(len(layers) - 1, stride_type)
263
+ )
264
+
265
+ self.groups = groups
266
+ self.base_width = width_per_group
267
+
268
+ # start with a basic conv-bn-relu-maxpool block
269
+ # to reduce the dimensionality before the heavy
270
+ # lifting starts
271
+ self.conv1 = nn.Conv1d(
272
+ in_channels,
273
+ self.inplanes,
274
+ kernel_size=7,
275
+ stride=2,
276
+ padding=3,
277
+ bias=False,
278
+ )
279
+ self.bn1 = self._norm_layer(self.inplanes)
280
+ self.relu = nn.ReLU(inplace=True)
281
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
282
+
283
+ # now create layers of residual blocks where each
284
+ # layer uses the same number of feature maps for
285
+ # all its blocks (some power of 2 times 64).
286
+ # Don't downsample along the time axis in the first
287
+ # layer, but downsample in all the rest (either by
288
+ # striding or dilating depending on the stride_type
289
+ # argument)
290
+ residual_layers = [self._make_layer(64, layers[0], kernel_size)]
291
+ it = zip(layers[1:], stride_type)
292
+ for i, (num_blocks, stride) in enumerate(it):
293
+ block_size = 64 * 2 ** (i + 1)
294
+ layer = self._make_layer(
295
+ block_size,
296
+ num_blocks,
297
+ kernel_size,
298
+ stride=2,
299
+ stride_type=stride,
300
+ )
301
+ residual_layers.append(layer)
302
+ self.residual_layers = nn.ModuleList(residual_layers)
303
+
304
+ # Average pool over each feature map to create a
305
+ # single value for each feature map that we'll use
306
+ # in the fully connected head
307
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
308
+
309
+ # use a fully connected layer to map from the
310
+ # feature maps to the binary output that we need
311
+ self.fc = nn.Linear(block_size * self.block.expansion, classes)
312
+
313
+ for m in self.modules():
314
+ if isinstance(m, nn.Conv1d):
315
+ nn.init.kaiming_normal_(
316
+ m.weight, mode="fan_out", nonlinearity="relu"
317
+ )
318
+ elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
319
+ nn.init.constant_(m.weight, 1)
320
+ nn.init.constant_(m.bias, 0)
321
+
322
+ # Zero-initialize the last BN in each residual branch,
323
+ # so that the residual branch starts with zeros,
324
+ # and each residual block behaves like an identity.
325
+ # This improves the model by 0.2~0.3% according to
326
+ # https://arxiv.org/abs/1706.02677
327
+ if zero_init_residual:
328
+ for m in self.modules():
329
+ if isinstance(m, Bottleneck):
330
+ nn.init.constant_(m.bn3.weight, 0)
331
+ elif isinstance(m, BasicBlock):
332
+ nn.init.constant_(m.bn2.weight, 0)
333
+
334
+ def _make_layer(
335
+ self,
336
+ planes: int,
337
+ blocks: int,
338
+ kernel_size: int = 3,
339
+ stride: int = 1,
340
+ stride_type: Literal["stride", "dilation"] = "stride",
341
+ ) -> nn.Sequential:
342
+ block = self.block
343
+ norm_layer = self._norm_layer
344
+ downsample = None
345
+ previous_dilation = self.dilation
346
+
347
+ if stride_type == "dilation":
348
+ self.dilation *= stride
349
+ stride = 1
350
+ elif stride_type != "stride":
351
+ raise ValueError("Unknown stride type {stride}")
352
+
353
+ if stride != 1 or self.inplanes != planes * block.expansion:
354
+ downsample = nn.Sequential(
355
+ conv1(self.inplanes, planes * block.expansion, stride),
356
+ norm_layer(planes * block.expansion),
357
+ )
358
+
359
+ layers = []
360
+ layers.append(
361
+ block(
362
+ self.inplanes,
363
+ planes,
364
+ kernel_size,
365
+ stride,
366
+ downsample,
367
+ self.groups,
368
+ self.base_width,
369
+ previous_dilation,
370
+ norm_layer,
371
+ )
372
+ )
373
+ self.inplanes = planes * block.expansion
374
+ for _ in range(1, blocks):
375
+ layers.append(
376
+ block(
377
+ self.inplanes,
378
+ planes,
379
+ kernel_size,
380
+ groups=self.groups,
381
+ base_width=self.base_width,
382
+ dilation=self.dilation,
383
+ norm_layer=norm_layer,
384
+ )
385
+ )
386
+
387
+ return nn.Sequential(*layers)
388
+
389
+ def _forward_impl(self, x: Tensor) -> Tensor:
390
+ # See note [TorchScript super()]
391
+ x = self.conv1(x)
392
+ x = self.bn1(x)
393
+ x = self.relu(x)
394
+ x = self.maxpool(x)
395
+
396
+ for layer in self.residual_layers:
397
+ x = layer(x)
398
+
399
+ x = self.avgpool(x)
400
+ x = torch.flatten(x, 1)
401
+ x = self.fc(x)
402
+
403
+ return x
404
+
405
+ def forward(self, x: Tensor) -> Tensor:
406
+ return self._forward_impl(x)
407
+
408
+
409
+ # TODO: implement as arg of ResNet instead?
410
+ class BottleneckResNet1D(ResNet1D):
411
+ """A version of ResNet that uses bottleneck blocks"""
412
+
413
+ block = Bottleneck