braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -14,9 +14,9 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PBT(EEGModuleMixin, nn.Module):
|
|
17
|
-
r"""Patched Brain Transformer (PBT) model from Klein et al (2025) [pbt]_.
|
|
17
|
+
r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
|
|
18
18
|
|
|
19
|
-
:bdg-danger:`
|
|
19
|
+
:bdg-danger:`Large Brain Model`
|
|
20
20
|
|
|
21
21
|
This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
|
|
22
22
|
|
|
@@ -136,7 +136,7 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
136
136
|
|
|
137
137
|
.. warning::
|
|
138
138
|
|
|
139
|
-
**Important:** As the other
|
|
139
|
+
**Important:** As the other Large Brain Models in Braindecode, :class:`PBT` is
|
|
140
140
|
designed for large-scale pre-training and fine-tuning. Training from
|
|
141
141
|
scratch on small datasets may lead to suboptimal results. Cross-Dataset
|
|
142
142
|
pre-training and subsequent fine-tuning is recommended to leverage the
|
|
@@ -146,9 +146,9 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
146
146
|
----------
|
|
147
147
|
d_input : int, optional
|
|
148
148
|
Size (in samples) of each patch (token) extracted along the time axis.
|
|
149
|
-
|
|
149
|
+
d_model : int, optional
|
|
150
150
|
Transformer embedding dimensionality.
|
|
151
|
-
|
|
151
|
+
n_blocks : int, optional
|
|
152
152
|
Number of Transformer encoder layers.
|
|
153
153
|
num_heads : int, optional
|
|
154
154
|
Number of attention heads.
|
|
@@ -190,13 +190,13 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
190
190
|
sfreq=None,
|
|
191
191
|
# Model parameters
|
|
192
192
|
d_input: int = 64,
|
|
193
|
-
|
|
194
|
-
|
|
193
|
+
d_model: int = 128,
|
|
194
|
+
n_blocks: int = 4,
|
|
195
195
|
num_heads: int = 4,
|
|
196
196
|
drop_prob: float = 0.1,
|
|
197
197
|
learnable_cls=True,
|
|
198
198
|
bias_transformer=False,
|
|
199
|
-
activation:
|
|
199
|
+
activation: nn.Module = nn.GELU,
|
|
200
200
|
) -> None:
|
|
201
201
|
super().__init__(
|
|
202
202
|
n_outputs=n_outputs,
|
|
@@ -209,8 +209,8 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
209
209
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
210
210
|
# Store hyperparameters
|
|
211
211
|
self.d_input = d_input
|
|
212
|
-
self.
|
|
213
|
-
self.
|
|
212
|
+
self.d_model = d_model
|
|
213
|
+
self.n_blocks = n_blocks
|
|
214
214
|
self.num_heads = num_heads
|
|
215
215
|
self.drop_prob = drop_prob
|
|
216
216
|
|
|
@@ -219,11 +219,11 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
219
219
|
|
|
220
220
|
# Classification token (learnable or fixed zero)
|
|
221
221
|
if learnable_cls:
|
|
222
|
-
self.cls_token = nn.Parameter(torch.randn(1, 1, self.
|
|
222
|
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.002)
|
|
223
223
|
else:
|
|
224
224
|
# non-learnable zeroed tensor
|
|
225
225
|
self.cls_token = torch.full(
|
|
226
|
-
size=(1, 1, self.
|
|
226
|
+
size=(1, 1, self.d_model),
|
|
227
227
|
fill_value=0,
|
|
228
228
|
requires_grad=False,
|
|
229
229
|
dtype=torch.float32,
|
|
@@ -234,20 +234,20 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
234
234
|
n_chans=self.n_chans, n_times=self.n_times, d_input=self.d_input
|
|
235
235
|
)
|
|
236
236
|
|
|
237
|
-
# Linear patch projection from token raw-size ->
|
|
237
|
+
# Linear patch projection from token raw-size -> d_model
|
|
238
238
|
self.patching_projection = nn.Linear(
|
|
239
|
-
in_features=self.d_input, out_features=self.
|
|
239
|
+
in_features=self.d_input, out_features=self.d_model, bias=False
|
|
240
240
|
)
|
|
241
241
|
|
|
242
|
-
# actual embedding table mapping indices ->
|
|
242
|
+
# actual embedding table mapping indices -> d_model
|
|
243
243
|
self.pos_embedding = nn.Embedding(
|
|
244
|
-
num_embeddings=self.num_embeddings + 1, embedding_dim=self.
|
|
244
|
+
num_embeddings=self.num_embeddings + 1, embedding_dim=self.d_model
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
# Transformer encoder stack
|
|
248
248
|
self.transformer_encoder = _TransformerEncoder(
|
|
249
|
-
|
|
250
|
-
|
|
249
|
+
n_blocks=n_blocks,
|
|
250
|
+
d_model=self.d_model,
|
|
251
251
|
n_head=num_heads,
|
|
252
252
|
drop_prob=drop_prob,
|
|
253
253
|
bias=bias_transformer,
|
|
@@ -256,7 +256,7 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
256
256
|
|
|
257
257
|
# classification head on classify token - CLS token
|
|
258
258
|
self.final_layer = nn.Linear(
|
|
259
|
-
in_features=
|
|
259
|
+
in_features=d_model, out_features=self.n_outputs, bias=True
|
|
260
260
|
)
|
|
261
261
|
|
|
262
262
|
# initialize weights
|
|
@@ -305,7 +305,7 @@ class PBT(EEGModuleMixin, nn.Module):
|
|
|
305
305
|
|
|
306
306
|
|
|
307
307
|
class _LayerNorm(nn.Module):
|
|
308
|
-
|
|
308
|
+
"""Layer normalization with optional bias.
|
|
309
309
|
|
|
310
310
|
Simple wrapper around :func:`torch.nn.functional.layer_norm` exposing a
|
|
311
311
|
learnable scale and optional bias.
|
|
@@ -346,7 +346,7 @@ class _LayerNorm(nn.Module):
|
|
|
346
346
|
|
|
347
347
|
|
|
348
348
|
class _MHSA(nn.Module):
|
|
349
|
-
|
|
349
|
+
"""Multi-head self-attention (MHSA) block.
|
|
350
350
|
|
|
351
351
|
Implements a standard multi-head attention mechanism with optional
|
|
352
352
|
use of PyTorch's scaled_dot_product_attention (FlashAttention) when
|
|
@@ -354,7 +354,7 @@ class _MHSA(nn.Module):
|
|
|
354
354
|
|
|
355
355
|
Parameters
|
|
356
356
|
----------
|
|
357
|
-
|
|
357
|
+
d_model : int
|
|
358
358
|
Dimensionality of the model / embeddings.
|
|
359
359
|
n_head : int
|
|
360
360
|
Number of attention heads.
|
|
@@ -366,25 +366,25 @@ class _MHSA(nn.Module):
|
|
|
366
366
|
|
|
367
367
|
def __init__(
|
|
368
368
|
self,
|
|
369
|
-
|
|
369
|
+
d_model: int,
|
|
370
370
|
n_head: int,
|
|
371
371
|
bias: bool,
|
|
372
372
|
drop_prob: float = 0.0,
|
|
373
373
|
) -> None:
|
|
374
374
|
super().__init__()
|
|
375
375
|
|
|
376
|
-
assert
|
|
376
|
+
assert d_model % n_head == 0, "d_model must be divisible by n_head"
|
|
377
377
|
|
|
378
378
|
# qkv and output projection
|
|
379
|
-
self.attn = nn.Linear(
|
|
380
|
-
self.proj = nn.Linear(
|
|
379
|
+
self.attn = nn.Linear(d_model, 3 * d_model, bias=bias)
|
|
380
|
+
self.proj = nn.Linear(d_model, d_model, bias=bias)
|
|
381
381
|
|
|
382
382
|
# dropout modules
|
|
383
383
|
self.attn_drop_prob = nn.Dropout(drop_prob)
|
|
384
384
|
self.resid_drop_prob = nn.Dropout(drop_prob)
|
|
385
385
|
|
|
386
386
|
self.n_head = n_head
|
|
387
|
-
self.
|
|
387
|
+
self.d_model = d_model
|
|
388
388
|
self.drop_prob = drop_prob
|
|
389
389
|
|
|
390
390
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -393,7 +393,7 @@ class _MHSA(nn.Module):
|
|
|
393
393
|
Parameters
|
|
394
394
|
----------
|
|
395
395
|
x : torch.Tensor
|
|
396
|
-
Input tensor of shape (B, T, C) where C ==
|
|
396
|
+
Input tensor of shape (B, T, C) where C == d_model.
|
|
397
397
|
|
|
398
398
|
Returns
|
|
399
399
|
-------
|
|
@@ -404,7 +404,7 @@ class _MHSA(nn.Module):
|
|
|
404
404
|
B, T, C = x.size()
|
|
405
405
|
|
|
406
406
|
# project to q, k, v and reshape for multi-head attention
|
|
407
|
-
q, k, v = self.attn(x).split(self.
|
|
407
|
+
q, k, v = self.attn(x).split(self.d_model, dim=2)
|
|
408
408
|
|
|
409
409
|
# (B, nh, T, hs)
|
|
410
410
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
@@ -429,14 +429,14 @@ class _MHSA(nn.Module):
|
|
|
429
429
|
|
|
430
430
|
|
|
431
431
|
class _FeedForward(nn.Module):
|
|
432
|
-
|
|
432
|
+
"""Position-wise feed-forward network from Transformer.
|
|
433
433
|
|
|
434
434
|
Implements the two-layer MLP with GELU activation and dropout used in
|
|
435
435
|
Transformer architectures.
|
|
436
436
|
|
|
437
437
|
Parameters
|
|
438
438
|
----------
|
|
439
|
-
|
|
439
|
+
d_model : int
|
|
440
440
|
Input and output dimensionality.
|
|
441
441
|
dim_feedforward : int, optional
|
|
442
442
|
Hidden dimensionality of the feed-forward layer. If None, must be provided by caller.
|
|
@@ -448,20 +448,20 @@ class _FeedForward(nn.Module):
|
|
|
448
448
|
|
|
449
449
|
def __init__(
|
|
450
450
|
self,
|
|
451
|
-
|
|
451
|
+
d_model: int,
|
|
452
452
|
dim_feedforward: Optional[int] = None,
|
|
453
453
|
drop_prob: float = 0.0,
|
|
454
454
|
bias: bool = False,
|
|
455
|
-
activation:
|
|
455
|
+
activation: nn.Module = nn.GELU,
|
|
456
456
|
) -> None:
|
|
457
457
|
super().__init__()
|
|
458
458
|
|
|
459
459
|
if dim_feedforward is None:
|
|
460
460
|
raise ValueError("dim_feedforward must be provided")
|
|
461
461
|
|
|
462
|
-
self.proj_in = nn.Linear(
|
|
462
|
+
self.proj_in = nn.Linear(d_model, dim_feedforward, bias=bias)
|
|
463
463
|
self.activation = activation()
|
|
464
|
-
self.proj = nn.Linear(dim_feedforward,
|
|
464
|
+
self.proj = nn.Linear(dim_feedforward, d_model, bias=bias)
|
|
465
465
|
self.drop_prob = nn.Dropout(drop_prob)
|
|
466
466
|
self.drop_prob1 = nn.Dropout(drop_prob)
|
|
467
467
|
|
|
@@ -476,7 +476,7 @@ class _FeedForward(nn.Module):
|
|
|
476
476
|
|
|
477
477
|
|
|
478
478
|
class _TransformerEncoderLayer(nn.Module):
|
|
479
|
-
|
|
479
|
+
"""Single Transformer encoder layer (pre-norm) combining MHSA and feed-forward.
|
|
480
480
|
|
|
481
481
|
The block follows the pattern:
|
|
482
482
|
x <- x + MHSA(_LayerNorm(x))
|
|
@@ -485,27 +485,27 @@ class _TransformerEncoderLayer(nn.Module):
|
|
|
485
485
|
|
|
486
486
|
def __init__(
|
|
487
487
|
self,
|
|
488
|
-
|
|
488
|
+
d_model: int,
|
|
489
489
|
n_head: int,
|
|
490
490
|
drop_prob: float = 0.0,
|
|
491
491
|
dim_feedforward: Optional[int] = None,
|
|
492
492
|
bias: bool = False,
|
|
493
|
-
activation:
|
|
493
|
+
activation: nn.Module = nn.GELU,
|
|
494
494
|
) -> None:
|
|
495
495
|
super().__init__()
|
|
496
496
|
|
|
497
497
|
if dim_feedforward is None:
|
|
498
|
-
dim_feedforward = 4 *
|
|
498
|
+
dim_feedforward = 4 * d_model
|
|
499
499
|
# note: preserve the original behaviour (print) from the provided code
|
|
500
500
|
print(
|
|
501
|
-
"dim_feedforward is set to 4*
|
|
501
|
+
"dim_feedforward is set to 4*d_model, the default in Vaswani et al. (Attention is all you need)"
|
|
502
502
|
)
|
|
503
503
|
|
|
504
|
-
self.layer_norm_att = _LayerNorm(
|
|
505
|
-
self.mhsa = _MHSA(
|
|
506
|
-
self.layer_norm_ff = _LayerNorm(
|
|
504
|
+
self.layer_norm_att = _LayerNorm(d_model, bias=bias)
|
|
505
|
+
self.mhsa = _MHSA(d_model, n_head, bias, drop_prob=drop_prob)
|
|
506
|
+
self.layer_norm_ff = _LayerNorm(d_model, bias=bias)
|
|
507
507
|
self.feed_forward = _FeedForward(
|
|
508
|
-
|
|
508
|
+
d_model=d_model,
|
|
509
509
|
dim_feedforward=dim_feedforward,
|
|
510
510
|
drop_prob=drop_prob,
|
|
511
511
|
bias=bias,
|
|
@@ -518,7 +518,7 @@ class _TransformerEncoderLayer(nn.Module):
|
|
|
518
518
|
Parameters
|
|
519
519
|
----------
|
|
520
520
|
x : torch.Tensor
|
|
521
|
-
Input of shape (B, T,
|
|
521
|
+
Input of shape (B, T, d_model).
|
|
522
522
|
|
|
523
523
|
Returns
|
|
524
524
|
-------
|
|
@@ -531,13 +531,13 @@ class _TransformerEncoderLayer(nn.Module):
|
|
|
531
531
|
|
|
532
532
|
|
|
533
533
|
class _TransformerEncoder(nn.Module):
|
|
534
|
-
|
|
534
|
+
"""Stack of Transformer encoder layers.
|
|
535
535
|
|
|
536
536
|
Parameters
|
|
537
537
|
----------
|
|
538
|
-
|
|
538
|
+
n_blocks : int
|
|
539
539
|
Number of encoder layers to stack.
|
|
540
|
-
|
|
540
|
+
d_model : int
|
|
541
541
|
Dimensionality of embeddings.
|
|
542
542
|
n_head : int
|
|
543
543
|
Number of attention heads per layer.
|
|
@@ -549,26 +549,26 @@ class _TransformerEncoder(nn.Module):
|
|
|
549
549
|
|
|
550
550
|
def __init__(
|
|
551
551
|
self,
|
|
552
|
-
|
|
553
|
-
|
|
552
|
+
n_blocks: int,
|
|
553
|
+
d_model: int,
|
|
554
554
|
n_head: int,
|
|
555
555
|
drop_prob: float,
|
|
556
556
|
bias: bool,
|
|
557
|
-
activation:
|
|
557
|
+
activation: nn.Module = nn.GELU,
|
|
558
558
|
) -> None:
|
|
559
559
|
super().__init__()
|
|
560
560
|
|
|
561
561
|
self.encoder_block = nn.ModuleList(
|
|
562
562
|
[
|
|
563
563
|
_TransformerEncoderLayer(
|
|
564
|
-
|
|
564
|
+
d_model=d_model,
|
|
565
565
|
n_head=n_head,
|
|
566
566
|
drop_prob=drop_prob,
|
|
567
567
|
dim_feedforward=None,
|
|
568
568
|
bias=bias,
|
|
569
569
|
activation=activation,
|
|
570
570
|
)
|
|
571
|
-
for _ in range(
|
|
571
|
+
for _ in range(n_blocks)
|
|
572
572
|
]
|
|
573
573
|
)
|
|
574
574
|
|
|
@@ -576,7 +576,7 @@ class _TransformerEncoder(nn.Module):
|
|
|
576
576
|
self.apply(self._init_weights)
|
|
577
577
|
for pn, p in self.named_parameters():
|
|
578
578
|
if pn.endswith("proj.weight"):
|
|
579
|
-
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 *
|
|
579
|
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_blocks))
|
|
580
580
|
|
|
581
581
|
@staticmethod
|
|
582
582
|
def _init_weights(module: nn.Module) -> None:
|
|
@@ -593,7 +593,7 @@ class _TransformerEncoder(nn.Module):
|
|
|
593
593
|
|
|
594
594
|
|
|
595
595
|
class _Patcher(nn.Module):
|
|
596
|
-
|
|
596
|
+
"""Patching encoding helper.
|
|
597
597
|
|
|
598
598
|
This module "patchifies" the original X entry in a ViT manner.
|
|
599
599
|
|
braindecode/models/sccnet.py
CHANGED
|
@@ -15,7 +15,7 @@ from braindecode.modules import LogActivation
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class SCCNet(EEGModuleMixin, nn.Module):
|
|
18
|
-
|
|
18
|
+
"""SCCNet from Wei, C S (2019) [sccnet]_.
|
|
19
19
|
|
|
20
20
|
:bdg-success:`Convolution`
|
|
21
21
|
|
|
@@ -155,7 +155,7 @@ class SCCNet(EEGModuleMixin, nn.Module):
|
|
|
155
155
|
n_spatial_filters: int = 22,
|
|
156
156
|
n_spatial_filters_smooth: int = 20,
|
|
157
157
|
drop_prob: float = 0.5,
|
|
158
|
-
activation:
|
|
158
|
+
activation: nn.Module = LogActivation,
|
|
159
159
|
batch_norm_momentum: float = 0.1,
|
|
160
160
|
):
|
|
161
161
|
super().__init__(
|
|
@@ -2,8 +2,6 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
-
from typing import Callable
|
|
6
|
-
|
|
7
5
|
from einops.layers.torch import Rearrange
|
|
8
6
|
from torch import nn
|
|
9
7
|
from torch.nn import init
|
|
@@ -20,7 +18,7 @@ from braindecode.modules import (
|
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
23
|
-
|
|
21
|
+
"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
24
22
|
|
|
25
23
|
:bdg-success:`Convolution`
|
|
26
24
|
|
|
@@ -83,9 +81,9 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
83
81
|
pool_time_length=75,
|
|
84
82
|
pool_time_stride=15,
|
|
85
83
|
final_conv_length="auto",
|
|
86
|
-
conv_nonlin
|
|
84
|
+
conv_nonlin=square,
|
|
87
85
|
pool_mode="mean",
|
|
88
|
-
activation_pool_nonlin:
|
|
86
|
+
activation_pool_nonlin: nn.Module = SafeLog,
|
|
89
87
|
split_first_layer=True,
|
|
90
88
|
batch_norm=True,
|
|
91
89
|
batch_norm_alpha=0.1,
|
|
@@ -24,7 +24,7 @@ _DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.18
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
|
|
27
|
-
|
|
27
|
+
"""Base class for the SignalJEPA models
|
|
28
28
|
|
|
29
29
|
Parameters
|
|
30
30
|
----------
|
|
@@ -144,9 +144,9 @@ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
|
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
class SignalJEPA(_BaseSignalJEPA):
|
|
147
|
-
|
|
147
|
+
"""Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
|
|
148
148
|
|
|
149
|
-
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`
|
|
149
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
|
|
150
150
|
|
|
151
151
|
This model is not meant for classification but for SSL pre-training.
|
|
152
152
|
Its output shape depends on the input shape.
|
|
@@ -232,9 +232,9 @@ class SignalJEPA(_BaseSignalJEPA):
|
|
|
232
232
|
|
|
233
233
|
|
|
234
234
|
class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
235
|
-
|
|
235
|
+
"""Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
236
236
|
|
|
237
|
-
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`
|
|
237
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
|
|
238
238
|
|
|
239
239
|
This architecture is one of the variants of :class:`SignalJEPA`
|
|
240
240
|
that can be used for classification purposes.
|
|
@@ -405,9 +405,9 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
|
405
405
|
|
|
406
406
|
|
|
407
407
|
class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
408
|
-
|
|
408
|
+
"""Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
409
409
|
|
|
410
|
-
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`
|
|
410
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
|
|
411
411
|
|
|
412
412
|
This architecture is one of the variants of :class:`SignalJEPA`
|
|
413
413
|
that can be used for classification purposes.
|
|
@@ -556,9 +556,9 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
|
556
556
|
|
|
557
557
|
|
|
558
558
|
class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
559
|
-
|
|
559
|
+
"""Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
560
560
|
|
|
561
|
-
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`
|
|
561
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
|
|
562
562
|
|
|
563
563
|
This architecture is one of the variants of :class:`SignalJEPA`
|
|
564
564
|
that can be used for classification purposes.
|
|
@@ -569,33 +569,6 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
|
569
569
|
|
|
570
570
|
.. versionadded:: 0.9
|
|
571
571
|
|
|
572
|
-
.. important::
|
|
573
|
-
**Pre-trained Weights Available**
|
|
574
|
-
|
|
575
|
-
This model has pre-trained weights available on the Hugging Face Hub.
|
|
576
|
-
You can load them using:
|
|
577
|
-
|
|
578
|
-
.. code-block:: python
|
|
579
|
-
|
|
580
|
-
from braindecode.models import SignalJEPA_PreLocal
|
|
581
|
-
|
|
582
|
-
# Load pre-trained model from Hugging Face Hub
|
|
583
|
-
model = SignalJEPA_PreLocal.from_pretrained(
|
|
584
|
-
"braindecode/SignalJEPA-PreLocal-pretrained"
|
|
585
|
-
)
|
|
586
|
-
|
|
587
|
-
To push your own trained model to the Hub:
|
|
588
|
-
|
|
589
|
-
.. code-block:: python
|
|
590
|
-
|
|
591
|
-
# After training your model
|
|
592
|
-
model.push_to_hub(
|
|
593
|
-
repo_id="username/my-sjepa-model",
|
|
594
|
-
commit_message="Upload trained SignalJEPA model",
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
598
|
-
|
|
599
572
|
Parameters
|
|
600
573
|
----------
|
|
601
574
|
n_spat_filters : int
|
|
@@ -745,7 +718,7 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
|
745
718
|
|
|
746
719
|
|
|
747
720
|
class _ConvFeatureEncoder(nn.Sequential):
|
|
748
|
-
|
|
721
|
+
"""Convolutional feature encoder for EEG data.
|
|
749
722
|
|
|
750
723
|
Computes successive 1D convolutions (with activations) over the time
|
|
751
724
|
dimension of the input EEG signal.
|
|
@@ -865,7 +838,7 @@ class _ConvFeatureEncoder(nn.Sequential):
|
|
|
865
838
|
|
|
866
839
|
|
|
867
840
|
class _ChannelEmbedding(nn.Embedding):
|
|
868
|
-
|
|
841
|
+
"""Embedding layer for EEG channels.
|
|
869
842
|
|
|
870
843
|
The difference with a regular :class:`nn.Embedding` is that the embedding
|
|
871
844
|
vectors are initialized with a positional encodding of the channel locations.
|
|
@@ -926,7 +899,7 @@ class _ChannelEmbedding(nn.Embedding):
|
|
|
926
899
|
|
|
927
900
|
|
|
928
901
|
class _PosEncoder(nn.Module):
|
|
929
|
-
|
|
902
|
+
"""Positional encoder for EEG data.
|
|
930
903
|
|
|
931
904
|
Parameters
|
|
932
905
|
----------
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
import torch.nn as nn
|
|
@@ -9,7 +10,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class SincShallowNet(EEGModuleMixin, nn.Module):
|
|
12
|
-
|
|
13
|
+
"""Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
|
|
13
14
|
|
|
14
15
|
:bdg-success:`Convolution` :bdg-warning:`Interpretability`
|
|
15
16
|
|
|
@@ -92,7 +93,7 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
|
|
|
92
93
|
num_time_filters: int = 32,
|
|
93
94
|
time_filter_len: int = 33,
|
|
94
95
|
depth_multiplier: int = 2,
|
|
95
|
-
activation:
|
|
96
|
+
activation: Optional[nn.Module] = nn.ELU,
|
|
96
97
|
drop_prob: float = 0.5,
|
|
97
98
|
first_freq: float = 5.0,
|
|
98
99
|
min_freq: float = 1.0,
|
|
@@ -210,7 +211,7 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
|
|
|
210
211
|
|
|
211
212
|
|
|
212
213
|
class _SincFilter(nn.Module):
|
|
213
|
-
|
|
214
|
+
"""Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
|
|
214
215
|
|
|
215
216
|
The `SincFilter` layer implements a convolutional layer where each kernel is
|
|
216
217
|
defined using a parametrized sinc function.
|
|
@@ -9,7 +9,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
12
|
-
|
|
12
|
+
"""Sleep staging architecture from Blanco et al. (2020) from [Blanco2020]_
|
|
13
13
|
|
|
14
14
|
:bdg-success:`Convolution`
|
|
15
15
|
|
|
@@ -68,7 +68,7 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
68
68
|
drop_prob=0.5,
|
|
69
69
|
apply_batch_norm=False,
|
|
70
70
|
return_feats=False,
|
|
71
|
-
activation:
|
|
71
|
+
activation: nn.Module = nn.ReLU,
|
|
72
72
|
chs_info=None,
|
|
73
73
|
n_times=None,
|
|
74
74
|
):
|
|
@@ -11,7 +11,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
14
|
-
|
|
14
|
+
"""Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
|
|
15
15
|
|
|
16
16
|
:bdg-success:`Convolution`
|
|
17
17
|
|
|
@@ -70,7 +70,7 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
70
70
|
time_conv_size_s=0.5,
|
|
71
71
|
max_pool_size_s=0.125,
|
|
72
72
|
pad_size_s=0.25,
|
|
73
|
-
activation:
|
|
73
|
+
activation: nn.Module = nn.ReLU,
|
|
74
74
|
input_window_seconds=None,
|
|
75
75
|
n_outputs=5,
|
|
76
76
|
drop_prob=0.25,
|
braindecode/models/sparcnet.py
CHANGED
|
@@ -11,7 +11,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SPARCNet(EEGModuleMixin, nn.Module):
|
|
14
|
-
|
|
14
|
+
"""Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al. (2023) [jing2023]_.
|
|
15
15
|
|
|
16
16
|
:bdg-success:`Convolution`
|
|
17
17
|
|
|
@@ -69,7 +69,7 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
69
69
|
drop_prob: float = 0.5,
|
|
70
70
|
conv_bias: bool = True,
|
|
71
71
|
batch_norm: bool = True,
|
|
72
|
-
activation:
|
|
72
|
+
activation: nn.Module = nn.ELU,
|
|
73
73
|
kernel_size_conv0: int = 7,
|
|
74
74
|
kernel_size_conv1: int = 1,
|
|
75
75
|
kernel_size_conv2: int = 3,
|
|
@@ -213,7 +213,7 @@ class SPARCNet(EEGModuleMixin, nn.Module):
|
|
|
213
213
|
|
|
214
214
|
|
|
215
215
|
class _DenseLayer(nn.Sequential):
|
|
216
|
-
|
|
216
|
+
"""
|
|
217
217
|
A densely connected layer with batch normalization and dropout.
|
|
218
218
|
|
|
219
219
|
Parameters
|
|
@@ -252,7 +252,7 @@ class _DenseLayer(nn.Sequential):
|
|
|
252
252
|
drop_prob: float = 0.5,
|
|
253
253
|
conv_bias: bool = True,
|
|
254
254
|
batch_norm: bool = True,
|
|
255
|
-
activation:
|
|
255
|
+
activation: nn.Module = nn.ELU,
|
|
256
256
|
kernel_size_conv1: int = 1,
|
|
257
257
|
kernel_size_conv2: int = 3,
|
|
258
258
|
stride_conv1: int = 1,
|
|
@@ -302,7 +302,7 @@ class _DenseLayer(nn.Sequential):
|
|
|
302
302
|
|
|
303
303
|
|
|
304
304
|
class _DenseBlock(nn.Sequential):
|
|
305
|
-
|
|
305
|
+
"""
|
|
306
306
|
A densely connected block that uses DenseLayers.
|
|
307
307
|
|
|
308
308
|
Parameters
|
|
@@ -344,7 +344,7 @@ class _DenseBlock(nn.Sequential):
|
|
|
344
344
|
drop_prob=0.5,
|
|
345
345
|
conv_bias=True,
|
|
346
346
|
batch_norm=True,
|
|
347
|
-
activation:
|
|
347
|
+
activation: nn.Module = nn.ELU,
|
|
348
348
|
kernel_size_conv1: int = 1,
|
|
349
349
|
kernel_size_conv2: int = 3,
|
|
350
350
|
stride_conv1: int = 1,
|
|
@@ -371,7 +371,7 @@ class _DenseBlock(nn.Sequential):
|
|
|
371
371
|
|
|
372
372
|
|
|
373
373
|
class _TransitionLayer(nn.Sequential):
|
|
374
|
-
|
|
374
|
+
"""
|
|
375
375
|
A pooling transition layer.
|
|
376
376
|
|
|
377
377
|
Parameters
|
|
@@ -403,7 +403,7 @@ class _TransitionLayer(nn.Sequential):
|
|
|
403
403
|
out_channels,
|
|
404
404
|
conv_bias=True,
|
|
405
405
|
batch_norm=True,
|
|
406
|
-
activation:
|
|
406
|
+
activation: nn.Module = nn.ELU,
|
|
407
407
|
kernel_size_trans: int = 2,
|
|
408
408
|
stride_trans: int = 2,
|
|
409
409
|
):
|