braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
- """Utilities for data augmentation."""
1
+ """
2
+ Utilities for data augmentation.
3
+ """
2
4
 
3
5
  from . import functional
4
6
  from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
5
7
  from .transforms import (
6
- AmplitudeScale,
7
8
  BandstopFilter,
8
9
  ChannelsDropout,
9
- ChannelsReref,
10
10
  ChannelsShuffle,
11
11
  ChannelsSymmetry,
12
12
  FrequencyShift,
@@ -46,7 +46,5 @@ __all__ = [
46
46
  "Mixup",
47
47
  "SegmentationReconstruction",
48
48
  "MaskEncoding",
49
- "AmplitudeScale",
50
- "ChannelsReref",
51
49
  "functional",
52
50
  ]
@@ -31,8 +31,7 @@ Operation = Callable[
31
31
 
32
32
 
33
33
  class Transform(torch.nn.Module):
34
- """Basic transform class used for implementing data augmentation.
35
-
34
+ """Basic transform class used for implementing data augmentation
36
35
  operations.
37
36
 
38
37
  Parameters
@@ -43,7 +42,7 @@ class Transform(torch.nn.Module):
43
42
  probability : float, optional
44
43
  Float between 0 and 1 defining the uniform probability of applying the
45
44
  operation. Set to 1.0 by default (e.g always apply the operation).
46
- random_state : int, optional
45
+ random_state: int, optional
47
46
  Seed to be used to instantiate numpy random number generator instance.
48
47
  Used to decide whether or not to transform given the probability
49
48
  argument. Defaults to None.
@@ -126,7 +125,7 @@ class Transform(torch.nn.Module):
126
125
  return out_X
127
126
 
128
127
  def _get_mask(self, batch_size, device) -> torch.Tensor:
129
- """Samples whether to apply operation or not over the whole batch."""
128
+ """Samples whether to apply operation or not over the whole batch"""
130
129
  return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
131
130
  device
132
131
  )
@@ -154,7 +153,7 @@ class Compose(Transform):
154
153
 
155
154
  Parameters
156
155
  ----------
157
- transforms : list
156
+ transforms: list
158
157
  Sequence of Transforms to be composed.
159
158
  """
160
159
 
@@ -170,9 +169,7 @@ class Compose(Transform):
170
169
 
171
170
  def _make_collateable(transform, device=None):
172
171
  """Wraps a transform to make it collateable.
173
-
174
- with device control.
175
- """
172
+ with device control."""
176
173
 
177
174
  def _collate_fn(batch):
178
175
  collated_batch = default_collate(batch)
@@ -144,7 +144,7 @@ def ft_surrogate(
144
144
  EEG input example or batch.
145
145
  y : torch.Tensor
146
146
  EEG labels for the example or batch.
147
- phase_noise_magnitude : float
147
+ phase_noise_magnitude: float
148
148
  Float between 0 and 1 setting the range over which the phase
149
149
  perturbation is uniformly sampled:
150
150
  [0, `phase_noise_magnitude` * 2 * `pi`].
@@ -152,7 +152,7 @@ def ft_surrogate(
152
152
  Whether to sample phase perturbations independently for each channel or
153
153
  not. It is advised to set it to False when spatial information is
154
154
  important for the task, like in BCI.
155
- random_state : int | numpy.random.Generator, optional
155
+ random_state: int | numpy.random.Generator, optional
156
156
  Used to draw the phase perturbation. Defaults to None.
157
157
 
158
158
  Returns
@@ -289,10 +289,10 @@ def channels_shuffle(
289
289
  EEG input example or batch.
290
290
  y : torch.Tensor
291
291
  EEG labels for the example or batch.
292
- p_shuffle : float | None
292
+ p_shuffle: float | None
293
293
  Float between 0 and 1 setting the probability of including the channel
294
294
  in the set of permutted channels.
295
- random_state : int | numpy.random.Generator, optional
295
+ random_state: int | numpy.random.Generator, optional
296
296
  Seed to be used to instantiate numpy random number generator instance.
297
297
  Used to sample which channels to shuffle and to carry the shuffle.
298
298
  Defaults to None.
@@ -335,7 +335,7 @@ def gaussian_noise(
335
335
  EEG labels for the example or batch.
336
336
  std : float
337
337
  Standard deviation to use for the additive noise.
338
- random_state : int | numpy.random.Generator, optional
338
+ random_state: int | numpy.random.Generator, optional
339
339
  Seed to be used to instantiate numpy random number generator instance.
340
340
  Defaults to None.
341
341
 
@@ -468,8 +468,7 @@ def bandstop_filter(
468
468
  bandwidth: float,
469
469
  freqs_to_notch: npt.ArrayLike | None,
470
470
  ) -> tuple[torch.Tensor, torch.Tensor]:
471
- """Apply a band-stop filter with desired bandwidth at the desired frequency.
472
-
471
+ """Apply a band-stop filter with desired bandwidth at the desired frequency
473
472
  position.
474
473
 
475
474
  Suggested e.g. in [1]_ and [2]_
@@ -621,7 +620,6 @@ def _torch_legval(
621
620
  ) -> torch.Tensor:
622
621
  """
623
622
  Evaluate a Legendre series at points x.
624
-
625
623
  If `c` is of length `n + 1`, this function returns the value:
626
624
  .. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
627
625
  The parameter `x` is converted to an array only if it is a tuple or a
@@ -807,6 +805,12 @@ def _torch_make_interpolation_matrix(
807
805
  The interpolation matrix that maps good signals to the location
808
806
  of bad signals.
809
807
 
808
+ References
809
+ ----------
810
+ [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
811
+ Spherical splines for scalp potential and current density mapping.
812
+ Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
813
+
810
814
  Notes
811
815
  -----
812
816
  Code copied and modified from MNE-Python:
@@ -837,12 +841,6 @@ def _torch_make_interpolation_matrix(
837
841
  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
838
842
  OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
839
843
  DAMAGE.
840
-
841
- References
842
- ----------
843
- [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
844
- Spherical splines for scalp potential and current density mapping.
845
- Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
846
844
  """
847
845
  pos_from = pos_from.clone()
848
846
  pos_to = pos_to.clone()
@@ -970,8 +968,7 @@ def sensors_rotation(
970
968
  angles: npt.ArrayLike,
971
969
  spherical_splines: bool,
972
970
  ) -> tuple[torch.Tensor, torch.Tensor]:
973
- """Interpolates EEG signals over sensors rotated around the desired axis.
974
-
971
+ """Interpolates EEG signals over sensors rotated around the desired axis
975
972
  with the desired angle.
976
973
 
977
974
  Suggested in [1]_
@@ -1030,7 +1027,7 @@ def mixup(
1030
1027
  lam : torch.Tensor
1031
1028
  Values between 0 and 1 setting the linear interpolation between
1032
1029
  examples.
1033
- idx_perm : torch.Tensor
1030
+ idx_perm: torch.Tensor
1034
1031
  Permuted indices of example that are mixed into original examples.
1035
1032
 
1036
1033
  Returns
@@ -1083,20 +1080,18 @@ def segmentation_reconstruction(
1083
1080
  EEG labels for the example or batch.
1084
1081
  n_segments : int
1085
1082
  Number of segments to use in the batch.
1086
- data_classes : list[tuple[int, torch.Tensor]]
1083
+ data_classes: list[tuple[int, torch.Tensor]]
1087
1084
  List of tuples. Each tuple contains the class index and the corresponding EEG data.
1088
- rand_indices : array-like
1085
+ rand_indices: array-like
1089
1086
  Array of indices that indicates which trial to use in each segment.
1090
- idx_shuffle : array-like
1087
+ idx_shuffle: array-like
1091
1088
  Array of indices to shuffle the new generated trials.
1092
-
1093
1089
  Returns
1094
1090
  -------
1095
1091
  torch.Tensor
1096
1092
  Transformed inputs.
1097
1093
  torch.Tensor
1098
1094
  Transformed labels.
1099
-
1100
1095
  References
1101
1096
  ----------
1102
1097
  .. [1] Lotte, F. (2015). Signal processing approaches to minimize or
@@ -1150,7 +1145,7 @@ def mask_encoding(
1150
1145
  segment_length: int,
1151
1146
  n_segments: int,
1152
1147
  ) -> tuple[torch.Tensor, torch.Tensor]:
1153
- """Mark encoding from Ding et al (2024) from [ding2024]_.
1148
+ """Mark encoding from Ding et al. (2024) from [ding2024]_.
1154
1149
 
1155
1150
  Replaces a contiguous part (or parts) of all channels by zeros
1156
1151
  (if more than one segment, it may overlap).
@@ -1217,7 +1212,7 @@ def channels_rereference(
1217
1212
  EEG input example or batch.
1218
1213
  y : torch.Tensor
1219
1214
  EEG labels for the example or batch.
1220
- random_state : int | numpy.random.Generator, optional
1215
+ random_state: int | numpy.random.Generator, optional
1221
1216
  Seed to be used to instantiate numpy random number generator instance.
1222
1217
  Defaults to None.
1223
1218
 
@@ -1234,6 +1229,7 @@ def channels_rereference(
1234
1229
  Representation Learning for Electroencephalogram Classification. Proceedings
1235
1230
  of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1236
1231
  Learning Research 136:238-253
1232
+
1237
1233
  """
1238
1234
 
1239
1235
  rng = check_random_state(random_state)
@@ -1266,7 +1262,7 @@ def amplitude_scale(
1266
1262
  EEG labels for the example or batch.
1267
1263
  scale : tuple of floats
1268
1264
  Interval from which ypu sample the scaling value
1269
- random_state : int | numpy.random.Generator, optional
1265
+ random_state: int | numpy.random.Generator, optional
1270
1266
  Seed to be used to instantiate numpy random number generator instance.
1271
1267
  Defaults to None.
1272
1268
 
@@ -1283,6 +1279,7 @@ def amplitude_scale(
1283
1279
  Representation Learning for Electroencephalogram Classification. Proceedings
1284
1280
  of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1285
1281
  Learning Research 136:238-253
1282
+
1286
1283
  """
1287
1284
 
1288
1285
  rng = torch.Generator()
@@ -40,7 +40,7 @@ class TimeReverse(Transform):
40
40
  ----------
41
41
  probability : float
42
42
  Float setting the probability of applying the operation.
43
- random_state : int | numpy.random.Generator, optional
43
+ random_state: int | numpy.random.Generator, optional
44
44
  Seed to be used to instantiate numpy random number generator instance.
45
45
  Used to decide whether or not to transform given the probability
46
46
  argument. Defaults to None.
@@ -66,7 +66,7 @@ class SignFlip(Transform):
66
66
  ----------
67
67
  probability : float
68
68
  Float setting the probability of applying the operation.
69
- random_state : int | numpy.random.Generator, optional
69
+ random_state: int | numpy.random.Generator, optional
70
70
  Seed to be used to instantiate numpy random number generator instance.
71
71
  Used to decide whether or not to transform given the probability
72
72
  argument. Defaults to None.
@@ -83,7 +83,7 @@ class FTSurrogate(Transform):
83
83
 
84
84
  Parameters
85
85
  ----------
86
- probability : float
86
+ probability: float
87
87
  Float setting the probability of applying the operation.
88
88
  phase_noise_magnitude : float | torch.Tensor, optional
89
89
  Float between 0 and 1 setting the range over which the phase
@@ -93,7 +93,7 @@ class FTSurrogate(Transform):
93
93
  Whether to sample phase perturbations independently for each channel or
94
94
  not. It is advised to set it to False when spatial information is
95
95
  important for the task, like in BCI. Default False.
96
- random_state : int | numpy.random.Generator, optional
96
+ random_state: int | numpy.random.Generator, optional
97
97
  Seed to be used to instantiate numpy random number generator instance.
98
98
  Used to decide whether or not to transform given the probability
99
99
  argument. Defaults to None.
@@ -162,12 +162,12 @@ class ChannelsDropout(Transform):
162
162
 
163
163
  Parameters
164
164
  ----------
165
- probability : float
165
+ probability: float
166
166
  Float setting the probability of applying the operation.
167
- p_drop : float | None, optional
167
+ p_drop: float | None, optional
168
168
  Float between 0 and 1 setting the probability of dropping each channel.
169
169
  Defaults to 0.2.
170
- random_state : int | numpy.random.RandomState, optional
170
+ random_state: int | numpy.random.RandomState, optional
171
171
  Seed to be used to instantiate numpy random number generator instance.
172
172
  Used to decide whether or not to transform given the probability
173
173
  argument and to sample channels to erase. Defaults to None.
@@ -219,12 +219,12 @@ class ChannelsShuffle(Transform):
219
219
 
220
220
  Parameters
221
221
  ----------
222
- probability : float
222
+ probability: float
223
223
  Float setting the probability of applying the operation.
224
- p_shuffle : float | None, optional
224
+ p_shuffle: float | None, optional
225
225
  Float between 0 and 1 setting the probability of including the channel
226
226
  in the set of permuted channels. Defaults to 0.2.
227
- random_state : int | numpy.random.Generator, optional
227
+ random_state: int | numpy.random.Generator, optional
228
228
  Seed to be used to instantiate numpy random number generator instance.
229
229
  Used to decide whether or not to transform given the probability
230
230
  argument, to sample which channels to shuffle and to carry the shuffle.
@@ -281,7 +281,7 @@ class GaussianNoise(Transform):
281
281
  Float setting the probability of applying the operation.
282
282
  std : float, optional
283
283
  Standard deviation to use for the additive noise. Defaults to 0.1.
284
- random_state : int | numpy.random.Generator, optional
284
+ random_state: int | numpy.random.Generator, optional
285
285
  Seed to be used to instantiate numpy random number generator instance.
286
286
  Defaults to None.
287
287
 
@@ -348,7 +348,7 @@ class ChannelsSymmetry(Transform):
348
348
  nomenclature) of the EEG channels that will be transformed. The
349
349
  first name should correspond the data in the first row of X, the
350
350
  second name in the second row and so on.
351
- random_state : int | numpy.random.Generator, optional
351
+ random_state: int | numpy.random.Generator, optional
352
352
  Seed to be used to instantiate numpy random number generator instance.
353
353
  Used to decide whether or not to transform given the probability
354
354
  argument. Defaults to None.
@@ -410,8 +410,7 @@ class ChannelsSymmetry(Transform):
410
410
 
411
411
 
412
412
  class SmoothTimeMask(Transform):
413
- """Smoothly replace a randomly chosen contiguous part of all channels by.
414
-
413
+ """Smoothly replace a randomly chosen contiguous part of all channels by
415
414
  zeros.
416
415
 
417
416
  Suggested e.g. in [1]_ and [2]_
@@ -423,7 +422,7 @@ class SmoothTimeMask(Transform):
423
422
  mask_len_samples : int | torch.Tensor, optional
424
423
  Number of consecutive samples to zero out. Will be ignored if
425
424
  magnitude is not set to None. Defaults to 100.
426
- random_state : int | numpy.random.Generator, optional
425
+ random_state: int | numpy.random.Generator, optional
427
426
  Seed to be used to instantiate numpy random number generator instance.
428
427
  Defaults to None.
429
428
 
@@ -496,8 +495,7 @@ class SmoothTimeMask(Transform):
496
495
 
497
496
 
498
497
  class BandstopFilter(Transform):
499
- """Apply a band-stop filter with desired bandwidth at a randomly selected.
500
-
498
+ """Apply a band-stop filter with desired bandwidth at a randomly selected
501
499
  frequency position between 0 and ``max_freq``.
502
500
 
503
501
  Suggested e.g. in [1]_ and [2]_
@@ -516,7 +514,7 @@ class BandstopFilter(Transform):
516
514
  that the corresponding high cut frequency + transition (=1Hz) are below
517
515
  ``max_freq``. If omitted or `None`, will default to the Nyquist
518
516
  frequency (``sfreq / 2``).
519
- random_state : int | numpy.random.Generator, optional
517
+ random_state: int | numpy.random.Generator, optional
520
518
  Seed to be used to instantiate numpy random number generator instance.
521
519
  Defaults to None.
522
520
 
@@ -625,7 +623,7 @@ class FrequencyShift(Transform):
625
623
  max_delta_freq : float | torch.Tensor, optional
626
624
  Maximum shift in Hz that can be sampled (in absolute value).
627
625
  Defaults to 2 (shift sampled between -2 and 2 Hz).
628
- random_state : int | numpy.random.Generator, optional
626
+ random_state: int | numpy.random.Generator, optional
629
627
  Seed to be used to instantiate numpy random number generator instance.
630
628
  Defaults to None.
631
629
  """
@@ -680,8 +678,7 @@ class FrequencyShift(Transform):
680
678
 
681
679
 
682
680
  def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
683
- """Returns standard 10-20 sensors position matrix (for instantiating.
684
-
681
+ """Returns standard 10-20 sensors position matrix (for instantiating
685
682
  SensorsRotation for example).
686
683
 
687
684
  Parameters
@@ -709,8 +706,7 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
709
706
 
710
707
 
711
708
  class SensorsRotation(Transform):
712
- """Interpolates EEG signals over sensors rotated around the desired axis.
713
-
709
+ """Interpolates EEG signals over sensors rotated around the desired axis
714
710
  with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
715
711
 
716
712
  Suggested in [1]_
@@ -738,7 +734,7 @@ class SensorsRotation(Transform):
738
734
  Whether to use spherical splines for the interpolation or not. When
739
735
  ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
740
736
  be used (as in the original paper). Defaults to True.
741
- random_state : int | numpy.random.Generator, optional
737
+ random_state: int | numpy.random.Generator, optional
742
738
  Seed to be used to instantiate numpy random number generator instance.
743
739
  Defaults to None.
744
740
 
@@ -837,8 +833,7 @@ class SensorsRotation(Transform):
837
833
 
838
834
 
839
835
  class SensorsZRotation(SensorsRotation):
840
- """Interpolates EEG signals over sensors rotated around the Z axis.
841
-
836
+ """Interpolates EEG signals over sensors rotated around the Z axis
842
837
  with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
843
838
 
844
839
  Suggested in [1]_
@@ -860,7 +855,7 @@ class SensorsZRotation(SensorsRotation):
860
855
  Whether to use spherical splines for the interpolation or not. When
861
856
  ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
862
857
  be used (as in the original paper). Defaults to True.
863
- random_state : int | numpy.random.Generator, optional
858
+ random_state: int | numpy.random.Generator, optional
864
859
  Seed to be used to instantiate numpy random number generator instance.
865
860
  Defaults to None.
866
861
 
@@ -894,8 +889,7 @@ class SensorsZRotation(SensorsRotation):
894
889
 
895
890
 
896
891
  class SensorsYRotation(SensorsRotation):
897
- """Interpolates EEG signals over sensors rotated around the Y axis.
898
-
892
+ """Interpolates EEG signals over sensors rotated around the Y axis
899
893
  with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
900
894
 
901
895
  Suggested in [1]_
@@ -917,7 +911,7 @@ class SensorsYRotation(SensorsRotation):
917
911
  Whether to use spherical splines for the interpolation or not. When
918
912
  ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
919
913
  be used (as in the original paper). Defaults to True.
920
- random_state : int | numpy.random.Generator, optional
914
+ random_state: int | numpy.random.Generator, optional
921
915
  Seed to be used to instantiate numpy random number generator instance.
922
916
  Defaults to None.
923
917
 
@@ -951,8 +945,7 @@ class SensorsYRotation(SensorsRotation):
951
945
 
952
946
 
953
947
  class SensorsXRotation(SensorsRotation):
954
- """Interpolates EEG signals over sensors rotated around the X axis.
955
-
948
+ """Interpolates EEG signals over sensors rotated around the X axis
956
949
  with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
957
950
 
958
951
  Suggested in [1]_
@@ -974,7 +967,7 @@ class SensorsXRotation(SensorsRotation):
974
967
  Whether to use spherical splines for the interpolation or not. When
975
968
  ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
976
969
  be used (as in the original paper). Defaults to True.
977
- random_state : int | numpy.random.Generator, optional
970
+ random_state: int | numpy.random.Generator, optional
978
971
  Seed to be used to instantiate numpy random number generator instance.
979
972
  Defaults to None.
980
973
 
@@ -1008,19 +1001,17 @@ class SensorsXRotation(SensorsRotation):
1008
1001
 
1009
1002
 
1010
1003
  class Mixup(Transform):
1011
- """Implements Iterator for Mixup for EEG data.
1012
-
1013
- See [1]_.
1004
+ """Implements Iterator for Mixup for EEG data. See [1]_.
1014
1005
  Implementation based on [2]_.
1015
1006
 
1016
1007
  Parameters
1017
1008
  ----------
1018
- alpha : float
1009
+ alpha: float
1019
1010
  Mixup hyperparameter.
1020
- beta_per_sample : bool (default=False)
1011
+ beta_per_sample: bool (default=False)
1021
1012
  By default, one mixing coefficient per batch is drawn from a beta
1022
1013
  distribution. If True, one mixing coefficient per sample is drawn.
1023
- random_state : int | numpy.random.Generator, optional
1014
+ random_state: int | numpy.random.Generator, optional
1024
1015
  Seed to be used to instantiate numpy random number generator instance.
1025
1016
  Defaults to None.
1026
1017
 
@@ -1055,7 +1046,7 @@ class Mixup(Transform):
1055
1046
 
1056
1047
  Returns
1057
1048
  -------
1058
- params : dict
1049
+ params: dict
1059
1050
  Contains the values sampled uniformly between 0 and 1 setting the
1060
1051
  linear interpolation between examples (lam) and the shuffled
1061
1052
  indices of examples that are mixed into original examples
@@ -1100,7 +1091,7 @@ class SegmentationReconstruction(Transform):
1100
1091
  ----------
1101
1092
  probability : float
1102
1093
  Float setting the probability of applying the operation.
1103
- random_state : int | numpy.random.Generator, optional
1094
+ random_state: int | numpy.random.Generator, optional
1104
1095
  Seed to be used to instantiate numpy random number generator instance.
1105
1096
  Used to decide whether to transform given the probability
1106
1097
  argument and to sample the segments mixing. Defaults to None.
@@ -1139,7 +1130,6 @@ class SegmentationReconstruction(Transform):
1139
1130
  The data.
1140
1131
  y : tensor.Tensor
1141
1132
  The labels.
1142
-
1143
1133
  Returns
1144
1134
  -------
1145
1135
  params : dict
@@ -1208,12 +1198,12 @@ class MaskEncoding(Transform):
1208
1198
  ----------
1209
1199
  probability : float
1210
1200
  Float setting the probability of applying the operation.
1211
- max_mask_ratio : float, optional
1201
+ max_mask_ratio: float, optional
1212
1202
  Signal ratio to zero out. Defaults to 0.1.
1213
1203
  n_segments : int, optional
1214
1204
  Number of segments to zero out in each example.
1215
1205
  Defaults to 1.
1216
- random_state : int | numpy.random.Generator, optional
1206
+ random_state: int | numpy.random.Generator, optional
1217
1207
  Seed to be used to instantiate numpy random number generator instance.
1218
1208
  Defaults to None.
1219
1209
 
@@ -1257,7 +1247,6 @@ class MaskEncoding(Transform):
1257
1247
  The data.
1258
1248
  y : tensor.Tensor
1259
1249
  The labels.
1260
-
1261
1250
  Returns
1262
1251
  -------
1263
1252
  params : dict
@@ -1294,9 +1283,9 @@ class ChannelsReref(Transform):
1294
1283
 
1295
1284
  Parameters
1296
1285
  ----------
1297
- probability : float
1286
+ probability: float
1298
1287
  Float setting the probability of applying the operation.
1299
- random_state : int | numpy.random.Generator, optional
1288
+ random_state: int | numpy.random.Generator, optional
1300
1289
  Seed to be used to instantiate numpy random number generator instance.
1301
1290
  Used to decide whether or not to transform given the probability
1302
1291
  argument, to sample which channels to shuffle and to carry the shuffle.
@@ -1308,6 +1297,7 @@ class ChannelsReref(Transform):
1308
1297
  Representation Learning for Electroencephalogram Classification. Proceedings
1309
1298
  of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1310
1299
  Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1300
+
1311
1301
  """
1312
1302
 
1313
1303
  operation = staticmethod(channels_rereference) # type: ignore[assignment]
@@ -1316,7 +1306,7 @@ class ChannelsReref(Transform):
1316
1306
  super().__init__(probability=probability, random_state=random_state)
1317
1307
 
1318
1308
  def get_augmentation_params(self, *batch):
1319
- """Return transform parameters."""
1309
+ """Return transform parameters"""
1320
1310
  return {
1321
1311
  "random_state": self.rng,
1322
1312
  }
@@ -1329,9 +1319,9 @@ class AmplitudeScale(Transform):
1329
1319
 
1330
1320
  Parameters
1331
1321
  ----------
1332
- probability : float
1322
+ probability: float
1333
1323
  Float setting the probability of applying the operation.
1334
- random_state : int | numpy.random.Generator, optional
1324
+ random_state: int | numpy.random.Generator, optional
1335
1325
  Seed to be used to instantiate numpy random number generator instance.
1336
1326
  Used to decide whether or not to transform given the probability
1337
1327
  argument, to sample which channels to shuffle and to carry the shuffle.
@@ -1343,6 +1333,7 @@ class AmplitudeScale(Transform):
1343
1333
  Representation Learning for Electroencephalogram Classification. Proceedings
1344
1334
  of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1345
1335
  Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1336
+
1346
1337
  """
1347
1338
 
1348
1339
  operation = staticmethod(amplitude_scale) # type: ignore[assignment]
@@ -1352,5 +1343,5 @@ class AmplitudeScale(Transform):
1352
1343
  self.scale = interval
1353
1344
 
1354
1345
  def get_augmentation_params(self, *batch):
1355
- """Return transform parameters."""
1346
+ """Return transform parameters"""
1356
1347
  return {"random_state": self.rng, "scale": self.scale}
braindecode/classifier.py CHANGED
@@ -8,6 +8,7 @@
8
8
 
9
9
  import warnings
10
10
 
11
+ import numpy as np
11
12
  from skorch import NeuralNet
12
13
  from skorch.callbacks import EpochScoring
13
14
  from skorch.classifier import NeuralNetClassifier
@@ -94,8 +95,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
94
95
  return iterator
95
96
 
96
97
  def predict_proba(self, X):
97
- """Return the output of the module's forward method as a numpy.
98
-
98
+ """Return the output of the module's forward method as a numpy
99
99
  array. In case of cropped decoding returns averaged values for
100
100
  each trial.
101
101
 
@@ -125,6 +125,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
125
125
  Returns
126
126
  -------
127
127
  y_proba : numpy ndarray
128
+
128
129
  """
129
130
  y_pred = super().predict_proba(X)
130
131
  # Normally, we have to average the predictions across crops/timesteps
@@ -191,28 +192,28 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
191
192
  Returns
192
193
  -------
193
194
  y_pred : numpy ndarray
195
+
194
196
  """
195
197
  return self.predict_proba(X).argmax(1)
196
198
 
197
199
  def predict_trials(self, X, return_targets=True):
198
- """Create trialwise predictions and optionally also return trialwise.
199
-
200
+ """Create trialwise predictions and optionally also return trialwise
200
201
  labels from cropped dataset.
201
202
 
202
203
  Parameters
203
204
  ----------
204
- X : braindecode.datasets.BaseConcatDataset
205
+ X: braindecode.datasets.BaseConcatDataset
205
206
  A braindecode dataset to be predicted.
206
- return_targets : bool
207
+ return_targets: bool
207
208
  If True, additionally returns the trial targets.
208
209
 
209
210
  Returns
210
211
  -------
211
- trial_predictions : np.ndarray
212
+ trial_predictions: np.ndarray
212
213
  3-dimensional array (n_trials x n_classes x n_predictions), where
213
214
  the number of predictions depend on the chosen window size and the
214
215
  receptive field of the network.
215
- trial_labels : np.ndarray
216
+ trial_labels: np.ndarray
216
217
  2-dimensional array (n_trials x n_targets) where the number of
217
218
  targets depends on the decoding paradigm and can be either a single
218
219
  value, multiple values, or a sequence.
@@ -236,9 +237,13 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
236
237
  num_workers=self.get_iterator(X, training=False).loader.num_workers,
237
238
  )
238
239
 
239
- @property
240
- def mode(self):
241
- return "classification"
240
+ def _get_n_outputs(self, y, classes):
241
+ classes_y = np.unique(y)
242
+ if classes is not None:
243
+ assert set(classes_y) <= set(classes)
244
+ else:
245
+ classes = classes_y
246
+ return len(classes)
242
247
 
243
248
  # Only add the 'accuracy' callback if we are not in cropped mode.
244
249
  @property
@@ -1,4 +1,6 @@
1
- """Loader code for some datasets."""
1
+ """
2
+ Loader code for some datasets.
3
+ """
2
4
 
3
5
  from .base import (
4
6
  BaseConcatDataset,
@@ -9,11 +11,9 @@ from .base import (
9
11
  )
10
12
  from .bcicomp import BCICompetitionIVDataset4
11
13
  from .bids import BIDSDataset, BIDSEpochsDataset
12
- from .chb_mit import CHBMIT
13
14
  from .mne import create_from_mne_epochs, create_from_mne_raw
14
15
  from .moabb import BNCI2014_001, HGD, MOABBDataset
15
16
  from .nmt import NMT
16
- from .siena import SIENA
17
17
  from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
18
18
  from .sleep_physionet import SleepPhysionet
19
19
  from .tuh import TUH, TUHAbnormal
@@ -34,9 +34,7 @@ __all__ = [
34
34
  "create_from_mne_epochs",
35
35
  "TUH",
36
36
  "TUHAbnormal",
37
- "SIENA",
38
37
  "NMT",
39
- "CHBMIT",
40
38
  "SleepPhysionet",
41
39
  "SleepPhysionetChallenge2018",
42
40
  "create_from_X_y",