braindecode 1.3.0.dev174777731__py3-none-any.whl → 1.3.0.dev175415232__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

@@ -4,9 +4,13 @@
4
4
  #
5
5
  # License: BSD (3-clause)
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  from numbers import Real
10
+ from typing import Literal
8
11
 
9
12
  import numpy as np
13
+ import numpy.typing as npt
10
14
  import torch
11
15
  from mne.filter import notch_filter
12
16
  from scipy.interpolate import Rbf
@@ -15,7 +19,7 @@ from torch.fft import fft, ifft
15
19
  from torch.nn.functional import one_hot, pad
16
20
 
17
21
 
18
- def identity(X, y):
22
+ def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
23
  """Identity operation.
20
24
 
21
25
  Parameters
@@ -35,7 +39,7 @@ def identity(X, y):
35
39
  return X, y
36
40
 
37
41
 
38
- def time_reverse(X, y):
42
+ def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
39
43
  """Flip the time axis of each input.
40
44
 
41
45
  Parameters
@@ -55,7 +59,7 @@ def time_reverse(X, y):
55
59
  return torch.flip(X, [-1]), y
56
60
 
57
61
 
58
- def sign_flip(X, y):
62
+ def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
59
63
  """Flip the sign axis of each input.
60
64
 
61
65
  Parameters
@@ -75,7 +79,13 @@ def sign_flip(X, y):
75
79
  return -X, y
76
80
 
77
81
 
78
- def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
82
+ def _new_random_fft_phase_odd(
83
+ batch_size: int,
84
+ c: int,
85
+ n: int,
86
+ device: torch.device,
87
+ random_state: int | np.random.RandomState | None,
88
+ ) -> torch.Tensor:
79
89
  rng = check_random_state(random_state)
80
90
  random_phase = torch.from_numpy(
81
91
  2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
@@ -90,7 +100,13 @@ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
90
100
  )
91
101
 
92
102
 
93
- def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
103
+ def _new_random_fft_phase_even(
104
+ batch_size: int,
105
+ c: int,
106
+ n: int,
107
+ device: torch.device,
108
+ random_state: int | np.random.RandomState | None,
109
+ ) -> torch.Tensor:
94
110
  rng = check_random_state(random_state)
95
111
  random_phase = torch.from_numpy(
96
112
  2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
@@ -109,7 +125,13 @@ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
109
125
  _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
110
126
 
111
127
 
112
- def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
128
+ def ft_surrogate(
129
+ X: torch.Tensor,
130
+ y: torch.Tensor,
131
+ phase_noise_magnitude: float,
132
+ channel_indep: bool,
133
+ random_state: int | np.random.RandomState | None = None,
134
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
135
  """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
114
136
 
115
137
  Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
@@ -175,7 +197,9 @@ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
175
197
  return transformed_X, y
176
198
 
177
199
 
178
- def _pick_channels_randomly(X, p_pick, random_state):
200
+ def _pick_channels_randomly(
201
+ X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
202
+ ) -> torch.Tensor:
179
203
  rng = check_random_state(random_state)
180
204
  batch_size, n_channels, _ = X.shape
181
205
  # allows to use the same RNG
@@ -188,7 +212,12 @@ def _pick_channels_randomly(X, p_pick, random_state):
188
212
  return torch.sigmoid(1000 * (unif_samples - p_pick))
189
213
 
190
214
 
191
- def channels_dropout(X, y, p_drop, random_state=None):
215
+ def channels_dropout(
216
+ X: torch.Tensor,
217
+ y: torch.Tensor,
218
+ p_drop: float,
219
+ random_state: int | np.random.RandomState | None = None,
220
+ ) -> tuple[torch.Tensor, torch.Tensor]:
192
221
  """Randomly set channels to flat signal.
193
222
 
194
223
  Part of the CMSAugment policy proposed in [1]_
@@ -222,7 +251,9 @@ def channels_dropout(X, y, p_drop, random_state=None):
222
251
  return X * mask.unsqueeze(-1), y
223
252
 
224
253
 
225
- def _make_permutation_matrix(X, mask, random_state):
254
+ def _make_permutation_matrix(
255
+ X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
256
+ ) -> torch.Tensor:
226
257
  rng = check_random_state(random_state)
227
258
  batch_size, n_channels, _ = X.shape
228
259
  hard_mask = mask.round()
@@ -241,7 +272,12 @@ def _make_permutation_matrix(X, mask, random_state):
241
272
  return batch_permutations
242
273
 
243
274
 
244
- def channels_shuffle(X, y, p_shuffle, random_state=None):
275
+ def channels_shuffle(
276
+ X: torch.Tensor,
277
+ y: torch.Tensor,
278
+ p_shuffle: float,
279
+ random_state: int | np.random.RandomState | None = None,
280
+ ) -> tuple[torch.Tensor, torch.Tensor]:
245
281
  """Randomly shuffle channels in EEG data matrix.
246
282
 
247
283
  Part of the CMSAugment policy proposed in [1]_
@@ -280,7 +316,12 @@ def channels_shuffle(X, y, p_shuffle, random_state=None):
280
316
  return torch.matmul(batch_permutations, X), y
281
317
 
282
318
 
283
- def gaussian_noise(X, y, std, random_state=None):
319
+ def gaussian_noise(
320
+ X: torch.Tensor,
321
+ y: torch.Tensor,
322
+ std: float,
323
+ random_state: int | np.random.RandomState | None = None,
324
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
325
  """Randomly add white Gaussian noise to all channels.
285
326
 
286
327
  Suggested e.g. in [1]_, [2]_ and [3]_
@@ -332,7 +373,9 @@ def gaussian_noise(X, y, std, random_state=None):
332
373
  return transformed_X, y
333
374
 
334
375
 
335
- def channels_permute(X, y, permutation):
376
+ def channels_permute(
377
+ X: torch.Tensor, y: torch.Tensor, permutation: list[int]
378
+ ) -> tuple[torch.Tensor, torch.Tensor]:
336
379
  """Permute EEG channels according to fixed permutation matrix.
337
380
 
338
381
  Suggested e.g. in [1]_
@@ -362,7 +405,12 @@ def channels_permute(X, y, permutation):
362
405
  return X[..., permutation, :], y
363
406
 
364
407
 
365
- def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
408
+ def smooth_time_mask(
409
+ X: torch.Tensor,
410
+ y: torch.Tensor,
411
+ mask_start_per_sample: torch.Tensor,
412
+ mask_len_samples: int,
413
+ ) -> tuple[torch.Tensor, torch.Tensor]:
366
414
  """Smoothly replace a contiguous part of all channels by zeros.
367
415
 
368
416
  Originally proposed in [1]_ and [2]_
@@ -412,7 +460,13 @@ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
412
460
  return X * mask, y
413
461
 
414
462
 
415
- def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
463
+ def bandstop_filter(
464
+ X: torch.Tensor,
465
+ y: torch.Tensor,
466
+ sfreq: float,
467
+ bandwidth: float,
468
+ freqs_to_notch: npt.ArrayLike | None,
469
+ ) -> tuple[torch.Tensor, torch.Tensor]:
416
470
  """Apply a band-stop filter with desired bandwidth at the desired frequency
417
471
  position.
418
472
 
@@ -451,7 +505,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
451
505
  Representation Learning for Electroencephalogram Classification. In
452
506
  Machine Learning for Health (pp. 238-253). PMLR.
453
507
  """
454
- if bandwidth == 0:
508
+ if bandwidth == 0 or freqs_to_notch is None:
455
509
  return X, y
456
510
  transformed_X = X.clone()
457
511
  for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
@@ -469,7 +523,7 @@ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
469
523
  return transformed_X, y
470
524
 
471
525
 
472
- def _analytic_transform(x):
526
+ def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
473
527
  if torch.is_complex(x):
474
528
  raise ValueError("x must be real.")
475
529
 
@@ -486,12 +540,12 @@ def _analytic_transform(x):
486
540
  return ifft(f * h, dim=-1)
487
541
 
488
542
 
489
- def _nextpow2(n):
543
+ def _nextpow2(n: int) -> int:
490
544
  """Return the first integer N such that 2**N >= abs(n)."""
491
545
  return int(np.ceil(np.log2(np.abs(n))))
492
546
 
493
547
 
494
- def _frequency_shift(X, fs, f_shift):
548
+ def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
495
549
  """
496
550
  Shift the specified signal by the specified frequency.
497
551
 
@@ -504,9 +558,13 @@ def _frequency_shift(X, fs, f_shift):
504
558
  t = torch.arange(N_padded, device=X.device) / fs
505
559
  padded = pad(X, (0, N_padded - N_orig))
506
560
  analytical = _analytic_transform(padded)
507
- if isinstance(f_shift, (float, int, np.ndarray, list)):
508
- f_shift = torch.as_tensor(f_shift).float()
509
- f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
561
+ if isinstance(f_shift, torch.Tensor):
562
+ _f_shift = f_shift
563
+ elif isinstance(f_shift, (float, int, np.ndarray, list)):
564
+ _f_shift = torch.as_tensor(f_shift).float()
565
+ else:
566
+ raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
567
+ f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
510
568
  reshaped_f_shift = f_shift_stack.permute(
511
569
  *torch.arange(f_shift_stack.ndim - 1, -1, -1)
512
570
  )
@@ -514,7 +572,9 @@ def _frequency_shift(X, fs, f_shift):
514
572
  return shifted[..., :N_orig].real.float()
515
573
 
516
574
 
517
- def frequency_shift(X, y, delta_freq, sfreq):
575
+ def frequency_shift(
576
+ X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
577
+ ) -> tuple[torch.Tensor, torch.Tensor]:
518
578
  """Adds a shift in the frequency domain to all channels.
519
579
 
520
580
  Note that here, the shift is the same for all channels of a single example.
@@ -545,7 +605,7 @@ def frequency_shift(X, y, delta_freq, sfreq):
545
605
  return transformed_X, y
546
606
 
547
607
 
548
- def _torch_normalize_vectors(rr):
608
+ def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
549
609
  """Normalize surface vertices."""
550
610
  norm = torch.linalg.norm(rr, axis=1, keepdim=True)
551
611
  mask = norm > 0
@@ -554,7 +614,9 @@ def _torch_normalize_vectors(rr):
554
614
  return new_rr
555
615
 
556
616
 
557
- def _torch_legval(x, c, tensor=True):
617
+ def _torch_legval(
618
+ x: torch.Tensor, c: torch.Tensor, tensor: bool = True
619
+ ) -> torch.Tensor:
558
620
  """
559
621
  Evaluate a Legendre series at points x.
560
622
  If `c` is of length `n + 1`, this function returns the value:
@@ -662,7 +724,9 @@ def _torch_legval(x, c, tensor=True):
662
724
  return c0 + c1 * x
663
725
 
664
726
 
665
- def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
727
+ def _torch_calc_g(
728
+ cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
729
+ ) -> torch.Tensor:
666
730
  """Calculate spherical spline g function between points on a sphere.
667
731
 
668
732
  Parameters
@@ -718,23 +782,25 @@ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
718
782
  return _torch_legval(cosang, [0] + factors)
719
783
 
720
784
 
721
- def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
785
+ def _torch_make_interpolation_matrix(
786
+ pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
787
+ ) -> torch.Tensor:
722
788
  """Compute interpolation matrix based on spherical splines.
723
789
 
724
790
  Implementation based on [1]_
725
791
 
726
792
  Parameters
727
793
  ----------
728
- pos_from : np.ndarray of float, shape(n_good_sensors, 3)
794
+ pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
729
795
  The positions to interpolate from.
730
- pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
796
+ pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
731
797
  The positions to interpolate.
732
798
  alpha : float
733
799
  Regularization parameter. Defaults to 1e-5.
734
800
 
735
801
  Returns
736
802
  -------
737
- interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
803
+ interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
738
804
  The interpolation matrix that maps good signals to the location
739
805
  of bad signals.
740
806
 
@@ -822,7 +888,12 @@ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
822
888
  return interpolation
823
889
 
824
890
 
825
- def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
891
+ def _rotate_signals(
892
+ X: torch.Tensor,
893
+ rotations: list[torch.Tensor],
894
+ sensors_positions_matrix: torch.Tensor,
895
+ spherical: bool = True,
896
+ ) -> torch.Tensor:
826
897
  sensors_positions_matrix = sensors_positions_matrix.to(X.device)
827
898
  rot_sensors_matrices = [
828
899
  rotation.matmul(sensors_positions_matrix) for rotation in rotations
@@ -853,22 +924,29 @@ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
853
924
  return transformed_X
854
925
 
855
926
 
856
- def _make_rotation_matrix(axis, angle, degrees=True):
927
+ def _make_rotation_matrix(
928
+ axis: Literal["x", "y", "z"],
929
+ angle: float | int | np.ndarray | list | torch.Tensor,
930
+ degrees: bool = True,
931
+ ) -> torch.Tensor:
857
932
  assert axis in ["x", "y", "z"], "axis should be either x, y or z."
858
-
859
- if isinstance(angle, (float, int, np.ndarray, list)):
860
- angle = torch.as_tensor(angle)
933
+ if isinstance(angle, torch.Tensor):
934
+ _angle = angle
935
+ elif isinstance(angle, (float, int, np.ndarray, list)):
936
+ _angle = torch.as_tensor(angle)
937
+ else:
938
+ raise ValueError(f"Invalid angle type: {type(angle)}")
861
939
 
862
940
  if degrees:
863
- angle = angle * np.pi / 180
941
+ _angle = _angle * np.pi / 180
864
942
 
865
- device = angle.device
943
+ device = _angle.device
866
944
  zero = torch.zeros(1, device=device)
867
945
  rot = torch.stack(
868
946
  [
869
947
  torch.as_tensor([1, 0, 0], device=device),
870
- torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
871
- torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
948
+ torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
949
+ torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
872
950
  ]
873
951
  )
874
952
  if axis == "x":
@@ -881,7 +959,14 @@ def _make_rotation_matrix(axis, angle, degrees=True):
881
959
  return rot[:, [1, 2, 0]]
882
960
 
883
961
 
884
- def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
962
+ def sensors_rotation(
963
+ X: torch.Tensor,
964
+ y: torch.Tensor,
965
+ sensors_positions_matrix: torch.Tensor,
966
+ axis: Literal["x", "y", "z"],
967
+ angles: npt.ArrayLike,
968
+ spherical_splines: bool,
969
+ ) -> tuple[torch.Tensor, torch.Tensor]:
885
970
  """Interpolates EEG signals over sensors rotated around the desired axis
886
971
  with the desired angle.
887
972
 
@@ -893,7 +978,7 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
893
978
  EEG input example or batch.
894
979
  y : torch.Tensor
895
980
  EEG labels for the example or batch.
896
- sensors_positions_matrix : numpy.ndarray
981
+ sensors_positions_matrix : torch.Tensor
897
982
  Matrix giving the positions of each sensor in a 3D cartesian coordinate
898
983
  system. Should have shape (3, n_channels), where n_channels is the
899
984
  number of channels. Standard 10-20 positions can be obtained from
@@ -924,7 +1009,9 @@ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_spl
924
1009
  return rotated_X, y
925
1010
 
926
1011
 
927
- def mixup(X, y, lam, idx_perm):
1012
+ def mixup(
1013
+ X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
1014
+ ) -> tuple[torch.Tensor, torch.Tensor]:
928
1015
  """Mixes two channels of EEG data.
929
1016
 
930
1017
  See [1]_ for details.
@@ -973,8 +1060,13 @@ def mixup(X, y, lam, idx_perm):
973
1060
 
974
1061
 
975
1062
  def segmentation_reconstruction(
976
- X, y, n_segments, data_classes, rand_indices, idx_shuffle
977
- ):
1063
+ X: torch.Tensor,
1064
+ y: torch.Tensor,
1065
+ n_segments: int,
1066
+ data_classes: list[tuple[int, torch.Tensor]],
1067
+ rand_indices: npt.ArrayLike,
1068
+ idx_shuffle: npt.ArrayLike,
1069
+ ) -> tuple[torch.Tensor, torch.Tensor]:
978
1070
  """Segment and reconstruct EEG data from [1]_.
979
1071
 
980
1072
  See [1]_ for details.
@@ -987,6 +1079,8 @@ def segmentation_reconstruction(
987
1079
  EEG labels for the example or batch.
988
1080
  n_segments : int
989
1081
  Number of segments to use in the batch.
1082
+ data_classes: list[tuple[int, torch.Tensor]]
1083
+ List of tuples. Each tuple contains the class index and the corresponding EEG data.
990
1084
  rand_indices: array-like
991
1085
  Array of indices that indicates which trial to use in each segment.
992
1086
  idx_shuffle: array-like
@@ -1005,8 +1099,8 @@ def segmentation_reconstruction(
1005
1099
  """
1006
1100
 
1007
1101
  # Initialize lists to store augmented data and corresponding labels
1008
- aug_data = []
1009
- aug_label = []
1102
+ aug_data: list[torch.Tensor] = []
1103
+ aug_label: list[torch.Tensor] = []
1010
1104
 
1011
1105
  # Iterate through each class to separate and augment data
1012
1106
  for class_index, X_class in data_classes:
@@ -1030,20 +1124,26 @@ def segmentation_reconstruction(
1030
1124
  aug_data.append(X_aug)
1031
1125
  aug_label.append(torch.full((n_trials,), class_index))
1032
1126
  # Concatenate the augmented data and labels
1033
- aug_data = torch.cat(aug_data, dim=0)
1034
- aug_data = aug_data.to(dtype=X.dtype, device=X.device)
1035
- aug_data = aug_data[idx_shuffle]
1127
+ concat_aug_data = torch.cat(aug_data, dim=0)
1128
+ concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
1129
+ concat_aug_data = concat_aug_data[idx_shuffle]
1036
1130
 
1037
1131
  if y is not None:
1038
- aug_label = torch.cat(aug_label, dim=0)
1039
- aug_label = aug_label.to(dtype=y.dtype, device=y.device)
1040
- aug_label = aug_label[idx_shuffle]
1041
- return aug_data, aug_label
1132
+ concat_label = torch.cat(aug_label, dim=0)
1133
+ concat_label = concat_label.to(dtype=y.dtype, device=y.device)
1134
+ concat_label = concat_label[idx_shuffle]
1135
+ return concat_aug_data, concat_label
1042
1136
 
1043
- return aug_data, y
1137
+ return concat_aug_data, None
1044
1138
 
1045
1139
 
1046
- def mask_encoding(X, y, time_start, segment_length, n_segments):
1140
+ def mask_encoding(
1141
+ X: torch.Tensor,
1142
+ y: torch.Tensor,
1143
+ time_start: torch.Tensor,
1144
+ segment_length: int,
1145
+ n_segments: int,
1146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1047
1147
  """Mark encoding from Ding et al. (2024) from [ding2024]_.
1048
1148
 
1049
1149
  Replaces a contiguous part (or parts) of all channels by zeros
@@ -161,10 +161,10 @@ class ChannelsDropout(Transform):
161
161
  ----------
162
162
  probability: float
163
163
  Float setting the probability of applying the operation.
164
- proba_drop: float | None, optional
164
+ p_drop: float | None, optional
165
165
  Float between 0 and 1 setting the probability of dropping each channel.
166
166
  Defaults to 0.2.
167
- random_state: int | numpy.random.Generator, optional
167
+ random_state: int | numpy.random.RandomState, optional
168
168
  Seed to be used to instantiate numpy random number generator instance.
169
169
  Used to decide whether or not to transform given the probability
170
170
  argument and to sample channels to erase. Defaults to None.
@@ -815,7 +815,7 @@ class BaseConcatDataset(ConcatDataset):
815
815
  @staticmethod
816
816
  def _save_description(sub_dir, description):
817
817
  description_file_path = os.path.join(sub_dir, "description.json")
818
- description.to_json(description_file_path)
818
+ description.to_json(description_file_path, default_handler=str)
819
819
 
820
820
  @staticmethod
821
821
  def _save_kwargs(sub_dir, ds):
@@ -22,7 +22,6 @@ from mne.datasets.utils import _get_path
22
22
  from mne.utils import warn
23
23
 
24
24
  from braindecode.datasets import BaseConcatDataset, BaseDataset
25
- from braindecode.preprocessing.preprocess import _preprocess
26
25
 
27
26
  PC18_DIR = op.join(op.dirname(__file__), "data", "pc18")
28
27
  PC18_RECORDS = op.join(PC18_DIR, "sleep_records.csv")
@@ -407,6 +406,8 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
407
406
  base_dataset = BaseDataset(raw_file, desc)
408
407
 
409
408
  if preproc is not None:
409
+ from braindecode.preprocessing.preprocess import _preprocess
410
+
410
411
  _preprocess(base_dataset, None, preproc)
411
412
 
412
413
  return base_dataset
@@ -138,12 +138,17 @@ def _load_signals(fif_file, preload, is_raw):
138
138
  with open(pkl_file, "rb") as f:
139
139
  signals = pickle.load(f)
140
140
 
141
- # If the file has been moved together with the pickle file, make sure
142
- # the path links to correct fif file.
143
- signals._fname = str(fif_file)
144
- if preload:
145
- signals.load_data()
146
- return signals
141
+ if all(f.exists() for f in signals.filenames):
142
+ if preload:
143
+ signals.load_data()
144
+ return signals
145
+ else: # This may happen if the file has been moved together with the pickle file.
146
+ warnings.warn(
147
+ f"Pickle file {pkl_file} exists, but the referenced fif "
148
+ "file(s) do not exist. Will read the fif file(s) directly "
149
+ "and re-create the pickle file.",
150
+ UserWarning,
151
+ )
147
152
 
148
153
  # If pickle didn't exist read via mne (likely slower) and save pkl after
149
154
  if is_raw:
@@ -27,6 +27,7 @@ from .hybrid import HybridNet
27
27
  from .ifnet import IFNet
28
28
  from .labram import Labram
29
29
  from .msvtnet import MSVTNet
30
+ from .patchedtransformer import PBT
30
31
  from .sccnet import SCCNet
31
32
  from .shallow_fbcsp import ShallowFBCSPNet
32
33
  from .signal_jepa import (
@@ -39,6 +40,7 @@ from .sinc_shallow import SincShallowNet
39
40
  from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
40
41
  from .sleep_stager_chambon_2018 import SleepStagerChambon2018
41
42
  from .sparcnet import SPARCNet
43
+ from .sstdpn import SSTDPN
42
44
  from .syncnet import SyncNet
43
45
  from .tcn import BDTCN, TCN
44
46
  from .tidnet import TIDNet
@@ -77,6 +79,7 @@ __all__ = [
77
79
  "IFNet",
78
80
  "Labram",
79
81
  "MSVTNet",
82
+ "PBT",
80
83
  "SCCNet",
81
84
  "ShallowFBCSPNet",
82
85
  "SignalJEPA",
@@ -84,6 +87,7 @@ __all__ = [
84
87
  "SignalJEPA_PostLocal",
85
88
  "SignalJEPA_PreLocal",
86
89
  "SincShallowNet",
90
+ "SSTDPN",
87
91
  "SleepStagerBlanco2020",
88
92
  "SleepStagerChambon2018",
89
93
  "SPARCNet",
@@ -370,7 +370,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
370
370
  nn.Sequential(
371
371
  *[
372
372
  _TCNResidualBlock(
373
- in_channels=self.F2,
373
+ in_channels=self.F2 if i == 0 else self.tcn_n_filters,
374
374
  kernel_size=self.tcn_kernel_size,
375
375
  n_filters=self.tcn_n_filters,
376
376
  dropout=self.tcn_dropout,
@@ -388,7 +388,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
388
388
  self.final_layer = nn.ModuleList(
389
389
  [
390
390
  MaxNormLinear(
391
- in_features=self.F2 * self.n_windows,
391
+ in_features=self.tcn_n_filters * self.n_windows,
392
392
  out_features=self.n_outputs,
393
393
  max_norm_val=self.max_norm_const,
394
394
  )
@@ -398,7 +398,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
398
398
  self.final_layer = nn.ModuleList(
399
399
  [
400
400
  MaxNormLinear(
401
- in_features=self.F2,
401
+ in_features=self.tcn_n_filters,
402
402
  out_features=self.n_outputs,
403
403
  max_norm_val=self.max_norm_const,
404
404
  )
@@ -408,7 +408,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
408
408
 
409
409
  self.out_fun = nn.Identity()
410
410
 
411
- def forward(self, X):
411
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
412
412
  # Dimension: (batch_size, C, T)
413
413
  X = self.ensuredims(X)
414
414
  # Dimension: (batch_size, C, T, 1)
@@ -695,8 +695,8 @@ class _TCNResidualBlock(nn.Module):
695
695
  # Reshape the input for the residual connection when necessary
696
696
  if in_channels != n_filters:
697
697
  self.reshaping_conv = nn.Conv1d(
698
- in_channels=in_channels,
699
- out_channels=n_filters,
698
+ in_channels=in_channels, # Specify input channels
699
+ out_channels=n_filters, # Specify output channels
700
700
  kernel_size=1,
701
701
  padding="same",
702
702
  )
@@ -716,7 +716,7 @@ class _TCNResidualBlock(nn.Module):
716
716
  out = self.activation(out)
717
717
  out = self.drop2(out)
718
718
 
719
- out = self.reshaping_conv(out)
719
+ X = self.reshaping_conv(X)
720
720
 
721
721
  # ----- Residual connection -----
722
722
  out = X + out
@@ -17,7 +17,7 @@ class BIOT(EEGModuleMixin, nn.Module):
17
17
 
18
18
  BIOT: Cross-data Biosignal Learning in the Wild.
19
19
 
20
- BIOT is a large language model for biosignal classification. It is
20
+ BIOT is a large brain model for biosignal classification. It is
21
21
  a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
22
22
 
23
23
  It is designed for N-dimensional biosignal data such as EEG, ECG, etc.