braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SincShallowNet(EEGModuleMixin, nn.Module):
|
|
12
|
+
r"""Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
|
|
13
|
+
|
|
14
|
+
:bdg-success:`Convolution` :bdg-warning:`Interpretability`
|
|
15
|
+
|
|
16
|
+
.. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: SincShallowNet Architecture
|
|
19
|
+
|
|
20
|
+
The Sinc-ShallowNet architecture has these fundamental blocks:
|
|
21
|
+
|
|
22
|
+
1. **Block 1: Spectral and Spatial Feature Extraction**
|
|
23
|
+
|
|
24
|
+
- *Temporal Sinc-Convolutional Layer*: Uses parametrized sinc functions to learn band-pass filters,
|
|
25
|
+
significantly reducing the number of trainable parameters by only
|
|
26
|
+
learning the lower and upper cutoff frequencies for each filter.
|
|
27
|
+
- *Spatial Depthwise Convolutional Layer*: Applies depthwise convolutions to learn spatial filters for
|
|
28
|
+
each temporal feature map independently, further reducing
|
|
29
|
+
parameters and enhancing interpretability.
|
|
30
|
+
- *Batch Normalization*
|
|
31
|
+
|
|
32
|
+
2. **Block 2: Temporal Aggregation**
|
|
33
|
+
|
|
34
|
+
- *Activation Function*: ELU
|
|
35
|
+
- *Average Pooling Layer*: Aggregation by averaging spatial dim
|
|
36
|
+
- *Dropout Layer*
|
|
37
|
+
- *Flatten Layer*
|
|
38
|
+
|
|
39
|
+
3. **Block 3: Classification**
|
|
40
|
+
|
|
41
|
+
- *Fully Connected Layer*: Maps the feature vector to n_outputs.
|
|
42
|
+
|
|
43
|
+
**Implementation Notes:**
|
|
44
|
+
|
|
45
|
+
- The sinc-convolutional layer initializes cutoff frequencies uniformly
|
|
46
|
+
within the desired frequency range and updates them during training while
|
|
47
|
+
ensuring the lower cutoff is less than the upper cutoff.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
num_time_filters : int
|
|
52
|
+
Number of temporal filters in the SincFilter layer.
|
|
53
|
+
time_filter_len : int
|
|
54
|
+
Size of the temporal filters.
|
|
55
|
+
depth_multiplier : int
|
|
56
|
+
Depth multiplier for spatial filtering.
|
|
57
|
+
activation : nn.Module, optional
|
|
58
|
+
Activation function to use. Default is nn.ELU().
|
|
59
|
+
drop_prob : float, optional
|
|
60
|
+
Dropout probability. Default is 0.5.
|
|
61
|
+
first_freq : float, optional
|
|
62
|
+
The starting frequency for the first Sinc filter. Default is 5.0.
|
|
63
|
+
min_freq : float, optional
|
|
64
|
+
Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
|
|
65
|
+
freq_stride : float, optional
|
|
66
|
+
Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies.
|
|
67
|
+
Default is 1.0.
|
|
68
|
+
padding : str, optional
|
|
69
|
+
Padding mode for convolution, either 'same' or 'valid'. Default is 'same'.
|
|
70
|
+
bandwidth : float, optional
|
|
71
|
+
Initial bandwidth for each Sinc filter. Default is 4.0.
|
|
72
|
+
pool_size : int, optional
|
|
73
|
+
Size of the pooling window for the average pooling layer. Default is 55.
|
|
74
|
+
pool_stride : int, optional
|
|
75
|
+
Stride of the pooling operation. Default is 12.
|
|
76
|
+
|
|
77
|
+
Notes
|
|
78
|
+
-----
|
|
79
|
+
This implementation is based on the implementation from [sincshallowcode]_.
|
|
80
|
+
|
|
81
|
+
References
|
|
82
|
+
----------
|
|
83
|
+
.. [borra2020] Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable
|
|
84
|
+
and lightweight convolutional neural network for EEG decoding: Application
|
|
85
|
+
to movement execution and imagination. Neural Networks, 129, 55-74.
|
|
86
|
+
.. [sincshallowcode] Sinc-ShallowNet re-implementation source code:
|
|
87
|
+
https://github.com/marcellosicbaldi/SincNet-Tensorflow
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
num_time_filters: int = 32,
|
|
93
|
+
time_filter_len: int = 33,
|
|
94
|
+
depth_multiplier: int = 2,
|
|
95
|
+
activation: type[nn.Module] | None = nn.ELU,
|
|
96
|
+
drop_prob: float = 0.5,
|
|
97
|
+
first_freq: float = 5.0,
|
|
98
|
+
min_freq: float = 1.0,
|
|
99
|
+
freq_stride: float = 1.0,
|
|
100
|
+
padding: str = "same",
|
|
101
|
+
bandwidth: float = 4.0,
|
|
102
|
+
pool_size: int = 55,
|
|
103
|
+
pool_stride: int = 12,
|
|
104
|
+
# braindecode parameters
|
|
105
|
+
n_chans=None,
|
|
106
|
+
n_outputs=None,
|
|
107
|
+
n_times=None,
|
|
108
|
+
input_window_seconds=None,
|
|
109
|
+
sfreq=None,
|
|
110
|
+
chs_info=None,
|
|
111
|
+
):
|
|
112
|
+
super().__init__(
|
|
113
|
+
n_outputs=n_outputs,
|
|
114
|
+
n_chans=n_chans,
|
|
115
|
+
chs_info=chs_info,
|
|
116
|
+
n_times=n_times,
|
|
117
|
+
input_window_seconds=input_window_seconds,
|
|
118
|
+
sfreq=sfreq,
|
|
119
|
+
)
|
|
120
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
121
|
+
|
|
122
|
+
# Checkers and creating variables
|
|
123
|
+
if activation is None:
|
|
124
|
+
activation = nn.ELU()
|
|
125
|
+
|
|
126
|
+
# Define low frequencies for the SincFilter
|
|
127
|
+
low_freqs = torch.arange(
|
|
128
|
+
first_freq,
|
|
129
|
+
first_freq + num_time_filters * freq_stride,
|
|
130
|
+
freq_stride,
|
|
131
|
+
dtype=torch.float32,
|
|
132
|
+
)
|
|
133
|
+
self.n_filters = len(low_freqs)
|
|
134
|
+
|
|
135
|
+
if padding.lower() == "valid":
|
|
136
|
+
n_times_after_sinc_filter = self.n_times - time_filter_len + 1
|
|
137
|
+
elif padding.lower() == "same":
|
|
138
|
+
n_times_after_sinc_filter = self.n_times
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Padding must be 'valid' or 'same'.")
|
|
141
|
+
|
|
142
|
+
size_after_pooling = (
|
|
143
|
+
(n_times_after_sinc_filter - pool_size) // pool_stride
|
|
144
|
+
) + 1
|
|
145
|
+
flattened_size = num_time_filters * depth_multiplier * size_after_pooling
|
|
146
|
+
|
|
147
|
+
# Layers
|
|
148
|
+
self.ensuredims = Rearrange("batch chans times -> batch chans times 1")
|
|
149
|
+
|
|
150
|
+
# Block 1: Sinc filter
|
|
151
|
+
self.sinc_filter_layer = _SincFilter(
|
|
152
|
+
low_freqs=low_freqs,
|
|
153
|
+
kernel_size=time_filter_len,
|
|
154
|
+
sfreq=self.sfreq,
|
|
155
|
+
padding=padding,
|
|
156
|
+
bandwidth=bandwidth,
|
|
157
|
+
min_freq=min_freq,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self.depthwiseconv = nn.Sequential(
|
|
161
|
+
# Matching dim to depth wise conv!
|
|
162
|
+
Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
|
|
163
|
+
nn.BatchNorm2d(
|
|
164
|
+
self.n_filters, momentum=0.99
|
|
165
|
+
), # To match keras implementation
|
|
166
|
+
nn.Conv2d(
|
|
167
|
+
in_channels=self.n_filters,
|
|
168
|
+
out_channels=depth_multiplier * self.n_filters,
|
|
169
|
+
kernel_size=(self.n_chans, 1),
|
|
170
|
+
groups=self.n_filters,
|
|
171
|
+
bias=False,
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Block 2: Batch norm, activation, pooling, dropout
|
|
176
|
+
self.temporal_aggregation = nn.Sequential(
|
|
177
|
+
nn.BatchNorm2d(depth_multiplier * self.n_filters, momentum=0.99),
|
|
178
|
+
activation(),
|
|
179
|
+
nn.AvgPool2d(kernel_size=(1, pool_size), stride=(1, pool_stride)),
|
|
180
|
+
nn.Dropout(p=drop_prob),
|
|
181
|
+
nn.Flatten(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Final classification layer
|
|
185
|
+
self.final_layer = nn.Linear(
|
|
186
|
+
flattened_size,
|
|
187
|
+
self.n_outputs,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
191
|
+
"""
|
|
192
|
+
Forward pass of the model.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
x : torch.Tensor
|
|
197
|
+
Input tensor of shape [batch_size, num_channels, num_samples].
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
torch.Tensor
|
|
202
|
+
Output logits of shape [batch_size, num_classes].
|
|
203
|
+
"""
|
|
204
|
+
x = self.ensuredims(x)
|
|
205
|
+
x = self.sinc_filter_layer(x)
|
|
206
|
+
x = self.depthwiseconv(x)
|
|
207
|
+
x = self.temporal_aggregation(x)
|
|
208
|
+
|
|
209
|
+
return self.final_layer(x)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class _SincFilter(nn.Module):
|
|
213
|
+
r"""Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
|
|
214
|
+
|
|
215
|
+
The `SincFilter` layer implements a convolutional layer where each kernel is
|
|
216
|
+
defined using a parametrized sinc function.
|
|
217
|
+
This design enforces each kernel to represent a band-pass filter,
|
|
218
|
+
reducing the number of trainable parameters.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
low_freqs : torch.Tensor
|
|
223
|
+
Initial low cutoff frequencies for each filter.
|
|
224
|
+
kernel_size : int
|
|
225
|
+
Size of the convolutional kernels (filters). Must be odd.
|
|
226
|
+
sfreq : float
|
|
227
|
+
Sampling rate of the input signal.
|
|
228
|
+
bandwidth : float, optional
|
|
229
|
+
Initial bandwidth for each filter. Default is 4.0.
|
|
230
|
+
min_freq : float, optional
|
|
231
|
+
Minimum frequency allowed for low frequencies. Default is 1.0.
|
|
232
|
+
padding : str, optional
|
|
233
|
+
Padding mode, either 'same' or 'valid'. Default is 'same'.
|
|
234
|
+
|
|
235
|
+
References
|
|
236
|
+
----------
|
|
237
|
+
.. [ravanelli] Ravanelli, M., & Bengio, Y. (2018, December). Speaker
|
|
238
|
+
recognition from raw waveform with sincnet. In 2018 IEEE spoken language
|
|
239
|
+
technology workshop (SLT) (pp. 1021-1028). IEEE.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
def __init__(
|
|
243
|
+
self,
|
|
244
|
+
low_freqs: torch.Tensor,
|
|
245
|
+
kernel_size: int,
|
|
246
|
+
sfreq: float,
|
|
247
|
+
bandwidth: float = 4.0,
|
|
248
|
+
min_freq: float = 1.0,
|
|
249
|
+
padding: str = "same",
|
|
250
|
+
):
|
|
251
|
+
super().__init__()
|
|
252
|
+
if kernel_size % 2 == 0:
|
|
253
|
+
raise ValueError("Kernel size must be odd.")
|
|
254
|
+
|
|
255
|
+
self.num_filters = low_freqs.numel()
|
|
256
|
+
self.kernel_size = kernel_size
|
|
257
|
+
self.sfreq = sfreq
|
|
258
|
+
self.min_freq = min_freq
|
|
259
|
+
self.padding = padding.lower()
|
|
260
|
+
|
|
261
|
+
# Precompute constants
|
|
262
|
+
window = torch.hamming_window(kernel_size, periodic=False)
|
|
263
|
+
|
|
264
|
+
self.register_buffer("window", window[: kernel_size // 2].unsqueeze(-1))
|
|
265
|
+
|
|
266
|
+
n_pi = (
|
|
267
|
+
torch.arange(-(kernel_size // 2), 0, dtype=torch.float32)
|
|
268
|
+
/ sfreq
|
|
269
|
+
* 2
|
|
270
|
+
* math.pi
|
|
271
|
+
)
|
|
272
|
+
self.register_buffer("n_pi", n_pi.unsqueeze(-1))
|
|
273
|
+
|
|
274
|
+
# Initialize learnable parameters
|
|
275
|
+
bandwidths = torch.full((1, self.num_filters), bandwidth)
|
|
276
|
+
self.bandwidths = nn.Parameter(bandwidths)
|
|
277
|
+
self.low_freqs = nn.Parameter(low_freqs.unsqueeze(0))
|
|
278
|
+
|
|
279
|
+
# Constant tensor of ones for filter construction
|
|
280
|
+
self.register_buffer("ones", torch.ones(1, 1, 1, self.num_filters))
|
|
281
|
+
|
|
282
|
+
def build_sinc_filters(self) -> torch.Tensor:
|
|
283
|
+
"""Builds the sinc filters based on current parameters."""
|
|
284
|
+
# Computing the low frequencies of the filters
|
|
285
|
+
low_freqs = self.min_freq + torch.abs(self.low_freqs)
|
|
286
|
+
# Setting a minimum band and minimum freq
|
|
287
|
+
high_freqs = torch.clamp(
|
|
288
|
+
low_freqs + torch.abs(self.bandwidths),
|
|
289
|
+
min=self.min_freq,
|
|
290
|
+
max=self.sfreq / 2.0,
|
|
291
|
+
)
|
|
292
|
+
bandwidths = high_freqs - low_freqs
|
|
293
|
+
|
|
294
|
+
# Passing from n_ to the corresponding f_times_t domain
|
|
295
|
+
low = self.n_pi * low_freqs # [kernel_size // 2, num_filters]
|
|
296
|
+
high = self.n_pi * high_freqs # [kernel_size // 2, num_filters]
|
|
297
|
+
|
|
298
|
+
filters_left = (torch.sin(high) - torch.sin(low)) / (self.n_pi / 2.0)
|
|
299
|
+
filters_left *= self.window
|
|
300
|
+
filters_left /= 2.0 * bandwidths
|
|
301
|
+
|
|
302
|
+
# [1, kernel_size // 2, 1, num_filters]
|
|
303
|
+
filters_left = filters_left.unsqueeze(0).unsqueeze(2)
|
|
304
|
+
filters_right = torch.flip(filters_left, dims=[1])
|
|
305
|
+
|
|
306
|
+
filters = torch.cat(
|
|
307
|
+
[filters_left, self.ones, filters_right], dim=1
|
|
308
|
+
) # [1, kernel_size, 1, num_filters]
|
|
309
|
+
filters = filters / torch.std(filters)
|
|
310
|
+
return filters
|
|
311
|
+
|
|
312
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
313
|
+
"""
|
|
314
|
+
Apply sinc filters to the input signal.
|
|
315
|
+
|
|
316
|
+
Parameters
|
|
317
|
+
----------
|
|
318
|
+
inputs : torch.Tensor
|
|
319
|
+
Input tensor of shape [batch_size, num_channels, num_samples, 1].
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
torch.Tensor
|
|
324
|
+
Filtered output tensor of shape [batch_size, num_channels, num_samples, num_filters].
|
|
325
|
+
"""
|
|
326
|
+
filters = self.build_sinc_filters().to(
|
|
327
|
+
inputs.device
|
|
328
|
+
) # [1, kernel_size, 1, num_filters]
|
|
329
|
+
|
|
330
|
+
# Convert from channels_last to channels_first format
|
|
331
|
+
inputs = inputs.permute(0, 3, 1, 2)
|
|
332
|
+
# Permuting to match conv:
|
|
333
|
+
filters = filters.permute(3, 2, 0, 1)
|
|
334
|
+
# Apply convolution
|
|
335
|
+
outputs = F.conv2d(inputs, filters, padding=self.padding)
|
|
336
|
+
# Changing the dimensional
|
|
337
|
+
outputs = outputs.permute(0, 2, 3, 1)
|
|
338
|
+
|
|
339
|
+
return outputs
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Authors: Divyesh Narayanan <divyesh.narayanan@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
|
|
12
|
+
r"""Sleep staging architecture from Blanco et al (2020) from [Blanco2020]_
|
|
13
|
+
|
|
14
|
+
:bdg-success:`Convolution`
|
|
15
|
+
|
|
16
|
+
.. figure:: https://media.springernature.com/full/springer-static/image/art%3A10.1007%2Fs00500-019-04174-1/MediaObjects/500_2019_4174_Fig2_HTML.png
|
|
17
|
+
:align: center
|
|
18
|
+
:alt: SleepStagerBlanco2020 Architecture
|
|
19
|
+
|
|
20
|
+
Convolutional neural network for sleep staging described in [Blanco2020]_.
|
|
21
|
+
A series of seven convolutional layers with kernel sizes running down from 7 to 3,
|
|
22
|
+
in an attempt to extract more general features at the beginning, while more specific
|
|
23
|
+
and complex features were extracted in the final stages.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
n_conv_chans : int
|
|
28
|
+
Number of convolutional channels. Set to 20 in [Blanco2020]_.
|
|
29
|
+
n_groups : int
|
|
30
|
+
Number of groups for the convolution. Set to 2 in [Blanco2020]_ for 2 Channel EEG.
|
|
31
|
+
controls the connections between inputs and outputs. n_channels and n_conv_chans must be
|
|
32
|
+
divisible by n_groups.
|
|
33
|
+
drop_prob : float
|
|
34
|
+
Dropout rate before the output dense layer.
|
|
35
|
+
apply_batch_norm : bool
|
|
36
|
+
If True, apply batch normalization after both temporal convolutional
|
|
37
|
+
layers.
|
|
38
|
+
return_feats : bool
|
|
39
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
40
|
+
(before the final linear layer). If False, pass the features through
|
|
41
|
+
the final linear layer.
|
|
42
|
+
n_channels : int
|
|
43
|
+
Alias for `n_chans`.
|
|
44
|
+
n_classes : int
|
|
45
|
+
Alias for `n_outputs`.
|
|
46
|
+
input_size_s : float
|
|
47
|
+
Alias for `input_window_seconds`.
|
|
48
|
+
activation: nn.Module, default=nn.ReLU
|
|
49
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
50
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
51
|
+
|
|
52
|
+
References
|
|
53
|
+
----------
|
|
54
|
+
.. [Blanco2020] Fernandez-Blanco, E., Rivero, D. & Pazos, A. Convolutional
|
|
55
|
+
neural networks for sleep stage scoring on a two-channel EEG signal.
|
|
56
|
+
Soft Comput 24, 4067–4079 (2020). https://doi.org/10.1007/s00500-019-04174-1
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
n_chans=None,
|
|
62
|
+
sfreq=None,
|
|
63
|
+
n_conv_chans=20,
|
|
64
|
+
input_window_seconds=None,
|
|
65
|
+
n_outputs=5,
|
|
66
|
+
n_groups=2,
|
|
67
|
+
max_pool_size=2,
|
|
68
|
+
drop_prob=0.5,
|
|
69
|
+
apply_batch_norm=False,
|
|
70
|
+
return_feats=False,
|
|
71
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
72
|
+
chs_info=None,
|
|
73
|
+
n_times=None,
|
|
74
|
+
):
|
|
75
|
+
super().__init__(
|
|
76
|
+
n_outputs=n_outputs,
|
|
77
|
+
n_chans=n_chans,
|
|
78
|
+
chs_info=chs_info,
|
|
79
|
+
n_times=n_times,
|
|
80
|
+
input_window_seconds=input_window_seconds,
|
|
81
|
+
sfreq=sfreq,
|
|
82
|
+
)
|
|
83
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
84
|
+
|
|
85
|
+
self.mapping = {
|
|
86
|
+
"fc.1.weight": "final_layer.1.weight",
|
|
87
|
+
"fc.1.bias": "final_layer.1.bias",
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
|
|
91
|
+
|
|
92
|
+
self.feature_extractor = nn.Sequential(
|
|
93
|
+
nn.Conv2d(self.n_chans, n_conv_chans, (1, 7), groups=n_groups, padding=0),
|
|
94
|
+
batch_norm(n_conv_chans),
|
|
95
|
+
activation(),
|
|
96
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
97
|
+
nn.Conv2d(
|
|
98
|
+
n_conv_chans, n_conv_chans, (1, 7), groups=n_conv_chans, padding=0
|
|
99
|
+
),
|
|
100
|
+
batch_norm(n_conv_chans),
|
|
101
|
+
activation(),
|
|
102
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
103
|
+
nn.Conv2d(
|
|
104
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
105
|
+
),
|
|
106
|
+
batch_norm(n_conv_chans),
|
|
107
|
+
activation(),
|
|
108
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
109
|
+
nn.Conv2d(
|
|
110
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
111
|
+
),
|
|
112
|
+
batch_norm(n_conv_chans),
|
|
113
|
+
activation(),
|
|
114
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
115
|
+
nn.Conv2d(
|
|
116
|
+
n_conv_chans, n_conv_chans, (1, 5), groups=n_conv_chans, padding=0
|
|
117
|
+
),
|
|
118
|
+
batch_norm(n_conv_chans),
|
|
119
|
+
activation(),
|
|
120
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
121
|
+
nn.Conv2d(
|
|
122
|
+
n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
|
|
123
|
+
),
|
|
124
|
+
batch_norm(n_conv_chans),
|
|
125
|
+
activation(),
|
|
126
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
127
|
+
nn.Conv2d(
|
|
128
|
+
n_conv_chans, n_conv_chans, (1, 3), groups=n_conv_chans, padding=0
|
|
129
|
+
),
|
|
130
|
+
batch_norm(n_conv_chans),
|
|
131
|
+
activation(),
|
|
132
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.len_last_layer = self._len_last_layer(self.n_chans, self.n_times)
|
|
136
|
+
self.return_feats = return_feats
|
|
137
|
+
|
|
138
|
+
# TODO: Add new way to handle return_features == True
|
|
139
|
+
if not return_feats:
|
|
140
|
+
self.final_layer = nn.Sequential(
|
|
141
|
+
nn.Dropout(drop_prob),
|
|
142
|
+
nn.Linear(self.len_last_layer, self.n_outputs),
|
|
143
|
+
nn.Identity(),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _len_last_layer(self, n_channels, input_size):
|
|
147
|
+
self.feature_extractor.eval()
|
|
148
|
+
with torch.no_grad():
|
|
149
|
+
out = self.feature_extractor(
|
|
150
|
+
torch.Tensor(1, n_channels, 1, input_size)
|
|
151
|
+
) # batch_size,n_channels,height,width
|
|
152
|
+
self.feature_extractor.train()
|
|
153
|
+
return len(out.flatten())
|
|
154
|
+
|
|
155
|
+
def forward(self, x):
|
|
156
|
+
"""Forward pass.
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
x: torch.Tensor
|
|
160
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
161
|
+
"""
|
|
162
|
+
if x.ndim == 3:
|
|
163
|
+
x = x.unsqueeze(2)
|
|
164
|
+
|
|
165
|
+
feats = self.feature_extractor(x).flatten(start_dim=1)
|
|
166
|
+
if self.return_feats:
|
|
167
|
+
return feats
|
|
168
|
+
else:
|
|
169
|
+
return self.final_layer(feats)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
|
|
14
|
+
r"""Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
|
|
15
|
+
|
|
16
|
+
:bdg-success:`Convolution`
|
|
17
|
+
|
|
18
|
+
.. figure:: https://braindecode.org/dev/_static/model/SleepStagerChambon2018.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: SleepStagerChambon2018 Architecture
|
|
21
|
+
|
|
22
|
+
Convolutional neural network for sleep staging described in [Chambon2018]_.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
n_conv_chs : int
|
|
27
|
+
Number of convolutional channels. Set to 8 in [Chambon2018]_.
|
|
28
|
+
time_conv_size_s : float
|
|
29
|
+
Size of filters in temporal convolution layers, in seconds. Set to 0.5
|
|
30
|
+
in [Chambon2018]_ (64 samples at sfreq=128).
|
|
31
|
+
max_pool_size_s : float
|
|
32
|
+
Max pooling size, in seconds. Set to 0.125 in [Chambon2018]_ (16
|
|
33
|
+
samples at sfreq=128).
|
|
34
|
+
pad_size_s : float
|
|
35
|
+
Padding size, in seconds. Set to 0.25 in [Chambon2018]_ (half the
|
|
36
|
+
temporal convolution kernel size).
|
|
37
|
+
drop_prob : float
|
|
38
|
+
Dropout rate before the output dense layer.
|
|
39
|
+
apply_batch_norm : bool
|
|
40
|
+
If True, apply batch normalization after both temporal convolutional
|
|
41
|
+
layers.
|
|
42
|
+
return_feats : bool
|
|
43
|
+
If True, return the features, i.e. the output of the feature extractor
|
|
44
|
+
(before the final linear layer). If False, pass the features through
|
|
45
|
+
the final linear layer.
|
|
46
|
+
n_channels : int
|
|
47
|
+
Alias for `n_chans`.
|
|
48
|
+
input_size_s:
|
|
49
|
+
Alias for `input_window_seconds`.
|
|
50
|
+
n_classes:
|
|
51
|
+
Alias for `n_outputs`.
|
|
52
|
+
activation: nn.Module, default=nn.ReLU
|
|
53
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
54
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
55
|
+
|
|
56
|
+
References
|
|
57
|
+
----------
|
|
58
|
+
.. [Chambon2018] Chambon, S., Galtier, M. N., Arnal, P. J., Wainrib, G., &
|
|
59
|
+
Gramfort, A. (2018). A deep learning architecture for temporal sleep
|
|
60
|
+
stage classification using multivariate and multimodal time series.
|
|
61
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
|
|
62
|
+
26(4), 758-769.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
n_chans=None,
|
|
68
|
+
sfreq=None,
|
|
69
|
+
n_conv_chs=8,
|
|
70
|
+
time_conv_size_s=0.5,
|
|
71
|
+
max_pool_size_s=0.125,
|
|
72
|
+
pad_size_s=0.25,
|
|
73
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
74
|
+
input_window_seconds=None,
|
|
75
|
+
n_outputs=5,
|
|
76
|
+
drop_prob=0.25,
|
|
77
|
+
apply_batch_norm=False,
|
|
78
|
+
return_feats=False,
|
|
79
|
+
chs_info=None,
|
|
80
|
+
n_times=None,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(
|
|
83
|
+
n_outputs=n_outputs,
|
|
84
|
+
n_chans=n_chans,
|
|
85
|
+
chs_info=chs_info,
|
|
86
|
+
n_times=n_times,
|
|
87
|
+
input_window_seconds=input_window_seconds,
|
|
88
|
+
sfreq=sfreq,
|
|
89
|
+
)
|
|
90
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
91
|
+
|
|
92
|
+
self.mapping = {
|
|
93
|
+
"fc.1.weight": "final_layer.1.weight",
|
|
94
|
+
"fc.1.bias": "final_layer.1.bias",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
time_conv_size = math.ceil(time_conv_size_s * self.sfreq)
|
|
98
|
+
max_pool_size = math.ceil(max_pool_size_s * self.sfreq)
|
|
99
|
+
pad_size = math.ceil(pad_size_s * self.sfreq)
|
|
100
|
+
|
|
101
|
+
if self.n_chans > 1:
|
|
102
|
+
self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
|
|
103
|
+
else:
|
|
104
|
+
self.spatial_conv = nn.Identity()
|
|
105
|
+
|
|
106
|
+
batch_norm = nn.BatchNorm2d if apply_batch_norm else nn.Identity
|
|
107
|
+
|
|
108
|
+
self.feature_extractor = nn.Sequential(
|
|
109
|
+
nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
|
|
110
|
+
batch_norm(n_conv_chs),
|
|
111
|
+
activation(),
|
|
112
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
113
|
+
nn.Conv2d(
|
|
114
|
+
n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)
|
|
115
|
+
),
|
|
116
|
+
batch_norm(n_conv_chs),
|
|
117
|
+
activation(),
|
|
118
|
+
nn.MaxPool2d((1, max_pool_size)),
|
|
119
|
+
)
|
|
120
|
+
self.return_feats = return_feats
|
|
121
|
+
|
|
122
|
+
dim_conv_1 = (
|
|
123
|
+
self.n_times + 2 * pad_size - (time_conv_size - 1)
|
|
124
|
+
) // max_pool_size
|
|
125
|
+
dim_after_conv = (
|
|
126
|
+
dim_conv_1 + 2 * pad_size - (time_conv_size - 1)
|
|
127
|
+
) // max_pool_size
|
|
128
|
+
|
|
129
|
+
self.len_last_layer = n_conv_chs * self.n_chans * dim_after_conv
|
|
130
|
+
|
|
131
|
+
# TODO: Add new way to handle return_features == True
|
|
132
|
+
if not return_feats:
|
|
133
|
+
self.final_layer = nn.Sequential(
|
|
134
|
+
nn.Dropout(p=drop_prob),
|
|
135
|
+
nn.Linear(in_features=self.len_last_layer, out_features=self.n_outputs),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
139
|
+
"""
|
|
140
|
+
Forward pass.
|
|
141
|
+
|
|
142
|
+
Parameters
|
|
143
|
+
----------
|
|
144
|
+
x: torch.Tensor
|
|
145
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
146
|
+
"""
|
|
147
|
+
if x.ndim == 3:
|
|
148
|
+
x = x.unsqueeze(1)
|
|
149
|
+
|
|
150
|
+
if self.n_chans > 1:
|
|
151
|
+
x = self.spatial_conv(x)
|
|
152
|
+
x = x.transpose(1, 2)
|
|
153
|
+
|
|
154
|
+
feats = self.feature_extractor(x).flatten(start_dim=1)
|
|
155
|
+
|
|
156
|
+
if self.return_feats:
|
|
157
|
+
return feats
|
|
158
|
+
|
|
159
|
+
return self.final_layer(feats)
|