braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,549 @@
|
|
|
1
|
+
# Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import warnings
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from mne.utils import deprecated
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
from braindecode.models.base import EEGModuleMixin
|
|
15
|
+
from braindecode.modules import CausalConv1d
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AttnSleep(EEGModuleMixin, nn.Module):
|
|
19
|
+
r"""Sleep Staging Architecture from Eldele et al (2021) [Eldele2021]_.
|
|
20
|
+
|
|
21
|
+
:bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
|
|
22
|
+
|
|
23
|
+
.. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: AttnSleep Architecture
|
|
26
|
+
|
|
27
|
+
Attention based Neural Net for sleep staging as described in [Eldele2021]_.
|
|
28
|
+
The code for the paper and this model is also available at [1]_.
|
|
29
|
+
Takes single channel EEG as input.
|
|
30
|
+
Feature extraction module based on multi-resolution convolutional neural network (MRCNN)
|
|
31
|
+
and adaptive feature recalibration (AFR).
|
|
32
|
+
The second module is the temporal context encoder (TCE) that leverages a multi-head attention
|
|
33
|
+
mechanism to capture the temporal dependencies among the extracted features.
|
|
34
|
+
|
|
35
|
+
Warning - This model was designed for signals of 30 seconds at 100Hz or 125Hz (in which case
|
|
36
|
+
the reference architecture from [1]_ which was validated on SHHS dataset [2]_ will be used)
|
|
37
|
+
to use any other input is likely to make the model perform in unintended ways.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
n_tce : int
|
|
42
|
+
Number of TCE clones.
|
|
43
|
+
d_model : int
|
|
44
|
+
Input dimension for the TCE.
|
|
45
|
+
Also the input dimension of the first FC layer in the feed forward
|
|
46
|
+
and the output of the second FC layer in the same.
|
|
47
|
+
Increase for higher sampling rate/signal length.
|
|
48
|
+
It should be divisible by n_attn_heads
|
|
49
|
+
d_ff : int
|
|
50
|
+
Output dimension of the first FC layer in the feed forward and the
|
|
51
|
+
input dimension of the second FC layer in the same.
|
|
52
|
+
n_attn_heads : int
|
|
53
|
+
Number of attention heads. It should be a factor of d_model
|
|
54
|
+
drop_prob : float
|
|
55
|
+
Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
|
|
56
|
+
after_reduced_cnn_size : int
|
|
57
|
+
Number of output channels produced by the convolution in the AFR module.
|
|
58
|
+
return_feats : bool
|
|
59
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
60
|
+
(before the final linear layer). If False, pass the features through
|
|
61
|
+
the final linear layer.
|
|
62
|
+
n_classes : int
|
|
63
|
+
Alias for `n_outputs`.
|
|
64
|
+
input_size_s : float
|
|
65
|
+
Alias for `input_window_seconds`.
|
|
66
|
+
activation : nn.Module, default=nn.ReLU
|
|
67
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
68
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
69
|
+
activation_mrcnn : nn.Module, default=nn.ReLU
|
|
70
|
+
Activation function class to apply in the Mask R-CNN layer.
|
|
71
|
+
Should be a PyTorch activation module class like ``nn.ReLU`` or
|
|
72
|
+
``nn.GELU``. Default is ``nn.GELU``.
|
|
73
|
+
|
|
74
|
+
References
|
|
75
|
+
----------
|
|
76
|
+
.. [Eldele2021] E. Eldele et al., "An Attention-Based Deep Learning Approach for Sleep Stage
|
|
77
|
+
Classification With Single-Channel EEG," in IEEE Transactions on Neural Systems and
|
|
78
|
+
Rehabilitation Engineering, vol. 29, pp. 809-818, 2021, doi: 10.1109/TNSRE.2021.3076234.
|
|
79
|
+
|
|
80
|
+
.. [1] https://github.com/emadeldeen24/AttnSleep
|
|
81
|
+
|
|
82
|
+
.. [2] https://sleepdata.org/datasets/shhs
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
sfreq=None,
|
|
88
|
+
n_tce=2,
|
|
89
|
+
d_model=80,
|
|
90
|
+
d_ff=120,
|
|
91
|
+
n_attn_heads=5,
|
|
92
|
+
drop_prob=0.1,
|
|
93
|
+
activation_mrcnn: type[nn.Module] = nn.GELU,
|
|
94
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
95
|
+
input_window_seconds=None,
|
|
96
|
+
n_outputs=None,
|
|
97
|
+
after_reduced_cnn_size=30,
|
|
98
|
+
return_feats=False,
|
|
99
|
+
chs_info=None,
|
|
100
|
+
n_chans=None,
|
|
101
|
+
n_times=None,
|
|
102
|
+
):
|
|
103
|
+
super().__init__(
|
|
104
|
+
n_outputs=n_outputs,
|
|
105
|
+
n_chans=n_chans,
|
|
106
|
+
chs_info=chs_info,
|
|
107
|
+
n_times=n_times,
|
|
108
|
+
input_window_seconds=input_window_seconds,
|
|
109
|
+
sfreq=sfreq,
|
|
110
|
+
)
|
|
111
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
112
|
+
|
|
113
|
+
self.mapping = {
|
|
114
|
+
"fc.weight": "final_layer.weight",
|
|
115
|
+
"fc.bias": "final_layer.bias",
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if not (
|
|
119
|
+
(self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
|
|
120
|
+
or (
|
|
121
|
+
self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
|
|
122
|
+
)
|
|
123
|
+
):
|
|
124
|
+
warnings.warn(
|
|
125
|
+
"This model was designed originally for input windows of 30sec at 100Hz, "
|
|
126
|
+
"with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
|
|
127
|
+
"other than this may cause errors or cause the model to perform in "
|
|
128
|
+
"other ways than intended",
|
|
129
|
+
UserWarning,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# the usual kernel size for the mrcnn, for sfreq 100
|
|
133
|
+
kernel_size = 7
|
|
134
|
+
|
|
135
|
+
if self.sfreq == 125:
|
|
136
|
+
kernel_size = 6
|
|
137
|
+
|
|
138
|
+
mrcnn = _MRCNN(
|
|
139
|
+
after_reduced_cnn_size,
|
|
140
|
+
kernel_size,
|
|
141
|
+
activation=activation_mrcnn,
|
|
142
|
+
activation_se=activation,
|
|
143
|
+
)
|
|
144
|
+
attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
|
|
145
|
+
ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
|
|
146
|
+
tce = _TCE(
|
|
147
|
+
_EncoderLayer(
|
|
148
|
+
d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
|
|
149
|
+
),
|
|
150
|
+
n_tce,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.feature_extractor = nn.Sequential(mrcnn, tce)
|
|
154
|
+
self.len_last_layer = self._len_last_layer(self.n_times)
|
|
155
|
+
self.return_feats = return_feats
|
|
156
|
+
|
|
157
|
+
# TODO: Add new way to handle return features
|
|
158
|
+
"""if return_feats:
|
|
159
|
+
raise ValueError("return_feat == True is not accepted anymore")"""
|
|
160
|
+
if not return_feats:
|
|
161
|
+
self.final_layer = nn.Linear(
|
|
162
|
+
d_model * after_reduced_cnn_size, self.n_outputs
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _len_last_layer(self, input_size):
|
|
166
|
+
self.feature_extractor.eval()
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
out = self.feature_extractor(torch.Tensor(1, 1, input_size))
|
|
169
|
+
self.feature_extractor.train()
|
|
170
|
+
return len(out.flatten())
|
|
171
|
+
|
|
172
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
173
|
+
"""
|
|
174
|
+
Forward pass.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
x : torch.Tensor
|
|
179
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
encoded_features = self.feature_extractor(x)
|
|
183
|
+
encoded_features = encoded_features.contiguous().view(
|
|
184
|
+
encoded_features.shape[0], -1
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if self.return_feats:
|
|
188
|
+
return encoded_features
|
|
189
|
+
|
|
190
|
+
return self.final_layer(encoded_features)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class _SELayer(nn.Module):
|
|
194
|
+
def __init__(self, channel, reduction=16, activation=nn.ReLU):
|
|
195
|
+
super(_SELayer, self).__init__()
|
|
196
|
+
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
197
|
+
self.fc = nn.Sequential(
|
|
198
|
+
nn.Linear(channel, channel // reduction, bias=False),
|
|
199
|
+
activation(inplace=True),
|
|
200
|
+
nn.Linear(channel // reduction, channel, bias=False),
|
|
201
|
+
nn.Sigmoid(),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
"""
|
|
206
|
+
Forward pass of the SE layer.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
x : torch.Tensor
|
|
211
|
+
Input tensor of shape (batch_size, channel, length).
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
torch.Tensor
|
|
216
|
+
Output tensor after applying the SE recalibration.
|
|
217
|
+
"""
|
|
218
|
+
b, c, _ = x.size()
|
|
219
|
+
y = self.avg_pool(x).view(b, c)
|
|
220
|
+
y = self.fc(y).view(b, c, 1)
|
|
221
|
+
return x * y.expand_as(x)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class _SEBasicBlock(nn.Module):
|
|
225
|
+
expansion = 1
|
|
226
|
+
|
|
227
|
+
def __init__(
|
|
228
|
+
self,
|
|
229
|
+
inplanes,
|
|
230
|
+
planes,
|
|
231
|
+
stride=1,
|
|
232
|
+
downsample=None,
|
|
233
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
234
|
+
*,
|
|
235
|
+
reduction=16,
|
|
236
|
+
):
|
|
237
|
+
super(_SEBasicBlock, self).__init__()
|
|
238
|
+
self.conv1 = nn.Conv1d(inplanes, planes, stride)
|
|
239
|
+
self.bn1 = nn.BatchNorm1d(planes)
|
|
240
|
+
self.relu = activation(inplace=True)
|
|
241
|
+
self.conv2 = nn.Conv1d(planes, planes, 1)
|
|
242
|
+
self.bn2 = nn.BatchNorm1d(planes)
|
|
243
|
+
self.se = _SELayer(planes, reduction)
|
|
244
|
+
self.downsample = downsample
|
|
245
|
+
self.stride = stride
|
|
246
|
+
self.features = nn.Sequential(
|
|
247
|
+
self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
251
|
+
"""
|
|
252
|
+
Forward pass of the SE layer.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
x : torch.Tensor
|
|
257
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
torch.Tensor
|
|
262
|
+
Output tensor after applying the SE recalibration.
|
|
263
|
+
"""
|
|
264
|
+
residual = x
|
|
265
|
+
out = self.features(x)
|
|
266
|
+
|
|
267
|
+
if self.downsample is not None:
|
|
268
|
+
residual = self.downsample(x)
|
|
269
|
+
|
|
270
|
+
out += residual
|
|
271
|
+
out = self.relu(out)
|
|
272
|
+
|
|
273
|
+
return out
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class _MRCNN(nn.Module):
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
after_reduced_cnn_size,
|
|
280
|
+
kernel_size=7,
|
|
281
|
+
activation: type[nn.Module] = nn.GELU,
|
|
282
|
+
activation_se: type[nn.Module] = nn.ReLU,
|
|
283
|
+
):
|
|
284
|
+
super(_MRCNN, self).__init__()
|
|
285
|
+
drate = 0.5
|
|
286
|
+
self.GELU = activation()
|
|
287
|
+
self.features1 = nn.Sequential(
|
|
288
|
+
nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
|
|
289
|
+
nn.BatchNorm1d(64),
|
|
290
|
+
self.GELU,
|
|
291
|
+
nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
|
|
292
|
+
nn.Dropout(drate),
|
|
293
|
+
nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
294
|
+
nn.BatchNorm1d(128),
|
|
295
|
+
self.GELU,
|
|
296
|
+
nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
|
|
297
|
+
nn.BatchNorm1d(128),
|
|
298
|
+
self.GELU,
|
|
299
|
+
nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
self.features2 = nn.Sequential(
|
|
303
|
+
nn.Conv1d(1, 64, kernel_size=400, stride=50, bias=False, padding=200),
|
|
304
|
+
nn.BatchNorm1d(64),
|
|
305
|
+
self.GELU,
|
|
306
|
+
nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
|
|
307
|
+
nn.Dropout(drate),
|
|
308
|
+
nn.Conv1d(
|
|
309
|
+
64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
310
|
+
),
|
|
311
|
+
nn.BatchNorm1d(128),
|
|
312
|
+
self.GELU,
|
|
313
|
+
nn.Conv1d(
|
|
314
|
+
128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
|
|
315
|
+
),
|
|
316
|
+
nn.BatchNorm1d(128),
|
|
317
|
+
self.GELU,
|
|
318
|
+
nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
self.dropout = nn.Dropout(drate)
|
|
322
|
+
self.inplanes = 128
|
|
323
|
+
self.AFR = self._make_layer(
|
|
324
|
+
_SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def _make_layer(
|
|
328
|
+
self, block, planes, blocks, stride=1, activate: type[nn.Module] = nn.ReLU
|
|
329
|
+
): # makes residual SE block
|
|
330
|
+
downsample = None
|
|
331
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
332
|
+
downsample = nn.Sequential(
|
|
333
|
+
nn.Conv1d(
|
|
334
|
+
self.inplanes,
|
|
335
|
+
planes * block.expansion,
|
|
336
|
+
kernel_size=1,
|
|
337
|
+
stride=stride,
|
|
338
|
+
bias=False,
|
|
339
|
+
),
|
|
340
|
+
nn.BatchNorm1d(planes * block.expansion),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
layers = []
|
|
344
|
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
345
|
+
self.inplanes = planes * block.expansion
|
|
346
|
+
for i in range(1, blocks):
|
|
347
|
+
layers.append(block(self.inplanes, planes, activate=activate))
|
|
348
|
+
|
|
349
|
+
return nn.Sequential(*layers)
|
|
350
|
+
|
|
351
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
352
|
+
x1 = self.features1(x)
|
|
353
|
+
x2 = self.features2(x)
|
|
354
|
+
x_concat = torch.cat((x1, x2), dim=2)
|
|
355
|
+
x_concat = self.dropout(x_concat)
|
|
356
|
+
x_concat = self.AFR(x_concat)
|
|
357
|
+
return x_concat
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
##########################################################################################
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _attention(
|
|
364
|
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
365
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
366
|
+
"""Implementation of Scaled dot product attention."""
|
|
367
|
+
# d_k - dimension of the query and key vectors
|
|
368
|
+
d_k = query.size(-1)
|
|
369
|
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
370
|
+
p_attn = F.softmax(scores, dim=-1) # attention weights
|
|
371
|
+
output = torch.matmul(p_attn, value) # (B, h, T, d_k)
|
|
372
|
+
return output, p_attn
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class _MultiHeadedAttention(nn.Module):
|
|
376
|
+
def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
|
|
377
|
+
"""Take in model size and number of heads."""
|
|
378
|
+
super().__init__()
|
|
379
|
+
assert d_model % h == 0
|
|
380
|
+
self.d_per_head = d_model // h
|
|
381
|
+
self.h = h
|
|
382
|
+
|
|
383
|
+
base_conv = CausalConv1d(
|
|
384
|
+
in_channels=after_reduced_cnn_size,
|
|
385
|
+
out_channels=after_reduced_cnn_size,
|
|
386
|
+
kernel_size=7,
|
|
387
|
+
stride=1,
|
|
388
|
+
)
|
|
389
|
+
self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
|
|
390
|
+
|
|
391
|
+
self.linear = nn.Linear(d_model, d_model)
|
|
392
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
393
|
+
|
|
394
|
+
def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
|
|
395
|
+
"""Implements Multi-head attention."""
|
|
396
|
+
nbatches = query.size(0)
|
|
397
|
+
|
|
398
|
+
query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
|
|
399
|
+
key = (
|
|
400
|
+
self.convs[1](key)
|
|
401
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
402
|
+
.transpose(1, 2)
|
|
403
|
+
)
|
|
404
|
+
value = (
|
|
405
|
+
self.convs[2](value)
|
|
406
|
+
.view(nbatches, -1, self.h, self.d_per_head)
|
|
407
|
+
.transpose(1, 2)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
x_raw, attn_weights = _attention(query, key, value)
|
|
411
|
+
# apply dropout to the *weights*
|
|
412
|
+
attn = self.dropout(attn_weights)
|
|
413
|
+
# recompute the weighted sum with dropped weights
|
|
414
|
+
x = torch.matmul(attn, value)
|
|
415
|
+
|
|
416
|
+
# stash the pre‑dropout weights if you need them
|
|
417
|
+
self.attn = attn_weights
|
|
418
|
+
|
|
419
|
+
# merge heads and project
|
|
420
|
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
|
|
421
|
+
|
|
422
|
+
return self.linear(x)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class _ResidualLayerNormAttn(nn.Module):
|
|
426
|
+
r"""A residual connection followed by a layer norm."""
|
|
427
|
+
|
|
428
|
+
def __init__(self, size, dropout, fn_attn):
|
|
429
|
+
super().__init__()
|
|
430
|
+
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
431
|
+
self.dropout = nn.Dropout(dropout)
|
|
432
|
+
self.fn_attn = fn_attn
|
|
433
|
+
|
|
434
|
+
def forward(
|
|
435
|
+
self,
|
|
436
|
+
x: torch.Tensor,
|
|
437
|
+
key: torch.Tensor,
|
|
438
|
+
value: torch.Tensor,
|
|
439
|
+
) -> torch.Tensor:
|
|
440
|
+
"""Apply residual connection to any sublayer with the same size."""
|
|
441
|
+
x_norm = self.norm(x)
|
|
442
|
+
|
|
443
|
+
out = self.fn_attn(x_norm, key, value)
|
|
444
|
+
|
|
445
|
+
return x + self.dropout(out)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class _ResidualLayerNormFF(nn.Module):
|
|
449
|
+
def __init__(self, size, dropout, fn_ff):
|
|
450
|
+
super().__init__()
|
|
451
|
+
self.norm = nn.LayerNorm(size, eps=1e-6)
|
|
452
|
+
self.dropout = nn.Dropout(dropout)
|
|
453
|
+
self.fn_ff = fn_ff
|
|
454
|
+
|
|
455
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
456
|
+
"""Apply residual connection to any sublayer with the same size."""
|
|
457
|
+
x_norm = self.norm(x)
|
|
458
|
+
|
|
459
|
+
out = self.fn_ff(x_norm)
|
|
460
|
+
|
|
461
|
+
return x + self.dropout(out)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class _TCE(nn.Module):
|
|
465
|
+
r"""
|
|
466
|
+
Transformer Encoder.
|
|
467
|
+
|
|
468
|
+
It is a stack of n layers.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
def __init__(self, layer, n):
|
|
472
|
+
super().__init__()
|
|
473
|
+
|
|
474
|
+
self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
|
|
475
|
+
|
|
476
|
+
self.norm = nn.LayerNorm(layer.size, eps=1e-6)
|
|
477
|
+
|
|
478
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
479
|
+
for layer in self.layers:
|
|
480
|
+
x = layer(x)
|
|
481
|
+
return self.norm(x)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class _EncoderLayer(nn.Module):
|
|
485
|
+
r"""
|
|
486
|
+
An encoder layer.
|
|
487
|
+
|
|
488
|
+
Made up of self-attention and a feed forward layer.
|
|
489
|
+
Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
|
|
493
|
+
super().__init__()
|
|
494
|
+
self.size = size
|
|
495
|
+
self.self_attn = self_attn
|
|
496
|
+
self.feed_forward = feed_forward
|
|
497
|
+
|
|
498
|
+
self.residual_self_attn = _ResidualLayerNormAttn(
|
|
499
|
+
size=size,
|
|
500
|
+
dropout=dropout,
|
|
501
|
+
fn_attn=self_attn,
|
|
502
|
+
)
|
|
503
|
+
self.residual_ff = _ResidualLayerNormFF(
|
|
504
|
+
size=size,
|
|
505
|
+
dropout=dropout,
|
|
506
|
+
fn_ff=feed_forward,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
self.conv = CausalConv1d(
|
|
510
|
+
in_channels=after_reduced_cnn_size,
|
|
511
|
+
out_channels=after_reduced_cnn_size,
|
|
512
|
+
kernel_size=7,
|
|
513
|
+
stride=1,
|
|
514
|
+
dilation=1,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
|
518
|
+
"""Transformer Encoder."""
|
|
519
|
+
query = self.conv(x_in)
|
|
520
|
+
# Encoder self-attention
|
|
521
|
+
x = self.residual_self_attn(query, x_in, x_in)
|
|
522
|
+
x_ff = self.residual_ff(x)
|
|
523
|
+
return x_ff
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class _PositionwiseFeedForward(nn.Module):
|
|
527
|
+
r"""Positionwise feed-forward network."""
|
|
528
|
+
|
|
529
|
+
def __init__(
|
|
530
|
+
self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
|
|
531
|
+
):
|
|
532
|
+
super().__init__()
|
|
533
|
+
self.w_1 = nn.Linear(d_model, d_ff)
|
|
534
|
+
self.w_2 = nn.Linear(d_ff, d_model)
|
|
535
|
+
self.dropout = nn.Dropout(dropout)
|
|
536
|
+
self.activate = activation()
|
|
537
|
+
|
|
538
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
539
|
+
"""Implements FFN equation."""
|
|
540
|
+
return self.w_2(self.dropout(self.activate(self.w_1(x))))
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@deprecated(
|
|
544
|
+
"`SleepStagerEldele2021` was renamed to `AttnSleep` in v1.12 to follow original author's name; this alias will be removed in v1.14."
|
|
545
|
+
)
|
|
546
|
+
class SleepStagerEldele2021(AttnSleep):
|
|
547
|
+
r"""Deprecated alias for SleepStagerEldele2021."""
|
|
548
|
+
|
|
549
|
+
pass
|