braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/models/summary.csv
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Model,Application,Type,Sampling Frequency (Hz),Hyperparameters,#Parameters,get_#Parameters,Categorization
|
|
2
|
-
ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention
|
|
3
|
-
AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention
|
|
2
|
+
ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
|
|
3
|
+
AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
4
4
|
BDTCN,Normal Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)","Convolution,Recurrent"
|
|
5
|
-
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","
|
|
5
|
+
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Large Brain Model"
|
|
6
6
|
ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)",Convolution
|
|
7
|
-
CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16,
|
|
7
|
+
CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)","Convolution,Small Attention"
|
|
8
8
|
Deep4Net,General,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
|
|
9
9
|
DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)","Convolution,Recurrent"
|
|
10
|
-
EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention
|
|
10
|
+
EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
11
11
|
EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)","Convolution"
|
|
12
12
|
EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)","Convolution"
|
|
13
13
|
EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)","Convolution,Recurrent"
|
|
@@ -17,18 +17,18 @@ EEGSym,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",299
|
|
|
17
17
|
EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)","Convolution,Interpretability"
|
|
18
18
|
EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)","Convolution"
|
|
19
19
|
EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)","Convolution,Recurrent"
|
|
20
|
-
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,
|
|
21
|
-
MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention
|
|
20
|
+
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Large Brain Model"
|
|
21
|
+
MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Small Attention"
|
|
22
22
|
SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)","Convolution"
|
|
23
|
-
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,
|
|
24
|
-
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,
|
|
25
|
-
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,
|
|
26
|
-
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,
|
|
23
|
+
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
|
|
24
|
+
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Large Brain Model"
|
|
25
|
+
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
26
|
+
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
27
27
|
SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Interpretability"
|
|
28
28
|
ShallowFBCSPNet,General,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution"
|
|
29
29
|
SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
|
|
30
30
|
SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)","Convolution"
|
|
31
|
-
AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Attention
|
|
31
|
+
AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Small Attention"
|
|
32
32
|
SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)","Convolution"
|
|
33
33
|
SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Interpretability"
|
|
34
34
|
TSception,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSception(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Convolution"
|
|
@@ -38,10 +38,8 @@ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",118
|
|
|
38
38
|
FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
39
39
|
FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
40
40
|
IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
|
|
47
|
-
REVE,General,Classification,200,"n_outputs, n_times, n_chans",69481476,"REVE(n_times=1000, n_outputs=4, n_chans=19)","Foundation Model,Attention/Transformer"
|
|
41
|
+
PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
|
|
42
|
+
SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
|
|
43
|
+
BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
|
|
44
|
+
LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Large Brain Model"
|
|
45
|
+
MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
|
braindecode/models/syncnet.py
CHANGED
|
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class SyncNet(EEGModuleMixin, nn.Module):
|
|
11
|
-
|
|
11
|
+
"""Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
|
|
12
12
|
|
|
13
13
|
:bdg-warning:`Interpretability`
|
|
14
14
|
|
|
@@ -89,7 +89,7 @@ class SyncNet(EEGModuleMixin, nn.Module):
|
|
|
89
89
|
num_filters=1,
|
|
90
90
|
filter_width=40,
|
|
91
91
|
pool_size=40,
|
|
92
|
-
activation:
|
|
92
|
+
activation: nn.Module = nn.ReLU,
|
|
93
93
|
ampli_init_values: tuple[float, float] = (-0.05, 0.05),
|
|
94
94
|
omega_init_values: tuple[float, float] = (0.0, 1.0),
|
|
95
95
|
beta_init_values: tuple[float, float] = (0.0, 0.05),
|
braindecode/models/tcn.py
CHANGED
|
@@ -12,7 +12,7 @@ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class BDTCN(EEGModuleMixin, nn.Module):
|
|
15
|
-
|
|
15
|
+
"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
|
|
16
16
|
|
|
17
17
|
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
18
18
|
|
|
@@ -57,7 +57,7 @@ class BDTCN(EEGModuleMixin, nn.Module):
|
|
|
57
57
|
n_filters=30,
|
|
58
58
|
kernel_size=5,
|
|
59
59
|
drop_prob=0.5,
|
|
60
|
-
activation:
|
|
60
|
+
activation: nn.Module = nn.ReLU,
|
|
61
61
|
):
|
|
62
62
|
super().__init__(
|
|
63
63
|
n_outputs=n_outputs,
|
|
@@ -90,7 +90,7 @@ class BDTCN(EEGModuleMixin, nn.Module):
|
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
class TCN(nn.Module):
|
|
93
|
-
|
|
93
|
+
"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
|
|
94
94
|
|
|
95
95
|
See [Bai2018]_ for details.
|
|
96
96
|
|
|
@@ -126,7 +126,7 @@ class TCN(nn.Module):
|
|
|
126
126
|
n_filters=30,
|
|
127
127
|
kernel_size=5,
|
|
128
128
|
drop_prob=0.5,
|
|
129
|
-
activation:
|
|
129
|
+
activation: nn.Module = nn.ReLU,
|
|
130
130
|
):
|
|
131
131
|
super().__init__()
|
|
132
132
|
self.mapping = {
|
|
@@ -221,7 +221,7 @@ class _TemporalBlock(nn.Module):
|
|
|
221
221
|
dilation,
|
|
222
222
|
padding,
|
|
223
223
|
drop_prob,
|
|
224
|
-
activation:
|
|
224
|
+
activation: nn.Module = nn.ReLU,
|
|
225
225
|
):
|
|
226
226
|
super().__init__()
|
|
227
227
|
self.conv1 = weight_norm(
|
braindecode/models/tidnet.py
CHANGED
|
@@ -11,7 +11,7 @@ from braindecode.modules import Ensure4d
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TIDNet(EEGModuleMixin, nn.Module):
|
|
14
|
-
|
|
14
|
+
"""Thinker Invariance DenseNet model from Kostas et al. (2020) [TIDNet]_.
|
|
15
15
|
|
|
16
16
|
:bdg-success:`Convolution`
|
|
17
17
|
|
|
@@ -85,7 +85,7 @@ class TIDNet(EEGModuleMixin, nn.Module):
|
|
|
85
85
|
temp_span: float = 0.05,
|
|
86
86
|
bottleneck: int = 3,
|
|
87
87
|
summary: int = -1,
|
|
88
|
-
activation:
|
|
88
|
+
activation: nn.Module = nn.LeakyReLU,
|
|
89
89
|
):
|
|
90
90
|
super().__init__(
|
|
91
91
|
n_outputs=n_outputs,
|
|
@@ -157,7 +157,7 @@ class _BatchNormZG(nn.BatchNorm2d):
|
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
class _ConvBlock2D(nn.Module):
|
|
160
|
-
|
|
160
|
+
"""Implements Convolution block with order:
|
|
161
161
|
Convolution, dropout, activation, batch-norm
|
|
162
162
|
"""
|
|
163
163
|
|
|
@@ -13,7 +13,7 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class TSception(EEGModuleMixin, nn.Module):
|
|
16
|
-
|
|
16
|
+
"""TSception model from Ding et al. (2020) from [ding2020]_.
|
|
17
17
|
|
|
18
18
|
:bdg-success:`Convolution`
|
|
19
19
|
|
|
@@ -78,7 +78,7 @@ class TSception(EEGModuleMixin, nn.Module):
|
|
|
78
78
|
number_filter_spat: int = 6,
|
|
79
79
|
hidden_size: int = 128,
|
|
80
80
|
drop_prob: float = 0.5,
|
|
81
|
-
activation:
|
|
81
|
+
activation: nn.Module = nn.LeakyReLU,
|
|
82
82
|
pool_size: int = 8,
|
|
83
83
|
inception_windows: tuple[float, float, float] = (0.5, 0.25, 0.125),
|
|
84
84
|
):
|
|
@@ -290,6 +290,6 @@ class TSception(EEGModuleMixin, nn.Module):
|
|
|
290
290
|
"this alias will be removed in v1.14."
|
|
291
291
|
)
|
|
292
292
|
class TSceptionV1(TSception):
|
|
293
|
-
|
|
293
|
+
"""Deprecated alias for TSception."""
|
|
294
294
|
|
|
295
295
|
pass
|
braindecode/models/usleep.py
CHANGED
|
@@ -12,8 +12,8 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class USleep(EEGModuleMixin, nn.Module):
|
|
15
|
-
|
|
16
|
-
Sleep staging architecture from Perslev et al (2021) [1]_.
|
|
15
|
+
"""
|
|
16
|
+
Sleep staging architecture from Perslev et al. (2021) [1]_.
|
|
17
17
|
|
|
18
18
|
:bdg-success:`Convolution`
|
|
19
19
|
|
|
@@ -182,7 +182,7 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
182
182
|
input_window_seconds=None,
|
|
183
183
|
time_conv_size_s=9 / 128,
|
|
184
184
|
ensure_odd_conv_size=False,
|
|
185
|
-
activation:
|
|
185
|
+
activation: nn.Module = nn.ELU,
|
|
186
186
|
chs_info=None,
|
|
187
187
|
n_times=None,
|
|
188
188
|
):
|
|
@@ -331,7 +331,7 @@ class USleep(EEGModuleMixin, nn.Module):
|
|
|
331
331
|
|
|
332
332
|
|
|
333
333
|
class _EncoderBlock(nn.Module):
|
|
334
|
-
|
|
334
|
+
"""Encoding block for a timeseries x of shape (B, C, T)."""
|
|
335
335
|
|
|
336
336
|
def __init__(
|
|
337
337
|
self,
|
|
@@ -339,7 +339,7 @@ class _EncoderBlock(nn.Module):
|
|
|
339
339
|
out_channels=2,
|
|
340
340
|
kernel_size=9,
|
|
341
341
|
downsample=2,
|
|
342
|
-
activation:
|
|
342
|
+
activation: nn.Module = nn.ELU,
|
|
343
343
|
):
|
|
344
344
|
super().__init__()
|
|
345
345
|
self.in_channels = in_channels
|
|
@@ -371,7 +371,7 @@ class _EncoderBlock(nn.Module):
|
|
|
371
371
|
|
|
372
372
|
|
|
373
373
|
class _DecoderBlock(nn.Module):
|
|
374
|
-
|
|
374
|
+
"""Decoding block for a timeseries x of shape (B, C, T)."""
|
|
375
375
|
|
|
376
376
|
def __init__(
|
|
377
377
|
self,
|
|
@@ -380,7 +380,7 @@ class _DecoderBlock(nn.Module):
|
|
|
380
380
|
kernel_size=9,
|
|
381
381
|
upsample=2,
|
|
382
382
|
with_skip_connection=True,
|
|
383
|
-
activation:
|
|
383
|
+
activation: nn.Module = nn.ELU,
|
|
384
384
|
):
|
|
385
385
|
super().__init__()
|
|
386
386
|
self.in_channels = in_channels
|
braindecode/models/util.py
CHANGED
|
@@ -3,9 +3,8 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
import inspect
|
|
6
|
-
from copy import deepcopy
|
|
7
6
|
from pathlib import Path
|
|
8
|
-
from typing import Any, Dict,
|
|
7
|
+
from typing import Any, Dict, Optional, Sequence
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
import pandas as pd
|
|
@@ -30,16 +29,6 @@ def _init_models_dict():
|
|
|
30
29
|
models_dict[m[0]] = m[1]
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
SigArgName = Literal[
|
|
34
|
-
"n_outputs",
|
|
35
|
-
"n_chans",
|
|
36
|
-
"chs_info",
|
|
37
|
-
"n_times",
|
|
38
|
-
"input_window_seconds",
|
|
39
|
-
"sfreq",
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
|
|
43
32
|
################################################################
|
|
44
33
|
# Test cases for models
|
|
45
34
|
#
|
|
@@ -61,9 +50,7 @@ SigArgName = Literal[
|
|
|
61
50
|
# The keys of this dictionary can only be among those of
|
|
62
51
|
# default_signal_params.
|
|
63
52
|
################################################################
|
|
64
|
-
models_mandatory_parameters
|
|
65
|
-
tuple[str, list[SigArgName], dict[SigArgName, Any] | None]
|
|
66
|
-
] = [
|
|
53
|
+
models_mandatory_parameters = [
|
|
67
54
|
("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
68
55
|
("BDTCN", ["n_chans", "n_outputs"], None),
|
|
69
56
|
("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
|
|
@@ -77,60 +64,45 @@ models_mandatory_parameters: list[
|
|
|
77
64
|
(
|
|
78
65
|
"SleepStagerBlanco2020",
|
|
79
66
|
["n_chans", "n_outputs", "n_times"],
|
|
80
|
-
|
|
67
|
+
dict(n_chans=4), # n_chans dividable by n_groups=2
|
|
81
68
|
),
|
|
82
69
|
("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
83
70
|
(
|
|
84
71
|
"AttnSleep",
|
|
85
72
|
["n_outputs", "n_times", "sfreq"],
|
|
86
|
-
|
|
87
|
-
"sfreq": 100.0,
|
|
88
|
-
"n_times": 3000,
|
|
89
|
-
"chs_info": [{"ch_name": "C1", "kind": "eeg"}],
|
|
90
|
-
},
|
|
73
|
+
dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
|
|
91
74
|
), # 1 channel
|
|
92
75
|
("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
93
|
-
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
76
|
+
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
|
|
94
77
|
("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
|
|
95
78
|
("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
96
79
|
("Labram", ["n_chans", "n_outputs", "n_times"], None),
|
|
97
80
|
("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
|
|
98
81
|
("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
99
|
-
("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"],
|
|
82
|
+
("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], dict(sfreq=200.0)),
|
|
100
83
|
("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
|
|
101
84
|
("EEGSym", ["chs_info", "n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
102
|
-
("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
85
|
+
("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
103
86
|
("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
104
87
|
("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
105
88
|
("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
106
|
-
("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
89
|
+
("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
107
90
|
("CTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
108
|
-
("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
109
|
-
("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
91
|
+
("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
|
|
92
|
+
("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
110
93
|
("SignalJEPA", ["chs_info"], None),
|
|
111
94
|
("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
|
|
112
95
|
("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
113
96
|
("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
114
|
-
("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
115
|
-
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
116
|
-
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
117
|
-
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"],
|
|
97
|
+
("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
98
|
+
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
99
|
+
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
100
|
+
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
118
101
|
("PBT", ["n_chans", "n_outputs", "n_times"], None),
|
|
119
102
|
("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
120
|
-
("BrainModule", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
121
103
|
("BENDR", ["n_chans", "n_outputs", "n_times"], None),
|
|
122
104
|
("LUNA", ["n_chans", "n_times", "n_outputs"], None),
|
|
123
105
|
("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
|
|
124
|
-
(
|
|
125
|
-
"REVE",
|
|
126
|
-
["n_times", "n_outputs", "n_chans", "chs_info"],
|
|
127
|
-
{
|
|
128
|
-
"sfreq": 200.0,
|
|
129
|
-
"n_chans": 19,
|
|
130
|
-
"n_times": 1_000,
|
|
131
|
-
"chs_info": [{"ch_name": f"E{i + 1}", "kind": "eeg"} for i in range(19)],
|
|
132
|
-
},
|
|
133
|
-
),
|
|
134
106
|
]
|
|
135
107
|
|
|
136
108
|
################################################################
|
|
@@ -143,129 +115,6 @@ non_classification_models = [
|
|
|
143
115
|
"SignalJEPA",
|
|
144
116
|
]
|
|
145
117
|
|
|
146
|
-
################################################################
|
|
147
|
-
|
|
148
|
-
rng = np.random.default_rng(12)
|
|
149
|
-
# Generating the channel info
|
|
150
|
-
chs_info = [
|
|
151
|
-
{
|
|
152
|
-
"ch_name": f"C{i}",
|
|
153
|
-
"kind": "eeg",
|
|
154
|
-
"loc": rng.random(12),
|
|
155
|
-
}
|
|
156
|
-
for i in range(1, 4)
|
|
157
|
-
]
|
|
158
|
-
default_signal_params: dict[SigArgName, Any] = {
|
|
159
|
-
"n_times": 1000,
|
|
160
|
-
"sfreq": 250.0,
|
|
161
|
-
"n_outputs": 2,
|
|
162
|
-
"chs_info": chs_info,
|
|
163
|
-
"n_chans": len(chs_info),
|
|
164
|
-
"input_window_seconds": 4.0,
|
|
165
|
-
}
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def _get_signal_params(
|
|
169
|
-
signal_params: dict[SigArgName, Any] | None,
|
|
170
|
-
required_params: list[SigArgName] | None = None,
|
|
171
|
-
) -> dict[SigArgName, Any]:
|
|
172
|
-
"""Get signal parameters for model initialization in tests."""
|
|
173
|
-
sp = deepcopy(default_signal_params)
|
|
174
|
-
if signal_params is not None:
|
|
175
|
-
sp.update(signal_params)
|
|
176
|
-
if "chs_info" in signal_params and "n_chans" not in signal_params:
|
|
177
|
-
sp["n_chans"] = len(signal_params["chs_info"])
|
|
178
|
-
if "n_chans" in signal_params and "chs_info" not in signal_params:
|
|
179
|
-
sp["chs_info"] = [
|
|
180
|
-
{"ch_name": f"C{i}", "kind": "eeg", "loc": rng.random(12)}
|
|
181
|
-
for i in range(signal_params["n_chans"])
|
|
182
|
-
]
|
|
183
|
-
assert isinstance(sp["n_times"], int)
|
|
184
|
-
assert isinstance(sp["sfreq"], float)
|
|
185
|
-
assert isinstance(sp["input_window_seconds"], float)
|
|
186
|
-
if "input_window_seconds" not in signal_params:
|
|
187
|
-
sp["input_window_seconds"] = sp["n_times"] / sp["sfreq"]
|
|
188
|
-
if "sfreq" not in signal_params:
|
|
189
|
-
sp["sfreq"] = sp["n_times"] / sp["input_window_seconds"]
|
|
190
|
-
if "n_times" not in signal_params:
|
|
191
|
-
sp["n_times"] = int(sp["input_window_seconds"] * sp["sfreq"])
|
|
192
|
-
if required_params is not None:
|
|
193
|
-
sp = {
|
|
194
|
-
k: sp[k] for k in set((signal_params or {}).keys()).union(required_params)
|
|
195
|
-
}
|
|
196
|
-
return sp
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def _get_possible_signal_params(
|
|
200
|
-
signal_params: dict[SigArgName, Any], required_params: list[SigArgName]
|
|
201
|
-
):
|
|
202
|
-
sp = signal_params
|
|
203
|
-
|
|
204
|
-
# List possible model kwargs:
|
|
205
|
-
output_kwargs = []
|
|
206
|
-
output_kwargs.append(dict(n_outputs=sp["n_outputs"]))
|
|
207
|
-
|
|
208
|
-
if "n_outputs" not in required_params:
|
|
209
|
-
output_kwargs.append(dict(n_outputs=None))
|
|
210
|
-
|
|
211
|
-
channel_kwargs = []
|
|
212
|
-
channel_kwargs.append(dict(chs_info=sp["chs_info"], n_chans=None))
|
|
213
|
-
if "chs_info" not in required_params:
|
|
214
|
-
channel_kwargs.append(dict(n_chans=sp["n_chans"], chs_info=None))
|
|
215
|
-
if "n_chans" not in required_params and "chs_info" not in required_params:
|
|
216
|
-
channel_kwargs.append(dict(n_chans=None, chs_info=None))
|
|
217
|
-
|
|
218
|
-
time_kwargs = []
|
|
219
|
-
time_kwargs.append(
|
|
220
|
-
dict(n_times=sp["n_times"], sfreq=sp["sfreq"], input_window_seconds=None)
|
|
221
|
-
)
|
|
222
|
-
time_kwargs.append(
|
|
223
|
-
dict(
|
|
224
|
-
n_times=None,
|
|
225
|
-
sfreq=sp["sfreq"],
|
|
226
|
-
input_window_seconds=sp["input_window_seconds"],
|
|
227
|
-
)
|
|
228
|
-
)
|
|
229
|
-
time_kwargs.append(
|
|
230
|
-
dict(
|
|
231
|
-
n_times=sp["n_times"],
|
|
232
|
-
sfreq=None,
|
|
233
|
-
input_window_seconds=sp["input_window_seconds"],
|
|
234
|
-
)
|
|
235
|
-
)
|
|
236
|
-
if "n_times" not in required_params and "sfreq" not in required_params:
|
|
237
|
-
time_kwargs.append(
|
|
238
|
-
dict(
|
|
239
|
-
n_times=None,
|
|
240
|
-
sfreq=None,
|
|
241
|
-
input_window_seconds=sp["input_window_seconds"],
|
|
242
|
-
)
|
|
243
|
-
)
|
|
244
|
-
if (
|
|
245
|
-
"n_times" not in required_params
|
|
246
|
-
and "input_window_seconds" not in required_params
|
|
247
|
-
):
|
|
248
|
-
time_kwargs.append(
|
|
249
|
-
dict(n_times=None, sfreq=sp["sfreq"], input_window_seconds=None)
|
|
250
|
-
)
|
|
251
|
-
if "sfreq" not in required_params and "input_window_seconds" not in required_params:
|
|
252
|
-
time_kwargs.append(
|
|
253
|
-
dict(n_times=sp["n_times"], sfreq=None, input_window_seconds=None)
|
|
254
|
-
)
|
|
255
|
-
if (
|
|
256
|
-
"n_times" not in required_params
|
|
257
|
-
and "sfreq" not in required_params
|
|
258
|
-
and "input_window_seconds" not in required_params
|
|
259
|
-
):
|
|
260
|
-
time_kwargs.append(dict(n_times=None, sfreq=None, input_window_seconds=None))
|
|
261
|
-
|
|
262
|
-
return [
|
|
263
|
-
dict(**o, **c, **t)
|
|
264
|
-
for o in output_kwargs
|
|
265
|
-
for c in channel_kwargs
|
|
266
|
-
for t in time_kwargs
|
|
267
|
-
]
|
|
268
|
-
|
|
269
118
|
|
|
270
119
|
################################################################
|
|
271
120
|
def get_summary_table(dir_name=None):
|
braindecode/modules/__init__.py
CHANGED
|
@@ -22,14 +22,7 @@ from .convolution import (
|
|
|
22
22
|
DepthwiseConv2d,
|
|
23
23
|
)
|
|
24
24
|
from .filter import FilterBankLayer, GeneralizedGaussianFilter
|
|
25
|
-
from .layers import
|
|
26
|
-
Chomp1d,
|
|
27
|
-
DropPath,
|
|
28
|
-
Ensure4d,
|
|
29
|
-
SqueezeFinalOutput,
|
|
30
|
-
SubjectLayers,
|
|
31
|
-
TimeDistributed,
|
|
32
|
-
)
|
|
25
|
+
from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
|
|
33
26
|
from .linear import LinearWithConstraint, MaxNormLinear
|
|
34
27
|
from .parametrization import MaxNorm, MaxNormParametrize
|
|
35
28
|
from .stats import (
|
|
@@ -72,7 +65,6 @@ __all__ = [
|
|
|
72
65
|
"Chomp1d",
|
|
73
66
|
"DropPath",
|
|
74
67
|
"Ensure4d",
|
|
75
|
-
"SubjectLayers",
|
|
76
68
|
"SqueezeFinalOutput",
|
|
77
69
|
"TimeDistributed",
|
|
78
70
|
"LinearWithConstraint",
|
|
@@ -8,24 +8,14 @@ class SafeLog(nn.Module):
|
|
|
8
8
|
r"""
|
|
9
9
|
Safe logarithm activation function module.
|
|
10
10
|
|
|
11
|
-
:math
|
|
11
|
+
:math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
|
|
12
12
|
|
|
13
13
|
Parameters
|
|
14
14
|
----------
|
|
15
|
-
|
|
15
|
+
eps : float, optional
|
|
16
16
|
A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
|
|
17
17
|
Default is 1e-6.
|
|
18
18
|
|
|
19
|
-
Examples
|
|
20
|
-
--------
|
|
21
|
-
>>> import torch
|
|
22
|
-
>>> from braindecode.modules import SafeLog
|
|
23
|
-
>>> module = SafeLog(epsilon=1e-6)
|
|
24
|
-
>>> inputs = torch.rand(2, 3)
|
|
25
|
-
>>> outputs = module(inputs)
|
|
26
|
-
>>> outputs.shape
|
|
27
|
-
torch.Size([2, 3])
|
|
28
|
-
|
|
29
19
|
"""
|
|
30
20
|
|
|
31
21
|
def __init__(self, epsilon: float = 1e-6):
|
|
@@ -54,23 +44,7 @@ class SafeLog(nn.Module):
|
|
|
54
44
|
|
|
55
45
|
|
|
56
46
|
class LogActivation(nn.Module):
|
|
57
|
-
"""Logarithm activation function.
|
|
58
|
-
|
|
59
|
-
Parameters
|
|
60
|
-
----------
|
|
61
|
-
epsilon : float, default=1e-6
|
|
62
|
-
Small float to adjust the activation.
|
|
63
|
-
|
|
64
|
-
Examples
|
|
65
|
-
--------
|
|
66
|
-
>>> import torch
|
|
67
|
-
>>> from braindecode.modules import LogActivation
|
|
68
|
-
>>> module = LogActivation(epsilon=1e-6)
|
|
69
|
-
>>> inputs = torch.rand(2, 3)
|
|
70
|
-
>>> outputs = module(inputs)
|
|
71
|
-
>>> outputs.shape
|
|
72
|
-
torch.Size([2, 3])
|
|
73
|
-
"""
|
|
47
|
+
"""Logarithm activation function."""
|
|
74
48
|
|
|
75
49
|
def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
|
|
76
50
|
"""
|