braindecode 1.2.0.dev182094932__py3-none-any.whl → 1.3.0.dev176728557__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/datasets/experimental.py +218 -0
- braindecode/models/__init__.py +6 -8
- braindecode/models/atcnet.py +152 -12
- braindecode/models/attentionbasenet.py +151 -26
- braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
- braindecode/models/ctnet.py +1 -1
- braindecode/models/deep4.py +6 -2
- braindecode/models/deepsleepnet.py +118 -5
- braindecode/models/eegconformer.py +114 -15
- braindecode/models/eeginception_erp.py +76 -7
- braindecode/models/eeginception_mi.py +2 -0
- braindecode/models/eegnet.py +25 -189
- braindecode/models/eegnex.py +113 -6
- braindecode/models/eegsimpleconv.py +2 -0
- braindecode/models/eegtcnet.py +1 -1
- braindecode/models/sccnet.py +81 -8
- braindecode/models/shallow_fbcsp.py +2 -0
- braindecode/models/sleep_stager_blanco_2020.py +2 -0
- braindecode/models/sleep_stager_chambon_2018.py +2 -0
- braindecode/models/sparcnet.py +2 -0
- braindecode/models/summary.csv +39 -41
- braindecode/models/tidnet.py +2 -0
- braindecode/models/tsinception.py +15 -3
- braindecode/models/usleep.py +103 -9
- braindecode/models/util.py +5 -5
- braindecode/preprocessing/preprocess.py +20 -26
- braindecode/version.py +1 -1
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/METADATA +7 -2
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/RECORD +33 -33
- braindecode/models/eegresnet.py +0 -362
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/WHEEL +0 -0
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/top_level.txt +0 -0
braindecode/models/eegresnet.py
DELETED
|
@@ -1,362 +0,0 @@
|
|
|
1
|
-
# Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
2
|
-
# Tonio Ball
|
|
3
|
-
#
|
|
4
|
-
# License: BSD-3
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
from einops.layers.torch import Rearrange
|
|
9
|
-
from torch import nn
|
|
10
|
-
from torch.nn import init
|
|
11
|
-
|
|
12
|
-
from braindecode.models.base import EEGModuleMixin
|
|
13
|
-
from braindecode.modules import (
|
|
14
|
-
AvgPool2dWithConv,
|
|
15
|
-
Ensure4d,
|
|
16
|
-
SqueezeFinalOutput,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
21
|
-
"""EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
|
|
22
|
-
|
|
23
|
-
.. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
|
|
24
|
-
:align: center
|
|
25
|
-
:alt: EEGResNet Architecture
|
|
26
|
-
|
|
27
|
-
Model described in [Schirrmeister2017]_.
|
|
28
|
-
|
|
29
|
-
Parameters
|
|
30
|
-
----------
|
|
31
|
-
in_chans :
|
|
32
|
-
Alias for ``n_chans``.
|
|
33
|
-
n_classes :
|
|
34
|
-
Alias for ``n_outputs``.
|
|
35
|
-
input_window_samples :
|
|
36
|
-
Alias for ``n_times``.
|
|
37
|
-
activation: nn.Module, default=nn.ELU
|
|
38
|
-
Activation function class to apply. Should be a PyTorch activation
|
|
39
|
-
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
40
|
-
|
|
41
|
-
References
|
|
42
|
-
----------
|
|
43
|
-
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
44
|
-
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
45
|
-
& Ball, T. (2017). Deep learning with convolutional neural networks for ,
|
|
46
|
-
EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
|
|
47
|
-
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
n_chans=None,
|
|
53
|
-
n_outputs=None,
|
|
54
|
-
n_times=None,
|
|
55
|
-
final_pool_length="auto",
|
|
56
|
-
n_first_filters=20,
|
|
57
|
-
n_layers_per_block=2,
|
|
58
|
-
first_filter_length=3,
|
|
59
|
-
activation=nn.ELU,
|
|
60
|
-
split_first_layer=True,
|
|
61
|
-
batch_norm_alpha=0.1,
|
|
62
|
-
batch_norm_epsilon=1e-4,
|
|
63
|
-
conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
|
|
64
|
-
chs_info=None,
|
|
65
|
-
input_window_seconds=None,
|
|
66
|
-
sfreq=250,
|
|
67
|
-
):
|
|
68
|
-
super().__init__(
|
|
69
|
-
n_outputs=n_outputs,
|
|
70
|
-
n_chans=n_chans,
|
|
71
|
-
chs_info=chs_info,
|
|
72
|
-
n_times=n_times,
|
|
73
|
-
input_window_seconds=input_window_seconds,
|
|
74
|
-
sfreq=sfreq,
|
|
75
|
-
)
|
|
76
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
77
|
-
|
|
78
|
-
if final_pool_length == "auto":
|
|
79
|
-
assert self.n_times is not None
|
|
80
|
-
assert first_filter_length % 2 == 1
|
|
81
|
-
self.final_pool_length = final_pool_length
|
|
82
|
-
self.n_first_filters = n_first_filters
|
|
83
|
-
self.n_layers_per_block = n_layers_per_block
|
|
84
|
-
self.first_filter_length = first_filter_length
|
|
85
|
-
self.nonlinearity = activation
|
|
86
|
-
self.split_first_layer = split_first_layer
|
|
87
|
-
self.batch_norm_alpha = batch_norm_alpha
|
|
88
|
-
self.batch_norm_epsilon = batch_norm_epsilon
|
|
89
|
-
self.conv_weight_init_fn = conv_weight_init_fn
|
|
90
|
-
|
|
91
|
-
self.mapping = {
|
|
92
|
-
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
93
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
self.add_module("ensuredims", Ensure4d())
|
|
97
|
-
if self.split_first_layer:
|
|
98
|
-
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
99
|
-
self.add_module(
|
|
100
|
-
"conv_time",
|
|
101
|
-
nn.Conv2d(
|
|
102
|
-
1,
|
|
103
|
-
self.n_first_filters,
|
|
104
|
-
(self.first_filter_length, 1),
|
|
105
|
-
stride=1,
|
|
106
|
-
padding=(self.first_filter_length // 2, 0),
|
|
107
|
-
),
|
|
108
|
-
)
|
|
109
|
-
self.add_module(
|
|
110
|
-
"conv_spat",
|
|
111
|
-
nn.Conv2d(
|
|
112
|
-
self.n_first_filters,
|
|
113
|
-
self.n_first_filters,
|
|
114
|
-
(1, self.n_chans),
|
|
115
|
-
stride=(1, 1),
|
|
116
|
-
bias=False,
|
|
117
|
-
),
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
self.add_module(
|
|
121
|
-
"conv_time",
|
|
122
|
-
nn.Conv2d(
|
|
123
|
-
self.n_chans,
|
|
124
|
-
self.n_first_filters,
|
|
125
|
-
(self.first_filter_length, 1),
|
|
126
|
-
stride=(1, 1),
|
|
127
|
-
padding=(self.first_filter_length // 2, 0),
|
|
128
|
-
bias=False,
|
|
129
|
-
),
|
|
130
|
-
)
|
|
131
|
-
n_filters_conv = self.n_first_filters
|
|
132
|
-
self.add_module(
|
|
133
|
-
"bnorm",
|
|
134
|
-
nn.BatchNorm2d(
|
|
135
|
-
n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
|
|
136
|
-
),
|
|
137
|
-
)
|
|
138
|
-
self.add_module("conv_nonlin", self.nonlinearity())
|
|
139
|
-
cur_dilation = np.array([1, 1])
|
|
140
|
-
n_cur_filters = n_filters_conv
|
|
141
|
-
i_block = 1
|
|
142
|
-
for i_layer in range(self.n_layers_per_block):
|
|
143
|
-
self.add_module(
|
|
144
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
145
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
146
|
-
)
|
|
147
|
-
i_block += 1
|
|
148
|
-
cur_dilation[0] *= 2
|
|
149
|
-
n_out_filters = int(2 * n_cur_filters)
|
|
150
|
-
self.add_module(
|
|
151
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
152
|
-
_ResidualBlock(
|
|
153
|
-
n_cur_filters,
|
|
154
|
-
n_out_filters,
|
|
155
|
-
dilation=cur_dilation,
|
|
156
|
-
),
|
|
157
|
-
)
|
|
158
|
-
n_cur_filters = n_out_filters
|
|
159
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
160
|
-
self.add_module(
|
|
161
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
162
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
i_block += 1
|
|
166
|
-
cur_dilation[0] *= 2
|
|
167
|
-
n_out_filters = int(1.5 * n_cur_filters)
|
|
168
|
-
self.add_module(
|
|
169
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
170
|
-
_ResidualBlock(
|
|
171
|
-
n_cur_filters,
|
|
172
|
-
n_out_filters,
|
|
173
|
-
dilation=cur_dilation,
|
|
174
|
-
),
|
|
175
|
-
)
|
|
176
|
-
n_cur_filters = n_out_filters
|
|
177
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
178
|
-
self.add_module(
|
|
179
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
180
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
i_block += 1
|
|
184
|
-
cur_dilation[0] *= 2
|
|
185
|
-
self.add_module(
|
|
186
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
187
|
-
_ResidualBlock(
|
|
188
|
-
n_cur_filters,
|
|
189
|
-
n_cur_filters,
|
|
190
|
-
dilation=cur_dilation,
|
|
191
|
-
),
|
|
192
|
-
)
|
|
193
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
194
|
-
self.add_module(
|
|
195
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
196
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
i_block += 1
|
|
200
|
-
cur_dilation[0] *= 2
|
|
201
|
-
self.add_module(
|
|
202
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
203
|
-
_ResidualBlock(
|
|
204
|
-
n_cur_filters,
|
|
205
|
-
n_cur_filters,
|
|
206
|
-
dilation=cur_dilation,
|
|
207
|
-
),
|
|
208
|
-
)
|
|
209
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
210
|
-
self.add_module(
|
|
211
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
212
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
i_block += 1
|
|
216
|
-
cur_dilation[0] *= 2
|
|
217
|
-
self.add_module(
|
|
218
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
219
|
-
_ResidualBlock(
|
|
220
|
-
n_cur_filters,
|
|
221
|
-
n_cur_filters,
|
|
222
|
-
dilation=cur_dilation,
|
|
223
|
-
),
|
|
224
|
-
)
|
|
225
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
226
|
-
self.add_module(
|
|
227
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
228
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
229
|
-
)
|
|
230
|
-
i_block += 1
|
|
231
|
-
cur_dilation[0] *= 2
|
|
232
|
-
self.add_module(
|
|
233
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
234
|
-
_ResidualBlock(
|
|
235
|
-
n_cur_filters,
|
|
236
|
-
n_cur_filters,
|
|
237
|
-
dilation=cur_dilation,
|
|
238
|
-
),
|
|
239
|
-
)
|
|
240
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
241
|
-
self.add_module(
|
|
242
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
243
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
self.eval()
|
|
247
|
-
if self.final_pool_length == "auto":
|
|
248
|
-
self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
|
|
249
|
-
else:
|
|
250
|
-
pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
|
|
251
|
-
self.add_module(
|
|
252
|
-
"mean_pool",
|
|
253
|
-
AvgPool2dWithConv(
|
|
254
|
-
(self.final_pool_length, 1), (1, 1), dilation=pool_dilation
|
|
255
|
-
),
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
# Incorporating classification module and subsequent ones in one final layer
|
|
259
|
-
module = nn.Sequential()
|
|
260
|
-
|
|
261
|
-
module.add_module(
|
|
262
|
-
"conv_classifier",
|
|
263
|
-
nn.Conv2d(
|
|
264
|
-
n_cur_filters,
|
|
265
|
-
self.n_outputs,
|
|
266
|
-
(1, 1),
|
|
267
|
-
bias=True,
|
|
268
|
-
),
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
module.add_module("squeeze", SqueezeFinalOutput())
|
|
272
|
-
|
|
273
|
-
self.add_module("final_layer", module)
|
|
274
|
-
|
|
275
|
-
# Initialize all weights
|
|
276
|
-
self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
|
|
277
|
-
|
|
278
|
-
# Start in train mode
|
|
279
|
-
self.train()
|
|
280
|
-
|
|
281
|
-
@staticmethod
|
|
282
|
-
def _weights_init(module, conv_weight_init_fn):
|
|
283
|
-
"""
|
|
284
|
-
initialize weights
|
|
285
|
-
"""
|
|
286
|
-
classname = module.__class__.__name__
|
|
287
|
-
if "Conv" in classname and classname != "AvgPool2dWithConv":
|
|
288
|
-
conv_weight_init_fn(module.weight)
|
|
289
|
-
if module.bias is not None:
|
|
290
|
-
init.constant_(module.bias, 0)
|
|
291
|
-
elif "BatchNorm" in classname:
|
|
292
|
-
init.constant_(module.weight, 1)
|
|
293
|
-
init.constant_(module.bias, 0)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
class _ResidualBlock(nn.Module):
|
|
297
|
-
"""
|
|
298
|
-
create a residual learning building block with two stacked 3x3 convlayers as in paper
|
|
299
|
-
"""
|
|
300
|
-
|
|
301
|
-
def __init__(
|
|
302
|
-
self,
|
|
303
|
-
in_filters,
|
|
304
|
-
out_num_filters,
|
|
305
|
-
dilation,
|
|
306
|
-
filter_time_length=3,
|
|
307
|
-
nonlinearity: nn.Module = nn.ELU,
|
|
308
|
-
batch_norm_alpha=0.1,
|
|
309
|
-
batch_norm_epsilon=1e-4,
|
|
310
|
-
):
|
|
311
|
-
super(_ResidualBlock, self).__init__()
|
|
312
|
-
time_padding = int((filter_time_length - 1) * dilation[0])
|
|
313
|
-
assert time_padding % 2 == 0
|
|
314
|
-
time_padding = int(time_padding // 2)
|
|
315
|
-
dilation = (int(dilation[0]), int(dilation[1]))
|
|
316
|
-
assert (out_num_filters - in_filters) % 2 == 0, (
|
|
317
|
-
"Need even number of extra channels in order to be able to pad correctly"
|
|
318
|
-
)
|
|
319
|
-
self.n_pad_chans = out_num_filters - in_filters
|
|
320
|
-
|
|
321
|
-
self.conv_1 = nn.Conv2d(
|
|
322
|
-
in_filters,
|
|
323
|
-
out_num_filters,
|
|
324
|
-
(filter_time_length, 1),
|
|
325
|
-
stride=(1, 1),
|
|
326
|
-
dilation=dilation,
|
|
327
|
-
padding=(time_padding, 0),
|
|
328
|
-
)
|
|
329
|
-
self.bn1 = nn.BatchNorm2d(
|
|
330
|
-
out_num_filters,
|
|
331
|
-
momentum=batch_norm_alpha,
|
|
332
|
-
affine=True,
|
|
333
|
-
eps=batch_norm_epsilon,
|
|
334
|
-
)
|
|
335
|
-
self.conv_2 = nn.Conv2d(
|
|
336
|
-
out_num_filters,
|
|
337
|
-
out_num_filters,
|
|
338
|
-
(filter_time_length, 1),
|
|
339
|
-
stride=(1, 1),
|
|
340
|
-
dilation=dilation,
|
|
341
|
-
padding=(time_padding, 0),
|
|
342
|
-
)
|
|
343
|
-
self.bn2 = nn.BatchNorm2d(
|
|
344
|
-
out_num_filters,
|
|
345
|
-
momentum=batch_norm_alpha,
|
|
346
|
-
affine=True,
|
|
347
|
-
eps=batch_norm_epsilon,
|
|
348
|
-
)
|
|
349
|
-
# also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
|
|
350
|
-
# for resnet options as ilya used them
|
|
351
|
-
self.nonlinearity = nonlinearity()
|
|
352
|
-
|
|
353
|
-
def forward(self, x):
|
|
354
|
-
stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
|
|
355
|
-
stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
|
|
356
|
-
if self.n_pad_chans != 0:
|
|
357
|
-
zeros_for_padding = x.new_zeros(
|
|
358
|
-
(x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
|
|
359
|
-
)
|
|
360
|
-
x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
|
|
361
|
-
out = self.nonlinearity(x + stack_2)
|
|
362
|
-
return out
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{braindecode-1.2.0.dev182094932.dist-info → braindecode-1.3.0.dev176728557.dist-info}/top_level.txt
RENAMED
|
File without changes
|