braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EEGNeX(EEGModuleMixin, nn.Module):
|
|
17
|
+
"""EEGNeX model from Chen et al. (2024) [eegnex]_.
|
|
18
|
+
|
|
19
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: EEGNeX Architecture
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
activation : nn.Module, optional
|
|
26
|
+
Activation function to use. Default is `nn.ELU`.
|
|
27
|
+
depth_multiplier : int, optional
|
|
28
|
+
Depth multiplier for the depthwise convolution. Default is 2.
|
|
29
|
+
filter_1 : int, optional
|
|
30
|
+
Number of filters in the first convolutional layer. Default is 8.
|
|
31
|
+
filter_2 : int, optional
|
|
32
|
+
Number of filters in the second convolutional layer. Default is 32.
|
|
33
|
+
drop_prob: float, optional
|
|
34
|
+
Dropout rate. Default is 0.5.
|
|
35
|
+
kernel_block_4 : tuple[int, int], optional
|
|
36
|
+
Kernel size for block 4. Default is (1, 16).
|
|
37
|
+
dilation_block_4 : tuple[int, int], optional
|
|
38
|
+
Dilation rate for block 4. Default is (1, 2).
|
|
39
|
+
avg_pool_block4 : tuple[int, int], optional
|
|
40
|
+
Pooling size for block 4. Default is (1, 4).
|
|
41
|
+
kernel_block_5 : tuple[int, int], optional
|
|
42
|
+
Kernel size for block 5. Default is (1, 16).
|
|
43
|
+
dilation_block_5 : tuple[int, int], optional
|
|
44
|
+
Dilation rate for block 5. Default is (1, 4).
|
|
45
|
+
avg_pool_block5 : tuple[int, int], optional
|
|
46
|
+
Pooling size for block 5. Default is (1, 8).
|
|
47
|
+
|
|
48
|
+
Notes
|
|
49
|
+
-----
|
|
50
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
51
|
+
by original authors, only reimplemented from the paper description and
|
|
52
|
+
source code in tensorflow [EEGNexCode]_.
|
|
53
|
+
|
|
54
|
+
References
|
|
55
|
+
----------
|
|
56
|
+
.. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
57
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
58
|
+
study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
|
|
59
|
+
.. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
60
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
61
|
+
study to EEGNeX. https://github.com/chenxiachan/EEGNeX
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
# Signal related parameters
|
|
67
|
+
n_chans=None,
|
|
68
|
+
n_outputs=None,
|
|
69
|
+
n_times=None,
|
|
70
|
+
chs_info=None,
|
|
71
|
+
input_window_seconds=None,
|
|
72
|
+
sfreq=None,
|
|
73
|
+
# Model parameters
|
|
74
|
+
activation: nn.Module = nn.ELU,
|
|
75
|
+
depth_multiplier: int = 2,
|
|
76
|
+
filter_1: int = 8,
|
|
77
|
+
filter_2: int = 32,
|
|
78
|
+
drop_prob: float = 0.5,
|
|
79
|
+
kernel_block_1_2: int = 64,
|
|
80
|
+
kernel_block_4: int = 16,
|
|
81
|
+
dilation_block_4: int = 2,
|
|
82
|
+
avg_pool_block4: int = 4,
|
|
83
|
+
kernel_block_5: int = 16,
|
|
84
|
+
dilation_block_5: int = 4,
|
|
85
|
+
avg_pool_block5: int = 8,
|
|
86
|
+
max_norm_conv: float = 1.0,
|
|
87
|
+
max_norm_linear: float = 0.25,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(
|
|
90
|
+
n_outputs=n_outputs,
|
|
91
|
+
n_chans=n_chans,
|
|
92
|
+
chs_info=chs_info,
|
|
93
|
+
n_times=n_times,
|
|
94
|
+
input_window_seconds=input_window_seconds,
|
|
95
|
+
sfreq=sfreq,
|
|
96
|
+
)
|
|
97
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
98
|
+
|
|
99
|
+
self.depth_multiplier = depth_multiplier
|
|
100
|
+
self.filter_1 = filter_1
|
|
101
|
+
self.filter_2 = filter_2
|
|
102
|
+
self.filter_3 = self.filter_2 * self.depth_multiplier
|
|
103
|
+
self.drop_prob = drop_prob
|
|
104
|
+
self.activation = activation
|
|
105
|
+
self.kernel_block_1_2 = (1, kernel_block_1_2)
|
|
106
|
+
self.kernel_block_4 = (1, kernel_block_4)
|
|
107
|
+
self.dilation_block_4 = (1, dilation_block_4)
|
|
108
|
+
self.avg_pool_block4 = (1, avg_pool_block4)
|
|
109
|
+
self.kernel_block_5 = (1, kernel_block_5)
|
|
110
|
+
self.dilation_block_5 = (1, dilation_block_5)
|
|
111
|
+
self.avg_pool_block5 = (1, avg_pool_block5)
|
|
112
|
+
|
|
113
|
+
# final layers output
|
|
114
|
+
self.in_features = self._calculate_output_length()
|
|
115
|
+
|
|
116
|
+
# Following paper nomenclature
|
|
117
|
+
self.block_1 = nn.Sequential(
|
|
118
|
+
Rearrange("batch ch time -> batch 1 ch time"),
|
|
119
|
+
nn.Conv2d(
|
|
120
|
+
in_channels=1,
|
|
121
|
+
out_channels=self.filter_1,
|
|
122
|
+
kernel_size=self.kernel_block_1_2,
|
|
123
|
+
padding="same",
|
|
124
|
+
bias=False,
|
|
125
|
+
),
|
|
126
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.block_2 = nn.Sequential(
|
|
130
|
+
nn.Conv2d(
|
|
131
|
+
in_channels=self.filter_1,
|
|
132
|
+
out_channels=self.filter_2,
|
|
133
|
+
kernel_size=self.kernel_block_1_2,
|
|
134
|
+
padding="same",
|
|
135
|
+
bias=False,
|
|
136
|
+
),
|
|
137
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.block_3 = nn.Sequential(
|
|
141
|
+
Conv2dWithConstraint(
|
|
142
|
+
in_channels=self.filter_2,
|
|
143
|
+
out_channels=self.filter_3,
|
|
144
|
+
max_norm=max_norm_conv,
|
|
145
|
+
kernel_size=(self.n_chans, 1),
|
|
146
|
+
groups=self.filter_2,
|
|
147
|
+
bias=False,
|
|
148
|
+
),
|
|
149
|
+
nn.BatchNorm2d(num_features=self.filter_3),
|
|
150
|
+
self.activation(),
|
|
151
|
+
nn.AvgPool2d(
|
|
152
|
+
kernel_size=self.avg_pool_block4,
|
|
153
|
+
padding=(0, 1),
|
|
154
|
+
),
|
|
155
|
+
nn.Dropout(p=self.drop_prob),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.block_4 = nn.Sequential(
|
|
159
|
+
nn.Conv2d(
|
|
160
|
+
in_channels=self.filter_3,
|
|
161
|
+
out_channels=self.filter_2,
|
|
162
|
+
kernel_size=self.kernel_block_4,
|
|
163
|
+
dilation=self.dilation_block_4,
|
|
164
|
+
padding="same",
|
|
165
|
+
bias=False,
|
|
166
|
+
),
|
|
167
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.block_5 = nn.Sequential(
|
|
171
|
+
nn.Conv2d(
|
|
172
|
+
in_channels=self.filter_2,
|
|
173
|
+
out_channels=self.filter_1,
|
|
174
|
+
kernel_size=self.kernel_block_5,
|
|
175
|
+
dilation=self.dilation_block_5,
|
|
176
|
+
padding="same",
|
|
177
|
+
bias=False,
|
|
178
|
+
),
|
|
179
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
180
|
+
self.activation(),
|
|
181
|
+
nn.AvgPool2d(
|
|
182
|
+
kernel_size=self.avg_pool_block5,
|
|
183
|
+
padding=(0, 1),
|
|
184
|
+
),
|
|
185
|
+
nn.Dropout(p=self.drop_prob),
|
|
186
|
+
nn.Flatten(),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
self.final_layer = LinearWithConstraint(
|
|
190
|
+
in_features=self.in_features,
|
|
191
|
+
out_features=self.n_outputs,
|
|
192
|
+
max_norm=max_norm_linear,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
"""
|
|
197
|
+
Forward pass of the EEGNeX model.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
x : torch.Tensor
|
|
202
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
torch.Tensor
|
|
207
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
208
|
+
"""
|
|
209
|
+
# x shape: (batch_size, n_chans, n_times)
|
|
210
|
+
x = self.block_1(x)
|
|
211
|
+
# (batch_size, n_filter, n_chans, n_times)
|
|
212
|
+
x = self.block_2(x)
|
|
213
|
+
# (batch_size, n_filter*4, n_chans, n_times)
|
|
214
|
+
x = self.block_3(x)
|
|
215
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
216
|
+
x = self.block_4(x)
|
|
217
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
218
|
+
x = self.block_5(x)
|
|
219
|
+
# (batch_size, n_filter*(n_times//32))
|
|
220
|
+
x = self.final_layer(x)
|
|
221
|
+
|
|
222
|
+
return x
|
|
223
|
+
|
|
224
|
+
def _calculate_output_length(self) -> int:
|
|
225
|
+
# Pooling kernel sizes for the time dimension
|
|
226
|
+
p4 = self.avg_pool_block4[1]
|
|
227
|
+
p5 = self.avg_pool_block5[1]
|
|
228
|
+
|
|
229
|
+
# Padding for the time dimension (assumed from padding=(0, 1))
|
|
230
|
+
pad4 = 1
|
|
231
|
+
pad5 = 1
|
|
232
|
+
|
|
233
|
+
# Stride is assumed to be equal to kernel size (p4 and p5)
|
|
234
|
+
|
|
235
|
+
# Calculate time dimension after block 3 pooling
|
|
236
|
+
# Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
|
|
237
|
+
T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
|
|
238
|
+
|
|
239
|
+
# Calculate time dimension after block 5 pooling
|
|
240
|
+
T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
|
|
241
|
+
|
|
242
|
+
# Calculate final flattened features (channels * 1 * time_dim)
|
|
243
|
+
# The spatial dimension is reduced to 1 after block 3's depthwise conv.
|
|
244
|
+
final_in_features = (
|
|
245
|
+
self.filter_1 * T5
|
|
246
|
+
) # filter_1 is the number of channels before flatten
|
|
247
|
+
return final_in_features
|