braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,34 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
3
4
|
#
|
|
4
5
|
# License: BSD (3-clause)
|
|
5
6
|
|
|
6
7
|
import warnings
|
|
7
8
|
from numbers import Real
|
|
9
|
+
from typing import Callable
|
|
8
10
|
|
|
9
11
|
import numpy as np
|
|
10
12
|
import torch
|
|
11
13
|
from mne.channels import make_standard_montage
|
|
12
14
|
|
|
13
15
|
from .base import Transform
|
|
14
|
-
from .functional import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
16
|
+
from .functional import (
|
|
17
|
+
bandstop_filter,
|
|
18
|
+
channels_dropout,
|
|
19
|
+
channels_permute,
|
|
20
|
+
channels_shuffle,
|
|
21
|
+
frequency_shift,
|
|
22
|
+
ft_surrogate,
|
|
23
|
+
gaussian_noise,
|
|
24
|
+
mask_encoding,
|
|
25
|
+
mixup,
|
|
26
|
+
segmentation_reconstruction,
|
|
27
|
+
sensors_rotation,
|
|
28
|
+
sign_flip,
|
|
29
|
+
smooth_time_mask,
|
|
30
|
+
time_reverse,
|
|
31
|
+
)
|
|
26
32
|
|
|
27
33
|
|
|
28
34
|
class TimeReverse(Transform):
|
|
@@ -37,7 +43,8 @@ class TimeReverse(Transform):
|
|
|
37
43
|
Used to decide whether or not to transform given the probability
|
|
38
44
|
argument. Defaults to None.
|
|
39
45
|
"""
|
|
40
|
-
|
|
46
|
+
|
|
47
|
+
operation = staticmethod(time_reverse) # type: ignore[assignment]
|
|
41
48
|
|
|
42
49
|
def __init__(
|
|
43
50
|
self,
|
|
@@ -62,17 +69,11 @@ class SignFlip(Transform):
|
|
|
62
69
|
Used to decide whether or not to transform given the probability
|
|
63
70
|
argument. Defaults to None.
|
|
64
71
|
"""
|
|
65
|
-
operation = staticmethod(sign_flip)
|
|
66
72
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
random_state=
|
|
71
|
-
):
|
|
72
|
-
super().__init__(
|
|
73
|
-
probability=probability,
|
|
74
|
-
random_state=random_state
|
|
75
|
-
)
|
|
73
|
+
operation = staticmethod(sign_flip) # type: ignore[assignment]
|
|
74
|
+
|
|
75
|
+
def __init__(self, probability, random_state=None):
|
|
76
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
class FTSurrogate(Transform):
|
|
@@ -102,25 +103,26 @@ class FTSurrogate(Transform):
|
|
|
102
103
|
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
|
|
103
104
|
preprint arXiv:1806.08675.
|
|
104
105
|
"""
|
|
105
|
-
|
|
106
|
+
|
|
107
|
+
operation = staticmethod(ft_surrogate) # type: ignore[assignment]
|
|
106
108
|
|
|
107
109
|
def __init__(
|
|
108
110
|
self,
|
|
109
111
|
probability,
|
|
110
112
|
phase_noise_magnitude=1,
|
|
111
113
|
channel_indep=False,
|
|
112
|
-
random_state=None
|
|
114
|
+
random_state=None,
|
|
113
115
|
):
|
|
114
|
-
super().__init__(
|
|
115
|
-
|
|
116
|
-
random_state=random_state
|
|
117
|
-
)
|
|
118
|
-
assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), \
|
|
116
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
117
|
+
assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
|
|
119
118
|
"phase_noise_magnitude should be a float."
|
|
120
|
-
|
|
119
|
+
)
|
|
120
|
+
assert 0 <= phase_noise_magnitude <= 1, (
|
|
121
121
|
"phase_noise_magnitude should be between 0 and 1."
|
|
122
|
+
)
|
|
122
123
|
assert isinstance(channel_indep, bool), (
|
|
123
|
-
"channel_indep is expected to be a boolean"
|
|
124
|
+
"channel_indep is expected to be a boolean"
|
|
125
|
+
)
|
|
124
126
|
self.phase_noise_magnitude = phase_noise_magnitude
|
|
125
127
|
self.channel_indep = channel_indep
|
|
126
128
|
|
|
@@ -174,18 +176,11 @@ class ChannelsDropout(Transform):
|
|
|
174
176
|
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
175
177
|
Reordering. arXiv preprint arXiv:2010.13694.
|
|
176
178
|
"""
|
|
177
|
-
operation = staticmethod(channels_dropout)
|
|
178
179
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
random_state=None
|
|
184
|
-
):
|
|
185
|
-
super().__init__(
|
|
186
|
-
probability=probability,
|
|
187
|
-
random_state=random_state
|
|
188
|
-
)
|
|
180
|
+
operation = staticmethod(channels_dropout) # type: ignore[assignment]
|
|
181
|
+
|
|
182
|
+
def __init__(self, probability, p_drop=0.2, random_state=None):
|
|
183
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
189
184
|
self.p_drop = p_drop
|
|
190
185
|
|
|
191
186
|
def get_augmentation_params(self, *batch):
|
|
@@ -239,18 +234,11 @@ class ChannelsShuffle(Transform):
|
|
|
239
234
|
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
240
235
|
Reordering. arXiv preprint arXiv:2010.13694.
|
|
241
236
|
"""
|
|
242
|
-
operation = staticmethod(channels_shuffle)
|
|
243
237
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
random_state=None
|
|
249
|
-
):
|
|
250
|
-
super().__init__(
|
|
251
|
-
probability=probability,
|
|
252
|
-
random_state=random_state
|
|
253
|
-
)
|
|
238
|
+
operation = staticmethod(channels_shuffle) # type: ignore[assignment]
|
|
239
|
+
|
|
240
|
+
def __init__(self, probability, p_shuffle=0.2, random_state=None):
|
|
241
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
254
242
|
self.p_shuffle = p_shuffle
|
|
255
243
|
|
|
256
244
|
def get_augmentation_params(self, *batch):
|
|
@@ -308,14 +296,10 @@ class GaussianNoise(Transform):
|
|
|
308
296
|
Representation Learning for Electroencephalogram Classification. In
|
|
309
297
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
310
298
|
"""
|
|
311
|
-
operation = staticmethod(gaussian_noise)
|
|
312
299
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
std=0.1,
|
|
317
|
-
random_state=None
|
|
318
|
-
):
|
|
300
|
+
operation = staticmethod(gaussian_noise) # type: ignore[assignment]
|
|
301
|
+
|
|
302
|
+
def __init__(self, probability, std=0.1, random_state=None):
|
|
319
303
|
super().__init__(
|
|
320
304
|
probability=probability,
|
|
321
305
|
random_state=random_state,
|
|
@@ -373,28 +357,23 @@ class ChannelsSymmetry(Transform):
|
|
|
373
357
|
(2018). HAMLET: interpretable human and machine co-learning technique.
|
|
374
358
|
arXiv preprint arXiv:1803.09702.
|
|
375
359
|
"""
|
|
376
|
-
operation = staticmethod(channels_permute)
|
|
377
360
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
ordered_ch_names,
|
|
382
|
-
random_state=None
|
|
383
|
-
):
|
|
361
|
+
operation = staticmethod(channels_permute) # type: ignore[assignment]
|
|
362
|
+
|
|
363
|
+
def __init__(self, probability, ordered_ch_names, random_state=None):
|
|
384
364
|
super().__init__(
|
|
385
365
|
probability=probability,
|
|
386
366
|
random_state=random_state,
|
|
387
367
|
)
|
|
388
|
-
assert (
|
|
389
|
-
isinstance(
|
|
390
|
-
all(isinstance(ch, str) for ch in ordered_ch_names)
|
|
368
|
+
assert isinstance(ordered_ch_names, list) and all(
|
|
369
|
+
isinstance(ch, str) for ch in ordered_ch_names
|
|
391
370
|
), "ordered_ch_names should be a list of str."
|
|
392
371
|
|
|
393
372
|
permutation = list()
|
|
394
373
|
for idx, ch_name in enumerate(ordered_ch_names):
|
|
395
374
|
new_position = idx
|
|
396
375
|
# Find digits in channel name (assuming 10-20 system)
|
|
397
|
-
d =
|
|
376
|
+
d = "".join(list(filter(str.isdigit, ch_name)))
|
|
398
377
|
if len(d) > 0:
|
|
399
378
|
d = int(d)
|
|
400
379
|
if d % 2 == 0: # pair/right electrodes
|
|
@@ -454,22 +433,17 @@ class SmoothTimeMask(Transform):
|
|
|
454
433
|
Representation Learning for Electroencephalogram Classification. In
|
|
455
434
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
456
435
|
"""
|
|
457
|
-
operation = staticmethod(smooth_time_mask)
|
|
458
436
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
mask_len_samples=100,
|
|
463
|
-
random_state=None
|
|
464
|
-
):
|
|
437
|
+
operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
|
|
438
|
+
|
|
439
|
+
def __init__(self, probability, mask_len_samples=100, random_state=None):
|
|
465
440
|
super().__init__(
|
|
466
441
|
probability=probability,
|
|
467
442
|
random_state=random_state,
|
|
468
443
|
)
|
|
469
444
|
|
|
470
445
|
assert (
|
|
471
|
-
isinstance(mask_len_samples, (int, torch.Tensor)) and
|
|
472
|
-
mask_len_samples > 0
|
|
446
|
+
isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
|
|
473
447
|
), "mask_len_samples has to be a positive integer"
|
|
474
448
|
self.mask_len_samples = mask_len_samples
|
|
475
449
|
|
|
@@ -504,9 +478,14 @@ class SmoothTimeMask(Transform):
|
|
|
504
478
|
mask_len_samples = self.mask_len_samples
|
|
505
479
|
if isinstance(mask_len_samples, torch.Tensor):
|
|
506
480
|
mask_len_samples = mask_len_samples.to(X.device)
|
|
507
|
-
mask_start = torch.as_tensor(
|
|
508
|
-
|
|
509
|
-
|
|
481
|
+
mask_start = torch.as_tensor(
|
|
482
|
+
self.rng.uniform(
|
|
483
|
+
low=0,
|
|
484
|
+
high=1,
|
|
485
|
+
size=X.shape[0],
|
|
486
|
+
),
|
|
487
|
+
device=X.device,
|
|
488
|
+
) * (seq_length - mask_len_samples)
|
|
510
489
|
return {
|
|
511
490
|
"mask_start_per_sample": mask_start,
|
|
512
491
|
"mask_len_samples": mask_len_samples,
|
|
@@ -546,27 +525,26 @@ class BandstopFilter(Transform):
|
|
|
546
525
|
Representation Learning for Electroencephalogram Classification. In
|
|
547
526
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
548
527
|
"""
|
|
549
|
-
|
|
528
|
+
|
|
529
|
+
operation = staticmethod(bandstop_filter) # type: ignore[assignment]
|
|
550
530
|
|
|
551
531
|
def __init__(
|
|
552
|
-
self,
|
|
553
|
-
probability,
|
|
554
|
-
sfreq,
|
|
555
|
-
bandwidth=1,
|
|
556
|
-
max_freq=None,
|
|
557
|
-
random_state=None
|
|
532
|
+
self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
|
|
558
533
|
):
|
|
559
534
|
super().__init__(
|
|
560
535
|
probability=probability,
|
|
561
536
|
random_state=random_state,
|
|
562
537
|
)
|
|
563
|
-
assert isinstance(bandwidth, Real) and bandwidth >= 0,
|
|
538
|
+
assert isinstance(bandwidth, Real) and bandwidth >= 0, (
|
|
564
539
|
"bandwidth should be a non-negative float."
|
|
565
|
-
|
|
540
|
+
)
|
|
541
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
566
542
|
"sfreq should be a positive float."
|
|
543
|
+
)
|
|
567
544
|
if max_freq is not None:
|
|
568
|
-
assert isinstance(max_freq, Real) and max_freq > 0,
|
|
545
|
+
assert isinstance(max_freq, Real) and max_freq > 0, (
|
|
569
546
|
"max_freq should be a positive float."
|
|
547
|
+
)
|
|
570
548
|
nyq = sfreq / 2
|
|
571
549
|
if max_freq is None or max_freq > nyq:
|
|
572
550
|
max_freq = nyq
|
|
@@ -575,8 +553,9 @@ class BandstopFilter(Transform):
|
|
|
575
553
|
f" Nyquist frequency ({nyq} Hz)."
|
|
576
554
|
f" Falling back to max_freq = {nyq}."
|
|
577
555
|
)
|
|
578
|
-
assert bandwidth < max_freq,
|
|
556
|
+
assert bandwidth < max_freq, (
|
|
579
557
|
f"`bandwidth` needs to be smaller than max_freq={max_freq}"
|
|
558
|
+
)
|
|
580
559
|
|
|
581
560
|
# override bandwidth value when a magnitude is passed
|
|
582
561
|
self.sfreq = sfreq
|
|
@@ -619,7 +598,7 @@ class BandstopFilter(Transform):
|
|
|
619
598
|
notched_freqs = self.rng.uniform(
|
|
620
599
|
low=1 + 2 * self.bandwidth,
|
|
621
600
|
high=self.max_freq - 1 - 2 * self.bandwidth,
|
|
622
|
-
size=X.shape[0]
|
|
601
|
+
size=X.shape[0],
|
|
623
602
|
)
|
|
624
603
|
return {
|
|
625
604
|
"sfreq": self.sfreq,
|
|
@@ -646,21 +625,17 @@ class FrequencyShift(Transform):
|
|
|
646
625
|
Seed to be used to instantiate numpy random number generator instance.
|
|
647
626
|
Defaults to None.
|
|
648
627
|
"""
|
|
649
|
-
operation = staticmethod(frequency_shift)
|
|
650
628
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
sfreq,
|
|
655
|
-
max_delta_freq=2,
|
|
656
|
-
random_state=None
|
|
657
|
-
):
|
|
629
|
+
operation = staticmethod(frequency_shift) # type: ignore[assignment]
|
|
630
|
+
|
|
631
|
+
def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
|
|
658
632
|
super().__init__(
|
|
659
633
|
probability=probability,
|
|
660
634
|
random_state=random_state,
|
|
661
635
|
)
|
|
662
|
-
assert isinstance(sfreq, Real) and sfreq > 0,
|
|
636
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
663
637
|
"sfreq should be a positive float."
|
|
638
|
+
)
|
|
664
639
|
self.sfreq = sfreq
|
|
665
640
|
|
|
666
641
|
self.max_delta_freq = max_delta_freq
|
|
@@ -689,10 +664,7 @@ class FrequencyShift(Transform):
|
|
|
689
664
|
return super().get_augmentation_params(*batch)
|
|
690
665
|
X = batch[0]
|
|
691
666
|
|
|
692
|
-
u = torch.as_tensor(
|
|
693
|
-
self.rng.uniform(size=X.shape[0]),
|
|
694
|
-
device=X.device
|
|
695
|
-
)
|
|
667
|
+
u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
|
|
696
668
|
max_delta_freq = self.max_delta_freq
|
|
697
669
|
if isinstance(max_delta_freq, torch.Tensor):
|
|
698
670
|
max_delta_freq = max_delta_freq.to(X.device)
|
|
@@ -718,12 +690,13 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
|
|
|
718
690
|
matrices that will be fed to `SensorsRotation` transform. By
|
|
719
691
|
default None.
|
|
720
692
|
"""
|
|
721
|
-
assert raw_or_epoch is not None or ordered_ch_names is not None,
|
|
693
|
+
assert raw_or_epoch is not None or ordered_ch_names is not None, (
|
|
722
694
|
"At least one of raw_or_epoch and ordered_ch_names is needed."
|
|
695
|
+
)
|
|
723
696
|
if ordered_ch_names is None:
|
|
724
|
-
ordered_ch_names = raw_or_epoch.info[
|
|
725
|
-
ten_twenty_montage = make_standard_montage(
|
|
726
|
-
positions_dict = ten_twenty_montage.get_positions()[
|
|
697
|
+
ordered_ch_names = raw_or_epoch.info["ch_names"]
|
|
698
|
+
ten_twenty_montage = make_standard_montage("standard_1020")
|
|
699
|
+
positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
|
|
727
700
|
positions_subdict = {
|
|
728
701
|
k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
|
|
729
702
|
}
|
|
@@ -770,37 +743,38 @@ class SensorsRotation(Transform):
|
|
|
770
743
|
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
771
744
|
(EMBC) (pp. 471-474).
|
|
772
745
|
"""
|
|
773
|
-
|
|
746
|
+
|
|
747
|
+
operation = staticmethod(sensors_rotation) # type: ignore[assignment]
|
|
774
748
|
|
|
775
749
|
def __init__(
|
|
776
750
|
self,
|
|
777
751
|
probability,
|
|
778
752
|
sensors_positions_matrix,
|
|
779
|
-
axis=
|
|
753
|
+
axis="z",
|
|
780
754
|
max_degrees=15,
|
|
781
755
|
spherical_splines=True,
|
|
782
|
-
random_state=None
|
|
756
|
+
random_state=None,
|
|
783
757
|
):
|
|
784
|
-
super().__init__(
|
|
785
|
-
probability=probability,
|
|
786
|
-
random_state=random_state
|
|
787
|
-
)
|
|
758
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
788
759
|
if isinstance(sensors_positions_matrix, (np.ndarray, list)):
|
|
789
|
-
sensors_positions_matrix = torch.as_tensor(
|
|
790
|
-
|
|
791
|
-
)
|
|
792
|
-
assert isinstance(sensors_positions_matrix, torch.Tensor), \
|
|
760
|
+
sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
|
|
761
|
+
assert isinstance(sensors_positions_matrix, torch.Tensor), (
|
|
793
762
|
"sensors_positions should be an Tensor"
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
max_degrees
|
|
797
|
-
)
|
|
798
|
-
assert isinstance(axis, str) and axis in [
|
|
799
|
-
"
|
|
800
|
-
|
|
763
|
+
)
|
|
764
|
+
assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
|
|
765
|
+
"max_degrees should be non-negative float."
|
|
766
|
+
)
|
|
767
|
+
assert isinstance(axis, str) and axis in [
|
|
768
|
+
"x",
|
|
769
|
+
"y",
|
|
770
|
+
"z",
|
|
771
|
+
], "axis can be either x, y or z."
|
|
772
|
+
assert sensors_positions_matrix.shape[0] == 3, (
|
|
801
773
|
"sensors_positions_matrix shape should be 3 x n_channels."
|
|
802
|
-
|
|
774
|
+
)
|
|
775
|
+
assert isinstance(spherical_splines, bool), (
|
|
803
776
|
"spherical_splines should be a boolean"
|
|
777
|
+
)
|
|
804
778
|
self.sensors_positions_matrix = sensors_positions_matrix
|
|
805
779
|
self.axis = axis
|
|
806
780
|
self.spherical_splines = spherical_splines
|
|
@@ -841,21 +815,18 @@ class SensorsRotation(Transform):
|
|
|
841
815
|
return super().get_augmentation_params(*batch)
|
|
842
816
|
X = batch[0]
|
|
843
817
|
|
|
844
|
-
u = self.rng.uniform(
|
|
845
|
-
low=0,
|
|
846
|
-
high=1,
|
|
847
|
-
size=X.shape[0]
|
|
848
|
-
)
|
|
818
|
+
u = self.rng.uniform(low=0, high=1, size=X.shape[0])
|
|
849
819
|
max_degrees = self.max_degrees
|
|
850
820
|
if isinstance(max_degrees, torch.Tensor):
|
|
851
821
|
max_degrees = max_degrees.to(X.device)
|
|
852
|
-
random_angles =
|
|
853
|
-
u, device=X.device) * 2 * max_degrees - max_degrees
|
|
822
|
+
random_angles = (
|
|
823
|
+
torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
|
|
824
|
+
)
|
|
854
825
|
return {
|
|
855
826
|
"sensors_positions_matrix": self.sensors_positions_matrix,
|
|
856
827
|
"axis": self.axis,
|
|
857
828
|
"angles": random_angles,
|
|
858
|
-
"spherical_splines": self.spherical_splines
|
|
829
|
+
"spherical_splines": self.spherical_splines,
|
|
859
830
|
}
|
|
860
831
|
|
|
861
832
|
|
|
@@ -900,7 +871,7 @@ class SensorsZRotation(SensorsRotation):
|
|
|
900
871
|
ordered_ch_names,
|
|
901
872
|
max_degrees=15,
|
|
902
873
|
spherical_splines=True,
|
|
903
|
-
random_state=None
|
|
874
|
+
random_state=None,
|
|
904
875
|
):
|
|
905
876
|
sensors_positions_matrix = torch.as_tensor(
|
|
906
877
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -908,10 +879,10 @@ class SensorsZRotation(SensorsRotation):
|
|
|
908
879
|
super().__init__(
|
|
909
880
|
probability=probability,
|
|
910
881
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
911
|
-
axis=
|
|
882
|
+
axis="z",
|
|
912
883
|
max_degrees=max_degrees,
|
|
913
884
|
spherical_splines=spherical_splines,
|
|
914
|
-
random_state=random_state
|
|
885
|
+
random_state=random_state,
|
|
915
886
|
)
|
|
916
887
|
|
|
917
888
|
|
|
@@ -956,7 +927,7 @@ class SensorsYRotation(SensorsRotation):
|
|
|
956
927
|
ordered_ch_names,
|
|
957
928
|
max_degrees=15,
|
|
958
929
|
spherical_splines=True,
|
|
959
|
-
random_state=None
|
|
930
|
+
random_state=None,
|
|
960
931
|
):
|
|
961
932
|
sensors_positions_matrix = torch.as_tensor(
|
|
962
933
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -964,10 +935,10 @@ class SensorsYRotation(SensorsRotation):
|
|
|
964
935
|
super().__init__(
|
|
965
936
|
probability=probability,
|
|
966
937
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
967
|
-
axis=
|
|
938
|
+
axis="y",
|
|
968
939
|
max_degrees=max_degrees,
|
|
969
940
|
spherical_splines=spherical_splines,
|
|
970
|
-
random_state=random_state
|
|
941
|
+
random_state=random_state,
|
|
971
942
|
)
|
|
972
943
|
|
|
973
944
|
|
|
@@ -1012,7 +983,7 @@ class SensorsXRotation(SensorsRotation):
|
|
|
1012
983
|
ordered_ch_names,
|
|
1013
984
|
max_degrees=15,
|
|
1014
985
|
spherical_splines=True,
|
|
1015
|
-
random_state=None
|
|
986
|
+
random_state=None,
|
|
1016
987
|
):
|
|
1017
988
|
sensors_positions_matrix = torch.as_tensor(
|
|
1018
989
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -1020,10 +991,10 @@ class SensorsXRotation(SensorsRotation):
|
|
|
1020
991
|
super().__init__(
|
|
1021
992
|
probability=probability,
|
|
1022
993
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
1023
|
-
axis=
|
|
994
|
+
axis="x",
|
|
1024
995
|
max_degrees=max_degrees,
|
|
1025
996
|
spherical_splines=spherical_splines,
|
|
1026
|
-
random_state=random_state
|
|
997
|
+
random_state=random_state,
|
|
1027
998
|
)
|
|
1028
999
|
|
|
1029
1000
|
|
|
@@ -1050,17 +1021,13 @@ class Mixup(Transform):
|
|
|
1050
1021
|
Online: https://arxiv.org/abs/1710.09412
|
|
1051
1022
|
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
1052
1023
|
"""
|
|
1053
|
-
operation = staticmethod(mixup)
|
|
1054
1024
|
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
beta_per_sample=False,
|
|
1059
|
-
random_state=None
|
|
1060
|
-
):
|
|
1025
|
+
operation = staticmethod(mixup) # type: ignore[assignment]
|
|
1026
|
+
|
|
1027
|
+
def __init__(self, alpha, beta_per_sample=False, random_state=None):
|
|
1061
1028
|
super().__init__(
|
|
1062
1029
|
probability=1.0, # Mixup has to be applied to whole batches
|
|
1063
|
-
random_state=random_state
|
|
1030
|
+
random_state=random_state,
|
|
1064
1031
|
)
|
|
1065
1032
|
self.alpha = alpha
|
|
1066
1033
|
self.beta_per_sample = beta_per_sample
|
|
@@ -1098,9 +1065,210 @@ class Mixup(Transform):
|
|
|
1098
1065
|
else:
|
|
1099
1066
|
lam = torch.ones(batch_size).to(device)
|
|
1100
1067
|
|
|
1101
|
-
idx_perm = torch.as_tensor(
|
|
1068
|
+
idx_perm = torch.as_tensor(
|
|
1069
|
+
self.rng.permutation(
|
|
1070
|
+
batch_size,
|
|
1071
|
+
)
|
|
1072
|
+
)
|
|
1102
1073
|
|
|
1103
1074
|
return {
|
|
1104
1075
|
"lam": lam,
|
|
1105
1076
|
"idx_perm": idx_perm,
|
|
1106
1077
|
}
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
class SegmentationReconstruction(Transform):
|
|
1081
|
+
"""Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
|
|
1082
|
+
|
|
1083
|
+
Applies a segmentation-reconstruction transform to the input data, as
|
|
1084
|
+
proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
|
|
1085
|
+
it to generate new synthetic trials by label, preserving the original
|
|
1086
|
+
order of the segments in time domain.
|
|
1087
|
+
|
|
1088
|
+
Parameters
|
|
1089
|
+
----------
|
|
1090
|
+
probability : float
|
|
1091
|
+
Float setting the probability of applying the operation.
|
|
1092
|
+
random_state: int | numpy.random.Generator, optional
|
|
1093
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1094
|
+
Used to decide whether to transform given the probability
|
|
1095
|
+
argument and to sample the segments mixing. Defaults to None.
|
|
1096
|
+
n_segments : int, optional
|
|
1097
|
+
Number of segments to use in the batch. If None, X will be
|
|
1098
|
+
automatically segmented, getting the last element in a list
|
|
1099
|
+
of factors of the number of samples's square root. Defaults to None.
|
|
1100
|
+
|
|
1101
|
+
References
|
|
1102
|
+
----------
|
|
1103
|
+
.. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
|
|
1104
|
+
or suppress calibration time in oscillatory activity-based brain–computer
|
|
1105
|
+
interfaces. Proceedings of the IEEE, 103(6), 871-890.
|
|
1106
|
+
"""
|
|
1107
|
+
|
|
1108
|
+
operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
|
|
1109
|
+
|
|
1110
|
+
def __init__(
|
|
1111
|
+
self,
|
|
1112
|
+
probability,
|
|
1113
|
+
n_segments=None,
|
|
1114
|
+
random_state=None,
|
|
1115
|
+
):
|
|
1116
|
+
super().__init__(
|
|
1117
|
+
probability=probability,
|
|
1118
|
+
random_state=random_state,
|
|
1119
|
+
)
|
|
1120
|
+
self.n_segments = n_segments
|
|
1121
|
+
|
|
1122
|
+
def get_augmentation_params(self, *batch):
|
|
1123
|
+
"""Return transform parameters.
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
X : tensor.Tensor
|
|
1128
|
+
The data.
|
|
1129
|
+
y : tensor.Tensor
|
|
1130
|
+
The labels.
|
|
1131
|
+
Returns
|
|
1132
|
+
-------
|
|
1133
|
+
params : dict
|
|
1134
|
+
Contains the number of segments to split the signal into.
|
|
1135
|
+
"""
|
|
1136
|
+
X, y = batch[0], batch[1]
|
|
1137
|
+
|
|
1138
|
+
if y is not None:
|
|
1139
|
+
if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
|
|
1140
|
+
raise ValueError("X and y must be torch tensors.")
|
|
1141
|
+
|
|
1142
|
+
if X.shape[0] != y.shape[0]:
|
|
1143
|
+
raise ValueError("Number of samples in X and y must be the same.")
|
|
1144
|
+
|
|
1145
|
+
if self.n_segments is None:
|
|
1146
|
+
self.n_segments = int(X.shape[2])
|
|
1147
|
+
n_segments_list = []
|
|
1148
|
+
for i in range(1, int(self.n_segments**0.5) + 1):
|
|
1149
|
+
if self.n_segments % i == 0:
|
|
1150
|
+
n_segments_list.append(i)
|
|
1151
|
+
self.n_segments = n_segments_list[-1]
|
|
1152
|
+
|
|
1153
|
+
elif not (
|
|
1154
|
+
isinstance(self.n_segments, (int, float))
|
|
1155
|
+
and 1 <= self.n_segments <= X.shape[2]
|
|
1156
|
+
):
|
|
1157
|
+
raise ValueError(
|
|
1158
|
+
f"Number of segments must be a positive integer less than "
|
|
1159
|
+
f"(or equal) the window size. Got {self.n_segments}"
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
if y is None:
|
|
1163
|
+
data_classes = [(np.nan, X)]
|
|
1164
|
+
|
|
1165
|
+
else:
|
|
1166
|
+
classes = torch.unique(y)
|
|
1167
|
+
|
|
1168
|
+
data_classes = [(i, X[y == i]) for i in classes]
|
|
1169
|
+
|
|
1170
|
+
rand_indices = dict()
|
|
1171
|
+
for label, X_class in data_classes:
|
|
1172
|
+
n_trials = X_class.shape[0]
|
|
1173
|
+
rand_indices[label] = self.rng.randint(
|
|
1174
|
+
0, n_trials, (n_trials, self.n_segments)
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
idx_shuffle = self.rng.permutation(X.shape[0])
|
|
1178
|
+
|
|
1179
|
+
return {
|
|
1180
|
+
"n_segments": self.n_segments,
|
|
1181
|
+
"data_classes": data_classes,
|
|
1182
|
+
"rand_indices": rand_indices,
|
|
1183
|
+
"idx_shuffle": idx_shuffle,
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
class MaskEncoding(Transform):
|
|
1188
|
+
"""MaskEncoding from [1]_.
|
|
1189
|
+
|
|
1190
|
+
Replaces randomly chosen contiguous part (or parts) of all channels by
|
|
1191
|
+
zeros (if more than one segment, it may overlap).
|
|
1192
|
+
|
|
1193
|
+
Implementation based on [1]_
|
|
1194
|
+
|
|
1195
|
+
Parameters
|
|
1196
|
+
----------
|
|
1197
|
+
probability : float
|
|
1198
|
+
Float setting the probability of applying the operation.
|
|
1199
|
+
max_mask_ratio: float, optional
|
|
1200
|
+
Signal ratio to zero out. Defaults to 0.1.
|
|
1201
|
+
n_segments : int, optional
|
|
1202
|
+
Number of segments to zero out in each example.
|
|
1203
|
+
Defaults to 1.
|
|
1204
|
+
random_state: int | numpy.random.Generator, optional
|
|
1205
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1206
|
+
Defaults to None.
|
|
1207
|
+
|
|
1208
|
+
References
|
|
1209
|
+
----------
|
|
1210
|
+
.. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
|
|
1211
|
+
Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
|
|
1212
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering
|
|
1213
|
+
32 (2024): 875-886.
|
|
1214
|
+
"""
|
|
1215
|
+
|
|
1216
|
+
operation = staticmethod(mask_encoding) # type: ignore[assignment]
|
|
1217
|
+
|
|
1218
|
+
def __init__(
|
|
1219
|
+
self,
|
|
1220
|
+
probability,
|
|
1221
|
+
max_mask_ratio=0.1,
|
|
1222
|
+
n_segments=1,
|
|
1223
|
+
random_state=None,
|
|
1224
|
+
):
|
|
1225
|
+
super().__init__(
|
|
1226
|
+
probability=probability,
|
|
1227
|
+
random_state=random_state,
|
|
1228
|
+
)
|
|
1229
|
+
assert isinstance(n_segments, int) and n_segments > 0, (
|
|
1230
|
+
"n_segments should be a positive integer."
|
|
1231
|
+
)
|
|
1232
|
+
assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
|
|
1233
|
+
"mask_ratio should be a float between 0 and 1."
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
self.mask_ratio = max_mask_ratio
|
|
1237
|
+
self.n_segments = n_segments
|
|
1238
|
+
|
|
1239
|
+
def get_augmentation_params(self, *batch):
|
|
1240
|
+
"""Return transform parameters.
|
|
1241
|
+
|
|
1242
|
+
Parameters
|
|
1243
|
+
----------
|
|
1244
|
+
X : tensor.Tensor
|
|
1245
|
+
The data.
|
|
1246
|
+
y : tensor.Tensor
|
|
1247
|
+
The labels.
|
|
1248
|
+
Returns
|
|
1249
|
+
-------
|
|
1250
|
+
params : dict
|
|
1251
|
+
Contains ...
|
|
1252
|
+
"""
|
|
1253
|
+
if len(batch) == 0:
|
|
1254
|
+
return super().get_augmentation_params(*batch)
|
|
1255
|
+
X = batch[0]
|
|
1256
|
+
|
|
1257
|
+
batch_size, _, n_times = X.shape
|
|
1258
|
+
|
|
1259
|
+
segment_length = int((n_times * self.mask_ratio) / self.n_segments)
|
|
1260
|
+
|
|
1261
|
+
assert segment_length >= 1, (
|
|
1262
|
+
"n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
time_start = self.rng.randint(
|
|
1266
|
+
0, n_times - segment_length, (batch_size, self.n_segments)
|
|
1267
|
+
)
|
|
1268
|
+
time_start = torch.from_numpy(time_start)
|
|
1269
|
+
|
|
1270
|
+
return {
|
|
1271
|
+
"time_start": time_start,
|
|
1272
|
+
"segment_length": segment_length,
|
|
1273
|
+
"n_segments": self.n_segments,
|
|
1274
|
+
}
|