braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
3
4
|
#
|
|
4
5
|
# License: BSD (3-clause)
|
|
5
6
|
|
|
6
7
|
from numbers import Real
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from mne.filter import notch_filter
|
|
9
12
|
from scipy.interpolate import Rbf
|
|
10
13
|
from sklearn.utils import check_random_state
|
|
11
|
-
import torch
|
|
12
14
|
from torch.fft import fft, ifft
|
|
13
|
-
from torch.nn.functional import
|
|
14
|
-
from mne.filter import notch_filter
|
|
15
|
+
from torch.nn.functional import one_hot, pad
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def identity(X, y):
|
|
@@ -79,11 +80,14 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
|
|
|
79
80
|
random_phase = torch.from_numpy(
|
|
80
81
|
2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
|
|
81
82
|
).to(device)
|
|
82
|
-
return torch.cat(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
83
|
+
return torch.cat(
|
|
84
|
+
[
|
|
85
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
86
|
+
random_phase,
|
|
87
|
+
-torch.flip(random_phase, [-1]),
|
|
88
|
+
],
|
|
89
|
+
dim=-1,
|
|
90
|
+
)
|
|
87
91
|
|
|
88
92
|
|
|
89
93
|
def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
|
|
@@ -91,27 +95,21 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
|
|
|
91
95
|
random_phase = torch.from_numpy(
|
|
92
96
|
2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
|
|
93
97
|
).to(device)
|
|
94
|
-
return torch.cat(
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def ft_surrogate(
|
|
109
|
-
X,
|
|
110
|
-
y,
|
|
111
|
-
phase_noise_magnitude,
|
|
112
|
-
channel_indep,
|
|
113
|
-
random_state=None
|
|
114
|
-
):
|
|
98
|
+
return torch.cat(
|
|
99
|
+
[
|
|
100
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
101
|
+
random_phase,
|
|
102
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
103
|
+
-torch.flip(random_phase, [-1]),
|
|
104
|
+
],
|
|
105
|
+
dim=-1,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
_new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
|
|
115
113
|
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
|
|
116
114
|
|
|
117
115
|
Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
|
|
@@ -148,11 +146,12 @@ def ft_surrogate(
|
|
|
148
146
|
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
|
|
149
147
|
preprint arXiv:1806.08675.
|
|
150
148
|
"""
|
|
151
|
-
assert
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
149
|
+
assert (
|
|
150
|
+
isinstance(
|
|
151
|
+
phase_noise_magnitude, (Real, torch.FloatTensor, torch.cuda.FloatTensor)
|
|
152
|
+
)
|
|
153
|
+
and 0 <= phase_noise_magnitude <= 1
|
|
154
|
+
), f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}."
|
|
156
155
|
|
|
157
156
|
f = fft(X.double(), dim=-1)
|
|
158
157
|
device = X.device
|
|
@@ -163,7 +162,7 @@ def ft_surrogate(
|
|
|
163
162
|
f.shape[-2] if channel_indep else 1,
|
|
164
163
|
n,
|
|
165
164
|
device=device,
|
|
166
|
-
random_state=random_state
|
|
165
|
+
random_state=random_state,
|
|
167
166
|
)
|
|
168
167
|
if not channel_indep:
|
|
169
168
|
random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
|
|
@@ -186,7 +185,7 @@ def _pick_channels_randomly(X, p_pick, random_state):
|
|
|
186
185
|
device=X.device,
|
|
187
186
|
)
|
|
188
187
|
# equivalent to a 0s and 1s mask
|
|
189
|
-
return torch.sigmoid(1000*(unif_samples - p_pick))
|
|
188
|
+
return torch.sigmoid(1000 * (unif_samples - p_pick))
|
|
190
189
|
|
|
191
190
|
|
|
192
191
|
def channels_dropout(X, y, p_drop, random_state=None):
|
|
@@ -234,7 +233,8 @@ def _make_permutation_matrix(X, mask, random_state):
|
|
|
234
233
|
channels_to_shuffle = torch.arange(n_channels, device=X.device)
|
|
235
234
|
channels_to_shuffle = channels_to_shuffle[mask.bool()]
|
|
236
235
|
reordered_channels = torch.tensor(
|
|
237
|
-
rng.permutation(channels_to_shuffle.cpu()), device=X.device
|
|
236
|
+
rng.permutation(channels_to_shuffle.cpu()), device=X.device
|
|
237
|
+
)
|
|
238
238
|
channels_permutation = torch.arange(n_channels, device=X.device)
|
|
239
239
|
channels_permutation[channels_to_shuffle] = reordered_channels
|
|
240
240
|
batch_permutations[b, ...] = one_hot(channels_permutation)
|
|
@@ -320,12 +320,14 @@ def gaussian_noise(X, y, std, random_state=None):
|
|
|
320
320
|
rng = check_random_state(random_state)
|
|
321
321
|
if isinstance(std, torch.Tensor):
|
|
322
322
|
std = std.to(X.device)
|
|
323
|
-
noise =
|
|
324
|
-
|
|
325
|
-
loc=np.zeros(X.shape),
|
|
326
|
-
|
|
327
|
-
)
|
|
328
|
-
|
|
323
|
+
noise = (
|
|
324
|
+
torch.from_numpy(
|
|
325
|
+
rng.normal(loc=np.zeros(X.shape), scale=1),
|
|
326
|
+
)
|
|
327
|
+
.float()
|
|
328
|
+
.to(X.device)
|
|
329
|
+
* std
|
|
330
|
+
)
|
|
329
331
|
transformed_X = X + noise
|
|
330
332
|
return transformed_X, y
|
|
331
333
|
|
|
@@ -399,9 +401,14 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
|
|
|
399
401
|
t = t.repeat(batch_size, n_channels, 1)
|
|
400
402
|
mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
|
|
401
403
|
s = 1000 / seq_len
|
|
402
|
-
mask = (
|
|
403
|
-
|
|
404
|
-
|
|
404
|
+
mask = (
|
|
405
|
+
(
|
|
406
|
+
torch.sigmoid(s * -(t - mask_start_per_sample))
|
|
407
|
+
+ torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
|
|
408
|
+
)
|
|
409
|
+
.float()
|
|
410
|
+
.to(X.device)
|
|
411
|
+
)
|
|
405
412
|
return X * mask, y
|
|
406
413
|
|
|
407
414
|
|
|
@@ -447,17 +454,18 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
|
|
|
447
454
|
if bandwidth == 0:
|
|
448
455
|
return X, y
|
|
449
456
|
transformed_X = X.clone()
|
|
450
|
-
for c, (sample, notched_freq) in enumerate(
|
|
451
|
-
zip(transformed_X, freqs_to_notch)):
|
|
457
|
+
for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
|
|
452
458
|
sample = sample.cpu().numpy().astype(np.float64)
|
|
453
|
-
transformed_X[c] = torch.as_tensor(
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
459
|
+
transformed_X[c] = torch.as_tensor(
|
|
460
|
+
notch_filter(
|
|
461
|
+
sample,
|
|
462
|
+
Fs=sfreq,
|
|
463
|
+
freqs=notched_freq,
|
|
464
|
+
method="fir",
|
|
465
|
+
notch_widths=bandwidth,
|
|
466
|
+
verbose=False,
|
|
467
|
+
)
|
|
468
|
+
)
|
|
461
469
|
return transformed_X, y
|
|
462
470
|
|
|
463
471
|
|
|
@@ -470,10 +478,10 @@ def _analytic_transform(x):
|
|
|
470
478
|
h = torch.zeros_like(f)
|
|
471
479
|
if N % 2 == 0:
|
|
472
480
|
h[..., 0] = h[..., N // 2] = 1
|
|
473
|
-
h[..., 1:N // 2] = 2
|
|
481
|
+
h[..., 1 : N // 2] = 2
|
|
474
482
|
else:
|
|
475
483
|
h[..., 0] = 1
|
|
476
|
-
h[..., 1:(N + 1) // 2] = 2
|
|
484
|
+
h[..., 1 : (N + 1) // 2] = 2
|
|
477
485
|
|
|
478
486
|
return ifft(f * h, dim=-1)
|
|
479
487
|
|
|
@@ -500,7 +508,8 @@ def _frequency_shift(X, fs, f_shift):
|
|
|
500
508
|
f_shift = torch.as_tensor(f_shift).float()
|
|
501
509
|
f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
|
|
502
510
|
reshaped_f_shift = f_shift_stack.permute(
|
|
503
|
-
*torch.arange(f_shift_stack.ndim - 1, -1, -1)
|
|
511
|
+
*torch.arange(f_shift_stack.ndim - 1, -1, -1)
|
|
512
|
+
)
|
|
504
513
|
shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
|
|
505
514
|
return shifted[..., :N_orig].real.float()
|
|
506
515
|
|
|
@@ -539,7 +548,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
|
|
|
539
548
|
def _torch_normalize_vectors(rr):
|
|
540
549
|
"""Normalize surface vertices."""
|
|
541
550
|
norm = torch.linalg.norm(rr, axis=1, keepdim=True)
|
|
542
|
-
mask =
|
|
551
|
+
mask = norm > 0
|
|
543
552
|
norm[~mask] = 1 # in case norm is zero, divide by 1
|
|
544
553
|
new_rr = rr / norm
|
|
545
554
|
return new_rr
|
|
@@ -631,7 +640,7 @@ def _torch_legval(x, c, tensor=True):
|
|
|
631
640
|
if isinstance(x, (tuple, list)):
|
|
632
641
|
x = torch.as_tensor(x)
|
|
633
642
|
if isinstance(x, torch.Tensor) and tensor:
|
|
634
|
-
c = c.view(c.shape + (1,)*x.ndim)
|
|
643
|
+
c = c.view(c.shape + (1,) * x.ndim)
|
|
635
644
|
|
|
636
645
|
c = c.to(x.device)
|
|
637
646
|
|
|
@@ -648,9 +657,9 @@ def _torch_legval(x, c, tensor=True):
|
|
|
648
657
|
for i in range(3, len(c) + 1):
|
|
649
658
|
tmp = c0
|
|
650
659
|
nd = nd - 1
|
|
651
|
-
c0 = c[-i] - (c1*(nd - 1))/nd
|
|
652
|
-
c1 = tmp + (c1*x*(2*nd - 1))/nd
|
|
653
|
-
return c0 + c1*x
|
|
660
|
+
c0 = c[-i] - (c1 * (nd - 1)) / nd
|
|
661
|
+
c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
|
|
662
|
+
return c0 + c1 * x
|
|
654
663
|
|
|
655
664
|
|
|
656
665
|
def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
|
@@ -702,9 +711,10 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
|
|
702
711
|
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
703
712
|
DAMAGE.
|
|
704
713
|
"""
|
|
705
|
-
factors = [
|
|
706
|
-
|
|
707
|
-
|
|
714
|
+
factors = [
|
|
715
|
+
(2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
|
|
716
|
+
for n in range(1, n_legendre_terms + 1)
|
|
717
|
+
]
|
|
708
718
|
return _torch_legval(cosang, [0] + factors)
|
|
709
719
|
|
|
710
720
|
|
|
@@ -783,15 +793,20 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
|
|
783
793
|
assert G_to_from.shape == (n_to, n_from)
|
|
784
794
|
|
|
785
795
|
if alpha is not None:
|
|
786
|
-
G_from.flatten()[::len(G_from) + 1] += alpha
|
|
796
|
+
G_from.flatten()[:: len(G_from) + 1] += alpha
|
|
787
797
|
|
|
788
798
|
device = G_from.device
|
|
789
|
-
C = torch.vstack(
|
|
799
|
+
C = torch.vstack(
|
|
800
|
+
[
|
|
790
801
|
torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
|
|
791
|
-
torch.hstack(
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
802
|
+
torch.hstack(
|
|
803
|
+
[
|
|
804
|
+
torch.ones((1, n_from), device=device),
|
|
805
|
+
torch.as_tensor([[0]], device=device),
|
|
806
|
+
]
|
|
807
|
+
),
|
|
808
|
+
]
|
|
809
|
+
)
|
|
795
810
|
|
|
796
811
|
try:
|
|
797
812
|
C_inv = torch.linalg.inv(C)
|
|
@@ -800,10 +815,9 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
|
|
800
815
|
# see https://github.com/pytorch/pytorch/issues/75494
|
|
801
816
|
C_inv = torch.linalg.pinv(C.cpu()).to(device)
|
|
802
817
|
|
|
803
|
-
interpolation = torch.hstack(
|
|
804
|
-
G_to_from,
|
|
805
|
-
|
|
806
|
-
]).matmul(C_inv[:, :-1])
|
|
818
|
+
interpolation = torch.hstack(
|
|
819
|
+
[G_to_from, torch.ones((n_to, 1), device=device)]
|
|
820
|
+
).matmul(C_inv[:, :-1])
|
|
807
821
|
assert interpolation.shape == (n_to, n_from)
|
|
808
822
|
return interpolation
|
|
809
823
|
|
|
@@ -815,11 +829,15 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
|
|
|
815
829
|
]
|
|
816
830
|
if spherical:
|
|
817
831
|
interpolation_matrix = torch.stack(
|
|
818
|
-
[
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
832
|
+
[
|
|
833
|
+
torch.as_tensor(
|
|
834
|
+
_torch_make_interpolation_matrix(
|
|
835
|
+
sensors_positions_matrix.T, rot_sensors_matrix.T
|
|
836
|
+
),
|
|
837
|
+
device=X.device,
|
|
838
|
+
).float()
|
|
839
|
+
for rot_sensors_matrix in rot_sensors_matrices
|
|
840
|
+
]
|
|
823
841
|
)
|
|
824
842
|
return torch.matmul(interpolation_matrix, X)
|
|
825
843
|
else:
|
|
@@ -836,7 +854,7 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
|
|
|
836
854
|
|
|
837
855
|
|
|
838
856
|
def _make_rotation_matrix(axis, angle, degrees=True):
|
|
839
|
-
assert axis in [
|
|
857
|
+
assert axis in ["x", "y", "z"], "axis should be either x, y or z."
|
|
840
858
|
|
|
841
859
|
if isinstance(angle, (float, int, np.ndarray, list)):
|
|
842
860
|
angle = torch.as_tensor(angle)
|
|
@@ -846,11 +864,13 @@ def _make_rotation_matrix(axis, angle, degrees=True):
|
|
|
846
864
|
|
|
847
865
|
device = angle.device
|
|
848
866
|
zero = torch.zeros(1, device=device)
|
|
849
|
-
rot = torch.stack(
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
867
|
+
rot = torch.stack(
|
|
868
|
+
[
|
|
869
|
+
torch.as_tensor([1, 0, 0], device=device),
|
|
870
|
+
torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
|
|
871
|
+
torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
|
|
872
|
+
]
|
|
873
|
+
)
|
|
854
874
|
if axis == "x":
|
|
855
875
|
return rot
|
|
856
876
|
elif axis == "y":
|
|
@@ -861,8 +881,7 @@ def _make_rotation_matrix(axis, angle, degrees=True):
|
|
|
861
881
|
return rot[:, [1, 2, 0]]
|
|
862
882
|
|
|
863
883
|
|
|
864
|
-
def sensors_rotation(X, y, sensors_positions_matrix, axis, angles,
|
|
865
|
-
spherical_splines):
|
|
884
|
+
def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
|
|
866
885
|
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
867
886
|
with the desired angle.
|
|
868
887
|
|
|
@@ -900,13 +919,8 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles,
|
|
|
900
919
|
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
901
920
|
(EMBC) (pp. 471-474).
|
|
902
921
|
"""
|
|
903
|
-
rots = [
|
|
904
|
-
|
|
905
|
-
for angle in angles
|
|
906
|
-
]
|
|
907
|
-
rotated_X = _rotate_signals(
|
|
908
|
-
X, rots, sensors_positions_matrix, spherical_splines
|
|
909
|
-
)
|
|
922
|
+
rots = [_make_rotation_matrix(axis, angle, degrees=True) for angle in angles]
|
|
923
|
+
rotated_X = _rotate_signals(X, rots, sensors_positions_matrix, spherical_splines)
|
|
910
924
|
return rotated_X, y
|
|
911
925
|
|
|
912
926
|
|
|
@@ -942,7 +956,7 @@ def mixup(X, y, lam, idx_perm):
|
|
|
942
956
|
International Conference on Learning Representations (ICLR)
|
|
943
957
|
Online: https://arxiv.org/abs/1710.09412
|
|
944
958
|
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
945
|
-
|
|
959
|
+
"""
|
|
946
960
|
device = X.device
|
|
947
961
|
batch_size, n_channels, n_times = X.shape
|
|
948
962
|
|
|
@@ -951,9 +965,132 @@ def mixup(X, y, lam, idx_perm):
|
|
|
951
965
|
y_b = torch.arange(batch_size).to(device)
|
|
952
966
|
|
|
953
967
|
for idx in range(batch_size):
|
|
954
|
-
X_mix[idx] = lam[idx] * X[idx]
|
|
955
|
-
+ (1 - lam[idx]) * X[idx_perm[idx]]
|
|
968
|
+
X_mix[idx] = lam[idx] * X[idx] + (1 - lam[idx]) * X[idx_perm[idx]]
|
|
956
969
|
y_a[idx] = y[idx]
|
|
957
970
|
y_b[idx] = y[idx_perm[idx]]
|
|
958
971
|
|
|
959
972
|
return X_mix, (y_a, y_b, lam)
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
def segmentation_reconstruction(
|
|
976
|
+
X, y, n_segments, data_classes, rand_indices, idx_shuffle
|
|
977
|
+
):
|
|
978
|
+
"""Segment and reconstruct EEG data from [1]_.
|
|
979
|
+
|
|
980
|
+
See [1]_ for details.
|
|
981
|
+
|
|
982
|
+
Parameters
|
|
983
|
+
----------
|
|
984
|
+
X : torch.Tensor
|
|
985
|
+
EEG input example or batch.
|
|
986
|
+
y : torch.Tensor
|
|
987
|
+
EEG labels for the example or batch.
|
|
988
|
+
n_segments : int
|
|
989
|
+
Number of segments to use in the batch.
|
|
990
|
+
rand_indices: array-like
|
|
991
|
+
Array of indices that indicates which trial to use in each segment.
|
|
992
|
+
idx_shuffle: array-like
|
|
993
|
+
Array of indices to shuffle the new generated trials.
|
|
994
|
+
Returns
|
|
995
|
+
-------
|
|
996
|
+
torch.Tensor
|
|
997
|
+
Transformed inputs.
|
|
998
|
+
torch.Tensor
|
|
999
|
+
Transformed labels.
|
|
1000
|
+
References
|
|
1001
|
+
----------
|
|
1002
|
+
.. [1] Lotte, F. (2015). Signal processing approaches to minimize or
|
|
1003
|
+
suppress calibration time in oscillatory activity-based brain–computer
|
|
1004
|
+
interfaces. Proceedings of the IEEE, 103(6), 871-890.
|
|
1005
|
+
"""
|
|
1006
|
+
|
|
1007
|
+
# Initialize lists to store augmented data and corresponding labels
|
|
1008
|
+
aug_data = []
|
|
1009
|
+
aug_label = []
|
|
1010
|
+
|
|
1011
|
+
# Iterate through each class to separate and augment data
|
|
1012
|
+
for class_index, X_class in data_classes:
|
|
1013
|
+
# Determine class-specific dimensions
|
|
1014
|
+
# Store the augmented data and the corresponding class labels
|
|
1015
|
+
n_trials, n_channels, window_size = X_class.shape
|
|
1016
|
+
# Segment Size
|
|
1017
|
+
segment_size = window_size // n_segments
|
|
1018
|
+
# Initialize an empty tensor for augmented data
|
|
1019
|
+
X_aug = torch.zeros_like(X_class)
|
|
1020
|
+
# Generate random indices within the class-specific dataset
|
|
1021
|
+
rand_idx = rand_indices[class_index]
|
|
1022
|
+
for idx_segment in range(n_segments):
|
|
1023
|
+
start = idx_segment * segment_size
|
|
1024
|
+
end = (idx_segment + 1) * segment_size
|
|
1025
|
+
|
|
1026
|
+
# Perform the data augmentation
|
|
1027
|
+
X_aug[np.arange(n_trials), :, start:end] = X_class[
|
|
1028
|
+
rand_idx[:, idx_segment], :, start:end
|
|
1029
|
+
]
|
|
1030
|
+
aug_data.append(X_aug)
|
|
1031
|
+
aug_label.append(torch.full((n_trials,), class_index))
|
|
1032
|
+
# Concatenate the augmented data and labels
|
|
1033
|
+
aug_data = torch.cat(aug_data, dim=0)
|
|
1034
|
+
aug_data = aug_data.to(dtype=X.dtype, device=X.device)
|
|
1035
|
+
aug_data = aug_data[idx_shuffle]
|
|
1036
|
+
|
|
1037
|
+
if y is not None:
|
|
1038
|
+
aug_label = torch.cat(aug_label, dim=0)
|
|
1039
|
+
aug_label = aug_label.to(dtype=y.dtype, device=y.device)
|
|
1040
|
+
aug_label = aug_label[idx_shuffle]
|
|
1041
|
+
return aug_data, aug_label
|
|
1042
|
+
|
|
1043
|
+
return aug_data, y
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def mask_encoding(X, y, time_start, segment_length, n_segments):
|
|
1047
|
+
"""Mark encoding from Ding et al. (2024) from [ding2024]_.
|
|
1048
|
+
|
|
1049
|
+
Replaces a contiguous part (or parts) of all channels by zeros
|
|
1050
|
+
(if more than one segment, it may overlap).
|
|
1051
|
+
|
|
1052
|
+
Implementation based on [ding2024]_
|
|
1053
|
+
|
|
1054
|
+
Parameters
|
|
1055
|
+
----------
|
|
1056
|
+
X : torch.Tensor
|
|
1057
|
+
EEG input example or batch.
|
|
1058
|
+
y : torch.Tensor
|
|
1059
|
+
EEG labels for the example or batch.
|
|
1060
|
+
time_start : torch.Tensor
|
|
1061
|
+
Tensor of integers containing the position (in last dimension) where to
|
|
1062
|
+
start masking the signal. Should have "n_segments" times the size of the first
|
|
1063
|
+
dimension of X (i.e. "n_segments" start positions per example in the batch).
|
|
1064
|
+
segment_length : int
|
|
1065
|
+
Length of each segment to zero out.
|
|
1066
|
+
n_segments : int
|
|
1067
|
+
Number of segments to zero out in each example.
|
|
1068
|
+
|
|
1069
|
+
Returns
|
|
1070
|
+
-------
|
|
1071
|
+
torch.Tensor
|
|
1072
|
+
Transformed inputs.
|
|
1073
|
+
torch.Tensor
|
|
1074
|
+
Transformed labels.
|
|
1075
|
+
|
|
1076
|
+
References
|
|
1077
|
+
----------
|
|
1078
|
+
.. [ding2024] Ding, Wenlong, et al. A Novel Data Augmentation Approach
|
|
1079
|
+
Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI.
|
|
1080
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering
|
|
1081
|
+
32 (2024): 875-886.
|
|
1082
|
+
"""
|
|
1083
|
+
|
|
1084
|
+
batch_indices = torch.arange(X.shape[0]).repeat_interleave(n_segments)
|
|
1085
|
+
start_indices = time_start.flatten()
|
|
1086
|
+
mask_indices = start_indices[:, None] + torch.arange(segment_length)
|
|
1087
|
+
|
|
1088
|
+
# Create a boolean mask with the same shape as X
|
|
1089
|
+
mask = torch.zeros_like(X, dtype=torch.bool)
|
|
1090
|
+
for batch_index, grouped_mask_indices in zip(batch_indices, mask_indices):
|
|
1091
|
+
mask[batch_index, :, grouped_mask_indices] = True
|
|
1092
|
+
|
|
1093
|
+
# Apply the mask to set the values to 0
|
|
1094
|
+
X[mask] = 0
|
|
1095
|
+
|
|
1096
|
+
return X, y # Return the masked tensor and labels
|