braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/eegnet.py
CHANGED
|
@@ -1,85 +1,110 @@
|
|
|
1
1
|
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Optional
|
|
4
7
|
|
|
5
|
-
import torch
|
|
6
8
|
from einops.layers.torch import Rearrange
|
|
9
|
+
from mne.utils import warn
|
|
7
10
|
from torch import nn
|
|
8
|
-
from torch.nn.functional import elu
|
|
9
|
-
|
|
10
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
11
|
-
from .functions import squeeze_final_output
|
|
12
|
-
from .modules import Ensure4d, Expression
|
|
13
|
-
|
|
14
11
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
)
|
|
24
|
-
return super(Conv2dWithConstraint, self).forward(x)
|
|
12
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import (
|
|
15
|
+
Conv2dWithConstraint,
|
|
16
|
+
Ensure4d,
|
|
17
|
+
LinearWithConstraint,
|
|
18
|
+
SqueezeFinalOutput,
|
|
19
|
+
)
|
|
25
20
|
|
|
26
21
|
|
|
27
22
|
class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
28
|
-
"""EEGNet v4 model from Lawhern et al 2018.
|
|
23
|
+
"""EEGNet v4 model from Lawhern et al. (2018) [EEGNet4]_.
|
|
24
|
+
|
|
25
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
|
|
26
|
+
:align: center
|
|
27
|
+
:alt: EEGNet4 Architecture
|
|
29
28
|
|
|
30
29
|
See details in [EEGNet4]_.
|
|
31
30
|
|
|
32
31
|
Parameters
|
|
33
32
|
----------
|
|
34
|
-
final_conv_length : int
|
|
35
|
-
If
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
33
|
+
final_conv_length : int or "auto", default="auto"
|
|
34
|
+
Length of the final convolution layer. If "auto", it is set based on n_times.
|
|
35
|
+
pool_mode : {"mean", "max"}, default="mean"
|
|
36
|
+
Pooling method to use in pooling layers.
|
|
37
|
+
F1 : int, default=8
|
|
38
|
+
Number of temporal filters in the first convolutional layer.
|
|
39
|
+
D : int, default=2
|
|
40
|
+
Depth multiplier for the depthwise convolution.
|
|
41
|
+
F2 : int or None, default=None
|
|
42
|
+
Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
|
|
43
|
+
depthwise_kernel_length : int, default=16
|
|
44
|
+
Length of the depthwise convolution kernel in the separable convolution.
|
|
45
|
+
pool1_kernel_size : int, default=4
|
|
46
|
+
Kernel size of the first pooling layer.
|
|
47
|
+
pool2_kernel_size : int, default=8
|
|
48
|
+
Kernel size of the second pooling layer.
|
|
49
|
+
kernel_length : int, default=64
|
|
50
|
+
Length of the temporal convolution kernel.
|
|
51
|
+
conv_spatial_max_norm : float, default=1
|
|
52
|
+
Maximum norm constraint for the spatial (depthwise) convolution.
|
|
53
|
+
activation : nn.Module, default=nn.ELU
|
|
54
|
+
Non-linear activation function to be used in the layers.
|
|
55
|
+
batch_norm_momentum : float, default=0.01
|
|
56
|
+
Momentum for instance normalization in batch norm layers.
|
|
57
|
+
batch_norm_affine : bool, default=True
|
|
58
|
+
If True, batch norm has learnable affine parameters.
|
|
59
|
+
batch_norm_eps : float, default=1e-3
|
|
60
|
+
Epsilon for numeric stability in batch norm layers.
|
|
61
|
+
drop_prob : float, default=0.25
|
|
62
|
+
Dropout probability.
|
|
63
|
+
final_layer_with_constraint : bool, default=False
|
|
64
|
+
If ``False``, uses a convolution-based classification layer. If ``True``,
|
|
65
|
+
apply a flattened linear layer with constraint on the weights norm as the final classification step.
|
|
66
|
+
norm_rate : float, default=0.25
|
|
67
|
+
Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
|
|
47
68
|
|
|
48
69
|
References
|
|
49
70
|
----------
|
|
50
|
-
.. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
arXiv preprint arXiv:1611.08024.
|
|
71
|
+
.. [EEGNet4] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
|
|
72
|
+
Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
|
|
73
|
+
neural network for EEG-based brain–computer interfaces. Journal of
|
|
74
|
+
neural engineering, 15(5), 056013.
|
|
55
75
|
"""
|
|
56
76
|
|
|
57
77
|
def __init__(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
78
|
+
self,
|
|
79
|
+
# signal's parameters
|
|
80
|
+
n_chans: Optional[int] = None,
|
|
81
|
+
n_outputs: Optional[int] = None,
|
|
82
|
+
n_times: Optional[int] = None,
|
|
83
|
+
# model's parameters
|
|
84
|
+
final_conv_length: str | int = "auto",
|
|
85
|
+
pool_mode: str = "mean",
|
|
86
|
+
F1: int = 8,
|
|
87
|
+
D: int = 2,
|
|
88
|
+
F2: Optional[int | None] = None,
|
|
89
|
+
kernel_length: int = 64,
|
|
90
|
+
*,
|
|
91
|
+
depthwise_kernel_length: int = 16,
|
|
92
|
+
pool1_kernel_size: int = 4,
|
|
93
|
+
pool2_kernel_size: int = 8,
|
|
94
|
+
conv_spatial_max_norm: int = 1,
|
|
95
|
+
activation: nn.Module = nn.ELU,
|
|
96
|
+
batch_norm_momentum: float = 0.01,
|
|
97
|
+
batch_norm_affine: bool = True,
|
|
98
|
+
batch_norm_eps: float = 1e-3,
|
|
99
|
+
drop_prob: float = 0.25,
|
|
100
|
+
final_layer_with_constraint: bool = False,
|
|
101
|
+
norm_rate: float = 0.25,
|
|
102
|
+
# Other ways to construct the signal related parameters
|
|
103
|
+
chs_info: Optional[list[Dict]] = None,
|
|
104
|
+
input_window_seconds=None,
|
|
105
|
+
sfreq=None,
|
|
106
|
+
**kwargs,
|
|
76
107
|
):
|
|
77
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
78
|
-
self,
|
|
79
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
80
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
81
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
82
|
-
)
|
|
83
108
|
super().__init__(
|
|
84
109
|
n_outputs=n_outputs,
|
|
85
110
|
n_chans=n_chans,
|
|
@@ -89,68 +114,106 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
89
114
|
sfreq=sfreq,
|
|
90
115
|
)
|
|
91
116
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
92
|
-
del in_chans, n_classes, input_window_samples
|
|
93
117
|
if final_conv_length == "auto":
|
|
94
118
|
assert self.n_times is not None
|
|
119
|
+
|
|
120
|
+
if not final_layer_with_constraint:
|
|
121
|
+
warn(
|
|
122
|
+
"Parameter 'final_layer_with_constraint=False' is deprecated and will be "
|
|
123
|
+
"removed in a future release. Please use `final_layer_linear=True`.",
|
|
124
|
+
DeprecationWarning,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if "third_kernel_size" in kwargs:
|
|
128
|
+
warn(
|
|
129
|
+
"The parameter `third_kernel_size` is deprecated "
|
|
130
|
+
"and will be removed in a future version.",
|
|
131
|
+
)
|
|
132
|
+
unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
|
|
133
|
+
if unexpected_kwargs:
|
|
134
|
+
raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
|
|
135
|
+
|
|
95
136
|
self.final_conv_length = final_conv_length
|
|
96
137
|
self.pool_mode = pool_mode
|
|
97
138
|
self.F1 = F1
|
|
98
139
|
self.D = D
|
|
140
|
+
|
|
141
|
+
if F2 is None:
|
|
142
|
+
F2 = self.F1 * self.D
|
|
99
143
|
self.F2 = F2
|
|
144
|
+
|
|
100
145
|
self.kernel_length = kernel_length
|
|
101
|
-
self.
|
|
146
|
+
self.depthwise_kernel_length = depthwise_kernel_length
|
|
147
|
+
self.pool1_kernel_size = pool1_kernel_size
|
|
148
|
+
self.pool2_kernel_size = pool2_kernel_size
|
|
102
149
|
self.drop_prob = drop_prob
|
|
150
|
+
self.activation = activation
|
|
151
|
+
self.batch_norm_momentum = batch_norm_momentum
|
|
152
|
+
self.batch_norm_affine = batch_norm_affine
|
|
153
|
+
self.batch_norm_eps = batch_norm_eps
|
|
154
|
+
self.conv_spatial_max_norm = conv_spatial_max_norm
|
|
155
|
+
self.norm_rate = norm_rate
|
|
156
|
+
|
|
103
157
|
# For the load_state_dict
|
|
104
158
|
# When padronize all layers,
|
|
105
159
|
# add the old's parameters here
|
|
106
160
|
self.mapping = {
|
|
107
161
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
108
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
162
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
109
163
|
}
|
|
110
164
|
|
|
111
165
|
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
112
166
|
self.add_module("ensuredims", Ensure4d())
|
|
113
167
|
|
|
114
|
-
self.add_module("dimshuffle",
|
|
115
|
-
Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
168
|
+
self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
116
169
|
self.add_module(
|
|
117
170
|
"conv_temporal",
|
|
118
171
|
nn.Conv2d(
|
|
119
172
|
1,
|
|
120
173
|
self.F1,
|
|
121
174
|
(1, self.kernel_length),
|
|
122
|
-
stride=1,
|
|
123
175
|
bias=False,
|
|
124
176
|
padding=(0, self.kernel_length // 2),
|
|
125
177
|
),
|
|
126
178
|
)
|
|
127
179
|
self.add_module(
|
|
128
180
|
"bnorm_temporal",
|
|
129
|
-
nn.BatchNorm2d(
|
|
181
|
+
nn.BatchNorm2d(
|
|
182
|
+
self.F1,
|
|
183
|
+
momentum=self.batch_norm_momentum,
|
|
184
|
+
affine=self.batch_norm_affine,
|
|
185
|
+
eps=self.batch_norm_eps,
|
|
186
|
+
),
|
|
130
187
|
)
|
|
131
188
|
self.add_module(
|
|
132
189
|
"conv_spatial",
|
|
133
190
|
Conv2dWithConstraint(
|
|
134
|
-
self.F1,
|
|
135
|
-
self.F1 * self.D,
|
|
136
|
-
(self.n_chans, 1),
|
|
137
|
-
max_norm=
|
|
138
|
-
stride=1,
|
|
191
|
+
in_channels=self.F1,
|
|
192
|
+
out_channels=self.F1 * self.D,
|
|
193
|
+
kernel_size=(self.n_chans, 1),
|
|
194
|
+
max_norm=self.conv_spatial_max_norm,
|
|
139
195
|
bias=False,
|
|
140
196
|
groups=self.F1,
|
|
141
|
-
padding=(0, 0),
|
|
142
197
|
),
|
|
143
198
|
)
|
|
144
199
|
|
|
145
200
|
self.add_module(
|
|
146
201
|
"bnorm_1",
|
|
147
202
|
nn.BatchNorm2d(
|
|
148
|
-
self.F1 * self.D,
|
|
203
|
+
self.F1 * self.D,
|
|
204
|
+
momentum=self.batch_norm_momentum,
|
|
205
|
+
affine=self.batch_norm_affine,
|
|
206
|
+
eps=self.batch_norm_eps,
|
|
149
207
|
),
|
|
150
208
|
)
|
|
151
|
-
self.add_module("elu_1",
|
|
209
|
+
self.add_module("elu_1", activation())
|
|
152
210
|
|
|
153
|
-
self.add_module(
|
|
211
|
+
self.add_module(
|
|
212
|
+
"pool_1",
|
|
213
|
+
pool_class(
|
|
214
|
+
kernel_size=(1, self.pool1_kernel_size),
|
|
215
|
+
),
|
|
216
|
+
)
|
|
154
217
|
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
155
218
|
|
|
156
219
|
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
|
|
@@ -159,11 +222,10 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
159
222
|
nn.Conv2d(
|
|
160
223
|
self.F1 * self.D,
|
|
161
224
|
self.F1 * self.D,
|
|
162
|
-
(1,
|
|
163
|
-
stride=1,
|
|
225
|
+
(1, self.depthwise_kernel_length),
|
|
164
226
|
bias=False,
|
|
165
227
|
groups=self.F1 * self.D,
|
|
166
|
-
padding=(0,
|
|
228
|
+
padding=(0, self.depthwise_kernel_length // 2),
|
|
167
229
|
),
|
|
168
230
|
)
|
|
169
231
|
self.add_module(
|
|
@@ -171,19 +233,27 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
171
233
|
nn.Conv2d(
|
|
172
234
|
self.F1 * self.D,
|
|
173
235
|
self.F2,
|
|
174
|
-
(1, 1),
|
|
175
|
-
stride=1,
|
|
236
|
+
kernel_size=(1, 1),
|
|
176
237
|
bias=False,
|
|
177
|
-
padding=(0, 0),
|
|
178
238
|
),
|
|
179
239
|
)
|
|
180
240
|
|
|
181
241
|
self.add_module(
|
|
182
242
|
"bnorm_2",
|
|
183
|
-
nn.BatchNorm2d(
|
|
243
|
+
nn.BatchNorm2d(
|
|
244
|
+
self.F2,
|
|
245
|
+
momentum=self.batch_norm_momentum,
|
|
246
|
+
affine=self.batch_norm_affine,
|
|
247
|
+
eps=self.batch_norm_eps,
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
self.add_module("elu_2", self.activation())
|
|
251
|
+
self.add_module(
|
|
252
|
+
"pool_2",
|
|
253
|
+
pool_class(
|
|
254
|
+
kernel_size=(1, self.pool2_kernel_size),
|
|
255
|
+
),
|
|
184
256
|
)
|
|
185
|
-
self.add_module("elu_2", Expression(elu))
|
|
186
|
-
self.add_module("pool_2", pool_class(kernel_size=(1, 8), stride=(1, 8)))
|
|
187
257
|
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
188
258
|
|
|
189
259
|
output_shape = self.get_output_shape()
|
|
@@ -195,27 +265,42 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
|
|
|
195
265
|
|
|
196
266
|
# Incorporating classification module and subsequent ones in one final layer
|
|
197
267
|
module = nn.Sequential()
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
268
|
+
if not final_layer_with_constraint:
|
|
269
|
+
module.add_module(
|
|
270
|
+
"conv_classifier",
|
|
271
|
+
nn.Conv2d(
|
|
272
|
+
self.F2,
|
|
273
|
+
self.n_outputs,
|
|
274
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
275
|
+
bias=True,
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Transpose back to the logic of braindecode,
|
|
280
|
+
# so time in third dimension (axis=2)
|
|
281
|
+
module.add_module(
|
|
282
|
+
"permute_back",
|
|
283
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
287
|
+
else:
|
|
288
|
+
module.add_module("flatten", nn.Flatten())
|
|
289
|
+
module.add_module(
|
|
290
|
+
"linearconstraint",
|
|
291
|
+
LinearWithConstraint(
|
|
292
|
+
in_features=self.F2 * self.final_conv_length,
|
|
293
|
+
out_features=self.n_outputs,
|
|
294
|
+
max_norm=norm_rate,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
212
297
|
self.add_module("final_layer", module)
|
|
213
298
|
|
|
214
|
-
|
|
299
|
+
glorot_weight_zero_bias(self)
|
|
215
300
|
|
|
216
301
|
|
|
217
302
|
class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
218
|
-
"""EEGNet model from Lawhern et al. 2016.
|
|
303
|
+
"""EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
|
|
219
304
|
|
|
220
305
|
See details in [EEGNet]_.
|
|
221
306
|
|
|
@@ -227,6 +312,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
227
312
|
Alias for n_outputs.
|
|
228
313
|
input_window_samples :
|
|
229
314
|
Alias for n_times.
|
|
315
|
+
activation: nn.Module, default=nn.ELU
|
|
316
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
317
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
230
318
|
|
|
231
319
|
Notes
|
|
232
320
|
-----
|
|
@@ -243,29 +331,20 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
243
331
|
"""
|
|
244
332
|
|
|
245
333
|
def __init__(
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
n_classes=None,
|
|
260
|
-
input_window_samples=None,
|
|
261
|
-
add_log_softmax=True,
|
|
334
|
+
self,
|
|
335
|
+
n_chans=None,
|
|
336
|
+
n_outputs=None,
|
|
337
|
+
n_times=None,
|
|
338
|
+
final_conv_length="auto",
|
|
339
|
+
pool_mode="max",
|
|
340
|
+
second_kernel_size=(2, 32),
|
|
341
|
+
third_kernel_size=(8, 4),
|
|
342
|
+
drop_prob=0.25,
|
|
343
|
+
activation: nn.Module = nn.ELU,
|
|
344
|
+
chs_info=None,
|
|
345
|
+
input_window_seconds=None,
|
|
346
|
+
sfreq=None,
|
|
262
347
|
):
|
|
263
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
264
|
-
self,
|
|
265
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
266
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
267
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
268
|
-
)
|
|
269
348
|
super().__init__(
|
|
270
349
|
n_outputs=n_outputs,
|
|
271
350
|
n_chans=n_chans,
|
|
@@ -273,10 +352,14 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
273
352
|
n_times=n_times,
|
|
274
353
|
input_window_seconds=input_window_seconds,
|
|
275
354
|
sfreq=sfreq,
|
|
276
|
-
add_log_softmax=add_log_softmax,
|
|
277
355
|
)
|
|
278
356
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
279
|
-
|
|
357
|
+
warn(
|
|
358
|
+
"The class EEGNetv1 is deprecated and will be removed in the "
|
|
359
|
+
"release 1.0 of braindecode. Please use "
|
|
360
|
+
"braindecode.models.EEGNetv4 instead in the future.",
|
|
361
|
+
DeprecationWarning,
|
|
362
|
+
)
|
|
280
363
|
if final_conv_length == "auto":
|
|
281
364
|
assert self.n_times is not None
|
|
282
365
|
self.final_conv_length = final_conv_length
|
|
@@ -289,7 +372,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
289
372
|
# add the old's parameters here
|
|
290
373
|
self.mapping = {
|
|
291
374
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
292
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
375
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
293
376
|
}
|
|
294
377
|
|
|
295
378
|
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
@@ -303,11 +386,9 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
303
386
|
"bnorm_1",
|
|
304
387
|
nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
|
|
305
388
|
)
|
|
306
|
-
self.add_module("elu_1",
|
|
389
|
+
self.add_module("elu_1", activation())
|
|
307
390
|
# transpose to examples x 1 x (virtual, not EEG) channels x time
|
|
308
|
-
self.add_module(
|
|
309
|
-
"permute_1", Expression(lambda x: x.permute(0, 3, 1, 2))
|
|
310
|
-
)
|
|
391
|
+
self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
|
|
311
392
|
|
|
312
393
|
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
313
394
|
|
|
@@ -332,7 +413,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
332
413
|
"bnorm_2",
|
|
333
414
|
nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
|
|
334
415
|
)
|
|
335
|
-
self.add_module("elu_2",
|
|
416
|
+
self.add_module("elu_2", activation())
|
|
336
417
|
self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
337
418
|
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
338
419
|
|
|
@@ -352,7 +433,7 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
352
433
|
"bnorm_3",
|
|
353
434
|
nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
|
|
354
435
|
)
|
|
355
|
-
self.add_module("elu_3",
|
|
436
|
+
self.add_module("elu_3", activation())
|
|
356
437
|
self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
|
|
357
438
|
self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
|
|
358
439
|
|
|
@@ -366,40 +447,26 @@ class EEGNetv1(EEGModuleMixin, nn.Sequential):
|
|
|
366
447
|
# Incorporating classification module and subsequent ones in one final layer
|
|
367
448
|
module = nn.Sequential()
|
|
368
449
|
|
|
369
|
-
module.add_module(
|
|
370
|
-
|
|
371
|
-
|
|
450
|
+
module.add_module(
|
|
451
|
+
"conv_classifier",
|
|
452
|
+
nn.Conv2d(
|
|
453
|
+
n_filters_3,
|
|
454
|
+
self.n_outputs,
|
|
455
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
456
|
+
bias=True,
|
|
457
|
+
),
|
|
458
|
+
)
|
|
372
459
|
|
|
373
|
-
if self.add_log_softmax:
|
|
374
|
-
module.add_module("softmax", nn.LogSoftmax(dim=1))
|
|
375
460
|
# Transpose back to the logic of braindecode,
|
|
376
461
|
|
|
377
462
|
# so time in third dimension (axis=2)
|
|
378
|
-
module.add_module(
|
|
463
|
+
module.add_module(
|
|
464
|
+
"permute_2",
|
|
465
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
466
|
+
)
|
|
379
467
|
|
|
380
|
-
module.add_module("squeeze",
|
|
468
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
381
469
|
|
|
382
470
|
self.add_module("final_layer", module)
|
|
383
471
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
def _glorot_weight_zero_bias(model):
|
|
388
|
-
"""Initialize parameters of all modules by initializing weights with
|
|
389
|
-
glorot
|
|
390
|
-
uniform/xavier initialization, and setting biases to zero. Weights from
|
|
391
|
-
batch norm layers are set to 1.
|
|
392
|
-
|
|
393
|
-
Parameters
|
|
394
|
-
----------
|
|
395
|
-
model: Module
|
|
396
|
-
"""
|
|
397
|
-
for module in model.modules():
|
|
398
|
-
if hasattr(module, "weight"):
|
|
399
|
-
if "BatchNorm" not in module.__class__.__name__:
|
|
400
|
-
nn.init.xavier_uniform_(module.weight, gain=1)
|
|
401
|
-
else:
|
|
402
|
-
nn.init.constant_(module.weight, 1)
|
|
403
|
-
if hasattr(module, "bias"):
|
|
404
|
-
if module.bias is not None:
|
|
405
|
-
nn.init.constant_(module.bias, 0)
|
|
472
|
+
glorot_weight_zero_bias(self)
|