braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
|
|
9
|
+
from braindecode.models.base import EEGModuleMixin
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SincShallowNet(EEGModuleMixin, nn.Module):
|
|
13
|
+
"""Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
|
|
14
|
+
|
|
15
|
+
.. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: SincShallowNet Architecture
|
|
18
|
+
|
|
19
|
+
The Sinc-ShallowNet architecture has these fundamental blocks:
|
|
20
|
+
|
|
21
|
+
1. **Block 1: Spectral and Spatial Feature Extraction**
|
|
22
|
+
- *Temporal Sinc-Convolutional Layer*:
|
|
23
|
+
Uses parametrized sinc functions to learn band-pass filters,
|
|
24
|
+
significantly reducing the number of trainable parameters by only
|
|
25
|
+
learning the lower and upper cutoff frequencies for each filter.
|
|
26
|
+
- *Spatial Depthwise Convolutional Layer*:
|
|
27
|
+
Applies depthwise convolutions to learn spatial filters for
|
|
28
|
+
each temporal feature map independently, further reducing
|
|
29
|
+
parameters and enhancing interpretability.
|
|
30
|
+
- *Batch Normalization*
|
|
31
|
+
|
|
32
|
+
2. **Block 2: Temporal Aggregation**
|
|
33
|
+
- *Activation Function*: ELU
|
|
34
|
+
- *Average Pooling Layer*: Aggregation by averaging spatial dim
|
|
35
|
+
- *Dropout Layer*
|
|
36
|
+
- *Flatten Layer*
|
|
37
|
+
|
|
38
|
+
3. **Block 3: Classification**
|
|
39
|
+
- *Fully Connected Layer*: Maps the feature vector to n_outputs.
|
|
40
|
+
|
|
41
|
+
**Implementation Notes:**
|
|
42
|
+
|
|
43
|
+
- The sinc-convolutional layer initializes cutoff frequencies uniformly
|
|
44
|
+
within the desired frequency range and updates them during training while
|
|
45
|
+
ensuring the lower cutoff is less than the upper cutoff.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
num_time_filters : int
|
|
50
|
+
Number of temporal filters in the SincFilter layer.
|
|
51
|
+
time_filter_len : int
|
|
52
|
+
Size of the temporal filters.
|
|
53
|
+
depth_multiplier : int
|
|
54
|
+
Depth multiplier for spatial filtering.
|
|
55
|
+
activation : nn.Module, optional
|
|
56
|
+
Activation function to use. Default is nn.ELU().
|
|
57
|
+
drop_prob : float, optional
|
|
58
|
+
Dropout probability. Default is 0.5.
|
|
59
|
+
first_freq : float, optional
|
|
60
|
+
The starting frequency for the first Sinc filter. Default is 5.0.
|
|
61
|
+
min_freq : float, optional
|
|
62
|
+
Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
|
|
63
|
+
freq_stride : float, optional
|
|
64
|
+
Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies.
|
|
65
|
+
Default is 1.0.
|
|
66
|
+
padding : str, optional
|
|
67
|
+
Padding mode for convolution, either 'same' or 'valid'. Default is 'same'.
|
|
68
|
+
bandwidth : float, optional
|
|
69
|
+
Initial bandwidth for each Sinc filter. Default is 4.0.
|
|
70
|
+
pool_size : int, optional
|
|
71
|
+
Size of the pooling window for the average pooling layer. Default is 55.
|
|
72
|
+
pool_stride : int, optional
|
|
73
|
+
Stride of the pooling operation. Default is 12.
|
|
74
|
+
|
|
75
|
+
Notes
|
|
76
|
+
-----
|
|
77
|
+
This implementation is based on the implementation from [sincshallowcode]_.
|
|
78
|
+
|
|
79
|
+
References
|
|
80
|
+
----------
|
|
81
|
+
.. [borra2020] Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable
|
|
82
|
+
and lightweight convolutional neural network for EEG decoding: Application
|
|
83
|
+
to movement execution and imagination. Neural Networks, 129, 55-74.
|
|
84
|
+
.. [sincshallowcode] Sinc-ShallowNet re-implementation source code:
|
|
85
|
+
https://github.com/marcellosicbaldi/SincNet-Tensorflow
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
num_time_filters: int = 32,
|
|
91
|
+
time_filter_len: int = 33,
|
|
92
|
+
depth_multiplier: int = 2,
|
|
93
|
+
activation: Optional[nn.Module] = nn.ELU,
|
|
94
|
+
drop_prob: float = 0.5,
|
|
95
|
+
first_freq: float = 5.0,
|
|
96
|
+
min_freq: float = 1.0,
|
|
97
|
+
freq_stride: float = 1.0,
|
|
98
|
+
padding: str = "same",
|
|
99
|
+
bandwidth: float = 4.0,
|
|
100
|
+
pool_size: int = 55,
|
|
101
|
+
pool_stride: int = 12,
|
|
102
|
+
# braindecode parameters
|
|
103
|
+
n_chans=None,
|
|
104
|
+
n_outputs=None,
|
|
105
|
+
n_times=None,
|
|
106
|
+
input_window_seconds=None,
|
|
107
|
+
sfreq=None,
|
|
108
|
+
chs_info=None,
|
|
109
|
+
):
|
|
110
|
+
super().__init__(
|
|
111
|
+
n_outputs=n_outputs,
|
|
112
|
+
n_chans=n_chans,
|
|
113
|
+
chs_info=chs_info,
|
|
114
|
+
n_times=n_times,
|
|
115
|
+
input_window_seconds=input_window_seconds,
|
|
116
|
+
sfreq=sfreq,
|
|
117
|
+
)
|
|
118
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
119
|
+
|
|
120
|
+
# Checkers and creating variables
|
|
121
|
+
if activation is None:
|
|
122
|
+
activation = nn.ELU()
|
|
123
|
+
|
|
124
|
+
# Define low frequencies for the SincFilter
|
|
125
|
+
low_freqs = torch.arange(
|
|
126
|
+
first_freq,
|
|
127
|
+
first_freq + num_time_filters * freq_stride,
|
|
128
|
+
freq_stride,
|
|
129
|
+
dtype=torch.float32,
|
|
130
|
+
)
|
|
131
|
+
self.n_filters = len(low_freqs)
|
|
132
|
+
|
|
133
|
+
if padding.lower() == "valid":
|
|
134
|
+
n_times_after_sinc_filter = self.n_times - time_filter_len + 1
|
|
135
|
+
elif padding.lower() == "same":
|
|
136
|
+
n_times_after_sinc_filter = self.n_times
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError("Padding must be 'valid' or 'same'.")
|
|
139
|
+
|
|
140
|
+
size_after_pooling = (
|
|
141
|
+
(n_times_after_sinc_filter - pool_size) // pool_stride
|
|
142
|
+
) + 1
|
|
143
|
+
flattened_size = num_time_filters * depth_multiplier * size_after_pooling
|
|
144
|
+
|
|
145
|
+
# Layers
|
|
146
|
+
self.ensuredims = Rearrange("batch chans times -> batch chans times 1")
|
|
147
|
+
|
|
148
|
+
# Block 1: Sinc filter
|
|
149
|
+
self.sinc_filter_layer = _SincFilter(
|
|
150
|
+
low_freqs=low_freqs,
|
|
151
|
+
kernel_size=time_filter_len,
|
|
152
|
+
sfreq=self.sfreq,
|
|
153
|
+
padding=padding,
|
|
154
|
+
bandwidth=bandwidth,
|
|
155
|
+
min_freq=min_freq,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.depthwiseconv = nn.Sequential(
|
|
159
|
+
# Matching dim to depth wise conv!
|
|
160
|
+
Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
|
|
161
|
+
nn.BatchNorm2d(
|
|
162
|
+
self.n_filters, momentum=0.99
|
|
163
|
+
), # To match keras implementation
|
|
164
|
+
nn.Conv2d(
|
|
165
|
+
in_channels=self.n_filters,
|
|
166
|
+
out_channels=depth_multiplier * self.n_filters,
|
|
167
|
+
kernel_size=(self.n_chans, 1),
|
|
168
|
+
groups=self.n_filters,
|
|
169
|
+
bias=False,
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Block 2: Batch norm, activation, pooling, dropout
|
|
174
|
+
self.temporal_aggregation = nn.Sequential(
|
|
175
|
+
nn.BatchNorm2d(depth_multiplier * self.n_filters, momentum=0.99),
|
|
176
|
+
activation(),
|
|
177
|
+
nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_stride)),
|
|
178
|
+
nn.Dropout(p=drop_prob),
|
|
179
|
+
nn.Flatten(),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Final classification layer
|
|
183
|
+
self.final_layer = nn.Linear(
|
|
184
|
+
flattened_size,
|
|
185
|
+
self.n_outputs,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
189
|
+
"""
|
|
190
|
+
Forward pass of the model.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
x : torch.Tensor
|
|
195
|
+
Input tensor of shape [batch_size, num_channels, num_samples].
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
torch.Tensor
|
|
200
|
+
Output logits of shape [batch_size, num_classes].
|
|
201
|
+
"""
|
|
202
|
+
x = self.ensuredims(x)
|
|
203
|
+
x = self.sinc_filter_layer(x)
|
|
204
|
+
x = self.depthwiseconv(x)
|
|
205
|
+
x = self.temporal_aggregation(x)
|
|
206
|
+
|
|
207
|
+
return self.final_layer(x)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class _SincFilter(nn.Module):
|
|
211
|
+
"""Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
|
|
212
|
+
|
|
213
|
+
The `SincFilter` layer implements a convolutional layer where each kernel is
|
|
214
|
+
defined using a parametrized sinc function.
|
|
215
|
+
This design enforces each kernel to represent a band-pass filter,
|
|
216
|
+
reducing the number of trainable parameters.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
low_freqs : torch.Tensor
|
|
221
|
+
Initial low cutoff frequencies for each filter.
|
|
222
|
+
kernel_size : int
|
|
223
|
+
Size of the convolutional kernels (filters). Must be odd.
|
|
224
|
+
sfreq : float
|
|
225
|
+
Sampling rate of the input signal.
|
|
226
|
+
bandwidth : float, optional
|
|
227
|
+
Initial bandwidth for each filter. Default is 4.0.
|
|
228
|
+
min_freq : float, optional
|
|
229
|
+
Minimum frequency allowed for low frequencies. Default is 1.0.
|
|
230
|
+
padding : str, optional
|
|
231
|
+
Padding mode, either 'same' or 'valid'. Default is 'same'.
|
|
232
|
+
|
|
233
|
+
References
|
|
234
|
+
----------
|
|
235
|
+
.. [ravanelli] Ravanelli, M., & Bengio, Y. (2018, December). Speaker
|
|
236
|
+
recognition from raw waveform with sincnet. In 2018 IEEE spoken language
|
|
237
|
+
technology workshop (SLT) (pp. 1021-1028). IEEE.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
low_freqs: torch.Tensor,
|
|
243
|
+
kernel_size: int,
|
|
244
|
+
sfreq: float,
|
|
245
|
+
bandwidth: float = 4.0,
|
|
246
|
+
min_freq: float = 1.0,
|
|
247
|
+
padding: str = "same",
|
|
248
|
+
):
|
|
249
|
+
super().__init__()
|
|
250
|
+
if kernel_size % 2 == 0:
|
|
251
|
+
raise ValueError("Kernel size must be odd.")
|
|
252
|
+
|
|
253
|
+
self.num_filters = low_freqs.numel()
|
|
254
|
+
self.kernel_size = kernel_size
|
|
255
|
+
self.sfreq = sfreq
|
|
256
|
+
self.min_freq = min_freq
|
|
257
|
+
self.padding = padding.lower()
|
|
258
|
+
|
|
259
|
+
# Precompute constants
|
|
260
|
+
window = torch.hamming_window(kernel_size, periodic=False)
|
|
261
|
+
|
|
262
|
+
self.register_buffer("window", window[: kernel_size // 2].unsqueeze(-1))
|
|
263
|
+
|
|
264
|
+
n_pi = (
|
|
265
|
+
torch.arange(-(kernel_size // 2), 0, dtype=torch.float32)
|
|
266
|
+
/ sfreq
|
|
267
|
+
* 2
|
|
268
|
+
* math.pi
|
|
269
|
+
)
|
|
270
|
+
self.register_buffer("n_pi", n_pi.unsqueeze(-1))
|
|
271
|
+
|
|
272
|
+
# Initialize learnable parameters
|
|
273
|
+
bandwidths = torch.full((1, self.num_filters), bandwidth)
|
|
274
|
+
self.bandwidths = nn.Parameter(bandwidths)
|
|
275
|
+
self.low_freqs = nn.Parameter(low_freqs.unsqueeze(0))
|
|
276
|
+
|
|
277
|
+
# Constant tensor of ones for filter construction
|
|
278
|
+
self.register_buffer("ones", torch.ones(1, 1, 1, self.num_filters))
|
|
279
|
+
|
|
280
|
+
def build_sinc_filters(self) -> torch.Tensor:
|
|
281
|
+
"""Builds the sinc filters based on current parameters."""
|
|
282
|
+
# Computing the low frequencies of the filters
|
|
283
|
+
low_freqs = self.min_freq + torch.abs(self.low_freqs)
|
|
284
|
+
# Setting a minimum band and minimum freq
|
|
285
|
+
high_freqs = torch.clamp(
|
|
286
|
+
low_freqs + torch.abs(self.bandwidths),
|
|
287
|
+
min=self.min_freq,
|
|
288
|
+
max=self.sfreq / 2.0,
|
|
289
|
+
)
|
|
290
|
+
bandwidths = high_freqs - low_freqs
|
|
291
|
+
|
|
292
|
+
# Passing from n_ to the corresponding f_times_t domain
|
|
293
|
+
low = self.n_pi * low_freqs # [kernel_size // 2, num_filters]
|
|
294
|
+
high = self.n_pi * high_freqs # [kernel_size // 2, num_filters]
|
|
295
|
+
|
|
296
|
+
filters_left = (torch.sin(high) - torch.sin(low)) / (self.n_pi / 2.0)
|
|
297
|
+
filters_left *= self.window
|
|
298
|
+
filters_left /= 2.0 * bandwidths
|
|
299
|
+
|
|
300
|
+
# [1, kernel_size // 2, 1, num_filters]
|
|
301
|
+
filters_left = filters_left.unsqueeze(0).unsqueeze(2)
|
|
302
|
+
filters_right = torch.flip(filters_left, dims=[1])
|
|
303
|
+
|
|
304
|
+
filters = torch.cat(
|
|
305
|
+
[filters_left, self.ones, filters_right], dim=1
|
|
306
|
+
) # [1, kernel_size, 1, num_filters]
|
|
307
|
+
filters = filters / torch.std(filters)
|
|
308
|
+
return filters
|
|
309
|
+
|
|
310
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
311
|
+
"""
|
|
312
|
+
Apply sinc filters to the input signal.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
inputs : torch.Tensor
|
|
317
|
+
Input tensor of shape [batch_size, num_channels, num_samples, 1].
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
torch.Tensor
|
|
322
|
+
Filtered output tensor of shape [batch_size, num_channels, num_samples, num_filters].
|
|
323
|
+
"""
|
|
324
|
+
filters = self.build_sinc_filters().to(
|
|
325
|
+
inputs.device
|
|
326
|
+
) # [1, kernel_size, 1, num_filters]
|
|
327
|
+
|
|
328
|
+
# Convert from channels_last to channels_first format
|
|
329
|
+
inputs = inputs.permute(0, 3, 1, 2)
|
|
330
|
+
# Permuting to match conv:
|
|
331
|
+
filters = filters.permute(3, 2, 0, 1)
|
|
332
|
+
# Apply convolution
|
|
333
|
+
outputs = F.conv2d(inputs, filters, padding=self.padding)
|
|
334
|
+
# Changing the dimensional
|
|
335
|
+
outputs = outputs.permute(0, 2, 3, 1)
|
|
336
|
+
|
|
337
|
+
return outputs
|
|
@@ -5,11 +5,15 @@
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
-
from .base import EEGModuleMixin
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
12
|
-
"""Sleep staging architecture from Blanco et al 2020
|
|
12
|
+
"""Sleep staging architecture from Blanco et al. (2020) from [Blanco2020]_
|
|
13
|
+
|
|
14
|
+
.. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1007%2Fs00500-019-04174-1/MediaObjects/500_2019_4174_Fig2_HTML.png
|
|
15
|
+
:align: center
|
|
16
|
+
:alt: SleepStagerBlanco2020 Architecture
|
|
13
17
|
|
|
14
18
|
Convolutional neural network for sleep staging described in [Blanco2020]_.
|
|
15
19
|
A series of seven convolutional layers with kernel sizes running down from 7 to 3,
|
|
@@ -24,7 +28,7 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
24
28
|
Number of groups for the convolution. Set to 2 in [Blanco2020]_ for 2 Channel EEG.
|
|
25
29
|
controls the connections between inputs and outputs. n_channels and n_conv_chans must be
|
|
26
30
|
divisible by n_groups.
|
|
27
|
-
|
|
31
|
+
drop_prob : float
|
|
28
32
|
Dropout rate before the output dense layer.
|
|
29
33
|
apply_batch_norm : bool
|
|
30
34
|
If True, apply batch normalization after both temporal convolutional
|
|
@@ -39,6 +43,9 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
39
43
|
Alias for `n_outputs`.
|
|
40
44
|
input_size_s : float
|
|
41
45
|
Alias for `input_window_seconds`.
|
|
46
|
+
activation: nn.Module, default=nn.ReLU
|
|
47
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
48
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
42
49
|
|
|
43
50
|
References
|
|
44
51
|
----------
|
|
@@ -48,30 +55,21 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
48
55
|
"""
|
|
49
56
|
|
|
50
57
|
def __init__(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
n_classes=None,
|
|
66
|
-
input_size_s=None,
|
|
67
|
-
add_log_softmax=True,
|
|
58
|
+
self,
|
|
59
|
+
n_chans=None,
|
|
60
|
+
sfreq=None,
|
|
61
|
+
n_conv_chans=20,
|
|
62
|
+
input_window_seconds=None,
|
|
63
|
+
n_outputs=5,
|
|
64
|
+
n_groups=2,
|
|
65
|
+
max_pool_size=2,
|
|
66
|
+
drop_prob=0.5,
|
|
67
|
+
apply_batch_norm=False,
|
|
68
|
+
return_feats=False,
|
|
69
|
+
activation: nn.Module = nn.ReLU,
|
|
70
|
+
chs_info=None,
|
|
71
|
+
n_times=None,
|
|
68
72
|
):
|
|
69
|
-
n_chans, n_outputs, input_window_seconds, = deprecated_args(
|
|
70
|
-
self,
|
|
71
|
-
("n_channels", "n_chans", n_channels, n_chans),
|
|
72
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
73
|
-
("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
|
|
74
|
-
)
|
|
75
73
|
super().__init__(
|
|
76
74
|
n_outputs=n_outputs,
|
|
77
75
|
n_chans=n_chans,
|
|
@@ -79,14 +77,12 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
79
77
|
n_times=n_times,
|
|
80
78
|
input_window_seconds=input_window_seconds,
|
|
81
79
|
sfreq=sfreq,
|
|
82
|
-
add_log_softmax=add_log_softmax,
|
|
83
80
|
)
|
|
84
81
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
85
|
-
del n_channels, n_classes, input_size_s
|
|
86
82
|
|
|
87
83
|
self.mapping = {
|
|
88
84
|
"fc.1.weight": "final_layer.1.weight",
|
|
89
|
-
"fc.1.bias": "final_layer.1.bias"
|
|
85
|
+
"fc.1.bias": "final_layer.1.bias",
|
|
90
86
|
}
|
|
91
87
|
|
|
92
88
|
batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
|
|
@@ -94,32 +90,44 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
94
90
|
self.feature_extractor = nn.Sequential(
|
|
95
91
|
nn.Conv2d(self.n_chans, n_conv_chans, (1, 7), groups=n_groups, padding=0),
|
|
96
92
|
batch_norm(n_conv_chans),
|
|
97
|
-
|
|
93
|
+
activation(),
|
|
98
94
|
nn.MaxPool2d((1, max_pool_size)),
|
|
99
|
-
nn.Conv2d(
|
|
95
|
+
nn.Conv2d(
|
|
96
|
+
n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0
|
|
97
|
+
),
|
|
100
98
|
batch_norm(n_conv_chans),
|
|
101
|
-
|
|
99
|
+
activation(),
|
|
102
100
|
nn.MaxPool2d((1, max_pool_size)),
|
|
103
|
-
nn.Conv2d(
|
|
101
|
+
nn.Conv2d(
|
|
102
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
103
|
+
),
|
|
104
104
|
batch_norm(n_conv_chans),
|
|
105
|
-
|
|
105
|
+
activation(),
|
|
106
106
|
nn.MaxPool2d((1, max_pool_size)),
|
|
107
|
-
nn.Conv2d(
|
|
107
|
+
nn.Conv2d(
|
|
108
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
109
|
+
),
|
|
108
110
|
batch_norm(n_conv_chans),
|
|
109
|
-
|
|
111
|
+
activation(),
|
|
110
112
|
nn.MaxPool2d((1, max_pool_size)),
|
|
111
|
-
nn.Conv2d(
|
|
113
|
+
nn.Conv2d(
|
|
114
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
115
|
+
),
|
|
112
116
|
batch_norm(n_conv_chans),
|
|
113
|
-
|
|
117
|
+
activation(),
|
|
114
118
|
nn.MaxPool2d((1, max_pool_size)),
|
|
115
|
-
nn.Conv2d(
|
|
119
|
+
nn.Conv2d(
|
|
120
|
+
n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
|
|
121
|
+
),
|
|
116
122
|
batch_norm(n_conv_chans),
|
|
117
|
-
|
|
123
|
+
activation(),
|
|
118
124
|
nn.MaxPool2d((1, max_pool_size)),
|
|
119
|
-
nn.Conv2d(
|
|
125
|
+
nn.Conv2d(
|
|
126
|
+
n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
|
|
127
|
+
),
|
|
120
128
|
batch_norm(n_conv_chans),
|
|
121
|
-
|
|
122
|
-
nn.MaxPool2d((1, max_pool_size))
|
|
129
|
+
activation(),
|
|
130
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
123
131
|
)
|
|
124
132
|
|
|
125
133
|
self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
|
|
@@ -128,16 +136,17 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
|
128
136
|
# TODO: Add new way to handle return_features == True
|
|
129
137
|
if not return_feats:
|
|
130
138
|
self.final_layer = nn.Sequential(
|
|
131
|
-
nn.Dropout(
|
|
139
|
+
nn.Dropout(drop_prob),
|
|
132
140
|
nn.Linear(self.len_last_layer, self.n_outputs),
|
|
133
|
-
nn.
|
|
141
|
+
nn.Identity(),
|
|
134
142
|
)
|
|
135
143
|
|
|
136
144
|
def _len_last_layer(self, n_channels, input_size):
|
|
137
145
|
self.feature_extractor.eval()
|
|
138
146
|
with torch.no_grad():
|
|
139
147
|
out = self.feature_extractor(
|
|
140
|
-
torch.Tensor(1, n_channels, 1, input_size)
|
|
148
|
+
torch.Tensor(1, n_channels, 1, input_size)
|
|
149
|
+
) # batch_size,n_channels,height,width
|
|
141
150
|
self.feature_extractor.train()
|
|
142
151
|
return len(out.flatten())
|
|
143
152
|
|
|
@@ -2,14 +2,20 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
+
import math
|
|
6
|
+
|
|
5
7
|
import torch
|
|
6
8
|
from torch import nn
|
|
7
|
-
|
|
8
|
-
from .base import EEGModuleMixin
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
12
|
-
"""Sleep staging architecture from Chambon et al 2018.
|
|
14
|
+
"""Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
|
|
15
|
+
|
|
16
|
+
.. figure:: https://braindecode.org/dev/_static/model/SleepStagerChambon2018.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: SleepStagerChambon2018 Architecture
|
|
13
19
|
|
|
14
20
|
Convolutional neural network for sleep staging described in [Chambon2018]_.
|
|
15
21
|
|
|
@@ -26,7 +32,7 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
26
32
|
pad_size_s : float
|
|
27
33
|
Padding size, in seconds. Set to 0.25 in [Chambon2018]_ (half the
|
|
28
34
|
temporal convolution kernel size).
|
|
29
|
-
|
|
35
|
+
drop_prob : float
|
|
30
36
|
Dropout rate before the output dense layer.
|
|
31
37
|
apply_batch_norm : bool
|
|
32
38
|
If True, apply batch normalization after both temporal convolutional
|
|
@@ -41,6 +47,9 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
41
47
|
Alias for `input_window_seconds`.
|
|
42
48
|
n_classes:
|
|
43
49
|
Alias for `n_outputs`.
|
|
50
|
+
activation: nn.Module, default=nn.ReLU
|
|
51
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
52
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
44
53
|
|
|
45
54
|
References
|
|
46
55
|
----------
|
|
@@ -52,30 +61,22 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
52
61
|
"""
|
|
53
62
|
|
|
54
63
|
def __init__(
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
input_size_s=None,
|
|
71
|
-
n_classes=None,
|
|
64
|
+
self,
|
|
65
|
+
n_chans=None,
|
|
66
|
+
sfreq=None,
|
|
67
|
+
n_conv_chs=8,
|
|
68
|
+
time_conv_size_s=0.5,
|
|
69
|
+
max_pool_size_s=0.125,
|
|
70
|
+
pad_size_s=0.25,
|
|
71
|
+
activation: nn.Module = nn.ReLU,
|
|
72
|
+
input_window_seconds=None,
|
|
73
|
+
n_outputs=5,
|
|
74
|
+
drop_prob=0.25,
|
|
75
|
+
apply_batch_norm=False,
|
|
76
|
+
return_feats=False,
|
|
77
|
+
chs_info=None,
|
|
78
|
+
n_times=None,
|
|
72
79
|
):
|
|
73
|
-
n_chans, n_outputs, input_window_seconds, = deprecated_args(
|
|
74
|
-
self,
|
|
75
|
-
("n_channels", "n_chans", n_channels, n_chans),
|
|
76
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
77
|
-
("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
|
|
78
|
-
)
|
|
79
80
|
super().__init__(
|
|
80
81
|
n_outputs=n_outputs,
|
|
81
82
|
n_chans=n_chans,
|
|
@@ -85,54 +86,54 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
85
86
|
sfreq=sfreq,
|
|
86
87
|
)
|
|
87
88
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
88
|
-
del n_channels, n_classes, input_size_s
|
|
89
89
|
|
|
90
90
|
self.mapping = {
|
|
91
91
|
"fc.1.weight": "final_layer.1.weight",
|
|
92
|
-
"fc.1.bias": "final_layer.1.bias"
|
|
92
|
+
"fc.1.bias": "final_layer.1.bias",
|
|
93
93
|
}
|
|
94
94
|
|
|
95
|
-
time_conv_size =
|
|
96
|
-
max_pool_size =
|
|
97
|
-
pad_size =
|
|
95
|
+
time_conv_size = math.ceil(time_conv_size_s * self.sfreq)
|
|
96
|
+
max_pool_size = math.ceil(max_pool_size_s * self.sfreq)
|
|
97
|
+
pad_size = math.ceil(pad_size_s * self.sfreq)
|
|
98
98
|
|
|
99
99
|
if self.n_chans > 1:
|
|
100
100
|
self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
|
|
101
|
+
else:
|
|
102
|
+
self.spatial_conv = nn.Identity()
|
|
101
103
|
|
|
102
104
|
batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
|
|
103
105
|
|
|
104
106
|
self.feature_extractor = nn.Sequential(
|
|
105
|
-
nn.Conv2d(
|
|
106
|
-
1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
|
|
107
|
+
nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
|
|
107
108
|
batch_norm(n_conv_chs),
|
|
108
|
-
|
|
109
|
+
activation(),
|
|
109
110
|
nn.MaxPool2d((1, max_pool_size)),
|
|
110
111
|
nn.Conv2d(
|
|
111
|
-
n_conv_chs, n_conv_chs, (1, time_conv_size),
|
|
112
|
-
|
|
112
|
+
n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)
|
|
113
|
+
),
|
|
113
114
|
batch_norm(n_conv_chs),
|
|
114
|
-
|
|
115
|
-
nn.MaxPool2d((1, max_pool_size))
|
|
115
|
+
activation(),
|
|
116
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
116
117
|
)
|
|
117
|
-
self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
|
|
118
118
|
self.return_feats = return_feats
|
|
119
119
|
|
|
120
|
+
dim_conv_1 = (
|
|
121
|
+
self.n_times + 2 * pad_size - (time_conv_size - 1)
|
|
122
|
+
) // max_pool_size
|
|
123
|
+
dim_after_conv = (
|
|
124
|
+
dim_conv_1 + 2 * pad_size - (time_conv_size - 1)
|
|
125
|
+
) // max_pool_size
|
|
126
|
+
|
|
127
|
+
self.len_last_layer = n_conv_chs * self.n_chans * dim_after_conv
|
|
128
|
+
|
|
120
129
|
# TODO: Add new way to handle return_features == True
|
|
121
130
|
if not return_feats:
|
|
122
131
|
self.final_layer = nn.Sequential(
|
|
123
|
-
nn.Dropout(
|
|
124
|
-
nn.Linear(self.len_last_layer, self.n_outputs),
|
|
132
|
+
nn.Dropout(p=drop_prob),
|
|
133
|
+
nn.Linear(in_features=self.len_last_layer, out_features=self.n_outputs),
|
|
125
134
|
)
|
|
126
135
|
|
|
127
|
-
def
|
|
128
|
-
self.feature_extractor.eval()
|
|
129
|
-
with torch.no_grad():
|
|
130
|
-
out = self.feature_extractor(
|
|
131
|
-
torch.Tensor(1, 1, n_channels, input_size))
|
|
132
|
-
self.feature_extractor.train()
|
|
133
|
-
return len(out.flatten())
|
|
134
|
-
|
|
135
|
-
def forward(self, x):
|
|
136
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
137
|
"""
|
|
137
138
|
Forward pass.
|
|
138
139
|
|
|
@@ -152,5 +153,5 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
|
152
153
|
|
|
153
154
|
if self.return_feats:
|
|
154
155
|
return feats
|
|
155
|
-
|
|
156
|
-
|
|
156
|
+
|
|
157
|
+
return self.final_layer(feats)
|