braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -3,18 +3,23 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
import math
|
|
6
|
-
import copy
|
|
7
|
-
from copy import deepcopy
|
|
8
6
|
import warnings
|
|
7
|
+
from copy import deepcopy
|
|
9
8
|
|
|
10
9
|
import torch
|
|
11
|
-
from torch import nn
|
|
12
10
|
import torch.nn.functional as F
|
|
13
|
-
from
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import CausalConv1d
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
17
|
-
"""Sleep Staging Architecture from Eldele et al 2021.
|
|
18
|
+
"""Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
|
|
19
|
+
|
|
20
|
+
.. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: SleepStagerEldele2021 Architecture
|
|
18
23
|
|
|
19
24
|
Attention based Neural Net for sleep staging as described in [Eldele2021]_.
|
|
20
25
|
The code for the paper and this model is also available at [1]_.
|
|
@@ -43,7 +48,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
43
48
|
input dimension of the second FC layer in the same.
|
|
44
49
|
n_attn_heads : int
|
|
45
50
|
Number of attention heads. It should be a factor of d_model
|
|
46
|
-
|
|
51
|
+
drop_prob : float
|
|
47
52
|
Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
|
|
48
53
|
after_reduced_cnn_size : int
|
|
49
54
|
Number of output channels produced by the convolution in the AFR module.
|
|
@@ -55,6 +60,13 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
55
60
|
Alias for `n_outputs`.
|
|
56
61
|
input_size_s : float
|
|
57
62
|
Alias for `input_window_seconds`.
|
|
63
|
+
activation: nn.Module, default=nn.ReLU
|
|
64
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
65
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
66
|
+
activation_mrcnn: nn.Module, default=nn.ReLU
|
|
67
|
+
Activation function class to apply in the Mask R-CNN layer.
|
|
68
|
+
Should be a PyTorch activation module class like ``nn.ReLU`` or
|
|
69
|
+
``nn.GELU``. Default is ``nn.GELU``.
|
|
58
70
|
|
|
59
71
|
References
|
|
60
72
|
----------
|
|
@@ -68,28 +80,23 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
68
80
|
"""
|
|
69
81
|
|
|
70
82
|
def __init__(
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
83
|
+
self,
|
|
84
|
+
sfreq=None,
|
|
85
|
+
n_tce=2,
|
|
86
|
+
d_model=80,
|
|
87
|
+
d_ff=120,
|
|
88
|
+
n_attn_heads=5,
|
|
89
|
+
drop_prob=0.1,
|
|
90
|
+
activation_mrcnn: nn.Module = nn.GELU,
|
|
91
|
+
activation: nn.Module = nn.ReLU,
|
|
92
|
+
input_window_seconds=None,
|
|
93
|
+
n_outputs=None,
|
|
94
|
+
after_reduced_cnn_size=30,
|
|
95
|
+
return_feats=False,
|
|
96
|
+
chs_info=None,
|
|
97
|
+
n_chans=None,
|
|
98
|
+
n_times=None,
|
|
87
99
|
):
|
|
88
|
-
n_outputs, input_window_seconds, = deprecated_args(
|
|
89
|
-
self,
|
|
90
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
91
|
-
("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
|
|
92
|
-
)
|
|
93
100
|
super().__init__(
|
|
94
101
|
n_outputs=n_outputs,
|
|
95
102
|
n_chans=n_chans,
|
|
@@ -99,19 +106,25 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
99
106
|
sfreq=sfreq,
|
|
100
107
|
)
|
|
101
108
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
102
|
-
del n_classes, input_size_s
|
|
103
109
|
|
|
104
110
|
self.mapping = {
|
|
105
111
|
"fc.weight": "final_layer.weight",
|
|
106
|
-
"fc.bias": "final_layer.bias"
|
|
112
|
+
"fc.bias": "final_layer.bias",
|
|
107
113
|
}
|
|
108
114
|
|
|
109
|
-
if not (
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
+
if not (
|
|
116
|
+
(self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
|
|
117
|
+
or (
|
|
118
|
+
self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
|
|
119
|
+
)
|
|
120
|
+
):
|
|
121
|
+
warnings.warn(
|
|
122
|
+
"This model was designed originally for input windows of 30sec at 100Hz, "
|
|
123
|
+
"with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
|
|
124
|
+
"other than this may cause errors or cause the model to perform in "
|
|
125
|
+
"other ways than intended",
|
|
126
|
+
UserWarning,
|
|
127
|
+
)
|
|
115
128
|
|
|
116
129
|
# the usual kernel size for the mrcnn, for sfreq 100
|
|
117
130
|
kernel_size = 7
|
|
@@ -119,11 +132,20 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
119
132
|
if self.sfreq == 125:
|
|
120
133
|
kernel_size = 6
|
|
121
134
|
|
|
122
|
-
mrcnn = _MRCNN(
|
|
135
|
+
mrcnn = _MRCNN(
|
|
136
|
+
after_reduced_cnn_size,
|
|
137
|
+
kernel_size,
|
|
138
|
+
activation=activation_mrcnn,
|
|
139
|
+
activation_se=activation,
|
|
140
|
+
)
|
|
123
141
|
attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
|
|
124
|
-
ff = _PositionwiseFeedForward(d_model, d_ff,
|
|
125
|
-
tce = _TCE(
|
|
126
|
-
|
|
142
|
+
ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
|
|
143
|
+
tce = _TCE(
|
|
144
|
+
_EncoderLayer(
|
|
145
|
+
d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
|
|
146
|
+
),
|
|
147
|
+
n_tce,
|
|
148
|
+
)
|
|
127
149
|
|
|
128
150
|
self.feature_extractor = nn.Sequential(mrcnn, tce)
|
|
129
151
|
self.len_last_layer = self._len_last_layer(self.n_times)
|
|
@@ -133,7 +155,9 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
133
155
|
"""if return_feats:
|
|
134
156
|
raise ValueError("return_feat == True is not accepted anymore")"""
|
|
135
157
|
if not return_feats:
|
|
136
|
-
self.final_layer = nn.Linear(
|
|
158
|
+
self.final_layer = nn.Linear(
|
|
159
|
+
d_model * after_reduced_cnn_size, self.n_outputs
|
|
160
|
+
)
|
|
137
161
|
|
|
138
162
|
def _len_last_layer(self, input_size):
|
|
139
163
|
self.feature_extractor.eval()
|
|
@@ -142,7 +166,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
142
166
|
self.feature_extractor.train()
|
|
143
167
|
return len(out.flatten())
|
|
144
168
|
|
|
145
|
-
def forward(self, x):
|
|
169
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
146
170
|
"""
|
|
147
171
|
Forward pass.
|
|
148
172
|
|
|
@@ -153,27 +177,41 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
|
|
|
153
177
|
"""
|
|
154
178
|
|
|
155
179
|
encoded_features = self.feature_extractor(x)
|
|
156
|
-
encoded_features = encoded_features.contiguous().view(
|
|
180
|
+
encoded_features = encoded_features.contiguous().view(
|
|
181
|
+
encoded_features.shape[0], -1
|
|
182
|
+
)
|
|
157
183
|
|
|
158
184
|
if self.return_feats:
|
|
159
185
|
return encoded_features
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
return final_output
|
|
186
|
+
|
|
187
|
+
return self.final_layer(encoded_features)
|
|
163
188
|
|
|
164
189
|
|
|
165
190
|
class _SELayer(nn.Module):
|
|
166
|
-
def __init__(self, channel, reduction=16):
|
|
191
|
+
def __init__(self, channel, reduction=16, activation=nn.ReLU):
|
|
167
192
|
super(_SELayer, self).__init__()
|
|
168
193
|
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
169
194
|
self.fc = nn.Sequential(
|
|
170
195
|
nn.Linear(channel, channel // reduction, bias=False),
|
|
171
|
-
|
|
196
|
+
activation(inplace=True),
|
|
172
197
|
nn.Linear(channel // reduction, channel, bias=False),
|
|
173
|
-
nn.Sigmoid()
|
|
198
|
+
nn.Sigmoid(),
|
|
174
199
|
)
|
|
175
200
|
|
|
176
|
-
def forward(self, x):
|
|
201
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Forward pass of the SE layer.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
x : torch.Tensor
|
|
208
|
+
Input tensor of shape (batch_size, channel, length).
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
torch.Tensor
|
|
213
|
+
Output tensor after applying the SE recalibration.
|
|
214
|
+
"""
|
|
177
215
|
b, c, _ = x.size()
|
|
178
216
|
y = self.avg_pool(x).view(b, c)
|
|
179
217
|
y = self.fc(y).view(b, c, 1)
|
|
@@ -183,22 +221,43 @@ class _SELayer(nn.Module):
|
|
|
183
221
|
class _SEBasicBlock(nn.Module):
|
|
184
222
|
expansion = 1
|
|
185
223
|
|
|
186
|
-
def __init__(
|
|
187
|
-
|
|
188
|
-
|
|
224
|
+
def __init__(
|
|
225
|
+
self,
|
|
226
|
+
inplanes,
|
|
227
|
+
planes,
|
|
228
|
+
stride=1,
|
|
229
|
+
downsample=None,
|
|
230
|
+
activation: nn.Module = nn.ReLU,
|
|
231
|
+
*,
|
|
232
|
+
reduction=16,
|
|
233
|
+
):
|
|
189
234
|
super(_SEBasicBlock, self).__init__()
|
|
190
235
|
self.conv1 = nn.Conv1d(inplanes, planes, stride)
|
|
191
236
|
self.bn1 = nn.BatchNorm1d(planes)
|
|
192
|
-
self.relu =
|
|
237
|
+
self.relu = activation(inplace=True)
|
|
193
238
|
self.conv2 = nn.Conv1d(planes, planes, 1)
|
|
194
239
|
self.bn2 = nn.BatchNorm1d(planes)
|
|
195
240
|
self.se = _SELayer(planes, reduction)
|
|
196
241
|
self.downsample = downsample
|
|
197
242
|
self.stride = stride
|
|
198
|
-
self.features = nn.Sequential(
|
|
199
|
-
|
|
243
|
+
self.features = nn.Sequential(
|
|
244
|
+
self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
|
|
245
|
+
)
|
|
200
246
|
|
|
201
|
-
def forward(self, x):
|
|
247
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
248
|
+
"""
|
|
249
|
+
Forward pass of the SE layer.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
x : torch.Tensor
|
|
254
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
torch.Tensor
|
|
259
|
+
Output tensor after applying the SE recalibration.
|
|
260
|
+
"""
|
|
202
261
|
residual = x
|
|
203
262
|
out = self.features(x)
|
|
204
263
|
|
|
@@ -212,26 +271,29 @@ class _SEBasicBlock(nn.Module):
|
|
|
212
271
|
|
|
213
272
|
|
|
214
273
|
class _MRCNN(nn.Module):
|
|
215
|
-
def __init__(
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
after_reduced_cnn_size,
|
|
277
|
+
kernel_size=7,
|
|
278
|
+
activation: nn.Module = nn.GELU,
|
|
279
|
+
activation_se: nn.Module = nn.ReLU,
|
|
280
|
+
):
|
|
216
281
|
super(_MRCNN, self).__init__()
|
|
217
282
|
drate = 0.5
|
|
218
|
-
self.GELU =
|
|
283
|
+
self.GELU = activation()
|
|
219
284
|
self.features1 = nn.Sequential(
|
|
220
285
|
nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
|
|
221
286
|
nn.BatchNorm1d(64),
|
|
222
287
|
self.GELU,
|
|
223
288
|
nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
|
|
224
289
|
nn.Dropout(drate),
|
|
225
|
-
|
|
226
290
|
nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
227
291
|
nn.BatchNorm1d(128),
|
|
228
292
|
self.GELU,
|
|
229
|
-
|
|
230
293
|
nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
231
294
|
nn.BatchNorm1d(128),
|
|
232
295
|
self.GELU,
|
|
233
|
-
|
|
234
|
-
nn.MaxPool1d(kernel_size=4, stride=4, padding=2)
|
|
296
|
+
nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
|
|
235
297
|
)
|
|
236
298
|
|
|
237
299
|
self.features2 = nn.Sequential(
|
|
@@ -240,28 +302,38 @@ class _MRCNN(nn.Module):
|
|
|
240
302
|
self.GELU,
|
|
241
303
|
nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
|
|
242
304
|
nn.Dropout(drate),
|
|
243
|
-
|
|
244
|
-
|
|
305
|
+
nn.Conv1d(
|
|
306
|
+
64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
307
|
+
),
|
|
245
308
|
nn.BatchNorm1d(128),
|
|
246
309
|
self.GELU,
|
|
247
|
-
|
|
248
|
-
|
|
310
|
+
nn.Conv1d(
|
|
311
|
+
128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
312
|
+
),
|
|
249
313
|
nn.BatchNorm1d(128),
|
|
250
314
|
self.GELU,
|
|
251
|
-
|
|
252
|
-
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
|
|
315
|
+
nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
|
|
253
316
|
)
|
|
254
317
|
|
|
255
318
|
self.dropout = nn.Dropout(drate)
|
|
256
319
|
self.inplanes = 128
|
|
257
|
-
self.AFR = self._make_layer(
|
|
320
|
+
self.AFR = self._make_layer(
|
|
321
|
+
_SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
|
|
322
|
+
)
|
|
258
323
|
|
|
259
|
-
def _make_layer(
|
|
324
|
+
def _make_layer(
|
|
325
|
+
self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
|
|
326
|
+
): # makes residual SE block
|
|
260
327
|
downsample = None
|
|
261
328
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
262
329
|
downsample = nn.Sequential(
|
|
263
|
-
nn.Conv1d(
|
|
264
|
-
|
|
330
|
+
nn.Conv1d(
|
|
331
|
+
self.inplanes,
|
|
332
|
+
planes * block.expansion,
|
|
333
|
+
kernel_size=1,
|
|
334
|
+
stride=stride,
|
|
335
|
+
bias=False,
|
|
336
|
+
),
|
|
265
337
|
nn.BatchNorm1d(planes * block.expansion),
|
|
266
338
|
)
|
|
267
339
|
|
|
@@ -269,11 +341,11 @@ class _MRCNN(nn.Module):
|
|
|
269
341
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
270
342
|
self.inplanes = planes * block.expansion
|
|
271
343
|
for i in range(1, blocks):
|
|
272
|
-
layers.append(block(self.inplanes, planes))
|
|
344
|
+
layers.append(block(self.inplanes, planes, activate=activate))
|
|
273
345
|
|
|
274
346
|
return nn.Sequential(*layers)
|
|
275
347
|
|
|
276
|
-
def forward(self, x):
|
|
348
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
277
349
|
x1 = self.features1(x)
|
|
278
350
|
x2 = self.features2(x)
|
|
279
351
|
x_concat = torch.cat((x1, x2), dim=2)
|
|
@@ -285,93 +357,107 @@ class _MRCNN(nn.Module):
|
|
|
285
357
|
##########################################################################################
|
|
286
358
|
|
|
287
359
|
|
|
288
|
-
def _attention(
|
|
360
|
+
def _attention(
|
|
361
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
362
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
289
363
|
"""Implementation of Scaled dot product attention"""
|
|
290
364
|
# d_k - dimension of the query and key vectors
|
|
291
365
|
d_k = query.size(-1)
|
|
292
366
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
p_attn = dropout(p_attn)
|
|
297
|
-
return torch.matmul(p_attn, value), p_attn
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class _CausalConv1d(torch.nn.Conv1d):
|
|
301
|
-
def __init__(self,
|
|
302
|
-
in_channels,
|
|
303
|
-
out_channels,
|
|
304
|
-
kernel_size,
|
|
305
|
-
stride=1,
|
|
306
|
-
dilation=1,
|
|
307
|
-
groups=1,
|
|
308
|
-
bias=True):
|
|
309
|
-
self.__padding = (kernel_size - 1) * dilation
|
|
310
|
-
|
|
311
|
-
super(_CausalConv1d, self).__init__(
|
|
312
|
-
in_channels,
|
|
313
|
-
out_channels,
|
|
314
|
-
kernel_size=kernel_size,
|
|
315
|
-
stride=stride,
|
|
316
|
-
padding=self.__padding,
|
|
317
|
-
dilation=dilation,
|
|
318
|
-
groups=groups,
|
|
319
|
-
bias=bias)
|
|
320
|
-
|
|
321
|
-
def forward(self, input):
|
|
322
|
-
result = super(_CausalConv1d, self).forward(input)
|
|
323
|
-
if self.__padding != 0:
|
|
324
|
-
return result[:, :, :-self.__padding]
|
|
325
|
-
return result
|
|
367
|
+
p_attn = F.softmax(scores, dim=-1) # attention weights
|
|
368
|
+
output = torch.matmul(p_attn, value) # (B, h, T, d_k)
|
|
369
|
+
return output, p_attn
|
|
326
370
|
|
|
327
371
|
|
|
328
372
|
class _MultiHeadedAttention(nn.Module):
|
|
329
373
|
def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
|
|
330
374
|
"""Take in model size and number of heads."""
|
|
331
|
-
super(
|
|
375
|
+
super().__init__()
|
|
332
376
|
assert d_model % h == 0
|
|
333
377
|
self.d_per_head = d_model // h
|
|
334
378
|
self.h = h
|
|
335
379
|
|
|
336
|
-
|
|
337
|
-
|
|
380
|
+
base_conv = CausalConv1d(
|
|
381
|
+
in_channels=after_reduced_cnn_size,
|
|
382
|
+
out_channels=after_reduced_cnn_size,
|
|
383
|
+
kernel_size=7,
|
|
384
|
+
stride=1,
|
|
385
|
+
)
|
|
386
|
+
self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
|
|
387
|
+
|
|
338
388
|
self.linear = nn.Linear(d_model, d_model)
|
|
339
389
|
self.dropout = nn.Dropout(p=dropout)
|
|
340
390
|
|
|
341
|
-
def forward(self, query, key, value):
|
|
391
|
+
def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
|
|
342
392
|
"""Implements Multi-head attention"""
|
|
343
393
|
nbatches = query.size(0)
|
|
344
394
|
|
|
345
395
|
query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
|
|
346
|
-
key =
|
|
347
|
-
|
|
396
|
+
key = (
|
|
397
|
+
self.convs[1](key)
|
|
398
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
399
|
+
.transpose(1, 2)
|
|
400
|
+
)
|
|
401
|
+
value = (
|
|
402
|
+
self.convs[2](value)
|
|
403
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
404
|
+
.transpose(1, 2)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
x_raw, attn_weights = _attention(query, key, value)
|
|
408
|
+
# apply dropout to the *weights*
|
|
409
|
+
attn = self.dropout(attn_weights)
|
|
410
|
+
# recompute the weighted sum with dropped weights
|
|
411
|
+
x = torch.matmul(attn, value)
|
|
348
412
|
|
|
349
|
-
|
|
413
|
+
# stash the pre‑dropout weights if you need them
|
|
414
|
+
self.attn = attn_weights
|
|
350
415
|
|
|
351
|
-
|
|
352
|
-
|
|
416
|
+
# merge heads and project
|
|
417
|
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
|
|
353
418
|
|
|
354
419
|
return self.linear(x)
|
|
355
420
|
|
|
356
421
|
|
|
357
|
-
class
|
|
422
|
+
class _ResidualLayerNormAttn(nn.Module):
|
|
358
423
|
"""
|
|
359
424
|
A residual connection followed by a layer norm.
|
|
360
425
|
"""
|
|
361
426
|
|
|
362
|
-
def __init__(self, size, dropout):
|
|
363
|
-
super(
|
|
427
|
+
def __init__(self, size, dropout, fn_attn):
|
|
428
|
+
super().__init__()
|
|
429
|
+
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
430
|
+
self.dropout = nn.Dropout(dropout)
|
|
431
|
+
self.fn_attn = fn_attn
|
|
432
|
+
|
|
433
|
+
def forward(
|
|
434
|
+
self,
|
|
435
|
+
x: torch.Tensor,
|
|
436
|
+
key: torch.Tensor,
|
|
437
|
+
value: torch.Tensor,
|
|
438
|
+
) -> torch.Tensor:
|
|
439
|
+
"""Apply residual connection to any sublayer with the same size."""
|
|
440
|
+
x_norm = self.norm(x)
|
|
441
|
+
|
|
442
|
+
out = self.fn_attn(x_norm, key, value)
|
|
443
|
+
|
|
444
|
+
return x + self.dropout(out)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class _ResidualLayerNormFF(nn.Module):
|
|
448
|
+
def __init__(self, size, dropout, fn_ff):
|
|
449
|
+
super().__init__()
|
|
364
450
|
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
365
451
|
self.dropout = nn.Dropout(dropout)
|
|
452
|
+
self.fn_ff = fn_ff
|
|
366
453
|
|
|
367
|
-
def forward(self, x
|
|
454
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
368
455
|
"""Apply residual connection to any sublayer with the same size."""
|
|
369
|
-
|
|
456
|
+
x_norm = self.norm(x)
|
|
370
457
|
|
|
458
|
+
out = self.fn_ff(x_norm)
|
|
371
459
|
|
|
372
|
-
|
|
373
|
-
"""Produce n identical layers."""
|
|
374
|
-
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
|
460
|
+
return x + self.dropout(out)
|
|
375
461
|
|
|
376
462
|
|
|
377
463
|
class _TCE(nn.Module):
|
|
@@ -381,11 +467,13 @@ class _TCE(nn.Module):
|
|
|
381
467
|
"""
|
|
382
468
|
|
|
383
469
|
def __init__(self, layer, n):
|
|
384
|
-
super(
|
|
385
|
-
|
|
470
|
+
super().__init__()
|
|
471
|
+
|
|
472
|
+
self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
|
|
473
|
+
|
|
386
474
|
self.norm = nn.LayerNorm(layer.size, eps=1e-6)
|
|
387
475
|
|
|
388
|
-
def forward(self, x):
|
|
476
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
389
477
|
for layer in self.layers:
|
|
390
478
|
x = layer(x)
|
|
391
479
|
return self.norm(x)
|
|
@@ -395,35 +483,53 @@ class _EncoderLayer(nn.Module):
|
|
|
395
483
|
"""
|
|
396
484
|
An encoder layer
|
|
397
485
|
Made up of self-attention and a feed forward layer.
|
|
398
|
-
Each of these sublayers have residual and layer norm, implemented by
|
|
486
|
+
Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
|
|
399
487
|
"""
|
|
400
488
|
|
|
401
489
|
def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
|
|
402
|
-
super(
|
|
490
|
+
super().__init__()
|
|
491
|
+
self.size = size
|
|
403
492
|
self.self_attn = self_attn
|
|
404
493
|
self.feed_forward = feed_forward
|
|
405
|
-
self.sublayer_output = _clones(_SublayerOutput(size, dropout), 2)
|
|
406
|
-
self.size = size
|
|
407
|
-
self.conv = _CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size, kernel_size=7,
|
|
408
|
-
stride=1, dilation=1)
|
|
409
494
|
|
|
410
|
-
|
|
495
|
+
self.residual_self_attn = _ResidualLayerNormAttn(
|
|
496
|
+
size=size,
|
|
497
|
+
dropout=dropout,
|
|
498
|
+
fn_attn=self_attn,
|
|
499
|
+
)
|
|
500
|
+
self.residual_ff = _ResidualLayerNormFF(
|
|
501
|
+
size=size,
|
|
502
|
+
dropout=dropout,
|
|
503
|
+
fn_ff=feed_forward,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
self.conv = CausalConv1d(
|
|
507
|
+
in_channels=after_reduced_cnn_size,
|
|
508
|
+
out_channels=after_reduced_cnn_size,
|
|
509
|
+
kernel_size=7,
|
|
510
|
+
stride=1,
|
|
511
|
+
dilation=1,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
|
411
515
|
"""Transformer Encoder"""
|
|
412
516
|
query = self.conv(x_in)
|
|
413
517
|
# Encoder self-attention
|
|
414
|
-
x = self.
|
|
415
|
-
|
|
518
|
+
x = self.residual_self_attn(query, x_in, x_in)
|
|
519
|
+
x_ff = self.residual_ff(x)
|
|
520
|
+
return x_ff
|
|
416
521
|
|
|
417
522
|
|
|
418
523
|
class _PositionwiseFeedForward(nn.Module):
|
|
419
524
|
"""Positionwise feed-forward network."""
|
|
420
525
|
|
|
421
|
-
def __init__(self, d_model, d_ff, dropout=0.1):
|
|
422
|
-
super(
|
|
526
|
+
def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
|
|
527
|
+
super().__init__()
|
|
423
528
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
424
529
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
425
530
|
self.dropout = nn.Dropout(dropout)
|
|
531
|
+
self.activate = activation()
|
|
426
532
|
|
|
427
|
-
def forward(self, x):
|
|
533
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
428
534
|
"""Implements FFN equation."""
|
|
429
|
-
return self.w_2(self.dropout(
|
|
535
|
+
return self.w_2(self.dropout(self.activate(self.w_1(x))))
|