braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/models/labram.py
CHANGED
|
@@ -21,9 +21,9 @@ from braindecode.modules import MLP, DropPath
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class Labram(EEGModuleMixin, nn.Module):
|
|
24
|
-
|
|
24
|
+
"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
|
|
25
25
|
|
|
26
|
-
:bdg-success:`Convolution` :bdg-danger:`
|
|
26
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
27
27
|
|
|
28
28
|
.. figure:: https://arxiv.org/html/2405.18765v1/x1.png
|
|
29
29
|
:align: center
|
|
@@ -64,30 +64,6 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
64
64
|
- LayerNorm: Apply layer normalization to the data;
|
|
65
65
|
- Linear: An head linear layer to transformer the data into classes.
|
|
66
66
|
|
|
67
|
-
.. important::
|
|
68
|
-
**Pre-trained Weights Available**
|
|
69
|
-
|
|
70
|
-
This model has pre-trained weights available on the Hugging Face Hub.
|
|
71
|
-
You can load them using:
|
|
72
|
-
|
|
73
|
-
.. code-block:: python
|
|
74
|
-
|
|
75
|
-
from braindecode.models import Labram
|
|
76
|
-
|
|
77
|
-
# Load pre-trained model from Hugging Face Hub
|
|
78
|
-
model = Labram.from_pretrained("braindecode/labram-pretrained")
|
|
79
|
-
|
|
80
|
-
To push your own trained model to the Hub:
|
|
81
|
-
|
|
82
|
-
.. code-block:: python
|
|
83
|
-
|
|
84
|
-
# After training your model
|
|
85
|
-
model.push_to_hub(
|
|
86
|
-
repo_id="username/my-labram-model", commit_message="Upload trained Labram model"
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
90
|
-
|
|
91
67
|
.. versionadded:: 0.9
|
|
92
68
|
|
|
93
69
|
|
|
@@ -107,15 +83,15 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
107
83
|
----------
|
|
108
84
|
patch_size : int
|
|
109
85
|
The size of the patch to be used in the patch embedding.
|
|
110
|
-
|
|
86
|
+
emb_size : int
|
|
111
87
|
The dimension of the embedding.
|
|
112
|
-
|
|
88
|
+
in_conv_channels : int
|
|
113
89
|
The number of convolutional input channels.
|
|
114
|
-
|
|
90
|
+
out_channels : int
|
|
115
91
|
The number of convolutional output channels.
|
|
116
|
-
|
|
92
|
+
n_layers : int (default=12)
|
|
117
93
|
The number of attention layers of the model.
|
|
118
|
-
|
|
94
|
+
att_num_heads : int (default=10)
|
|
119
95
|
The number of attention heads.
|
|
120
96
|
mlp_ratio : float (default=4.0)
|
|
121
97
|
The expansion ratio of the mlp layer
|
|
@@ -179,26 +155,26 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
179
155
|
sfreq=None,
|
|
180
156
|
input_window_seconds=None,
|
|
181
157
|
patch_size=200,
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
158
|
+
emb_size=200,
|
|
159
|
+
in_conv_channels=1,
|
|
160
|
+
out_channels=8,
|
|
161
|
+
n_layers=12,
|
|
162
|
+
att_num_heads=10,
|
|
187
163
|
mlp_ratio=4.0,
|
|
188
164
|
qkv_bias=False,
|
|
189
|
-
qk_norm
|
|
165
|
+
qk_norm=nn.LayerNorm,
|
|
190
166
|
qk_scale=None,
|
|
191
167
|
drop_prob=0.0,
|
|
192
168
|
attn_drop_prob=0.0,
|
|
193
169
|
drop_path_prob=0.0,
|
|
194
|
-
norm_layer
|
|
170
|
+
norm_layer=nn.LayerNorm,
|
|
195
171
|
init_values=0.1,
|
|
196
172
|
use_abs_pos_emb=True,
|
|
197
173
|
use_mean_pooling=True,
|
|
198
174
|
init_scale=0.001,
|
|
199
175
|
neural_tokenizer=True,
|
|
200
176
|
attn_head_dim=None,
|
|
201
|
-
activation:
|
|
177
|
+
activation: nn.Module = nn.GELU,
|
|
202
178
|
):
|
|
203
179
|
super().__init__(
|
|
204
180
|
n_outputs=n_outputs,
|
|
@@ -211,7 +187,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
211
187
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
212
188
|
|
|
213
189
|
self.patch_size = patch_size
|
|
214
|
-
self.num_features = self.
|
|
190
|
+
self.num_features = self.emb_size = emb_size
|
|
215
191
|
self.neural_tokenizer = neural_tokenizer
|
|
216
192
|
self.init_scale = init_scale
|
|
217
193
|
|
|
@@ -223,20 +199,20 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
223
199
|
)
|
|
224
200
|
self.patch_size = self.n_times
|
|
225
201
|
self.num_features = None
|
|
226
|
-
self.
|
|
202
|
+
self.emb_size = None
|
|
227
203
|
else:
|
|
228
204
|
self.patch_size = patch_size
|
|
229
205
|
self.n_path = self.n_times // self.patch_size
|
|
230
206
|
|
|
231
|
-
if neural_tokenizer and
|
|
207
|
+
if neural_tokenizer and in_conv_channels != 1:
|
|
232
208
|
warn(
|
|
233
209
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
234
|
-
+ "`
|
|
235
|
-
+ "`
|
|
236
|
-
+ "
|
|
210
|
+
+ "`in_conv_channels` is different from the default values."
|
|
211
|
+
+ "`in_conv_channels` is only needed for the Neural Decoder mode."
|
|
212
|
+
+ "in_conv_channels is not used in the Neural Tokenizer mode.",
|
|
237
213
|
UserWarning,
|
|
238
214
|
)
|
|
239
|
-
|
|
215
|
+
in_conv_channels = 1
|
|
240
216
|
# If you can use the model in Neural Tokenizer mode,
|
|
241
217
|
# temporal conv layer will be use over the patched dataset
|
|
242
218
|
if neural_tokenizer:
|
|
@@ -255,7 +231,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
255
231
|
(
|
|
256
232
|
"temporal_conv",
|
|
257
233
|
_TemporalConv(
|
|
258
|
-
out_channels=
|
|
234
|
+
out_channels=out_channels, activation=activation
|
|
259
235
|
),
|
|
260
236
|
),
|
|
261
237
|
]
|
|
@@ -273,8 +249,8 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
273
249
|
_PatchEmbed(
|
|
274
250
|
n_times=self.n_times,
|
|
275
251
|
patch_size=patch_size,
|
|
276
|
-
in_channels=
|
|
277
|
-
emb_dim=self.
|
|
252
|
+
in_channels=in_conv_channels,
|
|
253
|
+
emb_dim=self.emb_size,
|
|
278
254
|
),
|
|
279
255
|
)
|
|
280
256
|
|
|
@@ -283,12 +259,12 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
283
259
|
out = self.patch_embed(dummy)
|
|
284
260
|
# out.shape for tokenizer: (1, n_chans, emb_dim)
|
|
285
261
|
# for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
|
|
286
|
-
self.
|
|
287
|
-
self.num_features = self.
|
|
262
|
+
self.emb_size = out.shape[-1]
|
|
263
|
+
self.num_features = self.emb_size
|
|
288
264
|
|
|
289
265
|
# Defining the parameters
|
|
290
266
|
# Creating a parameter list with cls token]
|
|
291
|
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.
|
|
267
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
|
|
292
268
|
# Positional embedding and time embedding are complementary
|
|
293
269
|
# one is for the spatial information and the other is for the temporal
|
|
294
270
|
# information.
|
|
@@ -297,26 +273,26 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
297
273
|
# information.
|
|
298
274
|
if use_abs_pos_emb:
|
|
299
275
|
self.position_embedding = nn.Parameter(
|
|
300
|
-
torch.zeros(1, self.n_chans + 1, self.
|
|
276
|
+
torch.zeros(1, self.n_chans + 1, self.emb_size),
|
|
301
277
|
requires_grad=True,
|
|
302
278
|
)
|
|
303
279
|
else:
|
|
304
280
|
self.position_embedding = None
|
|
305
281
|
|
|
306
282
|
self.temporal_embedding = nn.Parameter(
|
|
307
|
-
torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.
|
|
283
|
+
torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.emb_size),
|
|
308
284
|
requires_grad=True,
|
|
309
285
|
)
|
|
310
286
|
self.pos_drop = nn.Dropout(p=drop_prob)
|
|
311
287
|
|
|
312
288
|
dpr = [
|
|
313
|
-
x.item() for x in torch.linspace(0, drop_path_prob,
|
|
289
|
+
x.item() for x in torch.linspace(0, drop_path_prob, n_layers)
|
|
314
290
|
] # stochastic depth decay rule
|
|
315
291
|
self.blocks = nn.ModuleList(
|
|
316
292
|
[
|
|
317
293
|
_WindowsAttentionBlock(
|
|
318
|
-
dim=self.
|
|
319
|
-
num_heads=
|
|
294
|
+
dim=self.emb_size,
|
|
295
|
+
num_heads=att_num_heads,
|
|
320
296
|
mlp_ratio=mlp_ratio,
|
|
321
297
|
qkv_bias=qkv_bias,
|
|
322
298
|
qk_norm=qk_norm,
|
|
@@ -334,14 +310,14 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
334
310
|
attn_head_dim=attn_head_dim,
|
|
335
311
|
activation=activation,
|
|
336
312
|
)
|
|
337
|
-
for i in range(
|
|
313
|
+
for i in range(n_layers)
|
|
338
314
|
]
|
|
339
315
|
)
|
|
340
|
-
self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.
|
|
341
|
-
self.fc_norm = norm_layer(self.
|
|
316
|
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.emb_size)
|
|
317
|
+
self.fc_norm = norm_layer(self.emb_size) if use_mean_pooling else None
|
|
342
318
|
|
|
343
319
|
if self.n_outputs > 0:
|
|
344
|
-
self.final_layer = nn.Linear(self.
|
|
320
|
+
self.final_layer = nn.Linear(self.emb_size, self.n_outputs)
|
|
345
321
|
else:
|
|
346
322
|
self.final_layer = nn.Identity()
|
|
347
323
|
|
|
@@ -439,7 +415,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
439
415
|
x = self.patch_embed(x)
|
|
440
416
|
# x shape: (batch, n_chans, emb_dim)
|
|
441
417
|
n_patch = self.n_chans
|
|
442
|
-
temporal = self.
|
|
418
|
+
temporal = self.emb_size
|
|
443
419
|
else:
|
|
444
420
|
# For neural decoder: input is (batch, n_chans, n_times)
|
|
445
421
|
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
@@ -486,7 +462,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
486
462
|
# In decoder mode, we have n_patch patches and don't need to expand
|
|
487
463
|
# Just broadcast the temporal embedding
|
|
488
464
|
if temporal is None:
|
|
489
|
-
temporal = self.
|
|
465
|
+
temporal = self.emb_size
|
|
490
466
|
|
|
491
467
|
# Get temporal embeddings for n_patch patches
|
|
492
468
|
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
@@ -661,7 +637,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
661
637
|
|
|
662
638
|
|
|
663
639
|
class _SegmentPatch(nn.Module):
|
|
664
|
-
|
|
640
|
+
"""Segment and Patch for EEG data.
|
|
665
641
|
|
|
666
642
|
Adapted Patch Embedding inspired in the Visual Transform approach
|
|
667
643
|
to extract the learned segmentor, we expect get the input shape as:
|
|
@@ -767,7 +743,7 @@ class _SegmentPatch(nn.Module):
|
|
|
767
743
|
|
|
768
744
|
|
|
769
745
|
class _PatchEmbed(nn.Module):
|
|
770
|
-
|
|
746
|
+
"""EEG to Patch Embedding for Neural Decoder mode.
|
|
771
747
|
|
|
772
748
|
This code is used when we want to apply the patch embedding
|
|
773
749
|
after the codebook layer (Neural Decoder mode).
|
|
@@ -873,7 +849,7 @@ class _PatchEmbed(nn.Module):
|
|
|
873
849
|
|
|
874
850
|
|
|
875
851
|
class _Attention(nn.Module):
|
|
876
|
-
|
|
852
|
+
"""
|
|
877
853
|
Attention with the options of Window-based multi-head self attention (W-MSA).
|
|
878
854
|
|
|
879
855
|
This code is strong inspired by:
|
|
@@ -1071,7 +1047,7 @@ class _Attention(nn.Module):
|
|
|
1071
1047
|
|
|
1072
1048
|
|
|
1073
1049
|
class _WindowsAttentionBlock(nn.Module):
|
|
1074
|
-
|
|
1050
|
+
"""Blocks of Windows Attention with Layer norm and MLP.
|
|
1075
1051
|
|
|
1076
1052
|
Notes: This code is strong inspired by:
|
|
1077
1053
|
BeiTv2 from Microsoft.
|
|
@@ -1130,7 +1106,7 @@ class _WindowsAttentionBlock(nn.Module):
|
|
|
1130
1106
|
attn_drop=0.0,
|
|
1131
1107
|
drop_path=0.0,
|
|
1132
1108
|
init_values=None,
|
|
1133
|
-
activation:
|
|
1109
|
+
activation: nn.Module = nn.GELU,
|
|
1134
1110
|
norm_layer=nn.LayerNorm,
|
|
1135
1111
|
window_size=None,
|
|
1136
1112
|
attn_head_dim=None,
|
|
@@ -1206,7 +1182,7 @@ class _WindowsAttentionBlock(nn.Module):
|
|
|
1206
1182
|
|
|
1207
1183
|
|
|
1208
1184
|
class _TemporalConv(nn.Module):
|
|
1209
|
-
|
|
1185
|
+
"""
|
|
1210
1186
|
Temporal Convolutional Module inspired by Visual Transformer.
|
|
1211
1187
|
|
|
1212
1188
|
In this module we apply the follow steps three times repeatedly
|
|
@@ -1253,7 +1229,7 @@ class _TemporalConv(nn.Module):
|
|
|
1253
1229
|
padding_1=(0, 7),
|
|
1254
1230
|
kernel_size_2=(1, 3),
|
|
1255
1231
|
padding_2=(0, 1),
|
|
1256
|
-
activation:
|
|
1232
|
+
activation: nn.Module = nn.GELU,
|
|
1257
1233
|
):
|
|
1258
1234
|
super().__init__()
|
|
1259
1235
|
|
braindecode/models/luna.py
CHANGED
|
@@ -27,9 +27,9 @@ from braindecode.modules.layers import DropPath
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class LUNA(EEGModuleMixin, nn.Module):
|
|
30
|
-
|
|
30
|
+
"""LUNA from Döner et al. [LUNA]_.
|
|
31
31
|
|
|
32
|
-
:bdg-success:`Convolution` :bdg-danger:`
|
|
32
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model` :bdg-dark-line:`Channel`
|
|
33
33
|
|
|
34
34
|
.. figure:: https://arxiv.org/html/2510.22257v1/x1.png
|
|
35
35
|
:align: center
|
|
@@ -44,61 +44,6 @@ class LUNA(EEGModuleMixin, nn.Module):
|
|
|
44
44
|
3. Patch-wise Temporal Encoder (RoPE-based transformer)
|
|
45
45
|
4. Decoder Heads (classification or reconstruction)
|
|
46
46
|
|
|
47
|
-
.. important::
|
|
48
|
-
**Pre-trained Weights Available**
|
|
49
|
-
|
|
50
|
-
This model has pre-trained weights available on the Hugging Face Hub
|
|
51
|
-
at `thorir/LUNA <https://huggingface.co/thorir/LUNA>`_.
|
|
52
|
-
|
|
53
|
-
Available model variants:
|
|
54
|
-
|
|
55
|
-
- **LUNA_base.safetensors** - Base model (embed_dim=64, num_queries=4, depth=8)
|
|
56
|
-
- **LUNA_large.safetensors** - Large model (embed_dim=96, num_queries=6, depth=10)
|
|
57
|
-
- **LUNA_huge.safetensors** - Huge model (embed_dim=128, num_queries=8, depth=24)
|
|
58
|
-
|
|
59
|
-
Example loading for fine-tuning:
|
|
60
|
-
|
|
61
|
-
.. code-block:: python
|
|
62
|
-
|
|
63
|
-
from huggingface_hub import hf_hub_download
|
|
64
|
-
from safetensors.torch import load_file
|
|
65
|
-
from braindecode.models import LUNA
|
|
66
|
-
|
|
67
|
-
# Download pre-trained weights
|
|
68
|
-
model_path = hf_hub_download(
|
|
69
|
-
repo_id="thorir/LUNA",
|
|
70
|
-
filename="LUNA_base.safetensors",
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
# Create model for classification (fine-tuning)
|
|
74
|
-
model = LUNA(
|
|
75
|
-
n_outputs=2, # Number of classes for your task
|
|
76
|
-
n_chans=22,
|
|
77
|
-
n_times=1000,
|
|
78
|
-
embed_dim=64,
|
|
79
|
-
num_queries=4,
|
|
80
|
-
depth=8,
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
# Load pre-trained encoder weights
|
|
84
|
-
state_dict = load_file(model_path)
|
|
85
|
-
# Apply key mapping for pretrained weights
|
|
86
|
-
mapping = model.mapping.copy()
|
|
87
|
-
mapping["cross_attn.temparature"] = "cross_attn.temperature"
|
|
88
|
-
mapped_state_dict = {mapping.get(k, k): v for k, v in state_dict.items()}
|
|
89
|
-
model.load_state_dict(mapped_state_dict, strict=False)
|
|
90
|
-
|
|
91
|
-
To push your own trained model to the Hub:
|
|
92
|
-
|
|
93
|
-
.. code-block:: python
|
|
94
|
-
|
|
95
|
-
# After training your model
|
|
96
|
-
model.push_to_hub(
|
|
97
|
-
repo_id="username/my-luna-model", commit_message="Upload trained LUNA model"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
101
|
-
|
|
102
47
|
Parameters
|
|
103
48
|
----------
|
|
104
49
|
patch_size : int
|
|
@@ -457,7 +402,7 @@ def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Ten
|
|
|
457
402
|
|
|
458
403
|
|
|
459
404
|
class _ChannelEmbeddings(nn.Module):
|
|
460
|
-
|
|
405
|
+
"""
|
|
461
406
|
This class creates embeddings for each EEG channel based on a predefined
|
|
462
407
|
mapping of channel names to indices.
|
|
463
408
|
|
|
@@ -485,7 +430,7 @@ class _ChannelEmbeddings(nn.Module):
|
|
|
485
430
|
|
|
486
431
|
|
|
487
432
|
class _FrequencyFeatureEmbedder(nn.Module):
|
|
488
|
-
|
|
433
|
+
"""
|
|
489
434
|
This class takes data that is of the form (B, C, T) and patches it
|
|
490
435
|
along the time dimension (T) into patches of size P (patch_size).
|
|
491
436
|
The output is of the form (B, C, S, P) where S = T // P.
|
|
@@ -861,7 +806,7 @@ class _PatchEmbedNetwork(nn.Module):
|
|
|
861
806
|
|
|
862
807
|
|
|
863
808
|
class _Mlp(nn.Module):
|
|
864
|
-
|
|
809
|
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks.
|
|
865
810
|
|
|
866
811
|
Code copied from timm.models.mlp.Mlp
|
|
867
812
|
"""
|
braindecode/models/medformer.py
CHANGED
|
@@ -18,21 +18,20 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class MEDFormer(EEGModuleMixin, nn.Module):
|
|
21
|
-
r"""
|
|
22
|
-
Medformer from Wang et al (2024) [Medformer2024]_.
|
|
21
|
+
r"""Medformer from Wang et al. (2024) [Medformer2024]_.
|
|
23
22
|
|
|
24
|
-
:bdg-success:`Convolution` :bdg-danger:`
|
|
23
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
25
24
|
|
|
26
25
|
.. figure:: https://raw.githubusercontent.com/DL4mHealth/Medformer/refs/heads/main/figs/medformer_architecture.png
|
|
27
26
|
:align: center
|
|
28
27
|
:alt: MEDFormer Architecture.
|
|
29
28
|
|
|
30
|
-
a) Workflow. b) For the input sample :math:`{x}_{
|
|
29
|
+
a) Workflow. b) For the input sample :math:`{x}_{\\textrm{in}}`, the authors apply :math:`n`
|
|
31
30
|
different patch lengths in parallel to create patched features :math:`{x}_p^{(i)}`, where :math:`i`
|
|
32
31
|
ranges from 1 to :math:`n`. Each patch length represents a different granularity. These patched
|
|
33
|
-
features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math
|
|
34
|
-
c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math
|
|
35
|
-
positional embedding :math:`{W}_{
|
|
32
|
+
features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\\widetilde{x}_e^{(i)}`.
|
|
33
|
+
c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\\widetilde{{x}}_e^{(i)}` with the
|
|
34
|
+
positional embedding :math:`{W}_{\\text{pos}}` and the granularity embedding :math:`{W}_{\\text{gr}}^{(i)}`.
|
|
36
35
|
Each granularity employs a router :math:`{u}^{(i)}` to capture aggregated information.
|
|
37
36
|
Intra-granularity attention focuses within individual granularities, and inter-granularity attention
|
|
38
37
|
leverages the routers to integrate information across granularities.
|
|
@@ -116,7 +115,6 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
116
115
|
**Role.** Learns representations and correlations within and across temporal scales while
|
|
117
116
|
reducing complexity from :math:`O((\sum_i N_i)^2)` to
|
|
118
117
|
:math:`O(\sum_i N_i^2 + n^2)` through the router mechanism.
|
|
119
|
-
|
|
120
118
|
.. rubric:: Temporal, Spatial, and Spectral Encoding
|
|
121
119
|
|
|
122
120
|
- **Temporal:** Multiple patch lengths in :attr:`patch_len_list` capture features at several
|
|
@@ -130,7 +128,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
130
128
|
.. rubric:: Additional Mechanisms
|
|
131
129
|
|
|
132
130
|
- **Granularity router:** Each granularity :math:`i` receives a dedicated router token
|
|
133
|
-
:math
|
|
131
|
+
:math:`\\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
|
|
134
132
|
aggregated information across scales.
|
|
135
133
|
- **Complexity:** Router-mediated two-stage attention maintains :math:`O(T^2)` complexity for
|
|
136
134
|
suitable patch lengths (e.g., power series), preserving transformer-like efficiency while
|
|
@@ -141,7 +139,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
141
139
|
patch_len_list : list of int, optional
|
|
142
140
|
Patch lengths for multi-granularity patching; each entry selects a temporal scale.
|
|
143
141
|
The default is ``[14, 44, 45]``.
|
|
144
|
-
|
|
142
|
+
d_model : int, optional
|
|
145
143
|
Embedding dimensionality. The default is ``128``.
|
|
146
144
|
num_heads : int, optional
|
|
147
145
|
Number of attention heads, which must divide :attr:`d_model`. The default is ``8``.
|
|
@@ -149,7 +147,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
149
147
|
Dropout probability. The default is ``0.1``.
|
|
150
148
|
no_inter_attn : bool, optional
|
|
151
149
|
If ``True``, disables inter-granularity attention. The default is ``False``.
|
|
152
|
-
|
|
150
|
+
n_layers : int, optional
|
|
153
151
|
Number of encoder layers. The default is ``6``.
|
|
154
152
|
dim_feedforward : int, optional
|
|
155
153
|
Feedforward dimensionality. The default is ``256``.
|
|
@@ -191,16 +189,16 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
191
189
|
sfreq=None,
|
|
192
190
|
# Model parameters
|
|
193
191
|
patch_len_list: Optional[List[int]] = None,
|
|
194
|
-
|
|
192
|
+
d_model: int = 128,
|
|
195
193
|
num_heads: int = 8,
|
|
196
194
|
drop_prob: float = 0.1,
|
|
197
195
|
no_inter_attn: bool = False,
|
|
198
|
-
|
|
196
|
+
n_layers: int = 6,
|
|
199
197
|
dim_feedforward: int = 256,
|
|
200
|
-
activation_trans:
|
|
198
|
+
activation_trans: Optional[nn.Module] = nn.ReLU,
|
|
201
199
|
single_channel: bool = False,
|
|
202
200
|
output_attention: bool = True,
|
|
203
|
-
activation_class:
|
|
201
|
+
activation_class: Optional[nn.Module] = nn.GELU,
|
|
204
202
|
):
|
|
205
203
|
super().__init__(
|
|
206
204
|
n_outputs=n_outputs,
|
|
@@ -217,11 +215,11 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
217
215
|
# - enc_in refers to the number of time points
|
|
218
216
|
|
|
219
217
|
# Save model parameters as instance variables
|
|
220
|
-
self.
|
|
218
|
+
self.d_model = d_model
|
|
221
219
|
self.num_heads = num_heads
|
|
222
220
|
self.drop_prob = drop_prob
|
|
223
221
|
self.no_inter_attn = no_inter_attn
|
|
224
|
-
self.
|
|
222
|
+
self.n_layers = n_layers
|
|
225
223
|
self.dim_feedforward = dim_feedforward
|
|
226
224
|
self.activation_trans = activation_trans
|
|
227
225
|
self.output_attention = output_attention
|
|
@@ -244,7 +242,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
244
242
|
# Initialize the embedding layer.
|
|
245
243
|
self.enc_embedding = _ListPatchEmbedding(
|
|
246
244
|
enc_in=self.n_times,
|
|
247
|
-
d_model=self.
|
|
245
|
+
d_model=self.d_model,
|
|
248
246
|
seq_len=self.n_chans,
|
|
249
247
|
patch_len_list=self.patch_len_list,
|
|
250
248
|
stride_list=self.stride_list,
|
|
@@ -259,22 +257,22 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
259
257
|
_EncoderLayer(
|
|
260
258
|
attention=_MedformerLayer(
|
|
261
259
|
num_blocks=len(self.patch_len_list),
|
|
262
|
-
d_model=self.
|
|
260
|
+
d_model=self.d_model,
|
|
263
261
|
num_heads=self.num_heads,
|
|
264
262
|
dropout=self.drop_prob,
|
|
265
263
|
output_attention=self.output_attention,
|
|
266
264
|
no_inter=self.no_inter_attn,
|
|
267
265
|
),
|
|
268
|
-
d_model=self.
|
|
266
|
+
d_model=self.d_model,
|
|
269
267
|
dim_feedforward=self.dim_feedforward,
|
|
270
268
|
dropout=self.drop_prob,
|
|
271
269
|
activation=self.activation_trans()
|
|
272
270
|
if self.activation_trans is not None
|
|
273
271
|
else nn.ReLU(),
|
|
274
272
|
)
|
|
275
|
-
for _ in range(self.
|
|
273
|
+
for _ in range(self.n_layers)
|
|
276
274
|
],
|
|
277
|
-
norm_layer=torch.nn.LayerNorm(self.
|
|
275
|
+
norm_layer=torch.nn.LayerNorm(self.d_model),
|
|
278
276
|
)
|
|
279
277
|
|
|
280
278
|
# For classification tasks, add additional layers.
|
|
@@ -283,7 +281,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
|
|
|
283
281
|
)
|
|
284
282
|
self.dropout = nn.Dropout(self.drop_prob)
|
|
285
283
|
self.final_layer = nn.Linear(
|
|
286
|
-
self.
|
|
284
|
+
self.d_model
|
|
287
285
|
* len(self.patch_num_list)
|
|
288
286
|
* (1 if not self.single_channel else self.n_chans),
|
|
289
287
|
self.n_outputs,
|
braindecode/models/msvtnet.py
CHANGED
|
@@ -11,9 +11,9 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class MSVTNet(EEGModuleMixin, nn.Module):
|
|
14
|
-
|
|
14
|
+
"""MSVTNet model from Liu K et al (2024) from [msvt2024]_.
|
|
15
15
|
|
|
16
|
-
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention
|
|
16
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
|
|
17
17
|
|
|
18
18
|
This model implements a multi-scale convolutional transformer network
|
|
19
19
|
for EEG signal classification, as described in [msvt2024]_.
|
|
@@ -41,9 +41,9 @@ class MSVTNet(EEGModuleMixin, nn.Module):
|
|
|
41
41
|
Dropout probability for convolutional layers, by default 0.3.
|
|
42
42
|
num_heads : int, optional
|
|
43
43
|
Number of attention heads in the transformer encoder, by default 8.
|
|
44
|
-
|
|
44
|
+
feedforward_ratio : float, optional
|
|
45
45
|
Ratio to compute feedforward dimension in the transformer, by default 1.
|
|
46
|
-
|
|
46
|
+
drop_prob_trans : float, optional
|
|
47
47
|
Dropout probability for the transformer, by default 0.5.
|
|
48
48
|
num_layers : int, optional
|
|
49
49
|
Number of transformer encoder layers, by default 2.
|
|
@@ -85,8 +85,8 @@ class MSVTNet(EEGModuleMixin, nn.Module):
|
|
|
85
85
|
pool2_size: int = 7,
|
|
86
86
|
drop_prob: float = 0.3,
|
|
87
87
|
num_heads: int = 8,
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
feedforward_ratio: float = 1,
|
|
89
|
+
drop_prob_trans: float = 0.5,
|
|
90
90
|
num_layers: int = 2,
|
|
91
91
|
activation: Type[nn.Module] = nn.ELU,
|
|
92
92
|
return_features: bool = False,
|
|
@@ -139,8 +139,8 @@ class MSVTNet(EEGModuleMixin, nn.Module):
|
|
|
139
139
|
seq_len,
|
|
140
140
|
d_model,
|
|
141
141
|
num_heads,
|
|
142
|
-
|
|
143
|
-
|
|
142
|
+
feedforward_ratio,
|
|
143
|
+
drop_prob_trans,
|
|
144
144
|
num_layers,
|
|
145
145
|
)
|
|
146
146
|
|
|
@@ -193,7 +193,7 @@ class MSVTNet(EEGModuleMixin, nn.Module):
|
|
|
193
193
|
|
|
194
194
|
|
|
195
195
|
class _TSConv(nn.Sequential):
|
|
196
|
-
|
|
196
|
+
"""
|
|
197
197
|
Time-Distributed Separable Convolution block.
|
|
198
198
|
|
|
199
199
|
The architecture consists of:
|
|
@@ -280,7 +280,7 @@ class _TSConv(nn.Sequential):
|
|
|
280
280
|
|
|
281
281
|
|
|
282
282
|
class _PositionalEncoding(nn.Module):
|
|
283
|
-
|
|
283
|
+
"""
|
|
284
284
|
Positional encoding module that adds learnable positional embeddings.
|
|
285
285
|
|
|
286
286
|
Parameters
|
|
@@ -303,7 +303,7 @@ class _PositionalEncoding(nn.Module):
|
|
|
303
303
|
|
|
304
304
|
|
|
305
305
|
class _Transformer(nn.Module):
|
|
306
|
-
|
|
306
|
+
"""
|
|
307
307
|
Transformer encoder module with learnable class token and positional encoding.
|
|
308
308
|
|
|
309
309
|
Parameters
|
|
@@ -314,7 +314,7 @@ class _Transformer(nn.Module):
|
|
|
314
314
|
Dimensionality of the model.
|
|
315
315
|
num_heads : int
|
|
316
316
|
Number of heads in the multihead attention.
|
|
317
|
-
|
|
317
|
+
feedforward_ratio : float
|
|
318
318
|
Ratio to compute the dimension of the feedforward network.
|
|
319
319
|
drop_prob : float, optional
|
|
320
320
|
Dropout probability, by default 0.5.
|
|
@@ -327,7 +327,7 @@ class _Transformer(nn.Module):
|
|
|
327
327
|
seq_length: int,
|
|
328
328
|
d_model: int,
|
|
329
329
|
num_heads: int,
|
|
330
|
-
|
|
330
|
+
feedforward_ratio: float,
|
|
331
331
|
drop_prob: float = 0.5,
|
|
332
332
|
num_layers: int = 4,
|
|
333
333
|
) -> None:
|
|
@@ -335,7 +335,7 @@ class _Transformer(nn.Module):
|
|
|
335
335
|
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
336
336
|
self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
|
|
337
337
|
|
|
338
|
-
dim_ff = int(d_model *
|
|
338
|
+
dim_ff = int(d_model * feedforward_ratio)
|
|
339
339
|
self.dropout = nn.Dropout(drop_prob)
|
|
340
340
|
self.trans = nn.TransformerEncoder(
|
|
341
341
|
nn.TransformerEncoderLayer(
|
|
@@ -359,7 +359,7 @@ class _Transformer(nn.Module):
|
|
|
359
359
|
|
|
360
360
|
|
|
361
361
|
class _DenseLayers(nn.Sequential):
|
|
362
|
-
|
|
362
|
+
"""
|
|
363
363
|
Final classification layers.
|
|
364
364
|
|
|
365
365
|
Parameters
|