braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/bbci.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
5
7
|
import logging
|
|
6
8
|
import os.path
|
|
7
9
|
import re
|
|
@@ -16,7 +18,8 @@ log = logging.getLogger(__name__)
|
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class BBCIDataset(object):
|
|
19
|
-
"""
|
|
21
|
+
"""BBCIDataset.
|
|
22
|
+
|
|
20
23
|
Loader class for files created by saving BBCI files in matlab (make
|
|
21
24
|
sure to save with '-v7.3' in matlab, see
|
|
22
25
|
https://de.mathworks.com/help/matlab/import_export/mat-file-versions.html#buk6i87
|
|
@@ -34,12 +37,14 @@ class BBCIDataset(object):
|
|
|
34
37
|
"""
|
|
35
38
|
|
|
36
39
|
def __init__(
|
|
37
|
-
self,
|
|
40
|
+
self,
|
|
41
|
+
filename: str,
|
|
42
|
+
load_sensor_names: list[str] | None = None,
|
|
43
|
+
check_class_names: bool = False,
|
|
38
44
|
):
|
|
39
45
|
self.__dict__.update(locals())
|
|
40
|
-
del self.self
|
|
41
46
|
|
|
42
|
-
def load(self):
|
|
47
|
+
def load(self) -> mne.io.RawArray:
|
|
43
48
|
cnt = self._load_continuous_signal()
|
|
44
49
|
cnt = self._add_markers(cnt)
|
|
45
50
|
return cnt
|
|
@@ -50,9 +55,7 @@ class BBCIDataset(object):
|
|
|
50
55
|
with h5py.File(self.filename, "r") as h5file:
|
|
51
56
|
samples = int(h5file["nfo"]["T"][0, 0])
|
|
52
57
|
cnt_signal_shape = (samples, len(wanted_chan_inds))
|
|
53
|
-
continuous_signal = (
|
|
54
|
-
np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
|
|
55
|
-
)
|
|
58
|
+
continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
|
|
56
59
|
for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds):
|
|
57
60
|
# + 1 because matlab/this hdf5-naming logic
|
|
58
61
|
# has 1-based indexing
|
|
@@ -63,9 +66,7 @@ class BBCIDataset(object):
|
|
|
63
66
|
:
|
|
64
67
|
].squeeze() # already load into memory
|
|
65
68
|
continuous_signal[:, chan_ind_arr] = chan_signal
|
|
66
|
-
assert not np.any(
|
|
67
|
-
np.isnan(continuous_signal)
|
|
68
|
-
), "No NaNs expected in signal"
|
|
69
|
+
assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal"
|
|
69
70
|
|
|
70
71
|
if self.load_sensor_names is None:
|
|
71
72
|
ch_types = ["eeg"] * len(wanted_chan_inds)
|
|
@@ -83,15 +84,12 @@ class BBCIDataset(object):
|
|
|
83
84
|
def _determine_sensors(self):
|
|
84
85
|
all_sensor_names = self.get_all_sensors(self.filename, pattern=None)
|
|
85
86
|
if self.load_sensor_names is None:
|
|
86
|
-
|
|
87
87
|
# if no sensor names given, take all EEG-chans
|
|
88
88
|
eeg_sensor_names = all_sensor_names
|
|
89
89
|
eeg_sensor_names = filter(
|
|
90
90
|
lambda s: not s.startswith("BIP"), eeg_sensor_names
|
|
91
91
|
)
|
|
92
|
-
eeg_sensor_names = filter(
|
|
93
|
-
lambda s: not s.startswith("E"), eeg_sensor_names
|
|
94
|
-
)
|
|
92
|
+
eeg_sensor_names = filter(lambda s: not s.startswith("E"), eeg_sensor_names)
|
|
95
93
|
eeg_sensor_names = filter(
|
|
96
94
|
lambda s: not s.startswith("Microphone"), eeg_sensor_names
|
|
97
95
|
)
|
|
@@ -103,17 +101,15 @@ class BBCIDataset(object):
|
|
|
103
101
|
)
|
|
104
102
|
eeg_sensor_names = list(eeg_sensor_names)
|
|
105
103
|
assert (
|
|
106
|
-
len(eeg_sensor_names) == 128
|
|
107
|
-
len(eeg_sensor_names) == 64
|
|
108
|
-
len(eeg_sensor_names) == 32
|
|
109
|
-
len(eeg_sensor_names) == 16
|
|
104
|
+
len(eeg_sensor_names) == 128
|
|
105
|
+
or len(eeg_sensor_names) == 64
|
|
106
|
+
or len(eeg_sensor_names) == 32
|
|
107
|
+
or len(eeg_sensor_names) == 16
|
|
110
108
|
), "Recheck this code if you have different sensors..."
|
|
111
109
|
wanted_sensor_names = eeg_sensor_names
|
|
112
110
|
else:
|
|
113
111
|
wanted_sensor_names = self.load_sensor_names
|
|
114
|
-
chan_inds = self._determine_chan_inds(
|
|
115
|
-
all_sensor_names, wanted_sensor_names
|
|
116
|
-
)
|
|
112
|
+
chan_inds = self._determine_chan_inds(all_sensor_names, wanted_sensor_names)
|
|
117
113
|
return chan_inds, wanted_sensor_names
|
|
118
114
|
|
|
119
115
|
def _determine_samplingrate(self):
|
|
@@ -127,16 +123,12 @@ class BBCIDataset(object):
|
|
|
127
123
|
def _determine_chan_inds(all_sensor_names, sensor_names):
|
|
128
124
|
assert sensor_names is not None
|
|
129
125
|
chan_inds = [all_sensor_names.index(s) for s in sensor_names]
|
|
130
|
-
assert len(chan_inds) == len(sensor_names),
|
|
131
|
-
|
|
132
|
-
)
|
|
133
|
-
assert len(set(chan_inds)) == len(chan_inds), (
|
|
134
|
-
"No duplicated sensors wanted."
|
|
135
|
-
)
|
|
126
|
+
assert len(chan_inds) == len(sensor_names), "Allsensors should be there."
|
|
127
|
+
assert len(set(chan_inds)) == len(chan_inds), "No duplicated sensors wanted."
|
|
136
128
|
return chan_inds
|
|
137
129
|
|
|
138
130
|
@staticmethod
|
|
139
|
-
def get_all_sensors(filename, pattern=None):
|
|
131
|
+
def get_all_sensors(filename: str, pattern: str | None = None) -> list[str]:
|
|
140
132
|
"""
|
|
141
133
|
Get all sensors that exist in the given file.
|
|
142
134
|
|
|
@@ -157,17 +149,15 @@ class BBCIDataset(object):
|
|
|
157
149
|
"".join(chr(c.item()) for c in h5file[obj_ref]) for obj_ref in clab_set
|
|
158
150
|
]
|
|
159
151
|
if pattern is not None:
|
|
160
|
-
all_sensor_names =
|
|
161
|
-
lambda sname: re.search(pattern, sname), all_sensor_names
|
|
152
|
+
all_sensor_names = list(
|
|
153
|
+
filter(lambda sname: re.search(pattern, sname), all_sensor_names)
|
|
162
154
|
)
|
|
163
155
|
return all_sensor_names
|
|
164
156
|
|
|
165
157
|
def _add_markers(self, cnt):
|
|
166
158
|
with h5py.File(self.filename, "r") as h5file:
|
|
167
159
|
event_times_in_ms = h5file["mrk"]["time"][:].squeeze()
|
|
168
|
-
event_classes = (
|
|
169
|
-
h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
|
|
170
|
-
)
|
|
160
|
+
event_classes = h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
|
|
171
161
|
|
|
172
162
|
# Check whether class names known and correct order
|
|
173
163
|
class_name_set = h5file["nfo"]["className"][:].squeeze()
|
|
@@ -177,9 +167,7 @@ class BBCIDataset(object):
|
|
|
177
167
|
]
|
|
178
168
|
|
|
179
169
|
if self.check_class_names:
|
|
180
|
-
_check_class_names(
|
|
181
|
-
all_class_names, event_times_in_ms, event_classes
|
|
182
|
-
)
|
|
170
|
+
_check_class_names(all_class_names, event_times_in_ms, event_classes)
|
|
183
171
|
|
|
184
172
|
event_times_in_samples = event_times_in_ms * cnt.info["sfreq"] / 1000.0
|
|
185
173
|
event_times_in_samples = np.uint32(np.round(event_times_in_samples))
|
|
@@ -196,8 +184,8 @@ class BBCIDataset(object):
|
|
|
196
184
|
i_sample,
|
|
197
185
|
event_classes[i_event - 1],
|
|
198
186
|
event_classes[i_event],
|
|
199
|
-
)
|
|
200
|
-
"Marker codes will be summed."
|
|
187
|
+
)
|
|
188
|
+
+ "Marker codes will be summed."
|
|
201
189
|
)
|
|
202
190
|
previous_i_sample = i_sample
|
|
203
191
|
|
|
@@ -222,7 +210,7 @@ class BBCIDataset(object):
|
|
|
222
210
|
# Hacky way to try to find out class names for each event
|
|
223
211
|
# h5file['mrk']['y'] y contains one-hot label for event name
|
|
224
212
|
with h5py.File(self.filename, "r") as h5file:
|
|
225
|
-
y = h5file[
|
|
213
|
+
y = h5file["mrk"]["y"][:]
|
|
226
214
|
# seems that there are cases where for last class
|
|
227
215
|
# y is just all zero for some reason?
|
|
228
216
|
# and seems then it is last of the class names
|
|
@@ -233,7 +221,7 @@ class BBCIDataset(object):
|
|
|
233
221
|
event_i_classes = np.argmax(y, axis=1)
|
|
234
222
|
|
|
235
223
|
# 4 second trials for High-Gamma dataset, otherwise how to know?
|
|
236
|
-
if all_class_names == [
|
|
224
|
+
if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
|
|
237
225
|
durations = np.full(event_times_in_ms.shape, 4)
|
|
238
226
|
else:
|
|
239
227
|
warnings.warn("Unknown event durations set to 0")
|
|
@@ -265,8 +253,8 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
|
265
253
|
pass
|
|
266
254
|
elif (
|
|
267
255
|
(
|
|
268
|
-
all_class_names
|
|
269
|
-
[
|
|
256
|
+
all_class_names
|
|
257
|
+
== [
|
|
270
258
|
"1",
|
|
271
259
|
"10",
|
|
272
260
|
"11",
|
|
@@ -285,9 +273,10 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
|
285
273
|
"44",
|
|
286
274
|
"99",
|
|
287
275
|
]
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
|
|
276
|
+
)
|
|
277
|
+
or (
|
|
278
|
+
all_class_names
|
|
279
|
+
== [
|
|
291
280
|
"1",
|
|
292
281
|
"10",
|
|
293
282
|
"11",
|
|
@@ -305,8 +294,8 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
|
305
294
|
"44",
|
|
306
295
|
"99",
|
|
307
296
|
]
|
|
308
|
-
)
|
|
309
|
-
|
|
297
|
+
)
|
|
298
|
+
or (all_class_names == ["1", "2", "3", "4"])
|
|
310
299
|
):
|
|
311
300
|
pass # Semantic classes
|
|
312
301
|
elif all_class_names == ["Rest", "Feet", "Left Hand", "Right Hand"]:
|
|
@@ -668,7 +657,9 @@ def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
|
668
657
|
log.warn("Unknown class names {:s}".format(all_class_names))
|
|
669
658
|
|
|
670
659
|
|
|
671
|
-
def load_bbci_sets_from_folder(
|
|
660
|
+
def load_bbci_sets_from_folder(
|
|
661
|
+
folder: str, runs: list[int] | str = "all"
|
|
662
|
+
) -> list[mne.io.RawArray]:
|
|
672
663
|
"""
|
|
673
664
|
Load bbci datasets from files in given folder.
|
|
674
665
|
|
|
@@ -687,10 +678,10 @@ def load_bbci_sets_from_folder(folder, runs="all"):
|
|
|
687
678
|
"""
|
|
688
679
|
bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
|
|
689
680
|
if runs != "all":
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
681
|
+
assert isinstance(runs, list), "runs should be list[int] or 'all'"
|
|
682
|
+
matches = [re.search("S[0-9]{3,3}R[0-9]{2,2}_", f) for f in bbci_mat_files]
|
|
683
|
+
file_run_numbers = [int(m.group()[5:7]) for m in matches if m is not None]
|
|
684
|
+
assert len(file_run_numbers) == len(bbci_mat_files), "Some files don't match"
|
|
694
685
|
indices = [file_run_numbers.index(num) for num in runs]
|
|
695
686
|
|
|
696
687
|
wanted_files = np.array(bbci_mat_files)[indices]
|
braindecode/datasets/bcicomp.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import glob
|
|
7
9
|
import os
|
|
8
10
|
import os.path as osp
|
|
@@ -14,10 +16,12 @@ import numpy as np
|
|
|
14
16
|
from mne.utils import verbose
|
|
15
17
|
from scipy.io import loadmat
|
|
16
18
|
|
|
17
|
-
from braindecode.datasets import
|
|
19
|
+
from braindecode.datasets import BaseConcatDataset, BaseDataset
|
|
18
20
|
|
|
19
|
-
DATASET_URL =
|
|
20
|
-
|
|
21
|
+
DATASET_URL = (
|
|
22
|
+
"https://stacks.stanford.edu/file/druid:zk881ps0522/"
|
|
23
|
+
"BCI_Competion4_dataset4_data_fingerflexions.zip"
|
|
24
|
+
)
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
@@ -42,30 +46,32 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
42
46
|
References
|
|
43
47
|
----------
|
|
44
48
|
.. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
|
|
45
|
-
|
|
49
|
+
Nature human behaviour 3, no. 11 (2019): 1225-1235.
|
|
50
|
+
https://doi.org/10.1038/s41562-019-0678-3
|
|
46
51
|
"""
|
|
52
|
+
|
|
47
53
|
possible_subjects = [1, 2, 3]
|
|
48
54
|
|
|
49
|
-
def __init__(self, subject_ids=None):
|
|
55
|
+
def __init__(self, subject_ids: list[int] | int | None = None):
|
|
50
56
|
data_path = self.download()
|
|
51
57
|
if isinstance(subject_ids, int):
|
|
52
58
|
subject_ids = [subject_ids]
|
|
53
59
|
if subject_ids is None:
|
|
54
60
|
subject_ids = self.possible_subjects
|
|
55
61
|
self._validate_subjects(subject_ids)
|
|
56
|
-
files_list = [f
|
|
62
|
+
files_list = [f"{data_path}/sub{i}_comp.mat" for i in subject_ids]
|
|
57
63
|
datasets = []
|
|
58
64
|
for file_path in files_list:
|
|
59
65
|
raw_train, raw_test = self._load_data_to_mne(file_path)
|
|
60
66
|
desc_train = dict(
|
|
61
|
-
subject=file_path.split(
|
|
62
|
-
file_name=file_path.split(
|
|
63
|
-
session=
|
|
67
|
+
subject=file_path.split("/")[-1].split("sub")[1][0],
|
|
68
|
+
file_name=file_path.split("/")[-1],
|
|
69
|
+
session="train",
|
|
64
70
|
)
|
|
65
71
|
desc_test = dict(
|
|
66
|
-
subject=file_path.split(
|
|
67
|
-
file_name=file_path.split(
|
|
68
|
-
session=
|
|
72
|
+
subject=file_path.split("/")[-1].split("sub")[1][0],
|
|
73
|
+
file_name=file_path.split("/")[-1],
|
|
74
|
+
session="test",
|
|
69
75
|
)
|
|
70
76
|
datasets.append(BaseDataset(raw_train, description=desc_train))
|
|
71
77
|
datasets.append(BaseDataset(raw_test, description=desc_test))
|
|
@@ -90,20 +96,24 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
90
96
|
-------
|
|
91
97
|
|
|
92
98
|
"""
|
|
93
|
-
signature =
|
|
94
|
-
folder_name =
|
|
99
|
+
signature = "BCICompetitionIVDataset4"
|
|
100
|
+
folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
|
|
95
101
|
# Check if the dataset already exists (unpacked). We have to do that manually
|
|
96
102
|
# because we are removing .zip file from disk to save disk space.
|
|
97
103
|
|
|
98
104
|
from moabb.datasets.download import get_dataset_path # keep soft dependency
|
|
105
|
+
|
|
99
106
|
path = get_dataset_path(signature, path)
|
|
100
107
|
key_dest = "MNE-{:s}-data".format(signature.lower())
|
|
101
108
|
# We do not use mne _url_to_local_path due to ':' in the url that causes problems on Windows
|
|
102
109
|
destination = osp.join(path, key_dest, folder_name)
|
|
103
|
-
if len(list(glob.glob(osp.join(destination,
|
|
110
|
+
if len(list(glob.glob(osp.join(destination, "*.mat")))) == 6:
|
|
104
111
|
return destination
|
|
105
|
-
data_path = _data_dl(
|
|
106
|
-
|
|
112
|
+
data_path = _data_dl(
|
|
113
|
+
DATASET_URL,
|
|
114
|
+
osp.join(destination, folder_name, signature),
|
|
115
|
+
force_update=force_update,
|
|
116
|
+
)
|
|
107
117
|
unpack_archive(data_path, osp.dirname(destination))
|
|
108
118
|
# removes .zip file that the data was unpacked from
|
|
109
119
|
remove(data_path)
|
|
@@ -117,26 +127,30 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
117
127
|
|
|
118
128
|
def _load_data_to_mne(self, file_path):
|
|
119
129
|
data = loadmat(file_path)
|
|
120
|
-
test_labels = loadmat(file_path.replace(
|
|
121
|
-
train_data = data[
|
|
122
|
-
test_data = data[
|
|
123
|
-
upsampled_train_targets = data[
|
|
124
|
-
upsampled_test_targets = test_labels[
|
|
130
|
+
test_labels = loadmat(file_path.replace("comp.mat", "testlabels.mat"))
|
|
131
|
+
train_data = data["train_data"]
|
|
132
|
+
test_data = data["test_data"]
|
|
133
|
+
upsampled_train_targets = data["train_dg"]
|
|
134
|
+
upsampled_test_targets = test_labels["test_dg"]
|
|
125
135
|
|
|
126
136
|
signal_sfreq = 1000
|
|
127
137
|
original_target_sfreq = 25
|
|
128
138
|
targets_stride = int(signal_sfreq / original_target_sfreq)
|
|
129
139
|
|
|
130
|
-
original_targets = self._prepare_targets(
|
|
131
|
-
|
|
140
|
+
original_targets = self._prepare_targets(
|
|
141
|
+
upsampled_train_targets, targets_stride
|
|
142
|
+
)
|
|
143
|
+
original_test_targets = self._prepare_targets(
|
|
144
|
+
upsampled_test_targets, targets_stride
|
|
145
|
+
)
|
|
132
146
|
|
|
133
|
-
ch_names = [f
|
|
134
|
-
ch_names += [f
|
|
135
|
-
ch_types = [
|
|
136
|
-
ch_types += [
|
|
147
|
+
ch_names = [f"{i}" for i in range(train_data.shape[1])]
|
|
148
|
+
ch_names += [f"target_{i}" for i in range(original_targets.shape[1])]
|
|
149
|
+
ch_types = ["ecog" for _ in range(train_data.shape[1])]
|
|
150
|
+
ch_types += ["misc" for _ in range(original_targets.shape[1])]
|
|
137
151
|
|
|
138
152
|
info = mne.create_info(sfreq=signal_sfreq, ch_names=ch_names, ch_types=ch_types)
|
|
139
|
-
info[
|
|
153
|
+
info["temp"] = dict(target_sfreq=original_target_sfreq)
|
|
140
154
|
train_data = np.concatenate([train_data, original_targets], axis=1)
|
|
141
155
|
test_data = np.concatenate([test_data, original_test_targets], axis=1)
|
|
142
156
|
|
|
@@ -149,12 +163,12 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
149
163
|
if isinstance(subject_ids, (list, tuple)):
|
|
150
164
|
if not all((subject in self.possible_subjects for subject in subject_ids)):
|
|
151
165
|
raise ValueError(
|
|
152
|
-
f
|
|
153
|
-
f
|
|
166
|
+
f"Wrong subject_ids parameter. Possible values: {self.possible_subjects}. "
|
|
167
|
+
f"Provided {subject_ids}."
|
|
154
168
|
)
|
|
155
169
|
else:
|
|
156
170
|
raise ValueError(
|
|
157
|
-
|
|
171
|
+
"Wrong subject_ids format. Expected types: None, list, tuple, int."
|
|
158
172
|
)
|
|
159
173
|
|
|
160
174
|
|
|
@@ -165,6 +179,7 @@ def _data_dl(url, destination, force_update=False, verbose=None):
|
|
|
165
179
|
# moabb/datasets/download.py
|
|
166
180
|
|
|
167
181
|
from pooch import file_hash, retrieve # keep soft dependency
|
|
182
|
+
|
|
168
183
|
if not osp.isfile(destination) or force_update:
|
|
169
184
|
if osp.isfile(destination):
|
|
170
185
|
os.remove(destination)
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Dataset for loading BIDS.
|
|
2
|
+
|
|
3
|
+
More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
7
|
+
#
|
|
8
|
+
# License: BSD (3-clause)
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import mne
|
|
17
|
+
import mne_bids
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
from joblib import Parallel, delayed
|
|
21
|
+
|
|
22
|
+
from .base import BaseConcatDataset, BaseDataset, WindowsDataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
|
|
26
|
+
return {
|
|
27
|
+
"path": bids_path.fpath,
|
|
28
|
+
"subject": bids_path.subject,
|
|
29
|
+
"session": bids_path.session,
|
|
30
|
+
"task": bids_path.task,
|
|
31
|
+
"acquisition": bids_path.acquisition,
|
|
32
|
+
"run": bids_path.run,
|
|
33
|
+
"processing": bids_path.processing,
|
|
34
|
+
"recording": bids_path.recording,
|
|
35
|
+
"space": bids_path.space,
|
|
36
|
+
"split": bids_path.split,
|
|
37
|
+
"description": bids_path.description,
|
|
38
|
+
"suffix": bids_path.suffix,
|
|
39
|
+
"extension": bids_path.extension,
|
|
40
|
+
"datatype": bids_path.datatype,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class BIDSDataset(BaseConcatDataset):
|
|
46
|
+
"""Dataset for loading BIDS.
|
|
47
|
+
|
|
48
|
+
This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
|
|
49
|
+
as it will be used to find the files to load. The default ``extensions`` parameter was changed.
|
|
50
|
+
|
|
51
|
+
More information on BIDS (Brain Imaging Data Structure)
|
|
52
|
+
can be found at https://bids.neuroimaging.io
|
|
53
|
+
|
|
54
|
+
.. Note::
|
|
55
|
+
For loading "unofficial" BIDS datasets containing epoched data,
|
|
56
|
+
you can use :class:`BIDSEpochsDataset`.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
root : pathlib.Path | str
|
|
61
|
+
The root of the BIDS path.
|
|
62
|
+
subjects : str | array-like of str | None
|
|
63
|
+
The subject ID. Corresponds to "sub".
|
|
64
|
+
sessions : str | array-like of str | None
|
|
65
|
+
The acquisition session. Corresponds to "ses".
|
|
66
|
+
tasks : str | array-like of str | None
|
|
67
|
+
The experimental task. Corresponds to "task".
|
|
68
|
+
acquisitions: str | array-like of str | None
|
|
69
|
+
The acquisition parameters. Corresponds to "acq".
|
|
70
|
+
runs : str | array-like of str | None
|
|
71
|
+
The run number. Corresponds to "run".
|
|
72
|
+
processings : str | array-like of str | None
|
|
73
|
+
The processing label. Corresponds to "proc".
|
|
74
|
+
recordings : str | array-like of str | None
|
|
75
|
+
The recording name. Corresponds to "rec".
|
|
76
|
+
spaces : str | array-like of str | None
|
|
77
|
+
The coordinate space for anatomical and sensor location
|
|
78
|
+
files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
|
|
79
|
+
Corresponds to "space".
|
|
80
|
+
Note that valid values for ``space`` must come from a list
|
|
81
|
+
of BIDS keywords as described in the BIDS specification.
|
|
82
|
+
splits : str | array-like of str | None
|
|
83
|
+
The split of the continuous recording file for ``.fif`` data.
|
|
84
|
+
Corresponds to "split".
|
|
85
|
+
descriptions : str | array-like of str | None
|
|
86
|
+
This corresponds to the BIDS entity ``desc``. It is used to provide
|
|
87
|
+
additional information for derivative data, e.g., preprocessed data
|
|
88
|
+
may be assigned ``description='cleaned'``.
|
|
89
|
+
suffixes : str | array-like of str | None
|
|
90
|
+
The filename suffix. This is the entity after the
|
|
91
|
+
last ``_`` before the extension. E.g., ``'channels'``.
|
|
92
|
+
The following filename suffix's are accepted:
|
|
93
|
+
'meg', 'markers', 'eeg', 'ieeg', 'T1w',
|
|
94
|
+
'participants', 'scans', 'electrodes', 'coordsystem',
|
|
95
|
+
'channels', 'events', 'headshape', 'digitizer',
|
|
96
|
+
'beh', 'physio', 'stim'
|
|
97
|
+
extensions : str | array-like of str | None
|
|
98
|
+
The extension of the filename. E.g., ``'.json'``.
|
|
99
|
+
By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
|
|
100
|
+
datatypes : str | array-like of str | None
|
|
101
|
+
The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
|
|
102
|
+
``'ieeg'``.
|
|
103
|
+
check : bool
|
|
104
|
+
If ``True``, only returns paths that conform to BIDS. If ``False``
|
|
105
|
+
(default), the ``.check`` attribute of the returned
|
|
106
|
+
:class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
|
|
107
|
+
do conform to BIDS, and to ``False`` for those that don't.
|
|
108
|
+
preload : bool
|
|
109
|
+
If True, preload the data. Defaults to False.
|
|
110
|
+
n_jobs : int
|
|
111
|
+
Number of jobs to run in parallel. Defaults to 1.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
root: Path | str
|
|
115
|
+
subjects: str | list[str] | None = None
|
|
116
|
+
sessions: str | list[str] | None = None
|
|
117
|
+
tasks: str | list[str] | None = None
|
|
118
|
+
acquisitions: str | list[str] | None = None
|
|
119
|
+
runs: str | list[str] | None = None
|
|
120
|
+
processings: str | list[str] | None = None
|
|
121
|
+
recordings: str | list[str] | None = None
|
|
122
|
+
spaces: str | list[str] | None = None
|
|
123
|
+
splits: str | list[str] | None = None
|
|
124
|
+
descriptions: str | list[str] | None = None
|
|
125
|
+
suffixes: str | list[str] | None = None
|
|
126
|
+
extensions: str | list[str] | None = field(
|
|
127
|
+
default_factory=lambda: [
|
|
128
|
+
".con",
|
|
129
|
+
".sqd",
|
|
130
|
+
".pdf",
|
|
131
|
+
".fif",
|
|
132
|
+
".ds",
|
|
133
|
+
".vhdr",
|
|
134
|
+
".set",
|
|
135
|
+
".edf",
|
|
136
|
+
".bdf",
|
|
137
|
+
".EDF",
|
|
138
|
+
".snirf",
|
|
139
|
+
".cdt",
|
|
140
|
+
".mef",
|
|
141
|
+
".nwb",
|
|
142
|
+
]
|
|
143
|
+
)
|
|
144
|
+
datatypes: str | list[str] | None = None
|
|
145
|
+
check: bool = False
|
|
146
|
+
preload: bool = False
|
|
147
|
+
n_jobs: int = 1
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def _filter_out_epochs(self):
|
|
151
|
+
return True
|
|
152
|
+
|
|
153
|
+
def __post_init__(self):
|
|
154
|
+
bids_paths = mne_bids.find_matching_paths(
|
|
155
|
+
root=self.root,
|
|
156
|
+
subjects=self.subjects,
|
|
157
|
+
sessions=self.sessions,
|
|
158
|
+
tasks=self.tasks,
|
|
159
|
+
acquisitions=self.acquisitions,
|
|
160
|
+
runs=self.runs,
|
|
161
|
+
processings=self.processings,
|
|
162
|
+
recordings=self.recordings,
|
|
163
|
+
spaces=self.spaces,
|
|
164
|
+
splits=self.splits,
|
|
165
|
+
descriptions=self.descriptions,
|
|
166
|
+
suffixes=self.suffixes,
|
|
167
|
+
extensions=self.extensions,
|
|
168
|
+
datatypes=self.datatypes,
|
|
169
|
+
check=self.check,
|
|
170
|
+
)
|
|
171
|
+
# Filter out .json files files:
|
|
172
|
+
# (argument ignore_json only available in mne-bids>=0.16)
|
|
173
|
+
bids_paths = [
|
|
174
|
+
bids_path for bids_path in bids_paths if bids_path.extension != ".json"
|
|
175
|
+
]
|
|
176
|
+
# Filter out _epo.fif files:
|
|
177
|
+
if self._filter_out_epochs:
|
|
178
|
+
bids_paths = [
|
|
179
|
+
bids_path
|
|
180
|
+
for bids_path in bids_paths
|
|
181
|
+
if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
all_base_ds = Parallel(n_jobs=self.n_jobs)(
|
|
185
|
+
delayed(self._get_dataset)(bids_path) for bids_path in bids_paths
|
|
186
|
+
)
|
|
187
|
+
super().__init__(all_base_ds)
|
|
188
|
+
|
|
189
|
+
def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
|
|
190
|
+
description = _description_from_bids_path(bids_path)
|
|
191
|
+
raw = mne_bids.read_raw_bids(bids_path, verbose=False)
|
|
192
|
+
if self.preload:
|
|
193
|
+
raw.load_data()
|
|
194
|
+
return BaseDataset(raw, description)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class BIDSEpochsDataset(BIDSDataset):
|
|
198
|
+
"""**Experimental** dataset for loading :class:`mne.Epochs` organised in BIDS.
|
|
199
|
+
|
|
200
|
+
The files must end with ``_epo.fif``.
|
|
201
|
+
|
|
202
|
+
.. Warning::
|
|
203
|
+
Epoched data is not officially supported in BIDS.
|
|
204
|
+
|
|
205
|
+
.. Note::
|
|
206
|
+
**Parameters:** This class has the same parameters as :class:`BIDSDataset` except
|
|
207
|
+
for arguments ``datatypes``, ``extensions`` and ``check`` which are fixed.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def _filter_out_epochs(self):
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
def __init__(self, *args, **kwargs):
|
|
215
|
+
super().__init__(
|
|
216
|
+
*args,
|
|
217
|
+
extensions=".fif",
|
|
218
|
+
suffixes="epo",
|
|
219
|
+
check=False,
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _set_metadata(self, epochs: mne.BaseEpochs) -> None:
|
|
224
|
+
# events = mne.events_from_annotations(epochs
|
|
225
|
+
n_times = epochs.times.shape[0]
|
|
226
|
+
# id_event = {v: k for k, v in epochs.event_id.items()}
|
|
227
|
+
annotations = epochs.annotations
|
|
228
|
+
if annotations is not None:
|
|
229
|
+
target = annotations.description
|
|
230
|
+
else:
|
|
231
|
+
id_events = {v: k for k, v in epochs.event_id.items()}
|
|
232
|
+
target = [id_events[event_id] for event_id in epochs.events[:, -1]]
|
|
233
|
+
metadata_dict = {
|
|
234
|
+
"i_window_in_trial": np.zeros(len(epochs)),
|
|
235
|
+
"i_start_in_trial": np.zeros(len(epochs)),
|
|
236
|
+
"i_stop_in_trial": np.zeros(len(epochs)) + n_times,
|
|
237
|
+
"target": target,
|
|
238
|
+
}
|
|
239
|
+
epochs.metadata = pd.DataFrame(metadata_dict)
|
|
240
|
+
|
|
241
|
+
def _get_dataset(self, bids_path):
|
|
242
|
+
description = _description_from_bids_path(bids_path)
|
|
243
|
+
epochs = mne.read_epochs(bids_path.fpath)
|
|
244
|
+
self._set_metadata(epochs)
|
|
245
|
+
return WindowsDataset(epochs, description=description, targets_from="metadata")
|