braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -3,15 +3,20 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn
|
|
7
6
|
from einops.layers.torch import Rearrange
|
|
7
|
+
from torch import nn
|
|
8
8
|
|
|
9
|
-
from .
|
|
10
|
-
from .
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import Ensure4d
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
14
|
-
"""EEG Inception for Motor Imagery, as proposed in [1]_
|
|
14
|
+
"""EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
|
|
15
|
+
|
|
16
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: EEGInceptionMI Architecture
|
|
19
|
+
|
|
15
20
|
|
|
16
21
|
The model is strongly based on the original InceptionNet for computer
|
|
17
22
|
vision. The main goal is to extract features in parallel with different
|
|
@@ -43,47 +48,31 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
43
48
|
Size in seconds of the basic 1D convolutional kernel used in inception
|
|
44
49
|
modules. Each convolutional layer in such modules have kernels of
|
|
45
50
|
increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7,
|
|
46
|
-
0.9 here for
|
|
51
|
+
0.9 here for ``n_convs=5``). Defaults to 0.1 s.
|
|
47
52
|
activation: nn.Module
|
|
48
53
|
Activation function. Defaults to ReLU activation.
|
|
49
|
-
in_channels : int
|
|
50
|
-
Alias for `n_chans`.
|
|
51
|
-
n_classes : int
|
|
52
|
-
Alias for `n_outputs`.
|
|
53
|
-
input_window_s : float, optional
|
|
54
|
-
Alias for `input_window_seconds`.
|
|
55
54
|
|
|
56
55
|
References
|
|
57
56
|
----------
|
|
58
57
|
.. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021).
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
58
|
+
EEG-inception: an accurate and robust end-to-end neural network
|
|
59
|
+
for EEG-based motor imagery classification.
|
|
60
|
+
Journal of Neural Engineering, 18(4), 046014.
|
|
62
61
|
"""
|
|
63
62
|
|
|
64
63
|
def __init__(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
in_channels=None,
|
|
77
|
-
n_classes=None,
|
|
78
|
-
input_window_s=None,
|
|
79
|
-
add_log_softmax=True,
|
|
64
|
+
self,
|
|
65
|
+
n_chans=None,
|
|
66
|
+
n_outputs=None,
|
|
67
|
+
input_window_seconds=None,
|
|
68
|
+
sfreq=250,
|
|
69
|
+
n_convs=5,
|
|
70
|
+
n_filters=48,
|
|
71
|
+
kernel_unit_s=0.1,
|
|
72
|
+
activation: nn.Module = nn.ReLU,
|
|
73
|
+
chs_info=None,
|
|
74
|
+
n_times=None,
|
|
80
75
|
):
|
|
81
|
-
n_chans, n_outputs, input_window_seconds, = deprecated_args(
|
|
82
|
-
self,
|
|
83
|
-
('in_channels', 'n_chans', in_channels, n_chans),
|
|
84
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
85
|
-
('input_window_s', 'input_window_seconds', input_window_s, input_window_seconds),
|
|
86
|
-
)
|
|
87
76
|
super().__init__(
|
|
88
77
|
n_outputs=n_outputs,
|
|
89
78
|
n_chans=n_chans,
|
|
@@ -91,10 +80,8 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
91
80
|
n_times=n_times,
|
|
92
81
|
input_window_seconds=input_window_seconds,
|
|
93
82
|
sfreq=sfreq,
|
|
94
|
-
add_log_softmax=add_log_softmax,
|
|
95
83
|
)
|
|
96
84
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
97
|
-
del in_channels, n_classes, input_window_s
|
|
98
85
|
|
|
99
86
|
self.n_convs = n_convs
|
|
100
87
|
self.n_filters = n_filters
|
|
@@ -105,8 +92,9 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
105
92
|
self.dimshuffle = Rearrange("batch C T 1 -> batch C 1 T")
|
|
106
93
|
|
|
107
94
|
self.mapping = {
|
|
108
|
-
|
|
109
|
-
|
|
95
|
+
"fc.weight": "final_layer.fc.weight",
|
|
96
|
+
"tc.bias": "final_layer.fc.bias",
|
|
97
|
+
}
|
|
110
98
|
|
|
111
99
|
# ======== Inception branches ========================
|
|
112
100
|
|
|
@@ -121,16 +109,19 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
121
109
|
|
|
122
110
|
intermediate_in_channels = (self.n_convs + 1) * self.n_filters
|
|
123
111
|
|
|
124
|
-
self.intermediate_inception_modules_1 = nn.ModuleList(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
112
|
+
self.intermediate_inception_modules_1 = nn.ModuleList(
|
|
113
|
+
[
|
|
114
|
+
_InceptionModuleMI(
|
|
115
|
+
in_channels=intermediate_in_channels,
|
|
116
|
+
n_filters=self.n_filters,
|
|
117
|
+
n_convs=self.n_convs,
|
|
118
|
+
kernel_unit_s=self.kernel_unit_s,
|
|
119
|
+
sfreq=self.sfreq,
|
|
120
|
+
activation=self.activation,
|
|
121
|
+
)
|
|
122
|
+
for _ in range(2)
|
|
123
|
+
]
|
|
124
|
+
)
|
|
134
125
|
|
|
135
126
|
self.residual_block_1 = _ResidualModuleMI(
|
|
136
127
|
in_channels=self.n_chans,
|
|
@@ -138,16 +129,19 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
138
129
|
activation=self.activation,
|
|
139
130
|
)
|
|
140
131
|
|
|
141
|
-
self.intermediate_inception_modules_2 = nn.ModuleList(
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
132
|
+
self.intermediate_inception_modules_2 = nn.ModuleList(
|
|
133
|
+
[
|
|
134
|
+
_InceptionModuleMI(
|
|
135
|
+
in_channels=intermediate_in_channels,
|
|
136
|
+
n_filters=self.n_filters,
|
|
137
|
+
n_convs=self.n_convs,
|
|
138
|
+
kernel_unit_s=self.kernel_unit_s,
|
|
139
|
+
sfreq=self.sfreq,
|
|
140
|
+
activation=self.activation,
|
|
141
|
+
)
|
|
142
|
+
for _ in range(3)
|
|
143
|
+
]
|
|
144
|
+
)
|
|
151
145
|
|
|
152
146
|
self.residual_block_2 = _ResidualModuleMI(
|
|
153
147
|
in_channels=intermediate_in_channels,
|
|
@@ -172,19 +166,20 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
172
166
|
self.flat = nn.Flatten()
|
|
173
167
|
|
|
174
168
|
module = nn.Sequential()
|
|
175
|
-
module.add_module(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
169
|
+
module.add_module(
|
|
170
|
+
"fc",
|
|
171
|
+
nn.Linear(
|
|
172
|
+
in_features=intermediate_in_channels,
|
|
173
|
+
out_features=self.n_outputs,
|
|
174
|
+
bias=True,
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
module.add_module("out_fun", nn.Identity())
|
|
183
178
|
self.final_layer = module
|
|
184
179
|
|
|
185
180
|
def forward(
|
|
186
|
-
|
|
187
|
-
|
|
181
|
+
self,
|
|
182
|
+
X: torch.Tensor,
|
|
188
183
|
) -> torch.Tensor:
|
|
189
184
|
X = self.ensuredims(X)
|
|
190
185
|
X = self.dimshuffle(X)
|
|
@@ -211,14 +206,46 @@ class EEGInceptionMI(EEGModuleMixin, nn.Module):
|
|
|
211
206
|
|
|
212
207
|
|
|
213
208
|
class _InceptionModuleMI(nn.Module):
|
|
209
|
+
"""
|
|
210
|
+
Inception module.
|
|
211
|
+
|
|
212
|
+
This module implements a inception-like architecture that processes input
|
|
213
|
+
feature maps through multiple convolutional paths with different kernel
|
|
214
|
+
sizes, allowing the network to capture features at multiple scales.
|
|
215
|
+
It includes bottleneck layers, convolutional layers, and pooling layers,
|
|
216
|
+
followed by batch normalization and an activation function.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
in_channels : int
|
|
221
|
+
Number of input channels in the input tensor.
|
|
222
|
+
n_filters : int
|
|
223
|
+
Number of filters (output channels) for each convolutional layer.
|
|
224
|
+
n_convs : int
|
|
225
|
+
Number of convolutional layers in the module (excluding bottleneck
|
|
226
|
+
and pooling paths).
|
|
227
|
+
kernel_unit_s : float, optional
|
|
228
|
+
Base size (in seconds) for the convolutional kernels. The actual kernel
|
|
229
|
+
size is computed as ``(2 * n_units + 1) * kernel_unit``, where ``n_units``
|
|
230
|
+
ranges from 0 to ``n_convs - 1``. Default is 0.1 seconds.
|
|
231
|
+
sfreq : float, optional
|
|
232
|
+
Sampling frequency of the input data, used to convert kernel sizes from
|
|
233
|
+
seconds to samples. Default is 250 Hz.
|
|
234
|
+
activation : nn.Module class, optional
|
|
235
|
+
Activation function class to apply after batch normalization. Should be
|
|
236
|
+
a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``.
|
|
237
|
+
Default is ``nn.ReLU``.
|
|
238
|
+
|
|
239
|
+
"""
|
|
240
|
+
|
|
214
241
|
def __init__(
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
242
|
+
self,
|
|
243
|
+
in_channels,
|
|
244
|
+
n_filters,
|
|
245
|
+
n_convs,
|
|
246
|
+
kernel_unit_s=0.1,
|
|
247
|
+
sfreq=250,
|
|
248
|
+
activation: nn.Module = nn.ReLU,
|
|
222
249
|
):
|
|
223
250
|
super().__init__()
|
|
224
251
|
self.in_channels = in_channels
|
|
@@ -253,23 +280,26 @@ class _InceptionModuleMI(nn.Module):
|
|
|
253
280
|
bias=True,
|
|
254
281
|
)
|
|
255
282
|
|
|
256
|
-
self.conv_list = nn.ModuleList(
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
283
|
+
self.conv_list = nn.ModuleList(
|
|
284
|
+
[
|
|
285
|
+
nn.Conv2d(
|
|
286
|
+
in_channels=self.n_filters,
|
|
287
|
+
out_channels=self.n_filters,
|
|
288
|
+
kernel_size=(1, (n_units * 2 + 1) * kernel_unit),
|
|
289
|
+
padding="same",
|
|
290
|
+
bias=True,
|
|
291
|
+
)
|
|
292
|
+
for n_units in range(self.n_convs)
|
|
293
|
+
]
|
|
294
|
+
)
|
|
265
295
|
|
|
266
296
|
self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1))
|
|
267
297
|
|
|
268
|
-
self.activation = activation
|
|
298
|
+
self.activation = activation()
|
|
269
299
|
|
|
270
300
|
def forward(
|
|
271
|
-
|
|
272
|
-
|
|
301
|
+
self,
|
|
302
|
+
X: torch.Tensor,
|
|
273
303
|
) -> torch.Tensor:
|
|
274
304
|
X1 = self.bottleneck(X)
|
|
275
305
|
|
|
@@ -285,16 +315,31 @@ class _InceptionModuleMI(nn.Module):
|
|
|
285
315
|
|
|
286
316
|
|
|
287
317
|
class _ResidualModuleMI(nn.Module):
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
318
|
+
"""
|
|
319
|
+
Residual module.
|
|
320
|
+
|
|
321
|
+
This module performs a 1x1 convolution followed by batch normalization and an activation function.
|
|
322
|
+
It is designed to process input feature maps and produce transformed output feature maps, often used
|
|
323
|
+
in residual connections within neural network architectures.
|
|
324
|
+
|
|
325
|
+
Parameters
|
|
326
|
+
----------
|
|
327
|
+
in_channels : int
|
|
328
|
+
Number of input channels in the input tensor.
|
|
329
|
+
n_filters : int
|
|
330
|
+
Number of filters (output channels) for the convolutional layer.
|
|
331
|
+
activation : nn.Module, optional
|
|
332
|
+
Activation function to apply after batch normalization. Should be an instance of a PyTorch
|
|
333
|
+
activation module (e.g., ``nn.ReLU()``, ``nn.ELU()``). Default is ``nn.ReLU()``.
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(self, in_channels, n_filters, activation: nn.Module = nn.ReLU):
|
|
294
339
|
super().__init__()
|
|
295
340
|
self.in_channels = in_channels
|
|
296
341
|
self.n_filters = n_filters
|
|
297
|
-
self.activation = activation
|
|
342
|
+
self.activation = activation()
|
|
298
343
|
|
|
299
344
|
self.bn = nn.BatchNorm2d(self.n_filters)
|
|
300
345
|
self.conv = nn.Conv2d(
|
|
@@ -305,9 +350,22 @@ class _ResidualModuleMI(nn.Module):
|
|
|
305
350
|
)
|
|
306
351
|
|
|
307
352
|
def forward(
|
|
308
|
-
|
|
309
|
-
|
|
353
|
+
self,
|
|
354
|
+
X: torch.Tensor,
|
|
310
355
|
) -> torch.Tensor:
|
|
356
|
+
"""
|
|
357
|
+
Forward pass of the residual module.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
X : torch.Tensor
|
|
362
|
+
Input tensor of shape (batch_size, ch_names, n_times).
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
torch.Tensor
|
|
367
|
+
Output tensor after convolution, batch normalization, and activation function.
|
|
368
|
+
"""
|
|
311
369
|
out = self.conv(X)
|
|
312
370
|
out = self.bn(out)
|
|
313
371
|
return self.activation(out)
|