braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,34 @@
1
1
  # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
2
  # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
3
4
  #
4
5
  # License: BSD (3-clause)
5
6
 
6
7
  import warnings
7
8
  from numbers import Real
9
+ from typing import Callable
8
10
 
9
11
  import numpy as np
10
12
  import torch
11
13
  from mne.channels import make_standard_montage
12
14
 
13
15
  from .base import Transform
14
- from .functional import bandstop_filter
15
- from .functional import channels_dropout
16
- from .functional import channels_permute
17
- from .functional import channels_shuffle
18
- from .functional import frequency_shift
19
- from .functional import ft_surrogate
20
- from .functional import gaussian_noise
21
- from .functional import mixup
22
- from .functional import sensors_rotation
23
- from .functional import sign_flip
24
- from .functional import smooth_time_mask
25
- from .functional import time_reverse
16
+ from .functional import (
17
+ bandstop_filter,
18
+ channels_dropout,
19
+ channels_permute,
20
+ channels_shuffle,
21
+ frequency_shift,
22
+ ft_surrogate,
23
+ gaussian_noise,
24
+ mask_encoding,
25
+ mixup,
26
+ segmentation_reconstruction,
27
+ sensors_rotation,
28
+ sign_flip,
29
+ smooth_time_mask,
30
+ time_reverse,
31
+ )
26
32
 
27
33
 
28
34
  class TimeReverse(Transform):
@@ -37,7 +43,8 @@ class TimeReverse(Transform):
37
43
  Used to decide whether or not to transform given the probability
38
44
  argument. Defaults to None.
39
45
  """
40
- operation = staticmethod(time_reverse)
46
+
47
+ operation = staticmethod(time_reverse) # type: ignore[assignment]
41
48
 
42
49
  def __init__(
43
50
  self,
@@ -62,17 +69,11 @@ class SignFlip(Transform):
62
69
  Used to decide whether or not to transform given the probability
63
70
  argument. Defaults to None.
64
71
  """
65
- operation = staticmethod(sign_flip)
66
72
 
67
- def __init__(
68
- self,
69
- probability,
70
- random_state=None
71
- ):
72
- super().__init__(
73
- probability=probability,
74
- random_state=random_state
75
- )
73
+ operation = staticmethod(sign_flip) # type: ignore[assignment]
74
+
75
+ def __init__(self, probability, random_state=None):
76
+ super().__init__(probability=probability, random_state=random_state)
76
77
 
77
78
 
78
79
  class FTSurrogate(Transform):
@@ -102,25 +103,26 @@ class FTSurrogate(Transform):
102
103
  Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
103
104
  preprint arXiv:1806.08675.
104
105
  """
105
- operation = staticmethod(ft_surrogate)
106
+
107
+ operation = staticmethod(ft_surrogate) # type: ignore[assignment]
106
108
 
107
109
  def __init__(
108
110
  self,
109
111
  probability,
110
112
  phase_noise_magnitude=1,
111
113
  channel_indep=False,
112
- random_state=None
114
+ random_state=None,
113
115
  ):
114
- super().__init__(
115
- probability=probability,
116
- random_state=random_state
117
- )
118
- assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), \
116
+ super().__init__(probability=probability, random_state=random_state)
117
+ assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
119
118
  "phase_noise_magnitude should be a float."
120
- assert 0 <= phase_noise_magnitude <= 1, \
119
+ )
120
+ assert 0 <= phase_noise_magnitude <= 1, (
121
121
  "phase_noise_magnitude should be between 0 and 1."
122
+ )
122
123
  assert isinstance(channel_indep, bool), (
123
- "channel_indep is expected to be a boolean")
124
+ "channel_indep is expected to be a boolean"
125
+ )
124
126
  self.phase_noise_magnitude = phase_noise_magnitude
125
127
  self.channel_indep = channel_indep
126
128
 
@@ -174,18 +176,11 @@ class ChannelsDropout(Transform):
174
176
  Learning from Heterogeneous EEG Signals with Differentiable Channel
175
177
  Reordering. arXiv preprint arXiv:2010.13694.
176
178
  """
177
- operation = staticmethod(channels_dropout)
178
179
 
179
- def __init__(
180
- self,
181
- probability,
182
- p_drop=0.2,
183
- random_state=None
184
- ):
185
- super().__init__(
186
- probability=probability,
187
- random_state=random_state
188
- )
180
+ operation = staticmethod(channels_dropout) # type: ignore[assignment]
181
+
182
+ def __init__(self, probability, p_drop=0.2, random_state=None):
183
+ super().__init__(probability=probability, random_state=random_state)
189
184
  self.p_drop = p_drop
190
185
 
191
186
  def get_augmentation_params(self, *batch):
@@ -239,18 +234,11 @@ class ChannelsShuffle(Transform):
239
234
  Learning from Heterogeneous EEG Signals with Differentiable Channel
240
235
  Reordering. arXiv preprint arXiv:2010.13694.
241
236
  """
242
- operation = staticmethod(channels_shuffle)
243
237
 
244
- def __init__(
245
- self,
246
- probability,
247
- p_shuffle=0.2,
248
- random_state=None
249
- ):
250
- super().__init__(
251
- probability=probability,
252
- random_state=random_state
253
- )
238
+ operation = staticmethod(channels_shuffle) # type: ignore[assignment]
239
+
240
+ def __init__(self, probability, p_shuffle=0.2, random_state=None):
241
+ super().__init__(probability=probability, random_state=random_state)
254
242
  self.p_shuffle = p_shuffle
255
243
 
256
244
  def get_augmentation_params(self, *batch):
@@ -308,14 +296,10 @@ class GaussianNoise(Transform):
308
296
  Representation Learning for Electroencephalogram Classification. In
309
297
  Machine Learning for Health (pp. 238-253). PMLR.
310
298
  """
311
- operation = staticmethod(gaussian_noise)
312
299
 
313
- def __init__(
314
- self,
315
- probability,
316
- std=0.1,
317
- random_state=None
318
- ):
300
+ operation = staticmethod(gaussian_noise) # type: ignore[assignment]
301
+
302
+ def __init__(self, probability, std=0.1, random_state=None):
319
303
  super().__init__(
320
304
  probability=probability,
321
305
  random_state=random_state,
@@ -373,28 +357,23 @@ class ChannelsSymmetry(Transform):
373
357
  (2018). HAMLET: interpretable human and machine co-learning technique.
374
358
  arXiv preprint arXiv:1803.09702.
375
359
  """
376
- operation = staticmethod(channels_permute)
377
360
 
378
- def __init__(
379
- self,
380
- probability,
381
- ordered_ch_names,
382
- random_state=None
383
- ):
361
+ operation = staticmethod(channels_permute) # type: ignore[assignment]
362
+
363
+ def __init__(self, probability, ordered_ch_names, random_state=None):
384
364
  super().__init__(
385
365
  probability=probability,
386
366
  random_state=random_state,
387
367
  )
388
- assert (
389
- isinstance(ordered_ch_names, list) and
390
- all(isinstance(ch, str) for ch in ordered_ch_names)
368
+ assert isinstance(ordered_ch_names, list) and all(
369
+ isinstance(ch, str) for ch in ordered_ch_names
391
370
  ), "ordered_ch_names should be a list of str."
392
371
 
393
372
  permutation = list()
394
373
  for idx, ch_name in enumerate(ordered_ch_names):
395
374
  new_position = idx
396
375
  # Find digits in channel name (assuming 10-20 system)
397
- d = ''.join(list(filter(str.isdigit, ch_name)))
376
+ d = "".join(list(filter(str.isdigit, ch_name)))
398
377
  if len(d) > 0:
399
378
  d = int(d)
400
379
  if d % 2 == 0: # pair/right electrodes
@@ -454,22 +433,17 @@ class SmoothTimeMask(Transform):
454
433
  Representation Learning for Electroencephalogram Classification. In
455
434
  Machine Learning for Health (pp. 238-253). PMLR.
456
435
  """
457
- operation = staticmethod(smooth_time_mask)
458
436
 
459
- def __init__(
460
- self,
461
- probability,
462
- mask_len_samples=100,
463
- random_state=None
464
- ):
437
+ operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
438
+
439
+ def __init__(self, probability, mask_len_samples=100, random_state=None):
465
440
  super().__init__(
466
441
  probability=probability,
467
442
  random_state=random_state,
468
443
  )
469
444
 
470
445
  assert (
471
- isinstance(mask_len_samples, (int, torch.Tensor)) and
472
- mask_len_samples > 0
446
+ isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
473
447
  ), "mask_len_samples has to be a positive integer"
474
448
  self.mask_len_samples = mask_len_samples
475
449
 
@@ -504,9 +478,14 @@ class SmoothTimeMask(Transform):
504
478
  mask_len_samples = self.mask_len_samples
505
479
  if isinstance(mask_len_samples, torch.Tensor):
506
480
  mask_len_samples = mask_len_samples.to(X.device)
507
- mask_start = torch.as_tensor(self.rng.uniform(
508
- low=0, high=1, size=X.shape[0],
509
- ), device=X.device) * (seq_length - mask_len_samples)
481
+ mask_start = torch.as_tensor(
482
+ self.rng.uniform(
483
+ low=0,
484
+ high=1,
485
+ size=X.shape[0],
486
+ ),
487
+ device=X.device,
488
+ ) * (seq_length - mask_len_samples)
510
489
  return {
511
490
  "mask_start_per_sample": mask_start,
512
491
  "mask_len_samples": mask_len_samples,
@@ -546,27 +525,26 @@ class BandstopFilter(Transform):
546
525
  Representation Learning for Electroencephalogram Classification. In
547
526
  Machine Learning for Health (pp. 238-253). PMLR.
548
527
  """
549
- operation = staticmethod(bandstop_filter)
528
+
529
+ operation = staticmethod(bandstop_filter) # type: ignore[assignment]
550
530
 
551
531
  def __init__(
552
- self,
553
- probability,
554
- sfreq,
555
- bandwidth=1,
556
- max_freq=None,
557
- random_state=None
532
+ self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
558
533
  ):
559
534
  super().__init__(
560
535
  probability=probability,
561
536
  random_state=random_state,
562
537
  )
563
- assert isinstance(bandwidth, Real) and bandwidth >= 0, \
538
+ assert isinstance(bandwidth, Real) and bandwidth >= 0, (
564
539
  "bandwidth should be a non-negative float."
565
- assert isinstance(sfreq, Real) and sfreq > 0, \
540
+ )
541
+ assert isinstance(sfreq, Real) and sfreq > 0, (
566
542
  "sfreq should be a positive float."
543
+ )
567
544
  if max_freq is not None:
568
- assert isinstance(max_freq, Real) and max_freq > 0, \
545
+ assert isinstance(max_freq, Real) and max_freq > 0, (
569
546
  "max_freq should be a positive float."
547
+ )
570
548
  nyq = sfreq / 2
571
549
  if max_freq is None or max_freq > nyq:
572
550
  max_freq = nyq
@@ -575,8 +553,9 @@ class BandstopFilter(Transform):
575
553
  f" Nyquist frequency ({nyq} Hz)."
576
554
  f" Falling back to max_freq = {nyq}."
577
555
  )
578
- assert bandwidth < max_freq, \
556
+ assert bandwidth < max_freq, (
579
557
  f"`bandwidth` needs to be smaller than max_freq={max_freq}"
558
+ )
580
559
 
581
560
  # override bandwidth value when a magnitude is passed
582
561
  self.sfreq = sfreq
@@ -619,7 +598,7 @@ class BandstopFilter(Transform):
619
598
  notched_freqs = self.rng.uniform(
620
599
  low=1 + 2 * self.bandwidth,
621
600
  high=self.max_freq - 1 - 2 * self.bandwidth,
622
- size=X.shape[0]
601
+ size=X.shape[0],
623
602
  )
624
603
  return {
625
604
  "sfreq": self.sfreq,
@@ -646,21 +625,17 @@ class FrequencyShift(Transform):
646
625
  Seed to be used to instantiate numpy random number generator instance.
647
626
  Defaults to None.
648
627
  """
649
- operation = staticmethod(frequency_shift)
650
628
 
651
- def __init__(
652
- self,
653
- probability,
654
- sfreq,
655
- max_delta_freq=2,
656
- random_state=None
657
- ):
629
+ operation = staticmethod(frequency_shift) # type: ignore[assignment]
630
+
631
+ def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
658
632
  super().__init__(
659
633
  probability=probability,
660
634
  random_state=random_state,
661
635
  )
662
- assert isinstance(sfreq, Real) and sfreq > 0, \
636
+ assert isinstance(sfreq, Real) and sfreq > 0, (
663
637
  "sfreq should be a positive float."
638
+ )
664
639
  self.sfreq = sfreq
665
640
 
666
641
  self.max_delta_freq = max_delta_freq
@@ -689,10 +664,7 @@ class FrequencyShift(Transform):
689
664
  return super().get_augmentation_params(*batch)
690
665
  X = batch[0]
691
666
 
692
- u = torch.as_tensor(
693
- self.rng.uniform(size=X.shape[0]),
694
- device=X.device
695
- )
667
+ u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
696
668
  max_delta_freq = self.max_delta_freq
697
669
  if isinstance(max_delta_freq, torch.Tensor):
698
670
  max_delta_freq = max_delta_freq.to(X.device)
@@ -718,12 +690,13 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
718
690
  matrices that will be fed to `SensorsRotation` transform. By
719
691
  default None.
720
692
  """
721
- assert raw_or_epoch is not None or ordered_ch_names is not None, \
693
+ assert raw_or_epoch is not None or ordered_ch_names is not None, (
722
694
  "At least one of raw_or_epoch and ordered_ch_names is needed."
695
+ )
723
696
  if ordered_ch_names is None:
724
- ordered_ch_names = raw_or_epoch.info['ch_names']
725
- ten_twenty_montage = make_standard_montage('standard_1020')
726
- positions_dict = ten_twenty_montage.get_positions()['ch_pos']
697
+ ordered_ch_names = raw_or_epoch.info["ch_names"]
698
+ ten_twenty_montage = make_standard_montage("standard_1020")
699
+ positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
727
700
  positions_subdict = {
728
701
  k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
729
702
  }
@@ -770,37 +743,38 @@ class SensorsRotation(Transform):
770
743
  Conference of the IEEE Engineering in Medicine and Biology Society
771
744
  (EMBC) (pp. 471-474).
772
745
  """
773
- operation = staticmethod(sensors_rotation)
746
+
747
+ operation = staticmethod(sensors_rotation) # type: ignore[assignment]
774
748
 
775
749
  def __init__(
776
750
  self,
777
751
  probability,
778
752
  sensors_positions_matrix,
779
- axis='z',
753
+ axis="z",
780
754
  max_degrees=15,
781
755
  spherical_splines=True,
782
- random_state=None
756
+ random_state=None,
783
757
  ):
784
- super().__init__(
785
- probability=probability,
786
- random_state=random_state
787
- )
758
+ super().__init__(probability=probability, random_state=random_state)
788
759
  if isinstance(sensors_positions_matrix, (np.ndarray, list)):
789
- sensors_positions_matrix = torch.as_tensor(
790
- sensors_positions_matrix
791
- )
792
- assert isinstance(sensors_positions_matrix, torch.Tensor), \
760
+ sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
761
+ assert isinstance(sensors_positions_matrix, torch.Tensor), (
793
762
  "sensors_positions should be an Tensor"
794
- assert (
795
- isinstance(max_degrees, (Real, torch.Tensor)) and
796
- max_degrees >= 0
797
- ), "max_degrees should be non-negative float."
798
- assert isinstance(axis, str) and axis in ['x', 'y', 'z'], \
799
- "axis can be either x, y or z."
800
- assert sensors_positions_matrix.shape[0] == 3, \
763
+ )
764
+ assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
765
+ "max_degrees should be non-negative float."
766
+ )
767
+ assert isinstance(axis, str) and axis in [
768
+ "x",
769
+ "y",
770
+ "z",
771
+ ], "axis can be either x, y or z."
772
+ assert sensors_positions_matrix.shape[0] == 3, (
801
773
  "sensors_positions_matrix shape should be 3 x n_channels."
802
- assert isinstance(spherical_splines, bool), \
774
+ )
775
+ assert isinstance(spherical_splines, bool), (
803
776
  "spherical_splines should be a boolean"
777
+ )
804
778
  self.sensors_positions_matrix = sensors_positions_matrix
805
779
  self.axis = axis
806
780
  self.spherical_splines = spherical_splines
@@ -841,21 +815,18 @@ class SensorsRotation(Transform):
841
815
  return super().get_augmentation_params(*batch)
842
816
  X = batch[0]
843
817
 
844
- u = self.rng.uniform(
845
- low=0,
846
- high=1,
847
- size=X.shape[0]
848
- )
818
+ u = self.rng.uniform(low=0, high=1, size=X.shape[0])
849
819
  max_degrees = self.max_degrees
850
820
  if isinstance(max_degrees, torch.Tensor):
851
821
  max_degrees = max_degrees.to(X.device)
852
- random_angles = torch.as_tensor(
853
- u, device=X.device) * 2 * max_degrees - max_degrees
822
+ random_angles = (
823
+ torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
824
+ )
854
825
  return {
855
826
  "sensors_positions_matrix": self.sensors_positions_matrix,
856
827
  "axis": self.axis,
857
828
  "angles": random_angles,
858
- "spherical_splines": self.spherical_splines
829
+ "spherical_splines": self.spherical_splines,
859
830
  }
860
831
 
861
832
 
@@ -900,7 +871,7 @@ class SensorsZRotation(SensorsRotation):
900
871
  ordered_ch_names,
901
872
  max_degrees=15,
902
873
  spherical_splines=True,
903
- random_state=None
874
+ random_state=None,
904
875
  ):
905
876
  sensors_positions_matrix = torch.as_tensor(
906
877
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -908,10 +879,10 @@ class SensorsZRotation(SensorsRotation):
908
879
  super().__init__(
909
880
  probability=probability,
910
881
  sensors_positions_matrix=sensors_positions_matrix,
911
- axis='z',
882
+ axis="z",
912
883
  max_degrees=max_degrees,
913
884
  spherical_splines=spherical_splines,
914
- random_state=random_state
885
+ random_state=random_state,
915
886
  )
916
887
 
917
888
 
@@ -956,7 +927,7 @@ class SensorsYRotation(SensorsRotation):
956
927
  ordered_ch_names,
957
928
  max_degrees=15,
958
929
  spherical_splines=True,
959
- random_state=None
930
+ random_state=None,
960
931
  ):
961
932
  sensors_positions_matrix = torch.as_tensor(
962
933
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -964,10 +935,10 @@ class SensorsYRotation(SensorsRotation):
964
935
  super().__init__(
965
936
  probability=probability,
966
937
  sensors_positions_matrix=sensors_positions_matrix,
967
- axis='y',
938
+ axis="y",
968
939
  max_degrees=max_degrees,
969
940
  spherical_splines=spherical_splines,
970
- random_state=random_state
941
+ random_state=random_state,
971
942
  )
972
943
 
973
944
 
@@ -1012,7 +983,7 @@ class SensorsXRotation(SensorsRotation):
1012
983
  ordered_ch_names,
1013
984
  max_degrees=15,
1014
985
  spherical_splines=True,
1015
- random_state=None
986
+ random_state=None,
1016
987
  ):
1017
988
  sensors_positions_matrix = torch.as_tensor(
1018
989
  _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
@@ -1020,10 +991,10 @@ class SensorsXRotation(SensorsRotation):
1020
991
  super().__init__(
1021
992
  probability=probability,
1022
993
  sensors_positions_matrix=sensors_positions_matrix,
1023
- axis='x',
994
+ axis="x",
1024
995
  max_degrees=max_degrees,
1025
996
  spherical_splines=spherical_splines,
1026
- random_state=random_state
997
+ random_state=random_state,
1027
998
  )
1028
999
 
1029
1000
 
@@ -1050,17 +1021,13 @@ class Mixup(Transform):
1050
1021
  Online: https://arxiv.org/abs/1710.09412
1051
1022
  .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
1052
1023
  """
1053
- operation = staticmethod(mixup)
1054
1024
 
1055
- def __init__(
1056
- self,
1057
- alpha,
1058
- beta_per_sample=False,
1059
- random_state=None
1060
- ):
1025
+ operation = staticmethod(mixup) # type: ignore[assignment]
1026
+
1027
+ def __init__(self, alpha, beta_per_sample=False, random_state=None):
1061
1028
  super().__init__(
1062
1029
  probability=1.0, # Mixup has to be applied to whole batches
1063
- random_state=random_state
1030
+ random_state=random_state,
1064
1031
  )
1065
1032
  self.alpha = alpha
1066
1033
  self.beta_per_sample = beta_per_sample
@@ -1098,9 +1065,210 @@ class Mixup(Transform):
1098
1065
  else:
1099
1066
  lam = torch.ones(batch_size).to(device)
1100
1067
 
1101
- idx_perm = torch.as_tensor(self.rng.permutation(batch_size,))
1068
+ idx_perm = torch.as_tensor(
1069
+ self.rng.permutation(
1070
+ batch_size,
1071
+ )
1072
+ )
1102
1073
 
1103
1074
  return {
1104
1075
  "lam": lam,
1105
1076
  "idx_perm": idx_perm,
1106
1077
  }
1078
+
1079
+
1080
+ class SegmentationReconstruction(Transform):
1081
+ """Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
1082
+
1083
+ Applies a segmentation-reconstruction transform to the input data, as
1084
+ proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
1085
+ it to generate new synthetic trials by label, preserving the original
1086
+ order of the segments in time domain.
1087
+
1088
+ Parameters
1089
+ ----------
1090
+ probability : float
1091
+ Float setting the probability of applying the operation.
1092
+ random_state: int | numpy.random.Generator, optional
1093
+ Seed to be used to instantiate numpy random number generator instance.
1094
+ Used to decide whether to transform given the probability
1095
+ argument and to sample the segments mixing. Defaults to None.
1096
+ n_segments : int, optional
1097
+ Number of segments to use in the batch. If None, X will be
1098
+ automatically segmented, getting the last element in a list
1099
+ of factors of the number of samples's square root. Defaults to None.
1100
+
1101
+ References
1102
+ ----------
1103
+ .. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
1104
+ or suppress calibration time in oscillatory activity-based brain–computer
1105
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1106
+ """
1107
+
1108
+ operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
1109
+
1110
+ def __init__(
1111
+ self,
1112
+ probability,
1113
+ n_segments=None,
1114
+ random_state=None,
1115
+ ):
1116
+ super().__init__(
1117
+ probability=probability,
1118
+ random_state=random_state,
1119
+ )
1120
+ self.n_segments = n_segments
1121
+
1122
+ def get_augmentation_params(self, *batch):
1123
+ """Return transform parameters.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ X : tensor.Tensor
1128
+ The data.
1129
+ y : tensor.Tensor
1130
+ The labels.
1131
+ Returns
1132
+ -------
1133
+ params : dict
1134
+ Contains the number of segments to split the signal into.
1135
+ """
1136
+ X, y = batch[0], batch[1]
1137
+
1138
+ if y is not None:
1139
+ if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
1140
+ raise ValueError("X and y must be torch tensors.")
1141
+
1142
+ if X.shape[0] != y.shape[0]:
1143
+ raise ValueError("Number of samples in X and y must be the same.")
1144
+
1145
+ if self.n_segments is None:
1146
+ self.n_segments = int(X.shape[2])
1147
+ n_segments_list = []
1148
+ for i in range(1, int(self.n_segments**0.5) + 1):
1149
+ if self.n_segments % i == 0:
1150
+ n_segments_list.append(i)
1151
+ self.n_segments = n_segments_list[-1]
1152
+
1153
+ elif not (
1154
+ isinstance(self.n_segments, (int, float))
1155
+ and 1 <= self.n_segments <= X.shape[2]
1156
+ ):
1157
+ raise ValueError(
1158
+ f"Number of segments must be a positive integer less than "
1159
+ f"(or equal) the window size. Got {self.n_segments}"
1160
+ )
1161
+
1162
+ if y is None:
1163
+ data_classes = [(np.nan, X)]
1164
+
1165
+ else:
1166
+ classes = torch.unique(y)
1167
+
1168
+ data_classes = [(i, X[y == i]) for i in classes]
1169
+
1170
+ rand_indices = dict()
1171
+ for label, X_class in data_classes:
1172
+ n_trials = X_class.shape[0]
1173
+ rand_indices[label] = self.rng.randint(
1174
+ 0, n_trials, (n_trials, self.n_segments)
1175
+ )
1176
+
1177
+ idx_shuffle = self.rng.permutation(X.shape[0])
1178
+
1179
+ return {
1180
+ "n_segments": self.n_segments,
1181
+ "data_classes": data_classes,
1182
+ "rand_indices": rand_indices,
1183
+ "idx_shuffle": idx_shuffle,
1184
+ }
1185
+
1186
+
1187
+ class MaskEncoding(Transform):
1188
+ """MaskEncoding from [1]_.
1189
+
1190
+ Replaces randomly chosen contiguous part (or parts) of all channels by
1191
+ zeros (if more than one segment, it may overlap).
1192
+
1193
+ Implementation based on [1]_
1194
+
1195
+ Parameters
1196
+ ----------
1197
+ probability : float
1198
+ Float setting the probability of applying the operation.
1199
+ max_mask_ratio: float, optional
1200
+ Signal ratio to zero out. Defaults to 0.1.
1201
+ n_segments : int, optional
1202
+ Number of segments to zero out in each example.
1203
+ Defaults to 1.
1204
+ random_state: int | numpy.random.Generator, optional
1205
+ Seed to be used to instantiate numpy random number generator instance.
1206
+ Defaults to None.
1207
+
1208
+ References
1209
+ ----------
1210
+ .. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
1211
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
1212
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1213
+ 32 (2024): 875-886.
1214
+ """
1215
+
1216
+ operation = staticmethod(mask_encoding) # type: ignore[assignment]
1217
+
1218
+ def __init__(
1219
+ self,
1220
+ probability,
1221
+ max_mask_ratio=0.1,
1222
+ n_segments=1,
1223
+ random_state=None,
1224
+ ):
1225
+ super().__init__(
1226
+ probability=probability,
1227
+ random_state=random_state,
1228
+ )
1229
+ assert isinstance(n_segments, int) and n_segments > 0, (
1230
+ "n_segments should be a positive integer."
1231
+ )
1232
+ assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
1233
+ "mask_ratio should be a float between 0 and 1."
1234
+ )
1235
+
1236
+ self.mask_ratio = max_mask_ratio
1237
+ self.n_segments = n_segments
1238
+
1239
+ def get_augmentation_params(self, *batch):
1240
+ """Return transform parameters.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ X : tensor.Tensor
1245
+ The data.
1246
+ y : tensor.Tensor
1247
+ The labels.
1248
+ Returns
1249
+ -------
1250
+ params : dict
1251
+ Contains ...
1252
+ """
1253
+ if len(batch) == 0:
1254
+ return super().get_augmentation_params(*batch)
1255
+ X = batch[0]
1256
+
1257
+ batch_size, _, n_times = X.shape
1258
+
1259
+ segment_length = int((n_times * self.mask_ratio) / self.n_segments)
1260
+
1261
+ assert segment_length >= 1, (
1262
+ "n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
1263
+ )
1264
+
1265
+ time_start = self.rng.randint(
1266
+ 0, n_times - segment_length, (batch_size, self.n_segments)
1267
+ )
1268
+ time_start = torch.from_numpy(time_start)
1269
+
1270
+ return {
1271
+ "time_start": time_start,
1272
+ "segment_length": segment_length,
1273
+ "n_segments": self.n_segments,
1274
+ }