braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev168011974__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (39) hide show
  1. braindecode/augmentation/functional.py +154 -54
  2. braindecode/augmentation/transforms.py +2 -2
  3. braindecode/datasets/experimental.py +218 -0
  4. braindecode/datasets/sleep_physio_challe_18.py +2 -1
  5. braindecode/datautil/serialization.py +11 -6
  6. braindecode/models/__init__.py +6 -8
  7. braindecode/models/atcnet.py +156 -16
  8. braindecode/models/attentionbasenet.py +151 -26
  9. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  10. braindecode/models/biot.py +1 -1
  11. braindecode/models/ctnet.py +1 -1
  12. braindecode/models/deep4.py +6 -2
  13. braindecode/models/deepsleepnet.py +118 -5
  14. braindecode/models/eegconformer.py +114 -15
  15. braindecode/models/eeginception_erp.py +76 -7
  16. braindecode/models/eeginception_mi.py +2 -0
  17. braindecode/models/eegnet.py +64 -177
  18. braindecode/models/eegnex.py +113 -6
  19. braindecode/models/eegsimpleconv.py +2 -0
  20. braindecode/models/eegtcnet.py +1 -1
  21. braindecode/models/sccnet.py +81 -8
  22. braindecode/models/shallow_fbcsp.py +2 -0
  23. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  24. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  25. braindecode/models/sparcnet.py +2 -0
  26. braindecode/models/summary.csv +39 -41
  27. braindecode/models/tidnet.py +2 -0
  28. braindecode/models/tsinception.py +15 -3
  29. braindecode/models/usleep.py +103 -9
  30. braindecode/models/util.py +5 -5
  31. braindecode/preprocessing/preprocess.py +31 -28
  32. braindecode/version.py +1 -1
  33. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/METADATA +7 -2
  34. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/RECORD +38 -38
  35. braindecode/models/eegresnet.py +0 -362
  36. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/WHEEL +0 -0
  37. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/licenses/LICENSE.txt +0 -0
  38. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/licenses/NOTICE.txt +0 -0
  39. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev168011974.dist-info}/top_level.txt +0 -0
@@ -4,9 +4,13 @@
4
4
  #
5
5
  # License: BSD (3-clause)
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  from numbers import Real
10
+ from typing import Literal
8
11
 
9
12
  import numpy as np
13
+ import numpy.typing as npt
10
14
  import torch
11
15
  from mne.filter import notch_filter
12
16
  from scipy.interpolate import Rbf
@@ -15,7 +19,7 @@ from torch.fft import fft, ifft
15
19
  from torch.nn.functional import one_hot, pad
16
20
 
17
21
 
18
- def identity(X, y):
22
+ def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
23
  """Identity operation.
20
24
 
21
25
  Parameters
@@ -35,7 +39,7 @@ def identity(X, y):
35
39
  return X, y
36
40
 
37
41
 
38
- def time_reverse(X, y):
42
+ def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
39
43
  """Flip the time axis of each input.
40
44
 
41
45
  Parameters
@@ -55,7 +59,7 @@ def time_reverse(X, y):
55
59
  return torch.flip(X, [-1]), y
56
60
 
57
61
 
58
- def sign_flip(X, y):
62
+ def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
59
63
  """Flip the sign axis of each input.
60
64
 
61
65
  Parameters
@@ -75,7 +79,13 @@ def sign_flip(X, y):
75
79
  return -X, y
76
80
 
77
81
 
78
- def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
82
+ def _new_random_fft_phase_odd(
83
+ batch_size: int,
84
+ c: int,
85
+ n: int,
86
+ device: torch.device,
87
+ random_state: int | np.random.RandomState | None,
88
+ ) -> torch.Tensor:
79
89
  rng = check_random_state(random_state)
80
90
  random_phase = torch.from_numpy(
81
91
  2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
@@ -90,7 +100,13 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
90
100
  )
91
101
 
92
102
 
93
- def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
103
+ def _new_random_fft_phase_even(
104
+ batch_size: int,
105
+ c: int,
106
+ n: int,
107
+ device: torch.device,
108
+ random_state: int | np.random.RandomState | None,
109
+ ) -> torch.Tensor:
94
110
  rng = check_random_state(random_state)
95
111
  random_phase = torch.from_numpy(
96
112
  2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
@@ -109,7 +125,13 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
109
125
  _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
110
126
 
111
127
 
112
- def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
128
+ def ft_surrogate(
129
+ X: torch.Tensor,
130
+ y: torch.Tensor,
131
+ phase_noise_magnitude: float,
132
+ channel_indep: bool,
133
+ random_state: int | np.random.RandomState | None = None,
134
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
135
  """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
114
136
 
115
137
  Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
@@ -175,7 +197,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
175
197
  return transformed_X, y
176
198
 
177
199
 
178
- def _pick_channels_randomly(X, p_pick, random_state):
200
+ def _pick_channels_randomly(
201
+ X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
202
+ ) -> torch.Tensor:
179
203
  rng = check_random_state(random_state)
180
204
  batch_size, n_channels, _ = X.shape
181
205
  # allows to use the same RNG
@@ -188,7 +212,12 @@ def _pick_channels_randomly(X, p_pick, random_state):
188
212
  return torch.sigmoid(1000 * (unif_samples - p_pick))
189
213
 
190
214
 
191
- def channels_dropout(X, y, p_drop, random_state=None):
215
+ def channels_dropout(
216
+ X: torch.Tensor,
217
+ y: torch.Tensor,
218
+ p_drop: float,
219
+ random_state: int | np.random.RandomState | None = None,
220
+ ) -> tuple[torch.Tensor, torch.Tensor]:
192
221
  """Randomly set channels to flat signal.
193
222
 
194
223
  Part of the CMSAugment policy proposed in [1]_
@@ -222,7 +251,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
222
251
  return X * mask.unsqueeze(-1), y
223
252
 
224
253
 
225
- def _make_permutation_matrix(X, mask, random_state):
254
+ def _make_permutation_matrix(
255
+ X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
256
+ ) -> torch.Tensor:
226
257
  rng = check_random_state(random_state)
227
258
  batch_size, n_channels, _ = X.shape
228
259
  hard_mask = mask.round()
@@ -241,7 +272,12 @@ def _make_permutation_matrix(X, mask, random_state):
241
272
  return batch_permutations
242
273
 
243
274
 
244
- def channels_shuffle(X, y, p_shuffle, random_state=None):
275
+ def channels_shuffle(
276
+ X: torch.Tensor,
277
+ y: torch.Tensor,
278
+ p_shuffle: float,
279
+ random_state: int | np.random.RandomState | None = None,
280
+ ) -> tuple[torch.Tensor, torch.Tensor]:
245
281
  """Randomly shuffle channels in EEG data matrix.
246
282
 
247
283
  Part of the CMSAugment policy proposed in [1]_
@@ -280,7 +316,12 @@ def channels_shuffle(X, y, p_shuffle, random_state=None):
280
316
  return torch.matmul(batch_permutations, X), y
281
317
 
282
318
 
283
- def gaussian_noise(X, y, std, random_state=None):
319
+ def gaussian_noise(
320
+ X: torch.Tensor,
321
+ y: torch.Tensor,
322
+ std: float,
323
+ random_state: int | np.random.RandomState | None = None,
324
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
325
  """Randomly add white Gaussian noise to all channels.
285
326
 
286
327
  Suggested e.g. in [1]_, [2]_ and [3]_
@@ -332,7 +373,9 @@ def gaussian_noise(X, y, std, random_state=None):
332
373
  return transformed_X, y
333
374
 
334
375
 
335
- def channels_permute(X, y, permutation):
376
+ def channels_permute(
377
+ X: torch.Tensor, y: torch.Tensor, permutation: list[int]
378
+ ) -> tuple[torch.Tensor, torch.Tensor]:
336
379
  """Permute EEG channels according to fixed permutation matrix.
337
380
 
338
381
  Suggested e.g. in [1]_
@@ -362,7 +405,12 @@ def channels_permute(X, y, permutation):
362
405
  return X[..., permutation, :], y
363
406
 
364
407
 
365
- def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
408
+ def smooth_time_mask(
409
+ X: torch.Tensor,
410
+ y: torch.Tensor,
411
+ mask_start_per_sample: torch.Tensor,
412
+ mask_len_samples: int,
413
+ ) -> tuple[torch.Tensor, torch.Tensor]:
366
414
  """Smoothly replace a contiguous part of all channels by zeros.
367
415
 
368
416
  Originally proposed in [1]_ and [2]_
@@ -412,7 +460,13 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
412
460
  return X * mask, y
413
461
 
414
462
 
415
- def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
463
+ def bandstop_filter(
464
+ X: torch.Tensor,
465
+ y: torch.Tensor,
466
+ sfreq: float,
467
+ bandwidth: float,
468
+ freqs_to_notch: npt.ArrayLike | None,
469
+ ) -> tuple[torch.Tensor, torch.Tensor]:
416
470
  """Apply a band-stop filter with desired bandwidth at the desired frequency
417
471
  position.
418
472
 
@@ -451,7 +505,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
451
505
  Representation Learning for Electroencephalogram Classification. In
452
506
  Machine Learning for Health (pp. 238-253). PMLR.
453
507
  """
454
- if bandwidth == 0:
508
+ if bandwidth == 0 or freqs_to_notch is None:
455
509
  return X, y
456
510
  transformed_X = X.clone()
457
511
  for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
@@ -469,7 +523,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
469
523
  return transformed_X, y
470
524
 
471
525
 
472
- def _analytic_transform(x):
526
+ def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
473
527
  if torch.is_complex(x):
474
528
  raise ValueError("x must be real.")
475
529
 
@@ -486,12 +540,12 @@ def _analytic_transform(x):
486
540
  return ifft(f * h, dim=-1)
487
541
 
488
542
 
489
- def _nextpow2(n):
543
+ def _nextpow2(n: int) -> int:
490
544
  """Return the first integer N such that 2**N >= abs(n)."""
491
545
  return int(np.ceil(np.log2(np.abs(n))))
492
546
 
493
547
 
494
- def _frequency_shift(X, fs, f_shift):
548
+ def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
495
549
  """
496
550
  Shift the specified signal by the specified frequency.
497
551
 
@@ -504,9 +558,13 @@ def _frequency_shift(X, fs, f_shift):
504
558
  t = torch.arange(N_padded, device=X.device) / fs
505
559
  padded = pad(X, (0, N_padded - N_orig))
506
560
  analytical = _analytic_transform(padded)
507
- if isinstance(f_shift, (float, int, np.ndarray, list)):
508
- f_shift = torch.as_tensor(f_shift).float()
509
- f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
561
+ if isinstance(f_shift, torch.Tensor):
562
+ _f_shift = f_shift
563
+ elif isinstance(f_shift, (float, int, np.ndarray, list)):
564
+ _f_shift = torch.as_tensor(f_shift).float()
565
+ else:
566
+ raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
567
+ f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
510
568
  reshaped_f_shift = f_shift_stack.permute(
511
569
  *torch.arange(f_shift_stack.ndim - 1, -1, -1)
512
570
  )
@@ -514,7 +572,9 @@ def _frequency_shift(X, fs, f_shift):
514
572
  return shifted[..., :N_orig].real.float()
515
573
 
516
574
 
517
- def frequency_shift(X, y, delta_freq, sfreq):
575
+ def frequency_shift(
576
+ X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
577
+ ) -> tuple[torch.Tensor, torch.Tensor]:
518
578
  """Adds a shift in the frequency domain to all channels.
519
579
 
520
580
  Note that here, the shift is the same for all channels of a single example.
@@ -545,7 +605,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
545
605
  return transformed_X, y
546
606
 
547
607
 
548
- def _torch_normalize_vectors(rr):
608
+ def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
549
609
  """Normalize surface vertices."""
550
610
  norm = torch.linalg.norm(rr, axis=1, keepdim=True)
551
611
  mask = norm > 0
@@ -554,7 +614,9 @@ def _torch_normalize_vectors(rr):
554
614
  return new_rr
555
615
 
556
616
 
557
- def _torch_legval(x, c, tensor=True):
617
+ def _torch_legval(
618
+ x: torch.Tensor, c: torch.Tensor, tensor: bool = True
619
+ ) -> torch.Tensor:
558
620
  """
559
621
  Evaluate a Legendre series at points x.
560
622
  If `c` is of length `n + 1`, this function returns the value:
@@ -662,7 +724,9 @@ def _torch_legval(x, c, tensor=True):
662
724
  return c0 + c1 * x
663
725
 
664
726
 
665
- def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
727
+ def _torch_calc_g(
728
+ cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
729
+ ) -> torch.Tensor:
666
730
  """Calculate spherical spline g function between points on a sphere.
667
731
 
668
732
  Parameters
@@ -718,23 +782,25 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
718
782
  return _torch_legval(cosang, [0] + factors)
719
783
 
720
784
 
721
- def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
785
+ def _torch_make_interpolation_matrix(
786
+ pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
787
+ ) -> torch.Tensor:
722
788
  """Compute interpolation matrix based on spherical splines.
723
789
 
724
790
  Implementation based on [1]_
725
791
 
726
792
  Parameters
727
793
  ----------
728
- pos_from : np.ndarray of float, shape(n_good_sensors, 3)
794
+ pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
729
795
  The positions to interpolate from.
730
- pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
796
+ pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
731
797
  The positions to interpolate.
732
798
  alpha : float
733
799
  Regularization parameter. Defaults to 1e-5.
734
800
 
735
801
  Returns
736
802
  -------
737
- interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
803
+ interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
738
804
  The interpolation matrix that maps good signals to the location
739
805
  of bad signals.
740
806
 
@@ -822,7 +888,12 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
822
888
  return interpolation
823
889
 
824
890
 
825
- def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
891
+ def _rotate_signals(
892
+ X: torch.Tensor,
893
+ rotations: list[torch.Tensor],
894
+ sensors_positions_matrix: torch.Tensor,
895
+ spherical: bool = True,
896
+ ) -> torch.Tensor:
826
897
  sensors_positions_matrix = sensors_positions_matrix.to(X.device)
827
898
  rot_sensors_matrices = [
828
899
  rotation.matmul(sensors_positions_matrix) for rotation in rotations
@@ -853,22 +924,29 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
853
924
  return transformed_X
854
925
 
855
926
 
856
- def _make_rotation_matrix(axis, angle, degrees=True):
927
+ def _make_rotation_matrix(
928
+ axis: Literal["x", "y", "z"],
929
+ angle: float | int | np.ndarray | list | torch.Tensor,
930
+ degrees: bool = True,
931
+ ) -> torch.Tensor:
857
932
  assert axis in ["x", "y", "z"], "axis should be either x, y or z."
858
-
859
- if isinstance(angle, (float, int, np.ndarray, list)):
860
- angle = torch.as_tensor(angle)
933
+ if isinstance(angle, torch.Tensor):
934
+ _angle = angle
935
+ elif isinstance(angle, (float, int, np.ndarray, list)):
936
+ _angle = torch.as_tensor(angle)
937
+ else:
938
+ raise ValueError(f"Invalid angle type: {type(angle)}")
861
939
 
862
940
  if degrees:
863
- angle = angle * np.pi / 180
941
+ _angle = _angle * np.pi / 180
864
942
 
865
- device = angle.device
943
+ device = _angle.device
866
944
  zero = torch.zeros(1, device=device)
867
945
  rot = torch.stack(
868
946
  [
869
947
  torch.as_tensor([1, 0, 0], device=device),
870
- torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
871
- torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
948
+ torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
949
+ torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
872
950
  ]
873
951
  )
874
952
  if axis == "x":
@@ -881,7 +959,14 @@ def _make_rotation_matrix(axis, angle, degrees=True):
881
959
  return rot[:, [1, 2, 0]]
882
960
 
883
961
 
884
- def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
962
+ def sensors_rotation(
963
+ X: torch.Tensor,
964
+ y: torch.Tensor,
965
+ sensors_positions_matrix: torch.Tensor,
966
+ axis: Literal["x", "y", "z"],
967
+ angles: npt.ArrayLike,
968
+ spherical_splines: bool,
969
+ ) -> tuple[torch.Tensor, torch.Tensor]:
885
970
  """Interpolates EEG signals over sensors rotated around the desired axis
886
971
  with the desired angle.
887
972
 
@@ -893,7 +978,7 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
893
978
  EEG input example or batch.
894
979
  y : torch.Tensor
895
980
  EEG labels for the example or batch.
896
- sensors_positions_matrix : numpy.ndarray
981
+ sensors_positions_matrix : torch.Tensor
897
982
  Matrix giving the positions of each sensor in a 3D cartesian coordinate
898
983
  system. Should have shape (3, n_channels), where n_channels is the
899
984
  number of channels. Standard 10-20 positions can be obtained from
@@ -924,7 +1009,9 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
924
1009
  return rotated_X, y
925
1010
 
926
1011
 
927
- def mixup(X, y, lam, idx_perm):
1012
+ def mixup(
1013
+ X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
1014
+ ) -> tuple[torch.Tensor, torch.Tensor]:
928
1015
  """Mixes two channels of EEG data.
929
1016
 
930
1017
  See [1]_ for details.
@@ -973,8 +1060,13 @@ def mixup(X, y, lam, idx_perm):
973
1060
 
974
1061
 
975
1062
  def segmentation_reconstruction(
976
- X, y, n_segments, data_classes, rand_indices, idx_shuffle
977
- ):
1063
+ X: torch.Tensor,
1064
+ y: torch.Tensor,
1065
+ n_segments: int,
1066
+ data_classes: list[tuple[int, torch.Tensor]],
1067
+ rand_indices: npt.ArrayLike,
1068
+ idx_shuffle: npt.ArrayLike,
1069
+ ) -> tuple[torch.Tensor, torch.Tensor]:
978
1070
  """Segment and reconstruct EEG data from [1]_.
979
1071
 
980
1072
  See [1]_ for details.
@@ -987,6 +1079,8 @@ def segmentation_reconstruction(
987
1079
  EEG labels for the example or batch.
988
1080
  n_segments : int
989
1081
  Number of segments to use in the batch.
1082
+ data_classes: list[tuple[int, torch.Tensor]]
1083
+ List of tuples. Each tuple contains the class index and the corresponding EEG data.
990
1084
  rand_indices: array-like
991
1085
  Array of indices that indicates which trial to use in each segment.
992
1086
  idx_shuffle: array-like
@@ -1005,8 +1099,8 @@ def segmentation_reconstruction(
1005
1099
  """
1006
1100
 
1007
1101
  # Initialize lists to store augmented data and corresponding labels
1008
- aug_data = []
1009
- aug_label = []
1102
+ aug_data: list[torch.Tensor] = []
1103
+ aug_label: list[torch.Tensor] = []
1010
1104
 
1011
1105
  # Iterate through each class to separate and augment data
1012
1106
  for class_index, X_class in data_classes:
@@ -1030,20 +1124,26 @@ def segmentation_reconstruction(
1030
1124
  aug_data.append(X_aug)
1031
1125
  aug_label.append(torch.full((n_trials,), class_index))
1032
1126
  # Concatenate the augmented data and labels
1033
- aug_data = torch.cat(aug_data, dim=0)
1034
- aug_data = aug_data.to(dtype=X.dtype, device=X.device)
1035
- aug_data = aug_data[idx_shuffle]
1127
+ concat_aug_data = torch.cat(aug_data, dim=0)
1128
+ concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
1129
+ concat_aug_data = concat_aug_data[idx_shuffle]
1036
1130
 
1037
1131
  if y is not None:
1038
- aug_label = torch.cat(aug_label, dim=0)
1039
- aug_label = aug_label.to(dtype=y.dtype, device=y.device)
1040
- aug_label = aug_label[idx_shuffle]
1041
- return aug_data, aug_label
1132
+ concat_label = torch.cat(aug_label, dim=0)
1133
+ concat_label = concat_label.to(dtype=y.dtype, device=y.device)
1134
+ concat_label = concat_label[idx_shuffle]
1135
+ return concat_aug_data, concat_label
1042
1136
 
1043
- return aug_data, y
1137
+ return concat_aug_data, None
1044
1138
 
1045
1139
 
1046
- def mask_encoding(X, y, time_start, segment_length, n_segments):
1140
+ def mask_encoding(
1141
+ X: torch.Tensor,
1142
+ y: torch.Tensor,
1143
+ time_start: torch.Tensor,
1144
+ segment_length: int,
1145
+ n_segments: int,
1146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1047
1147
  """Mark encoding from Ding et al. (2024) from [ding2024]_.
1048
1148
 
1049
1149
  Replaces a contiguous part (or parts) of all channels by zeros
@@ -161,10 +161,10 @@ class ChannelsDropout(Transform):
161
161
  ----------
162
162
  probability: float
163
163
  Float setting the probability of applying the operation.
164
- proba_drop: float | None, optional
164
+ p_drop: float | None, optional
165
165
  Float between 0 and 1 setting the probability of dropping each channel.
166
166
  Defaults to 0.2.
167
- random_state: int | numpy.random.Generator, optional
167
+ random_state: int | numpy.random.RandomState, optional
168
168
  Seed to be used to instantiate numpy random number generator instance.
169
169
  Used to decide whether or not to transform given the probability
170
170
  argument and to sample channels to erase. Defaults to None.
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Callable, Sequence
6
+
7
+ import mne_bids
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+
11
+ class BIDSIterableDataset(IterableDataset):
12
+ """Dataset for loading BIDS.
13
+
14
+ .. warning::
15
+ This class is experimental and may change in the future.
16
+
17
+ .. warning::
18
+ This dataset is not consistent with the Braindecode API.
19
+
20
+ This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
21
+ as it will be used to find the files to load. The default ``extensions`` parameter was changed.
22
+
23
+ More information on BIDS (Brain Imaging Data Structure)
24
+ can be found at https://bids.neuroimaging.io
25
+
26
+ Examples
27
+ --------
28
+ >>> from braindecode.datasets import BaseDataset, BaseConcatDataset
29
+ >>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
30
+ >>> from braindecode.preprocessing import create_fixed_length_windows
31
+ >>>
32
+ >>> def my_reader_fn(path):
33
+ ... raw = mne_bids.read_raw_bids(path)
34
+ ... desc = _description_from_bids_path(path)
35
+ ... ds = BaseDataset(raw, description=desc)
36
+ ... windows_ds = create_fixed_length_windows(
37
+ ... BaseConcatDataset([ds]),
38
+ ... window_size_samples=400,
39
+ ... window_stride_samples=200,
40
+ ... )
41
+ ... return windows_ds
42
+ >>>
43
+ >>> dataset = BIDSIterableDataset(
44
+ ... reader_fn=my_reader_fn,
45
+ ... root="root/of/my/bids/dataset/",
46
+ ... )
47
+
48
+ Parameters
49
+ ----------
50
+ reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
51
+ A function that takes a BIDSPath and returns a dataset.
52
+ pool_size : int
53
+ The number of recordings to read and sample from.
54
+ bids_paths : list[mne_bids.BIDSPath] | None
55
+ A list of BIDSPaths to load. If None, will use the paths found by
56
+ :func:`mne_bids.find_matching_paths` and the arguments below.
57
+ root : pathlib.Path | str
58
+ The root of the BIDS path.
59
+ subjects : str | array-like of str | None
60
+ The subject ID. Corresponds to "sub".
61
+ sessions : str | array-like of str | None
62
+ The acquisition session. Corresponds to "ses".
63
+ tasks : str | array-like of str | None
64
+ The experimental task. Corresponds to "task".
65
+ acquisitions: str | array-like of str | None
66
+ The acquisition parameters. Corresponds to "acq".
67
+ runs : str | array-like of str | None
68
+ The run number. Corresponds to "run".
69
+ processings : str | array-like of str | None
70
+ The processing label. Corresponds to "proc".
71
+ recordings : str | array-like of str | None
72
+ The recording name. Corresponds to "rec".
73
+ spaces : str | array-like of str | None
74
+ The coordinate space for anatomical and sensor location
75
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
76
+ Corresponds to "space".
77
+ Note that valid values for ``space`` must come from a list
78
+ of BIDS keywords as described in the BIDS specification.
79
+ splits : str | array-like of str | None
80
+ The split of the continuous recording file for ``.fif`` data.
81
+ Corresponds to "split".
82
+ descriptions : str | array-like of str | None
83
+ This corresponds to the BIDS entity ``desc``. It is used to provide
84
+ additional information for derivative data, e.g., preprocessed data
85
+ may be assigned ``description='cleaned'``.
86
+ suffixes : str | array-like of str | None
87
+ The filename suffix. This is the entity after the
88
+ last ``_`` before the extension. E.g., ``'channels'``.
89
+ The following filename suffix's are accepted:
90
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
91
+ 'participants', 'scans', 'electrodes', 'coordsystem',
92
+ 'channels', 'events', 'headshape', 'digitizer',
93
+ 'beh', 'physio', 'stim'
94
+ extensions : str | array-like of str | None
95
+ The extension of the filename. E.g., ``'.json'``.
96
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
97
+ datatypes : str | array-like of str | None
98
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
99
+ ``'ieeg'``.
100
+ check : bool
101
+ If ``True``, only returns paths that conform to BIDS. If ``False``
102
+ (default), the ``.check`` attribute of the returned
103
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
104
+ do conform to BIDS, and to ``False`` for those that don't.
105
+ preload : bool
106
+ If True, preload the data. Defaults to False.
107
+ n_jobs : int
108
+ Number of jobs to run in parallel. Defaults to 1.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ reader_fn: Callable[[mne_bids.BIDSPath], Sequence],
114
+ pool_size: int = 4,
115
+ bids_paths: list[mne_bids.BIDSPath] | None = None,
116
+ root: Path | str | None = None,
117
+ subjects: str | list[str] | None = None,
118
+ sessions: str | list[str] | None = None,
119
+ tasks: str | list[str] | None = None,
120
+ acquisitions: str | list[str] | None = None,
121
+ runs: str | list[str] | None = None,
122
+ processings: str | list[str] | None = None,
123
+ recordings: str | list[str] | None = None,
124
+ spaces: str | list[str] | None = None,
125
+ splits: str | list[str] | None = None,
126
+ descriptions: str | list[str] | None = None,
127
+ suffixes: str | list[str] | None = None,
128
+ extensions: str | list[str] | None = [
129
+ ".con",
130
+ ".sqd",
131
+ ".pdf",
132
+ ".fif",
133
+ ".ds",
134
+ ".vhdr",
135
+ ".set",
136
+ ".edf",
137
+ ".bdf",
138
+ ".EDF",
139
+ ".snirf",
140
+ ".cdt",
141
+ ".mef",
142
+ ".nwb",
143
+ ],
144
+ datatypes: str | list[str] | None = None,
145
+ check: bool = False,
146
+ ):
147
+ if bids_paths is None:
148
+ bids_paths = mne_bids.find_matching_paths(
149
+ root=root,
150
+ subjects=subjects,
151
+ sessions=sessions,
152
+ tasks=tasks,
153
+ acquisitions=acquisitions,
154
+ runs=runs,
155
+ processings=processings,
156
+ recordings=recordings,
157
+ spaces=spaces,
158
+ splits=splits,
159
+ descriptions=descriptions,
160
+ suffixes=suffixes,
161
+ extensions=extensions,
162
+ datatypes=datatypes,
163
+ check=check,
164
+ ignore_json=True,
165
+ )
166
+ # Filter out _epo.fif files:
167
+ bids_paths = [
168
+ bids_path
169
+ for bids_path in bids_paths
170
+ if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
171
+ ]
172
+ self.bids_paths = bids_paths
173
+ self.reader_fn = reader_fn
174
+ self.pool_size = pool_size
175
+
176
+ def __add__(self, other):
177
+ assert isinstance(other, BIDSIterableDataset)
178
+ return BIDSIterableDataset(
179
+ reader_fn=self.reader_fn,
180
+ bids_paths=self.bids_paths + other.bids_paths,
181
+ pool_size=self.pool_size,
182
+ )
183
+
184
+ def __iadd__(self, other):
185
+ assert isinstance(other, BIDSIterableDataset)
186
+ self.bids_paths += other.bids_paths
187
+ return self
188
+
189
+ def __iter__(self):
190
+ worker_info = get_worker_info()
191
+ if worker_info is None: # single-process data loading, return the full iterator
192
+ bids_paths = self.bids_paths
193
+ else: # in a worker process
194
+ # split workload
195
+ bids_paths = self.bids_paths[worker_info.id :: worker_info.num_workers]
196
+
197
+ pool = []
198
+ end = False
199
+ paths_it = iter(random.sample(bids_paths, k=len(bids_paths)))
200
+ while not (end and len(pool) == 0):
201
+ while not end and len(pool) < self.pool_size:
202
+ try:
203
+ bids_path = next(paths_it)
204
+ ds = self.reader_fn(bids_path)
205
+ if ds is None:
206
+ print(f"Skipping {bids_path} as it is too short.")
207
+ continue
208
+ idx = iter(random.sample(range(len(ds)), k=len(ds)))
209
+ pool.append((ds, idx))
210
+ except StopIteration:
211
+ end = True
212
+ i_pool = random.randint(0, len(pool) - 1)
213
+ ds, idx = pool[i_pool]
214
+ try:
215
+ i_ds = next(idx)
216
+ yield ds[i_ds]
217
+ except StopIteration:
218
+ pool.pop(i_pool)