braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,693 @@
1
+ # Authors: Robin Schirrmeister <robintibor@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import os.path
9
+ import re
10
+ import warnings
11
+ from glob import glob
12
+
13
+ import h5py
14
+ import mne
15
+ import numpy as np
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class BBCIDataset(object):
21
+ """BBCIDataset.
22
+
23
+ Loader class for files created by saving BBCI files in matlab (make
24
+ sure to save with '-v7.3' in matlab, see
25
+ https://de.mathworks.com/help/matlab/import_export/mat-file-versions.html#buk6i87
26
+ )
27
+
28
+ Parameters
29
+ ----------
30
+ filename : str
31
+ load_sensor_names : list of str, optional
32
+ Also speeds up loading if you only load some sensors.
33
+ None means load all sensors.
34
+ check_class_names : bool, optional
35
+ check if the class names are part of some known class names at
36
+ Translational NeuroTechnology Lab, AG Ball, Freiburg, Germany.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ filename: str,
42
+ load_sensor_names: list[str] | None = None,
43
+ check_class_names: bool = False,
44
+ ):
45
+ self.__dict__.update(locals())
46
+
47
+ def load(self) -> mne.io.RawArray:
48
+ cnt = self._load_continuous_signal()
49
+ cnt = self._add_markers(cnt)
50
+ return cnt
51
+
52
+ def _load_continuous_signal(self):
53
+ wanted_chan_inds, wanted_sensor_names = self._determine_sensors()
54
+ fs = self._determine_samplingrate()
55
+ with h5py.File(self.filename, "r") as h5file:
56
+ samples = int(h5file["nfo"]["T"][0, 0])
57
+ cnt_signal_shape = (samples, len(wanted_chan_inds))
58
+ continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
59
+ for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds):
60
+ # + 1 because matlab/this hdf5-naming logic
61
+ # has 1-based indexing
62
+ # i.e ch1,ch2,....
63
+ chan_set_name = "ch" + str(chan_ind_set + 1)
64
+ # first 0 to unpack into vector, before it is 1xN matrix
65
+ chan_signal = h5file[chan_set_name][
66
+ :
67
+ ].squeeze() # already load into memory
68
+ continuous_signal[:, chan_ind_arr] = chan_signal
69
+ assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal"
70
+
71
+ if self.load_sensor_names is None:
72
+ ch_types = ["eeg"] * len(wanted_chan_inds)
73
+ else:
74
+ warnings.warn("Setting to misc channel type as channel type not known")
75
+ # Assume we can't know channel type here automatically
76
+ ch_types = ["misc"] * len(wanted_chan_inds)
77
+ info = mne.create_info(
78
+ ch_names=wanted_sensor_names, sfreq=fs, ch_types=ch_types
79
+ )
80
+
81
+ cnt = mne.io.RawArray(continuous_signal.T, info)
82
+ return cnt
83
+
84
+ def _determine_sensors(self):
85
+ all_sensor_names = self.get_all_sensors(self.filename, pattern=None)
86
+ if self.load_sensor_names is None:
87
+ # if no sensor names given, take all EEG-chans
88
+ eeg_sensor_names = all_sensor_names
89
+ eeg_sensor_names = filter(
90
+ lambda s: not s.startswith("BIP"), eeg_sensor_names
91
+ )
92
+ eeg_sensor_names = filter(lambda s: not s.startswith("E"), eeg_sensor_names)
93
+ eeg_sensor_names = filter(
94
+ lambda s: not s.startswith("Microphone"), eeg_sensor_names
95
+ )
96
+ eeg_sensor_names = filter(
97
+ lambda s: not s.startswith("Breath"), eeg_sensor_names
98
+ )
99
+ eeg_sensor_names = filter(
100
+ lambda s: not s.startswith("GSR"), eeg_sensor_names
101
+ )
102
+ eeg_sensor_names = list(eeg_sensor_names)
103
+ assert (
104
+ len(eeg_sensor_names) == 128
105
+ or len(eeg_sensor_names) == 64
106
+ or len(eeg_sensor_names) == 32
107
+ or len(eeg_sensor_names) == 16
108
+ ), "Recheck this code if you have different sensors..."
109
+ wanted_sensor_names = eeg_sensor_names
110
+ else:
111
+ wanted_sensor_names = self.load_sensor_names
112
+ chan_inds = self._determine_chan_inds(all_sensor_names, wanted_sensor_names)
113
+ return chan_inds, wanted_sensor_names
114
+
115
+ def _determine_samplingrate(self):
116
+ with h5py.File(self.filename, "r") as h5file:
117
+ fs = h5file["nfo"]["fs"][0, 0]
118
+ assert isinstance(fs, int) or fs.is_integer()
119
+ fs = int(fs)
120
+ return fs
121
+
122
+ @staticmethod
123
+ def _determine_chan_inds(all_sensor_names, sensor_names):
124
+ assert sensor_names is not None
125
+ chan_inds = [all_sensor_names.index(s) for s in sensor_names]
126
+ assert len(chan_inds) == len(sensor_names), "Allsensors should be there."
127
+ assert len(set(chan_inds)) == len(chan_inds), "No duplicated sensors wanted."
128
+ return chan_inds
129
+
130
+ @staticmethod
131
+ def get_all_sensors(filename: str, pattern: str | None = None) -> list[str]:
132
+ """
133
+ Get all sensors that exist in the given file.
134
+
135
+ Parameters
136
+ ----------
137
+ filename : str
138
+ pattern : str, optional
139
+ Only return those sensor names that match the given pattern.
140
+
141
+ Returns
142
+ -------
143
+ sensor_names : list of str
144
+ Sensor names that match the pattern or all sensor names in the file.
145
+ """
146
+ with h5py.File(filename, "r") as h5file:
147
+ clab_set = h5file["nfo"]["clab"][:].squeeze()
148
+ all_sensor_names = [
149
+ "".join(chr(c.item()) for c in h5file[obj_ref]) for obj_ref in clab_set
150
+ ]
151
+ if pattern is not None:
152
+ all_sensor_names = list(
153
+ filter(lambda sname: re.search(pattern, sname), all_sensor_names)
154
+ )
155
+ return all_sensor_names
156
+
157
+ def _add_markers(self, cnt):
158
+ with h5py.File(self.filename, "r") as h5file:
159
+ event_times_in_ms = h5file["mrk"]["time"][:].squeeze()
160
+ event_classes = h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
161
+
162
+ # Check whether class names known and correct order
163
+ class_name_set = h5file["nfo"]["className"][:].squeeze()
164
+ all_class_names = [
165
+ "".join(chr(c.item()) for c in h5file[obj_ref])
166
+ for obj_ref in class_name_set
167
+ ]
168
+
169
+ if self.check_class_names:
170
+ _check_class_names(all_class_names, event_times_in_ms, event_classes)
171
+
172
+ event_times_in_samples = event_times_in_ms * cnt.info["sfreq"] / 1000.0
173
+ event_times_in_samples = np.uint32(np.round(event_times_in_samples))
174
+
175
+ # Check if there are markers at the same time
176
+ previous_i_sample = -1
177
+ for i_event, (i_sample, id_class) in enumerate(
178
+ zip(event_times_in_samples, event_classes)
179
+ ):
180
+ if i_sample == previous_i_sample:
181
+ log.warning(
182
+ "Same sample has at least two markers.\n"
183
+ "{:d}: ({:.0f} and {:.0f}).\n".format(
184
+ i_sample,
185
+ event_classes[i_event - 1],
186
+ event_classes[i_event],
187
+ )
188
+ + "Marker codes will be summed."
189
+ )
190
+ previous_i_sample = i_sample
191
+
192
+ # Now create stim chan
193
+ stim_chan = np.zeros_like(cnt.get_data()[0])
194
+ for i_sample, id_class in zip(event_times_in_samples, event_classes):
195
+ stim_chan[i_sample] += id_class
196
+ info = mne.create_info(
197
+ ch_names=["STI 014"], sfreq=cnt.info["sfreq"], ch_types=["stim"]
198
+ )
199
+ stim_cnt = mne.io.RawArray(stim_chan[None], info, verbose="WARNING")
200
+ cnt = cnt.add_channels([stim_cnt])
201
+ event_arr = [
202
+ event_times_in_samples,
203
+ [0] * len(event_times_in_samples),
204
+ event_classes,
205
+ ]
206
+ cnt.info["events"] = np.array(event_arr).T
207
+
208
+ # Generate Annotations
209
+ event_times_in_sec = event_times_in_ms / 1000.0
210
+ # Hacky way to try to find out class names for each event
211
+ # h5file['mrk']['y'] y contains one-hot label for event name
212
+ with h5py.File(self.filename, "r") as h5file:
213
+ y = h5file["mrk"]["y"][:]
214
+ # seems that there are cases where for last class
215
+ # y is just all zero for some reason?
216
+ # and seems then it is last of the class names
217
+ # ('Stimulation')
218
+ # at least in the file investigated
219
+ y[np.sum(y, axis=1) == 0, -1] = 1
220
+ assert np.all(np.sum(y, axis=1) == 1)
221
+ event_i_classes = np.argmax(y, axis=1)
222
+
223
+ # 4 second trials for High-Gamma dataset, otherwise how to know?
224
+ if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
225
+ durations = np.full(event_times_in_ms.shape, 4)
226
+ else:
227
+ warnings.warn("Unknown event durations set to 0")
228
+ durations = np.full(event_times_in_ms.shape, 0)
229
+
230
+ # Label information for this dataset
231
+ descriptions = [all_class_names[y] for y in event_i_classes]
232
+ annots = mne.Annotations(event_times_in_sec, durations, descriptions)
233
+ cnt.set_annotations(annots)
234
+
235
+ return cnt
236
+
237
+
238
+ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
239
+ """
240
+ Checks if the class names are part of some known class names used in.
241
+
242
+ translational neurotechnology lab, AG Ball, Freiburg.
243
+
244
+ Logs warning in case class names are not known.
245
+
246
+ Parameters
247
+ ----------
248
+ all_class_names : list of str
249
+ event_times_in_ms : list of number
250
+ event_classes : list of number
251
+ """
252
+ if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
253
+ pass
254
+ elif (
255
+ (
256
+ all_class_names
257
+ == [
258
+ "1",
259
+ "10",
260
+ "11",
261
+ "111",
262
+ "12",
263
+ "13",
264
+ "150",
265
+ "2",
266
+ "20",
267
+ "22",
268
+ "3",
269
+ "30",
270
+ "33",
271
+ "4",
272
+ "40",
273
+ "44",
274
+ "99",
275
+ ]
276
+ )
277
+ or (
278
+ all_class_names
279
+ == [
280
+ "1",
281
+ "10",
282
+ "11",
283
+ "12",
284
+ "13",
285
+ "150",
286
+ "2",
287
+ "20",
288
+ "22",
289
+ "3",
290
+ "30",
291
+ "33",
292
+ "4",
293
+ "40",
294
+ "44",
295
+ "99",
296
+ ]
297
+ )
298
+ or (all_class_names == ["1", "2", "3", "4"])
299
+ ):
300
+ pass # Semantic classes
301
+ elif all_class_names == ["Rest", "Feet", "Left Hand", "Right Hand"]:
302
+ # Have to swap from
303
+ # ['Rest', 'Feet', 'Left Hand', 'Right Hand']
304
+ # to
305
+ # ['Right Hand', 'Left Hand', 'Rest', 'Feet']
306
+ right_mask = event_classes == 4
307
+ left_mask = event_classes == 3
308
+ rest_mask = event_classes == 1
309
+ feet_mask = event_classes == 2
310
+ event_classes[right_mask] = 1
311
+ event_classes[left_mask] = 2
312
+ event_classes[rest_mask] = 3
313
+ event_classes[feet_mask] = 4
314
+ log.warn(
315
+ "Swapped class names {:s}... might cause problems...".format(
316
+ all_class_names
317
+ )
318
+ )
319
+ elif all_class_names == [
320
+ "Right Hand Start",
321
+ "Left Hand Start",
322
+ "Rest Start",
323
+ "Feet Start",
324
+ "Right Hand End",
325
+ "Left Hand End",
326
+ "Rest End",
327
+ "Feet End",
328
+ ]:
329
+ pass
330
+ elif all_class_names == [
331
+ "Right Hand",
332
+ "Left Hand",
333
+ "Rest",
334
+ "Feet",
335
+ "Face",
336
+ "Navigation",
337
+ "Music",
338
+ "Rotation",
339
+ "Subtraction",
340
+ "Words",
341
+ ]:
342
+ pass # robot hall 10 class decoding
343
+ elif all_class_names == [
344
+ "RightHand",
345
+ "Feet",
346
+ "Rotation",
347
+ "Words",
348
+ "\x00\x00",
349
+ "\x00\x00",
350
+ "\x00\x00",
351
+ "\x00\x00",
352
+ "\x00\x00",
353
+ "RightHand_End",
354
+ "\x00\x00",
355
+ "\x00\x00",
356
+ "\x00\x00",
357
+ "\x00\x00",
358
+ "\x00\x00",
359
+ "\x00\x00",
360
+ "\x00\x00",
361
+ "\x00\x00",
362
+ "\x00\x00",
363
+ "Feet_End",
364
+ "\x00\x00",
365
+ "\x00\x00",
366
+ "\x00\x00",
367
+ "\x00\x00",
368
+ "\x00\x00",
369
+ "\x00\x00",
370
+ "\x00\x00",
371
+ "\x00\x00",
372
+ "\x00\x00",
373
+ "Rotation_End",
374
+ "\x00\x00",
375
+ "\x00\x00",
376
+ "\x00\x00",
377
+ "\x00\x00",
378
+ "\x00\x00",
379
+ "\x00\x00",
380
+ "\x00\x00",
381
+ "\x00\x00",
382
+ "\x00\x00",
383
+ "Words_End",
384
+ ] or all_class_names == [
385
+ "RightHand",
386
+ "Feet",
387
+ "Rotation",
388
+ "Words",
389
+ "Rest",
390
+ "\x00\x00",
391
+ "\x00\x00",
392
+ "\x00\x00",
393
+ "\x00\x00",
394
+ "RightHand_End",
395
+ "\x00\x00",
396
+ "\x00\x00",
397
+ "\x00\x00",
398
+ "\x00\x00",
399
+ "\x00\x00",
400
+ "\x00\x00",
401
+ "\x00\x00",
402
+ "\x00\x00",
403
+ "\x00\x00",
404
+ "Feet_End",
405
+ "\x00\x00",
406
+ "\x00\x00",
407
+ "\x00\x00",
408
+ "\x00\x00",
409
+ "\x00\x00",
410
+ "\x00\x00",
411
+ "\x00\x00",
412
+ "\x00\x00",
413
+ "\x00\x00",
414
+ "Rotation_End",
415
+ "\x00\x00",
416
+ "\x00\x00",
417
+ "\x00\x00",
418
+ "\x00\x00",
419
+ "\x00\x00",
420
+ "\x00\x00",
421
+ "\x00\x00",
422
+ "\x00\x00",
423
+ "\x00\x00",
424
+ "Words_End",
425
+ "\x00\x00",
426
+ "\x00\x00",
427
+ "\x00\x00",
428
+ "\x00\x00",
429
+ "\x00\x00",
430
+ "\x00\x00",
431
+ "\x00\x00",
432
+ "\x00\x00",
433
+ "\x00\x00",
434
+ "Rest_End",
435
+ ]:
436
+ pass # weird stuff when we recorded cursor in robot hall
437
+ # on 2016-09-14 and 2016-09-16 :D
438
+
439
+ elif all_class_names == [
440
+ "0004",
441
+ "0016",
442
+ "0032",
443
+ "0056",
444
+ "0064",
445
+ "0088",
446
+ "0095",
447
+ "0120",
448
+ ]:
449
+ pass
450
+ elif all_class_names == ["0004", "0056", "0088", "0120"]:
451
+ pass
452
+ elif all_class_names == [
453
+ "0004",
454
+ "0016",
455
+ "0032",
456
+ "0048",
457
+ "0056",
458
+ "0064",
459
+ "0080",
460
+ "0088",
461
+ "0095",
462
+ "0120",
463
+ ]:
464
+ pass
465
+ elif all_class_names == ["0004", "0016", "0056", "0088", "0120", "__"]:
466
+ pass
467
+ elif all_class_names == ["0004", "0056", "0088", "0120", "__"]:
468
+ pass
469
+ elif all_class_names == [
470
+ "0004",
471
+ "0032",
472
+ "0048",
473
+ "0056",
474
+ "0064",
475
+ "0080",
476
+ "0088",
477
+ "0095",
478
+ "0120",
479
+ "__",
480
+ ]:
481
+ pass
482
+ elif all_class_names == [
483
+ "0004",
484
+ "0056",
485
+ "0080",
486
+ "0088",
487
+ "0096",
488
+ "0120",
489
+ "__",
490
+ ]:
491
+ pass
492
+ elif all_class_names == [
493
+ "0004",
494
+ "0032",
495
+ "0056",
496
+ "0064",
497
+ "0080",
498
+ "0088",
499
+ "0095",
500
+ "0120",
501
+ ]:
502
+ pass
503
+ elif all_class_names == [
504
+ "0004",
505
+ "0032",
506
+ "0048",
507
+ "0056",
508
+ "0064",
509
+ "0080",
510
+ "0088",
511
+ "0095",
512
+ "0120",
513
+ ]:
514
+ pass
515
+ elif all_class_names == [
516
+ "0004",
517
+ "0016",
518
+ "0032",
519
+ "0048",
520
+ "0056",
521
+ "0064",
522
+ "0080",
523
+ "0088",
524
+ "0095",
525
+ "0096",
526
+ "0120",
527
+ ]:
528
+ pass
529
+ elif all_class_names == ["4", "16", "32", "56", "64", "88", "95", "120"]:
530
+ pass
531
+ elif all_class_names == ["4", "56", "88", "120"]:
532
+ pass
533
+ elif all_class_names == [
534
+ "4",
535
+ "16",
536
+ "32",
537
+ "48",
538
+ "56",
539
+ "64",
540
+ "80",
541
+ "88",
542
+ "95",
543
+ "120",
544
+ ]:
545
+ pass
546
+ elif all_class_names == ["0", "4", "56", "88", "120"]:
547
+ pass
548
+ elif all_class_names == ["0", "4", "16", "56", "88", "120"]:
549
+ pass
550
+ elif all_class_names == [
551
+ "0",
552
+ "4",
553
+ "32",
554
+ "48",
555
+ "56",
556
+ "64",
557
+ "80",
558
+ "88",
559
+ "95",
560
+ "120",
561
+ ]:
562
+ pass
563
+ elif all_class_names == ["0", "4", "56", "80", "88", "96", "120"]:
564
+ pass
565
+ elif all_class_names == ["4", "32", "56", "64", "80", "88", "95", "120"]:
566
+ pass
567
+ elif all_class_names == ["One", "Two", "Three", "Four"]:
568
+ pass
569
+ elif all_class_names == [
570
+ "1",
571
+ "10",
572
+ "11",
573
+ "12",
574
+ "2",
575
+ "20",
576
+ "3",
577
+ "30",
578
+ "4",
579
+ "40",
580
+ ]:
581
+ pass
582
+ elif all_class_names == [
583
+ "1",
584
+ "10",
585
+ "12",
586
+ "13",
587
+ "2",
588
+ "20",
589
+ "3",
590
+ "30",
591
+ "4",
592
+ "40",
593
+ ]:
594
+ pass
595
+ elif all_class_names == [
596
+ "1",
597
+ "10",
598
+ "13",
599
+ "2",
600
+ "20",
601
+ "3",
602
+ "30",
603
+ "4",
604
+ "40",
605
+ "99",
606
+ ]:
607
+ pass
608
+ elif all_class_names == [
609
+ "1",
610
+ "10",
611
+ "11",
612
+ "14",
613
+ "18",
614
+ "20",
615
+ "21",
616
+ "24",
617
+ "251",
618
+ "252",
619
+ "28",
620
+ "30",
621
+ "4",
622
+ "8",
623
+ ]:
624
+ pass
625
+ elif all_class_names == [
626
+ "1",
627
+ "10",
628
+ "11",
629
+ "14",
630
+ "18",
631
+ "20",
632
+ "21",
633
+ "24",
634
+ "252",
635
+ "253",
636
+ "28",
637
+ "30",
638
+ "4",
639
+ "8",
640
+ ]:
641
+ pass
642
+ elif len(event_times_in_ms) == len(all_class_names):
643
+ pass # weird neuroone(?) logic where class names have event classes
644
+ elif all_class_names == [
645
+ "Right_hand_stimulus_onset",
646
+ "Feet_stimulus_onset",
647
+ "Rotation_stimulus_onset",
648
+ "Words_stimulus_onset",
649
+ "Right_hand_stimulus_offset",
650
+ "Feet_stimulus_offset",
651
+ "Rotation_stimulus_offset",
652
+ "Words_stimulus_offset",
653
+ ]:
654
+ pass
655
+ else:
656
+ # remove this whole if else stuffs?
657
+ log.warn("Unknown class names {:s}".format(all_class_names))
658
+
659
+
660
+ def load_bbci_sets_from_folder(
661
+ folder: str, runs: list[int] | str = "all"
662
+ ) -> list[mne.io.RawArray]:
663
+ """
664
+ Load bbci datasets from files in given folder.
665
+
666
+ Parameters
667
+ ----------
668
+ folder : str
669
+ Folder with .BBCI.mat files inside
670
+ runs : list of int
671
+ If you only want to load specific runs.
672
+ Assumes filenames with such kind of part: S001R02 for Run 2.
673
+ Tries to match this regex: ``'S[0-9]{3,3}R[0-9]{2,2}_'``.
674
+
675
+ Returns
676
+ -------
677
+ """
678
+ bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
679
+ if runs != "all":
680
+ assert isinstance(runs, list), "runs should be list[int] or 'all'"
681
+ matches = [re.search("S[0-9]{3,3}R[0-9]{2,2}_", f) for f in bbci_mat_files]
682
+ file_run_numbers = [int(m.group()[5:7]) for m in matches if m is not None]
683
+ assert len(file_run_numbers) == len(bbci_mat_files), "Some files don't match"
684
+ indices = [file_run_numbers.index(num) for num in runs]
685
+
686
+ wanted_files = np.array(bbci_mat_files)[indices]
687
+ else:
688
+ wanted_files = bbci_mat_files
689
+ cnts = []
690
+ for f in wanted_files:
691
+ log.info("Loading {:s}".format(f))
692
+ cnts.append(BBCIDataset(f).load())
693
+ return cnts