braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/tcn.py
CHANGED
|
@@ -2,22 +2,23 @@
|
|
|
2
2
|
# Lukas Gemein <l.gemein@gmail.com>
|
|
3
3
|
#
|
|
4
4
|
# License: BSD-3
|
|
5
|
-
|
|
5
|
+
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
from torch.nn import init
|
|
8
|
-
from torch.nn.utils import weight_norm
|
|
8
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
9
9
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
12
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
class
|
|
16
|
-
"""
|
|
14
|
+
class BDTCN(EEGModuleMixin, nn.Module):
|
|
15
|
+
"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
|
|
17
16
|
|
|
18
|
-
|
|
17
|
+
.. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
|
|
18
|
+
:align: center
|
|
19
|
+
:alt: Braindecode TCN Architecture
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
See [gemein2020]_ for details.
|
|
21
22
|
|
|
22
23
|
Parameters
|
|
23
24
|
----------
|
|
@@ -29,36 +30,33 @@ class TCN(EEGModuleMixin, nn.Module):
|
|
|
29
30
|
kernel size of the convolutions
|
|
30
31
|
drop_prob: float
|
|
31
32
|
dropout probability
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
activation: nn.Module, default=nn.ReLU
|
|
34
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
35
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
34
36
|
|
|
35
37
|
References
|
|
36
38
|
----------
|
|
37
|
-
.. [
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
arXiv preprint arXiv:1803.01271.
|
|
39
|
+
.. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
|
|
40
|
+
Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
|
|
41
|
+
diagnostics of EEG pathology. NeuroImage, 220, 117021.
|
|
41
42
|
"""
|
|
42
43
|
|
|
43
44
|
def __init__(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
45
|
+
self,
|
|
46
|
+
# Braindecode parameters
|
|
47
|
+
n_chans=None,
|
|
48
|
+
n_outputs=None,
|
|
49
|
+
chs_info=None,
|
|
50
|
+
n_times=None,
|
|
51
|
+
sfreq=None,
|
|
52
|
+
input_window_seconds=None,
|
|
53
|
+
# Model's parameters
|
|
54
|
+
n_blocks=3,
|
|
55
|
+
n_filters=30,
|
|
56
|
+
kernel_size=5,
|
|
57
|
+
drop_prob=0.5,
|
|
58
|
+
activation: nn.Module = nn.ReLU,
|
|
57
59
|
):
|
|
58
|
-
n_chans, = deprecated_args(
|
|
59
|
-
self,
|
|
60
|
-
("n_in_chans", "n_chans", n_in_chans, n_chans),
|
|
61
|
-
)
|
|
62
60
|
super().__init__(
|
|
63
61
|
n_outputs=n_outputs,
|
|
64
62
|
n_chans=n_chans,
|
|
@@ -66,43 +64,106 @@ class TCN(EEGModuleMixin, nn.Module):
|
|
|
66
64
|
n_times=n_times,
|
|
67
65
|
input_window_seconds=input_window_seconds,
|
|
68
66
|
sfreq=sfreq,
|
|
69
|
-
add_log_softmax=add_log_softmax,
|
|
70
67
|
)
|
|
71
|
-
del n_outputs, n_chans, chs_info, n_times,
|
|
72
|
-
|
|
68
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
69
|
+
|
|
70
|
+
self.base_tcn = TCN(
|
|
71
|
+
n_chans=self.n_chans,
|
|
72
|
+
n_outputs=self.n_outputs,
|
|
73
|
+
n_blocks=n_blocks,
|
|
74
|
+
n_filters=n_filters,
|
|
75
|
+
kernel_size=kernel_size,
|
|
76
|
+
drop_prob=drop_prob,
|
|
77
|
+
activation=activation,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.final_layer = torch.nn.Sequential(
|
|
81
|
+
torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
x = self.base_tcn(x)
|
|
86
|
+
x = self.final_layer(x)
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TCN(nn.Module):
|
|
91
|
+
"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
|
|
92
|
+
|
|
93
|
+
See [Bai2018]_ for details.
|
|
94
|
+
|
|
95
|
+
Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
n_filters: int
|
|
100
|
+
number of output filters of each convolution
|
|
101
|
+
n_blocks: int
|
|
102
|
+
number of temporal blocks in the network
|
|
103
|
+
kernel_size: int
|
|
104
|
+
kernel size of the convolutions
|
|
105
|
+
drop_prob: float
|
|
106
|
+
dropout probability
|
|
107
|
+
activation: nn.Module, default=nn.ReLU
|
|
108
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
109
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
110
|
+
|
|
111
|
+
References
|
|
112
|
+
----------
|
|
113
|
+
.. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
|
|
114
|
+
An empirical evaluation of generic convolutional and recurrent networks
|
|
115
|
+
for sequence modeling.
|
|
116
|
+
arXiv preprint arXiv:1803.01271.
|
|
117
|
+
"""
|
|
73
118
|
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
n_chans=None,
|
|
122
|
+
n_outputs=None,
|
|
123
|
+
n_blocks=3,
|
|
124
|
+
n_filters=30,
|
|
125
|
+
kernel_size=5,
|
|
126
|
+
drop_prob=0.5,
|
|
127
|
+
activation: nn.Module = nn.ReLU,
|
|
128
|
+
):
|
|
129
|
+
super().__init__()
|
|
74
130
|
self.mapping = {
|
|
75
131
|
"fc.weight": "final_layer.fc.weight",
|
|
76
|
-
"fc.bias": "final_layer.fc.bias"
|
|
132
|
+
"fc.bias": "final_layer.fc.bias",
|
|
77
133
|
}
|
|
78
134
|
self.ensuredims = Ensure4d()
|
|
79
135
|
t_blocks = nn.Sequential()
|
|
80
136
|
for i in range(n_blocks):
|
|
81
|
-
n_inputs =
|
|
82
|
-
dilation_size = 2
|
|
83
|
-
t_blocks.add_module(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
137
|
+
n_inputs = n_chans if i == 0 else n_filters
|
|
138
|
+
dilation_size = 2**i
|
|
139
|
+
t_blocks.add_module(
|
|
140
|
+
"temporal_block_{:d}".format(i),
|
|
141
|
+
_TemporalBlock(
|
|
142
|
+
n_inputs=n_inputs,
|
|
143
|
+
n_outputs=n_filters,
|
|
144
|
+
kernel_size=kernel_size,
|
|
145
|
+
stride=1,
|
|
146
|
+
dilation=dilation_size,
|
|
147
|
+
padding=(kernel_size - 1) * dilation_size,
|
|
148
|
+
drop_prob=drop_prob,
|
|
149
|
+
activation=activation,
|
|
150
|
+
),
|
|
151
|
+
)
|
|
92
152
|
self.temporal_blocks = t_blocks
|
|
93
153
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
154
|
+
self.final_layer = _FinalLayer(
|
|
155
|
+
in_features=n_filters,
|
|
156
|
+
out_features=n_outputs,
|
|
157
|
+
)
|
|
97
158
|
self.min_len = 1
|
|
98
159
|
for i in range(n_blocks):
|
|
99
|
-
dilation = 2
|
|
160
|
+
dilation = 2**i
|
|
100
161
|
self.min_len += 2 * (kernel_size - 1) * dilation
|
|
101
162
|
|
|
102
163
|
# start in eval mode
|
|
103
|
-
self.
|
|
164
|
+
self.train()
|
|
104
165
|
|
|
105
|
-
def forward(self, x):
|
|
166
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
106
167
|
"""Forward pass.
|
|
107
168
|
|
|
108
169
|
Parameters
|
|
@@ -126,21 +187,18 @@ class TCN(EEGModuleMixin, nn.Module):
|
|
|
126
187
|
|
|
127
188
|
|
|
128
189
|
class _FinalLayer(nn.Module):
|
|
129
|
-
def __init__(self, in_features, out_features
|
|
130
|
-
|
|
190
|
+
def __init__(self, in_features, out_features):
|
|
131
191
|
super().__init__()
|
|
132
192
|
|
|
133
193
|
self.fc = nn.Linear(in_features=in_features, out_features=out_features)
|
|
134
194
|
|
|
135
|
-
|
|
136
|
-
self.out_fun = nn.LogSoftmax(dim=1)
|
|
137
|
-
else:
|
|
138
|
-
self.out_fun = nn.Identity()
|
|
195
|
+
self.out_fun = nn.Identity()
|
|
139
196
|
|
|
140
|
-
self.squeeze =
|
|
141
|
-
|
|
142
|
-
def forward(self, x, batch_size, time_size, min_len):
|
|
197
|
+
self.squeeze = SqueezeFinalOutput()
|
|
143
198
|
|
|
199
|
+
def forward(
|
|
200
|
+
self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
|
|
201
|
+
) -> torch.Tensor:
|
|
144
202
|
fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
|
|
145
203
|
fc_out = self.out_fun(fc_out)
|
|
146
204
|
fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
|
|
@@ -151,27 +209,51 @@ class _FinalLayer(nn.Module):
|
|
|
151
209
|
return self.squeeze(out[:, :, :, None])
|
|
152
210
|
|
|
153
211
|
|
|
154
|
-
class
|
|
155
|
-
def __init__(
|
|
156
|
-
|
|
212
|
+
class _TemporalBlock(nn.Module):
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
n_inputs,
|
|
216
|
+
n_outputs,
|
|
217
|
+
kernel_size,
|
|
218
|
+
stride,
|
|
219
|
+
dilation,
|
|
220
|
+
padding,
|
|
221
|
+
drop_prob,
|
|
222
|
+
activation: nn.Module = nn.ReLU,
|
|
223
|
+
):
|
|
157
224
|
super().__init__()
|
|
158
|
-
self.conv1 = weight_norm(
|
|
159
|
-
|
|
160
|
-
|
|
225
|
+
self.conv1 = weight_norm(
|
|
226
|
+
nn.Conv1d(
|
|
227
|
+
n_inputs,
|
|
228
|
+
n_outputs,
|
|
229
|
+
kernel_size,
|
|
230
|
+
stride=stride,
|
|
231
|
+
padding=padding,
|
|
232
|
+
dilation=dilation,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
161
235
|
self.chomp1 = Chomp1d(padding)
|
|
162
|
-
self.relu1 =
|
|
236
|
+
self.relu1 = activation()
|
|
163
237
|
self.dropout1 = nn.Dropout2d(drop_prob)
|
|
164
238
|
|
|
165
|
-
self.conv2 = weight_norm(
|
|
166
|
-
|
|
167
|
-
|
|
239
|
+
self.conv2 = weight_norm(
|
|
240
|
+
nn.Conv1d(
|
|
241
|
+
n_outputs,
|
|
242
|
+
n_outputs,
|
|
243
|
+
kernel_size,
|
|
244
|
+
stride=stride,
|
|
245
|
+
padding=padding,
|
|
246
|
+
dilation=dilation,
|
|
247
|
+
)
|
|
248
|
+
)
|
|
168
249
|
self.chomp2 = Chomp1d(padding)
|
|
169
|
-
self.relu2 =
|
|
250
|
+
self.relu2 = activation()
|
|
170
251
|
self.dropout2 = nn.Dropout2d(drop_prob)
|
|
171
252
|
|
|
172
|
-
self.downsample = (
|
|
173
|
-
|
|
174
|
-
|
|
253
|
+
self.downsample = (
|
|
254
|
+
nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
|
255
|
+
)
|
|
256
|
+
self.relu = activation()
|
|
175
257
|
|
|
176
258
|
init.normal_(self.conv1.weight, 0, 0.01)
|
|
177
259
|
init.normal_(self.conv2.weight, 0, 0.01)
|
|
@@ -189,15 +271,3 @@ class TemporalBlock(nn.Module):
|
|
|
189
271
|
out = self.dropout2(out)
|
|
190
272
|
res = x if self.downsample is None else self.downsample(x)
|
|
191
273
|
return self.relu(out + res)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
class Chomp1d(nn.Module):
|
|
195
|
-
def __init__(self, chomp_size):
|
|
196
|
-
super().__init__()
|
|
197
|
-
self.chomp_size = chomp_size
|
|
198
|
-
|
|
199
|
-
def extra_repr(self):
|
|
200
|
-
return 'chomp_size={}'.format(self.chomp_size)
|
|
201
|
-
|
|
202
|
-
def forward(self, x):
|
|
203
|
-
return x[:, :, :-self.chomp_size].contiguous()
|