braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
3
4
  #
4
5
  # License: BSD (3-clause)
5
6
 
@@ -11,18 +12,22 @@ import torch
11
12
  from mne.channels import make_standard_montage
12
13
 
13
14
  from .base import Transform
14
- from .functional import bandstop_filter
15
- from .functional import channels_dropout
16
- from .functional import channels_permute
17
- from .functional import channels_shuffle
18
- from .functional import frequency_shift
19
- from .functional import ft_surrogate
20
- from .functional import gaussian_noise
21
- from .functional import mixup
22
- from .functional import sensors_rotation
23
- from .functional import sign_flip
24
- from .functional import smooth_time_mask
25
- from .functional import time_reverse
15
+ from .functional import (
16
+ bandstop_filter,
17
+ channels_dropout,
18
+ channels_permute,
19
+ channels_shuffle,
20
+ frequency_shift,
21
+ ft_surrogate,
22
+ gaussian_noise,
23
+ mask_encoding,
24
+ mixup,
25
+ segmentation_reconstruction,
26
+ sensors_rotation,
27
+ sign_flip,
28
+ smooth_time_mask,
29
+ time_reverse,
30
+ )
26
31
 
27
32
 
28
33
  class TimeReverse(Transform):
@@ -37,7 +42,8 @@ class TimeReverse(Transform):
37
42
  Used to decide whether or not to transform given the probability
38
43
  argument. Defaults to None.
39
44
  """
40
- operation = staticmethod(time_reverse)
45
+
46
+ operation = staticmethod(time_reverse) # type: ignore[assignment]
41
47
 
42
48
  def __init__(
43
49
  self,
@@ -62,17 +68,11 @@ class SignFlip(Transform):
62
68
  Used to decide whether or not to transform given the probability
63
69
  argument. Defaults to None.
64
70
  """
65
- operation = staticmethod(sign_flip)
66
71
 
67
- def __init__(
68
- self,
69
- probability,
70
- random_state=None
71
- ):
72
- super().__init__(
73
- probability=probability,
74
- random_state=random_state
75
- )
72
+ operation = staticmethod(sign_flip) # type: ignore[assignment]
73
+
74
+ def __init__(self, probability, random_state=None):
75
+ super().__init__(probability=probability, random_state=random_state)
76
76
 
77
77
 
78
78
  class FTSurrogate(Transform):
@@ -102,25 +102,26 @@ class FTSurrogate(Transform):
102
102
  Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
103
103
  preprint arXiv:1806.08675.
104
104
  """
105
- operation = staticmethod(ft_surrogate)
105
+
106
+ operation = staticmethod(ft_surrogate) # type: ignore[assignment]
106
107
 
107
108
  def __init__(
108
109
  self,
109
110
  probability,
110
111
  phase_noise_magnitude=1,
111
112
  channel_indep=False,
112
- random_state=None
113
+ random_state=None,
113
114
  ):
114
- super().__init__(
115
- probability=probability,
116
- random_state=random_state
117
- )
118
- assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), \
115
+ super().__init__(probability=probability, random_state=random_state)
116
+ assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
119
117
  "phase_noise_magnitude should be a float."
120
- assert 0 <= phase_noise_magnitude <= 1, \
118
+ )
119
+ assert 0 <= phase_noise_magnitude <= 1, (
121
120
  "phase_noise_magnitude should be between 0 and 1."
121
+ )
122
122
  assert isinstance(channel_indep, bool), (
123
- "channel_indep is expected to be a boolean")
123
+ "channel_indep is expected to be a boolean"
124
+ )
124
125
  self.phase_noise_magnitude = phase_noise_magnitude
125
126
  self.channel_indep = channel_indep
126
127
 
@@ -174,18 +175,11 @@ class ChannelsDropout(Transform):
174
175
  Learning from Heterogeneous EEG Signals with Differentiable Channel
175
176
  Reordering. arXiv preprint arXiv:2010.13694.
176
177
  """
177
- operation = staticmethod(channels_dropout)
178
178
 
179
- def __init__(
180
- self,
181
- probability,
182
- p_drop=0.2,
183
- random_state=None
184
- ):
185
- super().__init__(
186
- probability=probability,
187
- random_state=random_state
188
- )
179
+ operation = staticmethod(channels_dropout) # type: ignore[assignment]
180
+
181
+ def __init__(self, probability, p_drop=0.2, random_state=None):
182
+ super().__init__(probability=probability, random_state=random_state)
189
183
  self.p_drop = p_drop
190
184
 
191
185
  def get_augmentation_params(self, *batch):
@@ -239,18 +233,11 @@ class ChannelsShuffle(Transform):
239
233
  Learning from Heterogeneous EEG Signals with Differentiable Channel
240
234
  Reordering. arXiv preprint arXiv:2010.13694.
241
235
  """
242
- operation = staticmethod(channels_shuffle)
243
236
 
244
- def __init__(
245
- self,
246
- probability,
247
- p_shuffle=0.2,
248
- random_state=None
249
- ):
250
- super().__init__(
251
- probability=probability,
252
- random_state=random_state
253
- )
237
+ operation = staticmethod(channels_shuffle) # type: ignore[assignment]
238
+
239
+ def __init__(self, probability, p_shuffle=0.2, random_state=None):
240
+ super().__init__(probability=probability, random_state=random_state)
254
241
  self.p_shuffle = p_shuffle
255
242
 
256
243
  def get_augmentation_params(self, *batch):
@@ -308,14 +295,10 @@ class GaussianNoise(Transform):
308
295
  Representation Learning for Electroencephalogram Classification. In
309
296
  Machine Learning for Health (pp. 238-253). PMLR.
310
297
  """
311
- operation = staticmethod(gaussian_noise)
312
298
 
313
- def __init__(
314
- self,
315
- probability,
316
- std=0.1,
317
- random_state=None
318
- ):
299
+ operation = staticmethod(gaussian_noise) # type: ignore[assignment]
300
+
301
+ def __init__(self, probability, std=0.1, random_state=None):
319
302
  super().__init__(
320
303
  probability=probability,
321
304
  random_state=random_state,
@@ -373,28 +356,23 @@ class ChannelsSymmetry(Transform):
373
356
  (2018). HAMLET: interpretable human and machine co-learning technique.
374
357
  arXiv preprint arXiv:1803.09702.
375
358
  """
376
- operation = staticmethod(channels_permute)
377
359
 
378
- def __init__(
379
- self,
380
- probability,
381
- ordered_ch_names,
382
- random_state=None
383
- ):
360
+ operation = staticmethod(channels_permute) # type: ignore[assignment]
361
+
362
+ def __init__(self, probability, ordered_ch_names, random_state=None):
384
363
  super().__init__(
385
364
  probability=probability,
386
365
  random_state=random_state,
387
366
  )
388
- assert (
389
- isinstance(ordered_ch_names, list) and
390
- all(isinstance(ch, str) for ch in ordered_ch_names)
367
+ assert isinstance(ordered_ch_names, list) and all(
368
+ isinstance(ch, str) for ch in ordered_ch_names
391
369
  ), "ordered_ch_names should be a list of str."
392
370
 
393
371
  permutation = list()
394
372
  for idx, ch_name in enumerate(ordered_ch_names):
395
373
  new_position = idx
396
374
  # Find digits in channel name (assuming 10-20 system)
397
- d = ''.join(list(filter(str.isdigit, ch_name)))
375
+ d = "".join(list(filter(str.isdigit, ch_name)))
398
376
  if len(d) > 0:
399
377
  d = int(d)
400
378
  if d % 2 == 0: # pair/right electrodes
@@ -454,22 +432,17 @@ class SmoothTimeMask(Transform):
454
432
  Representation Learning for Electroencephalogram Classification. In
455
433
  Machine Learning for Health (pp. 238-253). PMLR.
456
434
  """
457
- operation = staticmethod(smooth_time_mask)
458
435
 
459
- def __init__(
460
- self,
461
- probability,
462
- mask_len_samples=100,
463
- random_state=None
464
- ):
436
+ operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
437
+
438
+ def __init__(self, probability, mask_len_samples=100, random_state=None):
465
439
  super().__init__(
466
440
  probability=probability,
467
441
  random_state=random_state,
468
442
  )
469
443
 
470
444
  assert (
471
- isinstance(mask_len_samples, (int, torch.Tensor)) and
472
- mask_len_samples > 0
445
+ isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
473
446
  ), "mask_len_samples has to be a positive integer"
474
447
  self.mask_len_samples = mask_len_samples
475
448
 
@@ -504,9 +477,14 @@ class SmoothTimeMask(Transform):
504
477
  mask_len_samples = self.mask_len_samples
505
478
  if isinstance(mask_len_samples, torch.Tensor):
506
479
  mask_len_samples = mask_len_samples.to(X.device)
507
- mask_start = torch.as_tensor(self.rng.uniform(
508
- low=0, high=1, size=X.shape[0],
509
- ), device=X.device) * (seq_length - mask_len_samples)
480
+ mask_start = torch.as_tensor(
481
+ self.rng.uniform(
482
+ low=0,
483
+ high=1,
484
+ size=X.shape[0],
485
+ ),
486
+ device=X.device,
487
+ ) * (seq_length - mask_len_samples)
510
488
  return {
511
489
  "mask_start_per_sample": mask_start,
512
490
  "mask_len_samples": mask_len_samples,
@@ -546,27 +524,26 @@ class BandstopFilter(Transform):
546
524
  Representation Learning for Electroencephalogram Classification. In
547
525
  Machine Learning for Health (pp. 238-253). PMLR.
548
526
  """
549
- operation = staticmethod(bandstop_filter)
527
+
528
+ operation = staticmethod(bandstop_filter) # type: ignore[assignment]
550
529
 
551
530
  def __init__(
552
- self,
553
- probability,
554
- sfreq,
555
- bandwidth=1,
556
- max_freq=None,
557
- random_state=None
531
+ self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
558
532
  ):
559
533
  super().__init__(
560
534
  probability=probability,
561
535
  random_state=random_state,
562
536
  )
563
- assert isinstance(bandwidth, Real) and bandwidth >= 0, \
537
+ assert isinstance(bandwidth, Real) and bandwidth >= 0, (
564
538
  "bandwidth should be a non-negative float."
565
- assert isinstance(sfreq, Real) and sfreq > 0, \
539
+ )
540
+ assert isinstance(sfreq, Real) and sfreq > 0, (
566
541
  "sfreq should be a positive float."
542
+ )
567
543
  if max_freq is not None:
568
- assert isinstance(max_freq, Real) and max_freq > 0, \
544
+ assert isinstance(max_freq, Real) and max_freq > 0, (
569
545
  "max_freq should be a positive float."
546
+ )
570
547
  nyq = sfreq / 2
571
548
  if max_freq is None or max_freq > nyq:
572
549
  max_freq = nyq
@@ -575,8 +552,9 @@ class BandstopFilter(Transform):
575
552
  f" Nyquist frequency ({nyq} Hz)."
576
553
  f" Falling back to max_freq = {nyq}."
577
554
  )
578
- assert bandwidth < max_freq, \
555
+ assert bandwidth < max_freq, (
579
556
  f"`bandwidth` needs to be smaller than max_freq={max_freq}"
557
+ )
580
558
 
581
559
  # override bandwidth value when a magnitude is passed
582
560
  self.sfreq = sfreq
@@ -619,7 +597,7 @@ class BandstopFilter(Transform):
619
597
  notched_freqs = self.rng.uniform(
620
598
  low=1 + 2 * self.bandwidth,
621
599
  high=self.max_freq - 1 - 2 * self.bandwidth,
622
- size=X.shape[0]
600
+ size=X.shape[0],
623
601
  )
624
602
  return {
625
603
  "sfreq": self.sfreq,
@@ -646,21 +624,17 @@ class FrequencyShift(Transform):
646
624
  Seed to be used to instantiate numpy random number generator instance.
647
625
  Defaults to None.
648
626
  """
649
- operation = staticmethod(frequency_shift)
650
627
 
651
- def __init__(
652
- self,
653
- probability,
654
- sfreq,
655
- max_delta_freq=2,
656
- random_state=None
657
- ):
628
+ operation = staticmethod(frequency_shift) # type: ignore[assignment]
629
+
630
+ def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
658
631
  super().__init__(
659
632
  probability=probability,
660
633
  random_state=random_state,
661
634
  )
662
- assert isinstance(sfreq, Real) and sfreq > 0, \
635
+ assert isinstance(sfreq, Real) and sfreq > 0, (
663
636
  "sfreq should be a positive float."
637
+ )
664
638
  self.sfreq = sfreq
665
639
 
666
640
  self.max_delta_freq = max_delta_freq
@@ -689,10 +663,7 @@ class FrequencyShift(Transform):
689
663
  return super().get_augmentation_params(*batch)
690
664
  X = batch[0]
691
665
 
692
- u = torch.as_tensor(
693
- self.rng.uniform(size=X.shape[0]),
694
- device=X.device
695
- )
666
+ u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
696
667
  max_delta_freq = self.max_delta_freq
697
668
  if isinstance(max_delta_freq, torch.Tensor):
698
669
  max_delta_freq = max_delta_freq.to(X.device)
@@ -718,12 +689,13 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
718
689
  matrices that will be fed to `SensorsRotation` transform. By
719
690
  default None.
720
691
  """
721
- assert raw_or_epoch is not None or ordered_ch_names is not None, \
692
+ assert raw_or_epoch is not None or ordered_ch_names is not None, (
722
693
  "At least one of raw_or_epoch and ordered_ch_names is needed."
694
+ )
723
695
  if ordered_ch_names is None:
724
- ordered_ch_names = raw_or_epoch.info['ch_names']
725
- ten_twenty_montage = make_standard_montage('standard_1020')
726
- positions_dict = ten_twenty_montage.get_positions()['ch_pos']
696
+ ordered_ch_names = raw_or_epoch.info["ch_names"]
697
+ ten_twenty_montage = make_standard_montage("standard_1020")
698
+ positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
727
699
  positions_subdict = {
728
700
  k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
729
701
  }
@@ -770,37 +742,38 @@ class SensorsRotation(Transform):
770
742
  Conference of the IEEE Engineering in Medicine and Biology Society
771
743
  (EMBC) (pp. 471-474).
772
744
  """
773
- operation = staticmethod(sensors_rotation)
745
+
746
+ operation = staticmethod(sensors_rotation) # type: ignore[assignment]
774
747
 
775
748
  def __init__(
776
749
  self,
777
750
  probability,
778
751
  sensors_positions_matrix,
779
- axis='z',
752
+ axis="z",
780
753
  max_degrees=15,
781
754
  spherical_splines=True,
782
- random_state=None
755
+ random_state=None,
783
756
  ):
784
- super().__init__(
785
- probability=probability,
786
- random_state=random_state
787
- )
757
+ super().__init__(probability=probability, random_state=random_state)
788
758
  if isinstance(sensors_positions_matrix, (np.ndarray, list)):
789
- sensors_positions_matrix = torch.as_tensor(
790
- sensors_positions_matrix
791
- )
792
- assert isinstance(sensors_positions_matrix, torch.Tensor), \
759
+ sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
760
+ assert isinstance(sensors_positions_matrix, torch.Tensor), (
793
761
  "sensors_positions should be an Tensor"
794
- assert (
795
- isinstance(max_degrees, (Real, torch.Tensor)) and
796
- max_degrees >= 0
797
- ), "max_degrees should be non-negative float."
798
- assert isinstance(axis, str) and axis in ['x', 'y', 'z'], \
799
- "axis can be either x, y or z."
800
- assert sensors_positions_matrix.shape[0] == 3, \
762
+ )
763
+ assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
764
+ "max_degrees should be non-negative float."
765
+ )
766
+ assert isinstance(axis, str) and axis in [
767
+ "x",
768
+ "y",
769
+ "z",
770
+ ], "axis can be either x, y or z."
771
+ assert sensors_positions_matrix.shape[0] == 3, (
801
772
  "sensors_positions_matrix shape should be 3 x n_channels."
802
- assert isinstance(spherical_splines, bool), \
773
+ )
774
+ assert isinstance(spherical_splines, bool), (
803
775
  "spherical_splines should be a boolean"
776
+ )
804
777
  self.sensors_positions_matrix = sensors_positions_matrix
805
778
  self.axis = axis
806
779
  self.spherical_splines = spherical_splines
@@ -841,21 +814,18 @@ class SensorsRotation(Transform):
841
814
  return super().get_augmentation_params(*batch)
842
815
  X = batch[0]
843
816
 
844
- u = self.rng.uniform(
845
- low=0,
846
- high=1,
847
- size=X.shape[0]
848
- )
817
+ u = self.rng.uniform(low=0, high=1, size=X.shape[0])
849
818
  max_degrees = self.max_degrees
850
819
  if isinstance(max_degrees, torch.Tensor):
851
820
  max_degrees = max_degrees.to(X.device)
852
- random_angles = torch.as_tensor(
853
- u, device=X.device) * 2 * max_degrees - max_degrees
821
+ random_angles = (
822
+ torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
823
+ )
854
824
  return {
855
825
  "sensors_positions_matrix": self.sensors_positions_matrix,
856
826
  "axis": self.axis,
857
827
  "angles": random_angles,
858
- "spherical_splines": self.spherical_splines
828
+ "spherical_splines": self.spherical_splines,
859
829
  }
860
830
 
861
831
 
@@ -900,7 +870,7 @@ class SensorsZRotation(SensorsRotation):
900
870
  ordered_ch_names,
901
871
  max_degrees=15,
902
872
  spherical_splines=True,
903
- random_state=None
873
+ random_state=None,
904
874
  ):
905
875
  sensors_positions_matrix = torch.as_tensor(
906
876
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -908,10 +878,10 @@ class SensorsZRotation(SensorsRotation):
908
878
  super().__init__(
909
879
  probability=probability,
910
880
  sensors_positions_matrix=sensors_positions_matrix,
911
- axis='z',
881
+ axis="z",
912
882
  max_degrees=max_degrees,
913
883
  spherical_splines=spherical_splines,
914
- random_state=random_state
884
+ random_state=random_state,
915
885
  )
916
886
 
917
887
 
@@ -956,7 +926,7 @@ class SensorsYRotation(SensorsRotation):
956
926
  ordered_ch_names,
957
927
  max_degrees=15,
958
928
  spherical_splines=True,
959
- random_state=None
929
+ random_state=None,
960
930
  ):
961
931
  sensors_positions_matrix = torch.as_tensor(
962
932
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -964,10 +934,10 @@ class SensorsYRotation(SensorsRotation):
964
934
  super().__init__(
965
935
  probability=probability,
966
936
  sensors_positions_matrix=sensors_positions_matrix,
967
- axis='y',
937
+ axis="y",
968
938
  max_degrees=max_degrees,
969
939
  spherical_splines=spherical_splines,
970
- random_state=random_state
940
+ random_state=random_state,
971
941
  )
972
942
 
973
943
 
@@ -1012,7 +982,7 @@ class SensorsXRotation(SensorsRotation):
1012
982
  ordered_ch_names,
1013
983
  max_degrees=15,
1014
984
  spherical_splines=True,
1015
- random_state=None
985
+ random_state=None,
1016
986
  ):
1017
987
  sensors_positions_matrix = torch.as_tensor(
1018
988
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -1020,10 +990,10 @@ class SensorsXRotation(SensorsRotation):
1020
990
  super().__init__(
1021
991
  probability=probability,
1022
992
  sensors_positions_matrix=sensors_positions_matrix,
1023
- axis='x',
993
+ axis="x",
1024
994
  max_degrees=max_degrees,
1025
995
  spherical_splines=spherical_splines,
1026
- random_state=random_state
996
+ random_state=random_state,
1027
997
  )
1028
998
 
1029
999
 
@@ -1050,17 +1020,13 @@ class Mixup(Transform):
1050
1020
  Online: https://arxiv.org/abs/1710.09412
1051
1021
  .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
1052
1022
  """
1053
- operation = staticmethod(mixup)
1054
1023
 
1055
- def __init__(
1056
- self,
1057
- alpha,
1058
- beta_per_sample=False,
1059
- random_state=None
1060
- ):
1024
+ operation = staticmethod(mixup) # type: ignore[assignment]
1025
+
1026
+ def __init__(self, alpha, beta_per_sample=False, random_state=None):
1061
1027
  super().__init__(
1062
1028
  probability=1.0, # Mixup has to be applied to whole batches
1063
- random_state=random_state
1029
+ random_state=random_state,
1064
1030
  )
1065
1031
  self.alpha = alpha
1066
1032
  self.beta_per_sample = beta_per_sample
@@ -1098,9 +1064,210 @@ class Mixup(Transform):
1098
1064
  else:
1099
1065
  lam = torch.ones(batch_size).to(device)
1100
1066
 
1101
- idx_perm = torch.as_tensor(self.rng.permutation(batch_size,))
1067
+ idx_perm = torch.as_tensor(
1068
+ self.rng.permutation(
1069
+ batch_size,
1070
+ )
1071
+ )
1102
1072
 
1103
1073
  return {
1104
1074
  "lam": lam,
1105
1075
  "idx_perm": idx_perm,
1106
1076
  }
1077
+
1078
+
1079
+ class SegmentationReconstruction(Transform):
1080
+ """Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
1081
+
1082
+ Applies a segmentation-reconstruction transform to the input data, as
1083
+ proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
1084
+ it to generate new synthetic trials by label, preserving the original
1085
+ order of the segments in time domain.
1086
+
1087
+ Parameters
1088
+ ----------
1089
+ probability : float
1090
+ Float setting the probability of applying the operation.
1091
+ random_state: int | numpy.random.Generator, optional
1092
+ Seed to be used to instantiate numpy random number generator instance.
1093
+ Used to decide whether to transform given the probability
1094
+ argument and to sample the segments mixing. Defaults to None.
1095
+ n_segments : int, optional
1096
+ Number of segments to use in the batch. If None, X will be
1097
+ automatically segmented, getting the last element in a list
1098
+ of factors of the number of samples's square root. Defaults to None.
1099
+
1100
+ References
1101
+ ----------
1102
+ .. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
1103
+ or suppress calibration time in oscillatory activity-based brain–computer
1104
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1105
+ """
1106
+
1107
+ operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
1108
+
1109
+ def __init__(
1110
+ self,
1111
+ probability,
1112
+ n_segments=None,
1113
+ random_state=None,
1114
+ ):
1115
+ super().__init__(
1116
+ probability=probability,
1117
+ random_state=random_state,
1118
+ )
1119
+ self.n_segments = n_segments
1120
+
1121
+ def get_augmentation_params(self, *batch):
1122
+ """Return transform parameters.
1123
+
1124
+ Parameters
1125
+ ----------
1126
+ X : tensor.Tensor
1127
+ The data.
1128
+ y : tensor.Tensor
1129
+ The labels.
1130
+ Returns
1131
+ -------
1132
+ params : dict
1133
+ Contains the number of segments to split the signal into.
1134
+ """
1135
+ X, y = batch[0], batch[1]
1136
+
1137
+ if y is not None:
1138
+ if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
1139
+ raise ValueError("X and y must be torch tensors.")
1140
+
1141
+ if X.shape[0] != y.shape[0]:
1142
+ raise ValueError("Number of samples in X and y must be the same.")
1143
+
1144
+ if self.n_segments is None:
1145
+ self.n_segments = int(X.shape[2])
1146
+ n_segments_list = []
1147
+ for i in range(1, int(self.n_segments**0.5) + 1):
1148
+ if self.n_segments % i == 0:
1149
+ n_segments_list.append(i)
1150
+ self.n_segments = n_segments_list[-1]
1151
+
1152
+ elif not (
1153
+ isinstance(self.n_segments, (int, float))
1154
+ and 1 <= self.n_segments <= X.shape[2]
1155
+ ):
1156
+ raise ValueError(
1157
+ f"Number of segments must be a positive integer less than "
1158
+ f"(or equal) the window size. Got {self.n_segments}"
1159
+ )
1160
+
1161
+ if y is None:
1162
+ data_classes = [(np.nan, X)]
1163
+
1164
+ else:
1165
+ classes = torch.unique(y)
1166
+
1167
+ data_classes = [(i, X[y == i]) for i in classes]
1168
+
1169
+ rand_indices = dict()
1170
+ for label, X_class in data_classes:
1171
+ n_trials = X_class.shape[0]
1172
+ rand_indices[label] = self.rng.randint(
1173
+ 0, n_trials, (n_trials, self.n_segments)
1174
+ )
1175
+
1176
+ idx_shuffle = self.rng.permutation(X.shape[0])
1177
+
1178
+ return {
1179
+ "n_segments": self.n_segments,
1180
+ "data_classes": data_classes,
1181
+ "rand_indices": rand_indices,
1182
+ "idx_shuffle": idx_shuffle,
1183
+ }
1184
+
1185
+
1186
+ class MaskEncoding(Transform):
1187
+ """MaskEncoding from [1]_.
1188
+
1189
+ Replaces randomly chosen contiguous part (or parts) of all channels by
1190
+ zeros (if more than one segment, it may overlap).
1191
+
1192
+ Implementation based on [1]_
1193
+
1194
+ Parameters
1195
+ ----------
1196
+ probability : float
1197
+ Float setting the probability of applying the operation.
1198
+ max_mask_ratio: float, optional
1199
+ Signal ratio to zero out. Defaults to 0.1.
1200
+ n_segments : int, optional
1201
+ Number of segments to zero out in each example.
1202
+ Defaults to 1.
1203
+ random_state: int | numpy.random.Generator, optional
1204
+ Seed to be used to instantiate numpy random number generator instance.
1205
+ Defaults to None.
1206
+
1207
+ References
1208
+ ----------
1209
+ .. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
1210
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
1211
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1212
+ 32 (2024): 875-886.
1213
+ """
1214
+
1215
+ operation = staticmethod(mask_encoding) # type: ignore[assignment]
1216
+
1217
+ def __init__(
1218
+ self,
1219
+ probability,
1220
+ max_mask_ratio=0.1,
1221
+ n_segments=1,
1222
+ random_state=None,
1223
+ ):
1224
+ super().__init__(
1225
+ probability=probability,
1226
+ random_state=random_state,
1227
+ )
1228
+ assert isinstance(n_segments, int) and n_segments > 0, (
1229
+ "n_segments should be a positive integer."
1230
+ )
1231
+ assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
1232
+ "mask_ratio should be a float between 0 and 1."
1233
+ )
1234
+
1235
+ self.mask_ratio = max_mask_ratio
1236
+ self.n_segments = n_segments
1237
+
1238
+ def get_augmentation_params(self, *batch):
1239
+ """Return transform parameters.
1240
+
1241
+ Parameters
1242
+ ----------
1243
+ X : tensor.Tensor
1244
+ The data.
1245
+ y : tensor.Tensor
1246
+ The labels.
1247
+ Returns
1248
+ -------
1249
+ params : dict
1250
+ Contains ...
1251
+ """
1252
+ if len(batch) == 0:
1253
+ return super().get_augmentation_params(*batch)
1254
+ X = batch[0]
1255
+
1256
+ batch_size, _, n_times = X.shape
1257
+
1258
+ segment_length = int((n_times * self.mask_ratio) / self.n_segments)
1259
+
1260
+ assert segment_length >= 1, (
1261
+ "n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
1262
+ )
1263
+
1264
+ time_start = self.rng.randint(
1265
+ 0, n_times - segment_length, (batch_size, self.n_segments)
1266
+ )
1267
+ time_start = torch.from_numpy(time_start)
1268
+
1269
+ return {
1270
+ "time_start": time_start,
1271
+ "segment_length": segment_length,
1272
+ "n_segments": self.n_segments,
1273
+ }