braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* Copyright (C) Cogitat, Ltd.
|
|
3
|
+
* Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
|
4
|
+
* Patent GB2609265 - Learnable filters for eeg classification
|
|
5
|
+
* https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from einops.layers.torch import Rearrange
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
import braindecode.functional as F
|
|
15
|
+
from braindecode.models.base import EEGModuleMixin
|
|
16
|
+
from braindecode.modules import GeneralizedGaussianFilter
|
|
17
|
+
|
|
18
|
+
_eeg_miner_methods = ["mag", "corr", "plv"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EEGMiner(EEGModuleMixin, nn.Module):
|
|
22
|
+
"""EEGMiner from Ludwig et al (2024) [eegminer]_.
|
|
23
|
+
|
|
24
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
|
|
25
|
+
:align: center
|
|
26
|
+
:alt: EEGMiner Architecture
|
|
27
|
+
|
|
28
|
+
EEGMiner is a neural network model for EEG signal classification using
|
|
29
|
+
learnable generalized Gaussian filters. The model leverages frequency domain
|
|
30
|
+
filtering and connectivity metrics or feature extraction, such as Phase Locking
|
|
31
|
+
Value (PLV) to extract meaningful features from EEG data, enabling effective
|
|
32
|
+
classification tasks.
|
|
33
|
+
|
|
34
|
+
The model has the following steps:
|
|
35
|
+
|
|
36
|
+
- **Generalized Gaussian** filters in the frequency domain to the input EEG signals.
|
|
37
|
+
|
|
38
|
+
- **Connectivity estimators** (corr, plv) or **Electrode-Wise Band Power** (mag), by default (plv).
|
|
39
|
+
- `'corr'`: Computes the correlation of the filtered signals.
|
|
40
|
+
- `'plv'`: Computes the phase locking value of the filtered signals.
|
|
41
|
+
- `'mag'`: Computes the magnitude of the filtered signals.
|
|
42
|
+
|
|
43
|
+
- **Feature Normalization**
|
|
44
|
+
- Apply batch normalization.
|
|
45
|
+
|
|
46
|
+
- **Final Layer**
|
|
47
|
+
- Feeds the batch-normalized features into a final linear layer for classification.
|
|
48
|
+
|
|
49
|
+
Depending on the selected method (`mag`, `corr`, or `plv`),
|
|
50
|
+
it computes the filtered signals' magnitude, correlation, or phase locking value.
|
|
51
|
+
These features are then normalized and passed through a batch normalization layer
|
|
52
|
+
before being fed into a final linear layer for classification.
|
|
53
|
+
|
|
54
|
+
The input to EEGMiner should be a three-dimensional tensor representing EEG signals:
|
|
55
|
+
|
|
56
|
+
``(batch_size, n_channels, n_timesteps)``.
|
|
57
|
+
|
|
58
|
+
Notes
|
|
59
|
+
-----
|
|
60
|
+
EEGMiner incorporates learnable parameters for filter characteristics, allowing the
|
|
61
|
+
model to adaptively learn optimal frequency bands and phase delays for the classification task.
|
|
62
|
+
By default, using the PLV as a connectivity metric makes EEGMiner suitable for tasks requiring
|
|
63
|
+
the analysis of phase relationships between different EEG channels.
|
|
64
|
+
|
|
65
|
+
The model and the module have patent [eegminercode]_, and the code is CC BY-NC 4.0.
|
|
66
|
+
|
|
67
|
+
.. versionadded:: 0.9
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
method : str, default="plv"
|
|
72
|
+
The method used for feature extraction. Options are:
|
|
73
|
+
- "mag": Electrode-Wise band power of the filtered signals.
|
|
74
|
+
- "corr": Correlation between filtered channels.
|
|
75
|
+
- "plv": Phase Locking Value connectivity metric.
|
|
76
|
+
filter_f_mean : list of float, default=[23.0, 23.0]
|
|
77
|
+
Mean frequencies for the generalized Gaussian filters.
|
|
78
|
+
filter_bandwidth : list of float, default=[44.0, 44.0]
|
|
79
|
+
Bandwidths for the generalized Gaussian filters.
|
|
80
|
+
filter_shape : list of float, default=[2.0, 2.0]
|
|
81
|
+
Shape parameters for the generalized Gaussian filters.
|
|
82
|
+
group_delay : tuple of float, default=(20.0, 20.0)
|
|
83
|
+
Group delay values for the filters in milliseconds.
|
|
84
|
+
clamp_f_mean : tuple of float, default=(1.0, 45.0)
|
|
85
|
+
Clamping range for the mean frequency parameters.
|
|
86
|
+
|
|
87
|
+
References
|
|
88
|
+
----------
|
|
89
|
+
.. [eegminer] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
90
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
91
|
+
of brain activity with learnable filters. Journal of Neural Engineering,
|
|
92
|
+
21(3), 036010.
|
|
93
|
+
.. [eegminercode] Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis,
|
|
94
|
+
Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features
|
|
95
|
+
of brain activity with learnable filters.
|
|
96
|
+
https://github.com/SMLudwig/EEGminer/.
|
|
97
|
+
Cogitat, Ltd. "Learnable filters for EEG classification."
|
|
98
|
+
Patent GB2609265.
|
|
99
|
+
https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self, # Signal related parameters
|
|
104
|
+
method: str = "plv",
|
|
105
|
+
n_chans=None,
|
|
106
|
+
n_outputs=None,
|
|
107
|
+
n_times=None,
|
|
108
|
+
chs_info=None,
|
|
109
|
+
input_window_seconds=None,
|
|
110
|
+
sfreq=None,
|
|
111
|
+
# model related
|
|
112
|
+
filter_f_mean=(23.0, 23.0),
|
|
113
|
+
filter_bandwidth=(44.0, 44.0),
|
|
114
|
+
filter_shape=(2.0, 2.0),
|
|
115
|
+
group_delay=(20.0, 20.0),
|
|
116
|
+
clamp_f_mean=(1.0, 45.0),
|
|
117
|
+
):
|
|
118
|
+
super().__init__(
|
|
119
|
+
n_outputs=n_outputs,
|
|
120
|
+
n_chans=n_chans,
|
|
121
|
+
chs_info=chs_info,
|
|
122
|
+
n_times=n_times,
|
|
123
|
+
input_window_seconds=input_window_seconds,
|
|
124
|
+
sfreq=sfreq,
|
|
125
|
+
)
|
|
126
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
127
|
+
|
|
128
|
+
# Initialize filter parameters
|
|
129
|
+
self.filter_f_mean = filter_f_mean
|
|
130
|
+
self.filter_bandwidth = filter_bandwidth
|
|
131
|
+
self.filter_shape = filter_shape
|
|
132
|
+
self.n_filters = len(self.filter_f_mean)
|
|
133
|
+
self.group_delay = group_delay
|
|
134
|
+
self.clamp_f_mean = clamp_f_mean
|
|
135
|
+
self.method = method.lower()
|
|
136
|
+
|
|
137
|
+
if self.method not in _eeg_miner_methods:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"The method {self.method} is not one of the valid options"
|
|
140
|
+
f" {_eeg_miner_methods}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if self.method == "mag" or self.method == "corr":
|
|
144
|
+
inverse_fourier = True
|
|
145
|
+
in_channels = self.n_chans
|
|
146
|
+
out_channels = self.n_chans * self.n_filters
|
|
147
|
+
else:
|
|
148
|
+
inverse_fourier = False
|
|
149
|
+
in_channels = 1
|
|
150
|
+
out_channels = 1 * self.n_filters
|
|
151
|
+
|
|
152
|
+
# Generalized Gaussian Filter
|
|
153
|
+
self.filter = GeneralizedGaussianFilter(
|
|
154
|
+
in_channels=in_channels,
|
|
155
|
+
out_channels=out_channels,
|
|
156
|
+
sequence_length=self.n_times,
|
|
157
|
+
sample_rate=self.sfreq,
|
|
158
|
+
f_mean=self.filter_f_mean,
|
|
159
|
+
bandwidth=self.filter_bandwidth,
|
|
160
|
+
shape=self.filter_shape,
|
|
161
|
+
affine_group_delay=False,
|
|
162
|
+
inverse_fourier=inverse_fourier,
|
|
163
|
+
group_delay=self.group_delay,
|
|
164
|
+
clamp_f_mean=self.clamp_f_mean,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Forward method
|
|
168
|
+
if self.method == "mag":
|
|
169
|
+
self.method_forward = self._apply_mag_forward
|
|
170
|
+
self.n_features = self.n_chans * self.n_filters
|
|
171
|
+
self.ensure_dim = nn.Identity()
|
|
172
|
+
elif self.method == "corr":
|
|
173
|
+
self.method_forward = partial(
|
|
174
|
+
self._apply_corr_forward,
|
|
175
|
+
n_chans=self.n_chans,
|
|
176
|
+
n_filters=self.n_filters,
|
|
177
|
+
n_times=self.n_times,
|
|
178
|
+
)
|
|
179
|
+
self.n_features = self.n_filters * self.n_chans * (self.n_chans - 1) // 2
|
|
180
|
+
self.ensure_dim = nn.Identity()
|
|
181
|
+
elif self.method == "plv":
|
|
182
|
+
self.method_forward = partial(self._apply_plv, n_chans=self.n_chans)
|
|
183
|
+
self.ensure_dim = Rearrange("... d -> ... 1 d")
|
|
184
|
+
self.n_features = (self.n_filters * self.n_chans * (self.n_chans - 1)) // 2
|
|
185
|
+
|
|
186
|
+
self.flatten_layer = nn.Flatten()
|
|
187
|
+
# Classifier
|
|
188
|
+
self.batch_layer = nn.BatchNorm1d(self.n_features, affine=False)
|
|
189
|
+
self.final_layer = nn.Linear(self.n_features, self.n_outputs)
|
|
190
|
+
nn.init.zeros_(self.final_layer.bias)
|
|
191
|
+
|
|
192
|
+
def forward(self, x):
|
|
193
|
+
"""x: (batch, electrodes, time)"""
|
|
194
|
+
batch = x.shape[0]
|
|
195
|
+
x = self.ensure_dim(x)
|
|
196
|
+
# Apply Gaussian filters in frequency domain
|
|
197
|
+
# x -> (batch, electrodes * filters, time)
|
|
198
|
+
x = self.filter(x)
|
|
199
|
+
|
|
200
|
+
x = self.method_forward(x=x, batch=batch)
|
|
201
|
+
# Classifier
|
|
202
|
+
# Note that the order of dimensions before flattening the feature vector is important
|
|
203
|
+
# for attributing feature weights during interpretation.
|
|
204
|
+
x = x.reshape(batch, self.n_features)
|
|
205
|
+
x = self.batch_layer(x)
|
|
206
|
+
x = self.final_layer(x)
|
|
207
|
+
|
|
208
|
+
return x
|
|
209
|
+
|
|
210
|
+
@staticmethod
|
|
211
|
+
def _apply_mag_forward(x, batch=None):
|
|
212
|
+
# Signal magnitude
|
|
213
|
+
x = x * x
|
|
214
|
+
x = x.mean(dim=-1)
|
|
215
|
+
x = torch.sqrt(x)
|
|
216
|
+
return x
|
|
217
|
+
|
|
218
|
+
@staticmethod
|
|
219
|
+
def _apply_corr_forward(
|
|
220
|
+
x, batch, n_chans, n_filters, n_times, epilson: float = 1e-6
|
|
221
|
+
):
|
|
222
|
+
x = x.reshape(batch, n_chans, n_filters, n_times).transpose(-3, -2)
|
|
223
|
+
x = (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(
|
|
224
|
+
x.var(dim=-1, keepdim=True) + epilson
|
|
225
|
+
)
|
|
226
|
+
x = torch.matmul(x, x.transpose(-2, -1)) / x.shape[-1]
|
|
227
|
+
# Original tensor shape: [batch, n_filters, chans, chans]
|
|
228
|
+
x = x.permute(0, 2, 3, 1)
|
|
229
|
+
# New tensor shape: [batch, chans, chans, n_filters]
|
|
230
|
+
# move filter channels to the end
|
|
231
|
+
x = x.abs()
|
|
232
|
+
|
|
233
|
+
# Get upper triu of symmetric connectivity matrix
|
|
234
|
+
triu = torch.triu_indices(n_chans, n_chans, 1)
|
|
235
|
+
x = x[:, triu[0], triu[1], :]
|
|
236
|
+
|
|
237
|
+
return x
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _apply_plv(x, n_chans, batch=None):
|
|
241
|
+
# Compute PLV connectivity
|
|
242
|
+
# x -> (batch, electrodes, electrodes, filters)
|
|
243
|
+
x = x.transpose(-4, -3) # swap electrodes and filters
|
|
244
|
+
# adjusting to compute the plv
|
|
245
|
+
x = F.plv_time(x, forward_fourier=False)
|
|
246
|
+
# batch, number of filters, connectivity matrix
|
|
247
|
+
# [batch, n_filters, chans, chans]
|
|
248
|
+
x = x.permute(0, 2, 3, 1)
|
|
249
|
+
# [batch, chans, chans, n_filters]
|
|
250
|
+
|
|
251
|
+
# Get upper triu of symmetric connectivity matrix
|
|
252
|
+
triu = torch.triu_indices(n_chans, n_chans, 1)
|
|
253
|
+
x = x[:, triu[0], triu[1], :]
|
|
254
|
+
return x
|