braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PhysioNet Challenge 2018 dataset.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
6
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
7
|
+
# License: BSD (3-clause)
|
|
8
|
+
# Code copied from the repository
|
|
9
|
+
# https://github.com/hubertjb/dynamic-spatial-filtering
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import os.path as op
|
|
13
|
+
import urllib
|
|
14
|
+
|
|
15
|
+
import mne
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import wfdb
|
|
19
|
+
from joblib import Parallel, delayed
|
|
20
|
+
from mne.datasets.sleep_physionet._utils import _fetch_one
|
|
21
|
+
from mne.datasets.utils import _get_path
|
|
22
|
+
from mne.utils import warn
|
|
23
|
+
|
|
24
|
+
from braindecode.datasets import BaseConcatDataset, BaseDataset
|
|
25
|
+
from braindecode.preprocessing.preprocess import _preprocess
|
|
26
|
+
|
|
27
|
+
PC18_DIR = op.join(op.dirname(__file__), "data", "pc18")
|
|
28
|
+
PC18_RECORDS = op.join(PC18_DIR, "sleep_records.csv")
|
|
29
|
+
PC18_INFO = op.join(PC18_DIR, "age-sex.csv")
|
|
30
|
+
PC18_URL = "https://physionet.org/files/challenge-2018/1.0.0/"
|
|
31
|
+
PC18_SHA1_TRAINING = op.join(PC18_DIR, "training_SHA1SUMS")
|
|
32
|
+
PC18_SHA1_TEST = op.join(PC18_DIR, "test_SHA1SUMS")
|
|
33
|
+
PC18_METAINFO_URL = "https://zenodo.org/records/13823458/files/"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Function to download a file if it doesn't exist
|
|
37
|
+
def _download_if_missing(file_path, url):
|
|
38
|
+
folder_path = op.dirname(file_path)
|
|
39
|
+
|
|
40
|
+
# Ensure the folder exists
|
|
41
|
+
if not op.exists(folder_path):
|
|
42
|
+
warn(f"Directory {folder_path} not found. Creating directory.")
|
|
43
|
+
os.makedirs(folder_path)
|
|
44
|
+
|
|
45
|
+
# Check if file exists, if not download it
|
|
46
|
+
if not op.exists(file_path):
|
|
47
|
+
warn(f"{file_path} not found. Downloading from {url}")
|
|
48
|
+
urllib.request.urlretrieve(url, file_path)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def ensure_metafiles_exist():
|
|
52
|
+
files_to_check = {
|
|
53
|
+
PC18_RECORDS: PC18_METAINFO_URL + "sleep_records.csv",
|
|
54
|
+
PC18_INFO: PC18_METAINFO_URL + "age-sex.csv",
|
|
55
|
+
PC18_SHA1_TRAINING: PC18_METAINFO_URL + "training_SHA1SUMS",
|
|
56
|
+
PC18_SHA1_TEST: PC18_METAINFO_URL + "test_SHA1SUMS",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
for file_path, url in files_to_check.items():
|
|
60
|
+
_download_if_missing(file_path, url)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _update_pc18_sleep_records(fname=PC18_RECORDS):
|
|
64
|
+
"""Create CSV file with information about available PC18 recordings."""
|
|
65
|
+
# Load and massage the checksums.
|
|
66
|
+
sha_train_df = pd.read_csv(
|
|
67
|
+
PC18_SHA1_TRAINING,
|
|
68
|
+
sep=" ",
|
|
69
|
+
header=None,
|
|
70
|
+
names=["sha", "fname"],
|
|
71
|
+
engine="python",
|
|
72
|
+
)
|
|
73
|
+
sha_test_df = pd.read_csv(
|
|
74
|
+
PC18_SHA1_TEST, sep=" ", header=None, names=["sha", "fname"], engine="python"
|
|
75
|
+
)
|
|
76
|
+
sha_train_df["Split"] = "training"
|
|
77
|
+
sha_test_df["Split"] = "test"
|
|
78
|
+
sha_df = pd.concat([sha_train_df, sha_test_df], axis=0, ignore_index=True)
|
|
79
|
+
select_records = (
|
|
80
|
+
sha_df.fname.str.startswith("tr") | sha_df.fname.str.startswith("te")
|
|
81
|
+
) & ~sha_df.fname.str.endswith("arousal.mat")
|
|
82
|
+
sha_df = sha_df[select_records]
|
|
83
|
+
sha_df["Record"] = sha_df["fname"].str.split("/", expand=True)[0]
|
|
84
|
+
sha_df["fname"] = sha_df[["Split", "fname"]].agg("/".join, axis=1)
|
|
85
|
+
|
|
86
|
+
# Load and massage the data.
|
|
87
|
+
data = pd.read_csv(PC18_INFO)
|
|
88
|
+
|
|
89
|
+
data = data.reset_index().rename({"index": "Subject"}, axis=1)
|
|
90
|
+
data["Sex"] = (
|
|
91
|
+
data["Sex"].map({"F": "female", "M": "male", "m": "male"}).astype("category")
|
|
92
|
+
)
|
|
93
|
+
data = sha_df.merge(data, on="Record")
|
|
94
|
+
|
|
95
|
+
data["Record type"] = (
|
|
96
|
+
data["fname"]
|
|
97
|
+
.str.split(".", expand=True)[1]
|
|
98
|
+
.map({"hea": "Header", "mat": "PSG", "arousal": "Arousal"})
|
|
99
|
+
.astype("category")
|
|
100
|
+
)
|
|
101
|
+
data = data[
|
|
102
|
+
["Subject", "Record", "Record type", "Split", "Age", "Sex", "sha", "fname"]
|
|
103
|
+
].sort_values(by="Subject")
|
|
104
|
+
|
|
105
|
+
# Save the data.
|
|
106
|
+
data.to_csv(fname, index=False)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _data_path(path=None):
|
|
110
|
+
"""Get path to local copy of PC18 dataset."""
|
|
111
|
+
key = "PC18_DATASET_PATH"
|
|
112
|
+
name = "PC18_DATASET_SLEEP"
|
|
113
|
+
path = _get_path(path, key, name)
|
|
114
|
+
subdirs = os.listdir(path)
|
|
115
|
+
if "training" in subdirs or "test" in subdirs: # the specified path is
|
|
116
|
+
# already at the training and test folders level
|
|
117
|
+
return path
|
|
118
|
+
else:
|
|
119
|
+
return op.join(path, "pc18-sleep-data")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def fetch_pc18_data(subjects, path=None, force_update=False, base_url=PC18_URL):
|
|
123
|
+
"""Get paths to local copies of PhysioNet Challenge 2018 dataset files.
|
|
124
|
+
|
|
125
|
+
This will fetch data from the publicly available PhysioNet Computing in
|
|
126
|
+
Cardiology Challenge 2018 dataset on sleep arousal detection [1]_ [2]_.
|
|
127
|
+
This corresponds to 1983 recordings from individual subjects with
|
|
128
|
+
(suspected) sleep apnea. The dataset is separated into a training set with
|
|
129
|
+
994 recordings for which arousal annotation are available and a test set
|
|
130
|
+
with 989 recordings for which the labels have not been revealed. Across the
|
|
131
|
+
entire dataset, mean age is 55 years old and 65% of recordings are from
|
|
132
|
+
male subjects.
|
|
133
|
+
|
|
134
|
+
More information can be found on the
|
|
135
|
+
`physionet website <https://physionet.org/content/challenge-2018/1.0.0/>`_.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
subjects : list of int
|
|
140
|
+
The subjects to use. Can be in the range of 0-1982 (inclusive). Test
|
|
141
|
+
recordings are 0-988, while training recordings are 989-1982.
|
|
142
|
+
path : None | str
|
|
143
|
+
Location of where to look for the PC18 data storing location. If None,
|
|
144
|
+
the environment variable or config parameter ``PC18_DATASET_PATH``
|
|
145
|
+
is used. If it doesn't exist, the "~/mne_data" directory is used. If
|
|
146
|
+
the dataset is not found under the given path, the data will be
|
|
147
|
+
automatically downloaded to the specified folder.
|
|
148
|
+
force_update : bool
|
|
149
|
+
Force update of the dataset even if a local copy exists.
|
|
150
|
+
update_path : bool | None
|
|
151
|
+
If True, set the PC18_DATASET_PATH in mne-python config to the given
|
|
152
|
+
path. If None, the user is prompted.
|
|
153
|
+
base_url : str
|
|
154
|
+
The URL root.
|
|
155
|
+
%(verbose)s
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
paths : list
|
|
160
|
+
List of local data paths of the given type.
|
|
161
|
+
|
|
162
|
+
References
|
|
163
|
+
----------
|
|
164
|
+
.. [1] Mohammad M Ghassemi, Benjamin E Moody, Li-wei H Lehman, Christopher
|
|
165
|
+
Song, Qiao Li, Haoqi Sun, Roger G Mark, M Brandon Westover, Gari D
|
|
166
|
+
Clifford. You Snooze, You Win: the PhysioNet/Computing in Cardiology
|
|
167
|
+
Challenge 2018.
|
|
168
|
+
.. [2] Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C.,
|
|
169
|
+
Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and
|
|
170
|
+
PhysioNet: Components of a new research resource for complex physiologic
|
|
171
|
+
signals. Circulation [Online]. 101 (23), pp. e215–e220.)
|
|
172
|
+
"""
|
|
173
|
+
records = pd.read_csv(PC18_RECORDS)
|
|
174
|
+
psg_records = records[records["Record type"] == "PSG"]
|
|
175
|
+
hea_records = records[records["Record type"] == "Header"]
|
|
176
|
+
arousal_records = records[records["Record type"] == "Arousal"]
|
|
177
|
+
|
|
178
|
+
path = _data_path(path=path)
|
|
179
|
+
params = [path, force_update, base_url]
|
|
180
|
+
|
|
181
|
+
fnames = []
|
|
182
|
+
for subject in subjects:
|
|
183
|
+
for idx in np.where(psg_records["Subject"] == subject)[0]:
|
|
184
|
+
psg_fname = _fetch_one(
|
|
185
|
+
psg_records["fname"].iloc[idx], psg_records["sha"].iloc[idx], *params
|
|
186
|
+
)
|
|
187
|
+
hea_fname = _fetch_one(
|
|
188
|
+
hea_records["fname"].iloc[idx], hea_records["sha"].iloc[idx], *params
|
|
189
|
+
)
|
|
190
|
+
if psg_records["Split"].iloc[idx] == "training":
|
|
191
|
+
train_idx = np.where(arousal_records["Subject"] == subject)[0][0]
|
|
192
|
+
arousal_fname = _fetch_one(
|
|
193
|
+
arousal_records["fname"].iloc[train_idx],
|
|
194
|
+
arousal_records["sha"].iloc[train_idx],
|
|
195
|
+
*params,
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
arousal_fname = None
|
|
199
|
+
fnames.append([psg_fname, hea_fname, arousal_fname])
|
|
200
|
+
|
|
201
|
+
return fnames
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _convert_wfdb_anns_to_mne_annotations(annots):
|
|
205
|
+
"""Convert wfdb.io.Annotation format to MNE's.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
annots : wfdb.io.Annotation
|
|
210
|
+
Annotation object obtained by e.g. loading an annotation file with
|
|
211
|
+
wfdb.rdann().
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
mne.Annotations :
|
|
216
|
+
MNE Annotations object.
|
|
217
|
+
"""
|
|
218
|
+
ann_chs = set(annots.chan)
|
|
219
|
+
onsets = annots.sample / annots.fs
|
|
220
|
+
new_onset, new_duration, new_description = list(), list(), list()
|
|
221
|
+
for channel_name in ann_chs:
|
|
222
|
+
mask = annots.chan == channel_name
|
|
223
|
+
ch_onsets = onsets[mask]
|
|
224
|
+
ch_descs = np.array(annots.aux_note)[mask]
|
|
225
|
+
|
|
226
|
+
# Events with beginning and end, defined by '(event' and 'event)'
|
|
227
|
+
if all([(i.startswith("(") or i.endswith(")")) for i in ch_descs]):
|
|
228
|
+
pass
|
|
229
|
+
else: # Sleep stage-like annotations
|
|
230
|
+
ch_durations = np.concatenate([np.diff(ch_onsets), [30]])
|
|
231
|
+
if all(ch_durations > 0):
|
|
232
|
+
ValueError("Negative duration")
|
|
233
|
+
new_onset.extend(ch_onsets)
|
|
234
|
+
new_duration.extend(ch_durations)
|
|
235
|
+
new_description.extend(ch_descs)
|
|
236
|
+
|
|
237
|
+
mne_annots = mne.Annotations(
|
|
238
|
+
new_onset, new_duration, new_description, orig_time=None
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return mne_annots
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class SleepPhysionetChallenge2018(BaseConcatDataset):
|
|
245
|
+
"""Physionet Challenge 2018 polysomnography dataset.
|
|
246
|
+
|
|
247
|
+
Sleep dataset from https://physionet.org/content/challenge-2018/1.0.0/.
|
|
248
|
+
Contains overnight recordings from 1983 healthy subjects.
|
|
249
|
+
|
|
250
|
+
The total size is 266 GB, so make sure you have enough space before
|
|
251
|
+
downloading.
|
|
252
|
+
|
|
253
|
+
See `fetch_pc18_data` for a more complete description.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
subject_ids: list(int) | str | None
|
|
258
|
+
(list of) int of subject(s) to be loaded.
|
|
259
|
+
- If `None`, loads all subjects (both training and test sets [no label associated]).
|
|
260
|
+
- If `"training"`, loads only the training set subjects.
|
|
261
|
+
- If `"test"`, loads only the test set subjects, no label associated!
|
|
262
|
+
- Otherwise, expects an iterable of subject IDs.
|
|
263
|
+
path : None | str
|
|
264
|
+
Location of where to look for the PC18 data storing location. If None,
|
|
265
|
+
the environment variable or config parameter ``MNE_DATASETS_PC18_PATH``
|
|
266
|
+
is used. If it doesn't exist, the "~/mne_data" directory is used. If
|
|
267
|
+
the dataset is not found under the given path, the data will be
|
|
268
|
+
automatically downloaded to the specified folder.
|
|
269
|
+
load_eeg_only: bool
|
|
270
|
+
If True, only load the EEG channels and discard the others (EOG, EMG,
|
|
271
|
+
temperature, respiration) to avoid resampling the other signals.
|
|
272
|
+
preproc : list(Preprocessor) | None
|
|
273
|
+
List of preprocessors to apply to each file individually. This way the
|
|
274
|
+
data can e.g., be downsampled (temporally and spatially) to limit the
|
|
275
|
+
memory usage of the entire Dataset object. This also enables applying
|
|
276
|
+
preprocessing in parallel over the recordings.
|
|
277
|
+
n_jobs : int
|
|
278
|
+
Number of parallel processes.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
def __init__(
|
|
282
|
+
self,
|
|
283
|
+
subject_ids="training",
|
|
284
|
+
path=None,
|
|
285
|
+
load_eeg_only=True,
|
|
286
|
+
preproc=None,
|
|
287
|
+
n_jobs=1,
|
|
288
|
+
):
|
|
289
|
+
if subject_ids is None:
|
|
290
|
+
subject_ids = range(1983)
|
|
291
|
+
warn(
|
|
292
|
+
""""
|
|
293
|
+
You are loading the complete dataset (0 to 1982),
|
|
294
|
+
which includes a portion of the test set (0 to 988)
|
|
295
|
+
from the Physionet Challenge 2018. Note that the test set
|
|
296
|
+
does not have associated labels, so supervised classification
|
|
297
|
+
cannot be performed on these data.""",
|
|
298
|
+
UserWarning,
|
|
299
|
+
)
|
|
300
|
+
elif subject_ids == "training":
|
|
301
|
+
subject_ids = range(989, 1983)
|
|
302
|
+
elif subject_ids == "test":
|
|
303
|
+
subject_ids = range(989)
|
|
304
|
+
warn(
|
|
305
|
+
"""
|
|
306
|
+
This subset does not have associated labels, so supervised
|
|
307
|
+
classification (sleep stage) cannot be performed on this data.
|
|
308
|
+
You can also use the meta information as a label to perform
|
|
309
|
+
another task.
|
|
310
|
+
"""
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
# If subject_ids is an iterable, check if it includes any test set IDs
|
|
314
|
+
if any(sid < 989 for sid in subject_ids):
|
|
315
|
+
warn(
|
|
316
|
+
"""
|
|
317
|
+
You are loading a subset of the data that includes test set
|
|
318
|
+
subjects (subject IDs: 0 to 988). These subjects do not have
|
|
319
|
+
associated labels, which means supervised classification
|
|
320
|
+
(sleep stage) cannot be performed on this data. You can also
|
|
321
|
+
use the meta information as a label to perform another task.
|
|
322
|
+
""",
|
|
323
|
+
UserWarning,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
ensure_metafiles_exist()
|
|
327
|
+
|
|
328
|
+
paths = fetch_pc18_data(subject_ids, path=path)
|
|
329
|
+
|
|
330
|
+
self.info_df = pd.read_csv(PC18_INFO)
|
|
331
|
+
|
|
332
|
+
if n_jobs == 1:
|
|
333
|
+
all_base_ds = [
|
|
334
|
+
self._load_raw(
|
|
335
|
+
subj_nb=subject_id,
|
|
336
|
+
raw_fname=p[0],
|
|
337
|
+
arousal_fname=p[2],
|
|
338
|
+
load_eeg_only=load_eeg_only,
|
|
339
|
+
preproc=preproc,
|
|
340
|
+
)
|
|
341
|
+
for subject_id, p in zip(subject_ids, paths)
|
|
342
|
+
]
|
|
343
|
+
else:
|
|
344
|
+
all_base_ds = Parallel(n_jobs=n_jobs)(
|
|
345
|
+
delayed(self._load_raw)(
|
|
346
|
+
subject_id,
|
|
347
|
+
p[0],
|
|
348
|
+
p[2],
|
|
349
|
+
load_eeg_only=load_eeg_only,
|
|
350
|
+
preproc=preproc,
|
|
351
|
+
)
|
|
352
|
+
for subject_id, p in zip(subject_ids, paths)
|
|
353
|
+
)
|
|
354
|
+
super().__init__(all_base_ds)
|
|
355
|
+
|
|
356
|
+
def _load_raw(self, subj_nb, raw_fname, arousal_fname, load_eeg_only, preproc):
|
|
357
|
+
channel_types = ["eeg"] * 7
|
|
358
|
+
if load_eeg_only:
|
|
359
|
+
channels = list(range(7))
|
|
360
|
+
else:
|
|
361
|
+
channel_types += ["emg", "misc", "misc", "misc", "misc", "ecg"]
|
|
362
|
+
channels = None
|
|
363
|
+
|
|
364
|
+
# Load raw signals and header
|
|
365
|
+
record = wfdb.io.rdrecord(op.splitext(raw_fname[0])[0], channels=channels)
|
|
366
|
+
|
|
367
|
+
# Convert to right units for MNE (EEG should be in V)
|
|
368
|
+
data = record.p_signal.T
|
|
369
|
+
data[np.array(record.units) == "uV"] /= 1e6
|
|
370
|
+
data[np.array(record.units) == "mV"] /= 1e3
|
|
371
|
+
info = mne.create_info(record.sig_name, record.fs, channel_types)
|
|
372
|
+
raw_file = mne.io.RawArray(data, info)
|
|
373
|
+
|
|
374
|
+
# Extract annotations
|
|
375
|
+
if arousal_fname is not None:
|
|
376
|
+
annots = wfdb.rdann(
|
|
377
|
+
op.splitext(raw_fname[0])[0],
|
|
378
|
+
"arousal",
|
|
379
|
+
sampfrom=0,
|
|
380
|
+
sampto=None,
|
|
381
|
+
shift_samps=False,
|
|
382
|
+
return_label_elements=["symbol"],
|
|
383
|
+
summarize_labels=False,
|
|
384
|
+
)
|
|
385
|
+
mne_annots = _convert_wfdb_anns_to_mne_annotations(annots)
|
|
386
|
+
raw_file = raw_file.set_annotations(mne_annots)
|
|
387
|
+
|
|
388
|
+
record_name = op.splitext(op.basename(raw_fname[0]))[0]
|
|
389
|
+
record_info = self.info_df[self.info_df["Record"] == record_name].iloc[0]
|
|
390
|
+
if record_info["Record"].startswith("tr"):
|
|
391
|
+
split = "training"
|
|
392
|
+
elif record_info["Record"].startswith("te"):
|
|
393
|
+
split = "test"
|
|
394
|
+
else:
|
|
395
|
+
split = "unknown"
|
|
396
|
+
|
|
397
|
+
desc = pd.Series(
|
|
398
|
+
{
|
|
399
|
+
"subject": subj_nb,
|
|
400
|
+
"record": record_info["Record"],
|
|
401
|
+
"split": split,
|
|
402
|
+
"age": record_info["Age"],
|
|
403
|
+
"sex": record_info["Sex"],
|
|
404
|
+
},
|
|
405
|
+
name="",
|
|
406
|
+
)
|
|
407
|
+
base_dataset = BaseDataset(raw_file, desc)
|
|
408
|
+
|
|
409
|
+
if preproc is not None:
|
|
410
|
+
_preprocess(base_dataset, None, preproc)
|
|
411
|
+
|
|
412
|
+
return base_dataset
|
|
@@ -3,14 +3,16 @@
|
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
6
8
|
import os
|
|
7
9
|
|
|
10
|
+
import mne
|
|
8
11
|
import numpy as np
|
|
9
12
|
import pandas as pd
|
|
10
|
-
import mne
|
|
11
13
|
from mne.datasets.sleep_physionet.age import fetch_data
|
|
12
14
|
|
|
13
|
-
from .base import
|
|
15
|
+
from .base import BaseConcatDataset, BaseDataset
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class SleepPhysionet(BaseConcatDataset):
|
|
@@ -19,7 +21,7 @@ class SleepPhysionet(BaseConcatDataset):
|
|
|
19
21
|
Sleep dataset from https://physionet.org/content/sleep-edfx/1.0.0/.
|
|
20
22
|
Contains overnight recordings from 78 healthy subjects.
|
|
21
23
|
|
|
22
|
-
See
|
|
24
|
+
See `MNE example <https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html>`.
|
|
23
25
|
|
|
24
26
|
Parameters
|
|
25
27
|
----------
|
|
@@ -42,34 +44,52 @@ class SleepPhysionet(BaseConcatDataset):
|
|
|
42
44
|
If not None crop the raw files (e.g. to use only the first 3h).
|
|
43
45
|
Example: ``crop=(0, 3600*3)`` to keep only the first 3h.
|
|
44
46
|
"""
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
subject_ids: list[int] | int | None = None,
|
|
51
|
+
recording_ids: list[int] | None = None,
|
|
52
|
+
preload=False,
|
|
53
|
+
load_eeg_only=True,
|
|
54
|
+
crop_wake_mins=30,
|
|
55
|
+
crop=None,
|
|
56
|
+
):
|
|
47
57
|
if subject_ids is None:
|
|
48
|
-
subject_ids = range(83)
|
|
58
|
+
subject_ids = list(range(83))
|
|
49
59
|
if recording_ids is None:
|
|
50
60
|
recording_ids = [1, 2]
|
|
51
61
|
|
|
52
|
-
paths = fetch_data(
|
|
53
|
-
subject_ids, recording=recording_ids, on_missing='warn')
|
|
62
|
+
paths = fetch_data(subject_ids, recording=recording_ids, on_missing="warn")
|
|
54
63
|
|
|
55
64
|
all_base_ds = list()
|
|
56
65
|
for p in paths:
|
|
57
66
|
raw, desc = self._load_raw(
|
|
58
|
-
p[0],
|
|
59
|
-
|
|
67
|
+
p[0],
|
|
68
|
+
p[1],
|
|
69
|
+
preload=preload,
|
|
70
|
+
load_eeg_only=load_eeg_only,
|
|
71
|
+
crop_wake_mins=crop_wake_mins,
|
|
72
|
+
crop=crop,
|
|
73
|
+
)
|
|
60
74
|
base_ds = BaseDataset(raw, desc)
|
|
61
75
|
all_base_ds.append(base_ds)
|
|
62
76
|
super().__init__(all_base_ds)
|
|
63
77
|
|
|
64
78
|
@staticmethod
|
|
65
|
-
def _load_raw(
|
|
66
|
-
|
|
79
|
+
def _load_raw(
|
|
80
|
+
raw_fname,
|
|
81
|
+
ann_fname,
|
|
82
|
+
preload,
|
|
83
|
+
load_eeg_only=True,
|
|
84
|
+
crop_wake_mins=False,
|
|
85
|
+
crop=None,
|
|
86
|
+
):
|
|
67
87
|
ch_mapping = {
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
88
|
+
"EOG horizontal": "eog",
|
|
89
|
+
"Resp oro-nasal": "misc",
|
|
90
|
+
"EMG submental": "misc",
|
|
91
|
+
"Temp rectal": "misc",
|
|
92
|
+
"Event marker": "misc",
|
|
73
93
|
}
|
|
74
94
|
exclude = list(ch_mapping.keys()) if load_eeg_only else ()
|
|
75
95
|
|
|
@@ -79,19 +99,16 @@ class SleepPhysionet(BaseConcatDataset):
|
|
|
79
99
|
|
|
80
100
|
if crop_wake_mins > 0:
|
|
81
101
|
# Find first and last sleep stages
|
|
82
|
-
mask = [
|
|
83
|
-
x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description]
|
|
102
|
+
mask = [x[-1] in ["1", "2", "3", "4", "R"] for x in annots.description]
|
|
84
103
|
sleep_event_inds = np.where(mask)[0]
|
|
85
104
|
|
|
86
105
|
# Crop raw
|
|
87
|
-
tmin = annots[int(sleep_event_inds[0])][
|
|
88
|
-
tmax = annots[int(sleep_event_inds[-1])][
|
|
89
|
-
raw.crop(tmin=max(tmin, raw.times[0]),
|
|
90
|
-
tmax=min(tmax, raw.times[-1]))
|
|
106
|
+
tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
|
|
107
|
+
tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
|
|
108
|
+
raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))
|
|
91
109
|
|
|
92
110
|
# Rename EEG channels
|
|
93
|
-
ch_names = {
|
|
94
|
-
i: i.replace('EEG ', '') for i in raw.ch_names if 'EEG' in i}
|
|
111
|
+
ch_names = {i: i.replace("EEG ", "") for i in raw.ch_names if "EEG" in i}
|
|
95
112
|
raw.rename_channels(ch_names)
|
|
96
113
|
|
|
97
114
|
if not load_eeg_only:
|
|
@@ -103,6 +120,6 @@ class SleepPhysionet(BaseConcatDataset):
|
|
|
103
120
|
basename = os.path.basename(raw_fname)
|
|
104
121
|
subj_nb = int(basename[3:5])
|
|
105
122
|
sess_nb = int(basename[5])
|
|
106
|
-
desc = pd.Series({
|
|
123
|
+
desc = pd.Series({"subject": subj_nb, "recording": sess_nb}, name="")
|
|
107
124
|
|
|
108
125
|
return raw, desc
|