braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import logging
6
8
  import os.path
7
9
  import re
@@ -16,7 +18,8 @@ log = logging.getLogger(__name__)
16
18
 
17
19
 
18
20
  class BBCIDataset(object):
19
- """
21
+ """BBCIDataset.
22
+
20
23
  Loader class for files created by saving BBCI files in matlab (make
21
24
  sure to save with '-v7.3' in matlab, see
22
25
  https://de.mathworks.com/help/matlab/import_export/mat-file-versions.html#buk6i87
@@ -34,12 +37,14 @@ class BBCIDataset(object):
34
37
  """
35
38
 
36
39
  def __init__(
37
- self, filename, load_sensor_names=None, check_class_names=False
40
+ self,
41
+ filename: str,
42
+ load_sensor_names: list[str] | None = None,
43
+ check_class_names: bool = False,
38
44
  ):
39
45
  self.__dict__.update(locals())
40
- del self.self
41
46
 
42
- def load(self):
47
+ def load(self) -> mne.io.RawArray:
43
48
  cnt = self._load_continuous_signal()
44
49
  cnt = self._add_markers(cnt)
45
50
  return cnt
@@ -50,9 +55,7 @@ class BBCIDataset(object):
50
55
  with h5py.File(self.filename, "r") as h5file:
51
56
  samples = int(h5file["nfo"]["T"][0, 0])
52
57
  cnt_signal_shape = (samples, len(wanted_chan_inds))
53
- continuous_signal = (
54
- np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
55
- )
58
+ continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
56
59
  for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds):
57
60
  # + 1 because matlab/this hdf5-naming logic
58
61
  # has 1-based indexing
@@ -63,9 +66,7 @@ class BBCIDataset(object):
63
66
  :
64
67
  ].squeeze() # already load into memory
65
68
  continuous_signal[:, chan_ind_arr] = chan_signal
66
- assert not np.any(
67
- np.isnan(continuous_signal)
68
- ), "No NaNs expected in signal"
69
+ assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal"
69
70
 
70
71
  if self.load_sensor_names is None:
71
72
  ch_types = ["eeg"] * len(wanted_chan_inds)
@@ -83,15 +84,12 @@ class BBCIDataset(object):
83
84
  def _determine_sensors(self):
84
85
  all_sensor_names = self.get_all_sensors(self.filename, pattern=None)
85
86
  if self.load_sensor_names is None:
86
-
87
87
  # if no sensor names given, take all EEG-chans
88
88
  eeg_sensor_names = all_sensor_names
89
89
  eeg_sensor_names = filter(
90
90
  lambda s: not s.startswith("BIP"), eeg_sensor_names
91
91
  )
92
- eeg_sensor_names = filter(
93
- lambda s: not s.startswith("E"), eeg_sensor_names
94
- )
92
+ eeg_sensor_names = filter(lambda s: not s.startswith("E"), eeg_sensor_names)
95
93
  eeg_sensor_names = filter(
96
94
  lambda s: not s.startswith("Microphone"), eeg_sensor_names
97
95
  )
@@ -103,17 +101,15 @@ class BBCIDataset(object):
103
101
  )
104
102
  eeg_sensor_names = list(eeg_sensor_names)
105
103
  assert (
106
- len(eeg_sensor_names) == 128 or
107
- len(eeg_sensor_names) == 64 or
108
- len(eeg_sensor_names) == 32 or
109
- len(eeg_sensor_names) == 16
104
+ len(eeg_sensor_names) == 128
105
+ or len(eeg_sensor_names) == 64
106
+ or len(eeg_sensor_names) == 32
107
+ or len(eeg_sensor_names) == 16
110
108
  ), "Recheck this code if you have different sensors..."
111
109
  wanted_sensor_names = eeg_sensor_names
112
110
  else:
113
111
  wanted_sensor_names = self.load_sensor_names
114
- chan_inds = self._determine_chan_inds(
115
- all_sensor_names, wanted_sensor_names
116
- )
112
+ chan_inds = self._determine_chan_inds(all_sensor_names, wanted_sensor_names)
117
113
  return chan_inds, wanted_sensor_names
118
114
 
119
115
  def _determine_samplingrate(self):
@@ -127,16 +123,12 @@ class BBCIDataset(object):
127
123
  def _determine_chan_inds(all_sensor_names, sensor_names):
128
124
  assert sensor_names is not None
129
125
  chan_inds = [all_sensor_names.index(s) for s in sensor_names]
130
- assert len(chan_inds) == len(sensor_names), (
131
- "All" "sensors should be there."
132
- )
133
- assert len(set(chan_inds)) == len(chan_inds), (
134
- "No duplicated sensors wanted."
135
- )
126
+ assert len(chan_inds) == len(sensor_names), "Allsensors should be there."
127
+ assert len(set(chan_inds)) == len(chan_inds), "No duplicated sensors wanted."
136
128
  return chan_inds
137
129
 
138
130
  @staticmethod
139
- def get_all_sensors(filename, pattern=None):
131
+ def get_all_sensors(filename: str, pattern: str | None = None) -> list[str]:
140
132
  """
141
133
  Get all sensors that exist in the given file.
142
134
 
@@ -157,17 +149,15 @@ class BBCIDataset(object):
157
149
  "".join(chr(c.item()) for c in h5file[obj_ref]) for obj_ref in clab_set
158
150
  ]
159
151
  if pattern is not None:
160
- all_sensor_names = filter(
161
- lambda sname: re.search(pattern, sname), all_sensor_names
152
+ all_sensor_names = list(
153
+ filter(lambda sname: re.search(pattern, sname), all_sensor_names)
162
154
  )
163
155
  return all_sensor_names
164
156
 
165
157
  def _add_markers(self, cnt):
166
158
  with h5py.File(self.filename, "r") as h5file:
167
159
  event_times_in_ms = h5file["mrk"]["time"][:].squeeze()
168
- event_classes = (
169
- h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
170
- )
160
+ event_classes = h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
171
161
 
172
162
  # Check whether class names known and correct order
173
163
  class_name_set = h5file["nfo"]["className"][:].squeeze()
@@ -177,9 +167,7 @@ class BBCIDataset(object):
177
167
  ]
178
168
 
179
169
  if self.check_class_names:
180
- _check_class_names(
181
- all_class_names, event_times_in_ms, event_classes
182
- )
170
+ _check_class_names(all_class_names, event_times_in_ms, event_classes)
183
171
 
184
172
  event_times_in_samples = event_times_in_ms * cnt.info["sfreq"] / 1000.0
185
173
  event_times_in_samples = np.uint32(np.round(event_times_in_samples))
@@ -196,8 +184,8 @@ class BBCIDataset(object):
196
184
  i_sample,
197
185
  event_classes[i_event - 1],
198
186
  event_classes[i_event],
199
- ) +
200
- "Marker codes will be summed."
187
+ )
188
+ + "Marker codes will be summed."
201
189
  )
202
190
  previous_i_sample = i_sample
203
191
 
@@ -222,7 +210,7 @@ class BBCIDataset(object):
222
210
  # Hacky way to try to find out class names for each event
223
211
  # h5file['mrk']['y'] y contains one-hot label for event name
224
212
  with h5py.File(self.filename, "r") as h5file:
225
- y = h5file['mrk']['y'][:]
213
+ y = h5file["mrk"]["y"][:]
226
214
  # seems that there are cases where for last class
227
215
  # y is just all zero for some reason?
228
216
  # and seems then it is last of the class names
@@ -233,7 +221,7 @@ class BBCIDataset(object):
233
221
  event_i_classes = np.argmax(y, axis=1)
234
222
 
235
223
  # 4 second trials for High-Gamma dataset, otherwise how to know?
236
- if all_class_names == ['Right Hand', 'Left Hand', 'Rest', 'Feet']:
224
+ if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
237
225
  durations = np.full(event_times_in_ms.shape, 4)
238
226
  else:
239
227
  warnings.warn("Unknown event durations set to 0")
@@ -265,8 +253,8 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
265
253
  pass
266
254
  elif (
267
255
  (
268
- all_class_names ==
269
- [
256
+ all_class_names
257
+ == [
270
258
  "1",
271
259
  "10",
272
260
  "11",
@@ -285,9 +273,10 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
285
273
  "44",
286
274
  "99",
287
275
  ]
288
- ) or (
289
- all_class_names ==
290
- [
276
+ )
277
+ or (
278
+ all_class_names
279
+ == [
291
280
  "1",
292
281
  "10",
293
282
  "11",
@@ -305,8 +294,8 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
305
294
  "44",
306
295
  "99",
307
296
  ]
308
- ) or (
309
- all_class_names == ["1", "2", "3", "4"])
297
+ )
298
+ or (all_class_names == ["1", "2", "3", "4"])
310
299
  ):
311
300
  pass # Semantic classes
312
301
  elif all_class_names == ["Rest", "Feet", "Left Hand", "Right Hand"]:
@@ -668,7 +657,9 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
668
657
  log.warn("Unknown class names {:s}".format(all_class_names))
669
658
 
670
659
 
671
- def load_bbci_sets_from_folder(folder, runs="all"):
660
+ def load_bbci_sets_from_folder(
661
+ folder: str, runs: list[int] | str = "all"
662
+ ) -> list[mne.io.RawArray]:
672
663
  """
673
664
  Load bbci datasets from files in given folder.
674
665
 
@@ -687,10 +678,10 @@ def load_bbci_sets_from_folder(folder, runs="all"):
687
678
  """
688
679
  bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
689
680
  if runs != "all":
690
- file_run_numbers = [
691
- int(re.search("S[0-9]{3,3}R[0-9]{2,2}_", f).group()[5:7])
692
- for f in bbci_mat_files
693
- ]
681
+ assert isinstance(runs, list), "runs should be list[int] or 'all'"
682
+ matches = [re.search("S[0-9]{3,3}R[0-9]{2,2}_", f) for f in bbci_mat_files]
683
+ file_run_numbers = [int(m.group()[5:7]) for m in matches if m is not None]
684
+ assert len(file_run_numbers) == len(bbci_mat_files), "Some files don't match"
694
685
  indices = [file_run_numbers.index(num) for num in runs]
695
686
 
696
687
  wanted_files = np.array(bbci_mat_files)[indices]
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # License: BSD (3-clause)
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import glob
7
9
  import os
8
10
  import os.path as osp
@@ -14,10 +16,12 @@ import numpy as np
14
16
  from mne.utils import verbose
15
17
  from scipy.io import loadmat
16
18
 
17
- from braindecode.datasets import BaseDataset, BaseConcatDataset
19
+ from braindecode.datasets import BaseConcatDataset, BaseDataset
18
20
 
19
- DATASET_URL = 'https://stacks.stanford.edu/file/druid:zk881ps0522/' \
20
- 'BCI_Competion4_dataset4_data_fingerflexions.zip'
21
+ DATASET_URL = (
22
+ "https://stacks.stanford.edu/file/druid:zk881ps0522/"
23
+ "BCI_Competion4_dataset4_data_fingerflexions.zip"
24
+ )
21
25
 
22
26
 
23
27
  class BCICompetitionIVDataset4(BaseConcatDataset):
@@ -42,30 +46,32 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
42
46
  References
43
47
  ----------
44
48
  .. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
45
- Nature human behaviour 3, no. 11 (2019): 1225-1235. https://doi.org/10.1038/s41562-019-0678-3
49
+ Nature human behaviour 3, no. 11 (2019): 1225-1235.
50
+ https://doi.org/10.1038/s41562-019-0678-3
46
51
  """
52
+
47
53
  possible_subjects = [1, 2, 3]
48
54
 
49
- def __init__(self, subject_ids=None):
55
+ def __init__(self, subject_ids: list[int] | int | None = None):
50
56
  data_path = self.download()
51
57
  if isinstance(subject_ids, int):
52
58
  subject_ids = [subject_ids]
53
59
  if subject_ids is None:
54
60
  subject_ids = self.possible_subjects
55
61
  self._validate_subjects(subject_ids)
56
- files_list = [f'{data_path}/sub{i}_comp.mat' for i in subject_ids]
62
+ files_list = [f"{data_path}/sub{i}_comp.mat" for i in subject_ids]
57
63
  datasets = []
58
64
  for file_path in files_list:
59
65
  raw_train, raw_test = self._load_data_to_mne(file_path)
60
66
  desc_train = dict(
61
- subject=file_path.split('/')[-1].split('sub')[1][0],
62
- file_name=file_path.split('/')[-1],
63
- session='train'
67
+ subject=file_path.split("/")[-1].split("sub")[1][0],
68
+ file_name=file_path.split("/")[-1],
69
+ session="train",
64
70
  )
65
71
  desc_test = dict(
66
- subject=file_path.split('/')[-1].split('sub')[1][0],
67
- file_name=file_path.split('/')[-1],
68
- session='test'
72
+ subject=file_path.split("/")[-1].split("sub")[1][0],
73
+ file_name=file_path.split("/")[-1],
74
+ session="test",
69
75
  )
70
76
  datasets.append(BaseDataset(raw_train, description=desc_train))
71
77
  datasets.append(BaseDataset(raw_test, description=desc_test))
@@ -90,20 +96,24 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
90
96
  -------
91
97
 
92
98
  """
93
- signature = 'BCICompetitionIVDataset4'
94
- folder_name = 'BCI_Competion4_dataset4_data_fingerflexions'
99
+ signature = "BCICompetitionIVDataset4"
100
+ folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
95
101
  # Check if the dataset already exists (unpacked). We have to do that manually
96
102
  # because we are removing .zip file from disk to save disk space.
97
103
 
98
104
  from moabb.datasets.download import get_dataset_path # keep soft dependency
105
+
99
106
  path = get_dataset_path(signature, path)
100
107
  key_dest = "MNE-{:s}-data".format(signature.lower())
101
108
  # We do not use mne _url_to_local_path due to ':' in the url that causes problems on Windows
102
109
  destination = osp.join(path, key_dest, folder_name)
103
- if len(list(glob.glob(osp.join(destination, '*.mat')))) == 6:
110
+ if len(list(glob.glob(osp.join(destination, "*.mat")))) == 6:
104
111
  return destination
105
- data_path = _data_dl(DATASET_URL, osp.join(destination, folder_name, signature),
106
- force_update=force_update)
112
+ data_path = _data_dl(
113
+ DATASET_URL,
114
+ osp.join(destination, folder_name, signature),
115
+ force_update=force_update,
116
+ )
107
117
  unpack_archive(data_path, osp.dirname(destination))
108
118
  # removes .zip file that the data was unpacked from
109
119
  remove(data_path)
@@ -117,26 +127,30 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
117
127
 
118
128
  def _load_data_to_mne(self, file_path):
119
129
  data = loadmat(file_path)
120
- test_labels = loadmat(file_path.replace('comp.mat', 'testlabels.mat'))
121
- train_data = data['train_data']
122
- test_data = data['test_data']
123
- upsampled_train_targets = data['train_dg']
124
- upsampled_test_targets = test_labels['test_dg']
130
+ test_labels = loadmat(file_path.replace("comp.mat", "testlabels.mat"))
131
+ train_data = data["train_data"]
132
+ test_data = data["test_data"]
133
+ upsampled_train_targets = data["train_dg"]
134
+ upsampled_test_targets = test_labels["test_dg"]
125
135
 
126
136
  signal_sfreq = 1000
127
137
  original_target_sfreq = 25
128
138
  targets_stride = int(signal_sfreq / original_target_sfreq)
129
139
 
130
- original_targets = self._prepare_targets(upsampled_train_targets, targets_stride)
131
- original_test_targets = self._prepare_targets(upsampled_test_targets, targets_stride)
140
+ original_targets = self._prepare_targets(
141
+ upsampled_train_targets, targets_stride
142
+ )
143
+ original_test_targets = self._prepare_targets(
144
+ upsampled_test_targets, targets_stride
145
+ )
132
146
 
133
- ch_names = [f'{i}' for i in range(train_data.shape[1])]
134
- ch_names += [f'target_{i}' for i in range(original_targets.shape[1])]
135
- ch_types = ['ecog' for _ in range(train_data.shape[1])]
136
- ch_types += ['misc' for _ in range(original_targets.shape[1])]
147
+ ch_names = [f"{i}" for i in range(train_data.shape[1])]
148
+ ch_names += [f"target_{i}" for i in range(original_targets.shape[1])]
149
+ ch_types = ["ecog" for _ in range(train_data.shape[1])]
150
+ ch_types += ["misc" for _ in range(original_targets.shape[1])]
137
151
 
138
152
  info = mne.create_info(sfreq=signal_sfreq, ch_names=ch_names, ch_types=ch_types)
139
- info['temp'] = dict(target_sfreq=original_target_sfreq)
153
+ info["temp"] = dict(target_sfreq=original_target_sfreq)
140
154
  train_data = np.concatenate([train_data, original_targets], axis=1)
141
155
  test_data = np.concatenate([test_data, original_test_targets], axis=1)
142
156
 
@@ -149,12 +163,12 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
149
163
  if isinstance(subject_ids, (list, tuple)):
150
164
  if not all((subject in self.possible_subjects for subject in subject_ids)):
151
165
  raise ValueError(
152
- f'Wrong subject_ids parameter. Possible values: {self.possible_subjects}. '
153
- f'Provided {subject_ids}.'
166
+ f"Wrong subject_ids parameter. Possible values: {self.possible_subjects}. "
167
+ f"Provided {subject_ids}."
154
168
  )
155
169
  else:
156
170
  raise ValueError(
157
- 'Wrong subject_ids format. Expected types: None, list, tuple, int.'
171
+ "Wrong subject_ids format. Expected types: None, list, tuple, int."
158
172
  )
159
173
 
160
174
 
@@ -165,6 +179,7 @@ def _data_dl(url, destination, force_update=False, verbose=None):
165
179
  # moabb/datasets/download.py
166
180
 
167
181
  from pooch import file_hash, retrieve # keep soft dependency
182
+
168
183
  if not osp.isfile(destination) or force_update:
169
184
  if osp.isfile(destination):
170
185
  os.remove(destination)
@@ -0,0 +1,245 @@
1
+ """Dataset for loading BIDS.
2
+
3
+ More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
4
+ """
5
+
6
+ # Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
7
+ #
8
+ # License: BSD (3-clause)
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import mne
17
+ import mne_bids
18
+ import numpy as np
19
+ import pandas as pd
20
+ from joblib import Parallel, delayed
21
+
22
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
23
+
24
+
25
+ def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
26
+ return {
27
+ "path": bids_path.fpath,
28
+ "subject": bids_path.subject,
29
+ "session": bids_path.session,
30
+ "task": bids_path.task,
31
+ "acquisition": bids_path.acquisition,
32
+ "run": bids_path.run,
33
+ "processing": bids_path.processing,
34
+ "recording": bids_path.recording,
35
+ "space": bids_path.space,
36
+ "split": bids_path.split,
37
+ "description": bids_path.description,
38
+ "suffix": bids_path.suffix,
39
+ "extension": bids_path.extension,
40
+ "datatype": bids_path.datatype,
41
+ }
42
+
43
+
44
+ @dataclass
45
+ class BIDSDataset(BaseConcatDataset):
46
+ """Dataset for loading BIDS.
47
+
48
+ This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
49
+ as it will be used to find the files to load. The default ``extensions`` parameter was changed.
50
+
51
+ More information on BIDS (Brain Imaging Data Structure)
52
+ can be found at https://bids.neuroimaging.io
53
+
54
+ .. Note::
55
+ For loading "unofficial" BIDS datasets containing epoched data,
56
+ you can use :class:`BIDSEpochsDataset`.
57
+
58
+ Parameters
59
+ ----------
60
+ root : pathlib.Path | str
61
+ The root of the BIDS path.
62
+ subjects : str | array-like of str | None
63
+ The subject ID. Corresponds to "sub".
64
+ sessions : str | array-like of str | None
65
+ The acquisition session. Corresponds to "ses".
66
+ tasks : str | array-like of str | None
67
+ The experimental task. Corresponds to "task".
68
+ acquisitions: str | array-like of str | None
69
+ The acquisition parameters. Corresponds to "acq".
70
+ runs : str | array-like of str | None
71
+ The run number. Corresponds to "run".
72
+ processings : str | array-like of str | None
73
+ The processing label. Corresponds to "proc".
74
+ recordings : str | array-like of str | None
75
+ The recording name. Corresponds to "rec".
76
+ spaces : str | array-like of str | None
77
+ The coordinate space for anatomical and sensor location
78
+ files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
79
+ Corresponds to "space".
80
+ Note that valid values for ``space`` must come from a list
81
+ of BIDS keywords as described in the BIDS specification.
82
+ splits : str | array-like of str | None
83
+ The split of the continuous recording file for ``.fif`` data.
84
+ Corresponds to "split".
85
+ descriptions : str | array-like of str | None
86
+ This corresponds to the BIDS entity ``desc``. It is used to provide
87
+ additional information for derivative data, e.g., preprocessed data
88
+ may be assigned ``description='cleaned'``.
89
+ suffixes : str | array-like of str | None
90
+ The filename suffix. This is the entity after the
91
+ last ``_`` before the extension. E.g., ``'channels'``.
92
+ The following filename suffix's are accepted:
93
+ 'meg', 'markers', 'eeg', 'ieeg', 'T1w',
94
+ 'participants', 'scans', 'electrodes', 'coordsystem',
95
+ 'channels', 'events', 'headshape', 'digitizer',
96
+ 'beh', 'physio', 'stim'
97
+ extensions : str | array-like of str | None
98
+ The extension of the filename. E.g., ``'.json'``.
99
+ By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
100
+ datatypes : str | array-like of str | None
101
+ The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
102
+ ``'ieeg'``.
103
+ check : bool
104
+ If ``True``, only returns paths that conform to BIDS. If ``False``
105
+ (default), the ``.check`` attribute of the returned
106
+ :class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
107
+ do conform to BIDS, and to ``False`` for those that don't.
108
+ preload : bool
109
+ If True, preload the data. Defaults to False.
110
+ n_jobs : int
111
+ Number of jobs to run in parallel. Defaults to 1.
112
+ """
113
+
114
+ root: Path | str
115
+ subjects: str | list[str] | None = None
116
+ sessions: str | list[str] | None = None
117
+ tasks: str | list[str] | None = None
118
+ acquisitions: str | list[str] | None = None
119
+ runs: str | list[str] | None = None
120
+ processings: str | list[str] | None = None
121
+ recordings: str | list[str] | None = None
122
+ spaces: str | list[str] | None = None
123
+ splits: str | list[str] | None = None
124
+ descriptions: str | list[str] | None = None
125
+ suffixes: str | list[str] | None = None
126
+ extensions: str | list[str] | None = field(
127
+ default_factory=lambda: [
128
+ ".con",
129
+ ".sqd",
130
+ ".pdf",
131
+ ".fif",
132
+ ".ds",
133
+ ".vhdr",
134
+ ".set",
135
+ ".edf",
136
+ ".bdf",
137
+ ".EDF",
138
+ ".snirf",
139
+ ".cdt",
140
+ ".mef",
141
+ ".nwb",
142
+ ]
143
+ )
144
+ datatypes: str | list[str] | None = None
145
+ check: bool = False
146
+ preload: bool = False
147
+ n_jobs: int = 1
148
+
149
+ @property
150
+ def _filter_out_epochs(self):
151
+ return True
152
+
153
+ def __post_init__(self):
154
+ bids_paths = mne_bids.find_matching_paths(
155
+ root=self.root,
156
+ subjects=self.subjects,
157
+ sessions=self.sessions,
158
+ tasks=self.tasks,
159
+ acquisitions=self.acquisitions,
160
+ runs=self.runs,
161
+ processings=self.processings,
162
+ recordings=self.recordings,
163
+ spaces=self.spaces,
164
+ splits=self.splits,
165
+ descriptions=self.descriptions,
166
+ suffixes=self.suffixes,
167
+ extensions=self.extensions,
168
+ datatypes=self.datatypes,
169
+ check=self.check,
170
+ )
171
+ # Filter out .json files files:
172
+ # (argument ignore_json only available in mne-bids>=0.16)
173
+ bids_paths = [
174
+ bids_path for bids_path in bids_paths if bids_path.extension != ".json"
175
+ ]
176
+ # Filter out _epo.fif files:
177
+ if self._filter_out_epochs:
178
+ bids_paths = [
179
+ bids_path
180
+ for bids_path in bids_paths
181
+ if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
182
+ ]
183
+
184
+ all_base_ds = Parallel(n_jobs=self.n_jobs)(
185
+ delayed(self._get_dataset)(bids_path) for bids_path in bids_paths
186
+ )
187
+ super().__init__(all_base_ds)
188
+
189
+ def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
190
+ description = _description_from_bids_path(bids_path)
191
+ raw = mne_bids.read_raw_bids(bids_path, verbose=False)
192
+ if self.preload:
193
+ raw.load_data()
194
+ return BaseDataset(raw, description)
195
+
196
+
197
+ class BIDSEpochsDataset(BIDSDataset):
198
+ """**Experimental** dataset for loading :class:`mne.Epochs` organised in BIDS.
199
+
200
+ The files must end with ``_epo.fif``.
201
+
202
+ .. Warning::
203
+ Epoched data is not officially supported in BIDS.
204
+
205
+ .. Note::
206
+ **Parameters:** This class has the same parameters as :class:`BIDSDataset` except
207
+ for arguments ``datatypes``, ``extensions`` and ``check`` which are fixed.
208
+ """
209
+
210
+ @property
211
+ def _filter_out_epochs(self):
212
+ return False
213
+
214
+ def __init__(self, *args, **kwargs):
215
+ super().__init__(
216
+ *args,
217
+ extensions=".fif",
218
+ suffixes="epo",
219
+ check=False,
220
+ **kwargs,
221
+ )
222
+
223
+ def _set_metadata(self, epochs: mne.BaseEpochs) -> None:
224
+ # events = mne.events_from_annotations(epochs
225
+ n_times = epochs.times.shape[0]
226
+ # id_event = {v: k for k, v in epochs.event_id.items()}
227
+ annotations = epochs.annotations
228
+ if annotations is not None:
229
+ target = annotations.description
230
+ else:
231
+ id_events = {v: k for k, v in epochs.event_id.items()}
232
+ target = [id_events[event_id] for event_id in epochs.events[:, -1]]
233
+ metadata_dict = {
234
+ "i_window_in_trial": np.zeros(len(epochs)),
235
+ "i_start_in_trial": np.zeros(len(epochs)),
236
+ "i_stop_in_trial": np.zeros(len(epochs)) + n_times,
237
+ "target": target,
238
+ }
239
+ epochs.metadata = pd.DataFrame(metadata_dict)
240
+
241
+ def _get_dataset(self, bids_path):
242
+ description = _description_from_bids_path(bids_path)
243
+ epochs = mne.read_epochs(bids_path.fpath)
244
+ self._set_metadata(epochs)
245
+ return WindowsDataset(epochs, description=description, targets_from="metadata")