braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
3
4
|
#
|
|
4
5
|
# License: BSD (3-clause)
|
|
5
6
|
|
|
@@ -11,18 +12,22 @@ import torch
|
|
|
11
12
|
from mne.channels import make_standard_montage
|
|
12
13
|
|
|
13
14
|
from .base import Transform
|
|
14
|
-
from .functional import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
15
|
+
from .functional import (
|
|
16
|
+
bandstop_filter,
|
|
17
|
+
channels_dropout,
|
|
18
|
+
channels_permute,
|
|
19
|
+
channels_shuffle,
|
|
20
|
+
frequency_shift,
|
|
21
|
+
ft_surrogate,
|
|
22
|
+
gaussian_noise,
|
|
23
|
+
mask_encoding,
|
|
24
|
+
mixup,
|
|
25
|
+
segmentation_reconstruction,
|
|
26
|
+
sensors_rotation,
|
|
27
|
+
sign_flip,
|
|
28
|
+
smooth_time_mask,
|
|
29
|
+
time_reverse,
|
|
30
|
+
)
|
|
26
31
|
|
|
27
32
|
|
|
28
33
|
class TimeReverse(Transform):
|
|
@@ -37,7 +42,8 @@ class TimeReverse(Transform):
|
|
|
37
42
|
Used to decide whether or not to transform given the probability
|
|
38
43
|
argument. Defaults to None.
|
|
39
44
|
"""
|
|
40
|
-
|
|
45
|
+
|
|
46
|
+
operation = staticmethod(time_reverse) # type: ignore[assignment]
|
|
41
47
|
|
|
42
48
|
def __init__(
|
|
43
49
|
self,
|
|
@@ -62,17 +68,11 @@ class SignFlip(Transform):
|
|
|
62
68
|
Used to decide whether or not to transform given the probability
|
|
63
69
|
argument. Defaults to None.
|
|
64
70
|
"""
|
|
65
|
-
operation = staticmethod(sign_flip)
|
|
66
71
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
random_state=
|
|
71
|
-
):
|
|
72
|
-
super().__init__(
|
|
73
|
-
probability=probability,
|
|
74
|
-
random_state=random_state
|
|
75
|
-
)
|
|
72
|
+
operation = staticmethod(sign_flip) # type: ignore[assignment]
|
|
73
|
+
|
|
74
|
+
def __init__(self, probability, random_state=None):
|
|
75
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
class FTSurrogate(Transform):
|
|
@@ -102,25 +102,26 @@ class FTSurrogate(Transform):
|
|
|
102
102
|
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
|
|
103
103
|
preprint arXiv:1806.08675.
|
|
104
104
|
"""
|
|
105
|
-
|
|
105
|
+
|
|
106
|
+
operation = staticmethod(ft_surrogate) # type: ignore[assignment]
|
|
106
107
|
|
|
107
108
|
def __init__(
|
|
108
109
|
self,
|
|
109
110
|
probability,
|
|
110
111
|
phase_noise_magnitude=1,
|
|
111
112
|
channel_indep=False,
|
|
112
|
-
random_state=None
|
|
113
|
+
random_state=None,
|
|
113
114
|
):
|
|
114
|
-
super().__init__(
|
|
115
|
-
|
|
116
|
-
random_state=random_state
|
|
117
|
-
)
|
|
118
|
-
assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), \
|
|
115
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
116
|
+
assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
|
|
119
117
|
"phase_noise_magnitude should be a float."
|
|
120
|
-
|
|
118
|
+
)
|
|
119
|
+
assert 0 <= phase_noise_magnitude <= 1, (
|
|
121
120
|
"phase_noise_magnitude should be between 0 and 1."
|
|
121
|
+
)
|
|
122
122
|
assert isinstance(channel_indep, bool), (
|
|
123
|
-
"channel_indep is expected to be a boolean"
|
|
123
|
+
"channel_indep is expected to be a boolean"
|
|
124
|
+
)
|
|
124
125
|
self.phase_noise_magnitude = phase_noise_magnitude
|
|
125
126
|
self.channel_indep = channel_indep
|
|
126
127
|
|
|
@@ -174,18 +175,11 @@ class ChannelsDropout(Transform):
|
|
|
174
175
|
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
175
176
|
Reordering. arXiv preprint arXiv:2010.13694.
|
|
176
177
|
"""
|
|
177
|
-
operation = staticmethod(channels_dropout)
|
|
178
178
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
random_state=None
|
|
184
|
-
):
|
|
185
|
-
super().__init__(
|
|
186
|
-
probability=probability,
|
|
187
|
-
random_state=random_state
|
|
188
|
-
)
|
|
179
|
+
operation = staticmethod(channels_dropout) # type: ignore[assignment]
|
|
180
|
+
|
|
181
|
+
def __init__(self, probability, p_drop=0.2, random_state=None):
|
|
182
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
189
183
|
self.p_drop = p_drop
|
|
190
184
|
|
|
191
185
|
def get_augmentation_params(self, *batch):
|
|
@@ -239,18 +233,11 @@ class ChannelsShuffle(Transform):
|
|
|
239
233
|
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
240
234
|
Reordering. arXiv preprint arXiv:2010.13694.
|
|
241
235
|
"""
|
|
242
|
-
operation = staticmethod(channels_shuffle)
|
|
243
236
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
random_state=None
|
|
249
|
-
):
|
|
250
|
-
super().__init__(
|
|
251
|
-
probability=probability,
|
|
252
|
-
random_state=random_state
|
|
253
|
-
)
|
|
237
|
+
operation = staticmethod(channels_shuffle) # type: ignore[assignment]
|
|
238
|
+
|
|
239
|
+
def __init__(self, probability, p_shuffle=0.2, random_state=None):
|
|
240
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
254
241
|
self.p_shuffle = p_shuffle
|
|
255
242
|
|
|
256
243
|
def get_augmentation_params(self, *batch):
|
|
@@ -308,14 +295,10 @@ class GaussianNoise(Transform):
|
|
|
308
295
|
Representation Learning for Electroencephalogram Classification. In
|
|
309
296
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
310
297
|
"""
|
|
311
|
-
operation = staticmethod(gaussian_noise)
|
|
312
298
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
std=0.1,
|
|
317
|
-
random_state=None
|
|
318
|
-
):
|
|
299
|
+
operation = staticmethod(gaussian_noise) # type: ignore[assignment]
|
|
300
|
+
|
|
301
|
+
def __init__(self, probability, std=0.1, random_state=None):
|
|
319
302
|
super().__init__(
|
|
320
303
|
probability=probability,
|
|
321
304
|
random_state=random_state,
|
|
@@ -373,28 +356,23 @@ class ChannelsSymmetry(Transform):
|
|
|
373
356
|
(2018). HAMLET: interpretable human and machine co-learning technique.
|
|
374
357
|
arXiv preprint arXiv:1803.09702.
|
|
375
358
|
"""
|
|
376
|
-
operation = staticmethod(channels_permute)
|
|
377
359
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
ordered_ch_names,
|
|
382
|
-
random_state=None
|
|
383
|
-
):
|
|
360
|
+
operation = staticmethod(channels_permute) # type: ignore[assignment]
|
|
361
|
+
|
|
362
|
+
def __init__(self, probability, ordered_ch_names, random_state=None):
|
|
384
363
|
super().__init__(
|
|
385
364
|
probability=probability,
|
|
386
365
|
random_state=random_state,
|
|
387
366
|
)
|
|
388
|
-
assert (
|
|
389
|
-
isinstance(
|
|
390
|
-
all(isinstance(ch, str) for ch in ordered_ch_names)
|
|
367
|
+
assert isinstance(ordered_ch_names, list) and all(
|
|
368
|
+
isinstance(ch, str) for ch in ordered_ch_names
|
|
391
369
|
), "ordered_ch_names should be a list of str."
|
|
392
370
|
|
|
393
371
|
permutation = list()
|
|
394
372
|
for idx, ch_name in enumerate(ordered_ch_names):
|
|
395
373
|
new_position = idx
|
|
396
374
|
# Find digits in channel name (assuming 10-20 system)
|
|
397
|
-
d =
|
|
375
|
+
d = "".join(list(filter(str.isdigit, ch_name)))
|
|
398
376
|
if len(d) > 0:
|
|
399
377
|
d = int(d)
|
|
400
378
|
if d % 2 == 0: # pair/right electrodes
|
|
@@ -454,22 +432,17 @@ class SmoothTimeMask(Transform):
|
|
|
454
432
|
Representation Learning for Electroencephalogram Classification. In
|
|
455
433
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
456
434
|
"""
|
|
457
|
-
operation = staticmethod(smooth_time_mask)
|
|
458
435
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
mask_len_samples=100,
|
|
463
|
-
random_state=None
|
|
464
|
-
):
|
|
436
|
+
operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
|
|
437
|
+
|
|
438
|
+
def __init__(self, probability, mask_len_samples=100, random_state=None):
|
|
465
439
|
super().__init__(
|
|
466
440
|
probability=probability,
|
|
467
441
|
random_state=random_state,
|
|
468
442
|
)
|
|
469
443
|
|
|
470
444
|
assert (
|
|
471
|
-
isinstance(mask_len_samples, (int, torch.Tensor)) and
|
|
472
|
-
mask_len_samples > 0
|
|
445
|
+
isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
|
|
473
446
|
), "mask_len_samples has to be a positive integer"
|
|
474
447
|
self.mask_len_samples = mask_len_samples
|
|
475
448
|
|
|
@@ -504,9 +477,14 @@ class SmoothTimeMask(Transform):
|
|
|
504
477
|
mask_len_samples = self.mask_len_samples
|
|
505
478
|
if isinstance(mask_len_samples, torch.Tensor):
|
|
506
479
|
mask_len_samples = mask_len_samples.to(X.device)
|
|
507
|
-
mask_start = torch.as_tensor(
|
|
508
|
-
|
|
509
|
-
|
|
480
|
+
mask_start = torch.as_tensor(
|
|
481
|
+
self.rng.uniform(
|
|
482
|
+
low=0,
|
|
483
|
+
high=1,
|
|
484
|
+
size=X.shape[0],
|
|
485
|
+
),
|
|
486
|
+
device=X.device,
|
|
487
|
+
) * (seq_length - mask_len_samples)
|
|
510
488
|
return {
|
|
511
489
|
"mask_start_per_sample": mask_start,
|
|
512
490
|
"mask_len_samples": mask_len_samples,
|
|
@@ -546,27 +524,26 @@ class BandstopFilter(Transform):
|
|
|
546
524
|
Representation Learning for Electroencephalogram Classification. In
|
|
547
525
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
548
526
|
"""
|
|
549
|
-
|
|
527
|
+
|
|
528
|
+
operation = staticmethod(bandstop_filter) # type: ignore[assignment]
|
|
550
529
|
|
|
551
530
|
def __init__(
|
|
552
|
-
self,
|
|
553
|
-
probability,
|
|
554
|
-
sfreq,
|
|
555
|
-
bandwidth=1,
|
|
556
|
-
max_freq=None,
|
|
557
|
-
random_state=None
|
|
531
|
+
self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
|
|
558
532
|
):
|
|
559
533
|
super().__init__(
|
|
560
534
|
probability=probability,
|
|
561
535
|
random_state=random_state,
|
|
562
536
|
)
|
|
563
|
-
assert isinstance(bandwidth, Real) and bandwidth >= 0,
|
|
537
|
+
assert isinstance(bandwidth, Real) and bandwidth >= 0, (
|
|
564
538
|
"bandwidth should be a non-negative float."
|
|
565
|
-
|
|
539
|
+
)
|
|
540
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
566
541
|
"sfreq should be a positive float."
|
|
542
|
+
)
|
|
567
543
|
if max_freq is not None:
|
|
568
|
-
assert isinstance(max_freq, Real) and max_freq > 0,
|
|
544
|
+
assert isinstance(max_freq, Real) and max_freq > 0, (
|
|
569
545
|
"max_freq should be a positive float."
|
|
546
|
+
)
|
|
570
547
|
nyq = sfreq / 2
|
|
571
548
|
if max_freq is None or max_freq > nyq:
|
|
572
549
|
max_freq = nyq
|
|
@@ -575,8 +552,9 @@ class BandstopFilter(Transform):
|
|
|
575
552
|
f" Nyquist frequency ({nyq} Hz)."
|
|
576
553
|
f" Falling back to max_freq = {nyq}."
|
|
577
554
|
)
|
|
578
|
-
assert bandwidth < max_freq,
|
|
555
|
+
assert bandwidth < max_freq, (
|
|
579
556
|
f"`bandwidth` needs to be smaller than max_freq={max_freq}"
|
|
557
|
+
)
|
|
580
558
|
|
|
581
559
|
# override bandwidth value when a magnitude is passed
|
|
582
560
|
self.sfreq = sfreq
|
|
@@ -619,7 +597,7 @@ class BandstopFilter(Transform):
|
|
|
619
597
|
notched_freqs = self.rng.uniform(
|
|
620
598
|
low=1 + 2 * self.bandwidth,
|
|
621
599
|
high=self.max_freq - 1 - 2 * self.bandwidth,
|
|
622
|
-
size=X.shape[0]
|
|
600
|
+
size=X.shape[0],
|
|
623
601
|
)
|
|
624
602
|
return {
|
|
625
603
|
"sfreq": self.sfreq,
|
|
@@ -646,21 +624,17 @@ class FrequencyShift(Transform):
|
|
|
646
624
|
Seed to be used to instantiate numpy random number generator instance.
|
|
647
625
|
Defaults to None.
|
|
648
626
|
"""
|
|
649
|
-
operation = staticmethod(frequency_shift)
|
|
650
627
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
sfreq,
|
|
655
|
-
max_delta_freq=2,
|
|
656
|
-
random_state=None
|
|
657
|
-
):
|
|
628
|
+
operation = staticmethod(frequency_shift) # type: ignore[assignment]
|
|
629
|
+
|
|
630
|
+
def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
|
|
658
631
|
super().__init__(
|
|
659
632
|
probability=probability,
|
|
660
633
|
random_state=random_state,
|
|
661
634
|
)
|
|
662
|
-
assert isinstance(sfreq, Real) and sfreq > 0,
|
|
635
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
663
636
|
"sfreq should be a positive float."
|
|
637
|
+
)
|
|
664
638
|
self.sfreq = sfreq
|
|
665
639
|
|
|
666
640
|
self.max_delta_freq = max_delta_freq
|
|
@@ -689,10 +663,7 @@ class FrequencyShift(Transform):
|
|
|
689
663
|
return super().get_augmentation_params(*batch)
|
|
690
664
|
X = batch[0]
|
|
691
665
|
|
|
692
|
-
u = torch.as_tensor(
|
|
693
|
-
self.rng.uniform(size=X.shape[0]),
|
|
694
|
-
device=X.device
|
|
695
|
-
)
|
|
666
|
+
u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
|
|
696
667
|
max_delta_freq = self.max_delta_freq
|
|
697
668
|
if isinstance(max_delta_freq, torch.Tensor):
|
|
698
669
|
max_delta_freq = max_delta_freq.to(X.device)
|
|
@@ -718,12 +689,13 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
|
|
|
718
689
|
matrices that will be fed to `SensorsRotation` transform. By
|
|
719
690
|
default None.
|
|
720
691
|
"""
|
|
721
|
-
assert raw_or_epoch is not None or ordered_ch_names is not None,
|
|
692
|
+
assert raw_or_epoch is not None or ordered_ch_names is not None, (
|
|
722
693
|
"At least one of raw_or_epoch and ordered_ch_names is needed."
|
|
694
|
+
)
|
|
723
695
|
if ordered_ch_names is None:
|
|
724
|
-
ordered_ch_names = raw_or_epoch.info[
|
|
725
|
-
ten_twenty_montage = make_standard_montage(
|
|
726
|
-
positions_dict = ten_twenty_montage.get_positions()[
|
|
696
|
+
ordered_ch_names = raw_or_epoch.info["ch_names"]
|
|
697
|
+
ten_twenty_montage = make_standard_montage("standard_1020")
|
|
698
|
+
positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
|
|
727
699
|
positions_subdict = {
|
|
728
700
|
k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
|
|
729
701
|
}
|
|
@@ -770,37 +742,38 @@ class SensorsRotation(Transform):
|
|
|
770
742
|
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
771
743
|
(EMBC) (pp. 471-474).
|
|
772
744
|
"""
|
|
773
|
-
|
|
745
|
+
|
|
746
|
+
operation = staticmethod(sensors_rotation) # type: ignore[assignment]
|
|
774
747
|
|
|
775
748
|
def __init__(
|
|
776
749
|
self,
|
|
777
750
|
probability,
|
|
778
751
|
sensors_positions_matrix,
|
|
779
|
-
axis=
|
|
752
|
+
axis="z",
|
|
780
753
|
max_degrees=15,
|
|
781
754
|
spherical_splines=True,
|
|
782
|
-
random_state=None
|
|
755
|
+
random_state=None,
|
|
783
756
|
):
|
|
784
|
-
super().__init__(
|
|
785
|
-
probability=probability,
|
|
786
|
-
random_state=random_state
|
|
787
|
-
)
|
|
757
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
788
758
|
if isinstance(sensors_positions_matrix, (np.ndarray, list)):
|
|
789
|
-
sensors_positions_matrix = torch.as_tensor(
|
|
790
|
-
|
|
791
|
-
)
|
|
792
|
-
assert isinstance(sensors_positions_matrix, torch.Tensor), \
|
|
759
|
+
sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
|
|
760
|
+
assert isinstance(sensors_positions_matrix, torch.Tensor), (
|
|
793
761
|
"sensors_positions should be an Tensor"
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
max_degrees
|
|
797
|
-
)
|
|
798
|
-
assert isinstance(axis, str) and axis in [
|
|
799
|
-
"
|
|
800
|
-
|
|
762
|
+
)
|
|
763
|
+
assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
|
|
764
|
+
"max_degrees should be non-negative float."
|
|
765
|
+
)
|
|
766
|
+
assert isinstance(axis, str) and axis in [
|
|
767
|
+
"x",
|
|
768
|
+
"y",
|
|
769
|
+
"z",
|
|
770
|
+
], "axis can be either x, y or z."
|
|
771
|
+
assert sensors_positions_matrix.shape[0] == 3, (
|
|
801
772
|
"sensors_positions_matrix shape should be 3 x n_channels."
|
|
802
|
-
|
|
773
|
+
)
|
|
774
|
+
assert isinstance(spherical_splines, bool), (
|
|
803
775
|
"spherical_splines should be a boolean"
|
|
776
|
+
)
|
|
804
777
|
self.sensors_positions_matrix = sensors_positions_matrix
|
|
805
778
|
self.axis = axis
|
|
806
779
|
self.spherical_splines = spherical_splines
|
|
@@ -841,21 +814,18 @@ class SensorsRotation(Transform):
|
|
|
841
814
|
return super().get_augmentation_params(*batch)
|
|
842
815
|
X = batch[0]
|
|
843
816
|
|
|
844
|
-
u = self.rng.uniform(
|
|
845
|
-
low=0,
|
|
846
|
-
high=1,
|
|
847
|
-
size=X.shape[0]
|
|
848
|
-
)
|
|
817
|
+
u = self.rng.uniform(low=0, high=1, size=X.shape[0])
|
|
849
818
|
max_degrees = self.max_degrees
|
|
850
819
|
if isinstance(max_degrees, torch.Tensor):
|
|
851
820
|
max_degrees = max_degrees.to(X.device)
|
|
852
|
-
random_angles =
|
|
853
|
-
u, device=X.device) * 2 * max_degrees - max_degrees
|
|
821
|
+
random_angles = (
|
|
822
|
+
torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
|
|
823
|
+
)
|
|
854
824
|
return {
|
|
855
825
|
"sensors_positions_matrix": self.sensors_positions_matrix,
|
|
856
826
|
"axis": self.axis,
|
|
857
827
|
"angles": random_angles,
|
|
858
|
-
"spherical_splines": self.spherical_splines
|
|
828
|
+
"spherical_splines": self.spherical_splines,
|
|
859
829
|
}
|
|
860
830
|
|
|
861
831
|
|
|
@@ -900,7 +870,7 @@ class SensorsZRotation(SensorsRotation):
|
|
|
900
870
|
ordered_ch_names,
|
|
901
871
|
max_degrees=15,
|
|
902
872
|
spherical_splines=True,
|
|
903
|
-
random_state=None
|
|
873
|
+
random_state=None,
|
|
904
874
|
):
|
|
905
875
|
sensors_positions_matrix = torch.as_tensor(
|
|
906
876
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -908,10 +878,10 @@ class SensorsZRotation(SensorsRotation):
|
|
|
908
878
|
super().__init__(
|
|
909
879
|
probability=probability,
|
|
910
880
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
911
|
-
axis=
|
|
881
|
+
axis="z",
|
|
912
882
|
max_degrees=max_degrees,
|
|
913
883
|
spherical_splines=spherical_splines,
|
|
914
|
-
random_state=random_state
|
|
884
|
+
random_state=random_state,
|
|
915
885
|
)
|
|
916
886
|
|
|
917
887
|
|
|
@@ -956,7 +926,7 @@ class SensorsYRotation(SensorsRotation):
|
|
|
956
926
|
ordered_ch_names,
|
|
957
927
|
max_degrees=15,
|
|
958
928
|
spherical_splines=True,
|
|
959
|
-
random_state=None
|
|
929
|
+
random_state=None,
|
|
960
930
|
):
|
|
961
931
|
sensors_positions_matrix = torch.as_tensor(
|
|
962
932
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -964,10 +934,10 @@ class SensorsYRotation(SensorsRotation):
|
|
|
964
934
|
super().__init__(
|
|
965
935
|
probability=probability,
|
|
966
936
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
967
|
-
axis=
|
|
937
|
+
axis="y",
|
|
968
938
|
max_degrees=max_degrees,
|
|
969
939
|
spherical_splines=spherical_splines,
|
|
970
|
-
random_state=random_state
|
|
940
|
+
random_state=random_state,
|
|
971
941
|
)
|
|
972
942
|
|
|
973
943
|
|
|
@@ -1012,7 +982,7 @@ class SensorsXRotation(SensorsRotation):
|
|
|
1012
982
|
ordered_ch_names,
|
|
1013
983
|
max_degrees=15,
|
|
1014
984
|
spherical_splines=True,
|
|
1015
|
-
random_state=None
|
|
985
|
+
random_state=None,
|
|
1016
986
|
):
|
|
1017
987
|
sensors_positions_matrix = torch.as_tensor(
|
|
1018
988
|
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
@@ -1020,10 +990,10 @@ class SensorsXRotation(SensorsRotation):
|
|
|
1020
990
|
super().__init__(
|
|
1021
991
|
probability=probability,
|
|
1022
992
|
sensors_positions_matrix=sensors_positions_matrix,
|
|
1023
|
-
axis=
|
|
993
|
+
axis="x",
|
|
1024
994
|
max_degrees=max_degrees,
|
|
1025
995
|
spherical_splines=spherical_splines,
|
|
1026
|
-
random_state=random_state
|
|
996
|
+
random_state=random_state,
|
|
1027
997
|
)
|
|
1028
998
|
|
|
1029
999
|
|
|
@@ -1050,17 +1020,13 @@ class Mixup(Transform):
|
|
|
1050
1020
|
Online: https://arxiv.org/abs/1710.09412
|
|
1051
1021
|
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
1052
1022
|
"""
|
|
1053
|
-
operation = staticmethod(mixup)
|
|
1054
1023
|
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
beta_per_sample=False,
|
|
1059
|
-
random_state=None
|
|
1060
|
-
):
|
|
1024
|
+
operation = staticmethod(mixup) # type: ignore[assignment]
|
|
1025
|
+
|
|
1026
|
+
def __init__(self, alpha, beta_per_sample=False, random_state=None):
|
|
1061
1027
|
super().__init__(
|
|
1062
1028
|
probability=1.0, # Mixup has to be applied to whole batches
|
|
1063
|
-
random_state=random_state
|
|
1029
|
+
random_state=random_state,
|
|
1064
1030
|
)
|
|
1065
1031
|
self.alpha = alpha
|
|
1066
1032
|
self.beta_per_sample = beta_per_sample
|
|
@@ -1098,9 +1064,210 @@ class Mixup(Transform):
|
|
|
1098
1064
|
else:
|
|
1099
1065
|
lam = torch.ones(batch_size).to(device)
|
|
1100
1066
|
|
|
1101
|
-
idx_perm = torch.as_tensor(
|
|
1067
|
+
idx_perm = torch.as_tensor(
|
|
1068
|
+
self.rng.permutation(
|
|
1069
|
+
batch_size,
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
1102
1072
|
|
|
1103
1073
|
return {
|
|
1104
1074
|
"lam": lam,
|
|
1105
1075
|
"idx_perm": idx_perm,
|
|
1106
1076
|
}
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
class SegmentationReconstruction(Transform):
|
|
1080
|
+
"""Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
|
|
1081
|
+
|
|
1082
|
+
Applies a segmentation-reconstruction transform to the input data, as
|
|
1083
|
+
proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
|
|
1084
|
+
it to generate new synthetic trials by label, preserving the original
|
|
1085
|
+
order of the segments in time domain.
|
|
1086
|
+
|
|
1087
|
+
Parameters
|
|
1088
|
+
----------
|
|
1089
|
+
probability : float
|
|
1090
|
+
Float setting the probability of applying the operation.
|
|
1091
|
+
random_state: int | numpy.random.Generator, optional
|
|
1092
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1093
|
+
Used to decide whether to transform given the probability
|
|
1094
|
+
argument and to sample the segments mixing. Defaults to None.
|
|
1095
|
+
n_segments : int, optional
|
|
1096
|
+
Number of segments to use in the batch. If None, X will be
|
|
1097
|
+
automatically segmented, getting the last element in a list
|
|
1098
|
+
of factors of the number of samples's square root. Defaults to None.
|
|
1099
|
+
|
|
1100
|
+
References
|
|
1101
|
+
----------
|
|
1102
|
+
.. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
|
|
1103
|
+
or suppress calibration time in oscillatory activity-based brain–computer
|
|
1104
|
+
interfaces. Proceedings of the IEEE, 103(6), 871-890.
|
|
1105
|
+
"""
|
|
1106
|
+
|
|
1107
|
+
operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
|
|
1108
|
+
|
|
1109
|
+
def __init__(
|
|
1110
|
+
self,
|
|
1111
|
+
probability,
|
|
1112
|
+
n_segments=None,
|
|
1113
|
+
random_state=None,
|
|
1114
|
+
):
|
|
1115
|
+
super().__init__(
|
|
1116
|
+
probability=probability,
|
|
1117
|
+
random_state=random_state,
|
|
1118
|
+
)
|
|
1119
|
+
self.n_segments = n_segments
|
|
1120
|
+
|
|
1121
|
+
def get_augmentation_params(self, *batch):
|
|
1122
|
+
"""Return transform parameters.
|
|
1123
|
+
|
|
1124
|
+
Parameters
|
|
1125
|
+
----------
|
|
1126
|
+
X : tensor.Tensor
|
|
1127
|
+
The data.
|
|
1128
|
+
y : tensor.Tensor
|
|
1129
|
+
The labels.
|
|
1130
|
+
Returns
|
|
1131
|
+
-------
|
|
1132
|
+
params : dict
|
|
1133
|
+
Contains the number of segments to split the signal into.
|
|
1134
|
+
"""
|
|
1135
|
+
X, y = batch[0], batch[1]
|
|
1136
|
+
|
|
1137
|
+
if y is not None:
|
|
1138
|
+
if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
|
|
1139
|
+
raise ValueError("X and y must be torch tensors.")
|
|
1140
|
+
|
|
1141
|
+
if X.shape[0] != y.shape[0]:
|
|
1142
|
+
raise ValueError("Number of samples in X and y must be the same.")
|
|
1143
|
+
|
|
1144
|
+
if self.n_segments is None:
|
|
1145
|
+
self.n_segments = int(X.shape[2])
|
|
1146
|
+
n_segments_list = []
|
|
1147
|
+
for i in range(1, int(self.n_segments**0.5) + 1):
|
|
1148
|
+
if self.n_segments % i == 0:
|
|
1149
|
+
n_segments_list.append(i)
|
|
1150
|
+
self.n_segments = n_segments_list[-1]
|
|
1151
|
+
|
|
1152
|
+
elif not (
|
|
1153
|
+
isinstance(self.n_segments, (int, float))
|
|
1154
|
+
and 1 <= self.n_segments <= X.shape[2]
|
|
1155
|
+
):
|
|
1156
|
+
raise ValueError(
|
|
1157
|
+
f"Number of segments must be a positive integer less than "
|
|
1158
|
+
f"(or equal) the window size. Got {self.n_segments}"
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
if y is None:
|
|
1162
|
+
data_classes = [(np.nan, X)]
|
|
1163
|
+
|
|
1164
|
+
else:
|
|
1165
|
+
classes = torch.unique(y)
|
|
1166
|
+
|
|
1167
|
+
data_classes = [(i, X[y == i]) for i in classes]
|
|
1168
|
+
|
|
1169
|
+
rand_indices = dict()
|
|
1170
|
+
for label, X_class in data_classes:
|
|
1171
|
+
n_trials = X_class.shape[0]
|
|
1172
|
+
rand_indices[label] = self.rng.randint(
|
|
1173
|
+
0, n_trials, (n_trials, self.n_segments)
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
idx_shuffle = self.rng.permutation(X.shape[0])
|
|
1177
|
+
|
|
1178
|
+
return {
|
|
1179
|
+
"n_segments": self.n_segments,
|
|
1180
|
+
"data_classes": data_classes,
|
|
1181
|
+
"rand_indices": rand_indices,
|
|
1182
|
+
"idx_shuffle": idx_shuffle,
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
class MaskEncoding(Transform):
|
|
1187
|
+
"""MaskEncoding from [1]_.
|
|
1188
|
+
|
|
1189
|
+
Replaces randomly chosen contiguous part (or parts) of all channels by
|
|
1190
|
+
zeros (if more than one segment, it may overlap).
|
|
1191
|
+
|
|
1192
|
+
Implementation based on [1]_
|
|
1193
|
+
|
|
1194
|
+
Parameters
|
|
1195
|
+
----------
|
|
1196
|
+
probability : float
|
|
1197
|
+
Float setting the probability of applying the operation.
|
|
1198
|
+
max_mask_ratio: float, optional
|
|
1199
|
+
Signal ratio to zero out. Defaults to 0.1.
|
|
1200
|
+
n_segments : int, optional
|
|
1201
|
+
Number of segments to zero out in each example.
|
|
1202
|
+
Defaults to 1.
|
|
1203
|
+
random_state: int | numpy.random.Generator, optional
|
|
1204
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1205
|
+
Defaults to None.
|
|
1206
|
+
|
|
1207
|
+
References
|
|
1208
|
+
----------
|
|
1209
|
+
.. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
|
|
1210
|
+
Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
|
|
1211
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering
|
|
1212
|
+
32 (2024): 875-886.
|
|
1213
|
+
"""
|
|
1214
|
+
|
|
1215
|
+
operation = staticmethod(mask_encoding) # type: ignore[assignment]
|
|
1216
|
+
|
|
1217
|
+
def __init__(
|
|
1218
|
+
self,
|
|
1219
|
+
probability,
|
|
1220
|
+
max_mask_ratio=0.1,
|
|
1221
|
+
n_segments=1,
|
|
1222
|
+
random_state=None,
|
|
1223
|
+
):
|
|
1224
|
+
super().__init__(
|
|
1225
|
+
probability=probability,
|
|
1226
|
+
random_state=random_state,
|
|
1227
|
+
)
|
|
1228
|
+
assert isinstance(n_segments, int) and n_segments > 0, (
|
|
1229
|
+
"n_segments should be a positive integer."
|
|
1230
|
+
)
|
|
1231
|
+
assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
|
|
1232
|
+
"mask_ratio should be a float between 0 and 1."
|
|
1233
|
+
)
|
|
1234
|
+
|
|
1235
|
+
self.mask_ratio = max_mask_ratio
|
|
1236
|
+
self.n_segments = n_segments
|
|
1237
|
+
|
|
1238
|
+
def get_augmentation_params(self, *batch):
|
|
1239
|
+
"""Return transform parameters.
|
|
1240
|
+
|
|
1241
|
+
Parameters
|
|
1242
|
+
----------
|
|
1243
|
+
X : tensor.Tensor
|
|
1244
|
+
The data.
|
|
1245
|
+
y : tensor.Tensor
|
|
1246
|
+
The labels.
|
|
1247
|
+
Returns
|
|
1248
|
+
-------
|
|
1249
|
+
params : dict
|
|
1250
|
+
Contains ...
|
|
1251
|
+
"""
|
|
1252
|
+
if len(batch) == 0:
|
|
1253
|
+
return super().get_augmentation_params(*batch)
|
|
1254
|
+
X = batch[0]
|
|
1255
|
+
|
|
1256
|
+
batch_size, _, n_times = X.shape
|
|
1257
|
+
|
|
1258
|
+
segment_length = int((n_times * self.mask_ratio) / self.n_segments)
|
|
1259
|
+
|
|
1260
|
+
assert segment_length >= 1, (
|
|
1261
|
+
"n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
|
|
1262
|
+
)
|
|
1263
|
+
|
|
1264
|
+
time_start = self.rng.randint(
|
|
1265
|
+
0, n_times - segment_length, (batch_size, self.n_segments)
|
|
1266
|
+
)
|
|
1267
|
+
time_start = torch.from_numpy(time_start)
|
|
1268
|
+
|
|
1269
|
+
return {
|
|
1270
|
+
"time_start": time_start,
|
|
1271
|
+
"segment_length": segment_length,
|
|
1272
|
+
"n_segments": self.n_segments,
|
|
1273
|
+
}
|