braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/models/biot.py
CHANGED
|
@@ -9,9 +9,9 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class BIOT(EEGModuleMixin, nn.Module):
|
|
12
|
-
|
|
12
|
+
"""BIOT from Yang et al. (2023) [Yang2023]_
|
|
13
13
|
|
|
14
|
-
:bdg-danger:`
|
|
14
|
+
:bdg-danger:`Large Brain Model`
|
|
15
15
|
|
|
16
16
|
.. figure:: https://braindecode.org/dev/_static/model/biot.jpg
|
|
17
17
|
:align: center
|
|
@@ -19,7 +19,7 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
19
19
|
|
|
20
20
|
BIOT: Cross-data Biosignal Learning in the Wild.
|
|
21
21
|
|
|
22
|
-
BIOT is a
|
|
22
|
+
BIOT is a large brain model for biosignal classification. It is
|
|
23
23
|
a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
|
|
24
24
|
|
|
25
25
|
It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
|
|
@@ -41,44 +41,15 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
41
41
|
linear layer that takes the output of the `BIOTEncoder` and outputs
|
|
42
42
|
the classification probabilities.
|
|
43
43
|
|
|
44
|
-
.. important::
|
|
45
|
-
**Pre-trained Weights Available**
|
|
46
|
-
|
|
47
|
-
This model has pre-trained weights available on the Hugging Face Hub.
|
|
48
|
-
You can load them using:
|
|
49
|
-
|
|
50
|
-
.. code-block:: python
|
|
51
|
-
|
|
52
|
-
from braindecode.models import BIOT
|
|
53
|
-
|
|
54
|
-
# Load the original pre-trained model from Hugging Face Hub
|
|
55
|
-
# For 16-channel models:
|
|
56
|
-
model = BIOT.from_pretrained("braindecode/biot-pretrained-prest-16chs")
|
|
57
|
-
|
|
58
|
-
# For 18-channel models:
|
|
59
|
-
model = BIOT.from_pretrained("braindecode/biot-pretrained-shhs-prest-18chs")
|
|
60
|
-
model = BIOT.from_pretrained("braindecode/biot-pretrained-six-datasets-18chs")
|
|
61
|
-
|
|
62
|
-
To push your own trained model to the Hub:
|
|
63
|
-
|
|
64
|
-
.. code-block:: python
|
|
65
|
-
|
|
66
|
-
# After training your model
|
|
67
|
-
model.push_to_hub(
|
|
68
|
-
repo_id="username/my-biot-model", commit_message="Upload trained BIOT model"
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
72
|
-
|
|
73
44
|
.. versionadded:: 0.9
|
|
74
45
|
|
|
75
46
|
Parameters
|
|
76
47
|
----------
|
|
77
|
-
|
|
48
|
+
emb_size : int, optional
|
|
78
49
|
The size of the embedding layer, by default 256
|
|
79
|
-
|
|
50
|
+
att_num_heads : int, optional
|
|
80
51
|
The number of attention heads, by default 8
|
|
81
|
-
|
|
52
|
+
n_layers : int, optional
|
|
82
53
|
The number of transformer layers, by default 4
|
|
83
54
|
activation: nn.Module, default=nn.ELU
|
|
84
55
|
Activation function class to apply. Should be a PyTorch activation
|
|
@@ -105,9 +76,9 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
105
76
|
|
|
106
77
|
def __init__(
|
|
107
78
|
self,
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
79
|
+
emb_size=256,
|
|
80
|
+
att_num_heads=8,
|
|
81
|
+
n_layers=4,
|
|
111
82
|
sfreq=200,
|
|
112
83
|
hop_length=100,
|
|
113
84
|
return_feature=False,
|
|
@@ -116,12 +87,12 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
116
87
|
chs_info=None,
|
|
117
88
|
n_times=None,
|
|
118
89
|
input_window_seconds=None,
|
|
119
|
-
activation:
|
|
90
|
+
activation: nn.Module = nn.ELU,
|
|
120
91
|
drop_prob: float = 0.5,
|
|
121
92
|
# Parameters for the encoder
|
|
122
93
|
max_seq_len: int = 1024,
|
|
123
|
-
|
|
124
|
-
|
|
94
|
+
attn_dropout=0.2,
|
|
95
|
+
attn_layer_dropout=0.2,
|
|
125
96
|
):
|
|
126
97
|
super().__init__(
|
|
127
98
|
n_outputs=n_outputs,
|
|
@@ -132,10 +103,10 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
132
103
|
sfreq=sfreq,
|
|
133
104
|
)
|
|
134
105
|
del n_outputs, n_chans, chs_info, n_times, sfreq
|
|
135
|
-
self.
|
|
106
|
+
self.emb_size = emb_size
|
|
136
107
|
self.hop_length = hop_length
|
|
137
|
-
self.
|
|
138
|
-
self.
|
|
108
|
+
self.att_num_heads = att_num_heads
|
|
109
|
+
self.n_layers = n_layers
|
|
139
110
|
self.return_feature = return_feature
|
|
140
111
|
if (self.sfreq != 200) & (self.sfreq is not None):
|
|
141
112
|
warn(
|
|
@@ -143,7 +114,7 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
143
114
|
+ "no guarantee to generalize well with the default parameters",
|
|
144
115
|
UserWarning,
|
|
145
116
|
)
|
|
146
|
-
if self.n_chans >
|
|
117
|
+
if self.n_chans > emb_size:
|
|
147
118
|
warn(
|
|
148
119
|
"The number of channels is larger than the embedding size. "
|
|
149
120
|
+ "This may cause overfitting. Consider using a larger "
|
|
@@ -171,20 +142,20 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
171
142
|
self.n_fft = int(self.sfreq)
|
|
172
143
|
|
|
173
144
|
self.encoder = _BIOTEncoder(
|
|
174
|
-
emb_size=
|
|
175
|
-
|
|
176
|
-
n_layers=
|
|
145
|
+
emb_size=emb_size,
|
|
146
|
+
att_num_heads=att_num_heads,
|
|
147
|
+
n_layers=n_layers,
|
|
177
148
|
n_chans=self.n_chans,
|
|
178
149
|
n_fft=self.n_fft,
|
|
179
150
|
hop_length=hop_length,
|
|
180
151
|
drop_prob=drop_prob,
|
|
181
152
|
max_seq_len=max_seq_len,
|
|
182
|
-
attn_dropout=
|
|
183
|
-
attn_layer_dropout=
|
|
153
|
+
attn_dropout=attn_dropout,
|
|
154
|
+
attn_layer_dropout=attn_layer_dropout,
|
|
184
155
|
)
|
|
185
156
|
|
|
186
157
|
self.final_layer = _ClassificationHead(
|
|
187
|
-
emb_size=
|
|
158
|
+
emb_size=emb_size,
|
|
188
159
|
n_outputs=self.n_outputs,
|
|
189
160
|
activation=activation,
|
|
190
161
|
)
|
|
@@ -216,7 +187,7 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
216
187
|
|
|
217
188
|
|
|
218
189
|
class _PatchFrequencyEmbedding(nn.Module):
|
|
219
|
-
|
|
190
|
+
"""
|
|
220
191
|
Patch Frequency Embedding.
|
|
221
192
|
|
|
222
193
|
A simple linear layer is used to learn some representation over the
|
|
@@ -258,7 +229,7 @@ class _PatchFrequencyEmbedding(nn.Module):
|
|
|
258
229
|
|
|
259
230
|
|
|
260
231
|
class _ClassificationHead(nn.Sequential):
|
|
261
|
-
|
|
232
|
+
"""
|
|
262
233
|
Classification head for the BIOT model.
|
|
263
234
|
|
|
264
235
|
Simple linear layer with ELU activation function.
|
|
@@ -279,9 +250,7 @@ class _ClassificationHead(nn.Sequential):
|
|
|
279
250
|
(batch, n_outputs)
|
|
280
251
|
"""
|
|
281
252
|
|
|
282
|
-
def __init__(
|
|
283
|
-
self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
|
|
284
|
-
):
|
|
253
|
+
def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
|
|
285
254
|
super().__init__()
|
|
286
255
|
self.activation_layer = activation()
|
|
287
256
|
self.classification_head = nn.Linear(emb_size, n_outputs)
|
|
@@ -293,7 +262,7 @@ class _ClassificationHead(nn.Sequential):
|
|
|
293
262
|
|
|
294
263
|
|
|
295
264
|
class _PositionalEncoding(nn.Module):
|
|
296
|
-
|
|
265
|
+
"""
|
|
297
266
|
Positional Encoding.
|
|
298
267
|
|
|
299
268
|
We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
|
|
@@ -354,7 +323,7 @@ class _PositionalEncoding(nn.Module):
|
|
|
354
323
|
|
|
355
324
|
|
|
356
325
|
class _BIOTEncoder(nn.Module):
|
|
357
|
-
|
|
326
|
+
"""
|
|
358
327
|
BIOT Encoder.
|
|
359
328
|
|
|
360
329
|
The BIOT encoder is a transformer that takes the time series input data and
|
|
@@ -376,7 +345,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
376
345
|
The number of channels
|
|
377
346
|
emb_size: int
|
|
378
347
|
The size of the embedding layer
|
|
379
|
-
|
|
348
|
+
att_num_heads: int
|
|
380
349
|
The number of attention heads
|
|
381
350
|
n_layers: int
|
|
382
351
|
The number of transformer layers
|
|
@@ -389,7 +358,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
389
358
|
def __init__(
|
|
390
359
|
self,
|
|
391
360
|
emb_size=256, # The size of the embedding layer
|
|
392
|
-
|
|
361
|
+
att_num_heads=8, # The number of attention heads
|
|
393
362
|
n_chans=16, # The number of channels
|
|
394
363
|
n_layers=4, # The number of transformer layers
|
|
395
364
|
n_fft=200, # Related with the frequency resolution
|
|
@@ -409,7 +378,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
409
378
|
)
|
|
410
379
|
self.transformer = LinearAttentionTransformer(
|
|
411
380
|
dim=emb_size,
|
|
412
|
-
heads=
|
|
381
|
+
heads=att_num_heads,
|
|
413
382
|
depth=n_layers,
|
|
414
383
|
max_seq_len=max_seq_len,
|
|
415
384
|
attn_layer_dropout=attn_layer_dropout,
|
braindecode/models/contrawr.py
CHANGED
|
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class ContraWR(EEGModuleMixin, nn.Module):
|
|
11
|
-
|
|
11
|
+
"""Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
|
|
12
12
|
|
|
13
13
|
:bdg-success:`Convolution`
|
|
14
14
|
|
|
@@ -58,7 +58,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
58
58
|
emb_size: int = 256,
|
|
59
59
|
res_channels: list[int] = [32, 64, 128],
|
|
60
60
|
steps=20,
|
|
61
|
-
activation:
|
|
61
|
+
activation: nn.Module = nn.ELU,
|
|
62
62
|
drop_prob: float = 0.5,
|
|
63
63
|
stride_res: int = 2,
|
|
64
64
|
kernel_size_res: int = 3,
|
|
@@ -148,7 +148,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
148
148
|
|
|
149
149
|
|
|
150
150
|
class _ResBlock(nn.Module):
|
|
151
|
-
|
|
151
|
+
"""Convolutional Residual Block 2D.
|
|
152
152
|
|
|
153
153
|
This block stacks two convolutional layers with batch normalization,
|
|
154
154
|
max pooling, dropout, and residual connection.
|
|
@@ -195,7 +195,7 @@ class _ResBlock(nn.Module):
|
|
|
195
195
|
kernel_size=3,
|
|
196
196
|
padding=1,
|
|
197
197
|
drop_prob=0.5,
|
|
198
|
-
activation:
|
|
198
|
+
activation: nn.Module = nn.ReLU,
|
|
199
199
|
):
|
|
200
200
|
super().__init__()
|
|
201
201
|
self.conv1 = nn.Conv2d(
|
|
@@ -259,7 +259,7 @@ class _ResBlock(nn.Module):
|
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
class _STFTModule(nn.Module):
|
|
262
|
-
|
|
262
|
+
"""
|
|
263
263
|
A PyTorch module that computes the Short-Time Fourier Transform (STFT)
|
|
264
264
|
of an EEG batch tensor.
|
|
265
265
|
|
braindecode/models/ctnet.py
CHANGED
|
@@ -25,9 +25,9 @@ from braindecode.modules import (
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class CTNet(EEGModuleMixin, nn.Module):
|
|
28
|
-
|
|
28
|
+
"""CTNet from Zhao, W et al (2024) [ctnet]_.
|
|
29
29
|
|
|
30
|
-
:bdg-success:`Convolution` :bdg-info:`Attention
|
|
30
|
+
:bdg-success:`Convolution` :bdg-info:`Small Attention`
|
|
31
31
|
|
|
32
32
|
A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
|
|
33
33
|
|
|
@@ -61,11 +61,11 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
61
61
|
----------
|
|
62
62
|
activation : nn.Module, default=nn.GELU
|
|
63
63
|
Activation function to use in the network.
|
|
64
|
-
|
|
64
|
+
heads : int, default=4
|
|
65
65
|
Number of attention heads in the Transformer encoder.
|
|
66
|
-
|
|
66
|
+
emb_size : int or None, default=None
|
|
67
67
|
Embedding size (dimensionality) for the Transformer encoder.
|
|
68
|
-
|
|
68
|
+
depth : int, default=6
|
|
69
69
|
Number of encoder layers in the Transformer.
|
|
70
70
|
n_filters_time : int, default=20
|
|
71
71
|
Number of temporal filters in the first convolutional layer.
|
|
@@ -77,11 +77,11 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
77
77
|
Pooling size for the first average pooling layer.
|
|
78
78
|
pool_size_2 : int, default=8
|
|
79
79
|
Pooling size for the second average pooling layer.
|
|
80
|
-
|
|
80
|
+
drop_prob_cnn : float, default=0.3
|
|
81
81
|
Dropout probability after convolutional layers.
|
|
82
|
-
|
|
82
|
+
drop_prob_posi : float, default=0.1
|
|
83
83
|
Dropout probability for the positional encoding in the Transformer.
|
|
84
|
-
|
|
84
|
+
drop_prob_final : float, default=0.5
|
|
85
85
|
Dropout probability before the final classification layer.
|
|
86
86
|
|
|
87
87
|
Notes
|
|
@@ -109,15 +109,15 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
109
109
|
n_times=None,
|
|
110
110
|
input_window_seconds=None,
|
|
111
111
|
# Model specific arguments
|
|
112
|
-
activation_patch:
|
|
113
|
-
activation_transformer:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
112
|
+
activation_patch: nn.Module = nn.ELU,
|
|
113
|
+
activation_transformer: nn.Module = nn.GELU,
|
|
114
|
+
drop_prob_cnn: float = 0.3,
|
|
115
|
+
drop_prob_posi: float = 0.1,
|
|
116
|
+
drop_prob_final: float = 0.5,
|
|
117
117
|
# other parameters
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
118
|
+
heads: int = 4,
|
|
119
|
+
emb_size: Optional[int] = 40,
|
|
120
|
+
depth: int = 6,
|
|
121
121
|
n_filters_time: Optional[int] = None,
|
|
122
122
|
kernel_size: int = 64,
|
|
123
123
|
depth_multiplier: Optional[int] = 2,
|
|
@@ -136,14 +136,14 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
136
136
|
|
|
137
137
|
self.activation_patch = activation_patch
|
|
138
138
|
self.activation_transformer = activation_transformer
|
|
139
|
-
self.
|
|
139
|
+
self.drop_prob_cnn = drop_prob_cnn
|
|
140
140
|
self.pool_size_1 = pool_size_1
|
|
141
141
|
self.pool_size_2 = pool_size_2
|
|
142
142
|
self.kernel_size = kernel_size
|
|
143
|
-
self.
|
|
144
|
-
self.
|
|
145
|
-
self.
|
|
146
|
-
self.
|
|
143
|
+
self.drop_prob_posi = drop_prob_posi
|
|
144
|
+
self.drop_prob_final = drop_prob_final
|
|
145
|
+
self.heads = heads
|
|
146
|
+
self.depth = depth
|
|
147
147
|
# n_times - pool_size_1 / p
|
|
148
148
|
self.sequence_length = math.floor(
|
|
149
149
|
(
|
|
@@ -154,8 +154,8 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
154
154
|
+ 1
|
|
155
155
|
)
|
|
156
156
|
|
|
157
|
-
self.depth_multiplier, self.n_filters_time, self.
|
|
158
|
-
depth_multiplier, n_filters_time,
|
|
157
|
+
self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
|
|
158
|
+
depth_multiplier, n_filters_time, emb_size
|
|
159
159
|
)
|
|
160
160
|
|
|
161
161
|
# Layers
|
|
@@ -168,32 +168,32 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
168
168
|
depth_multiplier=self.depth_multiplier,
|
|
169
169
|
pool_size_1=self.pool_size_1,
|
|
170
170
|
pool_size_2=self.pool_size_2,
|
|
171
|
-
drop_prob=self.
|
|
171
|
+
drop_prob=self.drop_prob_cnn,
|
|
172
172
|
n_chans=self.n_chans,
|
|
173
173
|
activation=self.activation_patch,
|
|
174
174
|
)
|
|
175
175
|
|
|
176
176
|
self.position = _PositionalEncoding(
|
|
177
|
-
emb_size=self.
|
|
178
|
-
drop_prob=self.
|
|
177
|
+
emb_size=self.emb_size,
|
|
178
|
+
drop_prob=self.drop_prob_posi,
|
|
179
179
|
n_times=self.n_times,
|
|
180
180
|
pool_size=self.pool_size_1,
|
|
181
181
|
)
|
|
182
182
|
|
|
183
183
|
self.trans = _TransformerEncoder(
|
|
184
|
-
self.
|
|
185
|
-
self.
|
|
186
|
-
self.
|
|
184
|
+
self.heads,
|
|
185
|
+
self.depth,
|
|
186
|
+
self.emb_size,
|
|
187
187
|
activation=self.activation_transformer,
|
|
188
188
|
)
|
|
189
189
|
|
|
190
190
|
self.flatten_drop_layer = nn.Sequential(
|
|
191
191
|
nn.Flatten(),
|
|
192
|
-
nn.Dropout(p=self.
|
|
192
|
+
nn.Dropout(p=self.drop_prob_final),
|
|
193
193
|
)
|
|
194
194
|
|
|
195
195
|
self.final_layer = nn.Linear(
|
|
196
|
-
in_features=int(self.
|
|
196
|
+
in_features=int(self.emb_size * self.sequence_length),
|
|
197
197
|
out_features=self.n_outputs,
|
|
198
198
|
)
|
|
199
199
|
|
|
@@ -213,7 +213,7 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
213
213
|
"""
|
|
214
214
|
x = self.ensuredim(x)
|
|
215
215
|
cnn = self.cnn(x)
|
|
216
|
-
cnn = cnn * math.sqrt(self.
|
|
216
|
+
cnn = cnn * math.sqrt(self.emb_size)
|
|
217
217
|
cnn = self.position(cnn)
|
|
218
218
|
trans = self.trans(cnn)
|
|
219
219
|
features = cnn + trans
|
|
@@ -312,7 +312,7 @@ class _PatchEmbeddingEEGNet(nn.Module):
|
|
|
312
312
|
pool_size_2: int = 8,
|
|
313
313
|
drop_prob: float = 0.3,
|
|
314
314
|
n_chans: int = 22,
|
|
315
|
-
activation:
|
|
315
|
+
activation: nn.Module = nn.ELU,
|
|
316
316
|
):
|
|
317
317
|
super().__init__()
|
|
318
318
|
n_filters_out = depth_multiplier * n_filters_time
|
|
@@ -416,7 +416,7 @@ class _TransformerEncoderBlock(nn.Module):
|
|
|
416
416
|
drop_prob: float = 0.5,
|
|
417
417
|
forward_expansion: int = 4,
|
|
418
418
|
forward_drop_p: float = 0.5,
|
|
419
|
-
activation:
|
|
419
|
+
activation: nn.Module = nn.GELU,
|
|
420
420
|
):
|
|
421
421
|
super().__init__()
|
|
422
422
|
self.attention = _ResidualAdd(
|
|
@@ -466,7 +466,7 @@ class _TransformerEncoder(nn.Module):
|
|
|
466
466
|
nheads: int,
|
|
467
467
|
depth: int,
|
|
468
468
|
dim_feedforward: int,
|
|
469
|
-
activation:
|
|
469
|
+
activation: nn.Module = nn.GELU,
|
|
470
470
|
):
|
|
471
471
|
super().__init__()
|
|
472
472
|
self.layers = nn.Sequential(
|
braindecode/models/deep4.py
CHANGED
|
@@ -17,7 +17,7 @@ from braindecode.modules import (
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
20
|
-
|
|
20
|
+
"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
21
21
|
|
|
22
22
|
:bdg-success:`Convolution`
|
|
23
23
|
|
|
@@ -109,12 +109,12 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
109
109
|
filter_length_3=10,
|
|
110
110
|
n_filters_4=200,
|
|
111
111
|
filter_length_4=10,
|
|
112
|
-
activation_first_conv_nonlin:
|
|
112
|
+
activation_first_conv_nonlin: nn.Module = nn.ELU,
|
|
113
113
|
first_pool_mode="max",
|
|
114
|
-
first_pool_nonlin:
|
|
115
|
-
activation_later_conv_nonlin:
|
|
114
|
+
first_pool_nonlin: nn.Module = nn.Identity,
|
|
115
|
+
activation_later_conv_nonlin: nn.Module = nn.ELU,
|
|
116
116
|
later_pool_mode="max",
|
|
117
|
-
later_pool_nonlin:
|
|
117
|
+
later_pool_nonlin: nn.Module = nn.Identity,
|
|
118
118
|
drop_prob=0.5,
|
|
119
119
|
split_first_layer=True,
|
|
120
120
|
batch_norm=True,
|
|
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
-
|
|
11
|
+
"""DeepSleepNet from Supratak et al. (2017) [Supratak2017]_.
|
|
12
12
|
|
|
13
13
|
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
14
14
|
|
|
@@ -172,8 +172,8 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
172
172
|
n_times=None,
|
|
173
173
|
input_window_seconds=None,
|
|
174
174
|
sfreq=None,
|
|
175
|
-
activation_large:
|
|
176
|
-
activation_small:
|
|
175
|
+
activation_large: nn.Module = nn.ELU,
|
|
176
|
+
activation_small: nn.Module = nn.ReLU,
|
|
177
177
|
drop_prob: float = 0.5,
|
|
178
178
|
):
|
|
179
179
|
super().__init__(
|
|
@@ -240,7 +240,7 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
240
240
|
|
|
241
241
|
|
|
242
242
|
class _SmallCNN(nn.Module):
|
|
243
|
-
|
|
243
|
+
"""
|
|
244
244
|
Smaller filter sizes to learn temporal information.
|
|
245
245
|
|
|
246
246
|
Parameters
|
|
@@ -252,7 +252,7 @@ class _SmallCNN(nn.Module):
|
|
|
252
252
|
The dropout rate for regularization. Values should be between 0 and 1.
|
|
253
253
|
"""
|
|
254
254
|
|
|
255
|
-
def __init__(self, activation:
|
|
255
|
+
def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
|
|
256
256
|
super().__init__()
|
|
257
257
|
self.conv1 = nn.Sequential(
|
|
258
258
|
nn.Conv2d(
|
|
@@ -317,7 +317,7 @@ class _SmallCNN(nn.Module):
|
|
|
317
317
|
|
|
318
318
|
|
|
319
319
|
class _LargeCNN(nn.Module):
|
|
320
|
-
|
|
320
|
+
"""
|
|
321
321
|
Larger filter sizes to learn frequency information.
|
|
322
322
|
|
|
323
323
|
Parameters
|
|
@@ -328,7 +328,7 @@ class _LargeCNN(nn.Module):
|
|
|
328
328
|
|
|
329
329
|
"""
|
|
330
330
|
|
|
331
|
-
def __init__(self, activation:
|
|
331
|
+
def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
|
|
332
332
|
super().__init__()
|
|
333
333
|
|
|
334
334
|
self.conv1 = nn.Sequential(
|
|
@@ -12,9 +12,9 @@ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class EEGConformer(EEGModuleMixin, nn.Module):
|
|
15
|
-
|
|
15
|
+
"""EEG Conformer from Song et al. (2022) [song2022]_.
|
|
16
16
|
|
|
17
|
-
:bdg-success:`Convolution` :bdg-info:`Attention
|
|
17
|
+
:bdg-success:`Convolution` :bdg-info:`Small Attention`
|
|
18
18
|
|
|
19
19
|
.. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
|
|
20
20
|
:align: center
|
|
@@ -57,9 +57,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
57
57
|
- :class:`_TransformerEncoder` **(context over temporal tokens)**
|
|
58
58
|
|
|
59
59
|
- *Operations.*
|
|
60
|
-
- A stack of ``
|
|
60
|
+
- A stack of ``att_depth`` encoder blocks. :class:`_TransformerEncoderBlock`
|
|
61
61
|
- Each block applies LayerNorm :class:`torch.nn.LayerNorm`
|
|
62
|
-
- Multi-Head Self-Attention (``
|
|
62
|
+
- Multi-Head Self-Attention (``att_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
|
|
63
63
|
- LayerNorm :class:`torch.nn.LayerNorm`
|
|
64
64
|
- 2-layer feed-forward (≈4x expansion, :class:`torch.nn.GELU`) with dropout + residual.
|
|
65
65
|
|
|
@@ -100,7 +100,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
100
100
|
|
|
101
101
|
.. rubric:: Attention / Sequential Modules
|
|
102
102
|
|
|
103
|
-
- **Type.** Standard multi-head self-attention (MHA) with ``
|
|
103
|
+
- **Type.** Standard multi-head self-attention (MHA) with ``att_heads`` heads over the token sequence.
|
|
104
104
|
- **Shapes.** Input/Output: ``(B, S_tokens, D)``; attention operates along the ``S_tokens`` axis.
|
|
105
105
|
- **Role.** Re-weights and integrates evidence across pooled windows, capturing dependencies
|
|
106
106
|
longer than any single token while leaving channel relationships to the convolutional stem.
|
|
@@ -127,7 +127,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
127
127
|
- **Instantiation.** Choose ``n_filters_time`` (embedding size ``D``) and
|
|
128
128
|
``filter_time_length`` to match the rhythms of interest. Tune
|
|
129
129
|
``pool_time_length/stride`` to trade temporal resolution for sequence length.
|
|
130
|
-
Keep ``
|
|
130
|
+
Keep ``att_depth`` modest (e.g., 4–6) and set ``att_heads`` to divide ``D``.
|
|
131
131
|
``final_fc_length="auto"`` infers the flattened size from PatchEmbedding.
|
|
132
132
|
|
|
133
133
|
Notes
|
|
@@ -160,9 +160,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
160
160
|
Length of stride between temporal pooling filters.
|
|
161
161
|
drop_prob: float
|
|
162
162
|
Dropout rate of the convolutional layer.
|
|
163
|
-
|
|
163
|
+
att_depth: int
|
|
164
164
|
Number of self-attention layers.
|
|
165
|
-
|
|
165
|
+
att_heads: int
|
|
166
166
|
Number of attention heads.
|
|
167
167
|
att_drop_prob: float
|
|
168
168
|
Dropout rate of the self-attention layer.
|
|
@@ -197,13 +197,13 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
197
197
|
pool_time_length=75,
|
|
198
198
|
pool_time_stride=15,
|
|
199
199
|
drop_prob=0.5,
|
|
200
|
-
|
|
201
|
-
|
|
200
|
+
att_depth=6,
|
|
201
|
+
att_heads=10,
|
|
202
202
|
att_drop_prob=0.5,
|
|
203
203
|
final_fc_length="auto",
|
|
204
204
|
return_features=False,
|
|
205
|
-
activation:
|
|
206
|
-
activation_transfor:
|
|
205
|
+
activation: nn.Module = nn.ELU,
|
|
206
|
+
activation_transfor: nn.Module = nn.GELU,
|
|
207
207
|
n_times=None,
|
|
208
208
|
chs_info=None,
|
|
209
209
|
input_window_seconds=None,
|
|
@@ -250,9 +250,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
250
250
|
self.final_fc_length = final_fc_length
|
|
251
251
|
|
|
252
252
|
self.transformer = _TransformerEncoder(
|
|
253
|
-
|
|
253
|
+
att_depth=att_depth,
|
|
254
254
|
emb_size=n_filters_time,
|
|
255
|
-
|
|
255
|
+
att_heads=att_heads,
|
|
256
256
|
att_drop=att_drop_prob,
|
|
257
257
|
activation=activation_transfor,
|
|
258
258
|
)
|
|
@@ -284,7 +284,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
284
284
|
|
|
285
285
|
|
|
286
286
|
class _PatchEmbedding(nn.Module):
|
|
287
|
-
|
|
287
|
+
"""Patch Embedding.
|
|
288
288
|
|
|
289
289
|
The authors used a convolution module to capture local features,
|
|
290
290
|
instead of position embedding.
|
|
@@ -318,7 +318,7 @@ class _PatchEmbedding(nn.Module):
|
|
|
318
318
|
pool_time_length,
|
|
319
319
|
stride_avg_pool,
|
|
320
320
|
drop_prob,
|
|
321
|
-
activation:
|
|
321
|
+
activation: nn.Module = nn.ELU,
|
|
322
322
|
):
|
|
323
323
|
super().__init__()
|
|
324
324
|
|
|
@@ -364,16 +364,16 @@ class _TransformerEncoderBlock(nn.Sequential):
|
|
|
364
364
|
def __init__(
|
|
365
365
|
self,
|
|
366
366
|
emb_size,
|
|
367
|
-
|
|
367
|
+
att_heads,
|
|
368
368
|
att_drop,
|
|
369
369
|
forward_expansion=4,
|
|
370
|
-
activation:
|
|
370
|
+
activation: nn.Module = nn.GELU,
|
|
371
371
|
):
|
|
372
372
|
super().__init__(
|
|
373
373
|
_ResidualAdd(
|
|
374
374
|
nn.Sequential(
|
|
375
375
|
nn.LayerNorm(emb_size),
|
|
376
|
-
MultiHeadAttention(emb_size,
|
|
376
|
+
MultiHeadAttention(emb_size, att_heads, att_drop),
|
|
377
377
|
nn.Dropout(att_drop),
|
|
378
378
|
)
|
|
379
379
|
),
|
|
@@ -393,17 +393,17 @@ class _TransformerEncoderBlock(nn.Sequential):
|
|
|
393
393
|
|
|
394
394
|
|
|
395
395
|
class _TransformerEncoder(nn.Sequential):
|
|
396
|
-
|
|
396
|
+
"""Transformer encoder module for the transformer encoder.
|
|
397
397
|
|
|
398
398
|
Similar to the layers used in ViT.
|
|
399
399
|
|
|
400
400
|
Parameters
|
|
401
401
|
----------
|
|
402
|
-
|
|
402
|
+
att_depth : int
|
|
403
403
|
Number of transformer encoder blocks.
|
|
404
404
|
emb_size : int
|
|
405
405
|
Embedding size of the transformer encoder.
|
|
406
|
-
|
|
406
|
+
att_heads : int
|
|
407
407
|
Number of attention heads.
|
|
408
408
|
att_drop : float
|
|
409
409
|
Dropout probability for the attention layers.
|
|
@@ -411,19 +411,14 @@ class _TransformerEncoder(nn.Sequential):
|
|
|
411
411
|
"""
|
|
412
412
|
|
|
413
413
|
def __init__(
|
|
414
|
-
self,
|
|
415
|
-
num_layers,
|
|
416
|
-
emb_size,
|
|
417
|
-
num_heads,
|
|
418
|
-
att_drop,
|
|
419
|
-
activation: type[nn.Module] = nn.GELU,
|
|
414
|
+
self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
|
|
420
415
|
):
|
|
421
416
|
super().__init__(
|
|
422
417
|
*[
|
|
423
418
|
_TransformerEncoderBlock(
|
|
424
|
-
emb_size,
|
|
419
|
+
emb_size, att_heads, att_drop, activation=activation
|
|
425
420
|
)
|
|
426
|
-
for _ in range(
|
|
421
|
+
for _ in range(att_depth)
|
|
427
422
|
]
|
|
428
423
|
)
|
|
429
424
|
|
|
@@ -436,7 +431,7 @@ class _FullyConnected(nn.Module):
|
|
|
436
431
|
drop_prob_2=0.3,
|
|
437
432
|
out_channels=256,
|
|
438
433
|
hidden_channels=32,
|
|
439
|
-
activation:
|
|
434
|
+
activation: nn.Module = nn.ELU,
|
|
440
435
|
):
|
|
441
436
|
"""Fully-connected layer for the transformer encoder.
|
|
442
437
|
|