braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,693 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os.path
|
|
9
|
+
import re
|
|
10
|
+
import warnings
|
|
11
|
+
from glob import glob
|
|
12
|
+
|
|
13
|
+
import h5py
|
|
14
|
+
import mne
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BBCIDataset(object):
|
|
21
|
+
"""BBCIDataset.
|
|
22
|
+
|
|
23
|
+
Loader class for files created by saving BBCI files in matlab (make
|
|
24
|
+
sure to save with '-v7.3' in matlab, see
|
|
25
|
+
https://de.mathworks.com/help/matlab/import_export/mat-file-versions.html#buk6i87
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
filename : str
|
|
31
|
+
load_sensor_names : list of str, optional
|
|
32
|
+
Also speeds up loading if you only load some sensors.
|
|
33
|
+
None means load all sensors.
|
|
34
|
+
check_class_names : bool, optional
|
|
35
|
+
check if the class names are part of some known class names at
|
|
36
|
+
Translational NeuroTechnology Lab, AG Ball, Freiburg, Germany.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
filename: str,
|
|
42
|
+
load_sensor_names: list[str] | None = None,
|
|
43
|
+
check_class_names: bool = False,
|
|
44
|
+
):
|
|
45
|
+
self.__dict__.update(locals())
|
|
46
|
+
|
|
47
|
+
def load(self) -> mne.io.RawArray:
|
|
48
|
+
cnt = self._load_continuous_signal()
|
|
49
|
+
cnt = self._add_markers(cnt)
|
|
50
|
+
return cnt
|
|
51
|
+
|
|
52
|
+
def _load_continuous_signal(self):
|
|
53
|
+
wanted_chan_inds, wanted_sensor_names = self._determine_sensors()
|
|
54
|
+
fs = self._determine_samplingrate()
|
|
55
|
+
with h5py.File(self.filename, "r") as h5file:
|
|
56
|
+
samples = int(h5file["nfo"]["T"][0, 0])
|
|
57
|
+
cnt_signal_shape = (samples, len(wanted_chan_inds))
|
|
58
|
+
continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan
|
|
59
|
+
for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds):
|
|
60
|
+
# + 1 because matlab/this hdf5-naming logic
|
|
61
|
+
# has 1-based indexing
|
|
62
|
+
# i.e ch1,ch2,....
|
|
63
|
+
chan_set_name = "ch" + str(chan_ind_set + 1)
|
|
64
|
+
# first 0 to unpack into vector, before it is 1xN matrix
|
|
65
|
+
chan_signal = h5file[chan_set_name][
|
|
66
|
+
:
|
|
67
|
+
].squeeze() # already load into memory
|
|
68
|
+
continuous_signal[:, chan_ind_arr] = chan_signal
|
|
69
|
+
assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal"
|
|
70
|
+
|
|
71
|
+
if self.load_sensor_names is None:
|
|
72
|
+
ch_types = ["eeg"] * len(wanted_chan_inds)
|
|
73
|
+
else:
|
|
74
|
+
warnings.warn("Setting to misc channel type as channel type not known")
|
|
75
|
+
# Assume we can't know channel type here automatically
|
|
76
|
+
ch_types = ["misc"] * len(wanted_chan_inds)
|
|
77
|
+
info = mne.create_info(
|
|
78
|
+
ch_names=wanted_sensor_names, sfreq=fs, ch_types=ch_types
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
cnt = mne.io.RawArray(continuous_signal.T, info)
|
|
82
|
+
return cnt
|
|
83
|
+
|
|
84
|
+
def _determine_sensors(self):
|
|
85
|
+
all_sensor_names = self.get_all_sensors(self.filename, pattern=None)
|
|
86
|
+
if self.load_sensor_names is None:
|
|
87
|
+
# if no sensor names given, take all EEG-chans
|
|
88
|
+
eeg_sensor_names = all_sensor_names
|
|
89
|
+
eeg_sensor_names = filter(
|
|
90
|
+
lambda s: not s.startswith("BIP"), eeg_sensor_names
|
|
91
|
+
)
|
|
92
|
+
eeg_sensor_names = filter(lambda s: not s.startswith("E"), eeg_sensor_names)
|
|
93
|
+
eeg_sensor_names = filter(
|
|
94
|
+
lambda s: not s.startswith("Microphone"), eeg_sensor_names
|
|
95
|
+
)
|
|
96
|
+
eeg_sensor_names = filter(
|
|
97
|
+
lambda s: not s.startswith("Breath"), eeg_sensor_names
|
|
98
|
+
)
|
|
99
|
+
eeg_sensor_names = filter(
|
|
100
|
+
lambda s: not s.startswith("GSR"), eeg_sensor_names
|
|
101
|
+
)
|
|
102
|
+
eeg_sensor_names = list(eeg_sensor_names)
|
|
103
|
+
assert (
|
|
104
|
+
len(eeg_sensor_names) == 128
|
|
105
|
+
or len(eeg_sensor_names) == 64
|
|
106
|
+
or len(eeg_sensor_names) == 32
|
|
107
|
+
or len(eeg_sensor_names) == 16
|
|
108
|
+
), "Recheck this code if you have different sensors..."
|
|
109
|
+
wanted_sensor_names = eeg_sensor_names
|
|
110
|
+
else:
|
|
111
|
+
wanted_sensor_names = self.load_sensor_names
|
|
112
|
+
chan_inds = self._determine_chan_inds(all_sensor_names, wanted_sensor_names)
|
|
113
|
+
return chan_inds, wanted_sensor_names
|
|
114
|
+
|
|
115
|
+
def _determine_samplingrate(self):
|
|
116
|
+
with h5py.File(self.filename, "r") as h5file:
|
|
117
|
+
fs = h5file["nfo"]["fs"][0, 0]
|
|
118
|
+
assert isinstance(fs, int) or fs.is_integer()
|
|
119
|
+
fs = int(fs)
|
|
120
|
+
return fs
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _determine_chan_inds(all_sensor_names, sensor_names):
|
|
124
|
+
assert sensor_names is not None
|
|
125
|
+
chan_inds = [all_sensor_names.index(s) for s in sensor_names]
|
|
126
|
+
assert len(chan_inds) == len(sensor_names), "Allsensors should be there."
|
|
127
|
+
assert len(set(chan_inds)) == len(chan_inds), "No duplicated sensors wanted."
|
|
128
|
+
return chan_inds
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def get_all_sensors(filename: str, pattern: str | None = None) -> list[str]:
|
|
132
|
+
"""
|
|
133
|
+
Get all sensors that exist in the given file.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
filename : str
|
|
138
|
+
pattern : str, optional
|
|
139
|
+
Only return those sensor names that match the given pattern.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
sensor_names : list of str
|
|
144
|
+
Sensor names that match the pattern or all sensor names in the file.
|
|
145
|
+
"""
|
|
146
|
+
with h5py.File(filename, "r") as h5file:
|
|
147
|
+
clab_set = h5file["nfo"]["clab"][:].squeeze()
|
|
148
|
+
all_sensor_names = [
|
|
149
|
+
"".join(chr(c.item()) for c in h5file[obj_ref]) for obj_ref in clab_set
|
|
150
|
+
]
|
|
151
|
+
if pattern is not None:
|
|
152
|
+
all_sensor_names = list(
|
|
153
|
+
filter(lambda sname: re.search(pattern, sname), all_sensor_names)
|
|
154
|
+
)
|
|
155
|
+
return all_sensor_names
|
|
156
|
+
|
|
157
|
+
def _add_markers(self, cnt):
|
|
158
|
+
with h5py.File(self.filename, "r") as h5file:
|
|
159
|
+
event_times_in_ms = h5file["mrk"]["time"][:].squeeze()
|
|
160
|
+
event_classes = h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64)
|
|
161
|
+
|
|
162
|
+
# Check whether class names known and correct order
|
|
163
|
+
class_name_set = h5file["nfo"]["className"][:].squeeze()
|
|
164
|
+
all_class_names = [
|
|
165
|
+
"".join(chr(c.item()) for c in h5file[obj_ref])
|
|
166
|
+
for obj_ref in class_name_set
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
if self.check_class_names:
|
|
170
|
+
_check_class_names(all_class_names, event_times_in_ms, event_classes)
|
|
171
|
+
|
|
172
|
+
event_times_in_samples = event_times_in_ms * cnt.info["sfreq"] / 1000.0
|
|
173
|
+
event_times_in_samples = np.uint32(np.round(event_times_in_samples))
|
|
174
|
+
|
|
175
|
+
# Check if there are markers at the same time
|
|
176
|
+
previous_i_sample = -1
|
|
177
|
+
for i_event, (i_sample, id_class) in enumerate(
|
|
178
|
+
zip(event_times_in_samples, event_classes)
|
|
179
|
+
):
|
|
180
|
+
if i_sample == previous_i_sample:
|
|
181
|
+
log.warning(
|
|
182
|
+
"Same sample has at least two markers.\n"
|
|
183
|
+
"{:d}: ({:.0f} and {:.0f}).\n".format(
|
|
184
|
+
i_sample,
|
|
185
|
+
event_classes[i_event - 1],
|
|
186
|
+
event_classes[i_event],
|
|
187
|
+
)
|
|
188
|
+
+ "Marker codes will be summed."
|
|
189
|
+
)
|
|
190
|
+
previous_i_sample = i_sample
|
|
191
|
+
|
|
192
|
+
# Now create stim chan
|
|
193
|
+
stim_chan = np.zeros_like(cnt.get_data()[0])
|
|
194
|
+
for i_sample, id_class in zip(event_times_in_samples, event_classes):
|
|
195
|
+
stim_chan[i_sample] += id_class
|
|
196
|
+
info = mne.create_info(
|
|
197
|
+
ch_names=["STI 014"], sfreq=cnt.info["sfreq"], ch_types=["stim"]
|
|
198
|
+
)
|
|
199
|
+
stim_cnt = mne.io.RawArray(stim_chan[None], info, verbose="WARNING")
|
|
200
|
+
cnt = cnt.add_channels([stim_cnt])
|
|
201
|
+
event_arr = [
|
|
202
|
+
event_times_in_samples,
|
|
203
|
+
[0] * len(event_times_in_samples),
|
|
204
|
+
event_classes,
|
|
205
|
+
]
|
|
206
|
+
cnt.info["events"] = np.array(event_arr).T
|
|
207
|
+
|
|
208
|
+
# Generate Annotations
|
|
209
|
+
event_times_in_sec = event_times_in_ms / 1000.0
|
|
210
|
+
# Hacky way to try to find out class names for each event
|
|
211
|
+
# h5file['mrk']['y'] y contains one-hot label for event name
|
|
212
|
+
with h5py.File(self.filename, "r") as h5file:
|
|
213
|
+
y = h5file["mrk"]["y"][:]
|
|
214
|
+
# seems that there are cases where for last class
|
|
215
|
+
# y is just all zero for some reason?
|
|
216
|
+
# and seems then it is last of the class names
|
|
217
|
+
# ('Stimulation')
|
|
218
|
+
# at least in the file investigated
|
|
219
|
+
y[np.sum(y, axis=1) == 0, -1] = 1
|
|
220
|
+
assert np.all(np.sum(y, axis=1) == 1)
|
|
221
|
+
event_i_classes = np.argmax(y, axis=1)
|
|
222
|
+
|
|
223
|
+
# 4 second trials for High-Gamma dataset, otherwise how to know?
|
|
224
|
+
if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
|
|
225
|
+
durations = np.full(event_times_in_ms.shape, 4)
|
|
226
|
+
else:
|
|
227
|
+
warnings.warn("Unknown event durations set to 0")
|
|
228
|
+
durations = np.full(event_times_in_ms.shape, 0)
|
|
229
|
+
|
|
230
|
+
# Label information for this dataset
|
|
231
|
+
descriptions = [all_class_names[y] for y in event_i_classes]
|
|
232
|
+
annots = mne.Annotations(event_times_in_sec, durations, descriptions)
|
|
233
|
+
cnt.set_annotations(annots)
|
|
234
|
+
|
|
235
|
+
return cnt
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
239
|
+
"""
|
|
240
|
+
Checks if the class names are part of some known class names used in.
|
|
241
|
+
|
|
242
|
+
translational neurotechnology lab, AG Ball, Freiburg.
|
|
243
|
+
|
|
244
|
+
Logs warning in case class names are not known.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
all_class_names : list of str
|
|
249
|
+
event_times_in_ms : list of number
|
|
250
|
+
event_classes : list of number
|
|
251
|
+
"""
|
|
252
|
+
if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
|
|
253
|
+
pass
|
|
254
|
+
elif (
|
|
255
|
+
(
|
|
256
|
+
all_class_names
|
|
257
|
+
== [
|
|
258
|
+
"1",
|
|
259
|
+
"10",
|
|
260
|
+
"11",
|
|
261
|
+
"111",
|
|
262
|
+
"12",
|
|
263
|
+
"13",
|
|
264
|
+
"150",
|
|
265
|
+
"2",
|
|
266
|
+
"20",
|
|
267
|
+
"22",
|
|
268
|
+
"3",
|
|
269
|
+
"30",
|
|
270
|
+
"33",
|
|
271
|
+
"4",
|
|
272
|
+
"40",
|
|
273
|
+
"44",
|
|
274
|
+
"99",
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
or (
|
|
278
|
+
all_class_names
|
|
279
|
+
== [
|
|
280
|
+
"1",
|
|
281
|
+
"10",
|
|
282
|
+
"11",
|
|
283
|
+
"12",
|
|
284
|
+
"13",
|
|
285
|
+
"150",
|
|
286
|
+
"2",
|
|
287
|
+
"20",
|
|
288
|
+
"22",
|
|
289
|
+
"3",
|
|
290
|
+
"30",
|
|
291
|
+
"33",
|
|
292
|
+
"4",
|
|
293
|
+
"40",
|
|
294
|
+
"44",
|
|
295
|
+
"99",
|
|
296
|
+
]
|
|
297
|
+
)
|
|
298
|
+
or (all_class_names == ["1", "2", "3", "4"])
|
|
299
|
+
):
|
|
300
|
+
pass # Semantic classes
|
|
301
|
+
elif all_class_names == ["Rest", "Feet", "Left Hand", "Right Hand"]:
|
|
302
|
+
# Have to swap from
|
|
303
|
+
# ['Rest', 'Feet', 'Left Hand', 'Right Hand']
|
|
304
|
+
# to
|
|
305
|
+
# ['Right Hand', 'Left Hand', 'Rest', 'Feet']
|
|
306
|
+
right_mask = event_classes == 4
|
|
307
|
+
left_mask = event_classes == 3
|
|
308
|
+
rest_mask = event_classes == 1
|
|
309
|
+
feet_mask = event_classes == 2
|
|
310
|
+
event_classes[right_mask] = 1
|
|
311
|
+
event_classes[left_mask] = 2
|
|
312
|
+
event_classes[rest_mask] = 3
|
|
313
|
+
event_classes[feet_mask] = 4
|
|
314
|
+
log.warn(
|
|
315
|
+
"Swapped class names {:s}... might cause problems...".format(
|
|
316
|
+
all_class_names
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
elif all_class_names == [
|
|
320
|
+
"Right Hand Start",
|
|
321
|
+
"Left Hand Start",
|
|
322
|
+
"Rest Start",
|
|
323
|
+
"Feet Start",
|
|
324
|
+
"Right Hand End",
|
|
325
|
+
"Left Hand End",
|
|
326
|
+
"Rest End",
|
|
327
|
+
"Feet End",
|
|
328
|
+
]:
|
|
329
|
+
pass
|
|
330
|
+
elif all_class_names == [
|
|
331
|
+
"Right Hand",
|
|
332
|
+
"Left Hand",
|
|
333
|
+
"Rest",
|
|
334
|
+
"Feet",
|
|
335
|
+
"Face",
|
|
336
|
+
"Navigation",
|
|
337
|
+
"Music",
|
|
338
|
+
"Rotation",
|
|
339
|
+
"Subtraction",
|
|
340
|
+
"Words",
|
|
341
|
+
]:
|
|
342
|
+
pass # robot hall 10 class decoding
|
|
343
|
+
elif all_class_names == [
|
|
344
|
+
"RightHand",
|
|
345
|
+
"Feet",
|
|
346
|
+
"Rotation",
|
|
347
|
+
"Words",
|
|
348
|
+
"\x00\x00",
|
|
349
|
+
"\x00\x00",
|
|
350
|
+
"\x00\x00",
|
|
351
|
+
"\x00\x00",
|
|
352
|
+
"\x00\x00",
|
|
353
|
+
"RightHand_End",
|
|
354
|
+
"\x00\x00",
|
|
355
|
+
"\x00\x00",
|
|
356
|
+
"\x00\x00",
|
|
357
|
+
"\x00\x00",
|
|
358
|
+
"\x00\x00",
|
|
359
|
+
"\x00\x00",
|
|
360
|
+
"\x00\x00",
|
|
361
|
+
"\x00\x00",
|
|
362
|
+
"\x00\x00",
|
|
363
|
+
"Feet_End",
|
|
364
|
+
"\x00\x00",
|
|
365
|
+
"\x00\x00",
|
|
366
|
+
"\x00\x00",
|
|
367
|
+
"\x00\x00",
|
|
368
|
+
"\x00\x00",
|
|
369
|
+
"\x00\x00",
|
|
370
|
+
"\x00\x00",
|
|
371
|
+
"\x00\x00",
|
|
372
|
+
"\x00\x00",
|
|
373
|
+
"Rotation_End",
|
|
374
|
+
"\x00\x00",
|
|
375
|
+
"\x00\x00",
|
|
376
|
+
"\x00\x00",
|
|
377
|
+
"\x00\x00",
|
|
378
|
+
"\x00\x00",
|
|
379
|
+
"\x00\x00",
|
|
380
|
+
"\x00\x00",
|
|
381
|
+
"\x00\x00",
|
|
382
|
+
"\x00\x00",
|
|
383
|
+
"Words_End",
|
|
384
|
+
] or all_class_names == [
|
|
385
|
+
"RightHand",
|
|
386
|
+
"Feet",
|
|
387
|
+
"Rotation",
|
|
388
|
+
"Words",
|
|
389
|
+
"Rest",
|
|
390
|
+
"\x00\x00",
|
|
391
|
+
"\x00\x00",
|
|
392
|
+
"\x00\x00",
|
|
393
|
+
"\x00\x00",
|
|
394
|
+
"RightHand_End",
|
|
395
|
+
"\x00\x00",
|
|
396
|
+
"\x00\x00",
|
|
397
|
+
"\x00\x00",
|
|
398
|
+
"\x00\x00",
|
|
399
|
+
"\x00\x00",
|
|
400
|
+
"\x00\x00",
|
|
401
|
+
"\x00\x00",
|
|
402
|
+
"\x00\x00",
|
|
403
|
+
"\x00\x00",
|
|
404
|
+
"Feet_End",
|
|
405
|
+
"\x00\x00",
|
|
406
|
+
"\x00\x00",
|
|
407
|
+
"\x00\x00",
|
|
408
|
+
"\x00\x00",
|
|
409
|
+
"\x00\x00",
|
|
410
|
+
"\x00\x00",
|
|
411
|
+
"\x00\x00",
|
|
412
|
+
"\x00\x00",
|
|
413
|
+
"\x00\x00",
|
|
414
|
+
"Rotation_End",
|
|
415
|
+
"\x00\x00",
|
|
416
|
+
"\x00\x00",
|
|
417
|
+
"\x00\x00",
|
|
418
|
+
"\x00\x00",
|
|
419
|
+
"\x00\x00",
|
|
420
|
+
"\x00\x00",
|
|
421
|
+
"\x00\x00",
|
|
422
|
+
"\x00\x00",
|
|
423
|
+
"\x00\x00",
|
|
424
|
+
"Words_End",
|
|
425
|
+
"\x00\x00",
|
|
426
|
+
"\x00\x00",
|
|
427
|
+
"\x00\x00",
|
|
428
|
+
"\x00\x00",
|
|
429
|
+
"\x00\x00",
|
|
430
|
+
"\x00\x00",
|
|
431
|
+
"\x00\x00",
|
|
432
|
+
"\x00\x00",
|
|
433
|
+
"\x00\x00",
|
|
434
|
+
"Rest_End",
|
|
435
|
+
]:
|
|
436
|
+
pass # weird stuff when we recorded cursor in robot hall
|
|
437
|
+
# on 2016-09-14 and 2016-09-16 :D
|
|
438
|
+
|
|
439
|
+
elif all_class_names == [
|
|
440
|
+
"0004",
|
|
441
|
+
"0016",
|
|
442
|
+
"0032",
|
|
443
|
+
"0056",
|
|
444
|
+
"0064",
|
|
445
|
+
"0088",
|
|
446
|
+
"0095",
|
|
447
|
+
"0120",
|
|
448
|
+
]:
|
|
449
|
+
pass
|
|
450
|
+
elif all_class_names == ["0004", "0056", "0088", "0120"]:
|
|
451
|
+
pass
|
|
452
|
+
elif all_class_names == [
|
|
453
|
+
"0004",
|
|
454
|
+
"0016",
|
|
455
|
+
"0032",
|
|
456
|
+
"0048",
|
|
457
|
+
"0056",
|
|
458
|
+
"0064",
|
|
459
|
+
"0080",
|
|
460
|
+
"0088",
|
|
461
|
+
"0095",
|
|
462
|
+
"0120",
|
|
463
|
+
]:
|
|
464
|
+
pass
|
|
465
|
+
elif all_class_names == ["0004", "0016", "0056", "0088", "0120", "__"]:
|
|
466
|
+
pass
|
|
467
|
+
elif all_class_names == ["0004", "0056", "0088", "0120", "__"]:
|
|
468
|
+
pass
|
|
469
|
+
elif all_class_names == [
|
|
470
|
+
"0004",
|
|
471
|
+
"0032",
|
|
472
|
+
"0048",
|
|
473
|
+
"0056",
|
|
474
|
+
"0064",
|
|
475
|
+
"0080",
|
|
476
|
+
"0088",
|
|
477
|
+
"0095",
|
|
478
|
+
"0120",
|
|
479
|
+
"__",
|
|
480
|
+
]:
|
|
481
|
+
pass
|
|
482
|
+
elif all_class_names == [
|
|
483
|
+
"0004",
|
|
484
|
+
"0056",
|
|
485
|
+
"0080",
|
|
486
|
+
"0088",
|
|
487
|
+
"0096",
|
|
488
|
+
"0120",
|
|
489
|
+
"__",
|
|
490
|
+
]:
|
|
491
|
+
pass
|
|
492
|
+
elif all_class_names == [
|
|
493
|
+
"0004",
|
|
494
|
+
"0032",
|
|
495
|
+
"0056",
|
|
496
|
+
"0064",
|
|
497
|
+
"0080",
|
|
498
|
+
"0088",
|
|
499
|
+
"0095",
|
|
500
|
+
"0120",
|
|
501
|
+
]:
|
|
502
|
+
pass
|
|
503
|
+
elif all_class_names == [
|
|
504
|
+
"0004",
|
|
505
|
+
"0032",
|
|
506
|
+
"0048",
|
|
507
|
+
"0056",
|
|
508
|
+
"0064",
|
|
509
|
+
"0080",
|
|
510
|
+
"0088",
|
|
511
|
+
"0095",
|
|
512
|
+
"0120",
|
|
513
|
+
]:
|
|
514
|
+
pass
|
|
515
|
+
elif all_class_names == [
|
|
516
|
+
"0004",
|
|
517
|
+
"0016",
|
|
518
|
+
"0032",
|
|
519
|
+
"0048",
|
|
520
|
+
"0056",
|
|
521
|
+
"0064",
|
|
522
|
+
"0080",
|
|
523
|
+
"0088",
|
|
524
|
+
"0095",
|
|
525
|
+
"0096",
|
|
526
|
+
"0120",
|
|
527
|
+
]:
|
|
528
|
+
pass
|
|
529
|
+
elif all_class_names == ["4", "16", "32", "56", "64", "88", "95", "120"]:
|
|
530
|
+
pass
|
|
531
|
+
elif all_class_names == ["4", "56", "88", "120"]:
|
|
532
|
+
pass
|
|
533
|
+
elif all_class_names == [
|
|
534
|
+
"4",
|
|
535
|
+
"16",
|
|
536
|
+
"32",
|
|
537
|
+
"48",
|
|
538
|
+
"56",
|
|
539
|
+
"64",
|
|
540
|
+
"80",
|
|
541
|
+
"88",
|
|
542
|
+
"95",
|
|
543
|
+
"120",
|
|
544
|
+
]:
|
|
545
|
+
pass
|
|
546
|
+
elif all_class_names == ["0", "4", "56", "88", "120"]:
|
|
547
|
+
pass
|
|
548
|
+
elif all_class_names == ["0", "4", "16", "56", "88", "120"]:
|
|
549
|
+
pass
|
|
550
|
+
elif all_class_names == [
|
|
551
|
+
"0",
|
|
552
|
+
"4",
|
|
553
|
+
"32",
|
|
554
|
+
"48",
|
|
555
|
+
"56",
|
|
556
|
+
"64",
|
|
557
|
+
"80",
|
|
558
|
+
"88",
|
|
559
|
+
"95",
|
|
560
|
+
"120",
|
|
561
|
+
]:
|
|
562
|
+
pass
|
|
563
|
+
elif all_class_names == ["0", "4", "56", "80", "88", "96", "120"]:
|
|
564
|
+
pass
|
|
565
|
+
elif all_class_names == ["4", "32", "56", "64", "80", "88", "95", "120"]:
|
|
566
|
+
pass
|
|
567
|
+
elif all_class_names == ["One", "Two", "Three", "Four"]:
|
|
568
|
+
pass
|
|
569
|
+
elif all_class_names == [
|
|
570
|
+
"1",
|
|
571
|
+
"10",
|
|
572
|
+
"11",
|
|
573
|
+
"12",
|
|
574
|
+
"2",
|
|
575
|
+
"20",
|
|
576
|
+
"3",
|
|
577
|
+
"30",
|
|
578
|
+
"4",
|
|
579
|
+
"40",
|
|
580
|
+
]:
|
|
581
|
+
pass
|
|
582
|
+
elif all_class_names == [
|
|
583
|
+
"1",
|
|
584
|
+
"10",
|
|
585
|
+
"12",
|
|
586
|
+
"13",
|
|
587
|
+
"2",
|
|
588
|
+
"20",
|
|
589
|
+
"3",
|
|
590
|
+
"30",
|
|
591
|
+
"4",
|
|
592
|
+
"40",
|
|
593
|
+
]:
|
|
594
|
+
pass
|
|
595
|
+
elif all_class_names == [
|
|
596
|
+
"1",
|
|
597
|
+
"10",
|
|
598
|
+
"13",
|
|
599
|
+
"2",
|
|
600
|
+
"20",
|
|
601
|
+
"3",
|
|
602
|
+
"30",
|
|
603
|
+
"4",
|
|
604
|
+
"40",
|
|
605
|
+
"99",
|
|
606
|
+
]:
|
|
607
|
+
pass
|
|
608
|
+
elif all_class_names == [
|
|
609
|
+
"1",
|
|
610
|
+
"10",
|
|
611
|
+
"11",
|
|
612
|
+
"14",
|
|
613
|
+
"18",
|
|
614
|
+
"20",
|
|
615
|
+
"21",
|
|
616
|
+
"24",
|
|
617
|
+
"251",
|
|
618
|
+
"252",
|
|
619
|
+
"28",
|
|
620
|
+
"30",
|
|
621
|
+
"4",
|
|
622
|
+
"8",
|
|
623
|
+
]:
|
|
624
|
+
pass
|
|
625
|
+
elif all_class_names == [
|
|
626
|
+
"1",
|
|
627
|
+
"10",
|
|
628
|
+
"11",
|
|
629
|
+
"14",
|
|
630
|
+
"18",
|
|
631
|
+
"20",
|
|
632
|
+
"21",
|
|
633
|
+
"24",
|
|
634
|
+
"252",
|
|
635
|
+
"253",
|
|
636
|
+
"28",
|
|
637
|
+
"30",
|
|
638
|
+
"4",
|
|
639
|
+
"8",
|
|
640
|
+
]:
|
|
641
|
+
pass
|
|
642
|
+
elif len(event_times_in_ms) == len(all_class_names):
|
|
643
|
+
pass # weird neuroone(?) logic where class names have event classes
|
|
644
|
+
elif all_class_names == [
|
|
645
|
+
"Right_hand_stimulus_onset",
|
|
646
|
+
"Feet_stimulus_onset",
|
|
647
|
+
"Rotation_stimulus_onset",
|
|
648
|
+
"Words_stimulus_onset",
|
|
649
|
+
"Right_hand_stimulus_offset",
|
|
650
|
+
"Feet_stimulus_offset",
|
|
651
|
+
"Rotation_stimulus_offset",
|
|
652
|
+
"Words_stimulus_offset",
|
|
653
|
+
]:
|
|
654
|
+
pass
|
|
655
|
+
else:
|
|
656
|
+
# remove this whole if else stuffs?
|
|
657
|
+
log.warn("Unknown class names {:s}".format(all_class_names))
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def load_bbci_sets_from_folder(
|
|
661
|
+
folder: str, runs: list[int] | str = "all"
|
|
662
|
+
) -> list[mne.io.RawArray]:
|
|
663
|
+
"""
|
|
664
|
+
Load bbci datasets from files in given folder.
|
|
665
|
+
|
|
666
|
+
Parameters
|
|
667
|
+
----------
|
|
668
|
+
folder : str
|
|
669
|
+
Folder with .BBCI.mat files inside
|
|
670
|
+
runs : list of int
|
|
671
|
+
If you only want to load specific runs.
|
|
672
|
+
Assumes filenames with such kind of part: S001R02 for Run 2.
|
|
673
|
+
Tries to match this regex: ``'S[0-9]{3,3}R[0-9]{2,2}_'``.
|
|
674
|
+
|
|
675
|
+
Returns
|
|
676
|
+
-------
|
|
677
|
+
"""
|
|
678
|
+
bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
|
|
679
|
+
if runs != "all":
|
|
680
|
+
assert isinstance(runs, list), "runs should be list[int] or 'all'"
|
|
681
|
+
matches = [re.search("S[0-9]{3,3}R[0-9]{2,2}_", f) for f in bbci_mat_files]
|
|
682
|
+
file_run_numbers = [int(m.group()[5:7]) for m in matches if m is not None]
|
|
683
|
+
assert len(file_run_numbers) == len(bbci_mat_files), "Some files don't match"
|
|
684
|
+
indices = [file_run_numbers.index(num) for num in runs]
|
|
685
|
+
|
|
686
|
+
wanted_files = np.array(bbci_mat_files)[indices]
|
|
687
|
+
else:
|
|
688
|
+
wanted_files = bbci_mat_files
|
|
689
|
+
cnts = []
|
|
690
|
+
for f in wanted_files:
|
|
691
|
+
log.info("Loading {:s}".format(f))
|
|
692
|
+
cnts.append(BBCIDataset(f).load())
|
|
693
|
+
return cnts
|