braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from math import floor, log2
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SPARCNet(EEGModuleMixin, nn.Module):
|
|
15
|
+
"""Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al. (2023) [jing2023]_.
|
|
16
|
+
|
|
17
|
+
This is a temporal CNN model for biosignal classification based on the DenseNet
|
|
18
|
+
architecture.
|
|
19
|
+
|
|
20
|
+
The model is based on the unofficial implementation [Code2023]_.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.9
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
27
|
+
by original authors.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
block_layers : int, optional
|
|
32
|
+
Number of layers per dense block. Default is 4.
|
|
33
|
+
growth_rate : int, optional
|
|
34
|
+
Growth rate of the DenseNet. Default is 16.
|
|
35
|
+
bn_size : int, optional
|
|
36
|
+
Bottleneck size. Default is 16.
|
|
37
|
+
drop_prob : float, optional
|
|
38
|
+
Dropout rate. Default is 0.5.
|
|
39
|
+
conv_bias : bool, optional
|
|
40
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
41
|
+
batch_norm : bool, optional
|
|
42
|
+
Whether to use batch normalization. Default is True.
|
|
43
|
+
activation: nn.Module, default=nn.ELU
|
|
44
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
45
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
46
|
+
|
|
47
|
+
References
|
|
48
|
+
----------
|
|
49
|
+
.. [jing2023] Jing, J., Ge, W., Hong, S., Fernandes, M. B., Lin, Z.,
|
|
50
|
+
Yang, C., ... & Westover, M. B. (2023). Development of expert-level
|
|
51
|
+
classification of seizures and rhythmic and periodic
|
|
52
|
+
patterns during eeg interpretation. Neurology, 100(17), e1750-e1762.
|
|
53
|
+
.. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
|
|
54
|
+
Biosignal Transformer for Cross-data Learning in the Wild.
|
|
55
|
+
GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
n_chans=None,
|
|
62
|
+
n_times=None,
|
|
63
|
+
n_outputs=None,
|
|
64
|
+
# Neural network parameters
|
|
65
|
+
block_layers: int = 4,
|
|
66
|
+
growth_rate: int = 16,
|
|
67
|
+
bottleneck_size: int = 16,
|
|
68
|
+
drop_prob: float = 0.5,
|
|
69
|
+
conv_bias: bool = True,
|
|
70
|
+
batch_norm: bool = True,
|
|
71
|
+
activation: nn.Module = nn.ELU,
|
|
72
|
+
# EEGModuleMixin parameters
|
|
73
|
+
# (another way to present the same parameters)
|
|
74
|
+
chs_info=None,
|
|
75
|
+
input_window_seconds=None,
|
|
76
|
+
sfreq=None,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(
|
|
79
|
+
n_outputs=n_outputs,
|
|
80
|
+
n_chans=n_chans,
|
|
81
|
+
chs_info=chs_info,
|
|
82
|
+
n_times=n_times,
|
|
83
|
+
input_window_seconds=input_window_seconds,
|
|
84
|
+
sfreq=sfreq,
|
|
85
|
+
)
|
|
86
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
87
|
+
|
|
88
|
+
# add initial convolutional layer
|
|
89
|
+
# the number of output channels is the smallest power of 2
|
|
90
|
+
# that is greater than the number of input channels
|
|
91
|
+
out_channels = 2 ** (floor(log2(self.n_chans)) + 1)
|
|
92
|
+
first_conv = OrderedDict(
|
|
93
|
+
[
|
|
94
|
+
(
|
|
95
|
+
"conv0",
|
|
96
|
+
nn.Conv1d(
|
|
97
|
+
in_channels=self.n_chans,
|
|
98
|
+
out_channels=out_channels,
|
|
99
|
+
kernel_size=7,
|
|
100
|
+
stride=2,
|
|
101
|
+
padding=3,
|
|
102
|
+
bias=conv_bias,
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
]
|
|
106
|
+
)
|
|
107
|
+
first_conv["norm0"] = nn.BatchNorm1d(out_channels)
|
|
108
|
+
first_conv["act_layer"] = activation()
|
|
109
|
+
first_conv["pool0"] = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
|
110
|
+
|
|
111
|
+
self.encoder = nn.Sequential(first_conv)
|
|
112
|
+
|
|
113
|
+
n_channels = out_channels
|
|
114
|
+
|
|
115
|
+
# Adding dense blocks
|
|
116
|
+
for n_layer in range(floor(log2(self.n_times // 4))):
|
|
117
|
+
block = _DenseBlock(
|
|
118
|
+
num_layers=block_layers,
|
|
119
|
+
in_channels=n_channels,
|
|
120
|
+
growth_rate=growth_rate,
|
|
121
|
+
bottleneck_size=bottleneck_size,
|
|
122
|
+
drop_prob=drop_prob,
|
|
123
|
+
conv_bias=conv_bias,
|
|
124
|
+
batch_norm=batch_norm,
|
|
125
|
+
activation=activation,
|
|
126
|
+
)
|
|
127
|
+
self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
|
|
128
|
+
# update the number of channels after each dense block
|
|
129
|
+
n_channels = n_channels + block_layers * growth_rate
|
|
130
|
+
|
|
131
|
+
trans = _TransitionLayer(
|
|
132
|
+
in_channels=n_channels,
|
|
133
|
+
out_channels=n_channels // 2,
|
|
134
|
+
conv_bias=conv_bias,
|
|
135
|
+
batch_norm=batch_norm,
|
|
136
|
+
activation=activation,
|
|
137
|
+
)
|
|
138
|
+
self.encoder.add_module("transition%d" % (n_layer + 1), trans)
|
|
139
|
+
# update the number of channels after each transition layer
|
|
140
|
+
n_channels = n_channels // 2
|
|
141
|
+
|
|
142
|
+
# add final convolutional layer
|
|
143
|
+
self.final_layer = nn.Sequential(
|
|
144
|
+
activation(),
|
|
145
|
+
nn.Linear(n_channels, self.n_outputs),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self._init_weights()
|
|
149
|
+
|
|
150
|
+
def _init_weights(self):
|
|
151
|
+
"""
|
|
152
|
+
Initialize the weights of the model.
|
|
153
|
+
|
|
154
|
+
Official init from torch repo, using kaiming_normal for conv layers
|
|
155
|
+
and normal for linear layers.
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
for m in self.modules():
|
|
159
|
+
if isinstance(m, nn.Conv1d):
|
|
160
|
+
nn.init.kaiming_normal_(m.weight.data)
|
|
161
|
+
elif isinstance(m, nn.BatchNorm1d):
|
|
162
|
+
m.weight.data.fill_(1)
|
|
163
|
+
m.bias.data.zero_()
|
|
164
|
+
elif isinstance(m, nn.Linear):
|
|
165
|
+
m.bias.data.zero_()
|
|
166
|
+
|
|
167
|
+
def forward(self, X: torch.Tensor):
|
|
168
|
+
"""
|
|
169
|
+
Forward pass of the model.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
X: torch.Tensor
|
|
174
|
+
The input tensor of the model with shape (batch_size, n_channels, n_times)
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
torch.Tensor
|
|
179
|
+
The output tensor of the model with shape (batch_size, n_outputs)
|
|
180
|
+
"""
|
|
181
|
+
emb = self.encoder(X).squeeze(-1)
|
|
182
|
+
out = self.final_layer(emb)
|
|
183
|
+
return out
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class _DenseLayer(nn.Sequential):
|
|
187
|
+
"""
|
|
188
|
+
A densely connected layer with batch normalization and dropout.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
in_channels : int
|
|
193
|
+
Number of input channels.
|
|
194
|
+
growth_rate : int
|
|
195
|
+
Rate of growth of channels in this layer.
|
|
196
|
+
bottleneck_size : int
|
|
197
|
+
Multiplicative factor for the bottleneck layer (does not affect the output size).
|
|
198
|
+
drop_prob : float, optional
|
|
199
|
+
Dropout rate. Default is 0.5.
|
|
200
|
+
conv_bias : bool, optional
|
|
201
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
202
|
+
batch_norm : bool, optional
|
|
203
|
+
Whether to use batch normalization. Default is True.
|
|
204
|
+
activation: nn.Module, default=nn.ELU
|
|
205
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
206
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
207
|
+
|
|
208
|
+
Examples
|
|
209
|
+
--------
|
|
210
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
211
|
+
>>> batch, channels, length = x.shape
|
|
212
|
+
>>> model = _DenseLayer(channels, 5, 2)
|
|
213
|
+
>>> y = model(x)
|
|
214
|
+
>>> y.shape
|
|
215
|
+
torch.Size([128, 10, 1000])
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
in_channels: int,
|
|
221
|
+
growth_rate: int,
|
|
222
|
+
bottleneck_size: int,
|
|
223
|
+
drop_prob: float = 0.5,
|
|
224
|
+
conv_bias: bool = True,
|
|
225
|
+
batch_norm: bool = True,
|
|
226
|
+
activation: nn.Module = nn.ELU,
|
|
227
|
+
):
|
|
228
|
+
super().__init__()
|
|
229
|
+
if batch_norm:
|
|
230
|
+
self.add_module("norm1", nn.BatchNorm1d(in_channels))
|
|
231
|
+
|
|
232
|
+
self.add_module("elu1", activation())
|
|
233
|
+
self.add_module(
|
|
234
|
+
"conv1",
|
|
235
|
+
nn.Conv1d(
|
|
236
|
+
in_channels=in_channels,
|
|
237
|
+
out_channels=bottleneck_size * growth_rate,
|
|
238
|
+
kernel_size=1,
|
|
239
|
+
stride=1,
|
|
240
|
+
bias=conv_bias,
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
if batch_norm:
|
|
244
|
+
self.add_module("norm2", nn.BatchNorm1d(bottleneck_size * growth_rate))
|
|
245
|
+
self.add_module("elu2", activation())
|
|
246
|
+
self.add_module(
|
|
247
|
+
"conv2",
|
|
248
|
+
nn.Conv1d(
|
|
249
|
+
in_channels=bottleneck_size * growth_rate,
|
|
250
|
+
out_channels=growth_rate,
|
|
251
|
+
kernel_size=3,
|
|
252
|
+
stride=1,
|
|
253
|
+
padding=1,
|
|
254
|
+
bias=conv_bias,
|
|
255
|
+
),
|
|
256
|
+
)
|
|
257
|
+
self.drop_prob = drop_prob
|
|
258
|
+
|
|
259
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
260
|
+
# Manually pass through each submodule
|
|
261
|
+
out = x
|
|
262
|
+
for layer in self:
|
|
263
|
+
out = layer(out)
|
|
264
|
+
# apply dropout using the functional API
|
|
265
|
+
out = F.dropout(out, p=self.drop_prob, training=self.training)
|
|
266
|
+
# concatenate input and new features
|
|
267
|
+
return torch.cat([x, out], dim=1)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class _DenseBlock(nn.Sequential):
|
|
271
|
+
"""
|
|
272
|
+
A densely connected block that uses DenseLayers.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
num_layers : int
|
|
277
|
+
Number of layers in this block.
|
|
278
|
+
in_channels : int
|
|
279
|
+
Number of input channels.
|
|
280
|
+
growth_rate : int
|
|
281
|
+
Rate of growth of channels in this layer.
|
|
282
|
+
bottleneck_size : int
|
|
283
|
+
Multiplicative factor for the bottleneck layer (does not affect the output size).
|
|
284
|
+
drop_prob : float, optional
|
|
285
|
+
Dropout rate. Default is 0.5.
|
|
286
|
+
conv_bias : bool, optional
|
|
287
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
288
|
+
batch_norm : bool, optional
|
|
289
|
+
Whether to use batch normalization. Default is True.
|
|
290
|
+
activation: nn.Module, default=nn.ELU
|
|
291
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
292
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
293
|
+
|
|
294
|
+
Examples
|
|
295
|
+
--------
|
|
296
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
297
|
+
>>> batch, channels, length = x.shape
|
|
298
|
+
>>> model = _DenseBlock(3, channels, 5, 2)
|
|
299
|
+
>>> y = model(x)
|
|
300
|
+
>>> y.shape
|
|
301
|
+
torch.Size([128, 20, 1000])
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
num_layers,
|
|
307
|
+
in_channels,
|
|
308
|
+
growth_rate,
|
|
309
|
+
bottleneck_size,
|
|
310
|
+
drop_prob=0.5,
|
|
311
|
+
conv_bias=True,
|
|
312
|
+
batch_norm=True,
|
|
313
|
+
activation: nn.Module = nn.ELU,
|
|
314
|
+
):
|
|
315
|
+
super(_DenseBlock, self).__init__()
|
|
316
|
+
for idx_layer in range(num_layers):
|
|
317
|
+
layer = _DenseLayer(
|
|
318
|
+
in_channels=in_channels + idx_layer * growth_rate,
|
|
319
|
+
growth_rate=growth_rate,
|
|
320
|
+
bottleneck_size=bottleneck_size,
|
|
321
|
+
drop_prob=drop_prob,
|
|
322
|
+
conv_bias=conv_bias,
|
|
323
|
+
batch_norm=batch_norm,
|
|
324
|
+
activation=activation,
|
|
325
|
+
)
|
|
326
|
+
self.add_module(f"denselayer{idx_layer + 1}", layer)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class _TransitionLayer(nn.Sequential):
|
|
330
|
+
"""
|
|
331
|
+
A pooling transition layer.
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
in_channels : int
|
|
336
|
+
Number of input channels.
|
|
337
|
+
out_channels : int
|
|
338
|
+
Number of output channels.
|
|
339
|
+
conv_bias : bool, optional
|
|
340
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
341
|
+
batch_norm : bool, optional
|
|
342
|
+
Whether to use batch normalization. Default is True.
|
|
343
|
+
activation: nn.Module, default=nn.ELU
|
|
344
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
345
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
346
|
+
|
|
347
|
+
Examples
|
|
348
|
+
--------
|
|
349
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
350
|
+
>>> model = _TransitionLayer(5, 18)
|
|
351
|
+
>>> y = model(x)
|
|
352
|
+
>>> y.shape
|
|
353
|
+
torch.Size([128, 18, 500])
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(
|
|
357
|
+
self,
|
|
358
|
+
in_channels,
|
|
359
|
+
out_channels,
|
|
360
|
+
conv_bias=True,
|
|
361
|
+
batch_norm=True,
|
|
362
|
+
activation: nn.Module = nn.ELU,
|
|
363
|
+
):
|
|
364
|
+
super(_TransitionLayer, self).__init__()
|
|
365
|
+
if batch_norm:
|
|
366
|
+
self.add_module("norm", nn.BatchNorm1d(in_channels))
|
|
367
|
+
self.add_module("elu", activation())
|
|
368
|
+
self.add_module(
|
|
369
|
+
"conv",
|
|
370
|
+
nn.Conv1d(
|
|
371
|
+
in_channels=in_channels,
|
|
372
|
+
out_channels=out_channels,
|
|
373
|
+
kernel_size=1,
|
|
374
|
+
stride=1,
|
|
375
|
+
bias=conv_bias,
|
|
376
|
+
),
|
|
377
|
+
)
|
|
378
|
+
self.add_module("pool", nn.AvgPool1d(kernel_size=2, stride=2))
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Model,Paradigm,Type,Freq(Hz),Hyperparameters,#Parameters,get_#Parameters
|
|
2
|
+
ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)"
|
|
3
|
+
AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)"
|
|
4
|
+
BDTCN,Normal/Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)"
|
|
5
|
+
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)"
|
|
6
|
+
ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)"
|
|
7
|
+
CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)"
|
|
8
|
+
Deep4Net,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)"
|
|
9
|
+
DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)"
|
|
10
|
+
EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)."
|
|
11
|
+
EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)"
|
|
12
|
+
EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)"
|
|
13
|
+
EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)"
|
|
14
|
+
EEGNetv1,General,Classification,128,"n_chans, n_outputs, n_times",3052,"EEGNetv1(n_chans=22, n_outputs=4, n_times=512)"
|
|
15
|
+
EEGNetv4,General,Classification,128,"n_chans, n_outputs, n_times",2484,"EEGNetv4(n_chans=22, n_outputs=4, n_times=512)"
|
|
16
|
+
EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEGNeX(n_chans=22, n_outputs=4, n_times=500)"
|
|
17
|
+
EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)"
|
|
18
|
+
EEGResNet,General,Classification,250,"n_chans, n_outputs, n_times",247484,"EEGResNet(n_chans=22, n_outputs=4, n_times=1000)"
|
|
19
|
+
EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)"
|
|
20
|
+
EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)"
|
|
21
|
+
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
22
|
+
MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)"
|
|
23
|
+
SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)"
|
|
24
|
+
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
|
|
25
|
+
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
|
|
26
|
+
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
|
|
27
|
+
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
|
|
28
|
+
SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
29
|
+
ShallowFBCSPNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
30
|
+
SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
|
|
31
|
+
SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)"
|
|
32
|
+
SleepStagerEldele2021,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"SleepStagerEldele2021(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
|
|
33
|
+
SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)"
|
|
34
|
+
SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
|
|
35
|
+
TSceptionV1,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSceptionV1(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
|
|
36
|
+
TIDNet,General,Classification,250,"n_chans, n_outputs, n_times",240404,"TIDNet(n_chans=22, n_outputs=4, n_times=1000)"
|
|
37
|
+
USleep,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",2482011,"USleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
|
|
38
|
+
FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",11812,"FCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
39
|
+
FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
40
|
+
FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
41
|
+
IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from numpy import arange, ceil
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SyncNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
|
|
12
|
+
|
|
13
|
+
.. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
|
|
14
|
+
:align: center
|
|
15
|
+
:alt: SyncNet Architecture
|
|
16
|
+
|
|
17
|
+
SyncNet uses parameterized 1-dimensional convolutional filters inspired by
|
|
18
|
+
the Morlet wavelet to extract features from EEG signals. The filters are
|
|
19
|
+
dynamically generated based on learnable parameters that control the
|
|
20
|
+
oscillation and decay characteristics.
|
|
21
|
+
|
|
22
|
+
The filter for channel ``c`` and filter ``k`` is defined as:
|
|
23
|
+
|
|
24
|
+
.. math::
|
|
25
|
+
|
|
26
|
+
f_c^{(k)}(\\tau) = amplitude_c^{(k)} \\cos(\\omega^{(k)} \\tau + \\phi_c^{(k)}) \\exp(-\\beta^{(k)} \\tau^2)
|
|
27
|
+
|
|
28
|
+
where:
|
|
29
|
+
- :math:`amplitude_c^{(k)}` is the amplitude parameter (channel-specific).
|
|
30
|
+
- :math:`\\omega^{(k)}` is the frequency parameter (shared across channels).
|
|
31
|
+
- :math:`\\phi_c^{(k)}` is the phase shift (channel-specific).
|
|
32
|
+
- :math:`\\beta^{(k)}` is the decay parameter (shared across channels).
|
|
33
|
+
- :math:`\\tau` is the time index.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
num_filters : int, optional
|
|
38
|
+
Number of filters in the convolutional layer. Default is 1.
|
|
39
|
+
filter_width : int, optional
|
|
40
|
+
Width of the convolutional filters. Default is 40.
|
|
41
|
+
pool_size : int, optional
|
|
42
|
+
Size of the pooling window. Default is 40.
|
|
43
|
+
activation : nn.Module, optional
|
|
44
|
+
Activation function to apply after pooling. Default is ``nn.ReLU``.
|
|
45
|
+
ampli_init_values : tuple of float, optional
|
|
46
|
+
The initialization range for amplitude parameter using uniform
|
|
47
|
+
distribution. Default is (-0.05, 0.05).
|
|
48
|
+
omega_init_values : tuple of float, optional
|
|
49
|
+
The initialization range for omega parameters using uniform
|
|
50
|
+
distribution. Default is (0, 1).
|
|
51
|
+
beta_init_values : tuple of float, optional
|
|
52
|
+
The initialization range for beta parameters using uniform
|
|
53
|
+
distribution. Default is (0, 1). Default is (0, 0.05).
|
|
54
|
+
phase_init_values : tuple of float, optional
|
|
55
|
+
The initialization range for phase parameters using `normal`
|
|
56
|
+
distribution. Default is (0, 1). Default is (0, 0.05).
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
Notes
|
|
60
|
+
-----
|
|
61
|
+
This implementation is not guaranteed to be correct! it has not been checked
|
|
62
|
+
by original authors. The modifications are based on derivated code from
|
|
63
|
+
[CodeICASSP2025]_.
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
References
|
|
67
|
+
----------
|
|
68
|
+
.. [Li2017] Li, Y., Dzirasa, K., Carin, L., & Carlson, D. E. (2017).
|
|
69
|
+
Targeting EEG/LFP synchrony with neural nets. Advances in neural
|
|
70
|
+
information processing systems, 30.
|
|
71
|
+
.. [CodeICASSP2025] Code from Baselines for EEG-Music Emotion Recognition
|
|
72
|
+
Grand Challenge at ICASSP 2025.
|
|
73
|
+
https://github.com/SalvoCalcagno/eeg-music-challenge-icassp-2025-baselines
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
# braindecode convention
|
|
80
|
+
n_chans=None,
|
|
81
|
+
n_times=None,
|
|
82
|
+
n_outputs=None,
|
|
83
|
+
chs_info=None,
|
|
84
|
+
input_window_seconds=None,
|
|
85
|
+
sfreq=None,
|
|
86
|
+
# model parameters
|
|
87
|
+
num_filters=1,
|
|
88
|
+
filter_width=40,
|
|
89
|
+
pool_size=40,
|
|
90
|
+
activation: nn.Module = nn.ReLU,
|
|
91
|
+
ampli_init_values: tuple[float, float] = (-0.05, 0.05),
|
|
92
|
+
omega_init_values: tuple[float, float] = (0.0, 1.0),
|
|
93
|
+
beta_init_values: tuple[float, float] = (0.0, 0.05),
|
|
94
|
+
phase_init_values: tuple[float, float] = (0.0, 0.05),
|
|
95
|
+
):
|
|
96
|
+
super().__init__(
|
|
97
|
+
n_chans=n_chans,
|
|
98
|
+
n_times=n_times,
|
|
99
|
+
n_outputs=n_outputs,
|
|
100
|
+
chs_info=chs_info,
|
|
101
|
+
input_window_seconds=input_window_seconds,
|
|
102
|
+
sfreq=sfreq,
|
|
103
|
+
)
|
|
104
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
105
|
+
|
|
106
|
+
self.num_filters = num_filters
|
|
107
|
+
self.filter_width = filter_width
|
|
108
|
+
self.pool_size = pool_size
|
|
109
|
+
self.activation = activation()
|
|
110
|
+
self.ampli_init_values = ampli_init_values
|
|
111
|
+
self.omega_init_values = omega_init_values
|
|
112
|
+
self.beta_init_values = beta_init_values
|
|
113
|
+
self.phase_init_values = phase_init_values
|
|
114
|
+
|
|
115
|
+
# Initialize parameters
|
|
116
|
+
self.amplitude = nn.Parameter(
|
|
117
|
+
torch.FloatTensor(1, 1, self.n_chans, self.num_filters).uniform_(
|
|
118
|
+
self.ampli_init_values[0], self.ampli_init_values[1]
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
self.omega = nn.Parameter(
|
|
122
|
+
torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
|
|
123
|
+
self.omega_init_values[0], self.omega_init_values[1]
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
self.bias = nn.Parameter(torch.zeros(self.num_filters))
|
|
128
|
+
|
|
129
|
+
# Calculate the output size after pooling
|
|
130
|
+
self.classifier_input_size = int(
|
|
131
|
+
ceil(float(self.n_times) / float(self.pool_size)) * self.num_filters
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Create time vector t
|
|
135
|
+
if self.filter_width % 2 == 0:
|
|
136
|
+
t_range = arange(-int(self.filter_width / 2), int(self.filter_width / 2))
|
|
137
|
+
else:
|
|
138
|
+
t_range = arange(
|
|
139
|
+
-int((self.filter_width - 1) / 2), int((self.filter_width - 1) / 2) + 1
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
t_np = t_range.reshape(1, self.filter_width, 1, 1)
|
|
143
|
+
self.t = nn.Parameter(torch.FloatTensor(t_np))
|
|
144
|
+
# Phase Shift
|
|
145
|
+
self.phi_ini = nn.Parameter(
|
|
146
|
+
torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
|
|
147
|
+
self.beta_init_values[0], self.beta_init_values[1]
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
self.beta = nn.Parameter(
|
|
151
|
+
torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
|
|
152
|
+
self.phase_init_values[0], self.phase_init_values[1]
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.padding = self._compute_padding(filter_width=self.filter_width)
|
|
157
|
+
self.pad_input = nn.ConstantPad1d(self.padding, 0.0)
|
|
158
|
+
self.pad_res = nn.ConstantPad1d(self.padding, 0.0)
|
|
159
|
+
|
|
160
|
+
# Define pooling and classifier layers
|
|
161
|
+
self.pool = nn.MaxPool2d((1, self.pool_size), stride=(1, self.pool_size))
|
|
162
|
+
|
|
163
|
+
self.ensuredim = Rearrange("batch ch time -> batch ch 1 time")
|
|
164
|
+
|
|
165
|
+
self.final_layer = nn.Linear(self.classifier_input_size, self.n_outputs)
|
|
166
|
+
|
|
167
|
+
def forward(self, x):
|
|
168
|
+
"""Forward pass of the SyncNet model.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
x : torch.Tensor
|
|
173
|
+
Input tensor of shape (batch_size, n_chans, n_times)
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
out : torch.Tensor
|
|
178
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
# Ensure input tensor has shape (batch_size, n_chans, 1, n_times)
|
|
182
|
+
x = self.ensuredim(x)
|
|
183
|
+
# Output: (batch_size, n_chans, 1, n_times)
|
|
184
|
+
|
|
185
|
+
# Compute the oscillatory component
|
|
186
|
+
W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
|
|
187
|
+
# W_osc is (1, filter_width, n_chans, 1)
|
|
188
|
+
|
|
189
|
+
# Compute the decay component
|
|
190
|
+
t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
|
|
191
|
+
t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
|
|
192
|
+
W_decay = torch.exp(-t_squared_beta)
|
|
193
|
+
# W_osc is (1, filter_width, 1, 1)
|
|
194
|
+
|
|
195
|
+
# Combine oscillatory and decay components
|
|
196
|
+
# W shape: (1, n_chans, num_filters, filter_width)
|
|
197
|
+
W = W_osc * W_decay
|
|
198
|
+
# W shape will be: (1, filter_width, n_chans, 1)
|
|
199
|
+
|
|
200
|
+
W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
|
|
201
|
+
|
|
202
|
+
# Apply convolution
|
|
203
|
+
x_padded = self.pad_input(x.float())
|
|
204
|
+
|
|
205
|
+
res = F.conv2d(x_padded, W.float(), bias=self.bias, stride=1)
|
|
206
|
+
|
|
207
|
+
# Apply padding to the convolution result
|
|
208
|
+
res_padded = self.pad_res(res)
|
|
209
|
+
res_pooled = self.pool(res_padded)
|
|
210
|
+
|
|
211
|
+
# Flatten the result
|
|
212
|
+
res_flat = res_pooled.view(-1, self.classifier_input_size)
|
|
213
|
+
|
|
214
|
+
# Ensure beta remains non-negative
|
|
215
|
+
self.beta.data.clamp_(min=0)
|
|
216
|
+
|
|
217
|
+
# Apply activation
|
|
218
|
+
out = self.activation(res_flat)
|
|
219
|
+
# Apply classifier
|
|
220
|
+
out = self.final_layer(out)
|
|
221
|
+
|
|
222
|
+
return out
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def _compute_padding(filter_width):
|
|
226
|
+
# Compute padding
|
|
227
|
+
P = filter_width - 2
|
|
228
|
+
if P % 2 == 0:
|
|
229
|
+
padding = (P // 2, P // 2 + 1)
|
|
230
|
+
else:
|
|
231
|
+
padding = (P // 2, P // 2)
|
|
232
|
+
return padding
|