braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/deep4.py
CHANGED
|
@@ -3,17 +3,25 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
from einops.layers.torch import Rearrange
|
|
6
|
+
from mne.utils import warn
|
|
6
7
|
from torch import nn
|
|
7
8
|
from torch.nn import init
|
|
8
|
-
from torch.nn.functional import elu
|
|
9
9
|
|
|
10
|
-
from .base import EEGModuleMixin
|
|
11
|
-
from .
|
|
12
|
-
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
AvgPool2dWithConv,
|
|
13
|
+
CombinedConv,
|
|
14
|
+
Ensure4d,
|
|
15
|
+
SqueezeFinalOutput,
|
|
16
|
+
)
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
16
|
-
"""Deep ConvNet model from Schirrmeister et al 2017.
|
|
20
|
+
"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
21
|
+
|
|
22
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
|
|
23
|
+
:align: center
|
|
24
|
+
:alt: CTNet Architecture
|
|
17
25
|
|
|
18
26
|
Model described in [Schirrmeister2017]_.
|
|
19
27
|
|
|
@@ -44,13 +52,13 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
44
52
|
Number of temporal filters in layer 4.
|
|
45
53
|
filter_length_4: int
|
|
46
54
|
Length of the temporal filter in layer 4.
|
|
47
|
-
|
|
55
|
+
activation_first_conv_nonlin: nn.Module, default is nn.ELU
|
|
48
56
|
Non-linear activation function to be used after convolution in layer 1.
|
|
49
57
|
first_pool_mode: str
|
|
50
58
|
Pooling mode in layer 1. "max" or "mean".
|
|
51
59
|
first_pool_nonlin: callable
|
|
52
60
|
Non-linear activation function to be used after pooling in layer 1.
|
|
53
|
-
|
|
61
|
+
activation_later_conv_nonlin: nn.Module, default is nn.ELU
|
|
54
62
|
Non-linear activation function to be used after convolution in later layers.
|
|
55
63
|
later_pool_mode: str
|
|
56
64
|
Pooling mode in later layers. "max" or "mean".
|
|
@@ -67,12 +75,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
67
75
|
Momentum for BatchNorm2d.
|
|
68
76
|
stride_before_pool: bool
|
|
69
77
|
Stride before pooling.
|
|
70
|
-
in_chans :
|
|
71
|
-
Alias for n_chans.
|
|
72
|
-
n_classes:
|
|
73
|
-
Alias for n_outputs.
|
|
74
|
-
input_window_samples :
|
|
75
|
-
Alias for n_times.
|
|
76
78
|
|
|
77
79
|
|
|
78
80
|
References
|
|
@@ -87,47 +89,38 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
87
89
|
"""
|
|
88
90
|
|
|
89
91
|
def __init__(
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
n_classes=None,
|
|
122
|
-
input_window_samples=None,
|
|
123
|
-
add_log_softmax=True,
|
|
92
|
+
self,
|
|
93
|
+
n_chans=None,
|
|
94
|
+
n_outputs=None,
|
|
95
|
+
n_times=None,
|
|
96
|
+
final_conv_length="auto",
|
|
97
|
+
n_filters_time=25,
|
|
98
|
+
n_filters_spat=25,
|
|
99
|
+
filter_time_length=10,
|
|
100
|
+
pool_time_length=3,
|
|
101
|
+
pool_time_stride=3,
|
|
102
|
+
n_filters_2=50,
|
|
103
|
+
filter_length_2=10,
|
|
104
|
+
n_filters_3=100,
|
|
105
|
+
filter_length_3=10,
|
|
106
|
+
n_filters_4=200,
|
|
107
|
+
filter_length_4=10,
|
|
108
|
+
activation_first_conv_nonlin: nn.Module = nn.ELU,
|
|
109
|
+
first_pool_mode="max",
|
|
110
|
+
first_pool_nonlin: nn.Module = nn.Identity,
|
|
111
|
+
activation_later_conv_nonlin: nn.Module = nn.ELU,
|
|
112
|
+
later_pool_mode="max",
|
|
113
|
+
later_pool_nonlin: nn.Module = nn.Identity,
|
|
114
|
+
drop_prob=0.5,
|
|
115
|
+
split_first_layer=True,
|
|
116
|
+
batch_norm=True,
|
|
117
|
+
batch_norm_alpha=0.1,
|
|
118
|
+
stride_before_pool=False,
|
|
119
|
+
# Braindecode EEGModuleMixin parameters
|
|
120
|
+
chs_info=None,
|
|
121
|
+
input_window_seconds=None,
|
|
122
|
+
sfreq=None,
|
|
124
123
|
):
|
|
125
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
126
|
-
self,
|
|
127
|
-
('in_chans', 'n_chans', in_chans, n_chans),
|
|
128
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
129
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
130
|
-
)
|
|
131
124
|
super().__init__(
|
|
132
125
|
n_outputs=n_outputs,
|
|
133
126
|
n_chans=n_chans,
|
|
@@ -135,10 +128,9 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
135
128
|
n_times=n_times,
|
|
136
129
|
input_window_seconds=input_window_seconds,
|
|
137
130
|
sfreq=sfreq,
|
|
138
|
-
add_log_softmax=add_log_softmax,
|
|
139
131
|
)
|
|
140
132
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
141
|
-
|
|
133
|
+
|
|
142
134
|
if final_conv_length == "auto":
|
|
143
135
|
assert self.n_times is not None
|
|
144
136
|
self.final_conv_length = final_conv_length
|
|
@@ -153,10 +145,10 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
153
145
|
self.filter_length_3 = filter_length_3
|
|
154
146
|
self.n_filters_4 = n_filters_4
|
|
155
147
|
self.filter_length_4 = filter_length_4
|
|
156
|
-
self.first_nonlin =
|
|
148
|
+
self.first_nonlin = activation_first_conv_nonlin
|
|
157
149
|
self.first_pool_mode = first_pool_mode
|
|
158
150
|
self.first_pool_nonlin = first_pool_nonlin
|
|
159
|
-
self.later_conv_nonlin =
|
|
151
|
+
self.later_conv_nonlin = activation_later_conv_nonlin
|
|
160
152
|
self.later_pool_mode = later_pool_mode
|
|
161
153
|
self.later_pool_nonlin = later_pool_nonlin
|
|
162
154
|
self.drop_prob = drop_prob
|
|
@@ -165,6 +157,27 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
165
157
|
self.batch_norm_alpha = batch_norm_alpha
|
|
166
158
|
self.stride_before_pool = stride_before_pool
|
|
167
159
|
|
|
160
|
+
min_n_times = self._get_min_n_times()
|
|
161
|
+
if self.n_times < min_n_times:
|
|
162
|
+
scaling_factor = self.n_times / min_n_times
|
|
163
|
+
warn(
|
|
164
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
165
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
166
|
+
"Adjusting parameters to ensure compatibility."
|
|
167
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
168
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
169
|
+
UserWarning,
|
|
170
|
+
)
|
|
171
|
+
# Calculate a scaling factor to adjust temporal parameters
|
|
172
|
+
# Apply the scaling factor to all temporal kernel and pooling sizes
|
|
173
|
+
self.filter_time_length = max(
|
|
174
|
+
1, int(self.filter_time_length * scaling_factor)
|
|
175
|
+
)
|
|
176
|
+
self.pool_time_length = max(1, int(self.pool_time_length * scaling_factor))
|
|
177
|
+
self.pool_time_stride = max(1, int(self.pool_time_stride * scaling_factor))
|
|
178
|
+
self.filter_length_2 = max(1, int(self.filter_length_2 * scaling_factor))
|
|
179
|
+
self.filter_length_3 = max(1, int(self.filter_length_3 * scaling_factor))
|
|
180
|
+
self.filter_length_4 = max(1, int(self.filter_length_4 * scaling_factor))
|
|
168
181
|
# For the load_state_dict
|
|
169
182
|
# When padronize all layers,
|
|
170
183
|
# add the old's parameters here
|
|
@@ -174,7 +187,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
174
187
|
"conv_time.bias": "conv_time_spat.conv_time.bias",
|
|
175
188
|
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
|
|
176
189
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
177
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
190
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
178
191
|
}
|
|
179
192
|
|
|
180
193
|
if self.stride_before_pool:
|
|
@@ -223,17 +236,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
223
236
|
eps=1e-5,
|
|
224
237
|
),
|
|
225
238
|
)
|
|
226
|
-
self.add_module("conv_nonlin",
|
|
239
|
+
self.add_module("conv_nonlin", self.first_nonlin())
|
|
227
240
|
self.add_module(
|
|
228
241
|
"pool",
|
|
229
242
|
first_pool_class(
|
|
230
243
|
kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1)
|
|
231
244
|
),
|
|
232
245
|
)
|
|
233
|
-
self.add_module("pool_nonlin",
|
|
246
|
+
self.add_module("pool_nonlin", self.first_pool_nonlin())
|
|
234
247
|
|
|
235
248
|
def add_conv_pool_block(
|
|
236
|
-
|
|
249
|
+
model, n_filters_before, n_filters, filter_length, block_nr
|
|
237
250
|
):
|
|
238
251
|
suffix = "_{:d}".format(block_nr)
|
|
239
252
|
self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob))
|
|
@@ -257,7 +270,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
257
270
|
eps=1e-5,
|
|
258
271
|
),
|
|
259
272
|
)
|
|
260
|
-
self.add_module("nonlin" + suffix,
|
|
273
|
+
self.add_module("nonlin" + suffix, self.later_conv_nonlin())
|
|
261
274
|
|
|
262
275
|
self.add_module(
|
|
263
276
|
"pool" + suffix,
|
|
@@ -266,7 +279,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
266
279
|
stride=(pool_stride, 1),
|
|
267
280
|
),
|
|
268
281
|
)
|
|
269
|
-
self.add_module("pool_nonlin" + suffix,
|
|
282
|
+
self.add_module("pool_nonlin" + suffix, self.later_pool_nonlin())
|
|
270
283
|
|
|
271
284
|
add_conv_pool_block(
|
|
272
285
|
self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2
|
|
@@ -278,7 +291,6 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
278
291
|
self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4
|
|
279
292
|
)
|
|
280
293
|
|
|
281
|
-
# self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob))
|
|
282
294
|
self.eval()
|
|
283
295
|
if self.final_conv_length == "auto":
|
|
284
296
|
self.final_conv_length = self.get_output_shape()[2]
|
|
@@ -286,17 +298,17 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
286
298
|
# Incorporating classification module and subsequent ones in one final layer
|
|
287
299
|
module = nn.Sequential()
|
|
288
300
|
|
|
289
|
-
module.add_module(
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
301
|
+
module.add_module(
|
|
302
|
+
"conv_classifier",
|
|
303
|
+
nn.Conv2d(
|
|
304
|
+
self.n_filters_4,
|
|
305
|
+
self.n_outputs,
|
|
306
|
+
(self.final_conv_length, 1),
|
|
307
|
+
bias=True,
|
|
308
|
+
),
|
|
309
|
+
)
|
|
298
310
|
|
|
299
|
-
module.add_module("squeeze",
|
|
311
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
300
312
|
|
|
301
313
|
self.add_module("final_layer", module)
|
|
302
314
|
|
|
@@ -309,7 +321,7 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
309
321
|
if self.split_first_layer:
|
|
310
322
|
init.xavier_uniform_(self.conv_time_spat.conv_spat.weight, gain=1)
|
|
311
323
|
if not self.batch_norm:
|
|
312
|
-
init.constant_(self.conv_spat.bias, 0)
|
|
324
|
+
init.constant_(self.conv_time_spat.conv_spat.bias, 0)
|
|
313
325
|
if self.batch_norm:
|
|
314
326
|
init.constant_(self.bnorm.weight, 1)
|
|
315
327
|
init.constant_(self.bnorm.bias, 0)
|
|
@@ -329,5 +341,32 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
329
341
|
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
|
|
330
342
|
init.constant_(self.final_layer.conv_classifier.bias, 0)
|
|
331
343
|
|
|
332
|
-
|
|
333
|
-
|
|
344
|
+
self.train()
|
|
345
|
+
|
|
346
|
+
def _get_min_n_times(self) -> int:
|
|
347
|
+
"""
|
|
348
|
+
Calculate the minimum number of time samples required for the model
|
|
349
|
+
to work with the given temporal parameters.
|
|
350
|
+
"""
|
|
351
|
+
# Start with the minimum valid output length of the network (1)
|
|
352
|
+
min_len = 1
|
|
353
|
+
|
|
354
|
+
# List of conv kernel sizes and pool parameters for the 4 blocks, in reverse order
|
|
355
|
+
# Each tuple: (filter_length, pool_length, pool_stride)
|
|
356
|
+
block_params = [
|
|
357
|
+
(self.filter_length_4, self.pool_time_length, self.pool_time_stride),
|
|
358
|
+
(self.filter_length_3, self.pool_time_length, self.pool_time_stride),
|
|
359
|
+
(self.filter_length_2, self.pool_time_length, self.pool_time_stride),
|
|
360
|
+
(self.filter_time_length, self.pool_time_length, self.pool_time_stride),
|
|
361
|
+
]
|
|
362
|
+
|
|
363
|
+
# Work backward from the last layer to the input
|
|
364
|
+
for filter_len, pool_len, pool_stride in block_params:
|
|
365
|
+
# Reverse the pooling operation
|
|
366
|
+
# L_in = stride * (L_out - 1) + kernel_size
|
|
367
|
+
min_len = pool_stride * (min_len - 1) + pool_len
|
|
368
|
+
# Reverse the convolution operation (assuming stride=1)
|
|
369
|
+
# L_in = L_out + kernel_size - 1
|
|
370
|
+
min_len = min_len + filter_len - 1
|
|
371
|
+
|
|
372
|
+
return min_len
|
|
@@ -4,11 +4,133 @@
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
|
|
7
|
-
from .base import EEGModuleMixin
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
11
|
-
|
|
10
|
+
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
|
|
12
|
+
|
|
13
|
+
.. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
|
|
14
|
+
:align: center
|
|
15
|
+
:alt: DeepSleepNet Architecture
|
|
16
|
+
|
|
17
|
+
Convolutional neural network and bidirectional-Long Short-Term
|
|
18
|
+
for single channels sleep staging described in [Supratak2017]_.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
activation_large: nn.Module, default=nn.ELU
|
|
23
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
24
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
25
|
+
activation_small: nn.Module, default=nn.ReLU
|
|
26
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
27
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
28
|
+
return_feats : bool
|
|
29
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
30
|
+
(before the final linear layer). If False, pass the features through
|
|
31
|
+
the final linear layer.
|
|
32
|
+
drop_prob : float, default=0.5
|
|
33
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
References
|
|
37
|
+
----------
|
|
38
|
+
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
39
|
+
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
40
|
+
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
41
|
+
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
n_outputs=5,
|
|
47
|
+
return_feats=False,
|
|
48
|
+
n_chans=None,
|
|
49
|
+
chs_info=None,
|
|
50
|
+
n_times=None,
|
|
51
|
+
input_window_seconds=None,
|
|
52
|
+
sfreq=None,
|
|
53
|
+
activation_large: nn.Module = nn.ELU,
|
|
54
|
+
activation_small: nn.Module = nn.ReLU,
|
|
55
|
+
drop_prob: float = 0.5,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(
|
|
58
|
+
n_outputs=n_outputs,
|
|
59
|
+
n_chans=n_chans,
|
|
60
|
+
chs_info=chs_info,
|
|
61
|
+
n_times=n_times,
|
|
62
|
+
input_window_seconds=input_window_seconds,
|
|
63
|
+
sfreq=sfreq,
|
|
64
|
+
)
|
|
65
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
66
|
+
self.cnn1 = _SmallCNN(activation=activation_small, drop_prob=drop_prob)
|
|
67
|
+
self.cnn2 = _LargeCNN(activation=activation_large, drop_prob=drop_prob)
|
|
68
|
+
self.dropout = nn.Dropout(0.5)
|
|
69
|
+
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
70
|
+
self.fc = nn.Sequential(
|
|
71
|
+
nn.Linear(3072, 1024, bias=False), nn.BatchNorm1d(num_features=1024)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.features_extractor = nn.Identity()
|
|
75
|
+
self.len_last_layer = 1024
|
|
76
|
+
self.return_feats = return_feats
|
|
77
|
+
|
|
78
|
+
# TODO: Add new way to handle return_features == True
|
|
79
|
+
if not return_feats:
|
|
80
|
+
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
81
|
+
else:
|
|
82
|
+
self.final_layer = nn.Identity()
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
"""Forward pass.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
x: torch.Tensor
|
|
90
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if x.ndim == 3:
|
|
94
|
+
x = x.unsqueeze(1)
|
|
95
|
+
|
|
96
|
+
x1 = self.cnn1(x)
|
|
97
|
+
x1 = x1.flatten(start_dim=1)
|
|
98
|
+
|
|
99
|
+
x2 = self.cnn2(x)
|
|
100
|
+
x2 = x2.flatten(start_dim=1)
|
|
101
|
+
|
|
102
|
+
x = torch.cat((x1, x2), dim=1)
|
|
103
|
+
x = self.dropout(x)
|
|
104
|
+
temp = x.clone()
|
|
105
|
+
temp = self.fc(temp)
|
|
106
|
+
x = x.unsqueeze(1)
|
|
107
|
+
x = self.bilstm(x)
|
|
108
|
+
x = x.squeeze()
|
|
109
|
+
x = torch.add(x, temp)
|
|
110
|
+
x = self.dropout(x)
|
|
111
|
+
|
|
112
|
+
feats = self.features_extractor(x)
|
|
113
|
+
|
|
114
|
+
if self.return_feats:
|
|
115
|
+
return feats
|
|
116
|
+
else:
|
|
117
|
+
return self.final_layer(feats)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class _SmallCNN(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
Smaller filter sizes to learn temporal information.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
activation: nn.Module, default=nn.ReLU
|
|
127
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
128
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
129
|
+
drop_prob : float, default=0.5
|
|
130
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
|
|
12
134
|
super().__init__()
|
|
13
135
|
self.conv1 = nn.Sequential(
|
|
14
136
|
nn.Conv2d(
|
|
@@ -20,10 +142,10 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
20
142
|
bias=False,
|
|
21
143
|
),
|
|
22
144
|
nn.BatchNorm2d(num_features=64),
|
|
23
|
-
|
|
145
|
+
activation(),
|
|
24
146
|
)
|
|
25
147
|
self.pool1 = nn.MaxPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 2))
|
|
26
|
-
self.dropout = nn.Dropout(p=
|
|
148
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
27
149
|
self.conv2 = nn.Sequential(
|
|
28
150
|
nn.Conv2d(
|
|
29
151
|
in_channels=64,
|
|
@@ -34,7 +156,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
34
156
|
bias=False,
|
|
35
157
|
),
|
|
36
158
|
nn.BatchNorm2d(num_features=128),
|
|
37
|
-
|
|
159
|
+
activation(),
|
|
38
160
|
)
|
|
39
161
|
self.conv3 = nn.Sequential(
|
|
40
162
|
nn.Conv2d(
|
|
@@ -46,7 +168,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
46
168
|
bias=False,
|
|
47
169
|
),
|
|
48
170
|
nn.BatchNorm2d(num_features=128),
|
|
49
|
-
|
|
171
|
+
activation(),
|
|
50
172
|
)
|
|
51
173
|
self.conv4 = nn.Sequential(
|
|
52
174
|
nn.Conv2d(
|
|
@@ -58,7 +180,7 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
58
180
|
bias=False,
|
|
59
181
|
),
|
|
60
182
|
nn.BatchNorm2d(num_features=128),
|
|
61
|
-
|
|
183
|
+
activation(),
|
|
62
184
|
)
|
|
63
185
|
self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
|
|
64
186
|
|
|
@@ -72,8 +194,19 @@ class _SmallCNN(nn.Module): # smaller filter sizes to learn temporal informatio
|
|
|
72
194
|
return x
|
|
73
195
|
|
|
74
196
|
|
|
75
|
-
class _LargeCNN(nn.Module):
|
|
76
|
-
|
|
197
|
+
class _LargeCNN(nn.Module):
|
|
198
|
+
"""
|
|
199
|
+
Larger filter sizes to learn frequency information.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
activation: nn.Module, default=nn.ELU
|
|
204
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
205
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
|
|
77
210
|
super().__init__()
|
|
78
211
|
|
|
79
212
|
self.conv1 = nn.Sequential(
|
|
@@ -86,10 +219,10 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
86
219
|
bias=False,
|
|
87
220
|
),
|
|
88
221
|
nn.BatchNorm2d(num_features=64),
|
|
89
|
-
|
|
222
|
+
activation(),
|
|
90
223
|
)
|
|
91
224
|
self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
|
|
92
|
-
self.dropout = nn.Dropout(p=
|
|
225
|
+
self.dropout = nn.Dropout(p=drop_prob)
|
|
93
226
|
self.conv2 = nn.Sequential(
|
|
94
227
|
nn.Conv2d(
|
|
95
228
|
in_channels=64,
|
|
@@ -100,7 +233,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
100
233
|
bias=False,
|
|
101
234
|
),
|
|
102
235
|
nn.BatchNorm2d(num_features=128),
|
|
103
|
-
|
|
236
|
+
activation(),
|
|
104
237
|
)
|
|
105
238
|
self.conv3 = nn.Sequential(
|
|
106
239
|
nn.Conv2d(
|
|
@@ -112,7 +245,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
112
245
|
bias=False,
|
|
113
246
|
),
|
|
114
247
|
nn.BatchNorm2d(num_features=128),
|
|
115
|
-
|
|
248
|
+
activation(),
|
|
116
249
|
)
|
|
117
250
|
self.conv4 = nn.Sequential(
|
|
118
251
|
nn.Conv2d(
|
|
@@ -124,7 +257,7 @@ class _LargeCNN(nn.Module): # larger filter sizes to learn frequency informatio
|
|
|
124
257
|
bias=False,
|
|
125
258
|
),
|
|
126
259
|
nn.BatchNorm2d(num_features=128),
|
|
127
|
-
|
|
260
|
+
activation(),
|
|
128
261
|
)
|
|
129
262
|
self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
|
|
130
263
|
|
|
@@ -154,112 +287,9 @@ class _BiLSTM(nn.Module):
|
|
|
154
287
|
|
|
155
288
|
def forward(self, x):
|
|
156
289
|
# set initial hidden and cell states
|
|
157
|
-
h0 = torch.zeros(
|
|
158
|
-
self.num_layers * 2, x.size(0), self.hidden_size
|
|
159
|
-
).to(x.device)
|
|
290
|
+
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
160
291
|
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
|
|
161
292
|
|
|
162
293
|
# forward propagate LSTM
|
|
163
294
|
out, _ = self.lstm(x, (h0, c0))
|
|
164
295
|
return out
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
168
|
-
"""Sleep staging architecture from Supratak et al 2017.
|
|
169
|
-
|
|
170
|
-
Convolutional neural network and bidirectional-Long Short-Term
|
|
171
|
-
for single channels sleep staging described in [Supratak2017]_.
|
|
172
|
-
|
|
173
|
-
Parameters
|
|
174
|
-
----------
|
|
175
|
-
return_feats : bool
|
|
176
|
-
If True, return the features, i.e. the output of the feature extractor
|
|
177
|
-
(before the final linear layer). If False, pass the features through
|
|
178
|
-
the final linear layer.
|
|
179
|
-
n_classes :
|
|
180
|
-
Alias for n_outputs.
|
|
181
|
-
|
|
182
|
-
References
|
|
183
|
-
----------
|
|
184
|
-
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|
|
185
|
-
DeepSleepNet: A model for automatic sleep stage scoring based
|
|
186
|
-
on raw single-channel EEG. IEEE Transactions on Neural Systems
|
|
187
|
-
and Rehabilitation Engineering, 25(11), 1998-2008.
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
def __init__(
|
|
191
|
-
self,
|
|
192
|
-
n_outputs=5,
|
|
193
|
-
return_feats=False,
|
|
194
|
-
n_chans=None,
|
|
195
|
-
chs_info=None,
|
|
196
|
-
n_times=None,
|
|
197
|
-
input_window_seconds=None,
|
|
198
|
-
sfreq=None,
|
|
199
|
-
n_classes=None,
|
|
200
|
-
):
|
|
201
|
-
n_outputs, = deprecated_args(
|
|
202
|
-
self,
|
|
203
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
204
|
-
)
|
|
205
|
-
super().__init__(
|
|
206
|
-
n_outputs=n_outputs,
|
|
207
|
-
n_chans=n_chans,
|
|
208
|
-
chs_info=chs_info,
|
|
209
|
-
n_times=n_times,
|
|
210
|
-
input_window_seconds=input_window_seconds,
|
|
211
|
-
sfreq=sfreq,
|
|
212
|
-
)
|
|
213
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
214
|
-
del n_classes
|
|
215
|
-
self.cnn1 = _SmallCNN()
|
|
216
|
-
self.cnn2 = _LargeCNN()
|
|
217
|
-
self.dropout = nn.Dropout(0.5)
|
|
218
|
-
self.bilstm = _BiLSTM(input_size=3072, hidden_size=512, num_layers=2)
|
|
219
|
-
self.fc = nn.Sequential(nn.Linear(3072, 1024, bias=False),
|
|
220
|
-
nn.BatchNorm1d(num_features=1024))
|
|
221
|
-
|
|
222
|
-
self.features_extractor = nn.Identity()
|
|
223
|
-
self.len_last_layer = 1024
|
|
224
|
-
self.return_feats = return_feats
|
|
225
|
-
|
|
226
|
-
# TODO: Add new way to handle return_features == True
|
|
227
|
-
if not return_feats:
|
|
228
|
-
self.final_layer = nn.Linear(1024, self.n_outputs)
|
|
229
|
-
else:
|
|
230
|
-
self.final_layer = nn.Identity()
|
|
231
|
-
|
|
232
|
-
def forward(self, x):
|
|
233
|
-
"""Forward pass.
|
|
234
|
-
|
|
235
|
-
Parameters
|
|
236
|
-
----------
|
|
237
|
-
x: torch.Tensor
|
|
238
|
-
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
if x.ndim == 3:
|
|
242
|
-
x = x.unsqueeze(1)
|
|
243
|
-
|
|
244
|
-
x1 = self.cnn1(x)
|
|
245
|
-
x1 = x1.flatten(start_dim=1)
|
|
246
|
-
|
|
247
|
-
x2 = self.cnn2(x)
|
|
248
|
-
x2 = x2.flatten(start_dim=1)
|
|
249
|
-
|
|
250
|
-
x = torch.cat((x1, x2), dim=1)
|
|
251
|
-
x = self.dropout(x)
|
|
252
|
-
temp = x.clone()
|
|
253
|
-
temp = self.fc(temp)
|
|
254
|
-
x = x.unsqueeze(1)
|
|
255
|
-
x = self.bilstm(x)
|
|
256
|
-
x = x.squeeze()
|
|
257
|
-
x = torch.add(x, temp)
|
|
258
|
-
x = self.dropout(x)
|
|
259
|
-
|
|
260
|
-
feats = self.features_extractor(x)
|
|
261
|
-
|
|
262
|
-
if self.return_feats:
|
|
263
|
-
return feats
|
|
264
|
-
else:
|
|
265
|
-
return self.final_layer(feats)
|