braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
# Authors: Yonghao Song <eeyhsong@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
from torch import Tensor, nn
|
|
12
|
+
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import FeedForwardBlock, MultiHeadAttention
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EEGConformer(EEGModuleMixin, nn.Module):
|
|
18
|
+
"""EEG Conformer from Song et al. (2022) from [song2022]_.
|
|
19
|
+
|
|
20
|
+
.. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: EEGConformer Architecture
|
|
23
|
+
|
|
24
|
+
Convolutional Transformer for EEG decoding.
|
|
25
|
+
|
|
26
|
+
The paper and original code with more details about the methodological
|
|
27
|
+
choices are available at the [song2022]_ and [ConformerCode]_.
|
|
28
|
+
|
|
29
|
+
This neural network architecture receives a traditional braindecode input.
|
|
30
|
+
The input shape should be three-dimensional matrix representing the EEG
|
|
31
|
+
signals.
|
|
32
|
+
|
|
33
|
+
`(batch_size, n_channels, n_timesteps)`.
|
|
34
|
+
|
|
35
|
+
The EEG Conformer architecture is composed of three modules:
|
|
36
|
+
- PatchEmbedding
|
|
37
|
+
- TransformerEncoder
|
|
38
|
+
- ClassificationHead
|
|
39
|
+
|
|
40
|
+
Notes
|
|
41
|
+
-----
|
|
42
|
+
The authors recommend using data augmentation before using Conformer,
|
|
43
|
+
e.g. segmentation and recombination,
|
|
44
|
+
Please refer to the original paper and code for more details.
|
|
45
|
+
|
|
46
|
+
The model was initially tuned on 4 seconds of 250 Hz data.
|
|
47
|
+
Please adjust the scale of the temporal convolutional layer,
|
|
48
|
+
and the pooling layer for better performance.
|
|
49
|
+
|
|
50
|
+
.. versionadded:: 0.8
|
|
51
|
+
|
|
52
|
+
We aggregate the parameters based on the parts of the models, or
|
|
53
|
+
when the parameters were used first, e.g. n_filters_time.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
n_filters_time: int
|
|
58
|
+
Number of temporal filters, defines also embedding size.
|
|
59
|
+
filter_time_length: int
|
|
60
|
+
Length of the temporal filter.
|
|
61
|
+
pool_time_length: int
|
|
62
|
+
Length of temporal pooling filter.
|
|
63
|
+
pool_time_stride: int
|
|
64
|
+
Length of stride between temporal pooling filters.
|
|
65
|
+
drop_prob: float
|
|
66
|
+
Dropout rate of the convolutional layer.
|
|
67
|
+
att_depth: int
|
|
68
|
+
Number of self-attention layers.
|
|
69
|
+
att_heads: int
|
|
70
|
+
Number of attention heads.
|
|
71
|
+
att_drop_prob: float
|
|
72
|
+
Dropout rate of the self-attention layer.
|
|
73
|
+
final_fc_length: int | str
|
|
74
|
+
The dimension of the fully connected layer.
|
|
75
|
+
return_features: bool
|
|
76
|
+
If True, the forward method returns the features before the
|
|
77
|
+
last classification layer. Defaults to False.
|
|
78
|
+
activation: nn.Module
|
|
79
|
+
Activation function as parameter. Default is nn.ELU
|
|
80
|
+
activation_transfor: nn.Module
|
|
81
|
+
Activation function as parameter, applied at the FeedForwardBlock module
|
|
82
|
+
inside the transformer. Default is nn.GeLU
|
|
83
|
+
|
|
84
|
+
References
|
|
85
|
+
----------
|
|
86
|
+
.. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
87
|
+
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
88
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
|
|
89
|
+
31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
|
|
90
|
+
.. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
|
|
91
|
+
conformer: Convolutional transformer for EEG decoding and visualization.
|
|
92
|
+
https://github.com/eeyhsong/EEG-Conformer.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
n_outputs=None,
|
|
98
|
+
n_chans=None,
|
|
99
|
+
n_filters_time=40,
|
|
100
|
+
filter_time_length=25,
|
|
101
|
+
pool_time_length=75,
|
|
102
|
+
pool_time_stride=15,
|
|
103
|
+
drop_prob=0.5,
|
|
104
|
+
att_depth=6,
|
|
105
|
+
att_heads=10,
|
|
106
|
+
att_drop_prob=0.5,
|
|
107
|
+
final_fc_length="auto",
|
|
108
|
+
return_features=False,
|
|
109
|
+
activation: nn.Module = nn.ELU,
|
|
110
|
+
activation_transfor: nn.Module = nn.GELU,
|
|
111
|
+
n_times=None,
|
|
112
|
+
chs_info=None,
|
|
113
|
+
input_window_seconds=None,
|
|
114
|
+
sfreq=None,
|
|
115
|
+
):
|
|
116
|
+
super().__init__(
|
|
117
|
+
n_outputs=n_outputs,
|
|
118
|
+
n_chans=n_chans,
|
|
119
|
+
chs_info=chs_info,
|
|
120
|
+
n_times=n_times,
|
|
121
|
+
input_window_seconds=input_window_seconds,
|
|
122
|
+
sfreq=sfreq,
|
|
123
|
+
)
|
|
124
|
+
self.mapping = {
|
|
125
|
+
"classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
|
|
126
|
+
"classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
130
|
+
if not (self.n_chans <= 64):
|
|
131
|
+
warnings.warn(
|
|
132
|
+
"This model has only been tested on no more "
|
|
133
|
+
+ "than 64 channels. no guarantee to work with "
|
|
134
|
+
+ "more channels.",
|
|
135
|
+
UserWarning,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.return_features = return_features
|
|
139
|
+
|
|
140
|
+
self.patch_embedding = _PatchEmbedding(
|
|
141
|
+
n_filters_time=n_filters_time,
|
|
142
|
+
filter_time_length=filter_time_length,
|
|
143
|
+
n_channels=self.n_chans,
|
|
144
|
+
pool_time_length=pool_time_length,
|
|
145
|
+
stride_avg_pool=pool_time_stride,
|
|
146
|
+
drop_prob=drop_prob,
|
|
147
|
+
activation=activation,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if final_fc_length == "auto":
|
|
151
|
+
assert self.n_times is not None
|
|
152
|
+
self.final_fc_length = self.get_fc_size()
|
|
153
|
+
|
|
154
|
+
self.transformer = _TransformerEncoder(
|
|
155
|
+
att_depth=att_depth,
|
|
156
|
+
emb_size=n_filters_time,
|
|
157
|
+
att_heads=att_heads,
|
|
158
|
+
att_drop=att_drop_prob,
|
|
159
|
+
activation=activation_transfor,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
self.fc = _FullyConnected(
|
|
163
|
+
final_fc_length=self.final_fc_length, activation=activation
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
|
|
167
|
+
|
|
168
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
169
|
+
x = torch.unsqueeze(x, dim=1) # add one extra dimension
|
|
170
|
+
x = self.patch_embedding(x)
|
|
171
|
+
feature = self.transformer(x)
|
|
172
|
+
|
|
173
|
+
if self.return_features:
|
|
174
|
+
return feature
|
|
175
|
+
|
|
176
|
+
x = self.fc(feature)
|
|
177
|
+
x = self.final_layer(x)
|
|
178
|
+
return x
|
|
179
|
+
|
|
180
|
+
def get_fc_size(self):
|
|
181
|
+
out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
|
|
182
|
+
size_embedding_1 = out.cpu().data.numpy().shape[1]
|
|
183
|
+
size_embedding_2 = out.cpu().data.numpy().shape[2]
|
|
184
|
+
|
|
185
|
+
return size_embedding_1 * size_embedding_2
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class _PatchEmbedding(nn.Module):
|
|
189
|
+
"""Patch Embedding.
|
|
190
|
+
|
|
191
|
+
The authors used a convolution module to capture local features,
|
|
192
|
+
instead of position embedding.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
n_filters_time: int
|
|
197
|
+
Number of temporal filters, defines also embedding size.
|
|
198
|
+
filter_time_length: int
|
|
199
|
+
Length of the temporal filter.
|
|
200
|
+
n_channels: int
|
|
201
|
+
Number of channels to be used as number of spatial filters.
|
|
202
|
+
pool_time_length: int
|
|
203
|
+
Length of temporal poling filter.
|
|
204
|
+
stride_avg_pool: int
|
|
205
|
+
Length of stride between temporal pooling filters.
|
|
206
|
+
drop_prob: float
|
|
207
|
+
Dropout rate of the convolutional layer.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
x: torch.Tensor
|
|
212
|
+
The output tensor of the patch embedding layer.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
n_filters_time,
|
|
218
|
+
filter_time_length,
|
|
219
|
+
n_channels,
|
|
220
|
+
pool_time_length,
|
|
221
|
+
stride_avg_pool,
|
|
222
|
+
drop_prob,
|
|
223
|
+
activation: nn.Module = nn.ELU,
|
|
224
|
+
):
|
|
225
|
+
super().__init__()
|
|
226
|
+
|
|
227
|
+
self.shallownet = nn.Sequential(
|
|
228
|
+
nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
|
|
229
|
+
nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
|
|
230
|
+
nn.BatchNorm2d(num_features=n_filters_time),
|
|
231
|
+
activation(),
|
|
232
|
+
nn.AvgPool2d(
|
|
233
|
+
kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
|
|
234
|
+
),
|
|
235
|
+
# pooling acts as slicing to obtain 'patch' along the
|
|
236
|
+
# time dimension as in ViT
|
|
237
|
+
nn.Dropout(p=drop_prob),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.projection = nn.Sequential(
|
|
241
|
+
nn.Conv2d(
|
|
242
|
+
n_filters_time, n_filters_time, (1, 1), stride=(1, 1)
|
|
243
|
+
), # transpose, conv could enhance fiting ability slightly
|
|
244
|
+
Rearrange("b d_model 1 seq -> b seq d_model"),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
248
|
+
x = self.shallownet(x)
|
|
249
|
+
x = self.projection(x)
|
|
250
|
+
return x
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class _ResidualAdd(nn.Module):
|
|
254
|
+
def __init__(self, fn):
|
|
255
|
+
super().__init__()
|
|
256
|
+
self.fn = fn
|
|
257
|
+
|
|
258
|
+
def forward(self, x):
|
|
259
|
+
res = x
|
|
260
|
+
x = self.fn(x)
|
|
261
|
+
x += res
|
|
262
|
+
return x
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class _TransformerEncoderBlock(nn.Sequential):
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
emb_size,
|
|
269
|
+
att_heads,
|
|
270
|
+
att_drop,
|
|
271
|
+
forward_expansion=4,
|
|
272
|
+
activation: nn.Module = nn.GELU,
|
|
273
|
+
):
|
|
274
|
+
super().__init__(
|
|
275
|
+
_ResidualAdd(
|
|
276
|
+
nn.Sequential(
|
|
277
|
+
nn.LayerNorm(emb_size),
|
|
278
|
+
MultiHeadAttention(emb_size, att_heads, att_drop),
|
|
279
|
+
nn.Dropout(att_drop),
|
|
280
|
+
)
|
|
281
|
+
),
|
|
282
|
+
_ResidualAdd(
|
|
283
|
+
nn.Sequential(
|
|
284
|
+
nn.LayerNorm(emb_size),
|
|
285
|
+
FeedForwardBlock(
|
|
286
|
+
emb_size,
|
|
287
|
+
expansion=forward_expansion,
|
|
288
|
+
drop_p=att_drop,
|
|
289
|
+
activation=activation,
|
|
290
|
+
),
|
|
291
|
+
nn.Dropout(att_drop),
|
|
292
|
+
)
|
|
293
|
+
),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class _TransformerEncoder(nn.Sequential):
|
|
298
|
+
"""Transformer encoder module for the transformer encoder.
|
|
299
|
+
|
|
300
|
+
Similar to the layers used in ViT.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
att_depth : int
|
|
305
|
+
Number of transformer encoder blocks.
|
|
306
|
+
emb_size : int
|
|
307
|
+
Embedding size of the transformer encoder.
|
|
308
|
+
att_heads : int
|
|
309
|
+
Number of attention heads.
|
|
310
|
+
att_drop : float
|
|
311
|
+
Dropout probability for the attention layers.
|
|
312
|
+
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
|
|
317
|
+
):
|
|
318
|
+
super().__init__(
|
|
319
|
+
*[
|
|
320
|
+
_TransformerEncoderBlock(
|
|
321
|
+
emb_size, att_heads, att_drop, activation=activation
|
|
322
|
+
)
|
|
323
|
+
for _ in range(att_depth)
|
|
324
|
+
]
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class _FullyConnected(nn.Module):
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
final_fc_length,
|
|
332
|
+
drop_prob_1=0.5,
|
|
333
|
+
drop_prob_2=0.3,
|
|
334
|
+
out_channels=256,
|
|
335
|
+
hidden_channels=32,
|
|
336
|
+
activation: nn.Module = nn.ELU,
|
|
337
|
+
):
|
|
338
|
+
"""Fully-connected layer for the transformer encoder.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
final_fc_length : int
|
|
343
|
+
Length of the final fully connected layer.
|
|
344
|
+
n_classes : int
|
|
345
|
+
Number of classes for classification.
|
|
346
|
+
drop_prob_1 : float
|
|
347
|
+
Dropout probability for the first dropout layer.
|
|
348
|
+
drop_prob_2 : float
|
|
349
|
+
Dropout probability for the second dropout layer.
|
|
350
|
+
out_channels : int
|
|
351
|
+
Number of output channels for the first linear layer.
|
|
352
|
+
hidden_channels : int
|
|
353
|
+
Number of output channels for the second linear layer.
|
|
354
|
+
return_features : bool
|
|
355
|
+
Whether to return input features.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
super().__init__()
|
|
359
|
+
self.hidden_channels = hidden_channels
|
|
360
|
+
self.fc = nn.Sequential(
|
|
361
|
+
nn.Linear(final_fc_length, out_channels),
|
|
362
|
+
activation(),
|
|
363
|
+
nn.Dropout(drop_prob_1),
|
|
364
|
+
nn.Linear(out_channels, hidden_channels),
|
|
365
|
+
activation(),
|
|
366
|
+
nn.Dropout(drop_prob_2),
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def forward(self, x):
|
|
370
|
+
x = x.contiguous().view(x.size(0), -1)
|
|
371
|
+
out = self.fc(x)
|
|
372
|
+
return out
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
# Cedric Rommel <cedric.rommel@inria.fr>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
16
|
+
"""EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
|
|
17
|
+
|
|
18
|
+
.. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: EEGInceptionERP Architecture
|
|
21
|
+
|
|
22
|
+
The code for the paper and this model is also available at [santamaria2020]_
|
|
23
|
+
and an adaptation for PyTorch [2]_.
|
|
24
|
+
|
|
25
|
+
The model is strongly based on the original InceptionNet for an image. The main goal is
|
|
26
|
+
to extract features in parallel with different scales. The authors extracted three scales
|
|
27
|
+
proportional to the window sample size. The network had three parts:
|
|
28
|
+
1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
|
|
29
|
+
for classification.
|
|
30
|
+
|
|
31
|
+
One advantage of the EEG-Inception block is that it allows a network
|
|
32
|
+
to learn simultaneous components of low and high frequency associated with the signal.
|
|
33
|
+
The winners of BEETL Competition/NeurIps 2021 used parts of the
|
|
34
|
+
model [beetl]_.
|
|
35
|
+
|
|
36
|
+
The model is fully described in [santamaria2020]_.
|
|
37
|
+
|
|
38
|
+
Notes
|
|
39
|
+
-----
|
|
40
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
41
|
+
by original authors, only reimplemented from the paper based on [2]_.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
n_times : int, optional
|
|
46
|
+
Size of the input, in number of samples. Set to 128 (1s) as in
|
|
47
|
+
[santamaria2020]_.
|
|
48
|
+
sfreq : float, optional
|
|
49
|
+
EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
|
|
50
|
+
drop_prob : float, optional
|
|
51
|
+
Dropout rate inside all the network. Defaults to 0.5 as in
|
|
52
|
+
[santamaria2020]_.
|
|
53
|
+
scales_samples_s: list(float), optional
|
|
54
|
+
Windows for inception block. Temporal scale (s) of the convolutions on
|
|
55
|
+
each Inception module. This parameter determines the kernel sizes of
|
|
56
|
+
the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
|
|
57
|
+
[santamaria2020]_.
|
|
58
|
+
n_filters : int, optional
|
|
59
|
+
Initial number of convolutional filters. Defaults to 8 as in
|
|
60
|
+
[santamaria2020]_.
|
|
61
|
+
activation: nn.Module, optional
|
|
62
|
+
Activation function. Defaults to ELU activation as in
|
|
63
|
+
[santamaria2020]_.
|
|
64
|
+
batch_norm_alpha: float, optional
|
|
65
|
+
Momentum for BatchNorm2d. Defaults to 0.01.
|
|
66
|
+
depth_multiplier: int, optional
|
|
67
|
+
Depth multiplier for the depthwise convolution. Defaults to 2 as in
|
|
68
|
+
[santamaria2020]_.
|
|
69
|
+
pooling_sizes: list(int), optional
|
|
70
|
+
Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
|
|
71
|
+
in [santamaria2020]_.
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
References
|
|
75
|
+
----------
|
|
76
|
+
.. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
|
|
77
|
+
Vaquerizo-Villar, F., & Hornero, R. (2020).
|
|
78
|
+
EEG-inception: A novel deep convolutional neural network for assistive
|
|
79
|
+
ERP-based brain-computer interfaces.
|
|
80
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
|
|
81
|
+
Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
|
|
82
|
+
.. [2] Grifcc. Implementation of the EEGInception in torch (2022).
|
|
83
|
+
Online: https://github.com/Grifcc/EEG/
|
|
84
|
+
.. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
|
|
85
|
+
Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
|
|
86
|
+
Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
|
|
87
|
+
Lawhern, V.J., Śliwowski, M., Rouanne, V. & Tempczyk, P. (2022).
|
|
88
|
+
2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &
|
|
89
|
+
Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
|
|
90
|
+
Demonstrations Track, in Proceedings of Machine Learning Research
|
|
91
|
+
176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
n_chans=None,
|
|
98
|
+
n_outputs=None,
|
|
99
|
+
n_times=1000,
|
|
100
|
+
sfreq=128,
|
|
101
|
+
drop_prob=0.5,
|
|
102
|
+
scales_samples_s=(0.5, 0.25, 0.125),
|
|
103
|
+
n_filters=8,
|
|
104
|
+
activation: nn.Module = nn.ELU,
|
|
105
|
+
batch_norm_alpha=0.01,
|
|
106
|
+
depth_multiplier=2,
|
|
107
|
+
pooling_sizes=(4, 2, 2, 2),
|
|
108
|
+
chs_info=None,
|
|
109
|
+
input_window_seconds=None,
|
|
110
|
+
):
|
|
111
|
+
super().__init__(
|
|
112
|
+
n_outputs=n_outputs,
|
|
113
|
+
n_chans=n_chans,
|
|
114
|
+
chs_info=chs_info,
|
|
115
|
+
n_times=n_times,
|
|
116
|
+
input_window_seconds=input_window_seconds,
|
|
117
|
+
sfreq=sfreq,
|
|
118
|
+
)
|
|
119
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
120
|
+
self.drop_prob = drop_prob
|
|
121
|
+
self.n_filters = n_filters
|
|
122
|
+
self.scales_samples_s = scales_samples_s
|
|
123
|
+
self.scales_samples = tuple(
|
|
124
|
+
int(size_s * self.sfreq) for size_s in self.scales_samples_s
|
|
125
|
+
)
|
|
126
|
+
self.activation = activation
|
|
127
|
+
self.alpha_momentum = batch_norm_alpha
|
|
128
|
+
self.depth_multiplier = depth_multiplier
|
|
129
|
+
self.pooling_sizes = pooling_sizes
|
|
130
|
+
|
|
131
|
+
self.mapping = {
|
|
132
|
+
"classification.1.weight": "final_layer.fc.weight",
|
|
133
|
+
"classification.1.bias": "final_layer.fc.bias",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
self.add_module("ensuredims", Ensure4d())
|
|
137
|
+
|
|
138
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
|
|
139
|
+
|
|
140
|
+
# ======== Inception branches ========================
|
|
141
|
+
block11 = self._get_inception_branch_1(
|
|
142
|
+
in_channels=self.n_chans,
|
|
143
|
+
out_channels=self.n_filters,
|
|
144
|
+
kernel_length=self.scales_samples[0],
|
|
145
|
+
alpha_momentum=self.alpha_momentum,
|
|
146
|
+
activation=self.activation,
|
|
147
|
+
drop_prob=self.drop_prob,
|
|
148
|
+
depth_multiplier=self.depth_multiplier,
|
|
149
|
+
)
|
|
150
|
+
block12 = self._get_inception_branch_1(
|
|
151
|
+
in_channels=self.n_chans,
|
|
152
|
+
out_channels=self.n_filters,
|
|
153
|
+
kernel_length=self.scales_samples[1],
|
|
154
|
+
alpha_momentum=self.alpha_momentum,
|
|
155
|
+
activation=self.activation,
|
|
156
|
+
drop_prob=self.drop_prob,
|
|
157
|
+
depth_multiplier=self.depth_multiplier,
|
|
158
|
+
)
|
|
159
|
+
block13 = self._get_inception_branch_1(
|
|
160
|
+
in_channels=self.n_chans,
|
|
161
|
+
out_channels=self.n_filters,
|
|
162
|
+
kernel_length=self.scales_samples[2],
|
|
163
|
+
alpha_momentum=self.alpha_momentum,
|
|
164
|
+
activation=self.activation,
|
|
165
|
+
drop_prob=self.drop_prob,
|
|
166
|
+
depth_multiplier=self.depth_multiplier,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.add_module(
|
|
170
|
+
"inception_block_1", InceptionBlock((block11, block12, block13))
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
|
|
174
|
+
|
|
175
|
+
# ======== Inception branches ========================
|
|
176
|
+
n_concat_filters = len(self.scales_samples) * self.n_filters
|
|
177
|
+
n_concat_dw_filters = n_concat_filters * self.depth_multiplier
|
|
178
|
+
block21 = self._get_inception_branch_2(
|
|
179
|
+
in_channels=n_concat_dw_filters,
|
|
180
|
+
out_channels=self.n_filters,
|
|
181
|
+
kernel_length=self.scales_samples[0] // 4,
|
|
182
|
+
alpha_momentum=self.alpha_momentum,
|
|
183
|
+
activation=self.activation,
|
|
184
|
+
drop_prob=self.drop_prob,
|
|
185
|
+
)
|
|
186
|
+
block22 = self._get_inception_branch_2(
|
|
187
|
+
in_channels=n_concat_dw_filters,
|
|
188
|
+
out_channels=self.n_filters,
|
|
189
|
+
kernel_length=self.scales_samples[1] // 4,
|
|
190
|
+
alpha_momentum=self.alpha_momentum,
|
|
191
|
+
activation=self.activation,
|
|
192
|
+
drop_prob=self.drop_prob,
|
|
193
|
+
)
|
|
194
|
+
block23 = self._get_inception_branch_2(
|
|
195
|
+
in_channels=n_concat_dw_filters,
|
|
196
|
+
out_channels=self.n_filters,
|
|
197
|
+
kernel_length=self.scales_samples[2] // 4,
|
|
198
|
+
alpha_momentum=self.alpha_momentum,
|
|
199
|
+
activation=self.activation,
|
|
200
|
+
drop_prob=self.drop_prob,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.add_module(
|
|
204
|
+
"inception_block_2", InceptionBlock((block21, block22, block23))
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
|
|
208
|
+
|
|
209
|
+
self.add_module(
|
|
210
|
+
"final_block",
|
|
211
|
+
nn.Sequential(
|
|
212
|
+
nn.Conv2d(
|
|
213
|
+
n_concat_filters,
|
|
214
|
+
n_concat_filters // 2,
|
|
215
|
+
(1, 8),
|
|
216
|
+
padding="same",
|
|
217
|
+
bias=False,
|
|
218
|
+
),
|
|
219
|
+
nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
|
|
220
|
+
activation(),
|
|
221
|
+
nn.Dropout(self.drop_prob),
|
|
222
|
+
nn.AvgPool2d((1, self.pooling_sizes[2])),
|
|
223
|
+
nn.Conv2d(
|
|
224
|
+
n_concat_filters // 2,
|
|
225
|
+
n_concat_filters // 4,
|
|
226
|
+
(1, 4),
|
|
227
|
+
padding="same",
|
|
228
|
+
bias=False,
|
|
229
|
+
),
|
|
230
|
+
nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
|
|
231
|
+
activation(),
|
|
232
|
+
nn.Dropout(self.drop_prob),
|
|
233
|
+
nn.AvgPool2d((1, self.pooling_sizes[3])),
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
|
|
238
|
+
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
|
|
239
|
+
|
|
240
|
+
self.add_module("flat", nn.Flatten())
|
|
241
|
+
|
|
242
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
243
|
+
module = nn.Sequential()
|
|
244
|
+
|
|
245
|
+
module.add_module(
|
|
246
|
+
"fc",
|
|
247
|
+
nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
module.add_module("identity", nn.Identity())
|
|
251
|
+
|
|
252
|
+
self.add_module("final_layer", module)
|
|
253
|
+
|
|
254
|
+
glorot_weight_zero_bias(self)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def _get_inception_branch_1(
|
|
258
|
+
in_channels,
|
|
259
|
+
out_channels,
|
|
260
|
+
kernel_length,
|
|
261
|
+
alpha_momentum,
|
|
262
|
+
drop_prob,
|
|
263
|
+
activation,
|
|
264
|
+
depth_multiplier,
|
|
265
|
+
):
|
|
266
|
+
return nn.Sequential(
|
|
267
|
+
nn.Conv2d(
|
|
268
|
+
1,
|
|
269
|
+
out_channels,
|
|
270
|
+
kernel_size=(1, kernel_length),
|
|
271
|
+
padding="same",
|
|
272
|
+
bias=True,
|
|
273
|
+
),
|
|
274
|
+
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
275
|
+
activation(),
|
|
276
|
+
nn.Dropout(drop_prob),
|
|
277
|
+
DepthwiseConv2d(
|
|
278
|
+
out_channels,
|
|
279
|
+
kernel_size=(in_channels, 1),
|
|
280
|
+
depth_multiplier=depth_multiplier,
|
|
281
|
+
bias=False,
|
|
282
|
+
padding="valid",
|
|
283
|
+
),
|
|
284
|
+
nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
|
|
285
|
+
activation(),
|
|
286
|
+
nn.Dropout(drop_prob),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _get_inception_branch_2(
|
|
291
|
+
in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
|
|
292
|
+
):
|
|
293
|
+
return nn.Sequential(
|
|
294
|
+
nn.Conv2d(
|
|
295
|
+
in_channels,
|
|
296
|
+
out_channels,
|
|
297
|
+
kernel_size=(1, kernel_length),
|
|
298
|
+
padding="same",
|
|
299
|
+
bias=False,
|
|
300
|
+
),
|
|
301
|
+
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
302
|
+
activation(),
|
|
303
|
+
nn.Dropout(drop_prob),
|
|
304
|
+
)
|