braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Optional
|
|
7
|
+
|
|
8
|
+
from einops.layers.torch import Rearrange
|
|
9
|
+
from mne.utils import warn
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import (
|
|
15
|
+
Conv2dWithConstraint,
|
|
16
|
+
Ensure4d,
|
|
17
|
+
Expression,
|
|
18
|
+
LinearWithConstraint,
|
|
19
|
+
SqueezeFinalOutput,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
24
|
+
"""EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
|
|
25
|
+
|
|
26
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
|
|
27
|
+
:align: center
|
|
28
|
+
:alt: EEGNet4 Architecture
|
|
29
|
+
|
|
30
|
+
See details in [EEGNet4]_.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
final_conv_length : int or "auto", default="auto"
|
|
35
|
+
Length of the final convolution layer. If "auto", it is set based on n_times.
|
|
36
|
+
pool_mode : {"mean", "max"}, default="mean"
|
|
37
|
+
Pooling method to use in pooling layers.
|
|
38
|
+
F1 : int, default=8
|
|
39
|
+
Number of temporal filters in the first convolutional layer.
|
|
40
|
+
D : int, default=2
|
|
41
|
+
Depth multiplier for the depthwise convolution.
|
|
42
|
+
F2 : int or None, default=None
|
|
43
|
+
Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
|
|
44
|
+
depthwise_kernel_length : int, default=16
|
|
45
|
+
Length of the depthwise convolution kernel in the separable convolution.
|
|
46
|
+
pool1_kernel_size : int, default=4
|
|
47
|
+
Kernel size of the first pooling layer.
|
|
48
|
+
pool2_kernel_size : int, default=8
|
|
49
|
+
Kernel size of the second pooling layer.
|
|
50
|
+
kernel_length : int, default=64
|
|
51
|
+
Length of the temporal convolution kernel.
|
|
52
|
+
conv_spatial_max_norm : float, default=1
|
|
53
|
+
Maximum norm constraint for the spatial (depthwise) convolution.
|
|
54
|
+
activation : nn.Module, default=nn.ELU
|
|
55
|
+
Non-linear activation function to be used in the layers.
|
|
56
|
+
batch_norm_momentum : float, default=0.01
|
|
57
|
+
Momentum for instance normalization in batch norm layers.
|
|
58
|
+
batch_norm_affine : bool, default=True
|
|
59
|
+
If True, batch norm has learnable affine parameters.
|
|
60
|
+
batch_norm_eps : float, default=1e-3
|
|
61
|
+
Epsilon for numeric stability in batch norm layers.
|
|
62
|
+
drop_prob : float, default=0.25
|
|
63
|
+
Dropout probability.
|
|
64
|
+
final_layer_with_constraint : bool, default=False
|
|
65
|
+
If ``False``, uses a convolution-based classification layer. If ``True``,
|
|
66
|
+
apply a flattened linear layer with constraint on the weights norm as the final classification step.
|
|
67
|
+
norm_rate : float, default=0.25
|
|
68
|
+
Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
|
|
69
|
+
|
|
70
|
+
References
|
|
71
|
+
----------
|
|
72
|
+
.. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
|
|
73
|
+
Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
|
|
74
|
+
neural network for EEG-based brain–computer interfaces. Journal of
|
|
75
|
+
neural engineering, 15(5), 056013.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
# signal's parameters
|
|
81
|
+
n_chans: Optional[int] = None,
|
|
82
|
+
n_outputs: Optional[int] = None,
|
|
83
|
+
n_times: Optional[int] = None,
|
|
84
|
+
# model's parameters
|
|
85
|
+
final_conv_length: str | int = "auto",
|
|
86
|
+
pool_mode: str = "mean",
|
|
87
|
+
F1: int = 8,
|
|
88
|
+
D: int = 2,
|
|
89
|
+
F2: Optional[int | None] = None,
|
|
90
|
+
kernel_length: int = 64,
|
|
91
|
+
*,
|
|
92
|
+
depthwise_kernel_length: int = 16,
|
|
93
|
+
pool1_kernel_size: int = 4,
|
|
94
|
+
pool2_kernel_size: int = 8,
|
|
95
|
+
conv_spatial_max_norm: int = 1,
|
|
96
|
+
activation: nn.Module = nn.ELU,
|
|
97
|
+
batch_norm_momentum: float = 0.01,
|
|
98
|
+
batch_norm_affine: bool = True,
|
|
99
|
+
batch_norm_eps: float = 1e-3,
|
|
100
|
+
drop_prob: float = 0.25,
|
|
101
|
+
final_layer_with_constraint: bool = False,
|
|
102
|
+
norm_rate: float = 0.25,
|
|
103
|
+
# Other ways to construct the signal related parameters
|
|
104
|
+
chs_info: Optional[list[Dict]] = None,
|
|
105
|
+
input_window_seconds=None,
|
|
106
|
+
sfreq=None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
super().__init__(
|
|
110
|
+
n_outputs=n_outputs,
|
|
111
|
+
n_chans=n_chans,
|
|
112
|
+
chs_info=chs_info,
|
|
113
|
+
n_times=n_times,
|
|
114
|
+
input_window_seconds=input_window_seconds,
|
|
115
|
+
sfreq=sfreq,
|
|
116
|
+
)
|
|
117
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
118
|
+
if final_conv_length == "auto":
|
|
119
|
+
assert self.n_times is not None
|
|
120
|
+
|
|
121
|
+
if not final_layer_with_constraint:
|
|
122
|
+
warn(
|
|
123
|
+
"Parameter 'final_layer_with_constraint=False' is deprecated and will be "
|
|
124
|
+
"removed in a future release. Please use `final_layer_linear=True`.",
|
|
125
|
+
DeprecationWarning,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if "third_kernel_size" in kwargs:
|
|
129
|
+
warn(
|
|
130
|
+
"The parameter `third_kernel_size` is deprecated "
|
|
131
|
+
"and will be removed in a future version.",
|
|
132
|
+
)
|
|
133
|
+
unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
|
|
134
|
+
if unexpected_kwargs:
|
|
135
|
+
raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
|
|
136
|
+
|
|
137
|
+
self.final_conv_length = final_conv_length
|
|
138
|
+
self.pool_mode = pool_mode
|
|
139
|
+
self.F1 = F1
|
|
140
|
+
self.D = D
|
|
141
|
+
|
|
142
|
+
if F2 is None:
|
|
143
|
+
F2 = self.F1 * self.D
|
|
144
|
+
self.F2 = F2
|
|
145
|
+
|
|
146
|
+
self.kernel_length = kernel_length
|
|
147
|
+
self.depthwise_kernel_length = depthwise_kernel_length
|
|
148
|
+
self.pool1_kernel_size = pool1_kernel_size
|
|
149
|
+
self.pool2_kernel_size = pool2_kernel_size
|
|
150
|
+
self.drop_prob = drop_prob
|
|
151
|
+
self.activation = activation
|
|
152
|
+
self.batch_norm_momentum = batch_norm_momentum
|
|
153
|
+
self.batch_norm_affine = batch_norm_affine
|
|
154
|
+
self.batch_norm_eps = batch_norm_eps
|
|
155
|
+
self.conv_spatial_max_norm = conv_spatial_max_norm
|
|
156
|
+
self.norm_rate = norm_rate
|
|
157
|
+
|
|
158
|
+
# For the load_state_dict
|
|
159
|
+
# When padronize all layers,
|
|
160
|
+
# add the old's parameters here
|
|
161
|
+
self.mapping = {
|
|
162
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
163
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
167
|
+
self.add_module("ensuredims", Ensure4d())
|
|
168
|
+
|
|
169
|
+
self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
170
|
+
self.add_module(
|
|
171
|
+
"conv_temporal",
|
|
172
|
+
nn.Conv2d(
|
|
173
|
+
1,
|
|
174
|
+
self.F1,
|
|
175
|
+
(1, self.kernel_length),
|
|
176
|
+
bias=False,
|
|
177
|
+
padding=(0, self.kernel_length // 2),
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
self.add_module(
|
|
181
|
+
"bnorm_temporal",
|
|
182
|
+
nn.BatchNorm2d(
|
|
183
|
+
self.F1,
|
|
184
|
+
momentum=self.batch_norm_momentum,
|
|
185
|
+
affine=self.batch_norm_affine,
|
|
186
|
+
eps=self.batch_norm_eps,
|
|
187
|
+
),
|
|
188
|
+
)
|
|
189
|
+
self.add_module(
|
|
190
|
+
"conv_spatial",
|
|
191
|
+
Conv2dWithConstraint(
|
|
192
|
+
in_channels=self.F1,
|
|
193
|
+
out_channels=self.F1 * self.D,
|
|
194
|
+
kernel_size=(self.n_chans, 1),
|
|
195
|
+
max_norm=self.conv_spatial_max_norm,
|
|
196
|
+
bias=False,
|
|
197
|
+
groups=self.F1,
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self.add_module(
|
|
202
|
+
"bnorm_1",
|
|
203
|
+
nn.BatchNorm2d(
|
|
204
|
+
self.F1 * self.D,
|
|
205
|
+
momentum=self.batch_norm_momentum,
|
|
206
|
+
affine=self.batch_norm_affine,
|
|
207
|
+
eps=self.batch_norm_eps,
|
|
208
|
+
),
|
|
209
|
+
)
|
|
210
|
+
self.add_module("elu_1", activation())
|
|
211
|
+
|
|
212
|
+
self.add_module(
|
|
213
|
+
"pool_1",
|
|
214
|
+
pool_class(
|
|
215
|
+
kernel_size=(1, self.pool1_kernel_size),
|
|
216
|
+
),
|
|
217
|
+
)
|
|
218
|
+
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
219
|
+
|
|
220
|
+
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
|
|
221
|
+
self.add_module(
|
|
222
|
+
"conv_separable_depth",
|
|
223
|
+
nn.Conv2d(
|
|
224
|
+
self.F1 * self.D,
|
|
225
|
+
self.F1 * self.D,
|
|
226
|
+
(1, self.depthwise_kernel_length),
|
|
227
|
+
bias=False,
|
|
228
|
+
groups=self.F1 * self.D,
|
|
229
|
+
padding=(0, self.depthwise_kernel_length // 2),
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
self.add_module(
|
|
233
|
+
"conv_separable_point",
|
|
234
|
+
nn.Conv2d(
|
|
235
|
+
self.F1 * self.D,
|
|
236
|
+
self.F2,
|
|
237
|
+
kernel_size=(1, 1),
|
|
238
|
+
bias=False,
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
self.add_module(
|
|
243
|
+
"bnorm_2",
|
|
244
|
+
nn.BatchNorm2d(
|
|
245
|
+
self.F2,
|
|
246
|
+
momentum=self.batch_norm_momentum,
|
|
247
|
+
affine=self.batch_norm_affine,
|
|
248
|
+
eps=self.batch_norm_eps,
|
|
249
|
+
),
|
|
250
|
+
)
|
|
251
|
+
self.add_module("elu_2", self.activation())
|
|
252
|
+
self.add_module(
|
|
253
|
+
"pool_2",
|
|
254
|
+
pool_class(
|
|
255
|
+
kernel_size=(1, self.pool2_kernel_size),
|
|
256
|
+
),
|
|
257
|
+
)
|
|
258
|
+
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
259
|
+
|
|
260
|
+
output_shape = self.get_output_shape()
|
|
261
|
+
n_out_virtual_chans = output_shape[2]
|
|
262
|
+
|
|
263
|
+
if self.final_conv_length == "auto":
|
|
264
|
+
n_out_time = output_shape[3]
|
|
265
|
+
self.final_conv_length = n_out_time
|
|
266
|
+
|
|
267
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
268
|
+
module = nn.Sequential()
|
|
269
|
+
if not final_layer_with_constraint:
|
|
270
|
+
module.add_module(
|
|
271
|
+
"conv_classifier",
|
|
272
|
+
nn.Conv2d(
|
|
273
|
+
self.F2,
|
|
274
|
+
self.n_outputs,
|
|
275
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
276
|
+
bias=True,
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Transpose back to the logic of braindecode,
|
|
281
|
+
# so time in third dimension (axis=2)
|
|
282
|
+
module.add_module(
|
|
283
|
+
"permute_back",
|
|
284
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
288
|
+
else:
|
|
289
|
+
module.add_module("flatten", nn.Flatten())
|
|
290
|
+
module.add_module(
|
|
291
|
+
"linearconstraint",
|
|
292
|
+
LinearWithConstraint(
|
|
293
|
+
in_features=self.F2 * self.final_conv_length,
|
|
294
|
+
out_features=self.n_outputs,
|
|
295
|
+
max_norm=norm_rate,
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
self.add_module("final_layer", module)
|
|
299
|
+
|
|
300
|
+
glorot_weight_zero_bias(self)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
304
|
+
"""EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
|
|
305
|
+
|
|
306
|
+
See details in [EEGNet]_.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
in_chans :
|
|
311
|
+
Alias for n_chans.
|
|
312
|
+
n_classes:
|
|
313
|
+
Alias for n_outputs.
|
|
314
|
+
input_window_samples :
|
|
315
|
+
Alias for n_times.
|
|
316
|
+
activation: nn.Module, default=nn.ELU
|
|
317
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
318
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
319
|
+
|
|
320
|
+
Notes
|
|
321
|
+
-----
|
|
322
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
323
|
+
by original authors, only reimplemented from the paper description.
|
|
324
|
+
|
|
325
|
+
References
|
|
326
|
+
----------
|
|
327
|
+
.. [EEGNet] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
|
328
|
+
S. M., Hung, C. P., & Lance, B. J. (2016).
|
|
329
|
+
EEGNet: A Compact Convolutional Network for EEG-based
|
|
330
|
+
Brain-Computer Interfaces.
|
|
331
|
+
arXiv preprint arXiv:1611.08024.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
def __init__(
|
|
335
|
+
self,
|
|
336
|
+
n_chans=None,
|
|
337
|
+
n_outputs=None,
|
|
338
|
+
n_times=None,
|
|
339
|
+
final_conv_length="auto",
|
|
340
|
+
pool_mode="max",
|
|
341
|
+
second_kernel_size=(2, 32),
|
|
342
|
+
third_kernel_size=(8, 4),
|
|
343
|
+
drop_prob=0.25,
|
|
344
|
+
activation: nn.Module = nn.ELU,
|
|
345
|
+
chs_info=None,
|
|
346
|
+
input_window_seconds=None,
|
|
347
|
+
sfreq=None,
|
|
348
|
+
):
|
|
349
|
+
super().__init__(
|
|
350
|
+
n_outputs=n_outputs,
|
|
351
|
+
n_chans=n_chans,
|
|
352
|
+
chs_info=chs_info,
|
|
353
|
+
n_times=n_times,
|
|
354
|
+
input_window_seconds=input_window_seconds,
|
|
355
|
+
sfreq=sfreq,
|
|
356
|
+
)
|
|
357
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
358
|
+
warn(
|
|
359
|
+
"The class EEGNetv1 is deprecated and will be removed in the "
|
|
360
|
+
"release 1.0 of braindecode. Please use "
|
|
361
|
+
"braindecode.models.EEGNetv4 instead in the future.",
|
|
362
|
+
DeprecationWarning,
|
|
363
|
+
)
|
|
364
|
+
if final_conv_length == "auto":
|
|
365
|
+
assert self.n_times is not None
|
|
366
|
+
self.final_conv_length = final_conv_length
|
|
367
|
+
self.pool_mode = pool_mode
|
|
368
|
+
self.second_kernel_size = second_kernel_size
|
|
369
|
+
self.third_kernel_size = third_kernel_size
|
|
370
|
+
self.drop_prob = drop_prob
|
|
371
|
+
# For the load_state_dict
|
|
372
|
+
# When padronize all layers,
|
|
373
|
+
# add the old's parameters here
|
|
374
|
+
self.mapping = {
|
|
375
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
376
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
380
|
+
self.add_module("ensuredims", Ensure4d())
|
|
381
|
+
n_filters_1 = 16
|
|
382
|
+
self.add_module(
|
|
383
|
+
"conv_1",
|
|
384
|
+
nn.Conv2d(self.n_chans, n_filters_1, (1, 1), stride=1, bias=True),
|
|
385
|
+
)
|
|
386
|
+
self.add_module(
|
|
387
|
+
"bnorm_1",
|
|
388
|
+
nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
|
|
389
|
+
)
|
|
390
|
+
self.add_module("elu_1", activation())
|
|
391
|
+
# transpose to examples x 1 x (virtual, not EEG) channels x time
|
|
392
|
+
self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
|
|
393
|
+
|
|
394
|
+
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
395
|
+
|
|
396
|
+
n_filters_2 = 4
|
|
397
|
+
# keras pads unequal padding more in front, so padding
|
|
398
|
+
# too large should be ok.
|
|
399
|
+
# Not padding in time so that cropped training makes sense
|
|
400
|
+
# https://stackoverflow.com/questions/43994604/padding-with-even-kernel-size-in-a-convolutional-layer-in-keras-theano
|
|
401
|
+
|
|
402
|
+
self.add_module(
|
|
403
|
+
"conv_2",
|
|
404
|
+
nn.Conv2d(
|
|
405
|
+
1,
|
|
406
|
+
n_filters_2,
|
|
407
|
+
self.second_kernel_size,
|
|
408
|
+
stride=1,
|
|
409
|
+
padding=(self.second_kernel_size[0] // 2, 0),
|
|
410
|
+
bias=True,
|
|
411
|
+
),
|
|
412
|
+
)
|
|
413
|
+
self.add_module(
|
|
414
|
+
"bnorm_2",
|
|
415
|
+
nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
|
|
416
|
+
)
|
|
417
|
+
self.add_module("elu_2", activation())
|
|
418
|
+
self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
419
|
+
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
420
|
+
|
|
421
|
+
n_filters_3 = 4
|
|
422
|
+
self.add_module(
|
|
423
|
+
"conv_3",
|
|
424
|
+
nn.Conv2d(
|
|
425
|
+
n_filters_2,
|
|
426
|
+
n_filters_3,
|
|
427
|
+
self.third_kernel_size,
|
|
428
|
+
stride=1,
|
|
429
|
+
padding=(self.third_kernel_size[0] // 2, 0),
|
|
430
|
+
bias=True,
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
self.add_module(
|
|
434
|
+
"bnorm_3",
|
|
435
|
+
nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
|
|
436
|
+
)
|
|
437
|
+
self.add_module("elu_3", activation())
|
|
438
|
+
self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
439
|
+
self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
|
|
440
|
+
|
|
441
|
+
output_shape = self.get_output_shape()
|
|
442
|
+
n_out_virtual_chans = output_shape[2]
|
|
443
|
+
|
|
444
|
+
if self.final_conv_length == "auto":
|
|
445
|
+
n_out_time = output_shape[3]
|
|
446
|
+
self.final_conv_length = n_out_time
|
|
447
|
+
|
|
448
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
449
|
+
module = nn.Sequential()
|
|
450
|
+
|
|
451
|
+
module.add_module(
|
|
452
|
+
"conv_classifier",
|
|
453
|
+
nn.Conv2d(
|
|
454
|
+
n_filters_3,
|
|
455
|
+
self.n_outputs,
|
|
456
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
457
|
+
bias=True,
|
|
458
|
+
),
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Transpose back to the logic of braindecode,
|
|
462
|
+
|
|
463
|
+
# so time in third dimension (axis=2)
|
|
464
|
+
module.add_module(
|
|
465
|
+
"permute_2",
|
|
466
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
470
|
+
|
|
471
|
+
self.add_module("final_layer", module)
|
|
472
|
+
|
|
473
|
+
glorot_weight_zero_bias(self)
|