braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -24,15 +24,16 @@ from braindecode.modules.attention import (
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
27
|
-
|
|
27
|
+
"""AttentionBaseNet from Wimpff M et al. (2023) [Martin2023]_.
|
|
28
28
|
|
|
29
|
-
:bdg-success:`Convolution` :bdg-info:`Attention
|
|
29
|
+
:bdg-success:`Convolution` :bdg-info:`Small Attention`
|
|
30
30
|
|
|
31
31
|
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
|
|
32
32
|
:align: center
|
|
33
33
|
:alt: AttentionBaseNet Architecture
|
|
34
34
|
:width: 640px
|
|
35
35
|
|
|
36
|
+
|
|
36
37
|
.. rubric:: Architectural Overview
|
|
37
38
|
|
|
38
39
|
AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
|
|
@@ -49,6 +50,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
49
50
|
attention unit that *re-weights channels* (and optionally temporal positions) before
|
|
50
51
|
classification.
|
|
51
52
|
|
|
53
|
+
|
|
52
54
|
.. rubric:: Macro Components
|
|
53
55
|
|
|
54
56
|
- :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
|
|
@@ -90,6 +92,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
90
92
|
*Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
|
|
91
93
|
``(B, ch_dim·T₂)`` to classes.
|
|
92
94
|
|
|
95
|
+
|
|
93
96
|
.. rubric:: Convolutional Details
|
|
94
97
|
|
|
95
98
|
- **Temporal (where time-domain patterns are learned).**
|
|
@@ -108,6 +111,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
108
111
|
emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
|
|
109
112
|
channel attention (DCT-based) summarizes frequencies to drive channel weights.
|
|
110
113
|
|
|
114
|
+
|
|
111
115
|
.. rubric:: Attention / Sequential Modules
|
|
112
116
|
|
|
113
117
|
- **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
|
|
@@ -120,6 +124,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
120
124
|
- **Role.** Re-weights channels (and optionally time) to highlight informative sources
|
|
121
125
|
and suppress distractors, improving SNR ahead of the linear head.
|
|
122
126
|
|
|
127
|
+
|
|
123
128
|
.. rubric:: Additional Mechanisms
|
|
124
129
|
|
|
125
130
|
**Attention variants at a glance:**
|
|
@@ -158,6 +163,17 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
158
163
|
only after the stem learns stable filters. For small datasets, prefer simpler modes
|
|
159
164
|
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
|
|
160
165
|
|
|
166
|
+
Notes
|
|
167
|
+
-----
|
|
168
|
+
- Sequence length after each stage is computed internally; the final classifier expects
|
|
169
|
+
a flattened ``ch_dim x T₂`` vector.
|
|
170
|
+
- Attention operates on *channel* dimension by design; temporal gating exists only in
|
|
171
|
+
specific variants (CBAM/CAT).
|
|
172
|
+
- The paper and original code with more details about the methodological
|
|
173
|
+
choices are available at the [Martin2023]_ and [MartinCode]_.
|
|
174
|
+
|
|
175
|
+
.. versionadded:: 0.9
|
|
176
|
+
|
|
161
177
|
Parameters
|
|
162
178
|
----------
|
|
163
179
|
n_temporal_filters : int, optional
|
|
@@ -219,24 +235,13 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
219
235
|
kernel_size : int, default=9
|
|
220
236
|
The kernel size used in certain types of attention mechanisms for convolution
|
|
221
237
|
operations.
|
|
222
|
-
activation
|
|
238
|
+
activation: nn.Module, default=nn.ELU
|
|
223
239
|
Activation function class to apply. Should be a PyTorch activation
|
|
224
240
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
225
241
|
extra_params : bool, default=False
|
|
226
242
|
Flag to indicate whether additional, custom parameters should be passed to
|
|
227
243
|
the attention mechanism.
|
|
228
244
|
|
|
229
|
-
Notes
|
|
230
|
-
-----
|
|
231
|
-
- Sequence length after each stage is computed internally; the final classifier expects
|
|
232
|
-
a flattened ``ch_dim x T₂`` vector.
|
|
233
|
-
- Attention operates on *channel* dimension by design; temporal gating exists only in
|
|
234
|
-
specific variants (CBAM/CAT).
|
|
235
|
-
- The paper and original code with more details about the methodological
|
|
236
|
-
choices are available at the [Martin2023]_ and [MartinCode]_.
|
|
237
|
-
|
|
238
|
-
.. versionadded:: 0.9
|
|
239
|
-
|
|
240
245
|
References
|
|
241
246
|
----------
|
|
242
247
|
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
|
|
@@ -272,7 +277,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
272
277
|
freq_idx: int = 0,
|
|
273
278
|
n_codewords: int = 4,
|
|
274
279
|
kernel_size: int = 9,
|
|
275
|
-
activation:
|
|
280
|
+
activation: nn.Module = nn.ELU,
|
|
276
281
|
extra_params: bool = False,
|
|
277
282
|
):
|
|
278
283
|
super(AttentionBaseNet, self).__init__()
|
|
@@ -392,8 +397,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
392
397
|
pool_length: int,
|
|
393
398
|
) -> int:
|
|
394
399
|
"""
|
|
395
|
-
Calculates the minimum n_times required for the model to work
|
|
396
|
-
|
|
400
|
+
Calculates the minimum n_times required for the model to work
|
|
397
401
|
with the given parameters.
|
|
398
402
|
|
|
399
403
|
The calculation is based on reversing the pooling operations to
|
|
@@ -409,15 +413,15 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
409
413
|
|
|
410
414
|
|
|
411
415
|
class _FeatureExtractor(nn.Module):
|
|
412
|
-
|
|
413
|
-
A module for feature extraction of the data with temporal and spatial
|
|
414
|
-
|
|
416
|
+
"""
|
|
417
|
+
A module for feature extraction of the data with temporal and spatial
|
|
415
418
|
transformations.
|
|
416
419
|
|
|
417
420
|
This module sequentially processes the input through a series of layers:
|
|
418
421
|
rearrangement, temporal convolution, batch normalization, spatial convolution,
|
|
419
422
|
another batch normalization, an ELU non-linearity, average pooling, and dropout.
|
|
420
423
|
|
|
424
|
+
|
|
421
425
|
Parameters
|
|
422
426
|
----------
|
|
423
427
|
n_chans : int
|
|
@@ -435,7 +439,7 @@ class _FeatureExtractor(nn.Module):
|
|
|
435
439
|
The stride of the average pooling operation. Default is 15.
|
|
436
440
|
drop_prob : float, optional
|
|
437
441
|
The dropout rate for regularization. Default is 0.5.
|
|
438
|
-
activation
|
|
442
|
+
activation: nn.Module, default=nn.ELU
|
|
439
443
|
Activation function class to apply. Should be a PyTorch activation
|
|
440
444
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
441
445
|
"""
|
|
@@ -449,7 +453,7 @@ class _FeatureExtractor(nn.Module):
|
|
|
449
453
|
pool_length: int = 75,
|
|
450
454
|
pool_stride: int = 15,
|
|
451
455
|
drop_prob: float = 0.5,
|
|
452
|
-
activation:
|
|
456
|
+
activation: nn.Module = nn.ELU,
|
|
453
457
|
):
|
|
454
458
|
super().__init__()
|
|
455
459
|
|
|
@@ -489,9 +493,8 @@ class _FeatureExtractor(nn.Module):
|
|
|
489
493
|
|
|
490
494
|
|
|
491
495
|
class _ChannelAttentionBlock(nn.Module):
|
|
492
|
-
|
|
493
|
-
A neural network module implementing channel-wise attention mechanisms to enhance
|
|
494
|
-
|
|
496
|
+
"""
|
|
497
|
+
A neural network module implementing channel-wise attention mechanisms to enhance
|
|
495
498
|
feature representations by selectively emphasizing important channels and suppressing
|
|
496
499
|
less useful ones. This block integrates convolutional layers, pooling, dropout, and
|
|
497
500
|
an optional attention mechanism that can be customized based on the given mode.
|
|
@@ -545,7 +548,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
545
548
|
extra_params : bool, default=False
|
|
546
549
|
Flag to indicate whether additional, custom parameters should be passed to
|
|
547
550
|
the attention mechanism.
|
|
548
|
-
activation
|
|
551
|
+
activation: nn.Module, default=nn.ELU
|
|
549
552
|
Activation function class to apply. Should be a PyTorch activation
|
|
550
553
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
551
554
|
|
|
@@ -561,7 +564,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
561
564
|
attention_block : torch.nn.Module or None
|
|
562
565
|
The attention mechanism applied to the output of the convolutional layers,
|
|
563
566
|
if `attention_mode` is not None. Otherwise, it's set to None.
|
|
564
|
-
activation
|
|
567
|
+
activation: nn.Module, default=nn.ELU
|
|
565
568
|
Activation function class to apply. Should be a PyTorch activation
|
|
566
569
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
567
570
|
|
|
@@ -571,6 +574,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
571
574
|
>>> x = torch.randn(1, 16, 64, 64) # Example input tensor
|
|
572
575
|
>>> output = channel_attention_block(x)
|
|
573
576
|
The output tensor then can be further processed or used as input to another block.
|
|
577
|
+
|
|
574
578
|
"""
|
|
575
579
|
|
|
576
580
|
def __init__(
|
|
@@ -588,7 +592,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
588
592
|
n_codewords: int = 4,
|
|
589
593
|
kernel_size: int = 9,
|
|
590
594
|
extra_params: bool = False,
|
|
591
|
-
activation:
|
|
595
|
+
activation: nn.Module = nn.ELU,
|
|
592
596
|
):
|
|
593
597
|
super().__init__()
|
|
594
598
|
self.conv = nn.Sequential(
|
|
@@ -648,31 +652,31 @@ def get_attention_block(
|
|
|
648
652
|
|
|
649
653
|
Parameters
|
|
650
654
|
----------
|
|
651
|
-
attention_mode
|
|
655
|
+
attention_mode: str
|
|
652
656
|
The type of attention mechanism to apply.
|
|
653
|
-
ch_dim
|
|
657
|
+
ch_dim: int
|
|
654
658
|
The number of input channels to the block.
|
|
655
|
-
reduction_rate
|
|
659
|
+
reduction_rate: int
|
|
656
660
|
The reduction rate used in the attention mechanism to reduce
|
|
657
661
|
dimensionality and computational complexity.
|
|
658
662
|
Used in all the methods, except for the
|
|
659
663
|
encnet and eca.
|
|
660
|
-
use_mlp
|
|
664
|
+
use_mlp: bool
|
|
661
665
|
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used
|
|
662
666
|
within the attention mechanism for further processing. Used in the ge
|
|
663
667
|
and srm attention mechanism.
|
|
664
|
-
seq_len
|
|
668
|
+
seq_len: int
|
|
665
669
|
The sequence length, used in certain types of attention mechanisms to
|
|
666
670
|
process temporal dimensions. Used in the ge or fca attention mechanism.
|
|
667
|
-
freq_idx
|
|
671
|
+
freq_idx: int
|
|
668
672
|
DCT index used in fca attention mechanism.
|
|
669
|
-
n_codewords
|
|
673
|
+
n_codewords: int
|
|
670
674
|
The number of codewords (clusters) used in attention mechanisms
|
|
671
675
|
that employ quantization or clustering strategies, encnet.
|
|
672
|
-
kernel_size
|
|
676
|
+
kernel_size: int
|
|
673
677
|
The kernel size used in certain types of attention mechanisms for convolution
|
|
674
678
|
operations, used in the cbam, eca, and cat attention mechanisms.
|
|
675
|
-
extra_params
|
|
679
|
+
extra_params: bool
|
|
676
680
|
Parameter to pass additional parameters to the GatherExcite mechanism.
|
|
677
681
|
|
|
678
682
|
Returns
|
braindecode/models/attn_sleep.py
CHANGED
|
@@ -16,9 +16,9 @@ from braindecode.modules import CausalConv1d
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class AttnSleep(EEGModuleMixin, nn.Module):
|
|
19
|
-
|
|
19
|
+
"""Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
|
|
20
20
|
|
|
21
|
-
:bdg-success:`Convolution` :bdg-info:`Attention
|
|
21
|
+
:bdg-success:`Convolution` :bdg-info:`Small Attention`
|
|
22
22
|
|
|
23
23
|
.. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
|
|
24
24
|
:align: center
|
|
@@ -63,10 +63,10 @@ class AttnSleep(EEGModuleMixin, nn.Module):
|
|
|
63
63
|
Alias for `n_outputs`.
|
|
64
64
|
input_size_s : float
|
|
65
65
|
Alias for `input_window_seconds`.
|
|
66
|
-
activation
|
|
66
|
+
activation: nn.Module, default=nn.ReLU
|
|
67
67
|
Activation function class to apply. Should be a PyTorch activation
|
|
68
68
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
69
|
-
activation_mrcnn
|
|
69
|
+
activation_mrcnn: nn.Module, default=nn.ReLU
|
|
70
70
|
Activation function class to apply in the Mask R-CNN layer.
|
|
71
71
|
Should be a PyTorch activation module class like ``nn.ReLU`` or
|
|
72
72
|
``nn.GELU``. Default is ``nn.GELU``.
|
|
@@ -90,8 +90,8 @@ class AttnSleep(EEGModuleMixin, nn.Module):
|
|
|
90
90
|
d_ff=120,
|
|
91
91
|
n_attn_heads=5,
|
|
92
92
|
drop_prob=0.1,
|
|
93
|
-
activation_mrcnn:
|
|
94
|
-
activation:
|
|
93
|
+
activation_mrcnn: nn.Module = nn.GELU,
|
|
94
|
+
activation: nn.Module = nn.ReLU,
|
|
95
95
|
input_window_seconds=None,
|
|
96
96
|
n_outputs=None,
|
|
97
97
|
after_reduced_cnn_size=30,
|
|
@@ -175,7 +175,7 @@ class AttnSleep(EEGModuleMixin, nn.Module):
|
|
|
175
175
|
|
|
176
176
|
Parameters
|
|
177
177
|
----------
|
|
178
|
-
x
|
|
178
|
+
x: torch.Tensor
|
|
179
179
|
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
180
180
|
"""
|
|
181
181
|
|
|
@@ -230,7 +230,7 @@ class _SEBasicBlock(nn.Module):
|
|
|
230
230
|
planes,
|
|
231
231
|
stride=1,
|
|
232
232
|
downsample=None,
|
|
233
|
-
activation:
|
|
233
|
+
activation: nn.Module = nn.ReLU,
|
|
234
234
|
*,
|
|
235
235
|
reduction=16,
|
|
236
236
|
):
|
|
@@ -278,8 +278,8 @@ class _MRCNN(nn.Module):
|
|
|
278
278
|
self,
|
|
279
279
|
after_reduced_cnn_size,
|
|
280
280
|
kernel_size=7,
|
|
281
|
-
activation:
|
|
282
|
-
activation_se:
|
|
281
|
+
activation: nn.Module = nn.GELU,
|
|
282
|
+
activation_se: nn.Module = nn.ReLU,
|
|
283
283
|
):
|
|
284
284
|
super(_MRCNN, self).__init__()
|
|
285
285
|
drate = 0.5
|
|
@@ -325,7 +325,7 @@ class _MRCNN(nn.Module):
|
|
|
325
325
|
)
|
|
326
326
|
|
|
327
327
|
def _make_layer(
|
|
328
|
-
self, block, planes, blocks, stride=1, activate:
|
|
328
|
+
self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
|
|
329
329
|
): # makes residual SE block
|
|
330
330
|
downsample = None
|
|
331
331
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
@@ -363,7 +363,7 @@ class _MRCNN(nn.Module):
|
|
|
363
363
|
def _attention(
|
|
364
364
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
365
365
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
366
|
-
"""Implementation of Scaled dot product attention
|
|
366
|
+
"""Implementation of Scaled dot product attention"""
|
|
367
367
|
# d_k - dimension of the query and key vectors
|
|
368
368
|
d_k = query.size(-1)
|
|
369
369
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
@@ -392,7 +392,7 @@ class _MultiHeadedAttention(nn.Module):
|
|
|
392
392
|
self.dropout = nn.Dropout(p=dropout)
|
|
393
393
|
|
|
394
394
|
def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
|
|
395
|
-
"""Implements Multi-head attention
|
|
395
|
+
"""Implements Multi-head attention"""
|
|
396
396
|
nbatches = query.size(0)
|
|
397
397
|
|
|
398
398
|
query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
|
|
@@ -423,7 +423,9 @@ class _MultiHeadedAttention(nn.Module):
|
|
|
423
423
|
|
|
424
424
|
|
|
425
425
|
class _ResidualLayerNormAttn(nn.Module):
|
|
426
|
-
|
|
426
|
+
"""
|
|
427
|
+
A residual connection followed by a layer norm.
|
|
428
|
+
"""
|
|
427
429
|
|
|
428
430
|
def __init__(self, size, dropout, fn_attn):
|
|
429
431
|
super().__init__()
|
|
@@ -462,9 +464,8 @@ class _ResidualLayerNormFF(nn.Module):
|
|
|
462
464
|
|
|
463
465
|
|
|
464
466
|
class _TCE(nn.Module):
|
|
465
|
-
|
|
466
|
-
Transformer Encoder
|
|
467
|
-
|
|
467
|
+
"""
|
|
468
|
+
Transformer Encoder
|
|
468
469
|
It is a stack of n layers.
|
|
469
470
|
"""
|
|
470
471
|
|
|
@@ -482,9 +483,8 @@ class _TCE(nn.Module):
|
|
|
482
483
|
|
|
483
484
|
|
|
484
485
|
class _EncoderLayer(nn.Module):
|
|
485
|
-
|
|
486
|
-
An encoder layer
|
|
487
|
-
|
|
486
|
+
"""
|
|
487
|
+
An encoder layer
|
|
488
488
|
Made up of self-attention and a feed forward layer.
|
|
489
489
|
Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
|
|
490
490
|
"""
|
|
@@ -515,7 +515,7 @@ class _EncoderLayer(nn.Module):
|
|
|
515
515
|
)
|
|
516
516
|
|
|
517
517
|
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
|
518
|
-
"""Transformer Encoder
|
|
518
|
+
"""Transformer Encoder"""
|
|
519
519
|
query = self.conv(x_in)
|
|
520
520
|
# Encoder self-attention
|
|
521
521
|
x = self.residual_self_attn(query, x_in, x_in)
|
|
@@ -524,11 +524,9 @@ class _EncoderLayer(nn.Module):
|
|
|
524
524
|
|
|
525
525
|
|
|
526
526
|
class _PositionwiseFeedForward(nn.Module):
|
|
527
|
-
|
|
527
|
+
"""Positionwise feed-forward network."""
|
|
528
528
|
|
|
529
|
-
def __init__(
|
|
530
|
-
self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
|
|
531
|
-
):
|
|
529
|
+
def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
|
|
532
530
|
super().__init__()
|
|
533
531
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
534
532
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
@@ -544,6 +542,6 @@ class _PositionwiseFeedForward(nn.Module):
|
|
|
544
542
|
"`SleepStagerEldele2021` was renamed to `AttnSleep` in v1.12 to follow original author's name; this alias will be removed in v1.14."
|
|
545
543
|
)
|
|
546
544
|
class SleepStagerEldele2021(AttnSleep):
|
|
547
|
-
|
|
545
|
+
"""Deprecated alias for SleepStagerEldele2021."""
|
|
548
546
|
|
|
549
547
|
pass
|
braindecode/models/base.py
CHANGED
|
@@ -192,7 +192,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
192
192
|
n_times is not None
|
|
193
193
|
and input_window_seconds is not None
|
|
194
194
|
and sfreq is not None
|
|
195
|
-
and n_times !=
|
|
195
|
+
and n_times != int(input_window_seconds * sfreq)
|
|
196
196
|
):
|
|
197
197
|
raise ValueError(
|
|
198
198
|
f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
|
|
@@ -236,7 +236,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
236
236
|
and self._input_window_seconds is not None
|
|
237
237
|
and self._sfreq is not None
|
|
238
238
|
):
|
|
239
|
-
return
|
|
239
|
+
return int(self._input_window_seconds * self._sfreq)
|
|
240
240
|
elif self._n_times is None:
|
|
241
241
|
raise ValueError(
|
|
242
242
|
"n_times could not be inferred. "
|
|
@@ -284,7 +284,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
284
284
|
|
|
285
285
|
Returns
|
|
286
286
|
-------
|
|
287
|
-
output_shape
|
|
287
|
+
output_shape: tuple[int, ...]
|
|
288
288
|
shape of the network output for `batch_size==1` (1, ...)
|
|
289
289
|
"""
|
|
290
290
|
with torch.inference_mode():
|
|
@@ -330,14 +330,13 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
330
330
|
|
|
331
331
|
def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
|
|
332
332
|
"""
|
|
333
|
-
Transform a sequential model with strides to a model that outputs
|
|
334
|
-
|
|
333
|
+
Transform a sequential model with strides to a model that outputs
|
|
335
334
|
dense predictions by removing the strides and instead inserting dilations.
|
|
336
335
|
Modifies model in-place.
|
|
337
336
|
|
|
338
337
|
Parameters
|
|
339
338
|
----------
|
|
340
|
-
axis
|
|
339
|
+
axis: int or (int,int)
|
|
341
340
|
Axis to transform (in terms of intermediate output axes)
|
|
342
341
|
can either be 2, 3, or (2,3).
|
|
343
342
|
|
|
@@ -346,6 +345,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
346
345
|
Does not yet work correctly for average pooling.
|
|
347
346
|
Prior to version 0.1.7, there had been a bug that could move strides
|
|
348
347
|
backwards one layer.
|
|
348
|
+
|
|
349
349
|
"""
|
|
350
350
|
if not hasattr(axis, "__iter__"):
|
|
351
351
|
axis = (axis,)
|
braindecode/models/bendr.py
CHANGED
|
@@ -8,15 +8,16 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class BENDR(EEGModuleMixin, nn.Module):
|
|
11
|
-
|
|
11
|
+
"""BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr]_.
|
|
12
12
|
|
|
13
|
-
:bdg-success:`Convolution` :bdg-danger:`
|
|
13
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
14
14
|
|
|
15
15
|
.. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
|
|
16
16
|
:align: center
|
|
17
17
|
:alt: BENDR Architecture
|
|
18
18
|
:width: 1000px
|
|
19
19
|
|
|
20
|
+
|
|
20
21
|
The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
|
|
21
22
|
development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
|
|
22
23
|
training objective to learn compressed representations of raw EEG signals [bendr]_. The
|
|
@@ -78,31 +79,6 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
78
79
|
prepended to the BENDR sequence before input to the transformer, serving as the aggregate
|
|
79
80
|
representation token [bendr]_.
|
|
80
81
|
|
|
81
|
-
.. important::
|
|
82
|
-
**Pre-trained Weights Available**
|
|
83
|
-
|
|
84
|
-
This model has pre-trained weights available on the Hugging Face Hub.
|
|
85
|
-
You can load them using:
|
|
86
|
-
|
|
87
|
-
.. code-block:: python
|
|
88
|
-
|
|
89
|
-
from braindecode.models import BENDR
|
|
90
|
-
|
|
91
|
-
# Load pre-trained model from Hugging Face Hub
|
|
92
|
-
# you can specify `n_outputs` for your downstream task
|
|
93
|
-
model = BENDR.from_pretrained("braindecode/braindecode-bendr", n_outputs=2)
|
|
94
|
-
|
|
95
|
-
To push your own trained model to the Hub:
|
|
96
|
-
|
|
97
|
-
.. code-block:: python
|
|
98
|
-
|
|
99
|
-
# After training your model
|
|
100
|
-
model.push_to_hub(
|
|
101
|
-
repo_id="username/my-bendr-model", commit_message="Upload trained BENDR model"
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
105
|
-
|
|
106
82
|
Notes
|
|
107
83
|
-----
|
|
108
84
|
* The full BENDR architecture contains a large number of parameters; configuration (1)
|
|
@@ -119,27 +95,6 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
119
95
|
**self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
|
|
120
96
|
by subsequent fine-tuning on the specific downstream classification task [bendr]_.
|
|
121
97
|
|
|
122
|
-
References
|
|
123
|
-
----------
|
|
124
|
-
.. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
|
|
125
|
-
BENDR: Using transformers and a contrastive self-supervised learning task to learn from
|
|
126
|
-
massive amounts of EEG data.
|
|
127
|
-
Frontiers in Human Neuroscience, 15, 653659.
|
|
128
|
-
https://doi.org/10.3389/fnhum.2021.653659
|
|
129
|
-
.. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
|
|
130
|
-
wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
|
|
131
|
-
In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
|
|
132
|
-
Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
|
|
133
|
-
https://dl.acm.org/doi/10.5555/3495724.3496768
|
|
134
|
-
.. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
|
|
135
|
-
Improving Transformer Optimization Through Better Initialization.
|
|
136
|
-
In International Conference on Machine Learning (pp. 4475-4483). PMLR.
|
|
137
|
-
https://dl.acm.org/doi/10.5555/3524938.3525354
|
|
138
|
-
.. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
|
|
139
|
-
Reducing Transformer Depth on Demand with Structured Dropout.
|
|
140
|
-
International Conference on Learning Representations.
|
|
141
|
-
Retrieved from https://openreview.net/forum?id=SylO2yStDr
|
|
142
|
-
|
|
143
98
|
Parameters
|
|
144
99
|
----------
|
|
145
100
|
encoder_h : int, default=512
|
|
@@ -183,6 +138,27 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
183
138
|
final_layer : bool, default=True
|
|
184
139
|
If True, includes a final linear classification layer that maps from encoder_h to
|
|
185
140
|
n_outputs. If False, the model outputs the contextualized features directly.
|
|
141
|
+
|
|
142
|
+
References
|
|
143
|
+
----------
|
|
144
|
+
.. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
|
|
145
|
+
BENDR: Using transformers and a contrastive self-supervised learning task to learn from
|
|
146
|
+
massive amounts of EEG data.
|
|
147
|
+
Frontiers in Human Neuroscience, 15, 653659.
|
|
148
|
+
https://doi.org/10.3389/fnhum.2021.653659
|
|
149
|
+
.. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
|
|
150
|
+
wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
|
|
151
|
+
In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
|
|
152
|
+
Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
|
|
153
|
+
https://dl.acm.org/doi/10.5555/3495724.3496768
|
|
154
|
+
.. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
|
|
155
|
+
Improving Transformer Optimization Through Better Initialization.
|
|
156
|
+
In International Conference on Machine Learning (pp. 4475-4483). PMLR.
|
|
157
|
+
https://dl.acm.org/doi/10.5555/3524938.3525354
|
|
158
|
+
.. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
|
|
159
|
+
Reducing Transformer Depth on Demand with Structured Dropout.
|
|
160
|
+
International Conference on Learning Representations.
|
|
161
|
+
Retrieved from https://openreview.net/forum?id=SylO2yStDr
|
|
186
162
|
"""
|
|
187
163
|
|
|
188
164
|
def __init__(
|
|
@@ -200,7 +176,7 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
200
176
|
projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
|
|
201
177
|
drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
|
|
202
178
|
layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
|
|
203
|
-
activation
|
|
179
|
+
activation=nn.GELU, # Activation function
|
|
204
180
|
# Transformer specific parameters
|
|
205
181
|
transformer_layers=8,
|
|
206
182
|
transformer_heads=8,
|
|
@@ -349,7 +325,7 @@ class _ConvEncoderBENDR(nn.Module):
|
|
|
349
325
|
|
|
350
326
|
|
|
351
327
|
class _BENDRContextualizer(nn.Module):
|
|
352
|
-
|
|
328
|
+
"""Transformer-based contextualizer for BENDR."""
|
|
353
329
|
|
|
354
330
|
def __init__(
|
|
355
331
|
self,
|