braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EEG-SimpleConv is a 1D Convolutional Neural Network from Yassine El Ouahidi et al. (2023).
|
|
3
|
+
|
|
4
|
+
Originally designed for Motor Imagery decoding, from EEG signals.
|
|
5
|
+
The model offers competitive performances, with a low latency and is mainly composed of
|
|
6
|
+
1D convolutional layers.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Authors: Yassine El Ouahidi <eloua.yas@gmail.com>
|
|
11
|
+
#
|
|
12
|
+
# License: BSD-3
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torch import nn
|
|
16
|
+
from torchaudio.transforms import Resample
|
|
17
|
+
|
|
18
|
+
from braindecode.models.base import EEGModuleMixin
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
|
|
22
|
+
"""EEGSimpleConv from Ouahidi, YE et al. (2023) [Yassine2023]_.
|
|
23
|
+
|
|
24
|
+
.. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
|
|
25
|
+
:align: center
|
|
26
|
+
:alt: EEGSimpleConv Architecture
|
|
27
|
+
|
|
28
|
+
EEGSimpleConv is a 1D Convolutional Neural Network originally designed
|
|
29
|
+
for decoding motor imagery from EEG signals. The model aims to have a
|
|
30
|
+
very simple and straightforward architecture that allows a low latency,
|
|
31
|
+
while still achieving very competitive performance.
|
|
32
|
+
|
|
33
|
+
EEG-SimpleConv starts with a 1D convolutional layer, where each EEG channel
|
|
34
|
+
enters a separate 1D convolutional channel. This is followed by a series of
|
|
35
|
+
blocks of two 1D convolutional layers. Between the two convolutional layers
|
|
36
|
+
of each block is a max pooling layer, which downsamples the data by a factor
|
|
37
|
+
of 2. Each convolution is followed by a batch normalisation layer and a ReLU
|
|
38
|
+
activation function. Finally, a global average pooling (in the time domain)
|
|
39
|
+
is performed to obtain a single value per feature map, which is then fed
|
|
40
|
+
into a linear layer to obtain the final classification prediction output.
|
|
41
|
+
|
|
42
|
+
The paper and original code with more details about the methodological
|
|
43
|
+
choices are available at the [Yassine2023]_ and [Yassine2023Code]_.
|
|
44
|
+
|
|
45
|
+
The input shape should be three-dimensional matrix representing the EEG
|
|
46
|
+
signals.
|
|
47
|
+
|
|
48
|
+
``(batch_size, n_channels, n_timesteps)``.
|
|
49
|
+
|
|
50
|
+
Notes
|
|
51
|
+
-----
|
|
52
|
+
The authors recommend using the default parameters for MI decoding.
|
|
53
|
+
Please refer to the original paper and code for more details.
|
|
54
|
+
|
|
55
|
+
Recommended range for the choice of the hyperparameters, regarding the
|
|
56
|
+
evaluation paradigm.
|
|
57
|
+
|
|
58
|
+
| Parameter | Within-Subject | Cross-Subject |
|
|
59
|
+
| feature_maps | [64-144] | [64-144] |
|
|
60
|
+
| n_convs | 1 | [2-4] |
|
|
61
|
+
| resampling_freq | [70-100] | [50-80] |
|
|
62
|
+
| kernel_size | [12-17] | [5-8] |
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
An intensive ablation study is included in the paper to understand the
|
|
66
|
+
of each parameter on the model performance.
|
|
67
|
+
|
|
68
|
+
.. versionadded:: 0.9
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
feature_maps: int
|
|
73
|
+
Number of Feature Maps at the first Convolution, width of the model.
|
|
74
|
+
n_convs: int
|
|
75
|
+
Number of blocks of convolutions (2 convolutions per block), depth of the model.
|
|
76
|
+
resampling: int
|
|
77
|
+
Resampling Frequency.
|
|
78
|
+
kernel_size: int
|
|
79
|
+
Size of the convolutions kernels.
|
|
80
|
+
activation: nn.Module, default=nn.ELU
|
|
81
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
82
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
83
|
+
|
|
84
|
+
References
|
|
85
|
+
----------
|
|
86
|
+
.. [Yassine2023] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
|
|
87
|
+
N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
|
|
88
|
+
BCI Motor Imagery Decoding. Arxiv preprint. arxiv.org/abs/2309.07159
|
|
89
|
+
.. [Yassine2023Code] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
|
|
90
|
+
N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
|
|
91
|
+
BCI Motor Imagery Decoding. GitHub repository.
|
|
92
|
+
https://github.com/elouayas/EEGSimpleConv.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
# Base arguments
|
|
99
|
+
n_outputs=None,
|
|
100
|
+
n_chans=None,
|
|
101
|
+
sfreq=None,
|
|
102
|
+
# Model specific arguments
|
|
103
|
+
feature_maps=128,
|
|
104
|
+
n_convs=2,
|
|
105
|
+
resampling_freq=80,
|
|
106
|
+
kernel_size=8,
|
|
107
|
+
return_feature=False,
|
|
108
|
+
activation: nn.Module = nn.ReLU,
|
|
109
|
+
# Other ways to initialize the model
|
|
110
|
+
chs_info=None,
|
|
111
|
+
n_times=None,
|
|
112
|
+
input_window_seconds=None,
|
|
113
|
+
):
|
|
114
|
+
super().__init__(
|
|
115
|
+
n_outputs=n_outputs,
|
|
116
|
+
n_chans=n_chans,
|
|
117
|
+
chs_info=chs_info,
|
|
118
|
+
n_times=n_times,
|
|
119
|
+
input_window_seconds=input_window_seconds,
|
|
120
|
+
sfreq=sfreq,
|
|
121
|
+
)
|
|
122
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
123
|
+
|
|
124
|
+
self.return_feature = return_feature
|
|
125
|
+
self.resample = (
|
|
126
|
+
Resample(orig_freq=int(self.sfreq), new_freq=int(resampling_freq))
|
|
127
|
+
if self.sfreq != resampling_freq
|
|
128
|
+
else torch.nn.Identity()
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.conv = torch.nn.Conv1d(
|
|
132
|
+
self.n_chans,
|
|
133
|
+
feature_maps,
|
|
134
|
+
kernel_size=kernel_size,
|
|
135
|
+
padding=kernel_size // 2,
|
|
136
|
+
bias=False,
|
|
137
|
+
)
|
|
138
|
+
self.bn = torch.nn.BatchNorm1d(feature_maps)
|
|
139
|
+
self.blocks = []
|
|
140
|
+
new_feature_maps = feature_maps
|
|
141
|
+
old_feature_maps = feature_maps
|
|
142
|
+
for i in range(n_convs):
|
|
143
|
+
if i > 0:
|
|
144
|
+
# 1.414 = sqrt(2) allow constant flops.
|
|
145
|
+
new_feature_maps = int(1.414 * new_feature_maps)
|
|
146
|
+
self.blocks.append(
|
|
147
|
+
torch.nn.Sequential(
|
|
148
|
+
(
|
|
149
|
+
torch.nn.Conv1d(
|
|
150
|
+
old_feature_maps,
|
|
151
|
+
new_feature_maps,
|
|
152
|
+
kernel_size=kernel_size,
|
|
153
|
+
padding=kernel_size // 2,
|
|
154
|
+
bias=False,
|
|
155
|
+
)
|
|
156
|
+
),
|
|
157
|
+
(torch.nn.BatchNorm1d(new_feature_maps)),
|
|
158
|
+
(torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
|
|
159
|
+
(activation()),
|
|
160
|
+
(
|
|
161
|
+
torch.nn.Conv1d(
|
|
162
|
+
new_feature_maps,
|
|
163
|
+
new_feature_maps,
|
|
164
|
+
kernel_size=kernel_size,
|
|
165
|
+
padding=kernel_size // 2,
|
|
166
|
+
bias=False,
|
|
167
|
+
)
|
|
168
|
+
),
|
|
169
|
+
(torch.nn.BatchNorm1d(new_feature_maps)),
|
|
170
|
+
(activation()),
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
old_feature_maps = new_feature_maps
|
|
174
|
+
self.blocks = torch.nn.ModuleList(self.blocks)
|
|
175
|
+
self.final_layer = torch.nn.Linear(old_feature_maps, self.n_outputs)
|
|
176
|
+
|
|
177
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
178
|
+
"""
|
|
179
|
+
Forward pass of the model.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
x: PyTorch Tensor
|
|
184
|
+
Input tensor of shape (batch_size, n_channels, n_times)
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
PyTorch Tensor (optional)
|
|
189
|
+
Output tensor of shape (batch_size, n_outputs)
|
|
190
|
+
"""
|
|
191
|
+
x_rs = self.resample(x.contiguous())
|
|
192
|
+
feat = torch.relu(self.bn(self.conv(x_rs)))
|
|
193
|
+
for seq in self.blocks:
|
|
194
|
+
feat = seq(feat)
|
|
195
|
+
feat = feat.mean(dim=2)
|
|
196
|
+
if self.return_feature:
|
|
197
|
+
return feat
|
|
198
|
+
else:
|
|
199
|
+
return self.final_layer(feat)
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from einops.layers.torch import Rearrange
|
|
10
|
+
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import Chomp1d, MaxNormLinear
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EEGTCNet(EEGModuleMixin, nn.Module):
|
|
16
|
+
"""EEGTCNet model from Ingolfsson et al. (2020) [ingolfsson2020]_.
|
|
17
|
+
|
|
18
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: EEGTCNet Architecture
|
|
21
|
+
|
|
22
|
+
Combining EEGNet and TCN blocks.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
activation : nn.Module, optional
|
|
27
|
+
Activation function to use. Default is `nn.ELU()`.
|
|
28
|
+
depth_multiplier : int, optional
|
|
29
|
+
Depth multiplier for the depthwise convolution. Default is 2.
|
|
30
|
+
filter_1 : int, optional
|
|
31
|
+
Number of temporal filters in the first convolutional layer. Default is 8.
|
|
32
|
+
kern_length : int, optional
|
|
33
|
+
Length of the temporal kernel in the first convolutional layer. Default is 64.
|
|
34
|
+
dropout : float, optional
|
|
35
|
+
Dropout rate. Default is 0.5.
|
|
36
|
+
depth : int, optional
|
|
37
|
+
Number of residual blocks in the TCN. Default is 2.
|
|
38
|
+
kernel_size : int, optional
|
|
39
|
+
Size of the temporal convolutional kernel in the TCN. Default is 4.
|
|
40
|
+
filters : int, optional
|
|
41
|
+
Number of filters in the TCN convolutional layers. Default is 12.
|
|
42
|
+
max_norm_const : float
|
|
43
|
+
Maximum L2-norm constraint imposed on weights of the last
|
|
44
|
+
fully-connected layer. Defaults to 0.25.
|
|
45
|
+
|
|
46
|
+
References
|
|
47
|
+
----------
|
|
48
|
+
.. [ingolfsson2020] Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N.,
|
|
49
|
+
Cavigelli, L., & Benini, L. (2020). EEG-TCNet: An accurate temporal
|
|
50
|
+
convolutional network for embedded motor-imagery brain–machine interfaces.
|
|
51
|
+
https://doi.org/10.48550/arXiv.2006.00622
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
# Signal related parameters
|
|
57
|
+
n_chans=None,
|
|
58
|
+
n_outputs=None,
|
|
59
|
+
n_times=None,
|
|
60
|
+
chs_info=None,
|
|
61
|
+
input_window_seconds=None,
|
|
62
|
+
sfreq=None,
|
|
63
|
+
# Model parameters
|
|
64
|
+
activation: nn.Module = nn.ELU,
|
|
65
|
+
depth_multiplier: int = 2,
|
|
66
|
+
filter_1: int = 8,
|
|
67
|
+
kern_length: int = 64,
|
|
68
|
+
drop_prob: float = 0.5,
|
|
69
|
+
depth: int = 2,
|
|
70
|
+
kernel_size: int = 4,
|
|
71
|
+
filters: int = 12,
|
|
72
|
+
max_norm_const: float = 0.25,
|
|
73
|
+
):
|
|
74
|
+
super().__init__(
|
|
75
|
+
n_outputs=n_outputs,
|
|
76
|
+
n_chans=n_chans,
|
|
77
|
+
chs_info=chs_info,
|
|
78
|
+
n_times=n_times,
|
|
79
|
+
input_window_seconds=input_window_seconds,
|
|
80
|
+
sfreq=sfreq,
|
|
81
|
+
)
|
|
82
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
83
|
+
|
|
84
|
+
self.activation = activation
|
|
85
|
+
self.drop_prob = drop_prob
|
|
86
|
+
self.depth_multiplier = depth_multiplier
|
|
87
|
+
self.filter_1 = filter_1
|
|
88
|
+
self.kern_length = kern_length
|
|
89
|
+
self.depth = depth
|
|
90
|
+
self.kernel_size = kernel_size
|
|
91
|
+
self.filters = filters
|
|
92
|
+
self.max_norm_const = max_norm_const
|
|
93
|
+
self.filter_2 = self.filter_1 * self.depth_multiplier
|
|
94
|
+
|
|
95
|
+
self.arrange_dim_input = Rearrange(
|
|
96
|
+
"batch nchans ntimes -> batch 1 ntimes nchans"
|
|
97
|
+
)
|
|
98
|
+
# EEGNet_TC Block
|
|
99
|
+
self.eegnet_tc = _EEGNetTC(
|
|
100
|
+
n_chans=self.n_chans,
|
|
101
|
+
filter_1=self.filter_1,
|
|
102
|
+
kern_length=self.kern_length,
|
|
103
|
+
depth_multiplier=self.depth_multiplier,
|
|
104
|
+
drop_prob=self.drop_prob,
|
|
105
|
+
activation=self.activation,
|
|
106
|
+
)
|
|
107
|
+
self.arrange_dim_eegnet = Rearrange(
|
|
108
|
+
"batch filter2 rtimes 1 -> batch rtimes filter2"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# TCN Block
|
|
112
|
+
self.tcn_block = _TCNBlock(
|
|
113
|
+
input_dimension=self.filter_2,
|
|
114
|
+
depth=self.depth,
|
|
115
|
+
kernel_size=self.kernel_size,
|
|
116
|
+
filters=self.filters,
|
|
117
|
+
drop_prob=self.drop_prob,
|
|
118
|
+
activation=self.activation,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Classification Block
|
|
122
|
+
self.final_layer = MaxNormLinear(
|
|
123
|
+
in_features=self.filters,
|
|
124
|
+
out_features=self.n_outputs,
|
|
125
|
+
max_norm_val=self.max_norm_const,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
Forward pass of the EEGTCNet model.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
x : torch.Tensor
|
|
135
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
torch.Tensor
|
|
140
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
141
|
+
"""
|
|
142
|
+
# x shape: (batch_size, n_chans, n_times)
|
|
143
|
+
x = self.arrange_dim_input(x) # (batch_size, 1, n_times, n_chans)
|
|
144
|
+
x = self.eegnet_tc(x) # (batch_size, filter, reduced_time, 1)
|
|
145
|
+
|
|
146
|
+
x = self.arrange_dim_eegnet(x) # (batch_size, reduced_time, F2)
|
|
147
|
+
x = self.tcn_block(x) # (batch_size, time_steps, filters)
|
|
148
|
+
|
|
149
|
+
# Select the last time step
|
|
150
|
+
x = x[:, -1, :] # (batch_size, filters)
|
|
151
|
+
|
|
152
|
+
x = self.final_layer(x) # (batch_size, n_outputs)
|
|
153
|
+
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class _EEGNetTC(nn.Module):
|
|
158
|
+
"""EEGNet Temporal Convolutional Network (TCN) block.
|
|
159
|
+
|
|
160
|
+
The main difference from our EEGNetV4 (braindecode) implementation is the
|
|
161
|
+
kernel and dimensional order. Because of this, we decided to keep this
|
|
162
|
+
implementation in a future issue; we will re-evaluate if it is necessary
|
|
163
|
+
to maintain this separate implementation.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
n_chans : int
|
|
168
|
+
Number of EEG channels.
|
|
169
|
+
filter_1 : int
|
|
170
|
+
Number of temporal filters in the first convolutional layer.
|
|
171
|
+
kern_length : int
|
|
172
|
+
Length of the temporal kernel in the first convolutional layer.
|
|
173
|
+
depth_multiplier : int
|
|
174
|
+
Depth multiplier for the depthwise convolution.
|
|
175
|
+
drop_prob : float
|
|
176
|
+
Dropout rate.
|
|
177
|
+
activation : nn.Module
|
|
178
|
+
Activation function.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
n_chans: int,
|
|
184
|
+
filter_1: int = 8,
|
|
185
|
+
kern_length: int = 64,
|
|
186
|
+
depth_multiplier: int = 2,
|
|
187
|
+
drop_prob: float = 0.5,
|
|
188
|
+
activation: nn.Module = nn.ELU,
|
|
189
|
+
):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.activation = activation()
|
|
192
|
+
self.drop_prob = drop_prob
|
|
193
|
+
self.n_chans = n_chans
|
|
194
|
+
self.filter_1 = filter_1
|
|
195
|
+
self.filter_2 = self.filter_1 * depth_multiplier
|
|
196
|
+
|
|
197
|
+
# First Conv2D Layer
|
|
198
|
+
self.conv1 = nn.Conv2d(
|
|
199
|
+
in_channels=1,
|
|
200
|
+
out_channels=self.filter_1,
|
|
201
|
+
kernel_size=(kern_length, 1),
|
|
202
|
+
padding=(kern_length // 2, 0),
|
|
203
|
+
bias=False,
|
|
204
|
+
)
|
|
205
|
+
self.bn1 = nn.BatchNorm2d(self.filter_1)
|
|
206
|
+
|
|
207
|
+
# Depthwise Convolution
|
|
208
|
+
self.depthwise_conv = nn.Conv2d(
|
|
209
|
+
in_channels=self.filter_1,
|
|
210
|
+
out_channels=self.filter_2,
|
|
211
|
+
kernel_size=(1, n_chans),
|
|
212
|
+
groups=self.filter_1,
|
|
213
|
+
bias=False,
|
|
214
|
+
)
|
|
215
|
+
self.bn2 = nn.BatchNorm2d(self.filter_2)
|
|
216
|
+
self.pool1 = nn.AvgPool2d(kernel_size=(8, 1))
|
|
217
|
+
self.drop1 = nn.Dropout(p=drop_prob)
|
|
218
|
+
|
|
219
|
+
# Separable Convolution (Depthwise + Pointwise)
|
|
220
|
+
self.separable_conv_depthwise = nn.Conv2d(
|
|
221
|
+
in_channels=self.filter_2,
|
|
222
|
+
out_channels=self.filter_2,
|
|
223
|
+
kernel_size=(self.filter_2, 1),
|
|
224
|
+
groups=self.filter_2,
|
|
225
|
+
padding=(self.filter_2 // 2, 0),
|
|
226
|
+
bias=False,
|
|
227
|
+
)
|
|
228
|
+
self.separable_conv_pointwise = nn.Conv2d(
|
|
229
|
+
in_channels=self.filter_2,
|
|
230
|
+
out_channels=self.filter_2,
|
|
231
|
+
kernel_size=(1, 1),
|
|
232
|
+
bias=False,
|
|
233
|
+
)
|
|
234
|
+
self.bn3 = nn.BatchNorm2d(self.filter_2)
|
|
235
|
+
self.pool2 = nn.AvgPool2d(kernel_size=(self.filter_1, 1))
|
|
236
|
+
self.drop2 = nn.Dropout(p=drop_prob)
|
|
237
|
+
|
|
238
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
239
|
+
# x shape: (batch_size, 1, n_times, n_chans)
|
|
240
|
+
x = self.conv1(x)
|
|
241
|
+
x = self.bn1(x)
|
|
242
|
+
x = self.activation(x)
|
|
243
|
+
|
|
244
|
+
x = self.depthwise_conv(x)
|
|
245
|
+
x = self.bn2(x)
|
|
246
|
+
x = self.activation(x)
|
|
247
|
+
x = self.pool1(x)
|
|
248
|
+
x = self.drop1(x)
|
|
249
|
+
|
|
250
|
+
x = self.separable_conv_depthwise(x)
|
|
251
|
+
x = self.separable_conv_pointwise(x)
|
|
252
|
+
x = self.bn3(x)
|
|
253
|
+
x = self.activation(x)
|
|
254
|
+
x = self.pool2(x)
|
|
255
|
+
x = self.drop2(x)
|
|
256
|
+
|
|
257
|
+
return x # Shape: (batch_size, F2, reduced_time, 1)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class _TCNBlock(nn.Module):
|
|
261
|
+
"""
|
|
262
|
+
Many differences from our Temporal Block (braindecode) implementation.
|
|
263
|
+
Because of this, we decided to keep this implementation in a future issue;
|
|
264
|
+
we will re-evaluate if it is necessary to maintain this separate
|
|
265
|
+
implementation.
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(
|
|
271
|
+
self,
|
|
272
|
+
input_dimension: int,
|
|
273
|
+
depth: int,
|
|
274
|
+
kernel_size: int,
|
|
275
|
+
filters: int,
|
|
276
|
+
drop_prob: float,
|
|
277
|
+
activation: nn.Module = nn.ELU,
|
|
278
|
+
):
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.activation = activation()
|
|
281
|
+
self.drop_prob = drop_prob
|
|
282
|
+
self.depth = depth
|
|
283
|
+
self.filters = filters
|
|
284
|
+
self.kernel_size = kernel_size
|
|
285
|
+
|
|
286
|
+
self.layers = nn.ModuleList()
|
|
287
|
+
self.downsample = (
|
|
288
|
+
nn.Conv1d(input_dimension, filters, kernel_size=1, bias=False)
|
|
289
|
+
if input_dimension != filters
|
|
290
|
+
else None
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
for i in range(depth):
|
|
294
|
+
dilation = 2**i
|
|
295
|
+
padding = (kernel_size - 1) * dilation
|
|
296
|
+
conv_block = nn.Sequential(
|
|
297
|
+
nn.Conv1d(
|
|
298
|
+
in_channels=input_dimension if i == 0 else filters,
|
|
299
|
+
out_channels=filters,
|
|
300
|
+
kernel_size=kernel_size,
|
|
301
|
+
dilation=dilation,
|
|
302
|
+
padding=padding,
|
|
303
|
+
bias=False,
|
|
304
|
+
),
|
|
305
|
+
Chomp1d(padding),
|
|
306
|
+
self.activation,
|
|
307
|
+
nn.Dropout(self.drop_prob),
|
|
308
|
+
nn.Conv1d(
|
|
309
|
+
in_channels=filters,
|
|
310
|
+
out_channels=filters,
|
|
311
|
+
kernel_size=kernel_size,
|
|
312
|
+
dilation=dilation,
|
|
313
|
+
padding=padding,
|
|
314
|
+
bias=False,
|
|
315
|
+
),
|
|
316
|
+
Chomp1d(padding),
|
|
317
|
+
self.activation,
|
|
318
|
+
nn.Dropout(self.drop_prob),
|
|
319
|
+
)
|
|
320
|
+
self.layers.append(conv_block)
|
|
321
|
+
|
|
322
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
323
|
+
# x shape: (batch_size, time_steps, input_dimension)
|
|
324
|
+
x = x.permute(0, 2, 1) # (batch_size, input_dimension, time_steps)
|
|
325
|
+
|
|
326
|
+
res = x if self.downsample is None else self.downsample(x)
|
|
327
|
+
for layer in self.layers:
|
|
328
|
+
out = layer(x)
|
|
329
|
+
out = out + res
|
|
330
|
+
out = self.activation(out)
|
|
331
|
+
res = out # Update residual
|
|
332
|
+
x = out # Update input for next layer
|
|
333
|
+
|
|
334
|
+
out = out.permute(0, 2, 1) # (batch_size, time_steps, filters)
|
|
335
|
+
return out
|