braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/eegnet.py
CHANGED
|
@@ -1,85 +1,111 @@
|
|
|
1
1
|
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Optional
|
|
4
7
|
|
|
5
|
-
import torch
|
|
6
8
|
from einops.layers.torch import Rearrange
|
|
9
|
+
from mne.utils import warn
|
|
7
10
|
from torch import nn
|
|
8
|
-
from torch.nn.functional import elu
|
|
9
|
-
|
|
10
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
11
|
-
from .functions import squeeze_final_output
|
|
12
|
-
from .modules import Ensure4d, Expression
|
|
13
|
-
|
|
14
11
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
return super(Conv2dWithConstraint, self).forward(x)
|
|
12
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import (
|
|
15
|
+
Conv2dWithConstraint,
|
|
16
|
+
Ensure4d,
|
|
17
|
+
Expression,
|
|
18
|
+
LinearWithConstraint,
|
|
19
|
+
SqueezeFinalOutput,
|
|
20
|
+
)
|
|
25
21
|
|
|
26
22
|
|
|
27
23
|
class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
28
|
-
"""EEGNet v4 model from Lawhern et al 2018.
|
|
24
|
+
"""EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
|
|
25
|
+
|
|
26
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
|
|
27
|
+
:align: center
|
|
28
|
+
:alt: EEGNet4 Architecture
|
|
29
29
|
|
|
30
30
|
See details in [EEGNet4]_.
|
|
31
31
|
|
|
32
32
|
Parameters
|
|
33
33
|
----------
|
|
34
|
-
final_conv_length : int
|
|
35
|
-
If
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
34
|
+
final_conv_length : int or "auto", default="auto"
|
|
35
|
+
Length of the final convolution layer. If "auto", it is set based on n_times.
|
|
36
|
+
pool_mode : {"mean", "max"}, default="mean"
|
|
37
|
+
Pooling method to use in pooling layers.
|
|
38
|
+
F1 : int, default=8
|
|
39
|
+
Number of temporal filters in the first convolutional layer.
|
|
40
|
+
D : int, default=2
|
|
41
|
+
Depth multiplier for the depthwise convolution.
|
|
42
|
+
F2 : int or None, default=None
|
|
43
|
+
Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
|
|
44
|
+
depthwise_kernel_length : int, default=16
|
|
45
|
+
Length of the depthwise convolution kernel in the separable convolution.
|
|
46
|
+
pool1_kernel_size : int, default=4
|
|
47
|
+
Kernel size of the first pooling layer.
|
|
48
|
+
pool2_kernel_size : int, default=8
|
|
49
|
+
Kernel size of the second pooling layer.
|
|
50
|
+
kernel_length : int, default=64
|
|
51
|
+
Length of the temporal convolution kernel.
|
|
52
|
+
conv_spatial_max_norm : float, default=1
|
|
53
|
+
Maximum norm constraint for the spatial (depthwise) convolution.
|
|
54
|
+
activation : nn.Module, default=nn.ELU
|
|
55
|
+
Non-linear activation function to be used in the layers.
|
|
56
|
+
batch_norm_momentum : float, default=0.01
|
|
57
|
+
Momentum for instance normalization in batch norm layers.
|
|
58
|
+
batch_norm_affine : bool, default=True
|
|
59
|
+
If True, batch norm has learnable affine parameters.
|
|
60
|
+
batch_norm_eps : float, default=1e-3
|
|
61
|
+
Epsilon for numeric stability in batch norm layers.
|
|
62
|
+
drop_prob : float, default=0.25
|
|
63
|
+
Dropout probability.
|
|
64
|
+
final_layer_with_constraint : bool, default=False
|
|
65
|
+
If ``False``, uses a convolution-based classification layer. If ``True``,
|
|
66
|
+
apply a flattened linear layer with constraint on the weights norm as the final classification step.
|
|
67
|
+
norm_rate : float, default=0.25
|
|
68
|
+
Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
|
|
47
69
|
|
|
48
70
|
References
|
|
49
71
|
----------
|
|
50
|
-
.. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
arXiv preprint arXiv:1611.08024.
|
|
72
|
+
.. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
|
|
73
|
+
Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
|
|
74
|
+
neural network for EEG-based brain–computer interfaces. Journal of
|
|
75
|
+
neural engineering, 15(5), 056013.
|
|
55
76
|
"""
|
|
56
77
|
|
|
57
78
|
def __init__(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
79
|
+
self,
|
|
80
|
+
# signal's parameters
|
|
81
|
+
n_chans: Optional[int] = None,
|
|
82
|
+
n_outputs: Optional[int] = None,
|
|
83
|
+
n_times: Optional[int] = None,
|
|
84
|
+
# model's parameters
|
|
85
|
+
final_conv_length: str | int = "auto",
|
|
86
|
+
pool_mode: str = "mean",
|
|
87
|
+
F1: int = 8,
|
|
88
|
+
D: int = 2,
|
|
89
|
+
F2: Optional[int | None] = None,
|
|
90
|
+
kernel_length: int = 64,
|
|
91
|
+
*,
|
|
92
|
+
depthwise_kernel_length: int = 16,
|
|
93
|
+
pool1_kernel_size: int = 4,
|
|
94
|
+
pool2_kernel_size: int = 8,
|
|
95
|
+
conv_spatial_max_norm: int = 1,
|
|
96
|
+
activation: nn.Module = nn.ELU,
|
|
97
|
+
batch_norm_momentum: float = 0.01,
|
|
98
|
+
batch_norm_affine: bool = True,
|
|
99
|
+
batch_norm_eps: float = 1e-3,
|
|
100
|
+
drop_prob: float = 0.25,
|
|
101
|
+
final_layer_with_constraint: bool = False,
|
|
102
|
+
norm_rate: float = 0.25,
|
|
103
|
+
# Other ways to construct the signal related parameters
|
|
104
|
+
chs_info: Optional[list[Dict]] = None,
|
|
105
|
+
input_window_seconds=None,
|
|
106
|
+
sfreq=None,
|
|
107
|
+
**kwargs,
|
|
76
108
|
):
|
|
77
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
78
|
-
self,
|
|
79
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
80
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
81
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
82
|
-
)
|
|
83
109
|
super().__init__(
|
|
84
110
|
n_outputs=n_outputs,
|
|
85
111
|
n_chans=n_chans,
|
|
@@ -89,68 +115,106 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
89
115
|
sfreq=sfreq,
|
|
90
116
|
)
|
|
91
117
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
92
|
-
del in_chans, n_classes, input_window_samples
|
|
93
118
|
if final_conv_length == "auto":
|
|
94
119
|
assert self.n_times is not None
|
|
120
|
+
|
|
121
|
+
if not final_layer_with_constraint:
|
|
122
|
+
warn(
|
|
123
|
+
"Parameter 'final_layer_with_constraint=False' is deprecated and will be "
|
|
124
|
+
"removed in a future release. Please use `final_layer_linear=True`.",
|
|
125
|
+
DeprecationWarning,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if "third_kernel_size" in kwargs:
|
|
129
|
+
warn(
|
|
130
|
+
"The parameter `third_kernel_size` is deprecated "
|
|
131
|
+
"and will be removed in a future version.",
|
|
132
|
+
)
|
|
133
|
+
unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
|
|
134
|
+
if unexpected_kwargs:
|
|
135
|
+
raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
|
|
136
|
+
|
|
95
137
|
self.final_conv_length = final_conv_length
|
|
96
138
|
self.pool_mode = pool_mode
|
|
97
139
|
self.F1 = F1
|
|
98
140
|
self.D = D
|
|
141
|
+
|
|
142
|
+
if F2 is None:
|
|
143
|
+
F2 = self.F1 * self.D
|
|
99
144
|
self.F2 = F2
|
|
145
|
+
|
|
100
146
|
self.kernel_length = kernel_length
|
|
101
|
-
self.
|
|
147
|
+
self.depthwise_kernel_length = depthwise_kernel_length
|
|
148
|
+
self.pool1_kernel_size = pool1_kernel_size
|
|
149
|
+
self.pool2_kernel_size = pool2_kernel_size
|
|
102
150
|
self.drop_prob = drop_prob
|
|
151
|
+
self.activation = activation
|
|
152
|
+
self.batch_norm_momentum = batch_norm_momentum
|
|
153
|
+
self.batch_norm_affine = batch_norm_affine
|
|
154
|
+
self.batch_norm_eps = batch_norm_eps
|
|
155
|
+
self.conv_spatial_max_norm = conv_spatial_max_norm
|
|
156
|
+
self.norm_rate = norm_rate
|
|
157
|
+
|
|
103
158
|
# For the load_state_dict
|
|
104
159
|
# When padronize all layers,
|
|
105
160
|
# add the old's parameters here
|
|
106
161
|
self.mapping = {
|
|
107
162
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
108
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
163
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
109
164
|
}
|
|
110
165
|
|
|
111
166
|
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
112
167
|
self.add_module("ensuredims", Ensure4d())
|
|
113
168
|
|
|
114
|
-
self.add_module("dimshuffle",
|
|
115
|
-
Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
169
|
+
self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
116
170
|
self.add_module(
|
|
117
171
|
"conv_temporal",
|
|
118
172
|
nn.Conv2d(
|
|
119
173
|
1,
|
|
120
174
|
self.F1,
|
|
121
175
|
(1, self.kernel_length),
|
|
122
|
-
stride=1,
|
|
123
176
|
bias=False,
|
|
124
177
|
padding=(0, self.kernel_length // 2),
|
|
125
178
|
),
|
|
126
179
|
)
|
|
127
180
|
self.add_module(
|
|
128
181
|
"bnorm_temporal",
|
|
129
|
-
nn.BatchNorm2d(
|
|
182
|
+
nn.BatchNorm2d(
|
|
183
|
+
self.F1,
|
|
184
|
+
momentum=self.batch_norm_momentum,
|
|
185
|
+
affine=self.batch_norm_affine,
|
|
186
|
+
eps=self.batch_norm_eps,
|
|
187
|
+
),
|
|
130
188
|
)
|
|
131
189
|
self.add_module(
|
|
132
190
|
"conv_spatial",
|
|
133
191
|
Conv2dWithConstraint(
|
|
134
|
-
self.F1,
|
|
135
|
-
self.F1 * self.D,
|
|
136
|
-
(self.n_chans, 1),
|
|
137
|
-
max_norm=
|
|
138
|
-
stride=1,
|
|
192
|
+
in_channels=self.F1,
|
|
193
|
+
out_channels=self.F1 * self.D,
|
|
194
|
+
kernel_size=(self.n_chans, 1),
|
|
195
|
+
max_norm=self.conv_spatial_max_norm,
|
|
139
196
|
bias=False,
|
|
140
197
|
groups=self.F1,
|
|
141
|
-
padding=(0, 0),
|
|
142
198
|
),
|
|
143
199
|
)
|
|
144
200
|
|
|
145
201
|
self.add_module(
|
|
146
202
|
"bnorm_1",
|
|
147
203
|
nn.BatchNorm2d(
|
|
148
|
-
self.F1 * self.D,
|
|
204
|
+
self.F1 * self.D,
|
|
205
|
+
momentum=self.batch_norm_momentum,
|
|
206
|
+
affine=self.batch_norm_affine,
|
|
207
|
+
eps=self.batch_norm_eps,
|
|
149
208
|
),
|
|
150
209
|
)
|
|
151
|
-
self.add_module("elu_1",
|
|
210
|
+
self.add_module("elu_1", activation())
|
|
152
211
|
|
|
153
|
-
self.add_module(
|
|
212
|
+
self.add_module(
|
|
213
|
+
"pool_1",
|
|
214
|
+
pool_class(
|
|
215
|
+
kernel_size=(1, self.pool1_kernel_size),
|
|
216
|
+
),
|
|
217
|
+
)
|
|
154
218
|
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
155
219
|
|
|
156
220
|
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
|
|
@@ -159,11 +223,10 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
159
223
|
nn.Conv2d(
|
|
160
224
|
self.F1 * self.D,
|
|
161
225
|
self.F1 * self.D,
|
|
162
|
-
(1,
|
|
163
|
-
stride=1,
|
|
226
|
+
(1, self.depthwise_kernel_length),
|
|
164
227
|
bias=False,
|
|
165
228
|
groups=self.F1 * self.D,
|
|
166
|
-
padding=(0,
|
|
229
|
+
padding=(0, self.depthwise_kernel_length // 2),
|
|
167
230
|
),
|
|
168
231
|
)
|
|
169
232
|
self.add_module(
|
|
@@ -171,19 +234,27 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
171
234
|
nn.Conv2d(
|
|
172
235
|
self.F1 * self.D,
|
|
173
236
|
self.F2,
|
|
174
|
-
(1, 1),
|
|
175
|
-
stride=1,
|
|
237
|
+
kernel_size=(1, 1),
|
|
176
238
|
bias=False,
|
|
177
|
-
padding=(0, 0),
|
|
178
239
|
),
|
|
179
240
|
)
|
|
180
241
|
|
|
181
242
|
self.add_module(
|
|
182
243
|
"bnorm_2",
|
|
183
|
-
nn.BatchNorm2d(
|
|
244
|
+
nn.BatchNorm2d(
|
|
245
|
+
self.F2,
|
|
246
|
+
momentum=self.batch_norm_momentum,
|
|
247
|
+
affine=self.batch_norm_affine,
|
|
248
|
+
eps=self.batch_norm_eps,
|
|
249
|
+
),
|
|
250
|
+
)
|
|
251
|
+
self.add_module("elu_2", self.activation())
|
|
252
|
+
self.add_module(
|
|
253
|
+
"pool_2",
|
|
254
|
+
pool_class(
|
|
255
|
+
kernel_size=(1, self.pool2_kernel_size),
|
|
256
|
+
),
|
|
184
257
|
)
|
|
185
|
-
self.add_module("elu_2", Expression(elu))
|
|
186
|
-
self.add_module("pool_2", pool_class(kernel_size=(1, 8), stride=(1, 8)))
|
|
187
258
|
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
188
259
|
|
|
189
260
|
output_shape = self.get_output_shape()
|
|
@@ -195,27 +266,42 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
195
266
|
|
|
196
267
|
# Incorporating classification module and subsequent ones in one final layer
|
|
197
268
|
module = nn.Sequential()
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
269
|
+
if not final_layer_with_constraint:
|
|
270
|
+
module.add_module(
|
|
271
|
+
"conv_classifier",
|
|
272
|
+
nn.Conv2d(
|
|
273
|
+
self.F2,
|
|
274
|
+
self.n_outputs,
|
|
275
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
276
|
+
bias=True,
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Transpose back to the logic of braindecode,
|
|
281
|
+
# so time in third dimension (axis=2)
|
|
282
|
+
module.add_module(
|
|
283
|
+
"permute_back",
|
|
284
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
288
|
+
else:
|
|
289
|
+
module.add_module("flatten", nn.Flatten())
|
|
290
|
+
module.add_module(
|
|
291
|
+
"linearconstraint",
|
|
292
|
+
LinearWithConstraint(
|
|
293
|
+
in_features=self.F2 * self.final_conv_length,
|
|
294
|
+
out_features=self.n_outputs,
|
|
295
|
+
max_norm=norm_rate,
|
|
296
|
+
),
|
|
297
|
+
)
|
|
212
298
|
self.add_module("final_layer", module)
|
|
213
299
|
|
|
214
|
-
|
|
300
|
+
glorot_weight_zero_bias(self)
|
|
215
301
|
|
|
216
302
|
|
|
217
303
|
class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
218
|
-
"""EEGNet model from Lawhern et al. 2016.
|
|
304
|
+
"""EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
|
|
219
305
|
|
|
220
306
|
See details in [EEGNet]_.
|
|
221
307
|
|
|
@@ -227,6 +313,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
227
313
|
Alias for n_outputs.
|
|
228
314
|
input_window_samples :
|
|
229
315
|
Alias for n_times.
|
|
316
|
+
activation: nn.Module, default=nn.ELU
|
|
317
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
318
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
230
319
|
|
|
231
320
|
Notes
|
|
232
321
|
-----
|
|
@@ -243,29 +332,20 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
243
332
|
"""
|
|
244
333
|
|
|
245
334
|
def __init__(
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
n_classes=None,
|
|
260
|
-
input_window_samples=None,
|
|
261
|
-
add_log_softmax=True,
|
|
335
|
+
self,
|
|
336
|
+
n_chans=None,
|
|
337
|
+
n_outputs=None,
|
|
338
|
+
n_times=None,
|
|
339
|
+
final_conv_length="auto",
|
|
340
|
+
pool_mode="max",
|
|
341
|
+
second_kernel_size=(2, 32),
|
|
342
|
+
third_kernel_size=(8, 4),
|
|
343
|
+
drop_prob=0.25,
|
|
344
|
+
activation: nn.Module = nn.ELU,
|
|
345
|
+
chs_info=None,
|
|
346
|
+
input_window_seconds=None,
|
|
347
|
+
sfreq=None,
|
|
262
348
|
):
|
|
263
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
264
|
-
self,
|
|
265
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
266
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
267
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
268
|
-
)
|
|
269
349
|
super().__init__(
|
|
270
350
|
n_outputs=n_outputs,
|
|
271
351
|
n_chans=n_chans,
|
|
@@ -273,10 +353,14 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
273
353
|
n_times=n_times,
|
|
274
354
|
input_window_seconds=input_window_seconds,
|
|
275
355
|
sfreq=sfreq,
|
|
276
|
-
add_log_softmax=add_log_softmax,
|
|
277
356
|
)
|
|
278
357
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
279
|
-
|
|
358
|
+
warn(
|
|
359
|
+
"The class EEGNetv1 is deprecated and will be removed in the "
|
|
360
|
+
"release 1.0 of braindecode. Please use "
|
|
361
|
+
"braindecode.models.EEGNetv4 instead in the future.",
|
|
362
|
+
DeprecationWarning,
|
|
363
|
+
)
|
|
280
364
|
if final_conv_length == "auto":
|
|
281
365
|
assert self.n_times is not None
|
|
282
366
|
self.final_conv_length = final_conv_length
|
|
@@ -289,7 +373,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
289
373
|
# add the old's parameters here
|
|
290
374
|
self.mapping = {
|
|
291
375
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
292
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
376
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
293
377
|
}
|
|
294
378
|
|
|
295
379
|
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
@@ -303,11 +387,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
303
387
|
"bnorm_1",
|
|
304
388
|
nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
|
|
305
389
|
)
|
|
306
|
-
self.add_module("elu_1",
|
|
390
|
+
self.add_module("elu_1", activation())
|
|
307
391
|
# transpose to examples x 1 x (virtual, not EEG) channels x time
|
|
308
|
-
self.add_module(
|
|
309
|
-
"permute_1", Expression(lambda x: x.permute(0, 3, 1, 2))
|
|
310
|
-
)
|
|
392
|
+
self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
|
|
311
393
|
|
|
312
394
|
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
313
395
|
|
|
@@ -332,7 +414,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
332
414
|
"bnorm_2",
|
|
333
415
|
nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
|
|
334
416
|
)
|
|
335
|
-
self.add_module("elu_2",
|
|
417
|
+
self.add_module("elu_2", activation())
|
|
336
418
|
self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
337
419
|
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
338
420
|
|
|
@@ -352,7 +434,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
352
434
|
"bnorm_3",
|
|
353
435
|
nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
|
|
354
436
|
)
|
|
355
|
-
self.add_module("elu_3",
|
|
437
|
+
self.add_module("elu_3", activation())
|
|
356
438
|
self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
357
439
|
self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
|
|
358
440
|
|
|
@@ -366,40 +448,26 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
366
448
|
# Incorporating classification module and subsequent ones in one final layer
|
|
367
449
|
module = nn.Sequential()
|
|
368
450
|
|
|
369
|
-
module.add_module(
|
|
370
|
-
|
|
371
|
-
|
|
451
|
+
module.add_module(
|
|
452
|
+
"conv_classifier",
|
|
453
|
+
nn.Conv2d(
|
|
454
|
+
n_filters_3,
|
|
455
|
+
self.n_outputs,
|
|
456
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
457
|
+
bias=True,
|
|
458
|
+
),
|
|
459
|
+
)
|
|
372
460
|
|
|
373
|
-
if self.add_log_softmax:
|
|
374
|
-
module.add_module("softmax", nn.LogSoftmax(dim=1))
|
|
375
461
|
# Transpose back to the logic of braindecode,
|
|
376
462
|
|
|
377
463
|
# so time in third dimension (axis=2)
|
|
378
|
-
module.add_module(
|
|
464
|
+
module.add_module(
|
|
465
|
+
"permute_2",
|
|
466
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
467
|
+
)
|
|
379
468
|
|
|
380
|
-
module.add_module("squeeze",
|
|
469
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
381
470
|
|
|
382
471
|
self.add_module("final_layer", module)
|
|
383
472
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def _glorot_weight_zero_bias(model):
|
|
388
|
-
"""Initialize parameters of all modules by initializing weights with
|
|
389
|
-
glorot
|
|
390
|
-
uniform/xavier initialization, and setting biases to zero. Weights from
|
|
391
|
-
batch norm layers are set to 1.
|
|
392
|
-
|
|
393
|
-
Parameters
|
|
394
|
-
----------
|
|
395
|
-
model: Module
|
|
396
|
-
"""
|
|
397
|
-
for module in model.modules():
|
|
398
|
-
if hasattr(module, "weight"):
|
|
399
|
-
if "BatchNorm" not in module.__class__.__name__:
|
|
400
|
-
nn.init.xavier_uniform_(module.weight, gain=1)
|
|
401
|
-
else:
|
|
402
|
-
nn.init.constant_(module.weight, 1)
|
|
403
|
-
if hasattr(module, "bias"):
|
|
404
|
-
if module.bias is not None:
|
|
405
|
-
nn.init.constant_(module.bias, 0)
|
|
473
|
+
glorot_weight_zero_bias(self)
|