braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EEGNeX(EEGModuleMixin, nn.Module):
|
|
17
|
+
r"""EEGNeX model from Chen et al (2024) [eegnex]_.
|
|
18
|
+
|
|
19
|
+
:bdg-success:`Convolution`
|
|
20
|
+
|
|
21
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: EEGNeX Architecture
|
|
24
|
+
:width: 620px
|
|
25
|
+
|
|
26
|
+
.. rubric:: Architectural Overview
|
|
27
|
+
|
|
28
|
+
EEGNeX is a **purely convolutional** architecture that refines the EEGNet-style stem
|
|
29
|
+
and deepens the temporal stack with **dilated temporal convolutions**. The end-to-end
|
|
30
|
+
flow is:
|
|
31
|
+
|
|
32
|
+
- (i) **Block-1/2**: two temporal convolutions ``(1 x L)`` with BN refine a
|
|
33
|
+
learned FIR-like *temporal filter bank* (no pooling yet);
|
|
34
|
+
- (ii) **Block-3**: depthwise **spatial** convolution across electrodes
|
|
35
|
+
``(n_chans x 1)`` with max-norm constraint, followed by ELU → AvgPool (time) → Dropout;
|
|
36
|
+
- (iii) **Block-4/5**: two additional **temporal** convolutions with increasing **dilation**
|
|
37
|
+
to expand the receptive field; the last block applies ELU → AvgPool → Dropout → Flatten;
|
|
38
|
+
- (iv) **Classifier**: a max-norm–constrained linear layer.
|
|
39
|
+
|
|
40
|
+
The published work positions EEGNeX as a compact, conv-only alternative that consistently
|
|
41
|
+
outperforms prior baselines across MOABB-style benchmarks, with the popular
|
|
42
|
+
“EEGNeX-8,32” shorthand denoting *8 temporal filters* and *kernel length 32*.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
.. rubric:: Macro Components
|
|
46
|
+
|
|
47
|
+
- **Block-1 / Block-2 — Temporal filter (learned).**
|
|
48
|
+
|
|
49
|
+
- *Operations.*
|
|
50
|
+
- :class:`torch.nn.Conv2d` with kernels ``(1, L)``
|
|
51
|
+
- :class:`torch.nn.BatchNorm2d` (no nonlinearity until Block-3, mirroring a linear FIR analysis stage).
|
|
52
|
+
These layers set up frequency-selective detectors before spatial mixing.
|
|
53
|
+
|
|
54
|
+
- *Interpretability.* Kernels can be inspected as FIR filters; two stacked temporal
|
|
55
|
+
convs allow longer effective kernels without parameter blow-up.
|
|
56
|
+
|
|
57
|
+
- **Block-3 — Spatial projection + condensation.**
|
|
58
|
+
|
|
59
|
+
- *Operations.*
|
|
60
|
+
- :class:`braindecode.modules.Conv2dWithConstraint` with kernel``(n_chans, 1)``
|
|
61
|
+
and ``groups = filter_2`` (depthwise across filters)
|
|
62
|
+
- :class:`torch.nn.BatchNorm2d`
|
|
63
|
+
- :class:`torch.nn.ELU`
|
|
64
|
+
- :class:`torch.nn.AvgPool2d` (time)
|
|
65
|
+
- :class:`torch.nn.Dropout`.
|
|
66
|
+
|
|
67
|
+
**Role**: Learns per-filter spatial patterns over the **full montage** while temporal
|
|
68
|
+
pooling stabilizes and compresses features; max-norm encourages well-behaved spatial
|
|
69
|
+
weights similar to EEGNet practice.
|
|
70
|
+
|
|
71
|
+
- **Block-4 / Block-5 — Dilated temporal integration.**
|
|
72
|
+
|
|
73
|
+
- *Operations.*
|
|
74
|
+
- :class:`torch.nn.Conv2d` with kernels ``(1, k)`` and **dilations**
|
|
75
|
+
(e.g., 2 then 4);
|
|
76
|
+
- :class:`torch.nn.BatchNorm2d`
|
|
77
|
+
- :class:`torch.nn.ELU`
|
|
78
|
+
- :class:`torch.nn.AvgPool2d` (time)
|
|
79
|
+
- :class:`torch.nn.Dropout`
|
|
80
|
+
- :class:`torch.nn.Flatten`.
|
|
81
|
+
|
|
82
|
+
**Role**: Expands the temporal receptive field efficiently to capture rhythms and
|
|
83
|
+
long-range context after condensation.
|
|
84
|
+
|
|
85
|
+
- **Final Classifier — Max-norm linear.**
|
|
86
|
+
|
|
87
|
+
- *Operations.*
|
|
88
|
+
- :class:`braindecode.modules.LinearWithConstraint` maps the flattened
|
|
89
|
+
vector to the target classes; the max-norm constraint regularizes the readout.
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
.. rubric:: Convolutional Details
|
|
93
|
+
|
|
94
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
95
|
+
Blocks 1-2 learn the primary filter bank (oscillations/transients), while Blocks 4-5
|
|
96
|
+
use **dilation** to integrate over longer horizons without extra pooling. The final
|
|
97
|
+
AvgPool in Block-5 sets the output token rate and helps noise suppression.
|
|
98
|
+
|
|
99
|
+
- **Spatial (how electrodes are processed).**
|
|
100
|
+
A *single* depthwise spatial conv (Block-3) spans the entire electrode set
|
|
101
|
+
(kernel ``(n_chans, 1)``), producing per-temporal-filter topographies; no cross-filter
|
|
102
|
+
mixing occurs at this stage, aiding interpretability.
|
|
103
|
+
|
|
104
|
+
- **Spectral (how frequency content is captured).**
|
|
105
|
+
Frequency selectivity emerges from the learned temporal kernels; dilation broadens effective
|
|
106
|
+
bandwidth coverage by composing multiple scales.
|
|
107
|
+
|
|
108
|
+
.. rubric:: Additional Mechanisms
|
|
109
|
+
|
|
110
|
+
- **EEGNeX-8,32 naming.** “8,32” indicates *8 temporal filters* and *kernel length 32*,
|
|
111
|
+
reflecting the paper's ablation path from EEGNet-8,2 toward thicker temporal kernels
|
|
112
|
+
and a deeper conv stack.
|
|
113
|
+
- **Max-norm constraints.** Spatial (Block-3) and final linear layers use max-norm
|
|
114
|
+
regularization—standard in EEG CNNs—to reduce overfitting and encourage stable spatial
|
|
115
|
+
patterns.
|
|
116
|
+
|
|
117
|
+
.. rubric:: Usage and Configuration
|
|
118
|
+
|
|
119
|
+
- **Kernel schedule.** Start with the canonical **EEGNeX-8,32** (``filter_1=8``,
|
|
120
|
+
``kernel_block_1_2=32``) and keep **Block-3** depth multiplier modest (e.g., 2) to match
|
|
121
|
+
the paper's “pure conv” profile.
|
|
122
|
+
- **Pooling vs. dilation.** Use pooling in Blocks 3 and 5 to control compute and variance;
|
|
123
|
+
increase dilations (Blocks 4-5) to widen temporal context when windows are short.
|
|
124
|
+
- **Regularization.** Combine dropout (Blocks 3 & 5) with max-norm on spatial and
|
|
125
|
+
classifier layers; prefer ELU activations for stable training on small EEG datasets.
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
- The braindecode implementation follows the paper's conv-only design with five blocks
|
|
129
|
+
and reproduces the depthwise spatial step and dilated temporal stack. See the class
|
|
130
|
+
reference for exact kernel sizes, dilations, and pooling defaults. You can check the
|
|
131
|
+
original implementation at [EEGNexCode]_.
|
|
132
|
+
|
|
133
|
+
.. versionadded:: 1.1
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
activation : nn.Module, optional
|
|
139
|
+
Activation function to use. Default is `nn.ELU`.
|
|
140
|
+
depth_multiplier : int, optional
|
|
141
|
+
Depth multiplier for the depthwise convolution. Default is 2.
|
|
142
|
+
filter_1 : int, optional
|
|
143
|
+
Number of filters in the first convolutional layer. Default is 8.
|
|
144
|
+
filter_2 : int, optional
|
|
145
|
+
Number of filters in the second convolutional layer. Default is 32.
|
|
146
|
+
drop_prob: float, optional
|
|
147
|
+
Dropout rate. Default is 0.5.
|
|
148
|
+
kernel_block_4 : tuple[int, int], optional
|
|
149
|
+
Kernel size for block 4. Default is (1, 16).
|
|
150
|
+
dilation_block_4 : tuple[int, int], optional
|
|
151
|
+
Dilation rate for block 4. Default is (1, 2).
|
|
152
|
+
avg_pool_block4 : tuple[int, int], optional
|
|
153
|
+
Pooling size for block 4. Default is (1, 4).
|
|
154
|
+
kernel_block_5 : tuple[int, int], optional
|
|
155
|
+
Kernel size for block 5. Default is (1, 16).
|
|
156
|
+
dilation_block_5 : tuple[int, int], optional
|
|
157
|
+
Dilation rate for block 5. Default is (1, 4).
|
|
158
|
+
avg_pool_block5 : tuple[int, int], optional
|
|
159
|
+
Pooling size for block 5. Default is (1, 8).
|
|
160
|
+
|
|
161
|
+
References
|
|
162
|
+
----------
|
|
163
|
+
.. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
164
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
165
|
+
study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
|
|
166
|
+
.. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
167
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
168
|
+
study to EEGNeX. https://github.com/chenxiachan/EEGNeX
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
# Signal related parameters
|
|
174
|
+
n_chans=None,
|
|
175
|
+
n_outputs=None,
|
|
176
|
+
n_times=None,
|
|
177
|
+
chs_info=None,
|
|
178
|
+
input_window_seconds=None,
|
|
179
|
+
sfreq=None,
|
|
180
|
+
# Model parameters
|
|
181
|
+
activation: type[nn.Module] = nn.ELU,
|
|
182
|
+
depth_multiplier: int = 2,
|
|
183
|
+
filter_1: int = 8,
|
|
184
|
+
filter_2: int = 32,
|
|
185
|
+
drop_prob: float = 0.5,
|
|
186
|
+
kernel_block_1_2: int = 64,
|
|
187
|
+
kernel_block_4: int = 16,
|
|
188
|
+
dilation_block_4: int = 2,
|
|
189
|
+
avg_pool_block4: int = 4,
|
|
190
|
+
kernel_block_5: int = 16,
|
|
191
|
+
dilation_block_5: int = 4,
|
|
192
|
+
avg_pool_block5: int = 8,
|
|
193
|
+
max_norm_conv: float = 1.0,
|
|
194
|
+
max_norm_linear: float = 0.25,
|
|
195
|
+
):
|
|
196
|
+
super().__init__(
|
|
197
|
+
n_outputs=n_outputs,
|
|
198
|
+
n_chans=n_chans,
|
|
199
|
+
chs_info=chs_info,
|
|
200
|
+
n_times=n_times,
|
|
201
|
+
input_window_seconds=input_window_seconds,
|
|
202
|
+
sfreq=sfreq,
|
|
203
|
+
)
|
|
204
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
205
|
+
|
|
206
|
+
self.depth_multiplier = depth_multiplier
|
|
207
|
+
self.filter_1 = filter_1
|
|
208
|
+
self.filter_2 = filter_2
|
|
209
|
+
self.filter_3 = self.filter_2 * self.depth_multiplier
|
|
210
|
+
self.drop_prob = drop_prob
|
|
211
|
+
self.activation = activation
|
|
212
|
+
self.kernel_block_1_2 = (1, kernel_block_1_2)
|
|
213
|
+
self.kernel_block_4 = (1, kernel_block_4)
|
|
214
|
+
self.dilation_block_4 = (1, dilation_block_4)
|
|
215
|
+
self.avg_pool_block4 = (1, avg_pool_block4)
|
|
216
|
+
self.kernel_block_5 = (1, kernel_block_5)
|
|
217
|
+
self.dilation_block_5 = (1, dilation_block_5)
|
|
218
|
+
self.avg_pool_block5 = (1, avg_pool_block5)
|
|
219
|
+
|
|
220
|
+
# final layers output
|
|
221
|
+
self.in_features = self._calculate_output_length()
|
|
222
|
+
|
|
223
|
+
# Following paper nomenclature
|
|
224
|
+
self.block_1 = nn.Sequential(
|
|
225
|
+
Rearrange("batch ch time -> batch 1 ch time"),
|
|
226
|
+
nn.Conv2d(
|
|
227
|
+
in_channels=1,
|
|
228
|
+
out_channels=self.filter_1,
|
|
229
|
+
kernel_size=self.kernel_block_1_2,
|
|
230
|
+
padding="same",
|
|
231
|
+
bias=False,
|
|
232
|
+
),
|
|
233
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
self.block_2 = nn.Sequential(
|
|
237
|
+
nn.Conv2d(
|
|
238
|
+
in_channels=self.filter_1,
|
|
239
|
+
out_channels=self.filter_2,
|
|
240
|
+
kernel_size=self.kernel_block_1_2,
|
|
241
|
+
padding="same",
|
|
242
|
+
bias=False,
|
|
243
|
+
),
|
|
244
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
self.block_3 = nn.Sequential(
|
|
248
|
+
Conv2dWithConstraint(
|
|
249
|
+
in_channels=self.filter_2,
|
|
250
|
+
out_channels=self.filter_3,
|
|
251
|
+
max_norm=max_norm_conv,
|
|
252
|
+
kernel_size=(self.n_chans, 1),
|
|
253
|
+
groups=self.filter_2,
|
|
254
|
+
bias=False,
|
|
255
|
+
),
|
|
256
|
+
nn.BatchNorm2d(num_features=self.filter_3),
|
|
257
|
+
self.activation(),
|
|
258
|
+
nn.AvgPool2d(
|
|
259
|
+
kernel_size=self.avg_pool_block4,
|
|
260
|
+
padding=(0, 1),
|
|
261
|
+
),
|
|
262
|
+
nn.Dropout(p=self.drop_prob),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.block_4 = nn.Sequential(
|
|
266
|
+
nn.Conv2d(
|
|
267
|
+
in_channels=self.filter_3,
|
|
268
|
+
out_channels=self.filter_2,
|
|
269
|
+
kernel_size=self.kernel_block_4,
|
|
270
|
+
dilation=self.dilation_block_4,
|
|
271
|
+
padding="same",
|
|
272
|
+
bias=False,
|
|
273
|
+
),
|
|
274
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.block_5 = nn.Sequential(
|
|
278
|
+
nn.Conv2d(
|
|
279
|
+
in_channels=self.filter_2,
|
|
280
|
+
out_channels=self.filter_1,
|
|
281
|
+
kernel_size=self.kernel_block_5,
|
|
282
|
+
dilation=self.dilation_block_5,
|
|
283
|
+
padding="same",
|
|
284
|
+
bias=False,
|
|
285
|
+
),
|
|
286
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
287
|
+
self.activation(),
|
|
288
|
+
nn.AvgPool2d(
|
|
289
|
+
kernel_size=self.avg_pool_block5,
|
|
290
|
+
padding=(0, 1),
|
|
291
|
+
),
|
|
292
|
+
nn.Dropout(p=self.drop_prob),
|
|
293
|
+
nn.Flatten(),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self.final_layer = LinearWithConstraint(
|
|
297
|
+
in_features=self.in_features,
|
|
298
|
+
out_features=self.n_outputs,
|
|
299
|
+
max_norm=max_norm_linear,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
303
|
+
"""
|
|
304
|
+
Forward pass of the EEGNeX model.
|
|
305
|
+
|
|
306
|
+
Parameters
|
|
307
|
+
----------
|
|
308
|
+
x : torch.Tensor
|
|
309
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
torch.Tensor
|
|
314
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
315
|
+
"""
|
|
316
|
+
# x shape: (batch_size, n_chans, n_times)
|
|
317
|
+
x = self.block_1(x)
|
|
318
|
+
# (batch_size, n_filter, n_chans, n_times)
|
|
319
|
+
x = self.block_2(x)
|
|
320
|
+
# (batch_size, n_filter*4, n_chans, n_times)
|
|
321
|
+
x = self.block_3(x)
|
|
322
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
323
|
+
x = self.block_4(x)
|
|
324
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
325
|
+
x = self.block_5(x)
|
|
326
|
+
# (batch_size, n_filter*(n_times//32))
|
|
327
|
+
x = self.final_layer(x)
|
|
328
|
+
|
|
329
|
+
return x
|
|
330
|
+
|
|
331
|
+
def _calculate_output_length(self) -> int:
|
|
332
|
+
# Pooling kernel sizes for the time dimension
|
|
333
|
+
p4 = self.avg_pool_block4[1]
|
|
334
|
+
p5 = self.avg_pool_block5[1]
|
|
335
|
+
|
|
336
|
+
# Padding for the time dimension (assumed from padding=(0, 1))
|
|
337
|
+
pad4 = 1
|
|
338
|
+
pad5 = 1
|
|
339
|
+
|
|
340
|
+
# Stride is assumed to be equal to kernel size (p4 and p5)
|
|
341
|
+
|
|
342
|
+
# Calculate time dimension after block 3 pooling
|
|
343
|
+
# Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
|
|
344
|
+
T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
|
|
345
|
+
|
|
346
|
+
# Calculate time dimension after block 5 pooling
|
|
347
|
+
T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
|
|
348
|
+
|
|
349
|
+
# Calculate final flattened features (channels * 1 * time_dim)
|
|
350
|
+
# The spatial dimension is reduced to 1 after block 3's depthwise conv.
|
|
351
|
+
final_in_features = (
|
|
352
|
+
self.filter_1 * T5
|
|
353
|
+
) # filter_1 is the number of channels before flatten
|
|
354
|
+
return final_in_features
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EEG-SimpleConv is a 1D Convolutional Neural Network from Yassine El Ouahidi et al. (2023).
|
|
3
|
+
|
|
4
|
+
Originally designed for Motor Imagery decoding, from EEG signals.
|
|
5
|
+
The model offers competitive performances, with a low latency and is mainly composed of
|
|
6
|
+
1D convolutional layers.
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Authors: Yassine El Ouahidi <eloua.yas@gmail.com>
|
|
11
|
+
#
|
|
12
|
+
# License: BSD-3
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torch import nn
|
|
16
|
+
from torchaudio.transforms import Resample
|
|
17
|
+
|
|
18
|
+
from braindecode.models.base import EEGModuleMixin
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
|
|
22
|
+
r"""EEGSimpleConv from Ouahidi, YE et al (2023) [Yassine2023]_.
|
|
23
|
+
|
|
24
|
+
:bdg-success:`Convolution`
|
|
25
|
+
|
|
26
|
+
.. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
|
|
27
|
+
:align: center
|
|
28
|
+
:alt: EEGSimpleConv Architecture
|
|
29
|
+
|
|
30
|
+
EEGSimpleConv is a 1D Convolutional Neural Network originally designed
|
|
31
|
+
for decoding motor imagery from EEG signals. The model aims to have a
|
|
32
|
+
very simple and straightforward architecture that allows a low latency,
|
|
33
|
+
while still achieving very competitive performance.
|
|
34
|
+
|
|
35
|
+
EEG-SimpleConv starts with a 1D convolutional layer, where each EEG channel
|
|
36
|
+
enters a separate 1D convolutional channel. This is followed by a series of
|
|
37
|
+
blocks of two 1D convolutional layers. Between the two convolutional layers
|
|
38
|
+
of each block is a max pooling layer, which downsamples the data by a factor
|
|
39
|
+
of 2. Each convolution is followed by a batch normalisation layer and a ReLU
|
|
40
|
+
activation function. Finally, a global average pooling (in the time domain)
|
|
41
|
+
is performed to obtain a single value per feature map, which is then fed
|
|
42
|
+
into a linear layer to obtain the final classification prediction output.
|
|
43
|
+
|
|
44
|
+
The paper and original code with more details about the methodological
|
|
45
|
+
choices are available at the [Yassine2023]_ and [Yassine2023Code]_.
|
|
46
|
+
|
|
47
|
+
The input shape should be three-dimensional matrix representing the EEG
|
|
48
|
+
signals.
|
|
49
|
+
|
|
50
|
+
``(batch_size, n_channels, n_timesteps)``.
|
|
51
|
+
|
|
52
|
+
Notes
|
|
53
|
+
-----
|
|
54
|
+
The authors recommend using the default parameters for MI decoding.
|
|
55
|
+
Please refer to the original paper and code for more details.
|
|
56
|
+
|
|
57
|
+
Recommended range for the choice of the hyperparameters, regarding the
|
|
58
|
+
evaluation paradigm.
|
|
59
|
+
|
|
60
|
+
| Parameter | Within-Subject | Cross-Subject |
|
|
61
|
+
| feature_maps | [64-144] | [64-144] |
|
|
62
|
+
| n_convs | 1 | [2-4] |
|
|
63
|
+
| resampling_freq | [70-100] | [50-80] |
|
|
64
|
+
| kernel_size | [12-17] | [5-8] |
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
An intensive ablation study is included in the paper to understand the
|
|
68
|
+
of each parameter on the model performance.
|
|
69
|
+
|
|
70
|
+
.. versionadded:: 0.9
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
feature_maps: int
|
|
75
|
+
Number of Feature Maps at the first Convolution, width of the model.
|
|
76
|
+
n_convs: int
|
|
77
|
+
Number of blocks of convolutions (2 convolutions per block), depth of the model.
|
|
78
|
+
resampling: int
|
|
79
|
+
Resampling Frequency.
|
|
80
|
+
kernel_size: int
|
|
81
|
+
Size of the convolutions kernels.
|
|
82
|
+
activation: nn.Module, default=nn.ELU
|
|
83
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
84
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
85
|
+
|
|
86
|
+
References
|
|
87
|
+
----------
|
|
88
|
+
.. [Yassine2023] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
|
|
89
|
+
N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
|
|
90
|
+
BCI Motor Imagery Decoding. Arxiv preprint. arxiv.org/abs/2309.07159
|
|
91
|
+
.. [Yassine2023Code] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
|
|
92
|
+
N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
|
|
93
|
+
BCI Motor Imagery Decoding. GitHub repository.
|
|
94
|
+
https://github.com/elouayas/EEGSimpleConv.
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
# Base arguments
|
|
101
|
+
n_outputs=None,
|
|
102
|
+
n_chans=None,
|
|
103
|
+
sfreq=None,
|
|
104
|
+
# Model specific arguments
|
|
105
|
+
feature_maps=128,
|
|
106
|
+
n_convs=2,
|
|
107
|
+
resampling_freq=80,
|
|
108
|
+
kernel_size=8,
|
|
109
|
+
return_feature=False,
|
|
110
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
111
|
+
# Other ways to initialize the model
|
|
112
|
+
chs_info=None,
|
|
113
|
+
n_times=None,
|
|
114
|
+
input_window_seconds=None,
|
|
115
|
+
):
|
|
116
|
+
super().__init__(
|
|
117
|
+
n_outputs=n_outputs,
|
|
118
|
+
n_chans=n_chans,
|
|
119
|
+
chs_info=chs_info,
|
|
120
|
+
n_times=n_times,
|
|
121
|
+
input_window_seconds=input_window_seconds,
|
|
122
|
+
sfreq=sfreq,
|
|
123
|
+
)
|
|
124
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
125
|
+
|
|
126
|
+
self.return_feature = return_feature
|
|
127
|
+
self.resample = (
|
|
128
|
+
Resample(orig_freq=int(self.sfreq), new_freq=int(resampling_freq))
|
|
129
|
+
if self.sfreq != resampling_freq
|
|
130
|
+
else torch.nn.Identity()
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.conv = torch.nn.Conv1d(
|
|
134
|
+
self.n_chans,
|
|
135
|
+
feature_maps,
|
|
136
|
+
kernel_size=kernel_size,
|
|
137
|
+
padding=kernel_size // 2,
|
|
138
|
+
bias=False,
|
|
139
|
+
)
|
|
140
|
+
self.bn = torch.nn.BatchNorm1d(feature_maps)
|
|
141
|
+
self.blocks = []
|
|
142
|
+
new_feature_maps = feature_maps
|
|
143
|
+
old_feature_maps = feature_maps
|
|
144
|
+
for i in range(n_convs):
|
|
145
|
+
if i > 0:
|
|
146
|
+
# 1.414 = sqrt(2) allow constant flops.
|
|
147
|
+
new_feature_maps = int(1.414 * new_feature_maps)
|
|
148
|
+
self.blocks.append(
|
|
149
|
+
torch.nn.Sequential(
|
|
150
|
+
(
|
|
151
|
+
torch.nn.Conv1d(
|
|
152
|
+
old_feature_maps,
|
|
153
|
+
new_feature_maps,
|
|
154
|
+
kernel_size=kernel_size,
|
|
155
|
+
padding=kernel_size // 2,
|
|
156
|
+
bias=False,
|
|
157
|
+
)
|
|
158
|
+
),
|
|
159
|
+
(torch.nn.BatchNorm1d(new_feature_maps)),
|
|
160
|
+
(torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
|
|
161
|
+
(activation()),
|
|
162
|
+
(
|
|
163
|
+
torch.nn.Conv1d(
|
|
164
|
+
new_feature_maps,
|
|
165
|
+
new_feature_maps,
|
|
166
|
+
kernel_size=kernel_size,
|
|
167
|
+
padding=kernel_size // 2,
|
|
168
|
+
bias=False,
|
|
169
|
+
)
|
|
170
|
+
),
|
|
171
|
+
(torch.nn.BatchNorm1d(new_feature_maps)),
|
|
172
|
+
(activation()),
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
old_feature_maps = new_feature_maps
|
|
176
|
+
self.blocks = torch.nn.ModuleList(self.blocks)
|
|
177
|
+
self.final_layer = torch.nn.Linear(old_feature_maps, self.n_outputs)
|
|
178
|
+
|
|
179
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
180
|
+
"""
|
|
181
|
+
Forward pass of the model.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
x: PyTorch Tensor
|
|
186
|
+
Input tensor of shape (batch_size, n_channels, n_times)
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
PyTorch Tensor (optional)
|
|
191
|
+
Output tensor of shape (batch_size, n_outputs)
|
|
192
|
+
"""
|
|
193
|
+
x_rs = self.resample(x.contiguous())
|
|
194
|
+
feat = torch.relu(self.bn(self.conv(x_rs)))
|
|
195
|
+
for seq in self.blocks:
|
|
196
|
+
feat = seq(feat)
|
|
197
|
+
feat = feat.mean(dim=2)
|
|
198
|
+
if self.return_feature:
|
|
199
|
+
return feat
|
|
200
|
+
else:
|
|
201
|
+
return self.final_layer(feat)
|