braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ # Hubert Banville <hubert.jbanville@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+ import inspect
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from scipy.special import log_softmax
12
+ from sklearn.utils import deprecated
13
+
14
+ import braindecode.models as models
15
+
16
+ models_dict = {}
17
+
18
+ # For the models inside the init model, go through all the models
19
+ # check those have the EEGMixin class inherited. If they are, add them to the
20
+ # list.
21
+
22
+
23
+ def _init_models_dict():
24
+ for m in inspect.getmembers(models, inspect.isclass):
25
+ if (
26
+ issubclass(m[1], models.base.EEGModuleMixin)
27
+ and m[1] != models.base.EEGModuleMixin
28
+ ):
29
+ models_dict[m[0]] = m[1]
30
+
31
+
32
+ ################################################################
33
+ # Test cases for models
34
+ #
35
+ # This list should be updated whenever a new model is added to
36
+ # braindecode (otherwise `test_completeness__models_test_cases`
37
+ # will fail).
38
+ # Each element in the list should be a tuple with structure
39
+ # (model_class, required_params, signal_params), such that:
40
+ #
41
+ # model_name: str
42
+ # The name of the class of the model to be tested.
43
+ # required_params: list[str]
44
+ # The signal-related parameters that are needed to initialize
45
+ # the model.
46
+ # signal_params: dict | None
47
+ # The characteristics of the signal that should be passed to
48
+ # the model tested in case the default_signal_params are not
49
+ # compatible with this model.
50
+ # The keys of this dictionary can only be among those of
51
+ # default_signal_params.
52
+ ################################################################
53
+ models_mandatory_parameters = [
54
+ ("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
55
+ ("BDTCN", ["n_chans", "n_outputs"], None),
56
+ ("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
57
+ ("DeepSleepNet", ["n_outputs"], None),
58
+ ("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
59
+ ("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
60
+ ("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
61
+ ("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
62
+ ("EEGNetv1", ["n_chans", "n_outputs", "n_times"], None),
63
+ ("EEGNetv4", ["n_chans", "n_outputs", "n_times"], None),
64
+ ("EEGResNet", ["n_chans", "n_outputs", "n_times"], None),
65
+ ("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
66
+ (
67
+ "SleepStagerBlanco2020",
68
+ ["n_chans", "n_outputs", "n_times"],
69
+ dict(n_chans=4), # n_chans dividable by n_groups=2
70
+ ),
71
+ ("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
72
+ (
73
+ "SleepStagerEldele2021",
74
+ ["n_outputs", "n_times", "sfreq"],
75
+ dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
76
+ ), # 1 channel
77
+ ("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
78
+ ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
79
+ ("BIOT", ["n_chans", "n_outputs", "sfreq"], None),
80
+ ("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
81
+ ("Labram", ["n_chans", "n_outputs", "n_times"], None),
82
+ ("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
83
+ ("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
84
+ ("ContraWR", ["n_chans", "n_outputs", "sfreq"], dict(sfreq=200.0)),
85
+ ("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
86
+ ("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
87
+ ("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
88
+ ("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
89
+ ("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
90
+ ("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
91
+ ("CTNet", ["n_chans", "n_outputs", "n_times"], None),
92
+ ("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
93
+ ("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
94
+ ("SignalJEPA", ["chs_info"], None),
95
+ ("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
96
+ ("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
97
+ ("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
98
+ ("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
99
+ ("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
100
+ ("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
101
+ ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
102
+ ]
103
+
104
+ ################################################################
105
+ # List of models that are not meant for classification
106
+ #
107
+ # Their output shape may difer from the expected output shape
108
+ # for classification models.
109
+ ################################################################
110
+ non_classification_models = [
111
+ "SignalJEPA",
112
+ ]
113
+
114
+
115
+ ################################################################
116
+ def get_summary_table(dir_name=None):
117
+ if dir_name is None:
118
+ dir_path = Path(__file__).parent
119
+ else:
120
+ dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
121
+
122
+ path = dir_path / "summary.csv"
123
+
124
+ df = pd.read_csv(
125
+ path,
126
+ header=0,
127
+ index_col="Model",
128
+ skipinitialspace=True,
129
+ )
130
+ return df
131
+
132
+
133
+ _summary_table = get_summary_table()
@@ -0,0 +1,38 @@
1
+ from .activation import LogActivation, SafeLog
2
+ from .attention import (
3
+ CAT,
4
+ CBAM,
5
+ ECA,
6
+ FCA,
7
+ GCT,
8
+ SRM,
9
+ CATLite,
10
+ EncNet,
11
+ GatherExcite,
12
+ GSoP,
13
+ MultiHeadAttention,
14
+ SqueezeAndExcitation,
15
+ )
16
+ from .blocks import MLP, FeedForwardBlock, InceptionBlock
17
+ from .convolution import (
18
+ AvgPool2dWithConv,
19
+ CausalConv1d,
20
+ CombinedConv,
21
+ Conv2dWithConstraint,
22
+ DepthwiseConv2d,
23
+ )
24
+ from .filter import FilterBankLayer, GeneralizedGaussianFilter
25
+ from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
26
+ from .linear import LinearWithConstraint, MaxNormLinear
27
+ from .parametrization import MaxNorm, MaxNormParametrize
28
+ from .stats import (
29
+ LogPowerLayer,
30
+ LogVarLayer,
31
+ MaxLayer,
32
+ MeanLayer,
33
+ StatLayer,
34
+ StdLayer,
35
+ VarLayer,
36
+ )
37
+ from .util import aggregate_probas
38
+ from .wrapper import Expression, IntermediateOutputWrapper
@@ -0,0 +1,60 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ import braindecode.functional as F
5
+
6
+
7
+ class SafeLog(nn.Module):
8
+ r"""
9
+ Safe logarithm activation function module.
10
+
11
+ :math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
12
+
13
+ Parameters
14
+ ----------
15
+ eps : float, optional
16
+ A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
17
+ Default is 1e-6.
18
+
19
+ """
20
+
21
+ def __init__(self, epsilon: float = 1e-6):
22
+ super().__init__()
23
+ self.epsilon = epsilon
24
+
25
+ def forward(self, x) -> Tensor:
26
+ """
27
+ Forward pass of the SafeLog module.
28
+
29
+ Parameters
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor.
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ Output tensor after applying safe logarithm.
38
+ """
39
+ return F.safe_log(x=x, eps=self.epsilon)
40
+
41
+ def extra_repr(self) -> str:
42
+ eps_str = f"eps={self.epsilon}"
43
+ return eps_str
44
+
45
+
46
+ class LogActivation(nn.Module):
47
+ """Logarithm activation function."""
48
+
49
+ def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
50
+ """
51
+ Parameters
52
+ ----------
53
+ epsilon : float
54
+ Small float to adjust the activation.
55
+ """
56
+ super().__init__(*args, **kwargs)
57
+ self.epsilon = epsilon
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)