braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,18 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
3
4
  #
4
5
  # License: BSD (3-clause)
5
6
 
6
7
  from numbers import Real
7
8
 
8
9
  import numpy as np
10
+ import torch
11
+ from mne.filter import notch_filter
9
12
  from scipy.interpolate import Rbf
10
13
  from sklearn.utils import check_random_state
11
- import torch
12
14
  from torch.fft import fft, ifft
13
- from torch.nn.functional import pad, one_hot
14
- from mne.filter import notch_filter
15
+ from torch.nn.functional import one_hot, pad
15
16
 
16
17
 
17
18
  def identity(X, y):
@@ -79,11 +80,14 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
79
80
  random_phase = torch.from_numpy(
80
81
  2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
81
82
  ).to(device)
82
- return torch.cat([
83
- torch.zeros((batch_size, c, 1), device=device),
84
- random_phase,
85
- -torch.flip(random_phase, [-1])
86
- ], dim=-1)
83
+ return torch.cat(
84
+ [
85
+ torch.zeros((batch_size, c, 1), device=device),
86
+ random_phase,
87
+ -torch.flip(random_phase, [-1]),
88
+ ],
89
+ dim=-1,
90
+ )
87
91
 
88
92
 
89
93
  def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
@@ -91,27 +95,21 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
91
95
  random_phase = torch.from_numpy(
92
96
  2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
93
97
  ).to(device)
94
- return torch.cat([
95
- torch.zeros((batch_size, c, 1), device=device),
96
- random_phase,
97
- torch.zeros((batch_size, c, 1), device=device),
98
- -torch.flip(random_phase, [-1])
99
- ], dim=-1)
100
-
101
-
102
- _new_random_fft_phase = {
103
- 0: _new_random_fft_phase_even,
104
- 1: _new_random_fft_phase_odd
105
- }
106
-
107
-
108
- def ft_surrogate(
109
- X,
110
- y,
111
- phase_noise_magnitude,
112
- channel_indep,
113
- random_state=None
114
- ):
98
+ return torch.cat(
99
+ [
100
+ torch.zeros((batch_size, c, 1), device=device),
101
+ random_phase,
102
+ torch.zeros((batch_size, c, 1), device=device),
103
+ -torch.flip(random_phase, [-1]),
104
+ ],
105
+ dim=-1,
106
+ )
107
+
108
+
109
+ _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
110
+
111
+
112
+ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
115
113
  """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
116
114
 
117
115
  Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
@@ -148,11 +146,12 @@ def ft_surrogate(
148
146
  Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
149
147
  preprint arXiv:1806.08675.
150
148
  """
151
- assert isinstance(
152
- phase_noise_magnitude,
153
- (Real, torch.FloatTensor, torch.cuda.FloatTensor)
154
- ) and 0 <= phase_noise_magnitude <= 1, (
155
- f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}.")
149
+ assert (
150
+ isinstance(
151
+ phase_noise_magnitude, (Real, torch.FloatTensor, torch.cuda.FloatTensor)
152
+ )
153
+ and 0 <= phase_noise_magnitude <= 1
154
+ ), f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}."
156
155
 
157
156
  f = fft(X.double(), dim=-1)
158
157
  device = X.device
@@ -163,7 +162,7 @@ def ft_surrogate(
163
162
  f.shape[-2] if channel_indep else 1,
164
163
  n,
165
164
  device=device,
166
- random_state=random_state
165
+ random_state=random_state,
167
166
  )
168
167
  if not channel_indep:
169
168
  random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
@@ -186,7 +185,7 @@ def _pick_channels_randomly(X, p_pick, random_state):
186
185
  device=X.device,
187
186
  )
188
187
  # equivalent to a 0s and 1s mask
189
- return torch.sigmoid(1000*(unif_samples - p_pick))
188
+ return torch.sigmoid(1000 * (unif_samples - p_pick))
190
189
 
191
190
 
192
191
  def channels_dropout(X, y, p_drop, random_state=None):
@@ -234,7 +233,8 @@ def _make_permutation_matrix(X, mask, random_state):
234
233
  channels_to_shuffle = torch.arange(n_channels, device=X.device)
235
234
  channels_to_shuffle = channels_to_shuffle[mask.bool()]
236
235
  reordered_channels = torch.tensor(
237
- rng.permutation(channels_to_shuffle.cpu()), device=X.device)
236
+ rng.permutation(channels_to_shuffle.cpu()), device=X.device
237
+ )
238
238
  channels_permutation = torch.arange(n_channels, device=X.device)
239
239
  channels_permutation[channels_to_shuffle] = reordered_channels
240
240
  batch_permutations[b, ...] = one_hot(channels_permutation)
@@ -320,12 +320,14 @@ def gaussian_noise(X, y, std, random_state=None):
320
320
  rng = check_random_state(random_state)
321
321
  if isinstance(std, torch.Tensor):
322
322
  std = std.to(X.device)
323
- noise = torch.from_numpy(
324
- rng.normal(
325
- loc=np.zeros(X.shape),
326
- scale=1
327
- ),
328
- ).float().to(X.device) * std
323
+ noise = (
324
+ torch.from_numpy(
325
+ rng.normal(loc=np.zeros(X.shape), scale=1),
326
+ )
327
+ .float()
328
+ .to(X.device)
329
+ * std
330
+ )
329
331
  transformed_X = X + noise
330
332
  return transformed_X, y
331
333
 
@@ -399,9 +401,14 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
399
401
  t = t.repeat(batch_size, n_channels, 1)
400
402
  mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
401
403
  s = 1000 / seq_len
402
- mask = (torch.sigmoid(s * -(t - mask_start_per_sample)) +
403
- torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
404
- ).float().to(X.device)
404
+ mask = (
405
+ (
406
+ torch.sigmoid(s * -(t - mask_start_per_sample))
407
+ + torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
408
+ )
409
+ .float()
410
+ .to(X.device)
411
+ )
405
412
  return X * mask, y
406
413
 
407
414
 
@@ -447,17 +454,18 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
447
454
  if bandwidth == 0:
448
455
  return X, y
449
456
  transformed_X = X.clone()
450
- for c, (sample, notched_freq) in enumerate(
451
- zip(transformed_X, freqs_to_notch)):
457
+ for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
452
458
  sample = sample.cpu().numpy().astype(np.float64)
453
- transformed_X[c] = torch.as_tensor(notch_filter(
454
- sample,
455
- Fs=sfreq,
456
- freqs=notched_freq,
457
- method='fir',
458
- notch_widths=bandwidth,
459
- verbose=False
460
- ))
459
+ transformed_X[c] = torch.as_tensor(
460
+ notch_filter(
461
+ sample,
462
+ Fs=sfreq,
463
+ freqs=notched_freq,
464
+ method="fir",
465
+ notch_widths=bandwidth,
466
+ verbose=False,
467
+ )
468
+ )
461
469
  return transformed_X, y
462
470
 
463
471
 
@@ -470,10 +478,10 @@ def _analytic_transform(x):
470
478
  h = torch.zeros_like(f)
471
479
  if N % 2 == 0:
472
480
  h[..., 0] = h[..., N // 2] = 1
473
- h[..., 1:N // 2] = 2
481
+ h[..., 1 : N // 2] = 2
474
482
  else:
475
483
  h[..., 0] = 1
476
- h[..., 1:(N + 1) // 2] = 2
484
+ h[..., 1 : (N + 1) // 2] = 2
477
485
 
478
486
  return ifft(f * h, dim=-1)
479
487
 
@@ -500,7 +508,8 @@ def _frequency_shift(X, fs, f_shift):
500
508
  f_shift = torch.as_tensor(f_shift).float()
501
509
  f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
502
510
  reshaped_f_shift = f_shift_stack.permute(
503
- *torch.arange(f_shift_stack.ndim - 1, -1, -1))
511
+ *torch.arange(f_shift_stack.ndim - 1, -1, -1)
512
+ )
504
513
  shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
505
514
  return shifted[..., :N_orig].real.float()
506
515
 
@@ -539,7 +548,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
539
548
  def _torch_normalize_vectors(rr):
540
549
  """Normalize surface vertices."""
541
550
  norm = torch.linalg.norm(rr, axis=1, keepdim=True)
542
- mask = (norm > 0)
551
+ mask = norm > 0
543
552
  norm[~mask] = 1 # in case norm is zero, divide by 1
544
553
  new_rr = rr / norm
545
554
  return new_rr
@@ -631,7 +640,7 @@ def _torch_legval(x, c, tensor=True):
631
640
  if isinstance(x, (tuple, list)):
632
641
  x = torch.as_tensor(x)
633
642
  if isinstance(x, torch.Tensor) and tensor:
634
- c = c.view(c.shape + (1,)*x.ndim)
643
+ c = c.view(c.shape + (1,) * x.ndim)
635
644
 
636
645
  c = c.to(x.device)
637
646
 
@@ -648,9 +657,9 @@ def _torch_legval(x, c, tensor=True):
648
657
  for i in range(3, len(c) + 1):
649
658
  tmp = c0
650
659
  nd = nd - 1
651
- c0 = c[-i] - (c1*(nd - 1))/nd
652
- c1 = tmp + (c1*x*(2*nd - 1))/nd
653
- return c0 + c1*x
660
+ c0 = c[-i] - (c1 * (nd - 1)) / nd
661
+ c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
662
+ return c0 + c1 * x
654
663
 
655
664
 
656
665
  def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
@@ -702,9 +711,10 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
702
711
  OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
703
712
  DAMAGE.
704
713
  """
705
- factors = [(2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness *
706
- 4 * np.pi)
707
- for n in range(1, n_legendre_terms + 1)]
714
+ factors = [
715
+ (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
716
+ for n in range(1, n_legendre_terms + 1)
717
+ ]
708
718
  return _torch_legval(cosang, [0] + factors)
709
719
 
710
720
 
@@ -783,15 +793,20 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
783
793
  assert G_to_from.shape == (n_to, n_from)
784
794
 
785
795
  if alpha is not None:
786
- G_from.flatten()[::len(G_from) + 1] += alpha
796
+ G_from.flatten()[:: len(G_from) + 1] += alpha
787
797
 
788
798
  device = G_from.device
789
- C = torch.vstack([
799
+ C = torch.vstack(
800
+ [
790
801
  torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
791
- torch.hstack([
792
- torch.ones((1, n_from), device=device),
793
- torch.as_tensor([[0]], device=device)])
794
- ])
802
+ torch.hstack(
803
+ [
804
+ torch.ones((1, n_from), device=device),
805
+ torch.as_tensor([[0]], device=device),
806
+ ]
807
+ ),
808
+ ]
809
+ )
795
810
 
796
811
  try:
797
812
  C_inv = torch.linalg.inv(C)
@@ -800,10 +815,9 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
800
815
  # see https://github.com/pytorch/pytorch/issues/75494
801
816
  C_inv = torch.linalg.pinv(C.cpu()).to(device)
802
817
 
803
- interpolation = torch.hstack([
804
- G_to_from,
805
- torch.ones((n_to, 1), device=device)
806
- ]).matmul(C_inv[:, :-1])
818
+ interpolation = torch.hstack(
819
+ [G_to_from, torch.ones((n_to, 1), device=device)]
820
+ ).matmul(C_inv[:, :-1])
807
821
  assert interpolation.shape == (n_to, n_from)
808
822
  return interpolation
809
823
 
@@ -815,11 +829,15 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
815
829
  ]
816
830
  if spherical:
817
831
  interpolation_matrix = torch.stack(
818
- [torch.as_tensor(
819
- _torch_make_interpolation_matrix(
820
- sensors_positions_matrix.T, rot_sensors_matrix.T
821
- ), device=X.device
822
- ).float() for rot_sensors_matrix in rot_sensors_matrices]
832
+ [
833
+ torch.as_tensor(
834
+ _torch_make_interpolation_matrix(
835
+ sensors_positions_matrix.T, rot_sensors_matrix.T
836
+ ),
837
+ device=X.device,
838
+ ).float()
839
+ for rot_sensors_matrix in rot_sensors_matrices
840
+ ]
823
841
  )
824
842
  return torch.matmul(interpolation_matrix, X)
825
843
  else:
@@ -836,7 +854,7 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
836
854
 
837
855
 
838
856
  def _make_rotation_matrix(axis, angle, degrees=True):
839
- assert axis in ['x', 'y', 'z'], "axis should be either x, y or z."
857
+ assert axis in ["x", "y", "z"], "axis should be either x, y or z."
840
858
 
841
859
  if isinstance(angle, (float, int, np.ndarray, list)):
842
860
  angle = torch.as_tensor(angle)
@@ -846,11 +864,13 @@ def _make_rotation_matrix(axis, angle, degrees=True):
846
864
 
847
865
  device = angle.device
848
866
  zero = torch.zeros(1, device=device)
849
- rot = torch.stack([
850
- torch.as_tensor([1, 0, 0], device=device),
851
- torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
852
- torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
853
- ])
867
+ rot = torch.stack(
868
+ [
869
+ torch.as_tensor([1, 0, 0], device=device),
870
+ torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
871
+ torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
872
+ ]
873
+ )
854
874
  if axis == "x":
855
875
  return rot
856
876
  elif axis == "y":
@@ -861,8 +881,7 @@ def _make_rotation_matrix(axis, angle, degrees=True):
861
881
  return rot[:, [1, 2, 0]]
862
882
 
863
883
 
864
- def sensors_rotation(X, y, sensors_positions_matrix, axis, angles,
865
- spherical_splines):
884
+ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
866
885
  """Interpolates EEG signals over sensors rotated around the desired axis
867
886
  with the desired angle.
868
887
 
@@ -900,13 +919,8 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles,
900
919
  Conference of the IEEE Engineering in Medicine and Biology Society
901
920
  (EMBC) (pp. 471-474).
902
921
  """
903
- rots = [
904
- _make_rotation_matrix(axis, angle, degrees=True)
905
- for angle in angles
906
- ]
907
- rotated_X = _rotate_signals(
908
- X, rots, sensors_positions_matrix, spherical_splines
909
- )
922
+ rots = [_make_rotation_matrix(axis, angle, degrees=True) for angle in angles]
923
+ rotated_X = _rotate_signals(X, rots, sensors_positions_matrix, spherical_splines)
910
924
  return rotated_X, y
911
925
 
912
926
 
@@ -942,7 +956,7 @@ def mixup(X, y, lam, idx_perm):
942
956
  International Conference on Learning Representations (ICLR)
943
957
  Online: https://arxiv.org/abs/1710.09412
944
958
  .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
945
- """
959
+ """
946
960
  device = X.device
947
961
  batch_size, n_channels, n_times = X.shape
948
962
 
@@ -951,9 +965,132 @@ def mixup(X, y, lam, idx_perm):
951
965
  y_b = torch.arange(batch_size).to(device)
952
966
 
953
967
  for idx in range(batch_size):
954
- X_mix[idx] = lam[idx] * X[idx] \
955
- + (1 - lam[idx]) * X[idx_perm[idx]]
968
+ X_mix[idx] = lam[idx] * X[idx] + (1 - lam[idx]) * X[idx_perm[idx]]
956
969
  y_a[idx] = y[idx]
957
970
  y_b[idx] = y[idx_perm[idx]]
958
971
 
959
972
  return X_mix, (y_a, y_b, lam)
973
+
974
+
975
+ def segmentation_reconstruction(
976
+ X, y, n_segments, data_classes, rand_indices, idx_shuffle
977
+ ):
978
+ """Segment and reconstruct EEG data from [1]_.
979
+
980
+ See [1]_ for details.
981
+
982
+ Parameters
983
+ ----------
984
+ X : torch.Tensor
985
+ EEG input example or batch.
986
+ y : torch.Tensor
987
+ EEG labels for the example or batch.
988
+ n_segments : int
989
+ Number of segments to use in the batch.
990
+ rand_indices: array-like
991
+ Array of indices that indicates which trial to use in each segment.
992
+ idx_shuffle: array-like
993
+ Array of indices to shuffle the new generated trials.
994
+ Returns
995
+ -------
996
+ torch.Tensor
997
+ Transformed inputs.
998
+ torch.Tensor
999
+ Transformed labels.
1000
+ References
1001
+ ----------
1002
+ .. [1] Lotte, F. (2015). Signal processing approaches to minimize or
1003
+ suppress calibration time in oscillatory activity-based brain–computer
1004
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1005
+ """
1006
+
1007
+ # Initialize lists to store augmented data and corresponding labels
1008
+ aug_data = []
1009
+ aug_label = []
1010
+
1011
+ # Iterate through each class to separate and augment data
1012
+ for class_index, X_class in data_classes:
1013
+ # Determine class-specific dimensions
1014
+ # Store the augmented data and the corresponding class labels
1015
+ n_trials, n_channels, window_size = X_class.shape
1016
+ # Segment Size
1017
+ segment_size = window_size // n_segments
1018
+ # Initialize an empty tensor for augmented data
1019
+ X_aug = torch.zeros_like(X_class)
1020
+ # Generate random indices within the class-specific dataset
1021
+ rand_idx = rand_indices[class_index]
1022
+ for idx_segment in range(n_segments):
1023
+ start = idx_segment * segment_size
1024
+ end = (idx_segment + 1) * segment_size
1025
+
1026
+ # Perform the data augmentation
1027
+ X_aug[np.arange(n_trials), :, start:end] = X_class[
1028
+ rand_idx[:, idx_segment], :, start:end
1029
+ ]
1030
+ aug_data.append(X_aug)
1031
+ aug_label.append(torch.full((n_trials,), class_index))
1032
+ # Concatenate the augmented data and labels
1033
+ aug_data = torch.cat(aug_data, dim=0)
1034
+ aug_data = aug_data.to(dtype=X.dtype, device=X.device)
1035
+ aug_data = aug_data[idx_shuffle]
1036
+
1037
+ if y is not None:
1038
+ aug_label = torch.cat(aug_label, dim=0)
1039
+ aug_label = aug_label.to(dtype=y.dtype, device=y.device)
1040
+ aug_label = aug_label[idx_shuffle]
1041
+ return aug_data, aug_label
1042
+
1043
+ return aug_data, y
1044
+
1045
+
1046
+ def mask_encoding(X, y, time_start, segment_length, n_segments):
1047
+ """Mark encoding from Ding et al. (2024) from [ding2024]_.
1048
+
1049
+ Replaces a contiguous part (or parts) of all channels by zeros
1050
+ (if more than one segment, it may overlap).
1051
+
1052
+ Implementation based on [ding2024]_
1053
+
1054
+ Parameters
1055
+ ----------
1056
+ X : torch.Tensor
1057
+ EEG input example or batch.
1058
+ y : torch.Tensor
1059
+ EEG labels for the example or batch.
1060
+ time_start : torch.Tensor
1061
+ Tensor of integers containing the position (in last dimension) where to
1062
+ start masking the signal. Should have "n_segments" times the size of the first
1063
+ dimension of X (i.e. "n_segments" start positions per example in the batch).
1064
+ segment_length : int
1065
+ Length of each segment to zero out.
1066
+ n_segments : int
1067
+ Number of segments to zero out in each example.
1068
+
1069
+ Returns
1070
+ -------
1071
+ torch.Tensor
1072
+ Transformed inputs.
1073
+ torch.Tensor
1074
+ Transformed labels.
1075
+
1076
+ References
1077
+ ----------
1078
+ .. [ding2024] Ding, Wenlong, et al. A Novel Data Augmentation Approach
1079
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI.
1080
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1081
+ 32 (2024): 875-886.
1082
+ """
1083
+
1084
+ batch_indices = torch.arange(X.shape[0]).repeat_interleave(n_segments)
1085
+ start_indices = time_start.flatten()
1086
+ mask_indices = start_indices[:, None] + torch.arange(segment_length)
1087
+
1088
+ # Create a boolean mask with the same shape as X
1089
+ mask = torch.zeros_like(X, dtype=torch.bool)
1090
+ for batch_index, grouped_mask_indices in zip(batch_indices, mask_indices):
1091
+ mask[batch_index, :, grouped_mask_indices] = True
1092
+
1093
+ # Apply the mask to set the values to 0
1094
+ X[mask] = 0
1095
+
1096
+ return X, y # Return the masked tensor and labels