braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from einops.layers.torch import Rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import init
|
|
8
|
+
|
|
9
|
+
from braindecode.functional import square
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
CombinedConv,
|
|
13
|
+
Ensure4d,
|
|
14
|
+
Expression,
|
|
15
|
+
SafeLog,
|
|
16
|
+
SqueezeFinalOutput,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
|
|
21
|
+
"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
22
|
+
|
|
23
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: ShallowNet Architecture
|
|
26
|
+
|
|
27
|
+
Model described in [Schirrmeister2017]_.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
n_filters_time: int
|
|
32
|
+
Number of temporal filters.
|
|
33
|
+
filter_time_length: int
|
|
34
|
+
Length of the temporal filter.
|
|
35
|
+
n_filters_spat: int
|
|
36
|
+
Number of spatial filters.
|
|
37
|
+
pool_time_length: int
|
|
38
|
+
Length of temporal pooling filter.
|
|
39
|
+
pool_time_stride: int
|
|
40
|
+
Length of stride between temporal pooling filters.
|
|
41
|
+
final_conv_length: int | str
|
|
42
|
+
Length of the final convolution layer.
|
|
43
|
+
If set to "auto", length of the input signal must be specified.
|
|
44
|
+
conv_nonlin: callable
|
|
45
|
+
Non-linear function to be used after convolution layers.
|
|
46
|
+
pool_mode: str
|
|
47
|
+
Method to use on pooling layers. "max" or "mean".
|
|
48
|
+
activation_pool_nonlin: callable
|
|
49
|
+
Non-linear function to be used after pooling layers.
|
|
50
|
+
split_first_layer: bool
|
|
51
|
+
Split first layer into temporal and spatial layers (True) or just use temporal (False).
|
|
52
|
+
There would be no non-linearity between the split layers.
|
|
53
|
+
batch_norm: bool
|
|
54
|
+
Whether to use batch normalisation.
|
|
55
|
+
batch_norm_alpha: float
|
|
56
|
+
Momentum for BatchNorm2d.
|
|
57
|
+
drop_prob: float
|
|
58
|
+
Dropout probability.
|
|
59
|
+
|
|
60
|
+
References
|
|
61
|
+
----------
|
|
62
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
63
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
64
|
+
& Ball, T. (2017).
|
|
65
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
66
|
+
visualization.
|
|
67
|
+
Human Brain Mapping , Aug. 2017.
|
|
68
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
n_chans=None,
|
|
74
|
+
n_outputs=None,
|
|
75
|
+
n_times=None,
|
|
76
|
+
n_filters_time=40,
|
|
77
|
+
filter_time_length=25,
|
|
78
|
+
n_filters_spat=40,
|
|
79
|
+
pool_time_length=75,
|
|
80
|
+
pool_time_stride=15,
|
|
81
|
+
final_conv_length="auto",
|
|
82
|
+
conv_nonlin=square,
|
|
83
|
+
pool_mode="mean",
|
|
84
|
+
activation_pool_nonlin: nn.Module = SafeLog,
|
|
85
|
+
split_first_layer=True,
|
|
86
|
+
batch_norm=True,
|
|
87
|
+
batch_norm_alpha=0.1,
|
|
88
|
+
drop_prob=0.5,
|
|
89
|
+
chs_info=None,
|
|
90
|
+
input_window_seconds=None,
|
|
91
|
+
sfreq=None,
|
|
92
|
+
):
|
|
93
|
+
super().__init__(
|
|
94
|
+
n_outputs=n_outputs,
|
|
95
|
+
n_chans=n_chans,
|
|
96
|
+
chs_info=chs_info,
|
|
97
|
+
n_times=n_times,
|
|
98
|
+
input_window_seconds=input_window_seconds,
|
|
99
|
+
sfreq=sfreq,
|
|
100
|
+
)
|
|
101
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
102
|
+
if final_conv_length == "auto":
|
|
103
|
+
assert self.n_times is not None
|
|
104
|
+
self.n_filters_time = n_filters_time
|
|
105
|
+
self.filter_time_length = filter_time_length
|
|
106
|
+
self.n_filters_spat = n_filters_spat
|
|
107
|
+
self.pool_time_length = pool_time_length
|
|
108
|
+
self.pool_time_stride = pool_time_stride
|
|
109
|
+
self.final_conv_length = final_conv_length
|
|
110
|
+
self.conv_nonlin = conv_nonlin
|
|
111
|
+
self.pool_mode = pool_mode
|
|
112
|
+
self.pool_nonlin = activation_pool_nonlin
|
|
113
|
+
self.split_first_layer = split_first_layer
|
|
114
|
+
self.batch_norm = batch_norm
|
|
115
|
+
self.batch_norm_alpha = batch_norm_alpha
|
|
116
|
+
self.drop_prob = drop_prob
|
|
117
|
+
|
|
118
|
+
self.mapping = {
|
|
119
|
+
"conv_time.weight": "conv_time_spat.conv_time.weight",
|
|
120
|
+
"conv_spat.weight": "conv_time_spat.conv_spat.weight",
|
|
121
|
+
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
122
|
+
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
123
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
124
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
self.add_module("ensuredims", Ensure4d())
|
|
128
|
+
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
|
129
|
+
if self.split_first_layer:
|
|
130
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
131
|
+
self.add_module(
|
|
132
|
+
"conv_time_spat",
|
|
133
|
+
CombinedConv(
|
|
134
|
+
in_chans=self.n_chans,
|
|
135
|
+
n_filters_time=self.n_filters_time,
|
|
136
|
+
n_filters_spat=self.n_filters_spat,
|
|
137
|
+
filter_time_length=filter_time_length,
|
|
138
|
+
bias_time=True,
|
|
139
|
+
bias_spat=not self.batch_norm,
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
n_filters_conv = self.n_filters_spat
|
|
143
|
+
else:
|
|
144
|
+
self.add_module(
|
|
145
|
+
"conv_time",
|
|
146
|
+
nn.Conv2d(
|
|
147
|
+
self.n_chans,
|
|
148
|
+
self.n_filters_time,
|
|
149
|
+
(self.filter_time_length, 1),
|
|
150
|
+
stride=1,
|
|
151
|
+
bias=not self.batch_norm,
|
|
152
|
+
),
|
|
153
|
+
)
|
|
154
|
+
n_filters_conv = self.n_filters_time
|
|
155
|
+
if self.batch_norm:
|
|
156
|
+
self.add_module(
|
|
157
|
+
"bnorm",
|
|
158
|
+
nn.BatchNorm2d(
|
|
159
|
+
n_filters_conv, momentum=self.batch_norm_alpha, affine=True
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
self.add_module("conv_nonlin_exp", Expression(self.conv_nonlin))
|
|
163
|
+
self.add_module(
|
|
164
|
+
"pool",
|
|
165
|
+
pool_class(
|
|
166
|
+
kernel_size=(self.pool_time_length, 1),
|
|
167
|
+
stride=(self.pool_time_stride, 1),
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
self.add_module("pool_nonlin_exp", self.pool_nonlin())
|
|
171
|
+
self.add_module("drop", nn.Dropout(p=self.drop_prob))
|
|
172
|
+
self.eval()
|
|
173
|
+
if self.final_conv_length == "auto":
|
|
174
|
+
self.final_conv_length = self.get_output_shape()[2]
|
|
175
|
+
|
|
176
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
177
|
+
module = nn.Sequential()
|
|
178
|
+
|
|
179
|
+
module.add_module(
|
|
180
|
+
"conv_classifier",
|
|
181
|
+
nn.Conv2d(
|
|
182
|
+
n_filters_conv,
|
|
183
|
+
self.n_outputs,
|
|
184
|
+
(self.final_conv_length, 1),
|
|
185
|
+
bias=True,
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
190
|
+
|
|
191
|
+
self.add_module("final_layer", module)
|
|
192
|
+
|
|
193
|
+
# Initialization, xavier is same as in paper...
|
|
194
|
+
init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
|
|
195
|
+
# maybe no bias in case of no split layer and batch norm
|
|
196
|
+
if self.split_first_layer or (not self.batch_norm):
|
|
197
|
+
init.constant_(self.conv_time_spat.conv_time.bias, 0)
|
|
198
|
+
if self.split_first_layer:
|
|
199
|
+
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
200
|
+
if not self.batch_norm:
|
|
201
|
+
init.constant_(self.conv_time_spat.conv_spat.bias, 0)
|
|
202
|
+
if self.batch_norm:
|
|
203
|
+
init.constant_(self.bnorm.weight, 1)
|
|
204
|
+
init.constant_(self.bnorm.bias, 0)
|
|
205
|
+
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
206
|
+
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
207
|
+
|
|
208
|
+
self.train()
|