braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
2
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import mne
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from .base import BaseConcatDataset, RawDataset, WindowsDataset
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_from_mne_raw(
|
|
16
|
+
raws: list[mne.io.BaseRaw],
|
|
17
|
+
trial_start_offset_samples: int,
|
|
18
|
+
trial_stop_offset_samples: int,
|
|
19
|
+
window_size_samples: int,
|
|
20
|
+
window_stride_samples: int,
|
|
21
|
+
drop_last_window: bool,
|
|
22
|
+
descriptions: list[dict | pd.Series] | None = None,
|
|
23
|
+
mapping: dict[str, int] | None = None,
|
|
24
|
+
preload: bool = False,
|
|
25
|
+
drop_bad_windows: bool = True,
|
|
26
|
+
accepted_bads_ratio: float = 0.0,
|
|
27
|
+
) -> BaseConcatDataset:
|
|
28
|
+
"""Create WindowsDatasets from mne.RawArrays.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
raws : array-like
|
|
33
|
+
list of mne.RawArrays
|
|
34
|
+
trial_start_offset_samples : int
|
|
35
|
+
start offset from original trial onsets in samples
|
|
36
|
+
trial_stop_offset_samples : int
|
|
37
|
+
stop offset from original trial stop in samples
|
|
38
|
+
window_size_samples : int
|
|
39
|
+
window size
|
|
40
|
+
window_stride_samples : int
|
|
41
|
+
stride between windows
|
|
42
|
+
drop_last_window : bool
|
|
43
|
+
whether or not have a last overlapping window, when
|
|
44
|
+
windows do not equally divide the continuous signal
|
|
45
|
+
descriptions : array-like
|
|
46
|
+
list of dicts or pandas.Series with additional information about the raws
|
|
47
|
+
mapping : dict(str: int)
|
|
48
|
+
mapping from event description to target value
|
|
49
|
+
preload : bool
|
|
50
|
+
if True, preload the data of the Epochs objects.
|
|
51
|
+
drop_bad_windows : bool
|
|
52
|
+
If True, call `.drop_bad()` on the resulting mne.Epochs object. This
|
|
53
|
+
step allows identifying e.g., windows that fall outside of the
|
|
54
|
+
continuous recording. It is suggested to run this step here as otherwise
|
|
55
|
+
the BaseConcatDataset has to be updated as well.
|
|
56
|
+
accepted_bads_ratio : float, optional
|
|
57
|
+
Acceptable proportion of trials withinconsistent length in a raw. If
|
|
58
|
+
the number of trials whose length is exceeded by the window size is
|
|
59
|
+
smaller than this, then only the corresponding trials are dropped, but
|
|
60
|
+
the computation continues. Otherwise, an error is raised. Defaults to
|
|
61
|
+
0.0 (raise an error).
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
windows_datasets : BaseConcatDataset
|
|
66
|
+
X and y transformed to a dataset format that is compatible with skorch
|
|
67
|
+
and braindecode
|
|
68
|
+
"""
|
|
69
|
+
# Prevent circular import
|
|
70
|
+
from ..preprocessing.windowers import create_windows_from_events
|
|
71
|
+
|
|
72
|
+
if descriptions is not None:
|
|
73
|
+
if len(descriptions) != len(raws):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"length of 'raws' ({len(raws)}) and 'description' "
|
|
76
|
+
f"({len(descriptions)}) has to match"
|
|
77
|
+
)
|
|
78
|
+
base_datasets = [RawDataset(raw, desc) for raw, desc in zip(raws, descriptions)]
|
|
79
|
+
else:
|
|
80
|
+
base_datasets = [RawDataset(raw) for raw in raws]
|
|
81
|
+
|
|
82
|
+
base_datasets = BaseConcatDataset(base_datasets)
|
|
83
|
+
windows_datasets = create_windows_from_events(
|
|
84
|
+
base_datasets,
|
|
85
|
+
trial_start_offset_samples=trial_start_offset_samples,
|
|
86
|
+
trial_stop_offset_samples=trial_stop_offset_samples,
|
|
87
|
+
window_size_samples=window_size_samples,
|
|
88
|
+
window_stride_samples=window_stride_samples,
|
|
89
|
+
drop_last_window=drop_last_window,
|
|
90
|
+
mapping=mapping,
|
|
91
|
+
drop_bad_windows=drop_bad_windows,
|
|
92
|
+
preload=preload,
|
|
93
|
+
accepted_bads_ratio=accepted_bads_ratio,
|
|
94
|
+
)
|
|
95
|
+
return windows_datasets
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_from_mne_epochs(
|
|
99
|
+
list_of_epochs: list[mne.BaseEpochs],
|
|
100
|
+
window_size_samples: int,
|
|
101
|
+
window_stride_samples: int,
|
|
102
|
+
drop_last_window: bool,
|
|
103
|
+
) -> BaseConcatDataset:
|
|
104
|
+
"""Create WindowsDatasets from mne.Epochs.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
list_of_epochs : array-like
|
|
109
|
+
list of mne.Epochs
|
|
110
|
+
window_size_samples : int
|
|
111
|
+
window size
|
|
112
|
+
window_stride_samples : int
|
|
113
|
+
stride between windows
|
|
114
|
+
drop_last_window : bool
|
|
115
|
+
whether or not have a last overlapping window, when
|
|
116
|
+
windows do not equally divide the continuous signal
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
windows_datasets : BaseConcatDataset
|
|
121
|
+
X and y transformed to a dataset format that is compatible with skorch
|
|
122
|
+
and braindecode
|
|
123
|
+
"""
|
|
124
|
+
# Prevent circular import
|
|
125
|
+
from ..preprocessing.windowers import _check_windowing_arguments
|
|
126
|
+
|
|
127
|
+
_check_windowing_arguments(0, 0, window_size_samples, window_stride_samples)
|
|
128
|
+
|
|
129
|
+
list_of_windows_ds = []
|
|
130
|
+
for epochs in list_of_epochs:
|
|
131
|
+
event_descriptions = epochs.events[:, 2]
|
|
132
|
+
original_trial_starts = epochs.events[:, 0]
|
|
133
|
+
stop = len(epochs.times) - window_size_samples
|
|
134
|
+
|
|
135
|
+
# already includes last incomplete window start
|
|
136
|
+
starts = np.arange(0, stop + 1, window_stride_samples)
|
|
137
|
+
|
|
138
|
+
if not drop_last_window and starts[-1] < stop:
|
|
139
|
+
# if last window does not end at trial stop, make it stop there
|
|
140
|
+
starts = np.append(starts, stop)
|
|
141
|
+
|
|
142
|
+
fake_events = [[start, window_size_samples, -1] for start in starts]
|
|
143
|
+
|
|
144
|
+
for trial_i, trial in enumerate(epochs):
|
|
145
|
+
metadata = pd.DataFrame(
|
|
146
|
+
{
|
|
147
|
+
"i_window_in_trial": np.arange(len(fake_events)),
|
|
148
|
+
"i_start_in_trial": starts + original_trial_starts[trial_i],
|
|
149
|
+
"i_stop_in_trial": starts
|
|
150
|
+
+ original_trial_starts[trial_i]
|
|
151
|
+
+ window_size_samples,
|
|
152
|
+
"target": len(fake_events) * [event_descriptions[trial_i]],
|
|
153
|
+
}
|
|
154
|
+
)
|
|
155
|
+
# window size - 1, since tmax is inclusive
|
|
156
|
+
mne_epochs = mne.Epochs(
|
|
157
|
+
mne.io.RawArray(trial, epochs.info),
|
|
158
|
+
fake_events,
|
|
159
|
+
baseline=None,
|
|
160
|
+
tmin=0,
|
|
161
|
+
tmax=(window_size_samples - 1) / epochs.info["sfreq"],
|
|
162
|
+
metadata=metadata,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
mne_epochs.drop_bad(reject=None, flat=None)
|
|
166
|
+
|
|
167
|
+
windows_ds = WindowsDataset(mne_epochs)
|
|
168
|
+
list_of_windows_ds.append(windows_ds)
|
|
169
|
+
|
|
170
|
+
return BaseConcatDataset(list_of_windows_ds)
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Dataset objects for some public datasets."""
|
|
2
|
+
|
|
3
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Simon Brandt <simonbrandt@protonmail.com>
|
|
6
|
+
# David Sabbagh <dav.sabbagh@gmail.com>
|
|
7
|
+
# Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import mne
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from mne.utils import deprecated
|
|
19
|
+
|
|
20
|
+
from braindecode.util import _update_moabb_docstring
|
|
21
|
+
|
|
22
|
+
from .base import BaseConcatDataset, RawDataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
|
|
26
|
+
# soft dependency on moabb
|
|
27
|
+
from moabb.datasets.utils import dataset_list
|
|
28
|
+
|
|
29
|
+
for dataset in dataset_list:
|
|
30
|
+
if dataset_name == dataset.__name__:
|
|
31
|
+
# return an instance of the found dataset class
|
|
32
|
+
if dataset_kwargs is None:
|
|
33
|
+
return dataset()
|
|
34
|
+
else:
|
|
35
|
+
return dataset(**dataset_kwargs)
|
|
36
|
+
raise ValueError(f"{dataset_name} not found in moabb datasets")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _fetch_and_unpack_moabb_data(dataset, subject_ids=None, dataset_load_kwargs=None):
|
|
40
|
+
if dataset_load_kwargs is None:
|
|
41
|
+
data = dataset.get_data(subject_ids)
|
|
42
|
+
else:
|
|
43
|
+
data = dataset.get_data(subjects=subject_ids, **dataset_load_kwargs)
|
|
44
|
+
|
|
45
|
+
raws, subject_ids, session_ids, run_ids = [], [], [], []
|
|
46
|
+
for subj_id, subj_data in data.items():
|
|
47
|
+
for sess_id, sess_data in subj_data.items():
|
|
48
|
+
for run_id, raw in sess_data.items():
|
|
49
|
+
annots = _annotations_from_moabb_stim_channel(raw, dataset)
|
|
50
|
+
raw.set_annotations(annots)
|
|
51
|
+
raws.append(raw)
|
|
52
|
+
subject_ids.append(subj_id)
|
|
53
|
+
session_ids.append(sess_id)
|
|
54
|
+
run_ids.append(run_id)
|
|
55
|
+
description = pd.DataFrame(
|
|
56
|
+
{"subject": subject_ids, "session": session_ids, "run": run_ids}
|
|
57
|
+
)
|
|
58
|
+
return raws, description
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _annotations_from_moabb_stim_channel(raw, dataset):
|
|
62
|
+
# find events from the stim channel
|
|
63
|
+
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
|
|
64
|
+
if len(stim_channels) > 0:
|
|
65
|
+
# returns an empty array if none found
|
|
66
|
+
events = mne.find_events(raw, shortest_event=0, verbose=False)
|
|
67
|
+
event_id = dataset.event_id
|
|
68
|
+
else:
|
|
69
|
+
events, event_id = mne.events_from_annotations(raw, verbose=False)
|
|
70
|
+
|
|
71
|
+
# get annotations from events
|
|
72
|
+
event_desc = {k: v for v, k in event_id.items()}
|
|
73
|
+
annots = mne.annotations_from_events(events, raw.info["sfreq"], event_desc)
|
|
74
|
+
|
|
75
|
+
# set trial on and offset given by moabb
|
|
76
|
+
onset, offset = dataset.interval
|
|
77
|
+
annots.onset += onset
|
|
78
|
+
annots.duration += offset - onset
|
|
79
|
+
return annots
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def fetch_data_with_moabb(
|
|
83
|
+
dataset_name: str,
|
|
84
|
+
subject_ids: list[int] | int | None = None,
|
|
85
|
+
dataset_kwargs: dict[str, Any] | None = None,
|
|
86
|
+
dataset_load_kwargs: dict[str, Any] | None = None,
|
|
87
|
+
) -> tuple[list[mne.io.Raw], pd.DataFrame]:
|
|
88
|
+
# ToDo: update path to where moabb downloads / looks for the data
|
|
89
|
+
"""Fetch data using moabb.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
dataset_name : str | moabb.datasets.base.BaseDataset
|
|
94
|
+
the name of a dataset included in moabb
|
|
95
|
+
subject_ids : list(int) | int
|
|
96
|
+
(list of) int of subject(s) to be fetched
|
|
97
|
+
dataset_kwargs : dict, optional
|
|
98
|
+
optional dictionary containing keyword arguments
|
|
99
|
+
to pass to the moabb dataset when instantiating it.
|
|
100
|
+
data_load_kwargs : dict, optional
|
|
101
|
+
optional dictionary containing keyword arguments
|
|
102
|
+
to pass to the moabb dataset's load_data method.
|
|
103
|
+
Allows using the moabb cache_config=None and
|
|
104
|
+
process_pipeline=None.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
raws : mne.Raw
|
|
109
|
+
info : pandas.DataFrame
|
|
110
|
+
"""
|
|
111
|
+
if isinstance(dataset_name, str):
|
|
112
|
+
dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
|
|
113
|
+
else:
|
|
114
|
+
from moabb.datasets.base import BaseDataset
|
|
115
|
+
|
|
116
|
+
if isinstance(dataset_name, BaseDataset):
|
|
117
|
+
dataset = dataset_name
|
|
118
|
+
|
|
119
|
+
subject_id = [subject_ids] if isinstance(subject_ids, int) else subject_ids
|
|
120
|
+
return _fetch_and_unpack_moabb_data(
|
|
121
|
+
dataset, subject_id, dataset_load_kwargs=dataset_load_kwargs
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class MOABBDataset(BaseConcatDataset):
|
|
126
|
+
"""A class for moabb datasets.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
dataset_name : str
|
|
131
|
+
name of dataset included in moabb to be fetched
|
|
132
|
+
subject_ids : list(int) | int | None
|
|
133
|
+
(list of) int of subject(s) to be fetched. If None, data of all
|
|
134
|
+
subjects is fetched.
|
|
135
|
+
dataset_kwargs : dict, optional
|
|
136
|
+
optional dictionary containing keyword arguments
|
|
137
|
+
to pass to the moabb dataset when instantiating it.
|
|
138
|
+
dataset_load_kwargs : dict, optional
|
|
139
|
+
optional dictionary containing keyword arguments
|
|
140
|
+
to pass to the moabb dataset's load_data method.
|
|
141
|
+
Allows using the moabb cache_config=None and
|
|
142
|
+
process_pipeline=None.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
dataset_name: str,
|
|
148
|
+
subject_ids: list[int] | int | None = None,
|
|
149
|
+
dataset_kwargs: dict[str, Any] | None = None,
|
|
150
|
+
dataset_load_kwargs: dict[str, Any] | None = None,
|
|
151
|
+
):
|
|
152
|
+
# soft dependency on moabb
|
|
153
|
+
from moabb import __version__ as moabb_version # type: ignore
|
|
154
|
+
|
|
155
|
+
if moabb_version == "1.0.0":
|
|
156
|
+
warnings.warn(
|
|
157
|
+
"moabb version 1.0.0 generates incorrect annotations. "
|
|
158
|
+
"Please update to another version, version 0.5 or 1.1.0 "
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
raws, description = fetch_data_with_moabb(
|
|
162
|
+
dataset_name,
|
|
163
|
+
subject_ids,
|
|
164
|
+
dataset_kwargs,
|
|
165
|
+
dataset_load_kwargs=dataset_load_kwargs,
|
|
166
|
+
)
|
|
167
|
+
all_base_ds = [
|
|
168
|
+
RawDataset(raw, row) for raw, (_, row) in zip(raws, description.iterrows())
|
|
169
|
+
]
|
|
170
|
+
super().__init__(all_base_ds)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class BNCI2014_001(MOABBDataset):
|
|
174
|
+
doc = """See moabb.datasets.bnci.BNCI2014_001
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
subject_ids: list(int) | int | None
|
|
179
|
+
(list of) int of subject(s) to be fetched. If None, data of all
|
|
180
|
+
subjects is fetched.
|
|
181
|
+
"""
|
|
182
|
+
try:
|
|
183
|
+
from moabb.datasets import BNCI2014_001
|
|
184
|
+
|
|
185
|
+
__doc__ = _update_moabb_docstring(BNCI2014_001, doc)
|
|
186
|
+
except ModuleNotFoundError:
|
|
187
|
+
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
|
|
188
|
+
|
|
189
|
+
def __init__(self, subject_ids):
|
|
190
|
+
super().__init__("BNCI2014_001", subject_ids=subject_ids)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class HGD(MOABBDataset):
|
|
194
|
+
doc = """See moabb.datasets.schirrmeister2017.Schirrmeister2017
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
subject_ids: list(int) | int | None
|
|
199
|
+
(list of) int of subject(s) to be fetched. If None, data of all
|
|
200
|
+
subjects is fetched.
|
|
201
|
+
"""
|
|
202
|
+
try:
|
|
203
|
+
from moabb.datasets import Schirrmeister2017
|
|
204
|
+
|
|
205
|
+
__doc__ = _update_moabb_docstring(Schirrmeister2017, doc)
|
|
206
|
+
except ModuleNotFoundError:
|
|
207
|
+
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
|
|
208
|
+
|
|
209
|
+
def __init__(self, subject_ids):
|
|
210
|
+
super().__init__("Schirrmeister2017", subject_ids=subject_ids)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@deprecated(
|
|
214
|
+
"`BNCI2014001` was renamed to `BNCI2014_001` in v1.13; this alias will be removed in v1.14."
|
|
215
|
+
)
|
|
216
|
+
class BNCI2014001(BNCI2014_001):
|
|
217
|
+
"""Deprecated alias for BNCI2014001."""
|
|
218
|
+
|
|
219
|
+
pass
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset classes for the NMT EEG Corpus dataset.
|
|
3
|
+
|
|
4
|
+
The NMT Scalp EEG Dataset is an open-source annotated dataset of healthy and
|
|
5
|
+
pathological EEG recordings for predictive modeling. This dataset contains
|
|
6
|
+
2,417 recordings from unique participants spanning almost 625 h.
|
|
7
|
+
|
|
8
|
+
Note:
|
|
9
|
+
- The signal unit may not be uV and further examination is required.
|
|
10
|
+
- The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
|
|
11
|
+
which needs to be further determined.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
|
|
15
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
16
|
+
#
|
|
17
|
+
# License: BSD (3-clause)
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import glob
|
|
22
|
+
import os
|
|
23
|
+
import warnings
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from unittest import mock
|
|
26
|
+
|
|
27
|
+
import mne
|
|
28
|
+
import numpy as np
|
|
29
|
+
import pandas as pd
|
|
30
|
+
from joblib import Parallel, delayed
|
|
31
|
+
from mne.datasets import fetch_dataset
|
|
32
|
+
|
|
33
|
+
from braindecode.datasets.base import BaseConcatDataset, RawDataset
|
|
34
|
+
from braindecode.datasets.utils import _correct_dataset_path
|
|
35
|
+
|
|
36
|
+
NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
|
|
37
|
+
NMT_archive_name = "NMT.zip"
|
|
38
|
+
NMT_folder_name = "MNE-NMT-eeg-dataset"
|
|
39
|
+
NMT_dataset_name = "NMT-EEG-Corpus"
|
|
40
|
+
|
|
41
|
+
NMT_dataset_params = {
|
|
42
|
+
"dataset_name": NMT_dataset_name,
|
|
43
|
+
"url": NMT_URL,
|
|
44
|
+
"archive_name": NMT_archive_name,
|
|
45
|
+
"folder_name": NMT_folder_name,
|
|
46
|
+
"hash": "77b3ce12bcaf6c6cce4e6690ea89cb22bed55af10c525077b430f6e1d2e3c6bf",
|
|
47
|
+
"config_key": NMT_dataset_name,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class NMT(BaseConcatDataset):
|
|
52
|
+
"""The NMT Scalp EEG Dataset.
|
|
53
|
+
|
|
54
|
+
An Open-Source Annotated Dataset of Healthy and Pathological EEG
|
|
55
|
+
Recordings for Predictive Modeling.
|
|
56
|
+
|
|
57
|
+
This dataset contains 2,417 recordings from unique participants spanning
|
|
58
|
+
almost 625 h.
|
|
59
|
+
|
|
60
|
+
Here, the dataset can be used for three tasks, brain-age, gender prediction,
|
|
61
|
+
abnormality detection.
|
|
62
|
+
|
|
63
|
+
The dataset is described in [Khan2022]_.
|
|
64
|
+
|
|
65
|
+
.. versionadded:: 0.9
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
path : str
|
|
70
|
+
Parent directory of the dataset.
|
|
71
|
+
recording_ids : list(int) | int
|
|
72
|
+
A (list of) int of recording id(s) to be read (order matters and will
|
|
73
|
+
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
74
|
+
then the first recording returned by this class will be chronologically
|
|
75
|
+
later than the second recording. Provide recording_ids in ascending
|
|
76
|
+
order to preserve chronological order.).
|
|
77
|
+
target_name : str
|
|
78
|
+
Can be "pathological", "gender", or "age".
|
|
79
|
+
preload : bool
|
|
80
|
+
If True, preload the data of the Raw objects.
|
|
81
|
+
|
|
82
|
+
References
|
|
83
|
+
----------
|
|
84
|
+
.. [Khan2022] Khan, H.A.,Ul Ain, R., Kamboh, A.M., Butt, H.T.,Shafait,S.,
|
|
85
|
+
Alamgir, W., Stricker, D. and Shafait, F., 2022. The NMT scalp EEG
|
|
86
|
+
dataset: an open-source annotated dataset of healthy and pathological
|
|
87
|
+
EEG recordings for predictive modeling. Frontiers in neuroscience,
|
|
88
|
+
15, p.755817.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
path=None,
|
|
94
|
+
target_name="pathological",
|
|
95
|
+
recording_ids=None,
|
|
96
|
+
preload=False,
|
|
97
|
+
n_jobs=1,
|
|
98
|
+
):
|
|
99
|
+
# Convert empty string to None for consistency
|
|
100
|
+
if path == "":
|
|
101
|
+
path = None
|
|
102
|
+
|
|
103
|
+
# Download dataset if not present
|
|
104
|
+
if path is None:
|
|
105
|
+
path = fetch_dataset(
|
|
106
|
+
dataset_params=NMT_dataset_params,
|
|
107
|
+
path=None,
|
|
108
|
+
processor="unzip",
|
|
109
|
+
force_update=False,
|
|
110
|
+
)
|
|
111
|
+
# First time we fetch the dataset, we need to move the files to the
|
|
112
|
+
# correct directory.
|
|
113
|
+
path = _correct_dataset_path(
|
|
114
|
+
path, NMT_archive_name, "nmt_scalp_eeg_dataset"
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
# Validate that the provided path is a valid NMT dataset
|
|
118
|
+
if not Path(f"{path}/Labels.csv").exists():
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"The provided path {path} does not contain a valid "
|
|
121
|
+
"NMT dataset (missing Labels.csv). Please ensure the "
|
|
122
|
+
"path points directly to the NMT dataset directory."
|
|
123
|
+
)
|
|
124
|
+
path = _correct_dataset_path(
|
|
125
|
+
path, NMT_archive_name, "nmt_scalp_eeg_dataset"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Get all file paths
|
|
129
|
+
file_paths = glob.glob(
|
|
130
|
+
os.path.join(path, "**" + os.sep + "*.edf"), recursive=True
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# sort by subject id
|
|
134
|
+
file_paths = [
|
|
135
|
+
file_path
|
|
136
|
+
for file_path in file_paths
|
|
137
|
+
if os.path.splitext(file_path)[1] == ".edf"
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
# sort by subject id
|
|
141
|
+
file_paths = sorted(
|
|
142
|
+
file_paths, key=lambda p: int(os.path.splitext(p)[0].split(os.sep)[-1])
|
|
143
|
+
)
|
|
144
|
+
if recording_ids is not None:
|
|
145
|
+
file_paths = [file_paths[rec_id] for rec_id in recording_ids]
|
|
146
|
+
|
|
147
|
+
# read labels and rearrange them to match TUH Abnormal EEG Corpus
|
|
148
|
+
description = pd.read_csv(
|
|
149
|
+
os.path.join(path, "Labels.csv"), index_col="recordname"
|
|
150
|
+
)
|
|
151
|
+
if recording_ids is not None:
|
|
152
|
+
# Match metadata by record name instead of position to fix alignment bug
|
|
153
|
+
# when CSV order differs from sorted file order
|
|
154
|
+
selected_recordnames = [os.path.basename(fp) for fp in file_paths]
|
|
155
|
+
description = description.loc[selected_recordnames]
|
|
156
|
+
description.replace(
|
|
157
|
+
{
|
|
158
|
+
"not specified": "X",
|
|
159
|
+
"female": "F",
|
|
160
|
+
"male": "M",
|
|
161
|
+
"abnormal": True,
|
|
162
|
+
"normal": False,
|
|
163
|
+
},
|
|
164
|
+
inplace=True,
|
|
165
|
+
)
|
|
166
|
+
description.rename(columns={"label": "pathological"}, inplace=True)
|
|
167
|
+
description.reset_index(drop=True, inplace=True)
|
|
168
|
+
description["path"] = file_paths
|
|
169
|
+
description = description[["path", "pathological", "age", "gender"]]
|
|
170
|
+
|
|
171
|
+
if n_jobs == 1:
|
|
172
|
+
base_datasets = [
|
|
173
|
+
self._create_dataset(d, target_name, preload)
|
|
174
|
+
for recording_id, d in description.iterrows()
|
|
175
|
+
]
|
|
176
|
+
else:
|
|
177
|
+
base_datasets = Parallel(n_jobs)(
|
|
178
|
+
delayed(self._create_dataset)(d, target_name, preload)
|
|
179
|
+
for recording_id, d in description.iterrows()
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
super().__init__(base_datasets)
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _create_dataset(d, target_name, preload):
|
|
186
|
+
raw = mne.io.read_raw_edf(d.path, preload=preload)
|
|
187
|
+
d["n_samples"] = raw.n_times
|
|
188
|
+
d["sfreq"] = raw.info["sfreq"]
|
|
189
|
+
d["train"] = "train" in d.path.split(os.sep)
|
|
190
|
+
base_dataset = RawDataset(raw, d, target_name)
|
|
191
|
+
return base_dataset
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _get_header(*args):
|
|
195
|
+
all_paths = {**_NMT_PATHS}
|
|
196
|
+
return all_paths[args[0]]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _fake_pd_read_csv(*args, **kwargs):
|
|
200
|
+
# Create a list of lists to hold the data
|
|
201
|
+
# Updated to match the file IDs from _NMT_PATHS (0000036-0000042)
|
|
202
|
+
# to align with the mocked glob.glob return value
|
|
203
|
+
data = [
|
|
204
|
+
["0000036.edf", "normal", 35, "male", "train"],
|
|
205
|
+
["0000037.edf", "abnormal", 28, "female", "test"],
|
|
206
|
+
["0000038.edf", "normal", 62, "male", "train"],
|
|
207
|
+
["0000039.edf", "abnormal", 41, "female", "test"],
|
|
208
|
+
["0000040.edf", "normal", 19, "male", "train"],
|
|
209
|
+
["0000041.edf", "abnormal", 55, "female", "test"],
|
|
210
|
+
["0000042.edf", "normal", 71, "male", "train"],
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
# Create the DataFrame, specifying column names
|
|
214
|
+
df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
|
|
215
|
+
|
|
216
|
+
# Set recordname as index to match the real pd.read_csv behavior with index_col="recordname"
|
|
217
|
+
df.set_index("recordname", inplace=True)
|
|
218
|
+
|
|
219
|
+
return df
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _fake_raw(*args, **kwargs):
|
|
223
|
+
sfreq = 10
|
|
224
|
+
ch_names = [
|
|
225
|
+
"EEG A1-REF",
|
|
226
|
+
"EEG A2-REF",
|
|
227
|
+
"EEG FP1-REF",
|
|
228
|
+
"EEG FP2-REF",
|
|
229
|
+
"EEG F3-REF",
|
|
230
|
+
"EEG F4-REF",
|
|
231
|
+
"EEG C3-REF",
|
|
232
|
+
"EEG C4-REF",
|
|
233
|
+
"EEG P3-REF",
|
|
234
|
+
"EEG P4-REF",
|
|
235
|
+
"EEG O1-REF",
|
|
236
|
+
"EEG O2-REF",
|
|
237
|
+
"EEG F7-REF",
|
|
238
|
+
"EEG F8-REF",
|
|
239
|
+
"EEG T3-REF",
|
|
240
|
+
"EEG T4-REF",
|
|
241
|
+
"EEG T5-REF",
|
|
242
|
+
"EEG T6-REF",
|
|
243
|
+
"EEG FZ-REF",
|
|
244
|
+
"EEG CZ-REF",
|
|
245
|
+
"EEG PZ-REF",
|
|
246
|
+
]
|
|
247
|
+
duration_min = 6
|
|
248
|
+
data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
|
|
249
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
|
|
250
|
+
raw = mne.io.RawArray(data=data, info=info)
|
|
251
|
+
return raw
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
_NMT_PATHS = {
|
|
255
|
+
# these are actual file paths and edf headers from NMT EEG Corpus
|
|
256
|
+
"nmt_scalp_eeg_dataset/abnormal/train/0000036.edf": b"0 0000036 M 13-May-1951 0000036 Age:32 ",
|
|
257
|
+
# noqa E501
|
|
258
|
+
"nmt_scalp_eeg_dataset/abnormal/eval/0000037.edf": b"0 0000037 M 13-May-1951 0000037 Age:32 ",
|
|
259
|
+
# noqa E501
|
|
260
|
+
"nmt_scalp_eeg_dataset/abnormal/eval/0000038.edf": b"0 0000038 M 13-May-1951 0000038 Age:32 ",
|
|
261
|
+
# noqa E501
|
|
262
|
+
"nmt_scalp_eeg_dataset/normal/train/0000039.edf": b"0 0000039 M 13-May-1951 0000039 Age:32 ",
|
|
263
|
+
# noqa E501
|
|
264
|
+
"nmt_scalp_eeg_dataset/normal/eval/0000040.edf": b"0 0000040 M 13-May-1951 0000040 Age:32 ",
|
|
265
|
+
# noqa E501
|
|
266
|
+
"nmt_scalp_eeg_dataset/normal/eval/0000041.edf": b"0 0000041 M 13-May-1951 0000041 Age:32 ",
|
|
267
|
+
# noqa E501
|
|
268
|
+
"nmt_scalp_eeg_dataset/abnormal/train/0000042.edf": b"0 0000042 M 13-May-1951 0000042 Age:32 ",
|
|
269
|
+
# noqa E501
|
|
270
|
+
"Labels.csv": b"0 recordname,label,age,gender,loc 1 0000001.edf,normal,22,not specified,train ",
|
|
271
|
+
# noqa E501
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class _NMTMock(NMT):
|
|
276
|
+
"""Mocked class for testing and examples."""
|
|
277
|
+
|
|
278
|
+
@mock.patch("pathlib.Path.exists", return_value=True)
|
|
279
|
+
@mock.patch("braindecode.datasets.nmt._correct_dataset_path")
|
|
280
|
+
@mock.patch("mne.datasets.fetch_dataset")
|
|
281
|
+
@mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
|
|
282
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
283
|
+
@mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
mock_glob,
|
|
287
|
+
mock_fetch,
|
|
288
|
+
mock_correct_path,
|
|
289
|
+
mock_path_exists,
|
|
290
|
+
path,
|
|
291
|
+
recording_ids=None,
|
|
292
|
+
target_name="pathological",
|
|
293
|
+
preload=False,
|
|
294
|
+
n_jobs=1,
|
|
295
|
+
):
|
|
296
|
+
# Prevent download by providing a dummy path if empty/None
|
|
297
|
+
if not path:
|
|
298
|
+
path = "mocked_nmt_path"
|
|
299
|
+
|
|
300
|
+
# Mock fetch_dataset to return a valid path without downloading
|
|
301
|
+
mock_fetch.return_value = path
|
|
302
|
+
# Mock _correct_dataset_path to return the path as-is
|
|
303
|
+
mock_correct_path.side_effect = lambda p, *args, **kwargs: p
|
|
304
|
+
|
|
305
|
+
with warnings.catch_warnings():
|
|
306
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
307
|
+
super().__init__(
|
|
308
|
+
path=path,
|
|
309
|
+
recording_ids=recording_ids,
|
|
310
|
+
target_name=target_name,
|
|
311
|
+
preload=preload,
|
|
312
|
+
n_jobs=n_jobs,
|
|
313
|
+
)
|