braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from einops.layers.torch import Rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import init
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
from braindecode.modules import (
|
|
11
|
+
AvgPool2dWithConv,
|
|
12
|
+
CombinedConv,
|
|
13
|
+
Ensure4d,
|
|
14
|
+
SqueezeFinalOutput,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
19
|
+
"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
20
|
+
|
|
21
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
|
|
22
|
+
:align: center
|
|
23
|
+
:alt: CTNet Architecture
|
|
24
|
+
|
|
25
|
+
Model described in [Schirrmeister2017]_.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
final_conv_length: int | str
|
|
30
|
+
Length of the final convolution layer.
|
|
31
|
+
If set to "auto", n_times must not be None. Default: "auto".
|
|
32
|
+
n_filters_time: int
|
|
33
|
+
Number of temporal filters.
|
|
34
|
+
n_filters_spat: int
|
|
35
|
+
Number of spatial filters.
|
|
36
|
+
filter_time_length: int
|
|
37
|
+
Length of the temporal filter in layer 1.
|
|
38
|
+
pool_time_length: int
|
|
39
|
+
Length of temporal pooling filter.
|
|
40
|
+
pool_time_stride: int
|
|
41
|
+
Length of stride between temporal pooling filters.
|
|
42
|
+
n_filters_2: int
|
|
43
|
+
Number of temporal filters in layer 2.
|
|
44
|
+
filter_length_2: int
|
|
45
|
+
Length of the temporal filter in layer 2.
|
|
46
|
+
n_filters_3: int
|
|
47
|
+
Number of temporal filters in layer 3.
|
|
48
|
+
filter_length_3: int
|
|
49
|
+
Length of the temporal filter in layer 3.
|
|
50
|
+
n_filters_4: int
|
|
51
|
+
Number of temporal filters in layer 4.
|
|
52
|
+
filter_length_4: int
|
|
53
|
+
Length of the temporal filter in layer 4.
|
|
54
|
+
activation_first_conv_nonlin: nn.Module, default is nn.ELU
|
|
55
|
+
Non-linear activation function to be used after convolution in layer 1.
|
|
56
|
+
first_pool_mode: str
|
|
57
|
+
Pooling mode in layer 1. "max" or "mean".
|
|
58
|
+
first_pool_nonlin: callable
|
|
59
|
+
Non-linear activation function to be used after pooling in layer 1.
|
|
60
|
+
activation_later_conv_nonlin: nn.Module, default is nn.ELU
|
|
61
|
+
Non-linear activation function to be used after convolution in later layers.
|
|
62
|
+
later_pool_mode: str
|
|
63
|
+
Pooling mode in later layers. "max" or "mean".
|
|
64
|
+
later_pool_nonlin: callable
|
|
65
|
+
Non-linear activation function to be used after pooling in later layers.
|
|
66
|
+
drop_prob: float
|
|
67
|
+
Dropout probability.
|
|
68
|
+
split_first_layer: bool
|
|
69
|
+
Split first layer into temporal and spatial layers (True) or just use temporal (False).
|
|
70
|
+
There would be no non-linearity between the split layers.
|
|
71
|
+
batch_norm: bool
|
|
72
|
+
Whether to use batch normalisation.
|
|
73
|
+
batch_norm_alpha: float
|
|
74
|
+
Momentum for BatchNorm2d.
|
|
75
|
+
stride_before_pool: bool
|
|
76
|
+
Stride before pooling.
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
References
|
|
80
|
+
----------
|
|
81
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
82
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
83
|
+
& Ball, T. (2017).
|
|
84
|
+
Deep learning with convolutional neural networks for EEG decoding and
|
|
85
|
+
visualization.
|
|
86
|
+
Human Brain Mapping , Aug. 2017.
|
|
87
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
n_chans=None,
|
|
93
|
+
n_outputs=None,
|
|
94
|
+
n_times=None,
|
|
95
|
+
final_conv_length="auto",
|
|
96
|
+
n_filters_time=25,
|
|
97
|
+
n_filters_spat=25,
|
|
98
|
+
filter_time_length=10,
|
|
99
|
+
pool_time_length=3,
|
|
100
|
+
pool_time_stride=3,
|
|
101
|
+
n_filters_2=50,
|
|
102
|
+
filter_length_2=10,
|
|
103
|
+
n_filters_3=100,
|
|
104
|
+
filter_length_3=10,
|
|
105
|
+
n_filters_4=200,
|
|
106
|
+
filter_length_4=10,
|
|
107
|
+
activation_first_conv_nonlin: nn.Module = nn.ELU,
|
|
108
|
+
first_pool_mode="max",
|
|
109
|
+
first_pool_nonlin: nn.Module = nn.Identity,
|
|
110
|
+
activation_later_conv_nonlin: nn.Module = nn.ELU,
|
|
111
|
+
later_pool_mode="max",
|
|
112
|
+
later_pool_nonlin: nn.Module = nn.Identity,
|
|
113
|
+
drop_prob=0.5,
|
|
114
|
+
split_first_layer=True,
|
|
115
|
+
batch_norm=True,
|
|
116
|
+
batch_norm_alpha=0.1,
|
|
117
|
+
stride_before_pool=False,
|
|
118
|
+
chs_info=None,
|
|
119
|
+
input_window_seconds=None,
|
|
120
|
+
sfreq=None,
|
|
121
|
+
):
|
|
122
|
+
super().__init__(
|
|
123
|
+
n_outputs=n_outputs,
|
|
124
|
+
n_chans=n_chans,
|
|
125
|
+
chs_info=chs_info,
|
|
126
|
+
n_times=n_times,
|
|
127
|
+
input_window_seconds=input_window_seconds,
|
|
128
|
+
sfreq=sfreq,
|
|
129
|
+
)
|
|
130
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
131
|
+
|
|
132
|
+
if final_conv_length == "auto":
|
|
133
|
+
assert self.n_times is not None
|
|
134
|
+
self.final_conv_length = final_conv_length
|
|
135
|
+
self.n_filters_time = n_filters_time
|
|
136
|
+
self.n_filters_spat = n_filters_spat
|
|
137
|
+
self.filter_time_length = filter_time_length
|
|
138
|
+
self.pool_time_length = pool_time_length
|
|
139
|
+
self.pool_time_stride = pool_time_stride
|
|
140
|
+
self.n_filters_2 = n_filters_2
|
|
141
|
+
self.filter_length_2 = filter_length_2
|
|
142
|
+
self.n_filters_3 = n_filters_3
|
|
143
|
+
self.filter_length_3 = filter_length_3
|
|
144
|
+
self.n_filters_4 = n_filters_4
|
|
145
|
+
self.filter_length_4 = filter_length_4
|
|
146
|
+
self.first_nonlin = activation_first_conv_nonlin
|
|
147
|
+
self.first_pool_mode = first_pool_mode
|
|
148
|
+
self.first_pool_nonlin = first_pool_nonlin
|
|
149
|
+
self.later_conv_nonlin = activation_later_conv_nonlin
|
|
150
|
+
self.later_pool_mode = later_pool_mode
|
|
151
|
+
self.later_pool_nonlin = later_pool_nonlin
|
|
152
|
+
self.drop_prob = drop_prob
|
|
153
|
+
self.split_first_layer = split_first_layer
|
|
154
|
+
self.batch_norm = batch_norm
|
|
155
|
+
self.batch_norm_alpha = batch_norm_alpha
|
|
156
|
+
self.stride_before_pool = stride_before_pool
|
|
157
|
+
|
|
158
|
+
# For the load_state_dict
|
|
159
|
+
# When padronize all layers,
|
|
160
|
+
# add the old's parameters here
|
|
161
|
+
self.mapping = {
|
|
162
|
+
"conv_time.weight": "conv_time_spat.conv_time.weight",
|
|
163
|
+
"conv_spat.weight": "conv_time_spat.conv_spat.weight",
|
|
164
|
+
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
165
|
+
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
166
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
167
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if self.stride_before_pool:
|
|
171
|
+
conv_stride = self.pool_time_stride
|
|
172
|
+
pool_stride = 1
|
|
173
|
+
else:
|
|
174
|
+
conv_stride = 1
|
|
175
|
+
pool_stride = self.pool_time_stride
|
|
176
|
+
self.add_module("ensuredims", Ensure4d())
|
|
177
|
+
pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
|
|
178
|
+
first_pool_class = pool_class_dict[self.first_pool_mode]
|
|
179
|
+
later_pool_class = pool_class_dict[self.later_pool_mode]
|
|
180
|
+
if self.split_first_layer:
|
|
181
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
182
|
+
self.add_module(
|
|
183
|
+
"conv_time_spat",
|
|
184
|
+
CombinedConv(
|
|
185
|
+
in_chans=self.n_chans,
|
|
186
|
+
n_filters_time=self.n_filters_time,
|
|
187
|
+
n_filters_spat=self.n_filters_spat,
|
|
188
|
+
filter_time_length=filter_time_length,
|
|
189
|
+
bias_time=True,
|
|
190
|
+
bias_spat=not self.batch_norm,
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
n_filters_conv = self.n_filters_spat
|
|
194
|
+
else:
|
|
195
|
+
self.add_module(
|
|
196
|
+
"conv_time",
|
|
197
|
+
nn.Conv2d(
|
|
198
|
+
self.n_chans,
|
|
199
|
+
self.n_filters_time,
|
|
200
|
+
(self.filter_time_length, 1),
|
|
201
|
+
stride=(conv_stride, 1),
|
|
202
|
+
bias=not self.batch_norm,
|
|
203
|
+
),
|
|
204
|
+
)
|
|
205
|
+
n_filters_conv = self.n_filters_time
|
|
206
|
+
if self.batch_norm:
|
|
207
|
+
self.add_module(
|
|
208
|
+
"bnorm",
|
|
209
|
+
nn.BatchNorm2d(
|
|
210
|
+
n_filters_conv,
|
|
211
|
+
momentum=self.batch_norm_alpha,
|
|
212
|
+
affine=True,
|
|
213
|
+
eps=1e-5,
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
self.add_module("conv_nonlin", self.first_nonlin())
|
|
217
|
+
self.add_module(
|
|
218
|
+
"pool",
|
|
219
|
+
first_pool_class(
|
|
220
|
+
kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
|
|
221
|
+
),
|
|
222
|
+
)
|
|
223
|
+
self.add_module("pool_nonlin", self.first_pool_nonlin())
|
|
224
|
+
|
|
225
|
+
def add_conv_pool_block(
|
|
226
|
+
model, n_filters_before, n_filters, filter_length, block_nr
|
|
227
|
+
):
|
|
228
|
+
suffix = "_{:d}".format(block_nr)
|
|
229
|
+
self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
|
|
230
|
+
self.add_module(
|
|
231
|
+
"conv" + suffix,
|
|
232
|
+
nn.Conv2d(
|
|
233
|
+
n_filters_before,
|
|
234
|
+
n_filters,
|
|
235
|
+
(filter_length, 1),
|
|
236
|
+
stride=(conv_stride, 1),
|
|
237
|
+
bias=not self.batch_norm,
|
|
238
|
+
),
|
|
239
|
+
)
|
|
240
|
+
if self.batch_norm:
|
|
241
|
+
self.add_module(
|
|
242
|
+
"bnorm" + suffix,
|
|
243
|
+
nn.BatchNorm2d(
|
|
244
|
+
n_filters,
|
|
245
|
+
momentum=self.batch_norm_alpha,
|
|
246
|
+
affine=True,
|
|
247
|
+
eps=1e-5,
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
self.add_module("nonlin" + suffix, self.later_conv_nonlin())
|
|
251
|
+
|
|
252
|
+
self.add_module(
|
|
253
|
+
"pool" + suffix,
|
|
254
|
+
later_pool_class(
|
|
255
|
+
kernel_size=(self.pool_time_length, 1),
|
|
256
|
+
stride=(pool_stride, 1),
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
|
|
260
|
+
|
|
261
|
+
add_conv_pool_block(
|
|
262
|
+
self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
|
|
263
|
+
)
|
|
264
|
+
add_conv_pool_block(
|
|
265
|
+
self, self.n_filters_2, self.n_filters_3, self.filter_length_3, 3
|
|
266
|
+
)
|
|
267
|
+
add_conv_pool_block(
|
|
268
|
+
self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
|
|
272
|
+
self.eval()
|
|
273
|
+
if self.final_conv_length == "auto":
|
|
274
|
+
self.final_conv_length = self.get_output_shape()[2]
|
|
275
|
+
|
|
276
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
277
|
+
module = nn.Sequential()
|
|
278
|
+
|
|
279
|
+
module.add_module(
|
|
280
|
+
"conv_classifier",
|
|
281
|
+
nn.Conv2d(
|
|
282
|
+
self.n_filters_4,
|
|
283
|
+
self.n_outputs,
|
|
284
|
+
(self.final_conv_length, 1),
|
|
285
|
+
bias=True,
|
|
286
|
+
),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
290
|
+
|
|
291
|
+
self.add_module("final_layer", module)
|
|
292
|
+
|
|
293
|
+
# Initialization, xavier is same as in our paper...
|
|
294
|
+
# was default from lasagne
|
|
295
|
+
init.xavier_uniform_(self.conv_time_spat.conv_time.weight, gain=1)
|
|
296
|
+
# maybe no bias in case of no split layer and batch norm
|
|
297
|
+
if self.split_first_layer or (not self.batch_norm):
|
|
298
|
+
init.constant_(self.conv_time_spat.conv_time.bias, 0)
|
|
299
|
+
if self.split_first_layer:
|
|
300
|
+
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
301
|
+
if not self.batch_norm:
|
|
302
|
+
init.constant_(self.conv_spat.bias, 0)
|
|
303
|
+
if self.batch_norm:
|
|
304
|
+
init.constant_(self.bnorm.weight, 1)
|
|
305
|
+
init.constant_(self.bnorm.bias, 0)
|
|
306
|
+
param_dict = dict(list(self.named_parameters()))
|
|
307
|
+
for block_nr in range(2, 5):
|
|
308
|
+
conv_weight = param_dict["conv_{:d}.weight".format(block_nr)]
|
|
309
|
+
init.xavier_uniform_(conv_weight, gain=1)
|
|
310
|
+
if not self.batch_norm:
|
|
311
|
+
conv_bias = param_dict["conv_{:d}.bias".format(block_nr)]
|
|
312
|
+
init.constant_(conv_bias, 0)
|
|
313
|
+
else:
|
|
314
|
+
bnorm_weight = param_dict["bnorm_{:d}.weight".format(block_nr)]
|
|
315
|
+
bnorm_bias = param_dict["bnorm_{:d}.bias".format(block_nr)]
|
|
316
|
+
init.constant_(bnorm_weight, 1)
|
|
317
|
+
init.constant_(bnorm_bias, 0)
|
|
318
|
+
|
|
319
|
+
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
320
|
+
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
321
|
+
|
|
322
|
+
self.train()
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Authors: Théo Gnassounou <theo.gnassounou@inria.fr>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
|
|
12
|
+
|
|
13
|
+
.. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
|
|
14
|
+
:align: center
|
|
15
|
+
:alt: DeepSleepNet Architecture
|
|
16
|
+
|
|
17
|
+
Convolutional neural network and bidirectional-Long Short-Term
|
|
18
|
+
for single channels sleep staging described in [Supratak2017]_.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
activation_large: nn.Module, default=nn.ELU
|
|
23
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
24
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
25
|
+
activation_small: nn.Module, default=nn.ReLU
|
|
26
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
27
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
28
|
+
return_feats : bool
|
|
29
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
30
|
+
(before the final linear layer). If False, pass the features through
|
|
31
|
+
the final linear layer.
|
|
32
|
+
drop_prob : float, default=0.5
|
|
33
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
References
|
|
37
|
+
----------
|
|
38
|
+
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
39
|
+
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
40
|
+
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
41
|
+
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
n_outputs=5,
|
|
47
|
+
return_feats=False,
|
|
48
|
+
n_chans=None,
|
|
49
|
+
chs_info=None,
|
|
50
|
+
n_times=None,
|
|
51
|
+
input_window_seconds=None,
|
|
52
|
+
sfreq=None,
|
|
53
|
+
activation_large: nn.Module = nn.ELU,
|
|
54
|
+
activation_small: nn.Module = nn.ReLU,
|
|
55
|
+
drop_prob: float = 0.5,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(
|
|
58
|
+
n_outputs=n_outputs,
|
|
59
|
+
n_chans=n_chans,
|
|
60
|
+
chs_info=chs_info,
|
|
61
|
+
n_times=n_times,
|
|
62
|
+
input_window_seconds=input_window_seconds,
|
|
63
|
+
sfreq=sfreq,
|
|
64
|
+
)
|
|
65
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
66
|
+
self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
|
|
67
|
+
self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
|
|
68
|
+
self.dropout = nn.Dropout(0.5)
|
|
69
|
+
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
70
|
+
self.fc = nn.Sequential(
|
|
71
|
+
nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.features_extractor = nn.Identity()
|
|
75
|
+
self.len_last_layer = 1024
|
|
76
|
+
self.return_feats = return_feats
|
|
77
|
+
|
|
78
|
+
# TODO: Add new way to handle return_features == True
|
|
79
|
+
if not return_feats:
|
|
80
|
+
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
81
|
+
else:
|
|
82
|
+
self.final_layer = nn.Identity()
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
"""Forward pass.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
x: torch.Tensor
|
|
90
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if x.ndim == 3:
|
|
94
|
+
x = x.unsqueeze(1)
|
|
95
|
+
|
|
96
|
+
x1 = self.cnn1(x)
|
|
97
|
+
x1 = x1.flatten(start_dim=1)
|
|
98
|
+
|
|
99
|
+
x2 = self.cnn2(x)
|
|
100
|
+
x2 = x2.flatten(start_dim=1)
|
|
101
|
+
|
|
102
|
+
x = torch.cat((x1, x2), dim=1)
|
|
103
|
+
x = self.dropout(x)
|
|
104
|
+
temp = x.clone()
|
|
105
|
+
temp = self.fc(temp)
|
|
106
|
+
x = x.unsqueeze(1)
|
|
107
|
+
x = self.bilstm(x)
|
|
108
|
+
x = x.squeeze()
|
|
109
|
+
x = torch.add(x, temp)
|
|
110
|
+
x = self.dropout(x)
|
|
111
|
+
|
|
112
|
+
feats = self.features_extractor(x)
|
|
113
|
+
|
|
114
|
+
if self.return_feats:
|
|
115
|
+
return feats
|
|
116
|
+
else:
|
|
117
|
+
return self.final_layer(feats)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class _SmallCNN(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
Smaller filter sizes to learn temporal information.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
activation: nn.Module, default=nn.ReLU
|
|
127
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
128
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
129
|
+
drop_prob : float, default=0.5
|
|
130
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.conv1 = nn.Sequential(
|
|
136
|
+
nn.Conv2d(
|
|
137
|
+
in_channels=1,
|
|
138
|
+
out_channels=64,
|
|
139
|
+
kernel_size=(1, 50),
|
|
140
|
+
stride=(1, 6),
|
|
141
|
+
padding=(0, 22),
|
|
142
|
+
bias=False,
|
|
143
|
+
),
|
|
144
|
+
nn.BatchNorm2d(num_features=64),
|
|
145
|
+
activation(),
|
|
146
|
+
)
|
|
147
|
+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
|
|
148
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
149
|
+
self.conv2 = nn.Sequential(
|
|
150
|
+
nn.Conv2d(
|
|
151
|
+
in_channels=64,
|
|
152
|
+
out_channels=128,
|
|
153
|
+
kernel_size=(1, 8),
|
|
154
|
+
stride=1,
|
|
155
|
+
padding="same",
|
|
156
|
+
bias=False,
|
|
157
|
+
),
|
|
158
|
+
nn.BatchNorm2d(num_features=128),
|
|
159
|
+
activation(),
|
|
160
|
+
)
|
|
161
|
+
self.conv3 = nn.Sequential(
|
|
162
|
+
nn.Conv2d(
|
|
163
|
+
in_channels=128,
|
|
164
|
+
out_channels=128,
|
|
165
|
+
kernel_size=(1, 8),
|
|
166
|
+
stride=1,
|
|
167
|
+
padding="same",
|
|
168
|
+
bias=False,
|
|
169
|
+
),
|
|
170
|
+
nn.BatchNorm2d(num_features=128),
|
|
171
|
+
activation(),
|
|
172
|
+
)
|
|
173
|
+
self.conv4 = nn.Sequential(
|
|
174
|
+
nn.Conv2d(
|
|
175
|
+
in_channels=128,
|
|
176
|
+
out_channels=128,
|
|
177
|
+
kernel_size=(1, 8),
|
|
178
|
+
stride=1,
|
|
179
|
+
padding="same",
|
|
180
|
+
bias=False,
|
|
181
|
+
),
|
|
182
|
+
nn.BatchNorm2d(num_features=128),
|
|
183
|
+
activation(),
|
|
184
|
+
)
|
|
185
|
+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
|
|
186
|
+
|
|
187
|
+
def forward(self, x):
|
|
188
|
+
x = self.conv1(x)
|
|
189
|
+
x = self.dropout(self.pool1(x))
|
|
190
|
+
x = self.conv2(x)
|
|
191
|
+
x = self.conv3(x)
|
|
192
|
+
x = self.conv4(x)
|
|
193
|
+
x = self.pool2(x)
|
|
194
|
+
return x
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class _LargeCNN(nn.Module):
|
|
198
|
+
"""
|
|
199
|
+
Larger filter sizes to learn frequency information.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
activation: nn.Module, default=nn.ELU
|
|
204
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
205
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
|
|
210
|
+
super().__init__()
|
|
211
|
+
|
|
212
|
+
self.conv1 = nn.Sequential(
|
|
213
|
+
nn.Conv2d(
|
|
214
|
+
in_channels=1,
|
|
215
|
+
out_channels=64,
|
|
216
|
+
kernel_size=(1, 400),
|
|
217
|
+
stride=(1, 50),
|
|
218
|
+
padding=(0, 175),
|
|
219
|
+
bias=False,
|
|
220
|
+
),
|
|
221
|
+
nn.BatchNorm2d(num_features=64),
|
|
222
|
+
activation(),
|
|
223
|
+
)
|
|
224
|
+
self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
|
|
225
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
226
|
+
self.conv2 = nn.Sequential(
|
|
227
|
+
nn.Conv2d(
|
|
228
|
+
in_channels=64,
|
|
229
|
+
out_channels=128,
|
|
230
|
+
kernel_size=(1, 6),
|
|
231
|
+
stride=1,
|
|
232
|
+
padding="same",
|
|
233
|
+
bias=False,
|
|
234
|
+
),
|
|
235
|
+
nn.BatchNorm2d(num_features=128),
|
|
236
|
+
activation(),
|
|
237
|
+
)
|
|
238
|
+
self.conv3 = nn.Sequential(
|
|
239
|
+
nn.Conv2d(
|
|
240
|
+
in_channels=128,
|
|
241
|
+
out_channels=128,
|
|
242
|
+
kernel_size=(1, 6),
|
|
243
|
+
stride=1,
|
|
244
|
+
padding="same",
|
|
245
|
+
bias=False,
|
|
246
|
+
),
|
|
247
|
+
nn.BatchNorm2d(num_features=128),
|
|
248
|
+
activation(),
|
|
249
|
+
)
|
|
250
|
+
self.conv4 = nn.Sequential(
|
|
251
|
+
nn.Conv2d(
|
|
252
|
+
in_channels=128,
|
|
253
|
+
out_channels=128,
|
|
254
|
+
kernel_size=(1, 6),
|
|
255
|
+
stride=1,
|
|
256
|
+
padding="same",
|
|
257
|
+
bias=False,
|
|
258
|
+
),
|
|
259
|
+
nn.BatchNorm2d(num_features=128),
|
|
260
|
+
activation(),
|
|
261
|
+
)
|
|
262
|
+
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
|
|
263
|
+
|
|
264
|
+
def forward(self, x):
|
|
265
|
+
x = self.conv1(x)
|
|
266
|
+
x = self.dropout(self.pool1(x))
|
|
267
|
+
x = self.conv2(x)
|
|
268
|
+
x = self.conv3(x)
|
|
269
|
+
x = self.conv4(x)
|
|
270
|
+
x = self.pool2(x)
|
|
271
|
+
return x
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class _BiLSTM(nn.Module):
|
|
275
|
+
def __init__(self, input_size, hidden_size, num_layers):
|
|
276
|
+
super(_BiLSTM, self).__init__()
|
|
277
|
+
self.hidden_size = hidden_size
|
|
278
|
+
self.num_layers = num_layers
|
|
279
|
+
self.lstm = nn.LSTM(
|
|
280
|
+
input_size,
|
|
281
|
+
hidden_size,
|
|
282
|
+
num_layers,
|
|
283
|
+
batch_first=True,
|
|
284
|
+
dropout=0.5,
|
|
285
|
+
bidirectional=True,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def forward(self, x):
|
|
289
|
+
# set initial hidden and cell states
|
|
290
|
+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
291
|
+
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
292
|
+
|
|
293
|
+
# forward propagate LSTM
|
|
294
|
+
out, _ = self.lstm(x, (h0, c0))
|
|
295
|
+
return out
|