braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
2
|
+
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import glob
|
|
9
|
+
import os
|
|
10
|
+
import os.path as osp
|
|
11
|
+
from os import remove
|
|
12
|
+
from shutil import unpack_archive
|
|
13
|
+
|
|
14
|
+
import mne
|
|
15
|
+
import numpy as np
|
|
16
|
+
from mne.utils import verbose
|
|
17
|
+
from scipy.io import loadmat
|
|
18
|
+
|
|
19
|
+
from braindecode.datasets import BaseConcatDataset, BaseDataset
|
|
20
|
+
|
|
21
|
+
DATASET_URL = (
|
|
22
|
+
"https://stacks.stanford.edu/file/druid:zk881ps0522/"
|
|
23
|
+
"BCI_Competion4_dataset4_data_fingerflexions.zip"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
28
|
+
"""BCI competition IV dataset 4.
|
|
29
|
+
|
|
30
|
+
Contains ECoG recordings for three patients moving fingers during the experiment.
|
|
31
|
+
Targets correspond to the time courses of the flexion of each of five fingers.
|
|
32
|
+
See http://www.bbci.de/competition/iv/desc_4.pdf and
|
|
33
|
+
http://www.bbci.de/competition/iv/ for the dataset and competition description.
|
|
34
|
+
ECoG library containing the dataset: https://searchworks.stanford.edu/view/zk881ps0522
|
|
35
|
+
|
|
36
|
+
Notes
|
|
37
|
+
-----
|
|
38
|
+
When using this dataset please cite [1]_ .
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
subject_ids : list(int) | int | None
|
|
43
|
+
(list of) int of subject(s) to be loaded. If None, load all available
|
|
44
|
+
subjects. Should be in range 1-3.
|
|
45
|
+
|
|
46
|
+
References
|
|
47
|
+
----------
|
|
48
|
+
.. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
|
|
49
|
+
Nature human behaviour 3, no. 11 (2019): 1225-1235.
|
|
50
|
+
https://doi.org/10.1038/s41562-019-0678-3
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
possible_subjects = [1, 2, 3]
|
|
54
|
+
|
|
55
|
+
def __init__(self, subject_ids: list[int] | int | None = None):
|
|
56
|
+
data_path = self.download()
|
|
57
|
+
if isinstance(subject_ids, int):
|
|
58
|
+
subject_ids = [subject_ids]
|
|
59
|
+
if subject_ids is None:
|
|
60
|
+
subject_ids = self.possible_subjects
|
|
61
|
+
self._validate_subjects(subject_ids)
|
|
62
|
+
files_list = [f"{data_path}/sub{i}_comp.mat" for i in subject_ids]
|
|
63
|
+
datasets = []
|
|
64
|
+
for file_path in files_list:
|
|
65
|
+
raw_train, raw_test = self._load_data_to_mne(file_path)
|
|
66
|
+
desc_train = dict(
|
|
67
|
+
subject=file_path.split("/")[-1].split("sub")[1][0],
|
|
68
|
+
file_name=file_path.split("/")[-1],
|
|
69
|
+
session="train",
|
|
70
|
+
)
|
|
71
|
+
desc_test = dict(
|
|
72
|
+
subject=file_path.split("/")[-1].split("sub")[1][0],
|
|
73
|
+
file_name=file_path.split("/")[-1],
|
|
74
|
+
session="test",
|
|
75
|
+
)
|
|
76
|
+
datasets.append(BaseDataset(raw_train, description=desc_train))
|
|
77
|
+
datasets.append(BaseDataset(raw_test, description=desc_test))
|
|
78
|
+
super().__init__(datasets)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def download(path=None, force_update=False, verbose=None):
|
|
82
|
+
"""Download the dataset.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
path (None | str) – Location of where to look for the data storing location.
|
|
87
|
+
If None, the environment variable or config parameter
|
|
88
|
+
MNE_DATASETS_(dataset)_PATH is used. If it doesn’t exist, the “~/mne_data”
|
|
89
|
+
directory is used. If the dataset is not found under the given path, the data
|
|
90
|
+
will be automatically downloaded to the specified folder.
|
|
91
|
+
force_update (bool) – Force update of the dataset even if a local copy exists.
|
|
92
|
+
verbose (bool, str, int, or None) – If not None, override default verbose level
|
|
93
|
+
(see mne.verbose())
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
signature = "BCICompetitionIVDataset4"
|
|
100
|
+
folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
|
|
101
|
+
# Check if the dataset already exists (unpacked). We have to do that manually
|
|
102
|
+
# because we are removing .zip file from disk to save disk space.
|
|
103
|
+
|
|
104
|
+
from moabb.datasets.download import get_dataset_path # keep soft dependency
|
|
105
|
+
|
|
106
|
+
path = get_dataset_path(signature, path)
|
|
107
|
+
key_dest = "MNE-{:s}-data".format(signature.lower())
|
|
108
|
+
# We do not use mne _url_to_local_path due to ':' in the url that causes problems on Windows
|
|
109
|
+
destination = osp.join(path, key_dest, folder_name)
|
|
110
|
+
if len(list(glob.glob(osp.join(destination, "*.mat")))) == 6:
|
|
111
|
+
return destination
|
|
112
|
+
data_path = _data_dl(
|
|
113
|
+
DATASET_URL,
|
|
114
|
+
osp.join(destination, folder_name, signature),
|
|
115
|
+
force_update=force_update,
|
|
116
|
+
)
|
|
117
|
+
unpack_archive(data_path, osp.dirname(destination))
|
|
118
|
+
# removes .zip file that the data was unpacked from
|
|
119
|
+
remove(data_path)
|
|
120
|
+
return destination
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _prepare_targets(upsampled_targets, targets_stride):
|
|
124
|
+
original_targets = np.full_like(upsampled_targets, np.nan)
|
|
125
|
+
original_targets[::targets_stride] = upsampled_targets[::targets_stride]
|
|
126
|
+
return original_targets
|
|
127
|
+
|
|
128
|
+
def _load_data_to_mne(self, file_path):
|
|
129
|
+
data = loadmat(file_path)
|
|
130
|
+
test_labels = loadmat(file_path.replace("comp.mat", "testlabels.mat"))
|
|
131
|
+
train_data = data["train_data"]
|
|
132
|
+
test_data = data["test_data"]
|
|
133
|
+
upsampled_train_targets = data["train_dg"]
|
|
134
|
+
upsampled_test_targets = test_labels["test_dg"]
|
|
135
|
+
|
|
136
|
+
signal_sfreq = 1000
|
|
137
|
+
original_target_sfreq = 25
|
|
138
|
+
targets_stride = int(signal_sfreq / original_target_sfreq)
|
|
139
|
+
|
|
140
|
+
original_targets = self._prepare_targets(
|
|
141
|
+
upsampled_train_targets, targets_stride
|
|
142
|
+
)
|
|
143
|
+
original_test_targets = self._prepare_targets(
|
|
144
|
+
upsampled_test_targets, targets_stride
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
ch_names = [f"{i}" for i in range(train_data.shape[1])]
|
|
148
|
+
ch_names += [f"target_{i}" for i in range(original_targets.shape[1])]
|
|
149
|
+
ch_types = ["ecog" for _ in range(train_data.shape[1])]
|
|
150
|
+
ch_types += ["misc" for _ in range(original_targets.shape[1])]
|
|
151
|
+
|
|
152
|
+
info = mne.create_info(sfreq=signal_sfreq, ch_names=ch_names, ch_types=ch_types)
|
|
153
|
+
info["temp"] = dict(target_sfreq=original_target_sfreq)
|
|
154
|
+
train_data = np.concatenate([train_data, original_targets], axis=1)
|
|
155
|
+
test_data = np.concatenate([test_data, original_test_targets], axis=1)
|
|
156
|
+
|
|
157
|
+
raw_train = mne.io.RawArray(train_data.T, info=info)
|
|
158
|
+
raw_test = mne.io.RawArray(test_data.T, info=info)
|
|
159
|
+
# TODO: show how to resample targets
|
|
160
|
+
return raw_train, raw_test
|
|
161
|
+
|
|
162
|
+
def _validate_subjects(self, subject_ids):
|
|
163
|
+
if isinstance(subject_ids, (list, tuple)):
|
|
164
|
+
if not all((subject in self.possible_subjects for subject in subject_ids)):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Wrong subject_ids parameter. Possible values: {self.possible_subjects}. "
|
|
167
|
+
f"Provided {subject_ids}."
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Wrong subject_ids format. Expected types: None, list, tuple, int."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@verbose
|
|
176
|
+
def _data_dl(url, destination, force_update=False, verbose=None):
|
|
177
|
+
# Code taken from moabb due to problem with ':' occurring in path
|
|
178
|
+
# On Windows ':' is a forbidden in folder name
|
|
179
|
+
# moabb/datasets/download.py
|
|
180
|
+
|
|
181
|
+
from pooch import file_hash, retrieve # keep soft dependency
|
|
182
|
+
|
|
183
|
+
if not osp.isfile(destination) or force_update:
|
|
184
|
+
if osp.isfile(destination):
|
|
185
|
+
os.remove(destination)
|
|
186
|
+
if not osp.isdir(osp.dirname(destination)):
|
|
187
|
+
os.makedirs(osp.dirname(destination))
|
|
188
|
+
known_hash = None
|
|
189
|
+
else:
|
|
190
|
+
known_hash = file_hash(destination)
|
|
191
|
+
data_path = retrieve(
|
|
192
|
+
url, known_hash, fname=osp.basename(url), path=osp.dirname(destination)
|
|
193
|
+
)
|
|
194
|
+
return data_path
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Dataset for loading BIDS.
|
|
2
|
+
|
|
3
|
+
More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
7
|
+
#
|
|
8
|
+
# License: BSD (3-clause)
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import mne
|
|
17
|
+
import mne_bids
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
from joblib import Parallel, delayed
|
|
21
|
+
|
|
22
|
+
from .base import BaseConcatDataset, BaseDataset, WindowsDataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
|
|
26
|
+
return {
|
|
27
|
+
"path": bids_path.fpath,
|
|
28
|
+
"subject": bids_path.subject,
|
|
29
|
+
"session": bids_path.session,
|
|
30
|
+
"task": bids_path.task,
|
|
31
|
+
"acquisition": bids_path.acquisition,
|
|
32
|
+
"run": bids_path.run,
|
|
33
|
+
"processing": bids_path.processing,
|
|
34
|
+
"recording": bids_path.recording,
|
|
35
|
+
"space": bids_path.space,
|
|
36
|
+
"split": bids_path.split,
|
|
37
|
+
"description": bids_path.description,
|
|
38
|
+
"suffix": bids_path.suffix,
|
|
39
|
+
"extension": bids_path.extension,
|
|
40
|
+
"datatype": bids_path.datatype,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class BIDSDataset(BaseConcatDataset):
|
|
46
|
+
"""Dataset for loading BIDS.
|
|
47
|
+
|
|
48
|
+
This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
|
|
49
|
+
as it will be used to find the files to load. The default ``extensions`` parameter was changed.
|
|
50
|
+
|
|
51
|
+
More information on BIDS (Brain Imaging Data Structure)
|
|
52
|
+
can be found at https://bids.neuroimaging.io
|
|
53
|
+
|
|
54
|
+
.. Note::
|
|
55
|
+
For loading "unofficial" BIDS datasets containing epoched data,
|
|
56
|
+
you can use :class:`BIDSEpochsDataset`.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
root : pathlib.Path | str
|
|
61
|
+
The root of the BIDS path.
|
|
62
|
+
subjects : str | array-like of str | None
|
|
63
|
+
The subject ID. Corresponds to "sub".
|
|
64
|
+
sessions : str | array-like of str | None
|
|
65
|
+
The acquisition session. Corresponds to "ses".
|
|
66
|
+
tasks : str | array-like of str | None
|
|
67
|
+
The experimental task. Corresponds to "task".
|
|
68
|
+
acquisitions: str | array-like of str | None
|
|
69
|
+
The acquisition parameters. Corresponds to "acq".
|
|
70
|
+
runs : str | array-like of str | None
|
|
71
|
+
The run number. Corresponds to "run".
|
|
72
|
+
processings : str | array-like of str | None
|
|
73
|
+
The processing label. Corresponds to "proc".
|
|
74
|
+
recordings : str | array-like of str | None
|
|
75
|
+
The recording name. Corresponds to "rec".
|
|
76
|
+
spaces : str | array-like of str | None
|
|
77
|
+
The coordinate space for anatomical and sensor location
|
|
78
|
+
files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
|
|
79
|
+
Corresponds to "space".
|
|
80
|
+
Note that valid values for ``space`` must come from a list
|
|
81
|
+
of BIDS keywords as described in the BIDS specification.
|
|
82
|
+
splits : str | array-like of str | None
|
|
83
|
+
The split of the continuous recording file for ``.fif`` data.
|
|
84
|
+
Corresponds to "split".
|
|
85
|
+
descriptions : str | array-like of str | None
|
|
86
|
+
This corresponds to the BIDS entity ``desc``. It is used to provide
|
|
87
|
+
additional information for derivative data, e.g., preprocessed data
|
|
88
|
+
may be assigned ``description='cleaned'``.
|
|
89
|
+
suffixes : str | array-like of str | None
|
|
90
|
+
The filename suffix. This is the entity after the
|
|
91
|
+
last ``_`` before the extension. E.g., ``'channels'``.
|
|
92
|
+
The following filename suffix's are accepted:
|
|
93
|
+
'meg', 'markers', 'eeg', 'ieeg', 'T1w',
|
|
94
|
+
'participants', 'scans', 'electrodes', 'coordsystem',
|
|
95
|
+
'channels', 'events', 'headshape', 'digitizer',
|
|
96
|
+
'beh', 'physio', 'stim'
|
|
97
|
+
extensions : str | array-like of str | None
|
|
98
|
+
The extension of the filename. E.g., ``'.json'``.
|
|
99
|
+
By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
|
|
100
|
+
datatypes : str | array-like of str | None
|
|
101
|
+
The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
|
|
102
|
+
``'ieeg'``.
|
|
103
|
+
check : bool
|
|
104
|
+
If ``True``, only returns paths that conform to BIDS. If ``False``
|
|
105
|
+
(default), the ``.check`` attribute of the returned
|
|
106
|
+
:class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
|
|
107
|
+
do conform to BIDS, and to ``False`` for those that don't.
|
|
108
|
+
preload : bool
|
|
109
|
+
If True, preload the data. Defaults to False.
|
|
110
|
+
n_jobs : int
|
|
111
|
+
Number of jobs to run in parallel. Defaults to 1.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
root: Path | str
|
|
115
|
+
subjects: str | list[str] | None = None
|
|
116
|
+
sessions: str | list[str] | None = None
|
|
117
|
+
tasks: str | list[str] | None = None
|
|
118
|
+
acquisitions: str | list[str] | None = None
|
|
119
|
+
runs: str | list[str] | None = None
|
|
120
|
+
processings: str | list[str] | None = None
|
|
121
|
+
recordings: str | list[str] | None = None
|
|
122
|
+
spaces: str | list[str] | None = None
|
|
123
|
+
splits: str | list[str] | None = None
|
|
124
|
+
descriptions: str | list[str] | None = None
|
|
125
|
+
suffixes: str | list[str] | None = None
|
|
126
|
+
extensions: str | list[str] | None = field(
|
|
127
|
+
default_factory=lambda: [
|
|
128
|
+
".con",
|
|
129
|
+
".sqd",
|
|
130
|
+
".pdf",
|
|
131
|
+
".fif",
|
|
132
|
+
".ds",
|
|
133
|
+
".vhdr",
|
|
134
|
+
".set",
|
|
135
|
+
".edf",
|
|
136
|
+
".bdf",
|
|
137
|
+
".EDF",
|
|
138
|
+
".snirf",
|
|
139
|
+
".cdt",
|
|
140
|
+
".mef",
|
|
141
|
+
".nwb",
|
|
142
|
+
]
|
|
143
|
+
)
|
|
144
|
+
datatypes: str | list[str] | None = None
|
|
145
|
+
check: bool = False
|
|
146
|
+
preload: bool = False
|
|
147
|
+
n_jobs: int = 1
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def _filter_out_epochs(self):
|
|
151
|
+
return True
|
|
152
|
+
|
|
153
|
+
def __post_init__(self):
|
|
154
|
+
bids_paths = mne_bids.find_matching_paths(
|
|
155
|
+
root=self.root,
|
|
156
|
+
subjects=self.subjects,
|
|
157
|
+
sessions=self.sessions,
|
|
158
|
+
tasks=self.tasks,
|
|
159
|
+
acquisitions=self.acquisitions,
|
|
160
|
+
runs=self.runs,
|
|
161
|
+
processings=self.processings,
|
|
162
|
+
recordings=self.recordings,
|
|
163
|
+
spaces=self.spaces,
|
|
164
|
+
splits=self.splits,
|
|
165
|
+
descriptions=self.descriptions,
|
|
166
|
+
suffixes=self.suffixes,
|
|
167
|
+
extensions=self.extensions,
|
|
168
|
+
datatypes=self.datatypes,
|
|
169
|
+
check=self.check,
|
|
170
|
+
)
|
|
171
|
+
# Filter out .json files files:
|
|
172
|
+
# (argument ignore_json only available in mne-bids>=0.16)
|
|
173
|
+
bids_paths = [
|
|
174
|
+
bids_path for bids_path in bids_paths if bids_path.extension != ".json"
|
|
175
|
+
]
|
|
176
|
+
# Filter out _epo.fif files:
|
|
177
|
+
if self._filter_out_epochs:
|
|
178
|
+
bids_paths = [
|
|
179
|
+
bids_path
|
|
180
|
+
for bids_path in bids_paths
|
|
181
|
+
if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
all_base_ds = Parallel(n_jobs=self.n_jobs)(
|
|
185
|
+
delayed(self._get_dataset)(bids_path) for bids_path in bids_paths
|
|
186
|
+
)
|
|
187
|
+
super().__init__(all_base_ds)
|
|
188
|
+
|
|
189
|
+
def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> BaseDataset:
|
|
190
|
+
description = _description_from_bids_path(bids_path)
|
|
191
|
+
raw = mne_bids.read_raw_bids(bids_path, verbose=False)
|
|
192
|
+
if self.preload:
|
|
193
|
+
raw.load_data()
|
|
194
|
+
return BaseDataset(raw, description)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class BIDSEpochsDataset(BIDSDataset):
|
|
198
|
+
"""**Experimental** dataset for loading :class:`mne.Epochs` organised in BIDS.
|
|
199
|
+
|
|
200
|
+
The files must end with ``_epo.fif``.
|
|
201
|
+
|
|
202
|
+
.. Warning::
|
|
203
|
+
Epoched data is not officially supported in BIDS.
|
|
204
|
+
|
|
205
|
+
.. Note::
|
|
206
|
+
**Parameters:** This class has the same parameters as :class:`BIDSDataset` except
|
|
207
|
+
for arguments ``datatypes``, ``extensions`` and ``check`` which are fixed.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def _filter_out_epochs(self):
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
def __init__(self, *args, **kwargs):
|
|
215
|
+
super().__init__(
|
|
216
|
+
*args,
|
|
217
|
+
extensions=".fif",
|
|
218
|
+
suffixes="epo",
|
|
219
|
+
check=False,
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _set_metadata(self, epochs: mne.BaseEpochs) -> None:
|
|
224
|
+
# events = mne.events_from_annotations(epochs
|
|
225
|
+
n_times = epochs.times.shape[0]
|
|
226
|
+
# id_event = {v: k for k, v in epochs.event_id.items()}
|
|
227
|
+
annotations = epochs.annotations
|
|
228
|
+
if annotations is not None:
|
|
229
|
+
target = annotations.description
|
|
230
|
+
else:
|
|
231
|
+
id_events = {v: k for k, v in epochs.event_id.items()}
|
|
232
|
+
target = [id_events[event_id] for event_id in epochs.events[:, -1]]
|
|
233
|
+
metadata_dict = {
|
|
234
|
+
"i_window_in_trial": np.zeros(len(epochs)),
|
|
235
|
+
"i_start_in_trial": np.zeros(len(epochs)),
|
|
236
|
+
"i_stop_in_trial": np.zeros(len(epochs)) + n_times,
|
|
237
|
+
"target": target,
|
|
238
|
+
}
|
|
239
|
+
epochs.metadata = pd.DataFrame(metadata_dict)
|
|
240
|
+
|
|
241
|
+
def _get_dataset(self, bids_path):
|
|
242
|
+
description = _description_from_bids_path(bids_path)
|
|
243
|
+
epochs = mne.read_epochs(bids_path.fpath)
|
|
244
|
+
self._set_metadata(epochs)
|
|
245
|
+
return WindowsDataset(epochs, description=description, targets_from="metadata")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
2
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import mne
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from .base import BaseConcatDataset, BaseDataset, WindowsDataset
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_from_mne_raw(
|
|
16
|
+
raws: list[mne.io.BaseRaw],
|
|
17
|
+
trial_start_offset_samples: int,
|
|
18
|
+
trial_stop_offset_samples: int,
|
|
19
|
+
window_size_samples: int,
|
|
20
|
+
window_stride_samples: int,
|
|
21
|
+
drop_last_window: bool,
|
|
22
|
+
descriptions: list[dict | pd.Series] | None = None,
|
|
23
|
+
mapping: dict[str, int] | None = None,
|
|
24
|
+
preload: bool = False,
|
|
25
|
+
drop_bad_windows: bool = True,
|
|
26
|
+
accepted_bads_ratio: float = 0.0,
|
|
27
|
+
) -> BaseConcatDataset:
|
|
28
|
+
"""Create WindowsDatasets from mne.RawArrays
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
raws: array-like
|
|
33
|
+
list of mne.RawArrays
|
|
34
|
+
trial_start_offset_samples: int
|
|
35
|
+
start offset from original trial onsets in samples
|
|
36
|
+
trial_stop_offset_samples: int
|
|
37
|
+
stop offset from original trial stop in samples
|
|
38
|
+
window_size_samples: int
|
|
39
|
+
window size
|
|
40
|
+
window_stride_samples: int
|
|
41
|
+
stride between windows
|
|
42
|
+
drop_last_window: bool
|
|
43
|
+
whether or not have a last overlapping window, when
|
|
44
|
+
windows do not equally divide the continuous signal
|
|
45
|
+
descriptions: array-like
|
|
46
|
+
list of dicts or pandas.Series with additional information about the raws
|
|
47
|
+
mapping: dict(str: int)
|
|
48
|
+
mapping from event description to target value
|
|
49
|
+
preload: bool
|
|
50
|
+
if True, preload the data of the Epochs objects.
|
|
51
|
+
drop_bad_windows: bool
|
|
52
|
+
If True, call `.drop_bad()` on the resulting mne.Epochs object. This
|
|
53
|
+
step allows identifying e.g., windows that fall outside of the
|
|
54
|
+
continuous recording. It is suggested to run this step here as otherwise
|
|
55
|
+
the BaseConcatDataset has to be updated as well.
|
|
56
|
+
accepted_bads_ratio: float, optional
|
|
57
|
+
Acceptable proportion of trials withinconsistent length in a raw. If
|
|
58
|
+
the number of trials whose length is exceeded by the window size is
|
|
59
|
+
smaller than this, then only the corresponding trials are dropped, but
|
|
60
|
+
the computation continues. Otherwise, an error is raised. Defaults to
|
|
61
|
+
0.0 (raise an error).
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
windows_datasets: BaseConcatDataset
|
|
66
|
+
X and y transformed to a dataset format that is compatible with skorch
|
|
67
|
+
and braindecode
|
|
68
|
+
"""
|
|
69
|
+
# Prevent circular import
|
|
70
|
+
from ..preprocessing.windowers import create_windows_from_events
|
|
71
|
+
|
|
72
|
+
if descriptions is not None:
|
|
73
|
+
if len(descriptions) != len(raws):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"length of 'raws' ({len(raws)}) and 'description' "
|
|
76
|
+
f"({len(descriptions)}) has to match"
|
|
77
|
+
)
|
|
78
|
+
base_datasets = [
|
|
79
|
+
BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
|
|
80
|
+
]
|
|
81
|
+
else:
|
|
82
|
+
base_datasets = [BaseDataset(raw) for raw in raws]
|
|
83
|
+
|
|
84
|
+
base_datasets = BaseConcatDataset(base_datasets)
|
|
85
|
+
windows_datasets = create_windows_from_events(
|
|
86
|
+
base_datasets,
|
|
87
|
+
trial_start_offset_samples=trial_start_offset_samples,
|
|
88
|
+
trial_stop_offset_samples=trial_stop_offset_samples,
|
|
89
|
+
window_size_samples=window_size_samples,
|
|
90
|
+
window_stride_samples=window_stride_samples,
|
|
91
|
+
drop_last_window=drop_last_window,
|
|
92
|
+
mapping=mapping,
|
|
93
|
+
drop_bad_windows=drop_bad_windows,
|
|
94
|
+
preload=preload,
|
|
95
|
+
accepted_bads_ratio=accepted_bads_ratio,
|
|
96
|
+
)
|
|
97
|
+
return windows_datasets
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def create_from_mne_epochs(
|
|
101
|
+
list_of_epochs: list[mne.BaseEpochs],
|
|
102
|
+
window_size_samples: int,
|
|
103
|
+
window_stride_samples: int,
|
|
104
|
+
drop_last_window: bool,
|
|
105
|
+
) -> BaseConcatDataset:
|
|
106
|
+
"""Create WindowsDatasets from mne.Epochs
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
list_of_epochs: array-like
|
|
111
|
+
list of mne.Epochs
|
|
112
|
+
window_size_samples: int
|
|
113
|
+
window size
|
|
114
|
+
window_stride_samples: int
|
|
115
|
+
stride between windows
|
|
116
|
+
drop_last_window: bool
|
|
117
|
+
whether or not have a last overlapping window, when
|
|
118
|
+
windows do not equally divide the continuous signal
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
windows_datasets: BaseConcatDataset
|
|
123
|
+
X and y transformed to a dataset format that is compatible with skorch
|
|
124
|
+
and braindecode
|
|
125
|
+
"""
|
|
126
|
+
# Prevent circular import
|
|
127
|
+
from ..preprocessing.windowers import _check_windowing_arguments
|
|
128
|
+
|
|
129
|
+
_check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
|
|
130
|
+
|
|
131
|
+
list_of_windows_ds = []
|
|
132
|
+
for epochs in list_of_epochs:
|
|
133
|
+
event_descriptions = epochs.events[:, 2]
|
|
134
|
+
original_trial_starts = epochs.events[:, 0]
|
|
135
|
+
stop = len(epochs.times) - window_size_samples
|
|
136
|
+
|
|
137
|
+
# already includes last incomplete window start
|
|
138
|
+
starts = np.arange(0, stop + 1, window_stride_samples)
|
|
139
|
+
|
|
140
|
+
if not drop_last_window and starts[-1] < stop:
|
|
141
|
+
# if last window does not end at trial stop, make it stop there
|
|
142
|
+
starts = np.append(starts, stop)
|
|
143
|
+
|
|
144
|
+
fake_events = [[start, window_size_samples, -1] for start in starts]
|
|
145
|
+
|
|
146
|
+
for trial_i, trial in enumerate(epochs):
|
|
147
|
+
metadata = pd.DataFrame(
|
|
148
|
+
{
|
|
149
|
+
"i_window_in_trial": np.arange(len(fake_events)),
|
|
150
|
+
"i_start_in_trial": starts + original_trial_starts[trial_i],
|
|
151
|
+
"i_stop_in_trial": starts
|
|
152
|
+
+ original_trial_starts[trial_i]
|
|
153
|
+
+ window_size_samples,
|
|
154
|
+
"target": len(fake_events) * [event_descriptions[trial_i]],
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
# window size - 1, since tmax is inclusive
|
|
158
|
+
mne_epochs = mne.Epochs(
|
|
159
|
+
mne.io.RawArray(trial, epochs.info),
|
|
160
|
+
fake_events,
|
|
161
|
+
baseline=None,
|
|
162
|
+
tmin=0,
|
|
163
|
+
tmax=(window_size_samples - 1) / epochs.info["sfreq"],
|
|
164
|
+
metadata=metadata,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
mne_epochs.drop_bad(reject=None, flat=None)
|
|
168
|
+
|
|
169
|
+
windows_ds = WindowsDataset(mne_epochs)
|
|
170
|
+
list_of_windows_ds.append(windows_ds)
|
|
171
|
+
|
|
172
|
+
return BaseConcatDataset(list_of_windows_ds)
|