braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* Copyright (C) Cogitat, Ltd.
|
|
3
|
+
* Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
|
4
|
+
* Patent GB2609265 - Learnable filters for eeg classification
|
|
5
|
+
* https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from einops.layers.torch import Rearrange
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
import braindecode.functional as F
|
|
15
|
+
from braindecode.models.base import EEGModuleMixin
|
|
16
|
+
from braindecode.modules import GeneralizedGaussianFilter
|
|
17
|
+
|
|
18
|
+
_eeg_miner_methods = ["mag", "corr", "plv"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EEGMiner(EEGModuleMixin, nn.Module):
|
|
22
|
+
r"""EEGMiner from Ludwig et al (2024) [eegminer]_.
|
|
23
|
+
|
|
24
|
+
:bdg-success:`Convolution` :bdg-warning:`Interpretability`
|
|
25
|
+
|
|
26
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
|
|
27
|
+
:align: center
|
|
28
|
+
:alt: EEGMiner Architecture
|
|
29
|
+
|
|
30
|
+
EEGMiner is a neural network model for EEG signal classification using
|
|
31
|
+
learnable generalized Gaussian filters. The model leverages frequency domain
|
|
32
|
+
filtering and connectivity metrics or feature extraction, such as Phase Locking
|
|
33
|
+
Value (PLV) to extract meaningful features from EEG data, enabling effective
|
|
34
|
+
classification tasks.
|
|
35
|
+
|
|
36
|
+
The model has the following steps:
|
|
37
|
+
|
|
38
|
+
- **Generalized Gaussian** filters in the frequency domain to the input EEG signals.
|
|
39
|
+
|
|
40
|
+
- **Connectivity estimators** (corr, plv) or **Electrode-Wise Band Power** (mag), by default (plv).
|
|
41
|
+
- `'corr'`: Computes the correlation of the filtered signals.
|
|
42
|
+
- `'plv'`: Computes the phase locking value of the filtered signals.
|
|
43
|
+
- `'mag'`: Computes the magnitude of the filtered signals.
|
|
44
|
+
|
|
45
|
+
- **Feature Normalization**
|
|
46
|
+
- Apply batch normalization.
|
|
47
|
+
|
|
48
|
+
- **Final Layer**
|
|
49
|
+
- Feeds the batch-normalized features into a final linear layer for classification.
|
|
50
|
+
|
|
51
|
+
Depending on the selected method (`mag`, `corr`, or `plv`),
|
|
52
|
+
it computes the filtered signals' magnitude, correlation, or phase locking value.
|
|
53
|
+
These features are then normalized and passed through a batch normalization layer
|
|
54
|
+
before being fed into a final linear layer for classification.
|
|
55
|
+
|
|
56
|
+
The input to EEGMiner should be a three-dimensional tensor representing EEG signals:
|
|
57
|
+
|
|
58
|
+
``(batch_size, n_channels, n_timesteps)``.
|
|
59
|
+
|
|
60
|
+
Notes
|
|
61
|
+
-----
|
|
62
|
+
EEGMiner incorporates learnable parameters for filter characteristics, allowing the
|
|
63
|
+
model to adaptively learn optimal frequency bands and phase delays for the classification task.
|
|
64
|
+
By default, using the PLV as a connectivity metric makes EEGMiner suitable for tasks requiring
|
|
65
|
+
the analysis of phase relationships between different EEG channels.
|
|
66
|
+
|
|
67
|
+
The model and the module have patent [eegminercode]_, and the code is CC BY-NC 4.0.
|
|
68
|
+
|
|
69
|
+
.. versionadded:: 0.9
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
method : str, default="plv"
|
|
74
|
+
The method used for feature extraction. Options are:
|
|
75
|
+
- "mag": Electrode-Wise band power of the filtered signals.
|
|
76
|
+
- "corr": Correlation between filtered channels.
|
|
77
|
+
- "plv": Phase Locking Value connectivity metric.
|
|
78
|
+
filter_f_mean : list of float, default=[23.0, 23.0]
|
|
79
|
+
Mean frequencies for the generalized Gaussian filters.
|
|
80
|
+
filter_bandwidth : list of float, default=[44.0, 44.0]
|
|
81
|
+
Bandwidths for the generalized Gaussian filters.
|
|
82
|
+
filter_shape : list of float, default=[2.0, 2.0]
|
|
83
|
+
Shape parameters for the generalized Gaussian filters.
|
|
84
|
+
group_delay : tuple of float, default=(20.0, 20.0)
|
|
85
|
+
Group delay values for the filters in milliseconds.
|
|
86
|
+
clamp_f_mean : tuple of float, default=(1.0, 45.0)
|
|
87
|
+
Clamping range for the mean frequency parameters.
|
|
88
|
+
|
|
89
|
+
References
|
|
90
|
+
----------
|
|
91
|
+
.. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
92
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
93
|
+
of brain activity with learnable filters. Journal of Neural Engineering,
|
|
94
|
+
21(3), 036010.
|
|
95
|
+
.. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
96
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
97
|
+
of brain activity with learnable filters.
|
|
98
|
+
https://github.com/SMLudwig/EEGminer/.
|
|
99
|
+
Cogitat, Ltd. "Learnable filters for EEG classification."
|
|
100
|
+
Patent GB2609265.
|
|
101
|
+
https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self, # Signal related parameters
|
|
106
|
+
method: str = "plv",
|
|
107
|
+
n_chans=None,
|
|
108
|
+
n_outputs=None,
|
|
109
|
+
n_times=None,
|
|
110
|
+
chs_info=None,
|
|
111
|
+
input_window_seconds=None,
|
|
112
|
+
sfreq=None,
|
|
113
|
+
# model related
|
|
114
|
+
filter_f_mean=(23.0, 23.0),
|
|
115
|
+
filter_bandwidth=(44.0, 44.0),
|
|
116
|
+
filter_shape=(2.0, 2.0),
|
|
117
|
+
group_delay=(20.0, 20.0),
|
|
118
|
+
clamp_f_mean=(1.0, 45.0),
|
|
119
|
+
):
|
|
120
|
+
super().__init__(
|
|
121
|
+
n_outputs=n_outputs,
|
|
122
|
+
n_chans=n_chans,
|
|
123
|
+
chs_info=chs_info,
|
|
124
|
+
n_times=n_times,
|
|
125
|
+
input_window_seconds=input_window_seconds,
|
|
126
|
+
sfreq=sfreq,
|
|
127
|
+
)
|
|
128
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
129
|
+
|
|
130
|
+
# Initialize filter parameters
|
|
131
|
+
self.filter_f_mean = filter_f_mean
|
|
132
|
+
self.filter_bandwidth = filter_bandwidth
|
|
133
|
+
self.filter_shape = filter_shape
|
|
134
|
+
self.n_filters = len(self.filter_f_mean)
|
|
135
|
+
self.group_delay = group_delay
|
|
136
|
+
self.clamp_f_mean = clamp_f_mean
|
|
137
|
+
self.method = method.lower()
|
|
138
|
+
|
|
139
|
+
if self.method not in _eeg_miner_methods:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"The method {self.method} is not one of the valid options"
|
|
142
|
+
f" {_eeg_miner_methods}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if self.method == "mag" or self.method == "corr":
|
|
146
|
+
inverse_fourier = True
|
|
147
|
+
in_channels = self.n_chans
|
|
148
|
+
out_channels = self.n_chans * self.n_filters
|
|
149
|
+
else:
|
|
150
|
+
inverse_fourier = False
|
|
151
|
+
in_channels = 1
|
|
152
|
+
out_channels = 1 * self.n_filters
|
|
153
|
+
|
|
154
|
+
# Generalized Gaussian Filter
|
|
155
|
+
self.filter = GeneralizedGaussianFilter(
|
|
156
|
+
in_channels=in_channels,
|
|
157
|
+
out_channels=out_channels,
|
|
158
|
+
sequence_length=self.n_times,
|
|
159
|
+
sample_rate=self.sfreq,
|
|
160
|
+
f_mean=self.filter_f_mean,
|
|
161
|
+
bandwidth=self.filter_bandwidth,
|
|
162
|
+
shape=self.filter_shape,
|
|
163
|
+
affine_group_delay=False,
|
|
164
|
+
inverse_fourier=inverse_fourier,
|
|
165
|
+
group_delay=self.group_delay,
|
|
166
|
+
clamp_f_mean=self.clamp_f_mean,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Forward method
|
|
170
|
+
if self.method == "mag":
|
|
171
|
+
self.method_forward = self._apply_mag_forward
|
|
172
|
+
self.n_features = self.n_chans * self.n_filters
|
|
173
|
+
self.ensure_dim = nn.Identity()
|
|
174
|
+
elif self.method == "corr":
|
|
175
|
+
self.method_forward = partial(
|
|
176
|
+
self._apply_corr_forward,
|
|
177
|
+
n_chans=self.n_chans,
|
|
178
|
+
n_filters=self.n_filters,
|
|
179
|
+
n_times=self.n_times,
|
|
180
|
+
)
|
|
181
|
+
self.n_features = self.n_filters * self.n_chans * (self.n_chans - 1) // 2
|
|
182
|
+
self.ensure_dim = nn.Identity()
|
|
183
|
+
elif self.method == "plv":
|
|
184
|
+
self.method_forward = partial(self._apply_plv, n_chans=self.n_chans)
|
|
185
|
+
self.ensure_dim = Rearrange("... d -> ... 1 d")
|
|
186
|
+
self.n_features = (self.n_filters * self.n_chans * (self.n_chans - 1)) // 2
|
|
187
|
+
|
|
188
|
+
self.flatten_layer = nn.Flatten()
|
|
189
|
+
# Classifier
|
|
190
|
+
self.batch_layer = nn.BatchNorm1d(self.n_features, affine=False)
|
|
191
|
+
self.final_layer = nn.Linear(self.n_features, self.n_outputs)
|
|
192
|
+
nn.init.zeros_(self.final_layer.bias)
|
|
193
|
+
|
|
194
|
+
def forward(self, x):
|
|
195
|
+
"""x: (batch, electrodes, time)"""
|
|
196
|
+
batch = x.shape[0]
|
|
197
|
+
x = self.ensure_dim(x)
|
|
198
|
+
# Apply Gaussian filters in frequency domain
|
|
199
|
+
# x -> (batch, electrodes * filters, time)
|
|
200
|
+
x = self.filter(x)
|
|
201
|
+
|
|
202
|
+
x = self.method_forward(x=x, batch=batch)
|
|
203
|
+
# Classifier
|
|
204
|
+
# Note that the order of dimensions before flattening the feature vector is important
|
|
205
|
+
# for attributing feature weights during interpretation.
|
|
206
|
+
x = x.reshape(batch, self.n_features)
|
|
207
|
+
x = self.batch_layer(x)
|
|
208
|
+
x = self.final_layer(x)
|
|
209
|
+
|
|
210
|
+
return x
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def _apply_mag_forward(x, batch=None):
|
|
214
|
+
# Signal magnitude
|
|
215
|
+
x = x * x
|
|
216
|
+
x = x.mean(dim=-1)
|
|
217
|
+
x = torch.sqrt(x)
|
|
218
|
+
return x
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _apply_corr_forward(
|
|
222
|
+
x, batch, n_chans, n_filters, n_times, epilson: float = 1e-6
|
|
223
|
+
):
|
|
224
|
+
x = x.reshape(batch, n_chans, n_filters, n_times).transpose(-3, -2)
|
|
225
|
+
x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(
|
|
226
|
+
x.var(dim=-1, keepdim=True) + epilson
|
|
227
|
+
)
|
|
228
|
+
x = torch.matmul(x, x.transpose(-2, -1)) / x.shape[-1]
|
|
229
|
+
# Original tensor shape: [batch, n_filters, chans, chans]
|
|
230
|
+
x = x.permute(0, 2, 3, 1)
|
|
231
|
+
# New tensor shape: [batch, chans, chans, n_filters]
|
|
232
|
+
# move filter channels to the end
|
|
233
|
+
x = x.abs()
|
|
234
|
+
|
|
235
|
+
# Get upper triu of symmetric connectivity matrix
|
|
236
|
+
triu = torch.triu_indices(n_chans, n_chans, 1)
|
|
237
|
+
x = x[:, triu[0], triu[1], :]
|
|
238
|
+
|
|
239
|
+
return x
|
|
240
|
+
|
|
241
|
+
@staticmethod
|
|
242
|
+
def _apply_plv(x, n_chans, batch=None):
|
|
243
|
+
# Compute PLV connectivity
|
|
244
|
+
# x -> (batch, electrodes, electrodes, filters)
|
|
245
|
+
x = x.transpose(-4, -3) # swap electrodes and filters
|
|
246
|
+
# adjusting to compute the plv
|
|
247
|
+
x = F.plv_time(x, forward_fourier=False)
|
|
248
|
+
# batch, number of filters, connectivity matrix
|
|
249
|
+
# [batch, n_filters, chans, chans]
|
|
250
|
+
x = x.permute(0, 2, 3, 1)
|
|
251
|
+
# [batch, chans, chans, n_filters]
|
|
252
|
+
|
|
253
|
+
# Get upper triu of symmetric connectivity matrix
|
|
254
|
+
triu = torch.triu_indices(n_chans, n_chans, 1)
|
|
255
|
+
x = x[:, triu[0], triu[1], :]
|
|
256
|
+
return x
|
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Optional
|
|
7
|
+
|
|
8
|
+
from einops.layers.torch import Rearrange
|
|
9
|
+
from mne.utils import deprecated, warn
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
from braindecode.modules import (
|
|
15
|
+
Conv2dWithConstraint,
|
|
16
|
+
Ensure4d,
|
|
17
|
+
LinearWithConstraint,
|
|
18
|
+
SqueezeFinalOutput,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EEGNet(EEGModuleMixin, nn.Sequential):
|
|
23
|
+
r"""EEGNet model from Lawhern et al (2018) [Lawhern2018]_.
|
|
24
|
+
|
|
25
|
+
:bdg-success:`Convolution`
|
|
26
|
+
|
|
27
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
|
|
28
|
+
:align: center
|
|
29
|
+
:alt: EEGNet Architecture
|
|
30
|
+
:width: 600px
|
|
31
|
+
|
|
32
|
+
.. rubric:: Architectural Overview
|
|
33
|
+
|
|
34
|
+
EEGNet is a compact convolutional network designed for EEG decoding with a pipeline that mirrors classical EEG processing:
|
|
35
|
+
- (i) learn temporal frequency-selective filters,
|
|
36
|
+
- (ii) learn spatial filters for those frequencies, and
|
|
37
|
+
- (iii) condense features with depthwise-separable convolutions before a lightweight classifier.
|
|
38
|
+
|
|
39
|
+
The architecture is deliberately small (temporal convolutional and spatial patterns) [Lawhern2018]_.
|
|
40
|
+
|
|
41
|
+
.. rubric:: Macro Components
|
|
42
|
+
|
|
43
|
+
- **Temporal convolution**
|
|
44
|
+
Temporal convolution applied per channel; learns ``F1`` kernels that act as data-driven band-pass filters.
|
|
45
|
+
- **Depthwise Spatial Filtering.**
|
|
46
|
+
Depthwise convolution spanning the channel dimension with ``groups = F1``,
|
|
47
|
+
yielding ``D`` spatial filters for each temporal filter (no cross-filter mixing).
|
|
48
|
+
- **Norm-Nonlinearity-Pooling (+ dropout).**
|
|
49
|
+
Batch normalization → ELU → temporal pooling, with dropout.
|
|
50
|
+
- **Depthwise-Separable Convolution Block.**
|
|
51
|
+
(a) depthwise temporal conv to refine temporal structure;
|
|
52
|
+
(b) pointwise 1x1 conv to mix feature maps into ``F2`` combinations.
|
|
53
|
+
- **Classifier Head.**
|
|
54
|
+
Lightweight 1x1 conv or dense layer (often with max-norm constraint).
|
|
55
|
+
|
|
56
|
+
.. rubric:: Convolutional Details
|
|
57
|
+
|
|
58
|
+
- **Temporal.** The initial temporal convs serve as a *learned filter bank*:
|
|
59
|
+
long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
|
|
60
|
+
Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
|
|
61
|
+
|
|
62
|
+
- **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
|
|
63
|
+
With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
|
|
64
|
+
typically regularized with max-norm.
|
|
65
|
+
|
|
66
|
+
- **Spectral.** No explicit Fourier/wavelet transform is used. Frequency structure
|
|
67
|
+
is captured implicitly by the temporal filter bank; later depthwise temporal kernels act as short-time integrators/refiners.
|
|
68
|
+
|
|
69
|
+
.. rubric:: Additional Comments
|
|
70
|
+
|
|
71
|
+
- **Filter-bank structure:** Parallel temporal kernels (``F1``) emulate classical filter banks; pairing them with frequency-specific spatial filters
|
|
72
|
+
yields features mappable to rhythms and topographies.
|
|
73
|
+
- **Depthwise & separable convs:** Parameter-efficient decomposition (depthwise + pointwise) retains power while limiting overfitting
|
|
74
|
+
[Chollet2017]_ and keeps temporal vs. mixing steps interpretable.
|
|
75
|
+
- **Regularization:** Batch norm, dropout, pooling, and optional max-norm on spatial kernels aid stability on small EEG datasets.
|
|
76
|
+
- The v4 means the version 4 at the arxiv paper [Lawhern2018]_.
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
final_conv_length : int or "auto", default="auto"
|
|
82
|
+
Length of the final convolution layer. If "auto", it is set based on n_times.
|
|
83
|
+
pool_mode : {"mean", "max"}, default="mean"
|
|
84
|
+
Pooling method to use in pooling layers.
|
|
85
|
+
F1 : int, default=8
|
|
86
|
+
Number of temporal filters in the first convolutional layer.
|
|
87
|
+
D : int, default=2
|
|
88
|
+
Depth multiplier for the depthwise convolution.
|
|
89
|
+
F2 : int or None, default=None
|
|
90
|
+
Number of pointwise filters in the separable convolution. Usually set to ``F1 * D``.
|
|
91
|
+
depthwise_kernel_length : int, default=16
|
|
92
|
+
Length of the depthwise convolution kernel in the separable convolution.
|
|
93
|
+
pool1_kernel_size : int, default=4
|
|
94
|
+
Kernel size of the first pooling layer.
|
|
95
|
+
pool2_kernel_size : int, default=8
|
|
96
|
+
Kernel size of the second pooling layer.
|
|
97
|
+
kernel_length : int, default=64
|
|
98
|
+
Length of the temporal convolution kernel.
|
|
99
|
+
conv_spatial_max_norm : float, default=1
|
|
100
|
+
Maximum norm constraint for the spatial (depthwise) convolution.
|
|
101
|
+
activation : nn.Module, default=nn.ELU
|
|
102
|
+
Non-linear activation function to be used in the layers.
|
|
103
|
+
batch_norm_momentum : float, default=0.01
|
|
104
|
+
Momentum for instance normalization in batch norm layers.
|
|
105
|
+
batch_norm_affine : bool, default=True
|
|
106
|
+
If True, batch norm has learnable affine parameters.
|
|
107
|
+
batch_norm_eps : float, default=1e-3
|
|
108
|
+
Epsilon for numeric stability in batch norm layers.
|
|
109
|
+
drop_prob : float, default=0.25
|
|
110
|
+
Dropout probability.
|
|
111
|
+
final_layer_with_constraint : bool, default=False
|
|
112
|
+
If ``False``, uses a convolution-based classification layer. If ``True``,
|
|
113
|
+
apply a flattened linear layer with constraint on the weights norm as the final classification step.
|
|
114
|
+
norm_rate : float, default=0.25
|
|
115
|
+
Max-norm constraint value for the linear layer (used if ``final_layer_conv=False``).
|
|
116
|
+
|
|
117
|
+
References
|
|
118
|
+
----------
|
|
119
|
+
.. [Lawhern2018] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M.,
|
|
120
|
+
Hung, C. P., & Lance, B. J. (2018). EEGNet: a compact convolutional
|
|
121
|
+
neural network for EEG-based brain–computer interfaces. Journal of
|
|
122
|
+
neural engineering, 15(5), 056013.
|
|
123
|
+
.. [Chollet2017] Chollet, F., *Xception: Deep Learning with Depthwise Separable
|
|
124
|
+
Convolutions*, CVPR, 2017.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
# signal's parameters
|
|
131
|
+
n_chans: Optional[int] = None,
|
|
132
|
+
n_outputs: Optional[int] = None,
|
|
133
|
+
n_times: Optional[int] = None,
|
|
134
|
+
# model's parameters
|
|
135
|
+
final_conv_length: str | int = "auto",
|
|
136
|
+
pool_mode: str = "mean",
|
|
137
|
+
F1: int = 8,
|
|
138
|
+
D: int = 2,
|
|
139
|
+
F2: Optional[int | None] = None,
|
|
140
|
+
kernel_length: int = 64,
|
|
141
|
+
*,
|
|
142
|
+
depthwise_kernel_length: int = 16,
|
|
143
|
+
pool1_kernel_size: int = 4,
|
|
144
|
+
pool2_kernel_size: int = 8,
|
|
145
|
+
conv_spatial_max_norm: int = 1,
|
|
146
|
+
activation: type[nn.Module] = nn.ELU,
|
|
147
|
+
batch_norm_momentum: float = 0.01,
|
|
148
|
+
batch_norm_affine: bool = True,
|
|
149
|
+
batch_norm_eps: float = 1e-3,
|
|
150
|
+
drop_prob: float = 0.25,
|
|
151
|
+
final_layer_with_constraint: bool = False,
|
|
152
|
+
norm_rate: float = 0.25,
|
|
153
|
+
# Other ways to construct the signal related parameters
|
|
154
|
+
chs_info: Optional[list[Dict]] = None,
|
|
155
|
+
input_window_seconds=None,
|
|
156
|
+
sfreq=None,
|
|
157
|
+
**kwargs,
|
|
158
|
+
):
|
|
159
|
+
super().__init__(
|
|
160
|
+
n_outputs=n_outputs,
|
|
161
|
+
n_chans=n_chans,
|
|
162
|
+
chs_info=chs_info,
|
|
163
|
+
n_times=n_times,
|
|
164
|
+
input_window_seconds=input_window_seconds,
|
|
165
|
+
sfreq=sfreq,
|
|
166
|
+
)
|
|
167
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
168
|
+
if final_conv_length == "auto":
|
|
169
|
+
assert self.n_times is not None
|
|
170
|
+
|
|
171
|
+
if not final_layer_with_constraint:
|
|
172
|
+
warn(
|
|
173
|
+
"Parameter 'final_layer_with_constraint=False' is deprecated and will be "
|
|
174
|
+
"removed in a future release. Please use `final_layer_linear=True`.",
|
|
175
|
+
DeprecationWarning,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if "third_kernel_size" in kwargs:
|
|
179
|
+
warn(
|
|
180
|
+
"The parameter `third_kernel_size` is deprecated "
|
|
181
|
+
"and will be removed in a future version.",
|
|
182
|
+
)
|
|
183
|
+
unexpected_kwargs = set(kwargs) - {"third_kernel_size"}
|
|
184
|
+
if unexpected_kwargs:
|
|
185
|
+
raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
|
|
186
|
+
|
|
187
|
+
self.final_conv_length = final_conv_length
|
|
188
|
+
self.pool_mode = pool_mode
|
|
189
|
+
self.F1 = F1
|
|
190
|
+
self.D = D
|
|
191
|
+
|
|
192
|
+
if F2 is None:
|
|
193
|
+
F2 = self.F1 * self.D
|
|
194
|
+
self.F2 = F2
|
|
195
|
+
|
|
196
|
+
self.kernel_length = kernel_length
|
|
197
|
+
self.depthwise_kernel_length = depthwise_kernel_length
|
|
198
|
+
self.pool1_kernel_size = pool1_kernel_size
|
|
199
|
+
self.pool2_kernel_size = pool2_kernel_size
|
|
200
|
+
self.drop_prob = drop_prob
|
|
201
|
+
self.activation = activation
|
|
202
|
+
self.batch_norm_momentum = batch_norm_momentum
|
|
203
|
+
self.batch_norm_affine = batch_norm_affine
|
|
204
|
+
self.batch_norm_eps = batch_norm_eps
|
|
205
|
+
self.conv_spatial_max_norm = conv_spatial_max_norm
|
|
206
|
+
self.norm_rate = norm_rate
|
|
207
|
+
|
|
208
|
+
# For the load_state_dict
|
|
209
|
+
# When padronize all layers,
|
|
210
|
+
# add the old's parameters here
|
|
211
|
+
self.mapping = {
|
|
212
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
213
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
217
|
+
self.add_module("ensuredims", Ensure4d())
|
|
218
|
+
|
|
219
|
+
self.add_module("dimshuffle", Rearrange("batch ch t 1 -> batch 1 ch t"))
|
|
220
|
+
self.add_module(
|
|
221
|
+
"conv_temporal",
|
|
222
|
+
nn.Conv2d(
|
|
223
|
+
1,
|
|
224
|
+
self.F1,
|
|
225
|
+
(1, self.kernel_length),
|
|
226
|
+
bias=False,
|
|
227
|
+
padding=(0, self.kernel_length // 2),
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
self.add_module(
|
|
231
|
+
"bnorm_temporal",
|
|
232
|
+
nn.BatchNorm2d(
|
|
233
|
+
self.F1,
|
|
234
|
+
momentum=self.batch_norm_momentum,
|
|
235
|
+
affine=self.batch_norm_affine,
|
|
236
|
+
eps=self.batch_norm_eps,
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
self.add_module(
|
|
240
|
+
"conv_spatial",
|
|
241
|
+
Conv2dWithConstraint(
|
|
242
|
+
in_channels=self.F1,
|
|
243
|
+
out_channels=self.F1 * self.D,
|
|
244
|
+
kernel_size=(self.n_chans, 1),
|
|
245
|
+
max_norm=self.conv_spatial_max_norm,
|
|
246
|
+
bias=False,
|
|
247
|
+
groups=self.F1,
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.add_module(
|
|
252
|
+
"bnorm_1",
|
|
253
|
+
nn.BatchNorm2d(
|
|
254
|
+
self.F1 * self.D,
|
|
255
|
+
momentum=self.batch_norm_momentum,
|
|
256
|
+
affine=self.batch_norm_affine,
|
|
257
|
+
eps=self.batch_norm_eps,
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
self.add_module("elu_1", activation())
|
|
261
|
+
|
|
262
|
+
self.add_module(
|
|
263
|
+
"pool_1",
|
|
264
|
+
pool_class(
|
|
265
|
+
kernel_size=(1, self.pool1_kernel_size),
|
|
266
|
+
),
|
|
267
|
+
)
|
|
268
|
+
self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
|
|
269
|
+
|
|
270
|
+
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
|
|
271
|
+
self.add_module(
|
|
272
|
+
"conv_separable_depth",
|
|
273
|
+
nn.Conv2d(
|
|
274
|
+
self.F1 * self.D,
|
|
275
|
+
self.F1 * self.D,
|
|
276
|
+
(1, self.depthwise_kernel_length),
|
|
277
|
+
bias=False,
|
|
278
|
+
groups=self.F1 * self.D,
|
|
279
|
+
padding=(0, self.depthwise_kernel_length // 2),
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
self.add_module(
|
|
283
|
+
"conv_separable_point",
|
|
284
|
+
nn.Conv2d(
|
|
285
|
+
self.F1 * self.D,
|
|
286
|
+
self.F2,
|
|
287
|
+
kernel_size=(1, 1),
|
|
288
|
+
bias=False,
|
|
289
|
+
),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
self.add_module(
|
|
293
|
+
"bnorm_2",
|
|
294
|
+
nn.BatchNorm2d(
|
|
295
|
+
self.F2,
|
|
296
|
+
momentum=self.batch_norm_momentum,
|
|
297
|
+
affine=self.batch_norm_affine,
|
|
298
|
+
eps=self.batch_norm_eps,
|
|
299
|
+
),
|
|
300
|
+
)
|
|
301
|
+
self.add_module("elu_2", self.activation())
|
|
302
|
+
self.add_module(
|
|
303
|
+
"pool_2",
|
|
304
|
+
pool_class(
|
|
305
|
+
kernel_size=(1, self.pool2_kernel_size),
|
|
306
|
+
),
|
|
307
|
+
)
|
|
308
|
+
self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
|
|
309
|
+
|
|
310
|
+
output_shape = self.get_output_shape()
|
|
311
|
+
n_out_virtual_chans = output_shape[2]
|
|
312
|
+
|
|
313
|
+
if self.final_conv_length == "auto":
|
|
314
|
+
n_out_time = output_shape[3]
|
|
315
|
+
self.final_conv_length = n_out_time
|
|
316
|
+
|
|
317
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
318
|
+
module = nn.Sequential()
|
|
319
|
+
if not final_layer_with_constraint:
|
|
320
|
+
module.add_module(
|
|
321
|
+
"conv_classifier",
|
|
322
|
+
nn.Conv2d(
|
|
323
|
+
self.F2,
|
|
324
|
+
self.n_outputs,
|
|
325
|
+
(n_out_virtual_chans, self.final_conv_length),
|
|
326
|
+
bias=True,
|
|
327
|
+
),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Transpose back to the logic of braindecode,
|
|
331
|
+
# so time in third dimension (axis=2)
|
|
332
|
+
module.add_module(
|
|
333
|
+
"permute_back",
|
|
334
|
+
Rearrange("batch x y z -> batch x z y"),
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
338
|
+
else:
|
|
339
|
+
module.add_module("flatten", nn.Flatten())
|
|
340
|
+
module.add_module(
|
|
341
|
+
"linearconstraint",
|
|
342
|
+
LinearWithConstraint(
|
|
343
|
+
in_features=self.F2 * self.final_conv_length,
|
|
344
|
+
out_features=self.n_outputs,
|
|
345
|
+
max_norm=norm_rate,
|
|
346
|
+
),
|
|
347
|
+
)
|
|
348
|
+
self.add_module("final_layer", module)
|
|
349
|
+
|
|
350
|
+
glorot_weight_zero_bias(self)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@deprecated(
|
|
354
|
+
"`EEGNetv4` was renamed to `EEGNet` in v1.12; this alias will be removed in v1.14."
|
|
355
|
+
)
|
|
356
|
+
class EEGNetv4(EEGNet):
|
|
357
|
+
r"""Deprecated alias for EEGNet."""
|
|
358
|
+
|
|
359
|
+
pass
|