braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import Ensure4d
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
14
|
+
r"""EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
|
|
15
|
+
|
|
16
|
+
:bdg-success:`Convolution`
|
|
17
|
+
|
|
18
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: EEGInceptionMI Architecture
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
The model is strongly based on the original InceptionNet for computer
|
|
24
|
+
vision. The main goal is to extract features in parallel with different
|
|
25
|
+
scales. The network has two blocks made of 3 inception modules with a skip
|
|
26
|
+
connection.
|
|
27
|
+
|
|
28
|
+
The model is fully described in [1]_.
|
|
29
|
+
|
|
30
|
+
Notes
|
|
31
|
+
-----
|
|
32
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
33
|
+
by original authors, only reimplemented bosed on the paper [1]_.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
input_window_seconds : float, optional
|
|
38
|
+
Size of the input, in seconds. Set to 4.5 s as in [1]_ for dataset
|
|
39
|
+
BCI IV 2a.
|
|
40
|
+
sfreq : float, optional
|
|
41
|
+
EEG sampling frequency in Hz. Defaults to 250 Hz as in [1]_ for dataset
|
|
42
|
+
BCI IV 2a.
|
|
43
|
+
n_convs : int, optional
|
|
44
|
+
Number of convolution per inception wide branching. Defaults to 5 as
|
|
45
|
+
in [1]_ for dataset BCI IV 2a.
|
|
46
|
+
n_filters : int, optional
|
|
47
|
+
Number of convolutional filters for all layers of this type. Set to 48
|
|
48
|
+
as in [1]_ for dataset BCI IV 2a.
|
|
49
|
+
kernel_unit_s : float, optional
|
|
50
|
+
Size in seconds of the basic 1D convolutional kernel used in inception
|
|
51
|
+
modules. Each convolutional layer in such modules have kernels of
|
|
52
|
+
increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7,
|
|
53
|
+
0.9 here for ``n_convs=5``). Defaults to 0.1 s.
|
|
54
|
+
activation: nn.Module
|
|
55
|
+
Activation function. Defaults to ReLU activation.
|
|
56
|
+
|
|
57
|
+
References
|
|
58
|
+
----------
|
|
59
|
+
.. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021).
|
|
60
|
+
EEG-inception: an accurate and robust end-to-end neural network
|
|
61
|
+
for EEG-based motor imagery classification.
|
|
62
|
+
Journal of Neural Engineering, 18(4), 046014.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
n_chans=None,
|
|
68
|
+
n_outputs=None,
|
|
69
|
+
input_window_seconds=None,
|
|
70
|
+
sfreq=250,
|
|
71
|
+
n_convs: int = 5,
|
|
72
|
+
n_filters: int = 48,
|
|
73
|
+
kernel_unit_s: float = 0.1,
|
|
74
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
75
|
+
chs_info=None,
|
|
76
|
+
n_times=None,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(
|
|
79
|
+
n_outputs=n_outputs,
|
|
80
|
+
n_chans=n_chans,
|
|
81
|
+
chs_info=chs_info,
|
|
82
|
+
n_times=n_times,
|
|
83
|
+
input_window_seconds=input_window_seconds,
|
|
84
|
+
sfreq=sfreq,
|
|
85
|
+
)
|
|
86
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
87
|
+
|
|
88
|
+
self.n_convs = n_convs
|
|
89
|
+
self.n_filters = n_filters
|
|
90
|
+
self.kernel_unit_s = kernel_unit_s
|
|
91
|
+
self.activation = activation
|
|
92
|
+
|
|
93
|
+
self.ensuredims = Ensure4d()
|
|
94
|
+
self.dimshuffle = Rearrange("batch C T 1 -> batch C 1 T")
|
|
95
|
+
|
|
96
|
+
self.mapping = {
|
|
97
|
+
"fc.weight": "final_layer.fc.weight",
|
|
98
|
+
"tc.bias": "final_layer.fc.bias",
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# ======== Inception branches ========================
|
|
102
|
+
|
|
103
|
+
self.initial_inception_module = _InceptionModuleMI(
|
|
104
|
+
in_channels=self.n_chans,
|
|
105
|
+
n_filters=self.n_filters,
|
|
106
|
+
n_convs=self.n_convs,
|
|
107
|
+
kernel_unit_s=self.kernel_unit_s,
|
|
108
|
+
sfreq=self.sfreq,
|
|
109
|
+
activation=self.activation,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
intermediate_in_channels = (self.n_convs + 1) * self.n_filters
|
|
113
|
+
|
|
114
|
+
self.intermediate_inception_modules_1 = nn.ModuleList(
|
|
115
|
+
[
|
|
116
|
+
_InceptionModuleMI(
|
|
117
|
+
in_channels=intermediate_in_channels,
|
|
118
|
+
n_filters=self.n_filters,
|
|
119
|
+
n_convs=self.n_convs,
|
|
120
|
+
kernel_unit_s=self.kernel_unit_s,
|
|
121
|
+
sfreq=self.sfreq,
|
|
122
|
+
activation=self.activation,
|
|
123
|
+
)
|
|
124
|
+
for _ in range(2)
|
|
125
|
+
]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.residual_block_1 = _ResidualModuleMI(
|
|
129
|
+
in_channels=self.n_chans,
|
|
130
|
+
n_filters=intermediate_in_channels,
|
|
131
|
+
activation=self.activation,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.intermediate_inception_modules_2 = nn.ModuleList(
|
|
135
|
+
[
|
|
136
|
+
_InceptionModuleMI(
|
|
137
|
+
in_channels=intermediate_in_channels,
|
|
138
|
+
n_filters=self.n_filters,
|
|
139
|
+
n_convs=self.n_convs,
|
|
140
|
+
kernel_unit_s=self.kernel_unit_s,
|
|
141
|
+
sfreq=self.sfreq,
|
|
142
|
+
activation=self.activation,
|
|
143
|
+
)
|
|
144
|
+
for _ in range(3)
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.residual_block_2 = _ResidualModuleMI(
|
|
149
|
+
in_channels=intermediate_in_channels,
|
|
150
|
+
n_filters=intermediate_in_channels,
|
|
151
|
+
activation=self.activation,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# XXX The paper mentions a final average pooling but does not indicate
|
|
155
|
+
# the kernel size... The only info available is figure1 showing a
|
|
156
|
+
# final AveragePooling layer and the table3 indicating the spatial and
|
|
157
|
+
# channel dimensions are unchanged by this layer... This could indicate
|
|
158
|
+
# a stride=1 as for MaxPooling layers. Howevere, when we look at the
|
|
159
|
+
# number of parameters of the linear layer following the average
|
|
160
|
+
# pooling, we see a small number of parameters, potentially indicating
|
|
161
|
+
# that the whole time dimension is averaged on this stage for each
|
|
162
|
+
# channel. We follow this last hypothesis here to comply with the
|
|
163
|
+
# number of parameters reported in the paper.
|
|
164
|
+
self.ave_pooling = nn.AvgPool2d(
|
|
165
|
+
kernel_size=(1, self.n_times),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.flat = nn.Flatten()
|
|
169
|
+
|
|
170
|
+
module = nn.Sequential()
|
|
171
|
+
module.add_module(
|
|
172
|
+
"fc",
|
|
173
|
+
nn.Linear(
|
|
174
|
+
in_features=intermediate_in_channels,
|
|
175
|
+
out_features=self.n_outputs,
|
|
176
|
+
bias=True,
|
|
177
|
+
),
|
|
178
|
+
)
|
|
179
|
+
module.add_module("out_fun", nn.Identity())
|
|
180
|
+
self.final_layer = module
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self,
|
|
184
|
+
X: torch.Tensor,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
X = self.ensuredims(X)
|
|
187
|
+
X = self.dimshuffle(X)
|
|
188
|
+
|
|
189
|
+
res1 = self.residual_block_1(X)
|
|
190
|
+
|
|
191
|
+
out = self.initial_inception_module(X)
|
|
192
|
+
for layer in self.intermediate_inception_modules_1:
|
|
193
|
+
out = layer(out)
|
|
194
|
+
|
|
195
|
+
out = out + res1
|
|
196
|
+
|
|
197
|
+
res2 = self.residual_block_2(out)
|
|
198
|
+
|
|
199
|
+
for layer in self.intermediate_inception_modules_2:
|
|
200
|
+
out = layer(out)
|
|
201
|
+
|
|
202
|
+
out = res2 + out
|
|
203
|
+
|
|
204
|
+
out = self.ave_pooling(out)
|
|
205
|
+
out = self.flat(out)
|
|
206
|
+
out = self.final_layer(out)
|
|
207
|
+
return out
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class _InceptionModuleMI(nn.Module):
|
|
211
|
+
r"""
|
|
212
|
+
Inception module.
|
|
213
|
+
|
|
214
|
+
This module implements a inception-like architecture that processes input
|
|
215
|
+
feature maps through multiple convolutional paths with different kernel
|
|
216
|
+
sizes, allowing the network to capture features at multiple scales.
|
|
217
|
+
It includes bottleneck layers, convolutional layers, and pooling layers,
|
|
218
|
+
followed by batch normalization and an activation function.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
in_channels : int
|
|
223
|
+
Number of input channels in the input tensor.
|
|
224
|
+
n_filters : int
|
|
225
|
+
Number of filters (output channels) for each convolutional layer.
|
|
226
|
+
n_convs : int
|
|
227
|
+
Number of convolutional layers in the module (excluding bottleneck
|
|
228
|
+
and pooling paths).
|
|
229
|
+
kernel_unit_s : float, optional
|
|
230
|
+
Base size (in seconds) for the convolutional kernels. The actual kernel
|
|
231
|
+
size is computed as ``(2 * n_units + 1) * kernel_unit``, where ``n_units``
|
|
232
|
+
ranges from 0 to ``n_convs - 1``. Default is 0.1 seconds.
|
|
233
|
+
sfreq : float, optional
|
|
234
|
+
Sampling frequency of the input data, used to convert kernel sizes from
|
|
235
|
+
seconds to samples. Default is 250 Hz.
|
|
236
|
+
activation : nn.Module class, optional
|
|
237
|
+
Activation function class to apply after batch normalization. Should be
|
|
238
|
+
a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``.
|
|
239
|
+
Default is ``nn.ReLU``.
|
|
240
|
+
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
in_channels,
|
|
246
|
+
n_filters,
|
|
247
|
+
n_convs,
|
|
248
|
+
kernel_unit_s=0.1,
|
|
249
|
+
sfreq=250,
|
|
250
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
251
|
+
):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.in_channels = in_channels
|
|
254
|
+
self.n_filters = n_filters
|
|
255
|
+
self.n_convs = n_convs
|
|
256
|
+
self.kernel_unit_s = kernel_unit_s
|
|
257
|
+
self.sfreq = sfreq
|
|
258
|
+
|
|
259
|
+
self.bottleneck = nn.Conv2d(
|
|
260
|
+
in_channels=self.in_channels,
|
|
261
|
+
out_channels=self.n_filters,
|
|
262
|
+
kernel_size=1,
|
|
263
|
+
bias=True,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
kernel_unit = int(self.kernel_unit_s * self.sfreq)
|
|
267
|
+
|
|
268
|
+
# XXX Maxpooling is usually used to reduce spatial resolution, with a
|
|
269
|
+
# stride equal to the kernel size... But it seems the authors use
|
|
270
|
+
# stride=1 in their paper according to the output shapes from Table3,
|
|
271
|
+
# although this is not clearly specified in the paper text.
|
|
272
|
+
self.pooling = nn.MaxPool2d(
|
|
273
|
+
kernel_size=(1, kernel_unit),
|
|
274
|
+
stride=1,
|
|
275
|
+
padding=(0, int(kernel_unit // 2)),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
self.pooling_conv = nn.Conv2d(
|
|
279
|
+
in_channels=self.in_channels,
|
|
280
|
+
out_channels=self.n_filters,
|
|
281
|
+
kernel_size=1,
|
|
282
|
+
bias=True,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self.conv_list = nn.ModuleList(
|
|
286
|
+
[
|
|
287
|
+
nn.Conv2d(
|
|
288
|
+
in_channels=self.n_filters,
|
|
289
|
+
out_channels=self.n_filters,
|
|
290
|
+
kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
|
|
291
|
+
padding="same",
|
|
292
|
+
bias=True,
|
|
293
|
+
)
|
|
294
|
+
for n_units in range(self.n_convs)
|
|
295
|
+
]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1))
|
|
299
|
+
|
|
300
|
+
self.activation = activation()
|
|
301
|
+
|
|
302
|
+
def forward(
|
|
303
|
+
self,
|
|
304
|
+
X: torch.Tensor,
|
|
305
|
+
) -> torch.Tensor:
|
|
306
|
+
X1 = self.bottleneck(X)
|
|
307
|
+
|
|
308
|
+
X1 = [conv(X1) for conv in self.conv_list]
|
|
309
|
+
|
|
310
|
+
X2 = self.pooling(X)
|
|
311
|
+
X2 = self.pooling_conv(X2)
|
|
312
|
+
# Get the target length from one of the conv branches
|
|
313
|
+
target_len = X1[0].shape[-1]
|
|
314
|
+
|
|
315
|
+
# Crop the pooling output if its length does not match
|
|
316
|
+
if X2.shape[-1] != target_len:
|
|
317
|
+
X2 = X2[..., :target_len]
|
|
318
|
+
|
|
319
|
+
out = torch.cat(X1 + [X2], 1)
|
|
320
|
+
|
|
321
|
+
out = self.bn(out)
|
|
322
|
+
return self.activation(out)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class _ResidualModuleMI(nn.Module):
|
|
326
|
+
r"""
|
|
327
|
+
Residual module.
|
|
328
|
+
|
|
329
|
+
This module performs a 1x1 convolution followed by batch normalization and an activation function.
|
|
330
|
+
It is designed to process input feature maps and produce transformed output feature maps, often used
|
|
331
|
+
in residual connections within neural network architectures.
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
in_channels : int
|
|
336
|
+
Number of input channels in the input tensor.
|
|
337
|
+
n_filters : int
|
|
338
|
+
Number of filters (output channels) for the convolutional layer.
|
|
339
|
+
activation : nn.Module, optional
|
|
340
|
+
Activation function to apply after batch normalization. Should be an instance of a PyTorch
|
|
341
|
+
activation module (e.g., ``nn.ReLU()``, ``nn.ELU()``). Default is ``nn.ReLU()``.
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(self, in_channels, n_filters, activation: type[nn.Module] = nn.ReLU):
|
|
347
|
+
super().__init__()
|
|
348
|
+
self.in_channels = in_channels
|
|
349
|
+
self.n_filters = n_filters
|
|
350
|
+
self.activation = activation()
|
|
351
|
+
|
|
352
|
+
self.bn = nn.BatchNorm2d(self.n_filters)
|
|
353
|
+
self.conv = nn.Conv2d(
|
|
354
|
+
in_channels=self.in_channels,
|
|
355
|
+
out_channels=self.n_filters,
|
|
356
|
+
kernel_size=1,
|
|
357
|
+
bias=True,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def forward(
|
|
361
|
+
self,
|
|
362
|
+
X: torch.Tensor,
|
|
363
|
+
) -> torch.Tensor:
|
|
364
|
+
"""
|
|
365
|
+
Forward pass of the residual module.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
X : torch.Tensor
|
|
370
|
+
Input tensor of shape (batch_size, ch_names, n_times).
|
|
371
|
+
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
torch.Tensor
|
|
375
|
+
Output tensor after convolution, batch normalization, and activation function.
|
|
376
|
+
"""
|
|
377
|
+
out = self.conv(X)
|
|
378
|
+
out = self.bn(out)
|
|
379
|
+
return self.activation(out)
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD-3
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EEGITNet(EEGModuleMixin, nn.Sequential):
|
|
12
|
+
r"""EEG-ITNet from Salami, et al (2022) [Salami2022]_
|
|
13
|
+
|
|
14
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
15
|
+
|
|
16
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: EEG-ITNet Architecture
|
|
19
|
+
|
|
20
|
+
EEG-ITNet: An Explainable Inception Temporal
|
|
21
|
+
Convolutional Network for motor imagery classification from
|
|
22
|
+
Salami et al. 2022.
|
|
23
|
+
|
|
24
|
+
See [Salami2022]_ for details.
|
|
25
|
+
|
|
26
|
+
Code adapted from https://github.com/abbassalami/eeg-itnet
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
drop_prob: float
|
|
31
|
+
Dropout probability.
|
|
32
|
+
activation: nn.Module, default=nn.ELU
|
|
33
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
34
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
35
|
+
kernel_length : int, optional
|
|
36
|
+
Kernel length for inception branches. Determines the temporal receptive field.
|
|
37
|
+
Default is 16.
|
|
38
|
+
pool_kernel : int, optional
|
|
39
|
+
Pooling kernel size for the average pooling layer. Default is 4.
|
|
40
|
+
tcn_in_channel : int, optional
|
|
41
|
+
Number of input channels for Temporal Convolutional (TC) blocks. Default is 14.
|
|
42
|
+
tcn_kernel_size : int, optional
|
|
43
|
+
Kernel size for the TC blocks. Determines the temporal receptive field.
|
|
44
|
+
Default is 4.
|
|
45
|
+
tcn_padding : int, optional
|
|
46
|
+
Padding size for the TC blocks to maintain the input dimensions. Default is 3.
|
|
47
|
+
drop_prob : float, optional
|
|
48
|
+
Dropout probability applied after certain layers to prevent overfitting.
|
|
49
|
+
Default is 0.4.
|
|
50
|
+
tcn_dilatation : int, optional
|
|
51
|
+
Dilation rate for the first TC block. Subsequent blocks will have
|
|
52
|
+
dilation rates multiplied by powers of 2. Default is 1.
|
|
53
|
+
|
|
54
|
+
Notes
|
|
55
|
+
-----
|
|
56
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
57
|
+
by original authors, only reimplemented from the paper based on author implementation.
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
.. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet:
|
|
63
|
+
An Explainable Inception Temporal Convolutional Network for motor
|
|
64
|
+
imagery classification," in IEEE Access,
|
|
65
|
+
doi: 10.1109/ACCESS.2022.3161489.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
# Braindecode parameters
|
|
71
|
+
n_outputs=None,
|
|
72
|
+
n_chans=None,
|
|
73
|
+
n_times=None,
|
|
74
|
+
chs_info=None,
|
|
75
|
+
input_window_seconds=None,
|
|
76
|
+
sfreq=None,
|
|
77
|
+
# Model parameters
|
|
78
|
+
n_filters_time: int = 2,
|
|
79
|
+
kernel_length: int = 16,
|
|
80
|
+
pool_kernel: int = 4,
|
|
81
|
+
tcn_in_channel: int = 14,
|
|
82
|
+
tcn_kernel_size: int = 4,
|
|
83
|
+
tcn_padding: int = 3,
|
|
84
|
+
drop_prob: float = 0.4,
|
|
85
|
+
tcn_dilatation: int = 1,
|
|
86
|
+
activation: type[nn.Module] = nn.ELU,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(
|
|
89
|
+
n_outputs=n_outputs,
|
|
90
|
+
n_chans=n_chans,
|
|
91
|
+
chs_info=chs_info,
|
|
92
|
+
n_times=n_times,
|
|
93
|
+
input_window_seconds=input_window_seconds,
|
|
94
|
+
sfreq=sfreq,
|
|
95
|
+
)
|
|
96
|
+
self.mapping = {
|
|
97
|
+
"classification.1.weight": "final_layer.clf.weight",
|
|
98
|
+
"classification.1.bias": "final_layer.clf.weight",
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
102
|
+
|
|
103
|
+
# ======== Handling EEG input ========================
|
|
104
|
+
self.add_module(
|
|
105
|
+
"input_preprocess",
|
|
106
|
+
nn.Sequential(Ensure4d(), Rearrange("ba ch t 1 -> ba 1 ch t")),
|
|
107
|
+
)
|
|
108
|
+
# ======== Inception branches ========================
|
|
109
|
+
block11 = self._get_inception_branch(
|
|
110
|
+
in_channels=self.n_chans,
|
|
111
|
+
out_channels=n_filters_time,
|
|
112
|
+
kernel_length=kernel_length,
|
|
113
|
+
activation=activation,
|
|
114
|
+
)
|
|
115
|
+
block12 = self._get_inception_branch(
|
|
116
|
+
in_channels=self.n_chans,
|
|
117
|
+
out_channels=n_filters_time * 2,
|
|
118
|
+
kernel_length=kernel_length * 2,
|
|
119
|
+
activation=activation,
|
|
120
|
+
)
|
|
121
|
+
block13 = self._get_inception_branch(
|
|
122
|
+
in_channels=self.n_chans,
|
|
123
|
+
out_channels=n_filters_time * 4,
|
|
124
|
+
kernel_length=n_filters_time * 4,
|
|
125
|
+
activation=activation,
|
|
126
|
+
)
|
|
127
|
+
self.add_module("inception_block", InceptionBlock((block11, block12, block13)))
|
|
128
|
+
self.pool1 = self.add_module(
|
|
129
|
+
"pooling",
|
|
130
|
+
nn.Sequential(
|
|
131
|
+
nn.AvgPool2d(kernel_size=(1, pool_kernel)), nn.Dropout(drop_prob)
|
|
132
|
+
),
|
|
133
|
+
)
|
|
134
|
+
# =========== TC blocks =====================
|
|
135
|
+
self.add_module(
|
|
136
|
+
"TC_block1",
|
|
137
|
+
_TCBlock(
|
|
138
|
+
in_ch=tcn_in_channel,
|
|
139
|
+
kernel_length=tcn_kernel_size,
|
|
140
|
+
dilatation=tcn_dilatation,
|
|
141
|
+
padding=tcn_padding,
|
|
142
|
+
drop_prob=drop_prob,
|
|
143
|
+
activation=activation,
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
# ================================
|
|
147
|
+
self.add_module(
|
|
148
|
+
"TC_block2",
|
|
149
|
+
_TCBlock(
|
|
150
|
+
in_ch=tcn_in_channel,
|
|
151
|
+
kernel_length=tcn_kernel_size,
|
|
152
|
+
dilatation=tcn_dilatation * 2,
|
|
153
|
+
padding=tcn_padding * 2,
|
|
154
|
+
drop_prob=drop_prob,
|
|
155
|
+
activation=activation,
|
|
156
|
+
),
|
|
157
|
+
)
|
|
158
|
+
# ================================
|
|
159
|
+
self.add_module(
|
|
160
|
+
"TC_block3",
|
|
161
|
+
_TCBlock(
|
|
162
|
+
in_ch=tcn_in_channel,
|
|
163
|
+
kernel_length=tcn_kernel_size,
|
|
164
|
+
dilatation=tcn_dilatation * 4,
|
|
165
|
+
padding=tcn_padding * 4,
|
|
166
|
+
drop_prob=drop_prob,
|
|
167
|
+
activation=activation,
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
# ================================
|
|
171
|
+
self.add_module(
|
|
172
|
+
"TC_block4",
|
|
173
|
+
_TCBlock(
|
|
174
|
+
in_ch=tcn_in_channel,
|
|
175
|
+
kernel_length=tcn_kernel_size,
|
|
176
|
+
dilatation=tcn_dilatation * 8,
|
|
177
|
+
padding=tcn_padding * 8,
|
|
178
|
+
drop_prob=drop_prob,
|
|
179
|
+
activation=activation,
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# ============= Dimensionality reduction ===================
|
|
184
|
+
self.add_module(
|
|
185
|
+
"dim_reduction",
|
|
186
|
+
nn.Sequential(
|
|
187
|
+
nn.Conv2d(tcn_in_channel, tcn_in_channel * 2, kernel_size=(1, 1)),
|
|
188
|
+
nn.BatchNorm2d(tcn_in_channel * 2),
|
|
189
|
+
activation(),
|
|
190
|
+
nn.AvgPool2d((1, tcn_kernel_size)),
|
|
191
|
+
nn.Dropout(drop_prob),
|
|
192
|
+
),
|
|
193
|
+
)
|
|
194
|
+
# ============== Classifier ==================
|
|
195
|
+
# Moved flatten to another layer
|
|
196
|
+
self.add_module("flatten", nn.Flatten())
|
|
197
|
+
|
|
198
|
+
num_features = self.get_output_shape()[-1]
|
|
199
|
+
|
|
200
|
+
self.add_module("final_layer", nn.Linear(num_features, self.n_outputs))
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def _get_inception_branch(
|
|
204
|
+
in_channels,
|
|
205
|
+
out_channels,
|
|
206
|
+
kernel_length,
|
|
207
|
+
depth_multiplier=1,
|
|
208
|
+
activation: type[nn.Module] = nn.ELU,
|
|
209
|
+
):
|
|
210
|
+
return nn.Sequential(
|
|
211
|
+
nn.Conv2d(
|
|
212
|
+
1,
|
|
213
|
+
out_channels,
|
|
214
|
+
kernel_size=(1, kernel_length),
|
|
215
|
+
padding="same",
|
|
216
|
+
bias=False,
|
|
217
|
+
),
|
|
218
|
+
nn.BatchNorm2d(out_channels),
|
|
219
|
+
DepthwiseConv2d(
|
|
220
|
+
out_channels,
|
|
221
|
+
kernel_size=(in_channels, 1),
|
|
222
|
+
depth_multiplier=depth_multiplier,
|
|
223
|
+
bias=False,
|
|
224
|
+
padding="valid",
|
|
225
|
+
),
|
|
226
|
+
nn.BatchNorm2d(out_channels),
|
|
227
|
+
activation(),
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class _TCBlock(nn.Module):
|
|
232
|
+
r"""
|
|
233
|
+
Temporal Convolutional (TC) block.
|
|
234
|
+
|
|
235
|
+
This module applies two depthwise separable convolutions with dilation and residual
|
|
236
|
+
connections, commonly used in temporal convolutional networks to capture long-range
|
|
237
|
+
dependencies in time-series data.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
in_ch : int
|
|
242
|
+
Number of input channels.
|
|
243
|
+
kernel_length : int
|
|
244
|
+
Length of the convolutional kernels.
|
|
245
|
+
dilatation : int
|
|
246
|
+
Dilatation rate for the convolutions.
|
|
247
|
+
padding : int
|
|
248
|
+
Amount of padding to add to the input.
|
|
249
|
+
drop_prob : float, optional
|
|
250
|
+
Dropout probability. Default is 0.4.
|
|
251
|
+
activation : nn.Module class, optional
|
|
252
|
+
Activation function class to use. Should be a PyTorch activation module class
|
|
253
|
+
like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
in_ch,
|
|
259
|
+
kernel_length,
|
|
260
|
+
dilatation,
|
|
261
|
+
padding,
|
|
262
|
+
drop_prob=0.4,
|
|
263
|
+
activation: type[nn.Module] = nn.ELU,
|
|
264
|
+
):
|
|
265
|
+
super().__init__()
|
|
266
|
+
self.pad = padding
|
|
267
|
+
self.tc1 = nn.Sequential(
|
|
268
|
+
DepthwiseConv2d(
|
|
269
|
+
in_ch,
|
|
270
|
+
kernel_size=(1, kernel_length),
|
|
271
|
+
depth_multiplier=1,
|
|
272
|
+
dilation=(1, dilatation),
|
|
273
|
+
bias=False,
|
|
274
|
+
padding="valid",
|
|
275
|
+
),
|
|
276
|
+
nn.BatchNorm2d(in_ch),
|
|
277
|
+
activation(),
|
|
278
|
+
nn.Dropout(drop_prob),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
self.tc2 = nn.Sequential(
|
|
282
|
+
DepthwiseConv2d(
|
|
283
|
+
in_ch,
|
|
284
|
+
kernel_size=(1, kernel_length),
|
|
285
|
+
depth_multiplier=1,
|
|
286
|
+
dilation=(1, dilatation),
|
|
287
|
+
bias=False,
|
|
288
|
+
padding="valid",
|
|
289
|
+
),
|
|
290
|
+
nn.BatchNorm2d(in_ch),
|
|
291
|
+
activation(),
|
|
292
|
+
nn.Dropout(drop_prob),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def forward(self, x):
|
|
296
|
+
residual = x
|
|
297
|
+
paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
|
|
298
|
+
x = nn.functional.pad(x, paddings)
|
|
299
|
+
x = self.tc1(x)
|
|
300
|
+
x = nn.functional.pad(x, paddings)
|
|
301
|
+
x = self.tc2(x) + residual
|
|
302
|
+
return x
|