braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -6,13 +6,23 @@ from einops.layers.torch import Rearrange
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import init
|
|
8
8
|
|
|
9
|
-
from .
|
|
10
|
-
from .
|
|
11
|
-
from .modules import
|
|
9
|
+
from braindecode.functional import square
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
CombinedConv,
|
|
13
|
+
Ensure4d,
|
|
14
|
+
Expression,
|
|
15
|
+
SafeLog,
|
|
16
|
+
SqueezeFinalOutput,
|
|
17
|
+
)
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
15
|
-
"""Shallow ConvNet model from Schirrmeister et al 2017.
|
|
21
|
+
"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
22
|
+
|
|
23
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: ShallowNet Architecture
|
|
16
26
|
|
|
17
27
|
Model described in [Schirrmeister2017]_.
|
|
18
28
|
|
|
@@ -35,7 +45,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
35
45
|
Non-linear function to be used after convolution layers.
|
|
36
46
|
pool_mode: str
|
|
37
47
|
Method to use on pooling layers. "max" or "mean".
|
|
38
|
-
|
|
48
|
+
activation_pool_nonlin: callable
|
|
39
49
|
Non-linear function to be used after pooling layers.
|
|
40
50
|
split_first_layer: bool
|
|
41
51
|
Split first layer into temporal and spatial layers (True) or just use temporal (False).
|
|
@@ -46,12 +56,6 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
46
56
|
Momentum for BatchNorm2d.
|
|
47
57
|
drop_prob: float
|
|
48
58
|
Dropout probability.
|
|
49
|
-
in_chans : int
|
|
50
|
-
Alias for `n_chans`.
|
|
51
|
-
n_classes: int
|
|
52
|
-
Alias for `n_outputs`.
|
|
53
|
-
input_window_samples: int | None
|
|
54
|
-
Alias for `n_times`.
|
|
55
59
|
|
|
56
60
|
References
|
|
57
61
|
----------
|
|
@@ -65,37 +69,27 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
65
69
|
"""
|
|
66
70
|
|
|
67
71
|
def __init__(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
in_chans=None,
|
|
89
|
-
n_classes=None,
|
|
90
|
-
input_window_samples=None,
|
|
91
|
-
add_log_softmax=True,
|
|
72
|
+
self,
|
|
73
|
+
n_chans=None,
|
|
74
|
+
n_outputs=None,
|
|
75
|
+
n_times=None,
|
|
76
|
+
n_filters_time=40,
|
|
77
|
+
filter_time_length=25,
|
|
78
|
+
n_filters_spat=40,
|
|
79
|
+
pool_time_length=75,
|
|
80
|
+
pool_time_stride=15,
|
|
81
|
+
final_conv_length="auto",
|
|
82
|
+
conv_nonlin=square,
|
|
83
|
+
pool_mode="mean",
|
|
84
|
+
activation_pool_nonlin: nn.Module = SafeLog,
|
|
85
|
+
split_first_layer=True,
|
|
86
|
+
batch_norm=True,
|
|
87
|
+
batch_norm_alpha=0.1,
|
|
88
|
+
drop_prob=0.5,
|
|
89
|
+
chs_info=None,
|
|
90
|
+
input_window_seconds=None,
|
|
91
|
+
sfreq=None,
|
|
92
92
|
):
|
|
93
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
94
|
-
self,
|
|
95
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
96
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
97
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
98
|
-
)
|
|
99
93
|
super().__init__(
|
|
100
94
|
n_outputs=n_outputs,
|
|
101
95
|
n_chans=n_chans,
|
|
@@ -103,10 +97,8 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
103
97
|
n_times=n_times,
|
|
104
98
|
input_window_seconds=input_window_seconds,
|
|
105
99
|
sfreq=sfreq,
|
|
106
|
-
add_log_softmax=add_log_softmax,
|
|
107
100
|
)
|
|
108
101
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
109
|
-
del in_chans, n_classes, input_window_samples
|
|
110
102
|
if final_conv_length == "auto":
|
|
111
103
|
assert self.n_times is not None
|
|
112
104
|
self.n_filters_time = n_filters_time
|
|
@@ -117,7 +109,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
117
109
|
self.final_conv_length = final_conv_length
|
|
118
110
|
self.conv_nonlin = conv_nonlin
|
|
119
111
|
self.pool_mode = pool_mode
|
|
120
|
-
self.pool_nonlin =
|
|
112
|
+
self.pool_nonlin = activation_pool_nonlin
|
|
121
113
|
self.split_first_layer = split_first_layer
|
|
122
114
|
self.batch_norm = batch_norm
|
|
123
115
|
self.batch_norm_alpha = batch_norm_alpha
|
|
@@ -129,7 +121,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
129
121
|
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
130
122
|
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
131
123
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
132
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
124
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
133
125
|
}
|
|
134
126
|
|
|
135
127
|
self.add_module("ensuredims", Ensure4d())
|
|
@@ -175,7 +167,7 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
175
167
|
stride=(self.pool_time_stride, 1),
|
|
176
168
|
),
|
|
177
169
|
)
|
|
178
|
-
self.add_module("pool_nonlin_exp",
|
|
170
|
+
self.add_module("pool_nonlin_exp", self.pool_nonlin())
|
|
179
171
|
self.add_module("drop", nn.Dropout(p=self.drop_prob))
|
|
180
172
|
self.eval()
|
|
181
173
|
if self.final_conv_length == "auto":
|
|
@@ -184,17 +176,17 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
184
176
|
# Incorporating classification module and subsequent ones in one final layer
|
|
185
177
|
module = nn.Sequential()
|
|
186
178
|
|
|
187
|
-
module.add_module(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
179
|
+
module.add_module(
|
|
180
|
+
"conv_classifier",
|
|
181
|
+
nn.Conv2d(
|
|
182
|
+
n_filters_conv,
|
|
183
|
+
self.n_outputs,
|
|
184
|
+
(self.final_conv_length, 1),
|
|
185
|
+
bias=True,
|
|
186
|
+
),
|
|
187
|
+
)
|
|
196
188
|
|
|
197
|
-
module.add_module("squeeze",
|
|
189
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
198
190
|
|
|
199
191
|
self.add_module("final_layer", module)
|
|
200
192
|
|
|
@@ -212,3 +204,5 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
|
212
204
|
init.constant_(self.bnorm.bias, 0)
|
|
213
205
|
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
214
206
|
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
207
|
+
|
|
208
|
+
self.train()
|