braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CTNet: a convolutional transformer network for EEG-based motor imagery
|
|
3
|
+
classification from Wei Zhao et al. (2024).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Authors: Wei Zhao <zhaowei701@163.com>
|
|
7
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
|
|
8
|
+
# License: MIT
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from einops.layers.torch import Rearrange
|
|
16
|
+
from mne.utils import warn
|
|
17
|
+
from torch import Tensor, nn
|
|
18
|
+
|
|
19
|
+
from braindecode.models.base import EEGModuleMixin
|
|
20
|
+
from braindecode.modules import (
|
|
21
|
+
FeedForwardBlock,
|
|
22
|
+
MultiHeadAttention,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CTNet(EEGModuleMixin, nn.Module):
|
|
27
|
+
"""CTNet from Zhao, W et al (2024) [ctnet]_.
|
|
28
|
+
|
|
29
|
+
A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
|
|
30
|
+
|
|
31
|
+
.. figure:: https://raw.githubusercontent.com/snailpt/CTNet/main/architecture.png
|
|
32
|
+
:align: center
|
|
33
|
+
:alt: CTNet Architecture
|
|
34
|
+
|
|
35
|
+
CTNet is an end-to-end neural network architecture designed for classifying motor imagery (MI) tasks from EEG signals.
|
|
36
|
+
The model combines convolutional neural networks (CNNs) with a Transformer encoder to capture both local and global temporal dependencies in the EEG data.
|
|
37
|
+
|
|
38
|
+
The architecture consists of three main components:
|
|
39
|
+
|
|
40
|
+
1. **Convolutional Module**:
|
|
41
|
+
- Apply EEGNetV4 to perform some feature extraction, denoted here as
|
|
42
|
+
_PatchEmbeddingEEGNet module.
|
|
43
|
+
|
|
44
|
+
2. **Transformer Encoder Module**:
|
|
45
|
+
- Utilizes multi-head self-attention mechanisms as EEGConformer but
|
|
46
|
+
with residual blocks.
|
|
47
|
+
|
|
48
|
+
3. **Classifier Module**:
|
|
49
|
+
- Combines features from both the convolutional module
|
|
50
|
+
and the Transformer encoder.
|
|
51
|
+
- Flattens the combined features and applies dropout for regularization.
|
|
52
|
+
- Uses a fully connected layer to produce the final classification output.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
activation : nn.Module, default=nn.GELU
|
|
57
|
+
Activation function to use in the network.
|
|
58
|
+
heads : int, default=4
|
|
59
|
+
Number of attention heads in the Transformer encoder.
|
|
60
|
+
emb_size : int, default=40
|
|
61
|
+
Embedding size (dimensionality) for the Transformer encoder.
|
|
62
|
+
depth : int, default=6
|
|
63
|
+
Number of encoder layers in the Transformer.
|
|
64
|
+
n_filters_time : int, default=20
|
|
65
|
+
Number of temporal filters in the first convolutional layer.
|
|
66
|
+
kernel_size : int, default=64
|
|
67
|
+
Kernel size for the temporal convolutional layer.
|
|
68
|
+
depth_multiplier : int, default=2
|
|
69
|
+
Multiplier for the number of depth-wise convolutional filters.
|
|
70
|
+
pool_size_1 : int, default=8
|
|
71
|
+
Pooling size for the first average pooling layer.
|
|
72
|
+
pool_size_2 : int, default=8
|
|
73
|
+
Pooling size for the second average pooling layer.
|
|
74
|
+
drop_prob_cnn : float, default=0.3
|
|
75
|
+
Dropout probability after convolutional layers.
|
|
76
|
+
drop_prob_posi : float, default=0.1
|
|
77
|
+
Dropout probability for the positional encoding in the Transformer.
|
|
78
|
+
drop_prob_final : float, default=0.5
|
|
79
|
+
Dropout probability before the final classification layer.
|
|
80
|
+
|
|
81
|
+
Notes
|
|
82
|
+
-----
|
|
83
|
+
This implementation is adapted from the original CTNet source code
|
|
84
|
+
[ctnetcode]_ to comply with Braindecode's model standards.
|
|
85
|
+
|
|
86
|
+
References
|
|
87
|
+
----------
|
|
88
|
+
.. [ctnet] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
|
|
89
|
+
CTNet: a convolutional transformer network for EEG-based motor imagery
|
|
90
|
+
classification. Scientific Reports, 14(1), 20237.
|
|
91
|
+
.. [ctnetcode] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
|
|
92
|
+
CTNet source code:
|
|
93
|
+
https://github.com/snailpt/CTNet
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
# Base arguments
|
|
99
|
+
n_outputs=None,
|
|
100
|
+
n_chans=None,
|
|
101
|
+
sfreq=None,
|
|
102
|
+
chs_info=None,
|
|
103
|
+
n_times=None,
|
|
104
|
+
input_window_seconds=None,
|
|
105
|
+
# Model specific arguments
|
|
106
|
+
activation_patch: nn.Module = nn.ELU,
|
|
107
|
+
activation_transformer: nn.Module = nn.GELU,
|
|
108
|
+
drop_prob_cnn: float = 0.3,
|
|
109
|
+
drop_prob_posi: float = 0.1,
|
|
110
|
+
drop_prob_final: float = 0.5,
|
|
111
|
+
# other parameters
|
|
112
|
+
heads: int = 4,
|
|
113
|
+
emb_size: int = 40,
|
|
114
|
+
depth: int = 6,
|
|
115
|
+
n_filters_time: int = 20,
|
|
116
|
+
kernel_size: int = 64,
|
|
117
|
+
depth_multiplier: int = 2,
|
|
118
|
+
pool_size_1: int = 8,
|
|
119
|
+
pool_size_2: int = 8,
|
|
120
|
+
):
|
|
121
|
+
super().__init__(
|
|
122
|
+
n_outputs=n_outputs,
|
|
123
|
+
n_chans=n_chans,
|
|
124
|
+
chs_info=chs_info,
|
|
125
|
+
n_times=n_times,
|
|
126
|
+
input_window_seconds=input_window_seconds,
|
|
127
|
+
sfreq=sfreq,
|
|
128
|
+
)
|
|
129
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
130
|
+
|
|
131
|
+
self.emb_size = emb_size
|
|
132
|
+
self.activation_patch = activation_patch
|
|
133
|
+
self.activation_transformer = activation_transformer
|
|
134
|
+
|
|
135
|
+
self.n_filters_time = n_filters_time
|
|
136
|
+
self.drop_prob_cnn = drop_prob_cnn
|
|
137
|
+
self.pool_size_1 = pool_size_1
|
|
138
|
+
self.pool_size_2 = pool_size_2
|
|
139
|
+
self.depth_multiplier = depth_multiplier
|
|
140
|
+
self.kernel_size = kernel_size
|
|
141
|
+
self.drop_prob_posi = drop_prob_posi
|
|
142
|
+
self.drop_prob_final = drop_prob_final
|
|
143
|
+
|
|
144
|
+
# n_times - pool_size_1 / p
|
|
145
|
+
sequence_length = math.floor(
|
|
146
|
+
(
|
|
147
|
+
math.floor((self.n_times - self.pool_size_1) / self.pool_size_1 + 1)
|
|
148
|
+
- self.pool_size_2
|
|
149
|
+
)
|
|
150
|
+
/ self.pool_size_2
|
|
151
|
+
+ 1
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Layers
|
|
155
|
+
self.ensuredim = Rearrange("batch nchans time -> batch 1 nchans time")
|
|
156
|
+
self.flatten = nn.Flatten()
|
|
157
|
+
|
|
158
|
+
self.cnn = _PatchEmbeddingEEGNet(
|
|
159
|
+
n_filters_time=self.n_filters_time,
|
|
160
|
+
kernel_size=self.kernel_size,
|
|
161
|
+
depth_multiplier=self.depth_multiplier,
|
|
162
|
+
pool_size_1=self.pool_size_1,
|
|
163
|
+
pool_size_2=self.pool_size_2,
|
|
164
|
+
drop_prob=self.drop_prob_cnn,
|
|
165
|
+
n_chans=self.n_chans,
|
|
166
|
+
activation=self.activation_patch,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.position = _PositionalEncoding(
|
|
170
|
+
emb_size=emb_size,
|
|
171
|
+
drop_prob=self.drop_prob_posi,
|
|
172
|
+
n_times=self.n_times,
|
|
173
|
+
pool_size=self.pool_size_1,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.trans = _TransformerEncoder(
|
|
177
|
+
heads, depth, emb_size, activation=self.activation_transformer
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self.flatten_drop_layer = nn.Sequential(
|
|
181
|
+
nn.Flatten(),
|
|
182
|
+
nn.Dropout(p=self.drop_prob_final),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
self.final_layer = nn.Linear(
|
|
186
|
+
in_features=emb_size * sequence_length, out_features=self.n_outputs
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
190
|
+
"""
|
|
191
|
+
Forward pass of the CTNet model.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
x : Tensor
|
|
196
|
+
Input tensor of shape (batch_size, n_channels, n_times).
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
Tensor
|
|
201
|
+
Output with shape (batch_size, n_outputs).
|
|
202
|
+
"""
|
|
203
|
+
x = self.ensuredim(x)
|
|
204
|
+
cnn = self.cnn(x)
|
|
205
|
+
cnn = cnn * math.sqrt(self.emb_size)
|
|
206
|
+
cnn = self.position(cnn)
|
|
207
|
+
trans = self.trans(cnn)
|
|
208
|
+
features = cnn + trans
|
|
209
|
+
flatten_feature = self.flatten(features)
|
|
210
|
+
out = self.final_layer(flatten_feature)
|
|
211
|
+
return out
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class _PatchEmbeddingEEGNet(nn.Module):
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
n_filters_time: int = 16,
|
|
218
|
+
kernel_size: int = 64,
|
|
219
|
+
depth_multiplier: int = 2,
|
|
220
|
+
pool_size_1: int = 8,
|
|
221
|
+
pool_size_2: int = 8,
|
|
222
|
+
drop_prob: float = 0.3,
|
|
223
|
+
n_chans: int = 22,
|
|
224
|
+
activation: nn.Module = nn.ELU,
|
|
225
|
+
):
|
|
226
|
+
super().__init__()
|
|
227
|
+
n_filters_out = depth_multiplier * n_filters_time
|
|
228
|
+
self.eegnet_module = nn.Sequential(
|
|
229
|
+
# Temporal convolution
|
|
230
|
+
nn.Conv2d(
|
|
231
|
+
in_channels=1,
|
|
232
|
+
out_channels=n_filters_time,
|
|
233
|
+
kernel_size=(1, kernel_size),
|
|
234
|
+
stride=(1, 1),
|
|
235
|
+
padding="same",
|
|
236
|
+
bias=False,
|
|
237
|
+
),
|
|
238
|
+
nn.BatchNorm2d(n_filters_time),
|
|
239
|
+
# Channel depth-wise convolution
|
|
240
|
+
nn.Conv2d(
|
|
241
|
+
in_channels=n_filters_time,
|
|
242
|
+
out_channels=n_filters_out,
|
|
243
|
+
kernel_size=(n_chans, 1),
|
|
244
|
+
stride=(1, 1),
|
|
245
|
+
groups=n_filters_time,
|
|
246
|
+
padding="valid",
|
|
247
|
+
bias=False,
|
|
248
|
+
),
|
|
249
|
+
nn.BatchNorm2d(n_filters_out),
|
|
250
|
+
activation(),
|
|
251
|
+
# First average pooling
|
|
252
|
+
nn.AvgPool2d(kernel_size=(1, pool_size_1)),
|
|
253
|
+
nn.Dropout(drop_prob),
|
|
254
|
+
# Spatial convolution
|
|
255
|
+
nn.Conv2d(
|
|
256
|
+
in_channels=n_filters_out,
|
|
257
|
+
out_channels=n_filters_out,
|
|
258
|
+
kernel_size=(1, 16),
|
|
259
|
+
padding="same",
|
|
260
|
+
bias=False,
|
|
261
|
+
),
|
|
262
|
+
nn.BatchNorm2d(n_filters_out),
|
|
263
|
+
activation(),
|
|
264
|
+
# Second average pooling
|
|
265
|
+
nn.AvgPool2d(kernel_size=(1, pool_size_2)),
|
|
266
|
+
nn.Dropout(drop_prob),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
self.projection = nn.Sequential(
|
|
270
|
+
Rearrange("b e h w -> b (h w) e"),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
274
|
+
"""
|
|
275
|
+
Forward pass of the Patch Embedding CNN.
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
x : Tensor
|
|
280
|
+
Input tensor of shape (batch_size, 1, n_channels, n_times).
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
Tensor
|
|
285
|
+
Embedded patches of shape (batch_size, num_patches, embedding_dim).
|
|
286
|
+
"""
|
|
287
|
+
x = self.eegnet_module(x)
|
|
288
|
+
x = self.projection(x)
|
|
289
|
+
return x
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class _ResidualAdd(nn.Module):
|
|
293
|
+
def __init__(self, module: nn.Module, emb_size: int, drop_p: float):
|
|
294
|
+
super().__init__()
|
|
295
|
+
self.module = module
|
|
296
|
+
self.drop = nn.Dropout(drop_p)
|
|
297
|
+
self.layernorm = nn.LayerNorm(emb_size)
|
|
298
|
+
|
|
299
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
300
|
+
"""
|
|
301
|
+
Forward pass with residual connection.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
x : Tensor
|
|
306
|
+
Input tensor.
|
|
307
|
+
**kwargs : Any
|
|
308
|
+
Additional arguments.
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
Tensor
|
|
313
|
+
Output tensor after applying residual connection.
|
|
314
|
+
"""
|
|
315
|
+
res = self.module(x)
|
|
316
|
+
out = self.layernorm(self.drop(res) + x)
|
|
317
|
+
return out
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class _TransformerEncoderBlock(nn.Module):
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
dim_feedforward: int,
|
|
324
|
+
num_heads: int = 4,
|
|
325
|
+
drop_prob: float = 0.5,
|
|
326
|
+
forward_expansion: int = 4,
|
|
327
|
+
forward_drop_p: float = 0.5,
|
|
328
|
+
activation: nn.Module = nn.GELU,
|
|
329
|
+
):
|
|
330
|
+
super().__init__()
|
|
331
|
+
self.attention = _ResidualAdd(
|
|
332
|
+
nn.Sequential(
|
|
333
|
+
MultiHeadAttention(dim_feedforward, num_heads, drop_prob),
|
|
334
|
+
),
|
|
335
|
+
dim_feedforward,
|
|
336
|
+
drop_prob,
|
|
337
|
+
)
|
|
338
|
+
self.feed_forward = _ResidualAdd(
|
|
339
|
+
nn.Sequential(
|
|
340
|
+
FeedForwardBlock(
|
|
341
|
+
dim_feedforward,
|
|
342
|
+
expansion=forward_expansion,
|
|
343
|
+
drop_p=forward_drop_p,
|
|
344
|
+
activation=activation,
|
|
345
|
+
),
|
|
346
|
+
),
|
|
347
|
+
dim_feedforward,
|
|
348
|
+
drop_prob,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
352
|
+
"""
|
|
353
|
+
Forward pass of the transformer encoder block.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
x : Tensor
|
|
358
|
+
Input tensor.
|
|
359
|
+
**kwargs : Any
|
|
360
|
+
Additional arguments.
|
|
361
|
+
|
|
362
|
+
Returns
|
|
363
|
+
-------
|
|
364
|
+
Tensor
|
|
365
|
+
Output tensor after transformer encoder block.
|
|
366
|
+
"""
|
|
367
|
+
x = self.attention(x)
|
|
368
|
+
x = self.feed_forward(x)
|
|
369
|
+
return x
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class _TransformerEncoder(nn.Module):
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
nheads: int,
|
|
376
|
+
depth: int,
|
|
377
|
+
dim_feedforward: int,
|
|
378
|
+
activation: nn.Module = nn.GELU,
|
|
379
|
+
):
|
|
380
|
+
super().__init__()
|
|
381
|
+
self.layers = nn.Sequential(
|
|
382
|
+
*[
|
|
383
|
+
_TransformerEncoderBlock(
|
|
384
|
+
dim_feedforward=dim_feedforward,
|
|
385
|
+
num_heads=nheads,
|
|
386
|
+
activation=activation,
|
|
387
|
+
)
|
|
388
|
+
for _ in range(depth)
|
|
389
|
+
]
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
393
|
+
"""
|
|
394
|
+
Forward pass of the transformer encoder.
|
|
395
|
+
|
|
396
|
+
Parameters
|
|
397
|
+
----------
|
|
398
|
+
x : Tensor
|
|
399
|
+
Input tensor.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
Tensor
|
|
404
|
+
Output tensor after transformer encoder.
|
|
405
|
+
"""
|
|
406
|
+
return self.layers(x)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class _PositionalEncoding(nn.Module):
|
|
410
|
+
def __init__(
|
|
411
|
+
self,
|
|
412
|
+
n_times: int,
|
|
413
|
+
emb_size: int,
|
|
414
|
+
length: int = 100,
|
|
415
|
+
drop_prob: float = 0.1,
|
|
416
|
+
pool_size: int = 8,
|
|
417
|
+
):
|
|
418
|
+
super().__init__()
|
|
419
|
+
self.pool_size = pool_size
|
|
420
|
+
self.n_times = n_times
|
|
421
|
+
|
|
422
|
+
if int(n_times / (pool_size * pool_size)) > length:
|
|
423
|
+
warn(
|
|
424
|
+
"the temporal dimensional is too long for this default length. "
|
|
425
|
+
"The length parameter will be automatically adjusted to "
|
|
426
|
+
"avoid inference issues."
|
|
427
|
+
)
|
|
428
|
+
length = int(n_times / (pool_size * pool_size))
|
|
429
|
+
|
|
430
|
+
self.dropout = nn.Dropout(drop_prob)
|
|
431
|
+
self.encoding = nn.Parameter(torch.randn(1, length, emb_size))
|
|
432
|
+
|
|
433
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
434
|
+
"""
|
|
435
|
+
Forward pass of the positional encoding.
|
|
436
|
+
|
|
437
|
+
Parameters
|
|
438
|
+
----------
|
|
439
|
+
x : Tensor
|
|
440
|
+
Input tensor of shape (batch_size, sequence_length, embedding_dim).
|
|
441
|
+
|
|
442
|
+
Returns
|
|
443
|
+
-------
|
|
444
|
+
Tensor
|
|
445
|
+
Tensor with positional encoding added.
|
|
446
|
+
"""
|
|
447
|
+
seq_length = x.size(1)
|
|
448
|
+
encoding = self.encoding[:, :seq_length, :]
|
|
449
|
+
x = x + encoding
|
|
450
|
+
return self.dropout(x)
|
braindecode/models/deep4.py
CHANGED
|
@@ -5,15 +5,22 @@
|
|
|
5
5
|
from einops.layers.torch import Rearrange
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import init
|
|
8
|
-
from torch.nn.functional import elu
|
|
9
8
|
|
|
10
|
-
from .base import EEGModuleMixin
|
|
11
|
-
from .
|
|
12
|
-
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import (
|
|
11
|
+
AvgPool2dWithConv,
|
|
12
|
+
CombinedConv,
|
|
13
|
+
Ensure4d,
|
|
14
|
+
SqueezeFinalOutput,
|
|
15
|
+
)
|
|
13
16
|
|
|
14
17
|
|
|
15
18
|
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
16
|
-
"""Deep ConvNet model from Schirrmeister et al 2017.
|
|
19
|
+
"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
20
|
+
|
|
21
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: CTNet Architecture
|
|
17
24
|
|
|
18
25
|
Model described in [Schirrmeister2017]_.
|
|
19
26
|
|
|
@@ -44,13 +51,13 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
44
51
|
Number of temporal filters in layer 4.
|
|
45
52
|
filter_length_4: int
|
|
46
53
|
Length of the temporal filter in layer 4.
|
|
47
|
-
|
|
54
|
+
activation_first_conv_nonlin: nn.Module, default is nn.ELU
|
|
48
55
|
Non-linear activation function to be used after convolution in layer 1.
|
|
49
56
|
first_pool_mode: str
|
|
50
57
|
Pooling mode in layer 1. "max" or "mean".
|
|
51
58
|
first_pool_nonlin: callable
|
|
52
59
|
Non-linear activation function to be used after pooling in layer 1.
|
|
53
|
-
|
|
60
|
+
activation_later_conv_nonlin: nn.Module, default is nn.ELU
|
|
54
61
|
Non-linear activation function to be used after convolution in later layers.
|
|
55
62
|
later_pool_mode: str
|
|
56
63
|
Pooling mode in later layers. "max" or "mean".
|
|
@@ -67,12 +74,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
67
74
|
Momentum for BatchNorm2d.
|
|
68
75
|
stride_before_pool: bool
|
|
69
76
|
Stride before pooling.
|
|
70
|
-
in_chans :
|
|
71
|
-
Alias for n_chans.
|
|
72
|
-
n_classes:
|
|
73
|
-
Alias for n_outputs.
|
|
74
|
-
input_window_samples :
|
|
75
|
-
Alias for n_times.
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
References
|
|
@@ -87,47 +88,37 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
87
88
|
"""
|
|
88
89
|
|
|
89
90
|
def __init__(
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
in_chans=None,
|
|
121
|
-
n_classes=None,
|
|
122
|
-
input_window_samples=None,
|
|
123
|
-
add_log_softmax=True,
|
|
91
|
+
self,
|
|
92
|
+
n_chans=None,
|
|
93
|
+
n_outputs=None,
|
|
94
|
+
n_times=None,
|
|
95
|
+
final_conv_length="auto",
|
|
96
|
+
n_filters_time=25,
|
|
97
|
+
n_filters_spat=25,
|
|
98
|
+
filter_time_length=10,
|
|
99
|
+
pool_time_length=3,
|
|
100
|
+
pool_time_stride=3,
|
|
101
|
+
n_filters_2=50,
|
|
102
|
+
filter_length_2=10,
|
|
103
|
+
n_filters_3=100,
|
|
104
|
+
filter_length_3=10,
|
|
105
|
+
n_filters_4=200,
|
|
106
|
+
filter_length_4=10,
|
|
107
|
+
activation_first_conv_nonlin: nn.Module = nn.ELU,
|
|
108
|
+
first_pool_mode="max",
|
|
109
|
+
first_pool_nonlin: nn.Module = nn.Identity,
|
|
110
|
+
activation_later_conv_nonlin: nn.Module = nn.ELU,
|
|
111
|
+
later_pool_mode="max",
|
|
112
|
+
later_pool_nonlin: nn.Module = nn.Identity,
|
|
113
|
+
drop_prob=0.5,
|
|
114
|
+
split_first_layer=True,
|
|
115
|
+
batch_norm=True,
|
|
116
|
+
batch_norm_alpha=0.1,
|
|
117
|
+
stride_before_pool=False,
|
|
118
|
+
chs_info=None,
|
|
119
|
+
input_window_seconds=None,
|
|
120
|
+
sfreq=None,
|
|
124
121
|
):
|
|
125
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
126
|
-
self,
|
|
127
|
-
('in_chans', 'n_chans', in_chans, n_chans),
|
|
128
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
129
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
130
|
-
)
|
|
131
122
|
super().__init__(
|
|
132
123
|
n_outputs=n_outputs,
|
|
133
124
|
n_chans=n_chans,
|
|
@@ -135,10 +126,9 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
135
126
|
n_times=n_times,
|
|
136
127
|
input_window_seconds=input_window_seconds,
|
|
137
128
|
sfreq=sfreq,
|
|
138
|
-
add_log_softmax=add_log_softmax,
|
|
139
129
|
)
|
|
140
130
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
141
|
-
|
|
131
|
+
|
|
142
132
|
if final_conv_length == "auto":
|
|
143
133
|
assert self.n_times is not None
|
|
144
134
|
self.final_conv_length = final_conv_length
|
|
@@ -153,10 +143,10 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
153
143
|
self.filter_length_3 = filter_length_3
|
|
154
144
|
self.n_filters_4 = n_filters_4
|
|
155
145
|
self.filter_length_4 = filter_length_4
|
|
156
|
-
self.first_nonlin =
|
|
146
|
+
self.first_nonlin = activation_first_conv_nonlin
|
|
157
147
|
self.first_pool_mode = first_pool_mode
|
|
158
148
|
self.first_pool_nonlin = first_pool_nonlin
|
|
159
|
-
self.later_conv_nonlin =
|
|
149
|
+
self.later_conv_nonlin = activation_later_conv_nonlin
|
|
160
150
|
self.later_pool_mode = later_pool_mode
|
|
161
151
|
self.later_pool_nonlin = later_pool_nonlin
|
|
162
152
|
self.drop_prob = drop_prob
|
|
@@ -174,7 +164,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
174
164
|
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
175
165
|
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
176
166
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
177
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
167
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
178
168
|
}
|
|
179
169
|
|
|
180
170
|
if self.stride_before_pool:
|
|
@@ -223,17 +213,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
223
213
|
eps=1e-5,
|
|
224
214
|
),
|
|
225
215
|
)
|
|
226
|
-
self.add_module("conv_nonlin",
|
|
216
|
+
self.add_module("conv_nonlin", self.first_nonlin())
|
|
227
217
|
self.add_module(
|
|
228
218
|
"pool",
|
|
229
219
|
first_pool_class(
|
|
230
220
|
kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
|
|
231
221
|
),
|
|
232
222
|
)
|
|
233
|
-
self.add_module("pool_nonlin",
|
|
223
|
+
self.add_module("pool_nonlin", self.first_pool_nonlin())
|
|
234
224
|
|
|
235
225
|
def add_conv_pool_block(
|
|
236
|
-
|
|
226
|
+
model, n_filters_before, n_filters, filter_length, block_nr
|
|
237
227
|
):
|
|
238
228
|
suffix = "_{:d}".format(block_nr)
|
|
239
229
|
self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
|
|
@@ -257,7 +247,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
257
247
|
eps=1e-5,
|
|
258
248
|
),
|
|
259
249
|
)
|
|
260
|
-
self.add_module("nonlin" + suffix,
|
|
250
|
+
self.add_module("nonlin" + suffix, self.later_conv_nonlin())
|
|
261
251
|
|
|
262
252
|
self.add_module(
|
|
263
253
|
"pool" + suffix,
|
|
@@ -266,7 +256,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
266
256
|
stride=(pool_stride, 1),
|
|
267
257
|
),
|
|
268
258
|
)
|
|
269
|
-
self.add_module("pool_nonlin" + suffix,
|
|
259
|
+
self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
|
|
270
260
|
|
|
271
261
|
add_conv_pool_block(
|
|
272
262
|
self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
|
|
@@ -286,17 +276,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
286
276
|
# Incorporating classification module and subsequent ones in one final layer
|
|
287
277
|
module = nn.Sequential()
|
|
288
278
|
|
|
289
|
-
module.add_module(
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
279
|
+
module.add_module(
|
|
280
|
+
"conv_classifier",
|
|
281
|
+
nn.Conv2d(
|
|
282
|
+
self.n_filters_4,
|
|
283
|
+
self.n_outputs,
|
|
284
|
+
(self.final_conv_length, 1),
|
|
285
|
+
bias=True,
|
|
286
|
+
),
|
|
287
|
+
)
|
|
298
288
|
|
|
299
|
-
module.add_module("squeeze",
|
|
289
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
300
290
|
|
|
301
291
|
self.add_module("final_layer", module)
|
|
302
292
|
|
|
@@ -329,5 +319,4 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
329
319
|
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
330
320
|
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
331
321
|
|
|
332
|
-
|
|
333
|
-
self.eval()
|
|
322
|
+
self.train()
|