braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -3,18 +3,24 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
-
import copy
|
|
7
|
-
from copy import deepcopy
|
|
8
6
|
import warnings
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import Callable, Optional
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
|
-
from torch import nn
|
|
12
11
|
import torch.nn.functional as F
|
|
13
|
-
from
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
from braindecode.models.base import EEGModuleMixin
|
|
15
|
+
from braindecode.modules import CausalConv1d
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
17
|
-
"""Sleep Staging Architecture from Eldele et al 2021.
|
|
19
|
+
"""Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
|
|
20
|
+
|
|
21
|
+
.. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: SleepStagerEldele2021 Architecture
|
|
18
24
|
|
|
19
25
|
Attention based Neural Net for sleep staging as described in [Eldele2021]_.
|
|
20
26
|
The code for the paper and this model is also available at [1]_.
|
|
@@ -43,7 +49,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
43
49
|
input dimension of the second FC layer in the same.
|
|
44
50
|
n_attn_heads : int
|
|
45
51
|
Number of attention heads. It should be a factor of d_model
|
|
46
|
-
|
|
52
|
+
drop_prob : float
|
|
47
53
|
Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
|
|
48
54
|
after_reduced_cnn_size : int
|
|
49
55
|
Number of output channels produced by the convolution in the AFR module.
|
|
@@ -55,6 +61,13 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
55
61
|
Alias for `n_outputs`.
|
|
56
62
|
input_size_s : float
|
|
57
63
|
Alias for `input_window_seconds`.
|
|
64
|
+
activation: nn.Module, default=nn.ReLU
|
|
65
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
66
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
67
|
+
activation_mrcnn: nn.Module, default=nn.ReLU
|
|
68
|
+
Activation function class to apply in the Mask R-CNN layer.
|
|
69
|
+
Should be a PyTorch activation module class like ``nn.ReLU`` or
|
|
70
|
+
``nn.GELU``. Default is ``nn.GELU``.
|
|
58
71
|
|
|
59
72
|
References
|
|
60
73
|
----------
|
|
@@ -68,28 +81,23 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
68
81
|
"""
|
|
69
82
|
|
|
70
83
|
def __init__(
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
84
|
+
self,
|
|
85
|
+
sfreq=None,
|
|
86
|
+
n_tce=2,
|
|
87
|
+
d_model=80,
|
|
88
|
+
d_ff=120,
|
|
89
|
+
n_attn_heads=5,
|
|
90
|
+
drop_prob=0.1,
|
|
91
|
+
activation_mrcnn: nn.Module = nn.GELU,
|
|
92
|
+
activation: nn.Module = nn.ReLU,
|
|
93
|
+
input_window_seconds=None,
|
|
94
|
+
n_outputs=None,
|
|
95
|
+
after_reduced_cnn_size=30,
|
|
96
|
+
return_feats=False,
|
|
97
|
+
chs_info=None,
|
|
98
|
+
n_chans=None,
|
|
99
|
+
n_times=None,
|
|
87
100
|
):
|
|
88
|
-
n_outputs, input_window_seconds, = deprecated_args(
|
|
89
|
-
self,
|
|
90
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
91
|
-
("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
|
|
92
|
-
)
|
|
93
101
|
super().__init__(
|
|
94
102
|
n_outputs=n_outputs,
|
|
95
103
|
n_chans=n_chans,
|
|
@@ -99,19 +107,25 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
99
107
|
sfreq=sfreq,
|
|
100
108
|
)
|
|
101
109
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
102
|
-
del n_classes, input_size_s
|
|
103
110
|
|
|
104
111
|
self.mapping = {
|
|
105
112
|
"fc.weight": "final_layer.weight",
|
|
106
|
-
"fc.bias": "final_layer.bias"
|
|
113
|
+
"fc.bias": "final_layer.bias",
|
|
107
114
|
}
|
|
108
115
|
|
|
109
|
-
if not (
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
116
|
+
if not (
|
|
117
|
+
(self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
|
|
118
|
+
or (
|
|
119
|
+
self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
|
|
120
|
+
)
|
|
121
|
+
):
|
|
122
|
+
warnings.warn(
|
|
123
|
+
"This model was designed originally for input windows of 30sec at 100Hz, "
|
|
124
|
+
"with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
|
|
125
|
+
"other than this may cause errors or cause the model to perform in "
|
|
126
|
+
"other ways than intended",
|
|
127
|
+
UserWarning,
|
|
128
|
+
)
|
|
115
129
|
|
|
116
130
|
# the usual kernel size for the mrcnn, for sfreq 100
|
|
117
131
|
kernel_size = 7
|
|
@@ -119,11 +133,20 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
119
133
|
if self.sfreq == 125:
|
|
120
134
|
kernel_size = 6
|
|
121
135
|
|
|
122
|
-
mrcnn = _MRCNN(
|
|
136
|
+
mrcnn = _MRCNN(
|
|
137
|
+
after_reduced_cnn_size,
|
|
138
|
+
kernel_size,
|
|
139
|
+
activation=activation_mrcnn,
|
|
140
|
+
activation_se=activation,
|
|
141
|
+
)
|
|
123
142
|
attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
|
|
124
|
-
ff = _PositionwiseFeedForward(d_model, d_ff,
|
|
125
|
-
tce = _TCE(
|
|
126
|
-
|
|
143
|
+
ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
|
|
144
|
+
tce = _TCE(
|
|
145
|
+
_EncoderLayer(
|
|
146
|
+
d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
|
|
147
|
+
),
|
|
148
|
+
n_tce,
|
|
149
|
+
)
|
|
127
150
|
|
|
128
151
|
self.feature_extractor = nn.Sequential(mrcnn, tce)
|
|
129
152
|
self.len_last_layer = self._len_last_layer(self.n_times)
|
|
@@ -133,7 +156,9 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
133
156
|
"""if return_feats:
|
|
134
157
|
raise ValueError("return_feat == True is not accepted anymore")"""
|
|
135
158
|
if not return_feats:
|
|
136
|
-
self.final_layer = nn.Linear(
|
|
159
|
+
self.final_layer = nn.Linear(
|
|
160
|
+
d_model * after_reduced_cnn_size, self.n_outputs
|
|
161
|
+
)
|
|
137
162
|
|
|
138
163
|
def _len_last_layer(self, input_size):
|
|
139
164
|
self.feature_extractor.eval()
|
|
@@ -142,7 +167,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
142
167
|
self.feature_extractor.train()
|
|
143
168
|
return len(out.flatten())
|
|
144
169
|
|
|
145
|
-
def forward(self, x):
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
146
171
|
"""
|
|
147
172
|
Forward pass.
|
|
148
173
|
|
|
@@ -153,27 +178,41 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
153
178
|
"""
|
|
154
179
|
|
|
155
180
|
encoded_features = self.feature_extractor(x)
|
|
156
|
-
encoded_features = encoded_features.contiguous().view(
|
|
181
|
+
encoded_features = encoded_features.contiguous().view(
|
|
182
|
+
encoded_features.shape[0], -1
|
|
183
|
+
)
|
|
157
184
|
|
|
158
185
|
if self.return_feats:
|
|
159
186
|
return encoded_features
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
return final_output
|
|
187
|
+
|
|
188
|
+
return self.final_layer(encoded_features)
|
|
163
189
|
|
|
164
190
|
|
|
165
191
|
class _SELayer(nn.Module):
|
|
166
|
-
def __init__(self, channel, reduction=16):
|
|
192
|
+
def __init__(self, channel, reduction=16, activation=nn.ReLU):
|
|
167
193
|
super(_SELayer, self).__init__()
|
|
168
194
|
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
169
195
|
self.fc = nn.Sequential(
|
|
170
196
|
nn.Linear(channel, channel // reduction, bias=False),
|
|
171
|
-
|
|
197
|
+
activation(inplace=True),
|
|
172
198
|
nn.Linear(channel // reduction, channel, bias=False),
|
|
173
|
-
nn.Sigmoid()
|
|
199
|
+
nn.Sigmoid(),
|
|
174
200
|
)
|
|
175
201
|
|
|
176
|
-
def forward(self, x):
|
|
202
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
203
|
+
"""
|
|
204
|
+
Forward pass of the SE layer.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
x : torch.Tensor
|
|
209
|
+
Input tensor of shape (batch_size, channel, length).
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
torch.Tensor
|
|
214
|
+
Output tensor after applying the SE recalibration.
|
|
215
|
+
"""
|
|
177
216
|
b, c, _ = x.size()
|
|
178
217
|
y = self.avg_pool(x).view(b, c)
|
|
179
218
|
y = self.fc(y).view(b, c, 1)
|
|
@@ -183,22 +222,43 @@ class _SELayer(nn.Module):
|
|
|
183
222
|
class _SEBasicBlock(nn.Module):
|
|
184
223
|
expansion = 1
|
|
185
224
|
|
|
186
|
-
def __init__(
|
|
187
|
-
|
|
188
|
-
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
inplanes,
|
|
228
|
+
planes,
|
|
229
|
+
stride=1,
|
|
230
|
+
downsample=None,
|
|
231
|
+
activation: nn.Module = nn.ReLU,
|
|
232
|
+
*,
|
|
233
|
+
reduction=16,
|
|
234
|
+
):
|
|
189
235
|
super(_SEBasicBlock, self).__init__()
|
|
190
236
|
self.conv1 = nn.Conv1d(inplanes, planes, stride)
|
|
191
237
|
self.bn1 = nn.BatchNorm1d(planes)
|
|
192
|
-
self.relu =
|
|
238
|
+
self.relu = activation(inplace=True)
|
|
193
239
|
self.conv2 = nn.Conv1d(planes, planes, 1)
|
|
194
240
|
self.bn2 = nn.BatchNorm1d(planes)
|
|
195
241
|
self.se = _SELayer(planes, reduction)
|
|
196
242
|
self.downsample = downsample
|
|
197
243
|
self.stride = stride
|
|
198
|
-
self.features = nn.Sequential(
|
|
199
|
-
|
|
244
|
+
self.features = nn.Sequential(
|
|
245
|
+
self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
|
|
246
|
+
)
|
|
200
247
|
|
|
201
|
-
def forward(self, x):
|
|
248
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
249
|
+
"""
|
|
250
|
+
Forward pass of the SE layer.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
x : torch.Tensor
|
|
255
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
torch.Tensor
|
|
260
|
+
Output tensor after applying the SE recalibration.
|
|
261
|
+
"""
|
|
202
262
|
residual = x
|
|
203
263
|
out = self.features(x)
|
|
204
264
|
|
|
@@ -212,26 +272,29 @@ class _SEBasicBlock(nn.Module):
|
|
|
212
272
|
|
|
213
273
|
|
|
214
274
|
class _MRCNN(nn.Module):
|
|
215
|
-
def __init__(
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
after_reduced_cnn_size,
|
|
278
|
+
kernel_size=7,
|
|
279
|
+
activation: nn.Module = nn.GELU,
|
|
280
|
+
activation_se: nn.Module = nn.ReLU,
|
|
281
|
+
):
|
|
216
282
|
super(_MRCNN, self).__init__()
|
|
217
283
|
drate = 0.5
|
|
218
|
-
self.GELU =
|
|
284
|
+
self.GELU = activation()
|
|
219
285
|
self.features1 = nn.Sequential(
|
|
220
286
|
nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
|
|
221
287
|
nn.BatchNorm1d(64),
|
|
222
288
|
self.GELU,
|
|
223
289
|
nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
|
|
224
290
|
nn.Dropout(drate),
|
|
225
|
-
|
|
226
291
|
nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
227
292
|
nn.BatchNorm1d(128),
|
|
228
293
|
self.GELU,
|
|
229
|
-
|
|
230
294
|
nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
231
295
|
nn.BatchNorm1d(128),
|
|
232
296
|
self.GELU,
|
|
233
|
-
|
|
234
|
-
nn.MaxPool1d(kernel_size=4, stride=4, padding=2)
|
|
297
|
+
nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
|
|
235
298
|
)
|
|
236
299
|
|
|
237
300
|
self.features2 = nn.Sequential(
|
|
@@ -240,28 +303,38 @@ class _MRCNN(nn.Module):
|
|
|
240
303
|
self.GELU,
|
|
241
304
|
nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
|
|
242
305
|
nn.Dropout(drate),
|
|
243
|
-
|
|
244
|
-
|
|
306
|
+
nn.Conv1d(
|
|
307
|
+
64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
308
|
+
),
|
|
245
309
|
nn.BatchNorm1d(128),
|
|
246
310
|
self.GELU,
|
|
247
|
-
|
|
248
|
-
|
|
311
|
+
nn.Conv1d(
|
|
312
|
+
128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
313
|
+
),
|
|
249
314
|
nn.BatchNorm1d(128),
|
|
250
315
|
self.GELU,
|
|
251
|
-
|
|
252
|
-
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
|
|
316
|
+
nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
|
|
253
317
|
)
|
|
254
318
|
|
|
255
319
|
self.dropout = nn.Dropout(drate)
|
|
256
320
|
self.inplanes = 128
|
|
257
|
-
self.AFR = self._make_layer(
|
|
321
|
+
self.AFR = self._make_layer(
|
|
322
|
+
_SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
|
|
323
|
+
)
|
|
258
324
|
|
|
259
|
-
def _make_layer(
|
|
325
|
+
def _make_layer(
|
|
326
|
+
self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
|
|
327
|
+
): # makes residual SE block
|
|
260
328
|
downsample = None
|
|
261
329
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
262
330
|
downsample = nn.Sequential(
|
|
263
|
-
nn.Conv1d(
|
|
264
|
-
|
|
331
|
+
nn.Conv1d(
|
|
332
|
+
self.inplanes,
|
|
333
|
+
planes * block.expansion,
|
|
334
|
+
kernel_size=1,
|
|
335
|
+
stride=stride,
|
|
336
|
+
bias=False,
|
|
337
|
+
),
|
|
265
338
|
nn.BatchNorm1d(planes * block.expansion),
|
|
266
339
|
)
|
|
267
340
|
|
|
@@ -269,11 +342,11 @@ class _MRCNN(nn.Module):
|
|
|
269
342
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
270
343
|
self.inplanes = planes * block.expansion
|
|
271
344
|
for i in range(1, blocks):
|
|
272
|
-
layers.append(block(self.inplanes, planes))
|
|
345
|
+
layers.append(block(self.inplanes, planes, activate=activate))
|
|
273
346
|
|
|
274
347
|
return nn.Sequential(*layers)
|
|
275
348
|
|
|
276
|
-
def forward(self, x):
|
|
349
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
277
350
|
x1 = self.features1(x)
|
|
278
351
|
x2 = self.features2(x)
|
|
279
352
|
x_concat = torch.cat((x1, x2), dim=2)
|
|
@@ -285,93 +358,107 @@ class _MRCNN(nn.Module):
|
|
|
285
358
|
##########################################################################################
|
|
286
359
|
|
|
287
360
|
|
|
288
|
-
def _attention(
|
|
361
|
+
def _attention(
|
|
362
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
363
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
289
364
|
"""Implementation of Scaled dot product attention"""
|
|
290
365
|
# d_k - dimension of the query and key vectors
|
|
291
366
|
d_k = query.size(-1)
|
|
292
367
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
p_attn = dropout(p_attn)
|
|
297
|
-
return torch.matmul(p_attn, value), p_attn
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class _CausalConv1d(torch.nn.Conv1d):
|
|
301
|
-
def __init__(self,
|
|
302
|
-
in_channels,
|
|
303
|
-
out_channels,
|
|
304
|
-
kernel_size,
|
|
305
|
-
stride=1,
|
|
306
|
-
dilation=1,
|
|
307
|
-
groups=1,
|
|
308
|
-
bias=True):
|
|
309
|
-
self.__padding = (kernel_size - 1) * dilation
|
|
310
|
-
|
|
311
|
-
super(_CausalConv1d, self).__init__(
|
|
312
|
-
in_channels,
|
|
313
|
-
out_channels,
|
|
314
|
-
kernel_size=kernel_size,
|
|
315
|
-
stride=stride,
|
|
316
|
-
padding=self.__padding,
|
|
317
|
-
dilation=dilation,
|
|
318
|
-
groups=groups,
|
|
319
|
-
bias=bias)
|
|
320
|
-
|
|
321
|
-
def forward(self, input):
|
|
322
|
-
result = super(_CausalConv1d, self).forward(input)
|
|
323
|
-
if self.__padding != 0:
|
|
324
|
-
return result[:, :, :-self.__padding]
|
|
325
|
-
return result
|
|
368
|
+
p_attn = F.softmax(scores, dim=-1) # attention weights
|
|
369
|
+
output = torch.matmul(p_attn, value) # (B, h, T, d_k)
|
|
370
|
+
return output, p_attn
|
|
326
371
|
|
|
327
372
|
|
|
328
373
|
class _MultiHeadedAttention(nn.Module):
|
|
329
374
|
def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
|
|
330
375
|
"""Take in model size and number of heads."""
|
|
331
|
-
super(
|
|
376
|
+
super().__init__()
|
|
332
377
|
assert d_model % h == 0
|
|
333
378
|
self.d_per_head = d_model // h
|
|
334
379
|
self.h = h
|
|
335
380
|
|
|
336
|
-
|
|
337
|
-
|
|
381
|
+
base_conv = CausalConv1d(
|
|
382
|
+
in_channels=after_reduced_cnn_size,
|
|
383
|
+
out_channels=after_reduced_cnn_size,
|
|
384
|
+
kernel_size=7,
|
|
385
|
+
stride=1,
|
|
386
|
+
)
|
|
387
|
+
self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
|
|
388
|
+
|
|
338
389
|
self.linear = nn.Linear(d_model, d_model)
|
|
339
390
|
self.dropout = nn.Dropout(p=dropout)
|
|
340
391
|
|
|
341
|
-
def forward(self, query, key, value):
|
|
392
|
+
def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
|
|
342
393
|
"""Implements Multi-head attention"""
|
|
343
394
|
nbatches = query.size(0)
|
|
344
395
|
|
|
345
396
|
query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
|
|
346
|
-
key =
|
|
347
|
-
|
|
397
|
+
key = (
|
|
398
|
+
self.convs[1](key)
|
|
399
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
400
|
+
.transpose(1, 2)
|
|
401
|
+
)
|
|
402
|
+
value = (
|
|
403
|
+
self.convs[2](value)
|
|
404
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
405
|
+
.transpose(1, 2)
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
x_raw, attn_weights = _attention(query, key, value)
|
|
409
|
+
# apply dropout to the *weights*
|
|
410
|
+
attn = self.dropout(attn_weights)
|
|
411
|
+
# recompute the weighted sum with dropped weights
|
|
412
|
+
x = torch.matmul(attn, value)
|
|
348
413
|
|
|
349
|
-
|
|
414
|
+
# stash the pre‑dropout weights if you need them
|
|
415
|
+
self.attn = attn_weights
|
|
350
416
|
|
|
351
|
-
|
|
352
|
-
|
|
417
|
+
# merge heads and project
|
|
418
|
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
|
|
353
419
|
|
|
354
420
|
return self.linear(x)
|
|
355
421
|
|
|
356
422
|
|
|
357
|
-
class
|
|
423
|
+
class _ResidualLayerNormAttn(nn.Module):
|
|
358
424
|
"""
|
|
359
425
|
A residual connection followed by a layer norm.
|
|
360
426
|
"""
|
|
361
427
|
|
|
362
|
-
def __init__(self, size, dropout):
|
|
363
|
-
super(
|
|
428
|
+
def __init__(self, size, dropout, fn_attn):
|
|
429
|
+
super().__init__()
|
|
430
|
+
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
431
|
+
self.dropout = nn.Dropout(dropout)
|
|
432
|
+
self.fn_attn = fn_attn
|
|
433
|
+
|
|
434
|
+
def forward(
|
|
435
|
+
self,
|
|
436
|
+
x: torch.Tensor,
|
|
437
|
+
key: torch.Tensor,
|
|
438
|
+
value: torch.Tensor,
|
|
439
|
+
) -> torch.Tensor:
|
|
440
|
+
"""Apply residual connection to any sublayer with the same size."""
|
|
441
|
+
x_norm = self.norm(x)
|
|
442
|
+
|
|
443
|
+
out = self.fn_attn(x_norm, key, value)
|
|
444
|
+
|
|
445
|
+
return x + self.dropout(out)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class _ResidualLayerNormFF(nn.Module):
|
|
449
|
+
def __init__(self, size, dropout, fn_ff):
|
|
450
|
+
super().__init__()
|
|
364
451
|
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
365
452
|
self.dropout = nn.Dropout(dropout)
|
|
453
|
+
self.fn_ff = fn_ff
|
|
366
454
|
|
|
367
|
-
def forward(self, x
|
|
455
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
368
456
|
"""Apply residual connection to any sublayer with the same size."""
|
|
369
|
-
|
|
457
|
+
x_norm = self.norm(x)
|
|
370
458
|
|
|
459
|
+
out = self.fn_ff(x_norm)
|
|
371
460
|
|
|
372
|
-
|
|
373
|
-
"""Produce n identical layers."""
|
|
374
|
-
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
|
461
|
+
return x + self.dropout(out)
|
|
375
462
|
|
|
376
463
|
|
|
377
464
|
class _TCE(nn.Module):
|
|
@@ -381,11 +468,13 @@ class _TCE(nn.Module):
|
|
|
381
468
|
"""
|
|
382
469
|
|
|
383
470
|
def __init__(self, layer, n):
|
|
384
|
-
super(
|
|
385
|
-
|
|
471
|
+
super().__init__()
|
|
472
|
+
|
|
473
|
+
self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
|
|
474
|
+
|
|
386
475
|
self.norm = nn.LayerNorm(layer.size, eps=1e-6)
|
|
387
476
|
|
|
388
|
-
def forward(self, x):
|
|
477
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
389
478
|
for layer in self.layers:
|
|
390
479
|
x = layer(x)
|
|
391
480
|
return self.norm(x)
|
|
@@ -395,35 +484,53 @@ class _EncoderLayer(nn.Module):
|
|
|
395
484
|
"""
|
|
396
485
|
An encoder layer
|
|
397
486
|
Made up of self-attention and a feed forward layer.
|
|
398
|
-
Each of these sublayers have residual and layer norm, implemented by
|
|
487
|
+
Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
|
|
399
488
|
"""
|
|
400
489
|
|
|
401
490
|
def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
|
|
402
|
-
super(
|
|
491
|
+
super().__init__()
|
|
492
|
+
self.size = size
|
|
403
493
|
self.self_attn = self_attn
|
|
404
494
|
self.feed_forward = feed_forward
|
|
405
|
-
self.sublayer_output = _clones(_SublayerOutput(size, dropout), 2)
|
|
406
|
-
self.size = size
|
|
407
|
-
self.conv = _CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size, kernel_size=7,
|
|
408
|
-
stride=1, dilation=1)
|
|
409
495
|
|
|
410
|
-
|
|
496
|
+
self.residual_self_attn = _ResidualLayerNormAttn(
|
|
497
|
+
size=size,
|
|
498
|
+
dropout=dropout,
|
|
499
|
+
fn_attn=self_attn,
|
|
500
|
+
)
|
|
501
|
+
self.residual_ff = _ResidualLayerNormFF(
|
|
502
|
+
size=size,
|
|
503
|
+
dropout=dropout,
|
|
504
|
+
fn_ff=feed_forward,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
self.conv = CausalConv1d(
|
|
508
|
+
in_channels=after_reduced_cnn_size,
|
|
509
|
+
out_channels=after_reduced_cnn_size,
|
|
510
|
+
kernel_size=7,
|
|
511
|
+
stride=1,
|
|
512
|
+
dilation=1,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
|
411
516
|
"""Transformer Encoder"""
|
|
412
517
|
query = self.conv(x_in)
|
|
413
518
|
# Encoder self-attention
|
|
414
|
-
x = self.
|
|
415
|
-
|
|
519
|
+
x = self.residual_self_attn(query, x_in, x_in)
|
|
520
|
+
x_ff = self.residual_ff(x)
|
|
521
|
+
return x_ff
|
|
416
522
|
|
|
417
523
|
|
|
418
524
|
class _PositionwiseFeedForward(nn.Module):
|
|
419
525
|
"""Positionwise feed-forward network."""
|
|
420
526
|
|
|
421
|
-
def __init__(self, d_model, d_ff, dropout=0.1):
|
|
422
|
-
super(
|
|
527
|
+
def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
|
|
528
|
+
super().__init__()
|
|
423
529
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
424
530
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
425
531
|
self.dropout = nn.Dropout(dropout)
|
|
532
|
+
self.activate = activation()
|
|
426
533
|
|
|
427
|
-
def forward(self, x):
|
|
534
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
428
535
|
"""Implements FFN equation."""
|
|
429
|
-
return self.w_2(self.dropout(
|
|
536
|
+
return self.w_2(self.dropout(self.activate(self.w_1(x))))
|