braindecode 1.3.0.dev181065563__py3-none-any.whl → 1.3.0.dev181594385__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +33 -34
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/attn_sleep.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/contrawr.py +2 -0
- braindecode/models/ctnet.py +8 -3
- braindecode/models/deepsleepnet.py +28 -19
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegitnet.py +2 -0
- braindecode/models/eegminer.py +2 -0
- braindecode/models/eegnet.py +1 -1
- braindecode/models/eegtcnet.py +2 -0
- braindecode/models/fbcnet.py +2 -0
- braindecode/models/fblightconvnet.py +2 -0
- braindecode/models/fbmsnet.py +2 -0
- braindecode/models/ifnet.py +2 -0
- braindecode/models/labram.py +193 -87
- braindecode/models/msvtnet.py +2 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +111 -27
- braindecode/models/sinc_shallow.py +12 -9
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +9 -6
- braindecode/models/syncnet.py +2 -0
- braindecode/models/tcn.py +2 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +23 -14
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/METADATA +6 -2
- braindecode-1.3.0.dev181594385.dist-info/RECORD +106 -0
- braindecode-1.3.0.dev181065563.dist-info/RECORD +0 -101
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/top_level.txt +0 -0
braindecode/augmentation/base.py
CHANGED
|
@@ -189,7 +189,7 @@ class AugmentedDataLoader(DataLoader):
|
|
|
189
189
|
|
|
190
190
|
Parameters
|
|
191
191
|
----------
|
|
192
|
-
dataset :
|
|
192
|
+
dataset : RecordDataset
|
|
193
193
|
The dataset containing the signals.
|
|
194
194
|
transforms : list | Transform, optional
|
|
195
195
|
Transform or sequence of Transform to be applied to each batch.
|
|
@@ -4,9 +4,13 @@
|
|
|
4
4
|
#
|
|
5
5
|
# License: BSD (3-clause)
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
7
9
|
from numbers import Real
|
|
10
|
+
from typing import Literal
|
|
8
11
|
|
|
9
12
|
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
10
14
|
import torch
|
|
11
15
|
from mne.filter import notch_filter
|
|
12
16
|
from scipy.interpolate import Rbf
|
|
@@ -15,7 +19,7 @@ from torch.fft import fft, ifft
|
|
|
15
19
|
from torch.nn.functional import one_hot, pad
|
|
16
20
|
|
|
17
21
|
|
|
18
|
-
def identity(X, y):
|
|
22
|
+
def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
19
23
|
"""Identity operation.
|
|
20
24
|
|
|
21
25
|
Parameters
|
|
@@ -35,7 +39,7 @@ def identity(X, y):
|
|
|
35
39
|
return X, y
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
def time_reverse(X, y):
|
|
42
|
+
def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
39
43
|
"""Flip the time axis of each input.
|
|
40
44
|
|
|
41
45
|
Parameters
|
|
@@ -55,7 +59,7 @@ def time_reverse(X, y):
|
|
|
55
59
|
return torch.flip(X, [-1]), y
|
|
56
60
|
|
|
57
61
|
|
|
58
|
-
def sign_flip(X, y):
|
|
62
|
+
def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
59
63
|
"""Flip the sign axis of each input.
|
|
60
64
|
|
|
61
65
|
Parameters
|
|
@@ -75,7 +79,13 @@ def sign_flip(X, y):
|
|
|
75
79
|
return -X, y
|
|
76
80
|
|
|
77
81
|
|
|
78
|
-
def _new_random_fft_phase_odd(
|
|
82
|
+
def _new_random_fft_phase_odd(
|
|
83
|
+
batch_size: int,
|
|
84
|
+
c: int,
|
|
85
|
+
n: int,
|
|
86
|
+
device: torch.device,
|
|
87
|
+
random_state: int | np.random.RandomState | None,
|
|
88
|
+
) -> torch.Tensor:
|
|
79
89
|
rng = check_random_state(random_state)
|
|
80
90
|
random_phase = torch.from_numpy(
|
|
81
91
|
2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
|
|
@@ -90,7 +100,13 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
|
|
|
90
100
|
)
|
|
91
101
|
|
|
92
102
|
|
|
93
|
-
def _new_random_fft_phase_even(
|
|
103
|
+
def _new_random_fft_phase_even(
|
|
104
|
+
batch_size: int,
|
|
105
|
+
c: int,
|
|
106
|
+
n: int,
|
|
107
|
+
device: torch.device,
|
|
108
|
+
random_state: int | np.random.RandomState | None,
|
|
109
|
+
) -> torch.Tensor:
|
|
94
110
|
rng = check_random_state(random_state)
|
|
95
111
|
random_phase = torch.from_numpy(
|
|
96
112
|
2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
|
|
@@ -109,7 +125,13 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
|
|
|
109
125
|
_new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
|
|
110
126
|
|
|
111
127
|
|
|
112
|
-
def ft_surrogate(
|
|
128
|
+
def ft_surrogate(
|
|
129
|
+
X: torch.Tensor,
|
|
130
|
+
y: torch.Tensor,
|
|
131
|
+
phase_noise_magnitude: float,
|
|
132
|
+
channel_indep: bool,
|
|
133
|
+
random_state: int | np.random.RandomState | None = None,
|
|
134
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
113
135
|
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
|
|
114
136
|
|
|
115
137
|
Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
|
|
@@ -175,7 +197,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
|
|
|
175
197
|
return transformed_X, y
|
|
176
198
|
|
|
177
199
|
|
|
178
|
-
def _pick_channels_randomly(
|
|
200
|
+
def _pick_channels_randomly(
|
|
201
|
+
X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
|
|
202
|
+
) -> torch.Tensor:
|
|
179
203
|
rng = check_random_state(random_state)
|
|
180
204
|
batch_size, n_channels, _ = X.shape
|
|
181
205
|
# allows to use the same RNG
|
|
@@ -188,7 +212,12 @@ def _pick_channels_randomly(X, p_pick, random_state):
|
|
|
188
212
|
return torch.sigmoid(1000 * (unif_samples - p_pick))
|
|
189
213
|
|
|
190
214
|
|
|
191
|
-
def channels_dropout(
|
|
215
|
+
def channels_dropout(
|
|
216
|
+
X: torch.Tensor,
|
|
217
|
+
y: torch.Tensor,
|
|
218
|
+
p_drop: float,
|
|
219
|
+
random_state: int | np.random.RandomState | None = None,
|
|
220
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
192
221
|
"""Randomly set channels to flat signal.
|
|
193
222
|
|
|
194
223
|
Part of the CMSAugment policy proposed in [1]_
|
|
@@ -222,7 +251,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
|
|
|
222
251
|
return X * mask.unsqueeze(-1), y
|
|
223
252
|
|
|
224
253
|
|
|
225
|
-
def _make_permutation_matrix(
|
|
254
|
+
def _make_permutation_matrix(
|
|
255
|
+
X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
|
|
256
|
+
) -> torch.Tensor:
|
|
226
257
|
rng = check_random_state(random_state)
|
|
227
258
|
batch_size, n_channels, _ = X.shape
|
|
228
259
|
hard_mask = mask.round()
|
|
@@ -241,7 +272,12 @@ def _make_permutation_matrix(X, mask, random_state):
|
|
|
241
272
|
return batch_permutations
|
|
242
273
|
|
|
243
274
|
|
|
244
|
-
def channels_shuffle(
|
|
275
|
+
def channels_shuffle(
|
|
276
|
+
X: torch.Tensor,
|
|
277
|
+
y: torch.Tensor,
|
|
278
|
+
p_shuffle: float,
|
|
279
|
+
random_state: int | np.random.RandomState | None = None,
|
|
280
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
245
281
|
"""Randomly shuffle channels in EEG data matrix.
|
|
246
282
|
|
|
247
283
|
Part of the CMSAugment policy proposed in [1]_
|
|
@@ -280,7 +316,12 @@ def channels_shuffle(X, y, p_shuffle, random_state=None):
|
|
|
280
316
|
return torch.matmul(batch_permutations, X), y
|
|
281
317
|
|
|
282
318
|
|
|
283
|
-
def gaussian_noise(
|
|
319
|
+
def gaussian_noise(
|
|
320
|
+
X: torch.Tensor,
|
|
321
|
+
y: torch.Tensor,
|
|
322
|
+
std: float,
|
|
323
|
+
random_state: int | np.random.RandomState | None = None,
|
|
324
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
284
325
|
"""Randomly add white Gaussian noise to all channels.
|
|
285
326
|
|
|
286
327
|
Suggested e.g. in [1]_, [2]_ and [3]_
|
|
@@ -332,7 +373,9 @@ def gaussian_noise(X, y, std, random_state=None):
|
|
|
332
373
|
return transformed_X, y
|
|
333
374
|
|
|
334
375
|
|
|
335
|
-
def channels_permute(
|
|
376
|
+
def channels_permute(
|
|
377
|
+
X: torch.Tensor, y: torch.Tensor, permutation: list[int]
|
|
378
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
336
379
|
"""Permute EEG channels according to fixed permutation matrix.
|
|
337
380
|
|
|
338
381
|
Suggested e.g. in [1]_
|
|
@@ -362,7 +405,12 @@ def channels_permute(X, y, permutation):
|
|
|
362
405
|
return X[..., permutation, :], y
|
|
363
406
|
|
|
364
407
|
|
|
365
|
-
def smooth_time_mask(
|
|
408
|
+
def smooth_time_mask(
|
|
409
|
+
X: torch.Tensor,
|
|
410
|
+
y: torch.Tensor,
|
|
411
|
+
mask_start_per_sample: torch.Tensor,
|
|
412
|
+
mask_len_samples: int,
|
|
413
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
366
414
|
"""Smoothly replace a contiguous part of all channels by zeros.
|
|
367
415
|
|
|
368
416
|
Originally proposed in [1]_ and [2]_
|
|
@@ -412,7 +460,13 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
|
|
|
412
460
|
return X * mask, y
|
|
413
461
|
|
|
414
462
|
|
|
415
|
-
def bandstop_filter(
|
|
463
|
+
def bandstop_filter(
|
|
464
|
+
X: torch.Tensor,
|
|
465
|
+
y: torch.Tensor,
|
|
466
|
+
sfreq: float,
|
|
467
|
+
bandwidth: float,
|
|
468
|
+
freqs_to_notch: npt.ArrayLike | None,
|
|
469
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
416
470
|
"""Apply a band-stop filter with desired bandwidth at the desired frequency
|
|
417
471
|
position.
|
|
418
472
|
|
|
@@ -451,7 +505,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
|
|
|
451
505
|
Representation Learning for Electroencephalogram Classification. In
|
|
452
506
|
Machine Learning for Health (pp. 238-253). PMLR.
|
|
453
507
|
"""
|
|
454
|
-
if bandwidth == 0:
|
|
508
|
+
if bandwidth == 0 or freqs_to_notch is None:
|
|
455
509
|
return X, y
|
|
456
510
|
transformed_X = X.clone()
|
|
457
511
|
for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
|
|
@@ -469,7 +523,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
|
|
|
469
523
|
return transformed_X, y
|
|
470
524
|
|
|
471
525
|
|
|
472
|
-
def _analytic_transform(x):
|
|
526
|
+
def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
|
|
473
527
|
if torch.is_complex(x):
|
|
474
528
|
raise ValueError("x must be real.")
|
|
475
529
|
|
|
@@ -486,12 +540,12 @@ def _analytic_transform(x):
|
|
|
486
540
|
return ifft(f * h, dim=-1)
|
|
487
541
|
|
|
488
542
|
|
|
489
|
-
def _nextpow2(n):
|
|
543
|
+
def _nextpow2(n: int) -> int:
|
|
490
544
|
"""Return the first integer N such that 2**N >= abs(n)."""
|
|
491
545
|
return int(np.ceil(np.log2(np.abs(n))))
|
|
492
546
|
|
|
493
547
|
|
|
494
|
-
def _frequency_shift(X, fs, f_shift):
|
|
548
|
+
def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
|
|
495
549
|
"""
|
|
496
550
|
Shift the specified signal by the specified frequency.
|
|
497
551
|
|
|
@@ -504,9 +558,13 @@ def _frequency_shift(X, fs, f_shift):
|
|
|
504
558
|
t = torch.arange(N_padded, device=X.device) / fs
|
|
505
559
|
padded = pad(X, (0, N_padded - N_orig))
|
|
506
560
|
analytical = _analytic_transform(padded)
|
|
507
|
-
if isinstance(f_shift,
|
|
508
|
-
|
|
509
|
-
|
|
561
|
+
if isinstance(f_shift, torch.Tensor):
|
|
562
|
+
_f_shift = f_shift
|
|
563
|
+
elif isinstance(f_shift, (float, int, np.ndarray, list)):
|
|
564
|
+
_f_shift = torch.as_tensor(f_shift).float()
|
|
565
|
+
else:
|
|
566
|
+
raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
|
|
567
|
+
f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
|
|
510
568
|
reshaped_f_shift = f_shift_stack.permute(
|
|
511
569
|
*torch.arange(f_shift_stack.ndim - 1, -1, -1)
|
|
512
570
|
)
|
|
@@ -514,7 +572,9 @@ def _frequency_shift(X, fs, f_shift):
|
|
|
514
572
|
return shifted[..., :N_orig].real.float()
|
|
515
573
|
|
|
516
574
|
|
|
517
|
-
def frequency_shift(
|
|
575
|
+
def frequency_shift(
|
|
576
|
+
X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
|
|
577
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
518
578
|
"""Adds a shift in the frequency domain to all channels.
|
|
519
579
|
|
|
520
580
|
Note that here, the shift is the same for all channels of a single example.
|
|
@@ -545,7 +605,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
|
|
|
545
605
|
return transformed_X, y
|
|
546
606
|
|
|
547
607
|
|
|
548
|
-
def _torch_normalize_vectors(rr):
|
|
608
|
+
def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
|
|
549
609
|
"""Normalize surface vertices."""
|
|
550
610
|
norm = torch.linalg.norm(rr, axis=1, keepdim=True)
|
|
551
611
|
mask = norm > 0
|
|
@@ -554,7 +614,9 @@ def _torch_normalize_vectors(rr):
|
|
|
554
614
|
return new_rr
|
|
555
615
|
|
|
556
616
|
|
|
557
|
-
def _torch_legval(
|
|
617
|
+
def _torch_legval(
|
|
618
|
+
x: torch.Tensor, c: torch.Tensor, tensor: bool = True
|
|
619
|
+
) -> torch.Tensor:
|
|
558
620
|
"""
|
|
559
621
|
Evaluate a Legendre series at points x.
|
|
560
622
|
If `c` is of length `n + 1`, this function returns the value:
|
|
@@ -662,7 +724,9 @@ def _torch_legval(x, c, tensor=True):
|
|
|
662
724
|
return c0 + c1 * x
|
|
663
725
|
|
|
664
726
|
|
|
665
|
-
def _torch_calc_g(
|
|
727
|
+
def _torch_calc_g(
|
|
728
|
+
cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
|
|
729
|
+
) -> torch.Tensor:
|
|
666
730
|
"""Calculate spherical spline g function between points on a sphere.
|
|
667
731
|
|
|
668
732
|
Parameters
|
|
@@ -718,23 +782,25 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
|
|
|
718
782
|
return _torch_legval(cosang, [0] + factors)
|
|
719
783
|
|
|
720
784
|
|
|
721
|
-
def _torch_make_interpolation_matrix(
|
|
785
|
+
def _torch_make_interpolation_matrix(
|
|
786
|
+
pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
|
|
787
|
+
) -> torch.Tensor:
|
|
722
788
|
"""Compute interpolation matrix based on spherical splines.
|
|
723
789
|
|
|
724
790
|
Implementation based on [1]_
|
|
725
791
|
|
|
726
792
|
Parameters
|
|
727
793
|
----------
|
|
728
|
-
pos_from :
|
|
794
|
+
pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
|
|
729
795
|
The positions to interpolate from.
|
|
730
|
-
pos_to :
|
|
796
|
+
pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
|
|
731
797
|
The positions to interpolate.
|
|
732
798
|
alpha : float
|
|
733
799
|
Regularization parameter. Defaults to 1e-5.
|
|
734
800
|
|
|
735
801
|
Returns
|
|
736
802
|
-------
|
|
737
|
-
interpolation :
|
|
803
|
+
interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
|
|
738
804
|
The interpolation matrix that maps good signals to the location
|
|
739
805
|
of bad signals.
|
|
740
806
|
|
|
@@ -822,7 +888,12 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
|
|
|
822
888
|
return interpolation
|
|
823
889
|
|
|
824
890
|
|
|
825
|
-
def _rotate_signals(
|
|
891
|
+
def _rotate_signals(
|
|
892
|
+
X: torch.Tensor,
|
|
893
|
+
rotations: list[torch.Tensor],
|
|
894
|
+
sensors_positions_matrix: torch.Tensor,
|
|
895
|
+
spherical: bool = True,
|
|
896
|
+
) -> torch.Tensor:
|
|
826
897
|
sensors_positions_matrix = sensors_positions_matrix.to(X.device)
|
|
827
898
|
rot_sensors_matrices = [
|
|
828
899
|
rotation.matmul(sensors_positions_matrix) for rotation in rotations
|
|
@@ -853,22 +924,29 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
|
|
|
853
924
|
return transformed_X
|
|
854
925
|
|
|
855
926
|
|
|
856
|
-
def _make_rotation_matrix(
|
|
927
|
+
def _make_rotation_matrix(
|
|
928
|
+
axis: Literal["x", "y", "z"],
|
|
929
|
+
angle: float | int | np.ndarray | list | torch.Tensor,
|
|
930
|
+
degrees: bool = True,
|
|
931
|
+
) -> torch.Tensor:
|
|
857
932
|
assert axis in ["x", "y", "z"], "axis should be either x, y or z."
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
933
|
+
if isinstance(angle, torch.Tensor):
|
|
934
|
+
_angle = angle
|
|
935
|
+
elif isinstance(angle, (float, int, np.ndarray, list)):
|
|
936
|
+
_angle = torch.as_tensor(angle)
|
|
937
|
+
else:
|
|
938
|
+
raise ValueError(f"Invalid angle type: {type(angle)}")
|
|
861
939
|
|
|
862
940
|
if degrees:
|
|
863
|
-
|
|
941
|
+
_angle = _angle * np.pi / 180
|
|
864
942
|
|
|
865
|
-
device =
|
|
943
|
+
device = _angle.device
|
|
866
944
|
zero = torch.zeros(1, device=device)
|
|
867
945
|
rot = torch.stack(
|
|
868
946
|
[
|
|
869
947
|
torch.as_tensor([1, 0, 0], device=device),
|
|
870
|
-
torch.hstack([zero, torch.cos(
|
|
871
|
-
torch.hstack([zero, torch.sin(
|
|
948
|
+
torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
|
|
949
|
+
torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
|
|
872
950
|
]
|
|
873
951
|
)
|
|
874
952
|
if axis == "x":
|
|
@@ -881,7 +959,14 @@ def _make_rotation_matrix(axis, angle, degrees=True):
|
|
|
881
959
|
return rot[:, [1, 2, 0]]
|
|
882
960
|
|
|
883
961
|
|
|
884
|
-
def sensors_rotation(
|
|
962
|
+
def sensors_rotation(
|
|
963
|
+
X: torch.Tensor,
|
|
964
|
+
y: torch.Tensor,
|
|
965
|
+
sensors_positions_matrix: torch.Tensor,
|
|
966
|
+
axis: Literal["x", "y", "z"],
|
|
967
|
+
angles: npt.ArrayLike,
|
|
968
|
+
spherical_splines: bool,
|
|
969
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
885
970
|
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
886
971
|
with the desired angle.
|
|
887
972
|
|
|
@@ -893,7 +978,7 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
|
|
|
893
978
|
EEG input example or batch.
|
|
894
979
|
y : torch.Tensor
|
|
895
980
|
EEG labels for the example or batch.
|
|
896
|
-
sensors_positions_matrix :
|
|
981
|
+
sensors_positions_matrix : torch.Tensor
|
|
897
982
|
Matrix giving the positions of each sensor in a 3D cartesian coordinate
|
|
898
983
|
system. Should have shape (3, n_channels), where n_channels is the
|
|
899
984
|
number of channels. Standard 10-20 positions can be obtained from
|
|
@@ -924,7 +1009,9 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
|
|
|
924
1009
|
return rotated_X, y
|
|
925
1010
|
|
|
926
1011
|
|
|
927
|
-
def mixup(
|
|
1012
|
+
def mixup(
|
|
1013
|
+
X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
|
|
1014
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
928
1015
|
"""Mixes two channels of EEG data.
|
|
929
1016
|
|
|
930
1017
|
See [1]_ for details.
|
|
@@ -973,8 +1060,13 @@ def mixup(X, y, lam, idx_perm):
|
|
|
973
1060
|
|
|
974
1061
|
|
|
975
1062
|
def segmentation_reconstruction(
|
|
976
|
-
X
|
|
977
|
-
|
|
1063
|
+
X: torch.Tensor,
|
|
1064
|
+
y: torch.Tensor,
|
|
1065
|
+
n_segments: int,
|
|
1066
|
+
data_classes: list[tuple[int, torch.Tensor]],
|
|
1067
|
+
rand_indices: npt.ArrayLike,
|
|
1068
|
+
idx_shuffle: npt.ArrayLike,
|
|
1069
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
978
1070
|
"""Segment and reconstruct EEG data from [1]_.
|
|
979
1071
|
|
|
980
1072
|
See [1]_ for details.
|
|
@@ -987,6 +1079,8 @@ def segmentation_reconstruction(
|
|
|
987
1079
|
EEG labels for the example or batch.
|
|
988
1080
|
n_segments : int
|
|
989
1081
|
Number of segments to use in the batch.
|
|
1082
|
+
data_classes: list[tuple[int, torch.Tensor]]
|
|
1083
|
+
List of tuples. Each tuple contains the class index and the corresponding EEG data.
|
|
990
1084
|
rand_indices: array-like
|
|
991
1085
|
Array of indices that indicates which trial to use in each segment.
|
|
992
1086
|
idx_shuffle: array-like
|
|
@@ -1005,8 +1099,8 @@ def segmentation_reconstruction(
|
|
|
1005
1099
|
"""
|
|
1006
1100
|
|
|
1007
1101
|
# Initialize lists to store augmented data and corresponding labels
|
|
1008
|
-
aug_data = []
|
|
1009
|
-
aug_label = []
|
|
1102
|
+
aug_data: list[torch.Tensor] = []
|
|
1103
|
+
aug_label: list[torch.Tensor] = []
|
|
1010
1104
|
|
|
1011
1105
|
# Iterate through each class to separate and augment data
|
|
1012
1106
|
for class_index, X_class in data_classes:
|
|
@@ -1030,20 +1124,26 @@ def segmentation_reconstruction(
|
|
|
1030
1124
|
aug_data.append(X_aug)
|
|
1031
1125
|
aug_label.append(torch.full((n_trials,), class_index))
|
|
1032
1126
|
# Concatenate the augmented data and labels
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1127
|
+
concat_aug_data = torch.cat(aug_data, dim=0)
|
|
1128
|
+
concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
|
|
1129
|
+
concat_aug_data = concat_aug_data[idx_shuffle]
|
|
1036
1130
|
|
|
1037
1131
|
if y is not None:
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
return
|
|
1132
|
+
concat_label = torch.cat(aug_label, dim=0)
|
|
1133
|
+
concat_label = concat_label.to(dtype=y.dtype, device=y.device)
|
|
1134
|
+
concat_label = concat_label[idx_shuffle]
|
|
1135
|
+
return concat_aug_data, concat_label
|
|
1042
1136
|
|
|
1043
|
-
return
|
|
1137
|
+
return concat_aug_data, None
|
|
1044
1138
|
|
|
1045
1139
|
|
|
1046
|
-
def mask_encoding(
|
|
1140
|
+
def mask_encoding(
|
|
1141
|
+
X: torch.Tensor,
|
|
1142
|
+
y: torch.Tensor,
|
|
1143
|
+
time_start: torch.Tensor,
|
|
1144
|
+
segment_length: int,
|
|
1145
|
+
n_segments: int,
|
|
1146
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1047
1147
|
"""Mark encoding from Ding et al. (2024) from [ding2024]_.
|
|
1048
1148
|
|
|
1049
1149
|
Replaces a contiguous part (or parts) of all channels by zeros
|
|
@@ -161,10 +161,10 @@ class ChannelsDropout(Transform):
|
|
|
161
161
|
----------
|
|
162
162
|
probability: float
|
|
163
163
|
Float setting the probability of applying the operation.
|
|
164
|
-
|
|
164
|
+
p_drop: float | None, optional
|
|
165
165
|
Float between 0 and 1 setting the probability of dropping each channel.
|
|
166
166
|
Defaults to 0.2.
|
|
167
|
-
random_state: int | numpy.random.
|
|
167
|
+
random_state: int | numpy.random.RandomState, optional
|
|
168
168
|
Seed to be used to instantiate numpy random number generator instance.
|
|
169
169
|
Used to decide whether or not to transform given the probability
|
|
170
170
|
argument and to sample channels to erase. Defaults to None.
|
braindecode/datasets/__init__.py
CHANGED
|
@@ -2,7 +2,13 @@
|
|
|
2
2
|
Loader code for some datasets.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from .base import
|
|
5
|
+
from .base import (
|
|
6
|
+
BaseConcatDataset,
|
|
7
|
+
EEGWindowsDataset,
|
|
8
|
+
RawDataset,
|
|
9
|
+
RecordDataset,
|
|
10
|
+
WindowsDataset,
|
|
11
|
+
)
|
|
6
12
|
from .bcicomp import BCICompetitionIVDataset4
|
|
7
13
|
from .bids import BIDSDataset, BIDSEpochsDataset
|
|
8
14
|
from .mne import create_from_mne_epochs, create_from_mne_raw
|
|
@@ -15,7 +21,9 @@ from .xy import create_from_X_y
|
|
|
15
21
|
|
|
16
22
|
__all__ = [
|
|
17
23
|
"WindowsDataset",
|
|
18
|
-
"
|
|
24
|
+
"EEGWindowsDataset",
|
|
25
|
+
"RecordDataset",
|
|
26
|
+
"RawDataset",
|
|
19
27
|
"BaseConcatDataset",
|
|
20
28
|
"BIDSDataset",
|
|
21
29
|
"BIDSEpochsDataset",
|