braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -3,163 +3,131 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
  import inspect
6
+ from pathlib import Path
6
7
 
7
8
  import numpy as np
9
+ import pandas as pd
8
10
  import torch
9
11
  from scipy.special import log_softmax
10
12
  from sklearn.utils import deprecated
11
13
 
12
14
  import braindecode.models as models
13
15
 
16
+ models_dict = {}
14
17
 
15
- @deprecated(
16
- "will be removed in version 1.0. Use EEGModuleMixin.to_dense_prediction_model method directly "
17
- "on the model object."
18
- )
19
- def to_dense_prediction_model(model, axis=(2, 3)):
20
- """
21
- Transform a sequential model with strides to a model that outputs
22
- dense predictions by removing the strides and instead inserting dilations.
23
- Modifies model in-place.
24
-
25
- Parameters
26
- ----------
27
- model: torch.nn.Module
28
- Model which modules will be modified
29
- axis: int or (int,int)
30
- Axis to transform (in terms of intermediate output axes)
31
- can either be 2, 3, or (2,3).
32
-
33
- Notes
34
- -----
35
- Does not yet work correctly for average pooling.
36
- Prior to version 0.1.7, there had been a bug that could move strides
37
- backwards one layer.
38
-
39
- """
40
- if not hasattr(axis, "__len__"):
41
- axis = [axis]
42
- assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
43
- axis = np.array(axis) - 2
44
- stride_so_far = np.array([1, 1])
45
- for module in model.modules():
46
- if hasattr(module, "dilation"):
47
- assert module.dilation == 1 or (module.dilation == (1, 1)), (
48
- "Dilation should equal 1 before conversion, maybe the model is "
49
- "already converted?"
50
- )
51
- new_dilation = [1, 1]
52
- for ax in axis:
53
- new_dilation[ax] = int(stride_so_far[ax])
54
- module.dilation = tuple(new_dilation)
55
- if hasattr(module, "stride"):
56
- if not hasattr(module.stride, "__len__"):
57
- module.stride = (module.stride, module.stride)
58
- stride_so_far *= np.array(module.stride)
59
- new_stride = list(module.stride)
60
- for ax in axis:
61
- new_stride[ax] = 1
62
- module.stride = tuple(new_stride)
63
-
64
-
65
- @deprecated(
66
- "will be removed in version 1.0. Use EEGModuleMixin.get_output_shape method directly on the "
67
- "model object."
68
- )
69
- def get_output_shape(model, in_chans, input_window_samples):
70
- """Returns shape of neural network output for batch size equal 1.
71
-
72
- Returns
73
- -------
74
- output_shape: tuple
75
- shape of the network output for `batch_size==1` (1, ...)
76
- """
77
- with torch.no_grad():
78
- dummy_input = torch.ones(
79
- 1, in_chans, input_window_samples,
80
- dtype=next(model.parameters()).dtype,
81
- device=next(model.parameters()).device,
82
- )
83
- output_shape = model(dummy_input).shape
84
- return output_shape
85
-
86
-
87
- def _pad_shift_array(x, stride=1):
88
- """Zero-pad and shift rows of a 3D array.
89
-
90
- E.g., used to align predictions of corresponding windows in
91
- sequence-to-sequence models.
92
-
93
- Parameters
94
- ----------
95
- x : np.ndarray
96
- Array of shape (n_rows, n_classes, n_windows).
97
- stride : int
98
- Number of non-overlapping elements between two consecutive sequences.
99
-
100
- Returns
101
- -------
102
- np.ndarray :
103
- Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows)
104
- where each row is obtained by zero-padding the corresponding row in
105
- ``x`` before and after in the last dimension.
106
- """
107
- if x.ndim != 3:
108
- raise NotImplementedError(
109
- 'x must be of shape (n_rows, n_classes, n_windows), got '
110
- f'{x.shape}')
111
- x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride)))
112
- orig_strides = x_padded.strides
113
- new_strides = (orig_strides[0] - stride * orig_strides[2],
114
- orig_strides[1],
115
- orig_strides[2])
116
- return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides)
117
-
18
+ # For the models inside the init model, go through all the models
19
+ # check those have the EEGMixin class inherited. If they are, add them to the
20
+ # list.
118
21
 
119
- def aggregate_probas(logits, n_windows_stride=1):
120
- """Aggregate predicted probabilities with self-ensembling.
121
22
 
122
- Aggregate window-wise predicted probabilities obtained on overlapping
123
- sequences of windows using multiplicative voting as described in
124
- [Phan2018]_.
23
+ def _init_models_dict():
24
+ for m in inspect.getmembers(models, inspect.isclass):
25
+ if (
26
+ issubclass(m[1], models.base.EEGModuleMixin)
27
+ and m[1] != models.base.EEGModuleMixin
28
+ ):
29
+ models_dict[m[0]] = m[1]
125
30
 
126
- Parameters
127
- ----------
128
- logits : np.ndarray
129
- Array of shape (n_sequences, n_classes, n_windows) containing the
130
- logits (i.e. the raw unnormalized scores for each class) for each
131
- window of each sequence.
132
- n_windows_stride : int
133
- Number of windows between two consecutive sequences. Default is 1
134
- (maximally overlapping sequences).
135
31
 
136
- Returns
137
- -------
138
- np.ndarray :
139
- Array of shape ((n_rows - 1) * stride + n_windows, n_classes)
140
- containing the aggregated predicted probabilities for each window
141
- contained in the input sequences.
32
+ ################################################################
33
+ # Test cases for models
34
+ #
35
+ # This list should be updated whenever a new model is added to
36
+ # braindecode (otherwise `test_completeness__models_test_cases`
37
+ # will fail).
38
+ # Each element in the list should be a tuple with structure
39
+ # (model_class, required_params, signal_params), such that:
40
+ #
41
+ # model_name: str
42
+ # The name of the class of the model to be tested.
43
+ # required_params: list[str]
44
+ # The signal-related parameters that are needed to initialize
45
+ # the model.
46
+ # signal_params: dict | None
47
+ # The characteristics of the signal that should be passed to
48
+ # the model tested in case the default_signal_params are not
49
+ # compatible with this model.
50
+ # The keys of this dictionary can only be among those of
51
+ # default_signal_params.
52
+ ################################################################
53
+ models_mandatory_parameters = [
54
+ ("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
55
+ ("BDTCN", ["n_chans", "n_outputs"], None),
56
+ ("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
57
+ ("DeepSleepNet", ["n_outputs"], None),
58
+ ("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
59
+ ("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
60
+ ("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
61
+ ("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
62
+ ("EEGNetv1", ["n_chans", "n_outputs", "n_times"], None),
63
+ ("EEGNetv4", ["n_chans", "n_outputs", "n_times"], None),
64
+ ("EEGResNet", ["n_chans", "n_outputs", "n_times"], None),
65
+ ("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
66
+ (
67
+ "SleepStagerBlanco2020",
68
+ ["n_chans", "n_outputs", "n_times"],
69
+ dict(n_chans=4), # n_chans dividable by n_groups=2
70
+ ),
71
+ ("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
72
+ (
73
+ "SleepStagerEldele2021",
74
+ ["n_outputs", "n_times", "sfreq"],
75
+ dict(sfreq=100.0, n_times=3000, chs_info=[dict(ch_name="C1", kind="eeg")]),
76
+ ), # 1 channel
77
+ ("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
78
+ ("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=128.0)),
79
+ ("BIOT", ["n_chans", "n_outputs", "sfreq"], None),
80
+ ("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
81
+ ("Labram", ["n_chans", "n_outputs", "n_times"], None),
82
+ ("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
83
+ ("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
84
+ ("ContraWR", ["n_chans", "n_outputs", "sfreq"], dict(sfreq=200.0)),
85
+ ("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
86
+ ("TSceptionV1", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
87
+ ("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
88
+ ("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
89
+ ("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
90
+ ("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
91
+ ("CTNet", ["n_chans", "n_outputs", "n_times"], None),
92
+ ("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=250.0)),
93
+ ("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
94
+ ("SignalJEPA", ["chs_info"], None),
95
+ ("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
96
+ ("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
97
+ ("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
98
+ ("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
99
+ ("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
100
+ ("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
101
+ ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
102
+ ]
103
+
104
+ ################################################################
105
+ # List of models that are not meant for classification
106
+ #
107
+ # Their output shape may difer from the expected output shape
108
+ # for classification models.
109
+ ################################################################
110
+ non_classification_models = [
111
+ "SignalJEPA",
112
+ ]
142
113
 
143
- References
144
- ----------
145
- .. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., &
146
- De Vos, M. (2018). Joint classification and prediction CNN framework
147
- for automatic sleep stage classification. IEEE Transactions on
148
- Biomedical Engineering, 66(5), 1285-1296.
149
- """
150
- log_probas = log_softmax(logits, axis=1)
151
- return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T
152
114
 
115
+ ################################################################
116
+ def get_summary_table(dir_name=None):
117
+ if dir_name is None:
118
+ dir_path = Path(__file__).parent
119
+ else:
120
+ dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
153
121
 
154
- models_dict = {}
122
+ path = dir_path / "summary.csv"
155
123
 
156
- # For the models inside the init model, go through all the models
157
- # check those have the EEGMixin class inherited. If they are, add them to the
158
- # list.
124
+ df = pd.read_csv(
125
+ path,
126
+ header=0,
127
+ index_col="Model",
128
+ skipinitialspace=True,
129
+ )
130
+ return df
159
131
 
160
132
 
161
- def _init_models_dict():
162
- for m in inspect.getmembers(models, inspect.isclass):
163
- if (issubclass(m[1], models.base.EEGModuleMixin)
164
- and m[1] != models.base.EEGModuleMixin):
165
- models_dict[m[0]] = m[1]
133
+ _summary_table = get_summary_table()
@@ -0,0 +1,38 @@
1
+ from .activation import LogActivation, SafeLog
2
+ from .attention import (
3
+ CAT,
4
+ CBAM,
5
+ ECA,
6
+ FCA,
7
+ GCT,
8
+ SRM,
9
+ CATLite,
10
+ EncNet,
11
+ GatherExcite,
12
+ GSoP,
13
+ MultiHeadAttention,
14
+ SqueezeAndExcitation,
15
+ )
16
+ from .blocks import MLP, FeedForwardBlock, InceptionBlock
17
+ from .convolution import (
18
+ AvgPool2dWithConv,
19
+ CausalConv1d,
20
+ CombinedConv,
21
+ Conv2dWithConstraint,
22
+ DepthwiseConv2d,
23
+ )
24
+ from .filter import FilterBankLayer, GeneralizedGaussianFilter
25
+ from .layers import Chomp1d, DropPath, Ensure4d, SqueezeFinalOutput, TimeDistributed
26
+ from .linear import LinearWithConstraint, MaxNormLinear
27
+ from .parametrization import MaxNorm, MaxNormParametrize
28
+ from .stats import (
29
+ LogPowerLayer,
30
+ LogVarLayer,
31
+ MaxLayer,
32
+ MeanLayer,
33
+ StatLayer,
34
+ StdLayer,
35
+ VarLayer,
36
+ )
37
+ from .util import aggregate_probas
38
+ from .wrapper import Expression, IntermediateOutputWrapper
@@ -0,0 +1,60 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ import braindecode.functional as F
5
+
6
+
7
+ class SafeLog(nn.Module):
8
+ r"""
9
+ Safe logarithm activation function module.
10
+
11
+ :math:\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)
12
+
13
+ Parameters
14
+ ----------
15
+ eps : float, optional
16
+ A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
17
+ Default is 1e-6.
18
+
19
+ """
20
+
21
+ def __init__(self, epsilon: float = 1e-6):
22
+ super().__init__()
23
+ self.epsilon = epsilon
24
+
25
+ def forward(self, x) -> Tensor:
26
+ """
27
+ Forward pass of the SafeLog module.
28
+
29
+ Parameters
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor.
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ Output tensor after applying safe logarithm.
38
+ """
39
+ return F.safe_log(x=x, eps=self.epsilon)
40
+
41
+ def extra_repr(self) -> str:
42
+ eps_str = f"eps={self.epsilon}"
43
+ return eps_str
44
+
45
+
46
+ class LogActivation(nn.Module):
47
+ """Logarithm activation function."""
48
+
49
+ def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
50
+ """
51
+ Parameters
52
+ ----------
53
+ epsilon : float
54
+ Small float to adjust the activation.
55
+ """
56
+ super().__init__(*args, **kwargs)
57
+ self.epsilon = epsilon
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)