braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/mne.py
CHANGED
|
@@ -3,18 +3,28 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import mne
|
|
6
9
|
import numpy as np
|
|
7
10
|
import pandas as pd
|
|
8
|
-
import mne
|
|
9
11
|
|
|
10
|
-
from .base import
|
|
12
|
+
from .base import BaseConcatDataset, BaseDataset, WindowsDataset
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def create_from_mne_raw(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
raws: list[mne.io.BaseRaw],
|
|
17
|
+
trial_start_offset_samples: int,
|
|
18
|
+
trial_stop_offset_samples: int,
|
|
19
|
+
window_size_samples: int,
|
|
20
|
+
window_stride_samples: int,
|
|
21
|
+
drop_last_window: bool,
|
|
22
|
+
descriptions: list[dict | pd.Series] | None = None,
|
|
23
|
+
mapping: dict[str, int] | None = None,
|
|
24
|
+
preload: bool = False,
|
|
25
|
+
drop_bad_windows: bool = True,
|
|
26
|
+
accepted_bads_ratio: float = 0.0,
|
|
27
|
+
) -> BaseConcatDataset:
|
|
18
28
|
"""Create WindowsDatasets from mne.RawArrays
|
|
19
29
|
|
|
20
30
|
Parameters
|
|
@@ -58,13 +68,16 @@ def create_from_mne_raw(
|
|
|
58
68
|
"""
|
|
59
69
|
# Prevent circular import
|
|
60
70
|
from ..preprocessing.windowers import create_windows_from_events
|
|
71
|
+
|
|
61
72
|
if descriptions is not None:
|
|
62
73
|
if len(descriptions) != len(raws):
|
|
63
74
|
raise ValueError(
|
|
64
75
|
f"length of 'raws' ({len(raws)}) and 'description' "
|
|
65
|
-
f"({len(descriptions)}) has to match"
|
|
66
|
-
|
|
67
|
-
|
|
76
|
+
f"({len(descriptions)}) has to match"
|
|
77
|
+
)
|
|
78
|
+
base_datasets = [
|
|
79
|
+
BaseDataset(raw, desc) for raw, desc in zip(raws, descriptions)
|
|
80
|
+
]
|
|
68
81
|
else:
|
|
69
82
|
base_datasets = [BaseDataset(raw) for raw in raws]
|
|
70
83
|
|
|
@@ -84,8 +97,12 @@ def create_from_mne_raw(
|
|
|
84
97
|
return windows_datasets
|
|
85
98
|
|
|
86
99
|
|
|
87
|
-
def create_from_mne_epochs(
|
|
88
|
-
|
|
100
|
+
def create_from_mne_epochs(
|
|
101
|
+
list_of_epochs: list[mne.BaseEpochs],
|
|
102
|
+
window_size_samples: int,
|
|
103
|
+
window_stride_samples: int,
|
|
104
|
+
drop_last_window: bool,
|
|
105
|
+
) -> BaseConcatDataset:
|
|
89
106
|
"""Create WindowsDatasets from mne.Epochs
|
|
90
107
|
|
|
91
108
|
Parameters
|
|
@@ -108,8 +125,8 @@ def create_from_mne_epochs(list_of_epochs, window_size_samples,
|
|
|
108
125
|
"""
|
|
109
126
|
# Prevent circular import
|
|
110
127
|
from ..preprocessing.windowers import _check_windowing_arguments
|
|
111
|
-
|
|
112
|
-
|
|
128
|
+
|
|
129
|
+
_check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
|
|
113
130
|
|
|
114
131
|
list_of_windows_ds = []
|
|
115
132
|
for epochs in list_of_epochs:
|
|
@@ -124,24 +141,28 @@ def create_from_mne_epochs(list_of_epochs, window_size_samples,
|
|
|
124
141
|
# if last window does not end at trial stop, make it stop there
|
|
125
142
|
starts = np.append(starts, stop)
|
|
126
143
|
|
|
127
|
-
fake_events = [[start, window_size_samples, -1] for start in
|
|
128
|
-
starts]
|
|
144
|
+
fake_events = [[start, window_size_samples, -1] for start in starts]
|
|
129
145
|
|
|
130
146
|
for trial_i, trial in enumerate(epochs):
|
|
131
|
-
metadata = pd.DataFrame(
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
147
|
+
metadata = pd.DataFrame(
|
|
148
|
+
{
|
|
149
|
+
"i_window_in_trial": np.arange(len(fake_events)),
|
|
150
|
+
"i_start_in_trial": starts + original_trial_starts[trial_i],
|
|
151
|
+
"i_stop_in_trial": starts
|
|
152
|
+
+ original_trial_starts[trial_i]
|
|
153
|
+
+ window_size_samples,
|
|
154
|
+
"target": len(fake_events) * [event_descriptions[trial_i]],
|
|
155
|
+
}
|
|
156
|
+
)
|
|
138
157
|
# window size - 1, since tmax is inclusive
|
|
139
158
|
mne_epochs = mne.Epochs(
|
|
140
|
-
mne.io.RawArray(trial, epochs.info),
|
|
159
|
+
mne.io.RawArray(trial, epochs.info),
|
|
160
|
+
fake_events,
|
|
141
161
|
baseline=None,
|
|
142
162
|
tmin=0,
|
|
143
163
|
tmax=(window_size_samples - 1) / epochs.info["sfreq"],
|
|
144
|
-
metadata=metadata
|
|
164
|
+
metadata=metadata,
|
|
165
|
+
)
|
|
145
166
|
|
|
146
167
|
mne_epochs.drop_bad(reject=None, flat=None)
|
|
147
168
|
|
braindecode/datasets/moabb.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""Dataset objects for some public datasets.
|
|
2
|
-
"""
|
|
1
|
+
"""Dataset objects for some public datasets."""
|
|
3
2
|
|
|
4
3
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
5
4
|
# Lukas Gemein <l.gemein@gmail.com>
|
|
@@ -9,16 +8,23 @@
|
|
|
9
8
|
#
|
|
10
9
|
# License: BSD (3-clause)
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
13
16
|
import mne
|
|
17
|
+
import pandas as pd
|
|
14
18
|
|
|
15
|
-
from .base import BaseDataset, BaseConcatDataset
|
|
16
19
|
from braindecode.util import _update_moabb_docstring
|
|
17
20
|
|
|
21
|
+
from .base import BaseConcatDataset, BaseDataset
|
|
22
|
+
|
|
18
23
|
|
|
19
24
|
def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
|
|
20
25
|
# soft dependency on moabb
|
|
21
26
|
from moabb.datasets.utils import dataset_list
|
|
27
|
+
|
|
22
28
|
for dataset in dataset_list:
|
|
23
29
|
if dataset_name == dataset.__name__:
|
|
24
30
|
# return an instance of the found dataset class
|
|
@@ -29,35 +35,41 @@ def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
|
|
|
29
35
|
raise ValueError(f"{dataset_name} not found in moabb datasets")
|
|
30
36
|
|
|
31
37
|
|
|
32
|
-
def _fetch_and_unpack_moabb_data(dataset, subject_ids):
|
|
33
|
-
|
|
38
|
+
def _fetch_and_unpack_moabb_data(dataset, subject_ids=None, dataset_load_kwargs=None):
|
|
39
|
+
if dataset_load_kwargs is None:
|
|
40
|
+
data = dataset.get_data(subject_ids)
|
|
41
|
+
else:
|
|
42
|
+
data = dataset.get_data(subjects=subject_ids, **dataset_load_kwargs)
|
|
43
|
+
|
|
34
44
|
raws, subject_ids, session_ids, run_ids = [], [], [], []
|
|
35
45
|
for subj_id, subj_data in data.items():
|
|
36
46
|
for sess_id, sess_data in subj_data.items():
|
|
37
47
|
for run_id, raw in sess_data.items():
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
annots = _annotations_from_moabb_stim_channel(raw, dataset)
|
|
41
|
-
raw.set_annotations(annots)
|
|
48
|
+
annots = _annotations_from_moabb_stim_channel(raw, dataset)
|
|
49
|
+
raw.set_annotations(annots)
|
|
42
50
|
raws.append(raw)
|
|
43
51
|
subject_ids.append(subj_id)
|
|
44
52
|
session_ids.append(sess_id)
|
|
45
53
|
run_ids.append(run_id)
|
|
46
|
-
description = pd.DataFrame(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
'run': run_ids
|
|
50
|
-
})
|
|
54
|
+
description = pd.DataFrame(
|
|
55
|
+
{"subject": subject_ids, "session": session_ids, "run": run_ids}
|
|
56
|
+
)
|
|
51
57
|
return raws, description
|
|
52
58
|
|
|
53
59
|
|
|
54
60
|
def _annotations_from_moabb_stim_channel(raw, dataset):
|
|
55
|
-
# find events from stim channel
|
|
56
|
-
|
|
61
|
+
# find events from the stim channel
|
|
62
|
+
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
|
|
63
|
+
if len(stim_channels) > 0:
|
|
64
|
+
# returns an empty array if none found
|
|
65
|
+
events = mne.find_events(raw, shortest_event=0, verbose=False)
|
|
66
|
+
event_id = dataset.event_id
|
|
67
|
+
else:
|
|
68
|
+
events, event_id = mne.events_from_annotations(raw, verbose=False)
|
|
57
69
|
|
|
58
70
|
# get annotations from events
|
|
59
|
-
event_desc = {k: v for v, k in
|
|
60
|
-
annots = mne.annotations_from_events(events, raw.info[
|
|
71
|
+
event_desc = {k: v for v, k in event_id.items()}
|
|
72
|
+
annots = mne.annotations_from_events(events, raw.info["sfreq"], event_desc)
|
|
61
73
|
|
|
62
74
|
# set trial on and offset given by moabb
|
|
63
75
|
onset, offset = dataset.interval
|
|
@@ -66,28 +78,47 @@ def _annotations_from_moabb_stim_channel(raw, dataset):
|
|
|
66
78
|
return annots
|
|
67
79
|
|
|
68
80
|
|
|
69
|
-
def fetch_data_with_moabb(
|
|
81
|
+
def fetch_data_with_moabb(
|
|
82
|
+
dataset_name: str,
|
|
83
|
+
subject_ids: list[int] | int | None = None,
|
|
84
|
+
dataset_kwargs: dict[str, Any] | None = None,
|
|
85
|
+
dataset_load_kwargs: dict[str, Any] | None = None,
|
|
86
|
+
) -> tuple[list[mne.io.Raw], pd.DataFrame]:
|
|
70
87
|
# ToDo: update path to where moabb downloads / looks for the data
|
|
71
88
|
"""Fetch data using moabb.
|
|
72
89
|
|
|
73
90
|
Parameters
|
|
74
91
|
----------
|
|
75
|
-
dataset_name: str
|
|
92
|
+
dataset_name: str | moabb.datasets.base.BaseDataset
|
|
76
93
|
the name of a dataset included in moabb
|
|
77
94
|
subject_ids: list(int) | int
|
|
78
95
|
(list of) int of subject(s) to be fetched
|
|
79
96
|
dataset_kwargs: dict, optional
|
|
80
97
|
optional dictionary containing keyword arguments
|
|
81
98
|
to pass to the moabb dataset when instantiating it.
|
|
99
|
+
data_load_kwargs: dict, optional
|
|
100
|
+
optional dictionary containing keyword arguments
|
|
101
|
+
to pass to the moabb dataset's load_data method.
|
|
102
|
+
Allows using the moabb cache_config=None and
|
|
103
|
+
process_pipeline=None.
|
|
82
104
|
|
|
83
105
|
Returns
|
|
84
106
|
-------
|
|
85
107
|
raws: mne.Raw
|
|
86
108
|
info: pandas.DataFrame
|
|
87
109
|
"""
|
|
88
|
-
|
|
110
|
+
if isinstance(dataset_name, str):
|
|
111
|
+
dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
|
|
112
|
+
else:
|
|
113
|
+
from moabb.datasets.base import BaseDataset
|
|
114
|
+
|
|
115
|
+
if isinstance(dataset_name, BaseDataset):
|
|
116
|
+
dataset = dataset_name
|
|
117
|
+
|
|
89
118
|
subject_id = [subject_ids] if isinstance(subject_ids, int) else subject_ids
|
|
90
|
-
return _fetch_and_unpack_moabb_data(
|
|
119
|
+
return _fetch_and_unpack_moabb_data(
|
|
120
|
+
dataset, subject_id, dataset_load_kwargs=dataset_load_kwargs
|
|
121
|
+
)
|
|
91
122
|
|
|
92
123
|
|
|
93
124
|
class MOABBDataset(BaseConcatDataset):
|
|
@@ -103,11 +134,38 @@ class MOABBDataset(BaseConcatDataset):
|
|
|
103
134
|
dataset_kwargs: dict, optional
|
|
104
135
|
optional dictionary containing keyword arguments
|
|
105
136
|
to pass to the moabb dataset when instantiating it.
|
|
137
|
+
dataset_load_kwargs: dict, optional
|
|
138
|
+
optional dictionary containing keyword arguments
|
|
139
|
+
to pass to the moabb dataset's load_data method.
|
|
140
|
+
Allows using the moabb cache_config=None and
|
|
141
|
+
process_pipeline=None.
|
|
106
142
|
"""
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
dataset_name: str,
|
|
147
|
+
subject_ids: list[int] | int | None = None,
|
|
148
|
+
dataset_kwargs: dict[str, Any] | None = None,
|
|
149
|
+
dataset_load_kwargs: dict[str, Any] | None = None,
|
|
150
|
+
):
|
|
151
|
+
# soft dependency on moabb
|
|
152
|
+
from moabb import __version__ as moabb_version # type: ignore
|
|
153
|
+
|
|
154
|
+
if moabb_version == "1.0.0":
|
|
155
|
+
warnings.warn(
|
|
156
|
+
"moabb version 1.0.0 generates incorrect annotations. "
|
|
157
|
+
"Please update to another version, version 0.5 or 1.1.0 "
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
raws, description = fetch_data_with_moabb(
|
|
161
|
+
dataset_name,
|
|
162
|
+
subject_ids,
|
|
163
|
+
dataset_kwargs,
|
|
164
|
+
dataset_load_kwargs=dataset_load_kwargs,
|
|
165
|
+
)
|
|
166
|
+
all_base_ds = [
|
|
167
|
+
BaseDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
|
|
168
|
+
]
|
|
111
169
|
super().__init__(all_base_ds)
|
|
112
170
|
|
|
113
171
|
|
|
@@ -122,6 +180,7 @@ class BNCI2014001(MOABBDataset):
|
|
|
122
180
|
"""
|
|
123
181
|
try:
|
|
124
182
|
from moabb.datasets import BNCI2014001
|
|
183
|
+
|
|
125
184
|
__doc__ = _update_moabb_docstring(BNCI2014001, doc)
|
|
126
185
|
except ModuleNotFoundError:
|
|
127
186
|
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
|
|
@@ -141,6 +200,7 @@ class HGD(MOABBDataset):
|
|
|
141
200
|
"""
|
|
142
201
|
try:
|
|
143
202
|
from moabb.datasets import Schirrmeister2017
|
|
203
|
+
|
|
144
204
|
__doc__ = _update_moabb_docstring(Schirrmeister2017, doc)
|
|
145
205
|
except ModuleNotFoundError:
|
|
146
206
|
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset classes for the NMT EEG Corpus dataset.
|
|
3
|
+
|
|
4
|
+
The NMT Scalp EEG Dataset is an open-source annotated dataset of healthy and
|
|
5
|
+
pathological EEG recordings for predictive modeling. This dataset contains
|
|
6
|
+
2,417 recordings from unique participants spanning almost 625 h.
|
|
7
|
+
|
|
8
|
+
Note:
|
|
9
|
+
- The signal unit may not be uV and further examination is required.
|
|
10
|
+
- The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
|
|
11
|
+
which needs to be further determined.
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
# Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
|
|
16
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
17
|
+
#
|
|
18
|
+
# License: BSD (3-clause)
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import glob
|
|
23
|
+
import os
|
|
24
|
+
import warnings
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from unittest import mock
|
|
27
|
+
|
|
28
|
+
import mne
|
|
29
|
+
import numpy as np
|
|
30
|
+
import pandas as pd
|
|
31
|
+
from joblib import Parallel, delayed
|
|
32
|
+
from mne.datasets import fetch_dataset
|
|
33
|
+
|
|
34
|
+
from braindecode.datasets.base import BaseConcatDataset, BaseDataset
|
|
35
|
+
|
|
36
|
+
NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
|
|
37
|
+
NMT_archive_name = "NMT.zip"
|
|
38
|
+
NMT_folder_name = "MNE-NMT-eeg-dataset"
|
|
39
|
+
NMT_dataset_name = "NMT-EEG-Corpus"
|
|
40
|
+
|
|
41
|
+
NMT_dataset_params = {
|
|
42
|
+
"dataset_name": NMT_dataset_name,
|
|
43
|
+
"url": NMT_URL,
|
|
44
|
+
"archive_name": NMT_archive_name,
|
|
45
|
+
"folder_name": NMT_folder_name,
|
|
46
|
+
"hash": "77b3ce12bcaf6c6cce4e6690ea89cb22bed55af10c525077b430f6e1d2e3c6bf",
|
|
47
|
+
"config_key": NMT_dataset_name,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class NMT(BaseConcatDataset):
|
|
52
|
+
"""The NMT Scalp EEG Dataset.
|
|
53
|
+
|
|
54
|
+
An Open-Source Annotated Dataset of Healthy and Pathological EEG
|
|
55
|
+
Recordings for Predictive Modeling.
|
|
56
|
+
|
|
57
|
+
This dataset contains 2,417 recordings from unique participants spanning
|
|
58
|
+
almost 625 h.
|
|
59
|
+
|
|
60
|
+
Here, the dataset can be used for three tasks, brain-age, gender prediction,
|
|
61
|
+
abnormality detection.
|
|
62
|
+
|
|
63
|
+
The dataset is described in [Khan2022]_.
|
|
64
|
+
|
|
65
|
+
.. versionadded:: 0.9
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
path: str
|
|
70
|
+
Parent directory of the dataset.
|
|
71
|
+
recording_ids: list(int) | int
|
|
72
|
+
A (list of) int of recording id(s) to be read (order matters and will
|
|
73
|
+
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
74
|
+
then the first recording returned by this class will be chronologically
|
|
75
|
+
later than the second recording. Provide recording_ids in ascending
|
|
76
|
+
order to preserve chronological order.).
|
|
77
|
+
target_name: str
|
|
78
|
+
Can be "pathological", "gender", or "age".
|
|
79
|
+
preload: bool
|
|
80
|
+
If True, preload the data of the Raw objects.
|
|
81
|
+
|
|
82
|
+
References
|
|
83
|
+
----------
|
|
84
|
+
.. [Khan2022] Khan, H.A.,Ul Ain, R., Kamboh, A.M., Butt, H.T.,Shafait,S.,
|
|
85
|
+
Alamgir, W., Stricker, D. and Shafait, F., 2022. The NMT scalp EEG
|
|
86
|
+
dataset: an open-source annotated dataset of healthy and pathological
|
|
87
|
+
EEG recordings for predictive modeling. Frontiers in neuroscience,
|
|
88
|
+
15, p.755817.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
path=None,
|
|
94
|
+
target_name="pathological",
|
|
95
|
+
recording_ids=None,
|
|
96
|
+
preload=False,
|
|
97
|
+
n_jobs=1,
|
|
98
|
+
):
|
|
99
|
+
# correct the path if needed
|
|
100
|
+
if path is not None:
|
|
101
|
+
list_csv = glob.glob(f"{path}/**/Labels.csv", recursive=True)
|
|
102
|
+
if isinstance(list_csv, list) and len(list_csv) > 0:
|
|
103
|
+
path = Path(list_csv[0]).parent
|
|
104
|
+
|
|
105
|
+
if path is None or len(list_csv) == 0:
|
|
106
|
+
path = fetch_dataset(
|
|
107
|
+
dataset_params=NMT_dataset_params,
|
|
108
|
+
path=Path(path) if path is not None else None,
|
|
109
|
+
processor="unzip",
|
|
110
|
+
force_update=False,
|
|
111
|
+
)
|
|
112
|
+
# First time we fetch the dataset, we need to move the files to the
|
|
113
|
+
# correct directory.
|
|
114
|
+
path = _correct_path(path)
|
|
115
|
+
|
|
116
|
+
# Get all file paths
|
|
117
|
+
file_paths = glob.glob(
|
|
118
|
+
os.path.join(path, "**" + os.sep + "*.edf"), recursive=True
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# sort by subject id
|
|
122
|
+
file_paths = [
|
|
123
|
+
file_path
|
|
124
|
+
for file_path in file_paths
|
|
125
|
+
if os.path.splitext(file_path)[1] == ".edf"
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
# sort by subject id
|
|
129
|
+
file_paths = sorted(
|
|
130
|
+
file_paths, key=lambda p: int(os.path.splitext(p)[0].split(os.sep)[-1])
|
|
131
|
+
)
|
|
132
|
+
if recording_ids is not None:
|
|
133
|
+
file_paths = [file_paths[rec_id] for rec_id in recording_ids]
|
|
134
|
+
|
|
135
|
+
# read labels and rearrange them to match TUH Abnormal EEG Corpus
|
|
136
|
+
description = pd.read_csv(
|
|
137
|
+
os.path.join(path, "Labels.csv"), index_col="recordname"
|
|
138
|
+
)
|
|
139
|
+
if recording_ids is not None:
|
|
140
|
+
description = description.iloc[recording_ids]
|
|
141
|
+
description.replace(
|
|
142
|
+
{
|
|
143
|
+
"not specified": "X",
|
|
144
|
+
"female": "F",
|
|
145
|
+
"male": "M",
|
|
146
|
+
"abnormal": True,
|
|
147
|
+
"normal": False,
|
|
148
|
+
},
|
|
149
|
+
inplace=True,
|
|
150
|
+
)
|
|
151
|
+
description.rename(columns={"label": "pathological"}, inplace=True)
|
|
152
|
+
description.reset_index(drop=True, inplace=True)
|
|
153
|
+
description["path"] = file_paths
|
|
154
|
+
description = description[["path", "pathological", "age", "gender"]]
|
|
155
|
+
|
|
156
|
+
if n_jobs == 1:
|
|
157
|
+
base_datasets = [
|
|
158
|
+
self._create_dataset(d, target_name, preload)
|
|
159
|
+
for recording_id, d in description.iterrows()
|
|
160
|
+
]
|
|
161
|
+
else:
|
|
162
|
+
base_datasets = Parallel(n_jobs)(
|
|
163
|
+
delayed(self._create_dataset)(d, target_name, preload)
|
|
164
|
+
for recording_id, d in description.iterrows()
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
super().__init__(base_datasets)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _create_dataset(d, target_name, preload):
|
|
171
|
+
raw = mne.io.read_raw_edf(d.path, preload=preload)
|
|
172
|
+
d["n_samples"] = raw.n_times
|
|
173
|
+
d["sfreq"] = raw.info["sfreq"]
|
|
174
|
+
d["train"] = "train" in d.path.split(os.sep)
|
|
175
|
+
base_dataset = BaseDataset(raw, d, target_name)
|
|
176
|
+
return base_dataset
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _correct_path(path: str):
|
|
180
|
+
"""
|
|
181
|
+
Check if the path is correct and rename the file if needed.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
path: basestring
|
|
186
|
+
Path to the file.
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
path: basestring
|
|
191
|
+
Corrected path.
|
|
192
|
+
"""
|
|
193
|
+
if not Path(path).exists():
|
|
194
|
+
unzip_file_name = f"{NMT_archive_name}.unzip"
|
|
195
|
+
if (Path(path).parent / unzip_file_name).exists():
|
|
196
|
+
try:
|
|
197
|
+
os.rename(
|
|
198
|
+
src=Path(path).parent / unzip_file_name,
|
|
199
|
+
dst=Path(path),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
except PermissionError:
|
|
203
|
+
raise PermissionError(
|
|
204
|
+
f"Please rename {Path(path).parent / unzip_file_name}"
|
|
205
|
+
+ f"manually to {path} and try again."
|
|
206
|
+
)
|
|
207
|
+
path = os.path.join(path, "nmt_scalp_eeg_dataset")
|
|
208
|
+
|
|
209
|
+
return path
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _get_header(*args):
|
|
213
|
+
all_paths = {**_NMT_PATHS}
|
|
214
|
+
return all_paths[args[0]]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _fake_pd_read_csv(*args, **kwargs):
|
|
218
|
+
# Create a list of lists to hold the data
|
|
219
|
+
data = [
|
|
220
|
+
["0000001.edf", "normal", 35, "male", "train"],
|
|
221
|
+
["0000002.edf", "abnormal", 28, "female", "test"],
|
|
222
|
+
["0000003.edf", "normal", 62, "male", "train"],
|
|
223
|
+
["0000004.edf", "abnormal", 41, "female", "test"],
|
|
224
|
+
["0000005.edf", "normal", 19, "male", "train"],
|
|
225
|
+
["0000006.edf", "abnormal", 55, "female", "test"],
|
|
226
|
+
["0000007.edf", "normal", 71, "male", "train"],
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
# Create the DataFrame, specifying column names
|
|
230
|
+
df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
|
|
231
|
+
|
|
232
|
+
return df
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _fake_raw(*args, **kwargs):
|
|
236
|
+
sfreq = 10
|
|
237
|
+
ch_names = [
|
|
238
|
+
"EEG A1-REF",
|
|
239
|
+
"EEG A2-REF",
|
|
240
|
+
"EEG FP1-REF",
|
|
241
|
+
"EEG FP2-REF",
|
|
242
|
+
"EEG F3-REF",
|
|
243
|
+
"EEG F4-REF",
|
|
244
|
+
"EEG C3-REF",
|
|
245
|
+
"EEG C4-REF",
|
|
246
|
+
"EEG P3-REF",
|
|
247
|
+
"EEG P4-REF",
|
|
248
|
+
"EEG O1-REF",
|
|
249
|
+
"EEG O2-REF",
|
|
250
|
+
"EEG F7-REF",
|
|
251
|
+
"EEG F8-REF",
|
|
252
|
+
"EEG T3-REF",
|
|
253
|
+
"EEG T4-REF",
|
|
254
|
+
"EEG T5-REF",
|
|
255
|
+
"EEG T6-REF",
|
|
256
|
+
"EEG FZ-REF",
|
|
257
|
+
"EEG CZ-REF",
|
|
258
|
+
"EEG PZ-REF",
|
|
259
|
+
]
|
|
260
|
+
duration_min = 6
|
|
261
|
+
data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
|
|
262
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
|
|
263
|
+
raw = mne.io.RawArray(data=data, info=info)
|
|
264
|
+
return raw
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
_NMT_PATHS = {
|
|
268
|
+
# these are actual file paths and edf headers from NMT EEG Corpus
|
|
269
|
+
"nmt_scalp_eeg_dataset/abnormal/train/0000036.edf": b"0 0000036 M 13-May-1951 0000036 Age:32 ",
|
|
270
|
+
# noqa E501
|
|
271
|
+
"nmt_scalp_eeg_dataset/abnormal/eval/0000037.edf": b"0 0000037 M 13-May-1951 0000037 Age:32 ",
|
|
272
|
+
# noqa E501
|
|
273
|
+
"nmt_scalp_eeg_dataset/abnormal/eval/0000038.edf": b"0 0000038 M 13-May-1951 0000038 Age:32 ",
|
|
274
|
+
# noqa E501
|
|
275
|
+
"nmt_scalp_eeg_dataset/normal/train/0000039.edf": b"0 0000039 M 13-May-1951 0000039 Age:32 ",
|
|
276
|
+
# noqa E501
|
|
277
|
+
"nmt_scalp_eeg_dataset/normal/eval/0000040.edf": b"0 0000040 M 13-May-1951 0000040 Age:32 ",
|
|
278
|
+
# noqa E501
|
|
279
|
+
"nmt_scalp_eeg_dataset/normal/eval/0000041.edf": b"0 0000041 M 13-May-1951 0000041 Age:32 ",
|
|
280
|
+
# noqa E501
|
|
281
|
+
"nmt_scalp_eeg_dataset/abnormal/train/0000042.edf": b"0 0000042 M 13-May-1951 0000042 Age:32 ",
|
|
282
|
+
# noqa E501
|
|
283
|
+
"Labels.csv": b"0 recordname,label,age,gender,loc 1 0000001.edf,normal,22,not specified,train ",
|
|
284
|
+
# noqa E501
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _NMTMock(NMT):
|
|
289
|
+
"""Mocked class for testing and examples."""
|
|
290
|
+
|
|
291
|
+
@mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
|
|
292
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
293
|
+
@mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
mock_glob,
|
|
297
|
+
path,
|
|
298
|
+
recording_ids=None,
|
|
299
|
+
target_name="pathological",
|
|
300
|
+
preload=False,
|
|
301
|
+
n_jobs=1,
|
|
302
|
+
):
|
|
303
|
+
with warnings.catch_warnings():
|
|
304
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
305
|
+
super().__init__(
|
|
306
|
+
path=path,
|
|
307
|
+
recording_ids=recording_ids,
|
|
308
|
+
target_name=target_name,
|
|
309
|
+
preload=preload,
|
|
310
|
+
n_jobs=n_jobs,
|
|
311
|
+
)
|