braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
# Hubert Banville <hubert.jbanville@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
import inspect
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
from scipy.special import log_softmax
|
|
12
|
+
from sklearn.utils import deprecated
|
|
13
|
+
|
|
14
|
+
import braindecode.models as models
|
|
15
|
+
|
|
16
|
+
models_dict = {}
|
|
17
|
+
|
|
18
|
+
# For the models inside the init model, go through all the models
|
|
19
|
+
# check those have the EEGMixin class inherited. If they are, add them to the
|
|
20
|
+
# list.
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _init_models_dict():
|
|
24
|
+
for m in inspect.getmembers(models, inspect.isclass):
|
|
25
|
+
if (
|
|
26
|
+
issubclass(m[1], models.base.EEGModuleMixin)
|
|
27
|
+
and m[1] != models.base.EEGModuleMixin
|
|
28
|
+
):
|
|
29
|
+
models_dict[m[0]] = m[1]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
################################################################
|
|
33
|
+
# Test cases for models
|
|
34
|
+
#
|
|
35
|
+
# This list should be updated whenever a new model is added to
|
|
36
|
+
# braindecode (otherwise `test_completeness__models_test_cases`
|
|
37
|
+
# will fail).
|
|
38
|
+
# Each element in the list should be a tuple with structure
|
|
39
|
+
# (model_class, required_params, signal_params), such that:
|
|
40
|
+
#
|
|
41
|
+
# model_name: str
|
|
42
|
+
# The name of the class of the model to be tested.
|
|
43
|
+
# required_params: list[str]
|
|
44
|
+
# The signal-related parameters that are needed to initialize
|
|
45
|
+
# the model.
|
|
46
|
+
# signal_params: dict | None
|
|
47
|
+
# The characteristics of the signal that should be passed to
|
|
48
|
+
# the model tested in case the default_signal_params are not
|
|
49
|
+
# compatible with this model.
|
|
50
|
+
# The keys of this dictionary can only be among those of
|
|
51
|
+
# default_signal_params.
|
|
52
|
+
################################################################
|
|
53
|
+
models_mandatory_parameters = [
|
|
54
|
+
("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
55
|
+
("BDTCN", ["n_chans", "n_outputs"], None),
|
|
56
|
+
("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
|
|
57
|
+
("DeepSleepNet", ["n_outputs"], None),
|
|
58
|
+
("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
|
|
59
|
+
("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
60
|
+
("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
61
|
+
("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
62
|
+
("EEGNetv1", ["n_chans", "n_outputs", "n_times"], None),
|
|
63
|
+
("EEGNetv4", ["n_chans", "n_outputs", "n_times"], None),
|
|
64
|
+
("EEGResNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
65
|
+
("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
66
|
+
(
|
|
67
|
+
"SleepStagerBlanco2020",
|
|
68
|
+
["n_chans", "n_outputs", "n_times"],
|
|
69
|
+
dict(n_chans=4), # n_chans dividable by n_groups=2
|
|
70
|
+
),
|
|
71
|
+
("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
72
|
+
(
|
|
73
|
+
"SleepStagerEldele2021",
|
|
74
|
+
["n_outputs", "n_times", "sfreq"],
|
|
75
|
+
dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
|
|
76
|
+
), # 1 channel
|
|
77
|
+
("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
78
|
+
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
|
|
79
|
+
("BIOT", ["n_chans", "n_outputs", "sfreq"], None),
|
|
80
|
+
("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
81
|
+
("Labram", ["n_chans", "n_outputs", "n_times"], None),
|
|
82
|
+
("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
|
|
83
|
+
("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
84
|
+
("ContraWR", ["n_chans", "n_outputs", "sfreq"], dict(sfreq=200.0)),
|
|
85
|
+
("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
|
|
86
|
+
("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
87
|
+
("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
88
|
+
("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
89
|
+
("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
90
|
+
("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
91
|
+
("CTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
92
|
+
("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
|
|
93
|
+
("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
94
|
+
("SignalJEPA", ["chs_info"], None),
|
|
95
|
+
("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
|
|
96
|
+
("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
97
|
+
("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
98
|
+
("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
99
|
+
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
100
|
+
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
101
|
+
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
################################################################
|
|
105
|
+
# List of models that are not meant for classification
|
|
106
|
+
#
|
|
107
|
+
# Their output shape may difer from the expected output shape
|
|
108
|
+
# for classification models.
|
|
109
|
+
################################################################
|
|
110
|
+
non_classification_models = [
|
|
111
|
+
"SignalJEPA",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
################################################################
|
|
116
|
+
def get_summary_table(dir_name=None):
|
|
117
|
+
if dir_name is None:
|
|
118
|
+
dir_path = Path(__file__).parent
|
|
119
|
+
else:
|
|
120
|
+
dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
|
|
121
|
+
|
|
122
|
+
path = dir_path / "summary.csv"
|
|
123
|
+
|
|
124
|
+
df = pd.read_csv(
|
|
125
|
+
path,
|
|
126
|
+
header=0,
|
|
127
|
+
index_col="Model",
|
|
128
|
+
skipinitialspace=True,
|
|
129
|
+
)
|
|
130
|
+
return df
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
_summary_table = get_summary_table()
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from .activation import LogActivation, SafeLog
|
|
2
|
+
from .attention import (
|
|
3
|
+
CAT,
|
|
4
|
+
CBAM,
|
|
5
|
+
ECA,
|
|
6
|
+
FCA,
|
|
7
|
+
GCT,
|
|
8
|
+
SRM,
|
|
9
|
+
CATLite,
|
|
10
|
+
EncNet,
|
|
11
|
+
GatherExcite,
|
|
12
|
+
GSoP,
|
|
13
|
+
MultiHeadAttention,
|
|
14
|
+
SqueezeAndExcitation,
|
|
15
|
+
)
|
|
16
|
+
from .blocks import MLP, FeedForwardBlock, InceptionBlock
|
|
17
|
+
from .convolution import (
|
|
18
|
+
AvgPool2dWithConv,
|
|
19
|
+
CausalConv1d,
|
|
20
|
+
CombinedConv,
|
|
21
|
+
Conv2dWithConstraint,
|
|
22
|
+
DepthwiseConv2d,
|
|
23
|
+
)
|
|
24
|
+
from .filter import FilterBankLayer, GeneralizedGaussianFilter
|
|
25
|
+
from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
|
|
26
|
+
from .linear import LinearWithConstraint, MaxNormLinear
|
|
27
|
+
from .parametrization import MaxNorm, MaxNormParametrize
|
|
28
|
+
from .stats import (
|
|
29
|
+
LogPowerLayer,
|
|
30
|
+
LogVarLayer,
|
|
31
|
+
MaxLayer,
|
|
32
|
+
MeanLayer,
|
|
33
|
+
StatLayer,
|
|
34
|
+
StdLayer,
|
|
35
|
+
VarLayer,
|
|
36
|
+
)
|
|
37
|
+
from .util import aggregate_probas
|
|
38
|
+
from .wrapper import Expression, IntermediateOutputWrapper
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor, nn
|
|
3
|
+
|
|
4
|
+
import braindecode.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SafeLog(nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
Safe logarithm activation function module.
|
|
10
|
+
|
|
11
|
+
:math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
eps : float, optional
|
|
16
|
+
A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
|
|
17
|
+
Default is 1e-6.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, epsilon: float = 1e-6):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.epsilon = epsilon
|
|
24
|
+
|
|
25
|
+
def forward(self, x) -> Tensor:
|
|
26
|
+
"""
|
|
27
|
+
Forward pass of the SafeLog module.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
x : torch.Tensor
|
|
32
|
+
Input tensor.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
torch.Tensor
|
|
37
|
+
Output tensor after applying safe logarithm.
|
|
38
|
+
"""
|
|
39
|
+
return F.safe_log(x=x, eps=self.epsilon)
|
|
40
|
+
|
|
41
|
+
def extra_repr(self) -> str:
|
|
42
|
+
eps_str = f"eps={self.epsilon}"
|
|
43
|
+
return eps_str
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LogActivation(nn.Module):
|
|
47
|
+
"""Logarithm activation function."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
|
|
50
|
+
"""
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
epsilon : float
|
|
54
|
+
Small float to adjust the activation.
|
|
55
|
+
"""
|
|
56
|
+
super().__init__(*args, **kwargs)
|
|
57
|
+
self.epsilon = epsilon
|
|
58
|
+
|
|
59
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)
|