braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from mne.utils import warn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContraWR(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
|
|
12
|
+
|
|
13
|
+
This model is a convolutional neural network that uses a spectral
|
|
14
|
+
representation with a series of convolutional layers and residual blocks.
|
|
15
|
+
The model is designed to learn a representation of the EEG signal that can
|
|
16
|
+
be used for sleep staging.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
steps : int, optional
|
|
21
|
+
Number of steps to take the frequency decomposition `hop_length`
|
|
22
|
+
parameters by default 20.
|
|
23
|
+
emb_size : int, optional
|
|
24
|
+
Embedding size for the final layer, by default 256.
|
|
25
|
+
res_channels : list[int], optional
|
|
26
|
+
Number of channels for each residual block, by default [32, 64, 128].
|
|
27
|
+
activation: nn.Module, default=nn.ELU
|
|
28
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
29
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
30
|
+
drop_prob : float, default=0.5
|
|
31
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
32
|
+
|
|
33
|
+
.. versionadded:: 0.9
|
|
34
|
+
|
|
35
|
+
Notes
|
|
36
|
+
-----
|
|
37
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
38
|
+
by original authors. The modifications are minimal and the model is expected
|
|
39
|
+
to work as intended. the original code from [Code2023]_.
|
|
40
|
+
|
|
41
|
+
References
|
|
42
|
+
----------
|
|
43
|
+
.. [Yang2021] Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023).
|
|
44
|
+
Self-supervised electroencephalogram representation learning for automatic
|
|
45
|
+
sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.
|
|
46
|
+
.. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
|
|
47
|
+
Biosignal Transformer for Cross-data Learning in the Wild.
|
|
48
|
+
GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
n_chans=None,
|
|
54
|
+
n_outputs=None,
|
|
55
|
+
sfreq=None,
|
|
56
|
+
emb_size: int = 256,
|
|
57
|
+
res_channels: list[int] = [32, 64, 128],
|
|
58
|
+
steps=20,
|
|
59
|
+
activation: nn.Module = nn.ELU,
|
|
60
|
+
drop_prob: float = 0.5,
|
|
61
|
+
stride_res: int = 2,
|
|
62
|
+
kernel_size_res: int = 3,
|
|
63
|
+
padding_res: int = 1,
|
|
64
|
+
# Another way to pass the EEG parameters
|
|
65
|
+
chs_info=None,
|
|
66
|
+
n_times=None,
|
|
67
|
+
input_window_seconds=None,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(
|
|
70
|
+
n_outputs=n_outputs,
|
|
71
|
+
n_chans=n_chans,
|
|
72
|
+
chs_info=chs_info,
|
|
73
|
+
n_times=n_times,
|
|
74
|
+
input_window_seconds=input_window_seconds,
|
|
75
|
+
sfreq=sfreq,
|
|
76
|
+
)
|
|
77
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
78
|
+
if not isinstance(res_channels, list):
|
|
79
|
+
raise ValueError("res_channels must be a list of integers.")
|
|
80
|
+
|
|
81
|
+
if self.input_window_seconds < 1.0:
|
|
82
|
+
warning_msg = (
|
|
83
|
+
"The input window is less than 1 second, which may not be "
|
|
84
|
+
"sufficient for the model to learn meaningful representations."
|
|
85
|
+
"changing the `n_fft` to `n_times`."
|
|
86
|
+
)
|
|
87
|
+
warn(warning_msg, UserWarning)
|
|
88
|
+
self.n_fft = self.n_times
|
|
89
|
+
else:
|
|
90
|
+
self.n_fft = int(self.sfreq)
|
|
91
|
+
|
|
92
|
+
self.steps = steps
|
|
93
|
+
|
|
94
|
+
res_channels = [self.n_chans] + res_channels + [emb_size]
|
|
95
|
+
|
|
96
|
+
self.torch_stft = _STFTModule(
|
|
97
|
+
n_fft=self.n_fft,
|
|
98
|
+
hop_length=int(self.n_fft // self.steps),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
self.convs = nn.ModuleList(
|
|
102
|
+
[
|
|
103
|
+
_ResBlock(
|
|
104
|
+
in_channels=res_channels[i],
|
|
105
|
+
out_channels=res_channels[i + 1],
|
|
106
|
+
stride=stride_res,
|
|
107
|
+
use_downsampling=True,
|
|
108
|
+
pooling=True,
|
|
109
|
+
drop_prob=drop_prob,
|
|
110
|
+
kernel_size=kernel_size_res,
|
|
111
|
+
padding=padding_res,
|
|
112
|
+
activation=activation,
|
|
113
|
+
)
|
|
114
|
+
for i in range(len(res_channels) - 1)
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
self.adaptative_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
118
|
+
self.flatten_layer = nn.Flatten()
|
|
119
|
+
|
|
120
|
+
self.activation_layer = activation()
|
|
121
|
+
self.final_layer = nn.Linear(emb_size, self.n_outputs)
|
|
122
|
+
|
|
123
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
124
|
+
"""
|
|
125
|
+
Forward pass.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
X: Tensor
|
|
130
|
+
Input tensor of shape (batch_size, n_channels, n_times).
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
Tensor
|
|
134
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
135
|
+
"""
|
|
136
|
+
X = self.torch_stft(X)
|
|
137
|
+
|
|
138
|
+
for conv in self.convs:
|
|
139
|
+
X = conv.forward(X)
|
|
140
|
+
|
|
141
|
+
emb = self.adaptative_pool(X)
|
|
142
|
+
emb = self.flatten_layer(emb)
|
|
143
|
+
emb = self.activation_layer(emb)
|
|
144
|
+
|
|
145
|
+
return self.final_layer(emb)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _ResBlock(nn.Module):
|
|
149
|
+
"""Convolutional Residual Block 2D.
|
|
150
|
+
|
|
151
|
+
This block stacks two convolutional layers with batch normalization,
|
|
152
|
+
max pooling, dropout, and residual connection.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
in_channels : int
|
|
157
|
+
Number of input channels.
|
|
158
|
+
out_channels : int
|
|
159
|
+
Number of output channels.
|
|
160
|
+
stride : int (default=1)
|
|
161
|
+
Stride of the convolutional layers.
|
|
162
|
+
use_downsampling : bool (default=True)
|
|
163
|
+
Whether to use a downsampling residual connection.
|
|
164
|
+
pooling : bool (default=True)
|
|
165
|
+
Whether to use max pooling.
|
|
166
|
+
kernel_size : int (default=3)
|
|
167
|
+
Kernel size of the convolutional layers.
|
|
168
|
+
padding : int (default=1)
|
|
169
|
+
Padding of the convolutional layers.
|
|
170
|
+
activation: nn.Module, default=nn.ELU
|
|
171
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
172
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
173
|
+
drop_prob : float, default=0.5
|
|
174
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
175
|
+
|
|
176
|
+
Examples
|
|
177
|
+
--------
|
|
178
|
+
>>> import torch
|
|
179
|
+
>>> model = ResBlock2D(6, 16, 1, True, True)
|
|
180
|
+
>>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width)
|
|
181
|
+
>>> output = model(input_)
|
|
182
|
+
>>> output.shape
|
|
183
|
+
torch.Size([16, 16, 14, 75])
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
in_channels,
|
|
189
|
+
out_channels,
|
|
190
|
+
stride=1,
|
|
191
|
+
use_downsampling=True,
|
|
192
|
+
pooling=True,
|
|
193
|
+
kernel_size=3,
|
|
194
|
+
padding=1,
|
|
195
|
+
drop_prob=0.5,
|
|
196
|
+
activation: nn.Module = nn.ReLU,
|
|
197
|
+
):
|
|
198
|
+
super().__init__()
|
|
199
|
+
self.conv1 = nn.Conv2d(
|
|
200
|
+
in_channels=in_channels,
|
|
201
|
+
out_channels=out_channels,
|
|
202
|
+
kernel_size=kernel_size,
|
|
203
|
+
stride=stride,
|
|
204
|
+
padding=padding,
|
|
205
|
+
)
|
|
206
|
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
207
|
+
self.relu = activation()
|
|
208
|
+
self.conv2 = nn.Conv2d(
|
|
209
|
+
in_channels=out_channels,
|
|
210
|
+
out_channels=out_channels,
|
|
211
|
+
kernel_size=kernel_size,
|
|
212
|
+
padding=padding,
|
|
213
|
+
)
|
|
214
|
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
215
|
+
self.maxpool = nn.MaxPool2d(
|
|
216
|
+
kernel_size=kernel_size, stride=stride, padding=padding
|
|
217
|
+
)
|
|
218
|
+
self.downsample = nn.Sequential(
|
|
219
|
+
nn.Conv2d(
|
|
220
|
+
in_channels=in_channels,
|
|
221
|
+
out_channels=out_channels,
|
|
222
|
+
kernel_size=kernel_size,
|
|
223
|
+
stride=stride,
|
|
224
|
+
padding=padding,
|
|
225
|
+
),
|
|
226
|
+
nn.BatchNorm2d(out_channels),
|
|
227
|
+
)
|
|
228
|
+
self.use_downsampling = use_downsampling
|
|
229
|
+
self.pooling = pooling
|
|
230
|
+
self.dropout = nn.Dropout(drop_prob)
|
|
231
|
+
|
|
232
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
X: Tensor
|
|
238
|
+
Input tensor of shape (batch_size, n_channels, n_freqs, n_times).
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
Tensor
|
|
243
|
+
Output tensor of shape (batch_size, n_channels, n_freqs, n_times).
|
|
244
|
+
"""
|
|
245
|
+
out = self.conv1(x)
|
|
246
|
+
out = self.bn1(out)
|
|
247
|
+
out = self.relu(out)
|
|
248
|
+
out = self.conv2(out)
|
|
249
|
+
out = self.bn2(out)
|
|
250
|
+
if self.use_downsampling:
|
|
251
|
+
residual = self.downsample(x)
|
|
252
|
+
out += residual
|
|
253
|
+
if self.pooling:
|
|
254
|
+
out = self.maxpool(out)
|
|
255
|
+
out = self.dropout(out)
|
|
256
|
+
return out
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class _STFTModule(nn.Module):
|
|
260
|
+
"""
|
|
261
|
+
A PyTorch module that computes the Short-Time Fourier Transform (STFT)
|
|
262
|
+
of an EEG batch tensor.
|
|
263
|
+
|
|
264
|
+
Expects input of shape (batch_size, n_channels, n_times) and returns
|
|
265
|
+
(batch_size, n_channels, n_freqs, n_times).
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
n_fft: int,
|
|
271
|
+
hop_length: int,
|
|
272
|
+
center: bool = True,
|
|
273
|
+
onesided: bool = True,
|
|
274
|
+
return_complex: bool = True,
|
|
275
|
+
normalized: bool = True,
|
|
276
|
+
):
|
|
277
|
+
"""
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
n_fft : int
|
|
281
|
+
Number of FFT points (window size).
|
|
282
|
+
steps : int
|
|
283
|
+
Number of hops per window (i.e. hop_length = n_fft // steps).
|
|
284
|
+
"""
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.n_fft = n_fft
|
|
287
|
+
self.hop_length = hop_length
|
|
288
|
+
self.center = center
|
|
289
|
+
self.one_sided = onesided
|
|
290
|
+
self.return_complex = return_complex
|
|
291
|
+
self.normalized = normalized
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
294
|
+
window = torch.ones(self.n_fft, device=x.device)
|
|
295
|
+
|
|
296
|
+
# x: (B, C, T)
|
|
297
|
+
B, C, T = x.shape
|
|
298
|
+
# flatten batch & channel into one dim
|
|
299
|
+
x_flat = x.reshape(B * C, T)
|
|
300
|
+
|
|
301
|
+
# compute stft on 2D tensor
|
|
302
|
+
spec_flat = torch.stft(
|
|
303
|
+
x_flat,
|
|
304
|
+
n_fft=self.n_fft,
|
|
305
|
+
hop_length=self.hop_length,
|
|
306
|
+
win_length=self.n_fft,
|
|
307
|
+
window=window,
|
|
308
|
+
normalized=self.normalized,
|
|
309
|
+
center=self.center,
|
|
310
|
+
onesided=self.one_sided,
|
|
311
|
+
return_complex=self.return_complex,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
F, L = spec_flat.shape[-2], spec_flat.shape[-1]
|
|
315
|
+
spec = spec_flat.view(B, C, F, L)
|
|
316
|
+
|
|
317
|
+
return torch.abs(spec)
|