braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the.
|
|
3
|
+
|
|
4
|
+
TUH Abnormal EEG Corpus.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
8
|
+
#
|
|
9
|
+
# License: BSD (3-clause)
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import glob
|
|
14
|
+
import os
|
|
15
|
+
import re
|
|
16
|
+
import warnings
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
from typing import Iterable
|
|
19
|
+
from unittest import mock
|
|
20
|
+
|
|
21
|
+
import mne
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from joblib import Parallel, delayed
|
|
25
|
+
|
|
26
|
+
from .base import BaseConcatDataset, RawDataset
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TUH(BaseConcatDataset):
|
|
30
|
+
"""Temple University Hospital (TUH) EEG Corpus.
|
|
31
|
+
|
|
32
|
+
(www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
path : str
|
|
37
|
+
Parent directory of the dataset.
|
|
38
|
+
recording_ids : list(int) | int
|
|
39
|
+
A (list of) int of recording id(s) to be read (order matters and will
|
|
40
|
+
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
41
|
+
then the first recording returned by this class will be chronologically
|
|
42
|
+
later then the second recording. Provide recording_ids in ascending
|
|
43
|
+
order to preserve chronological order.).
|
|
44
|
+
target_name : str
|
|
45
|
+
Can be 'gender', or 'age'.
|
|
46
|
+
preload : bool
|
|
47
|
+
If True, preload the data of the Raw objects.
|
|
48
|
+
add_physician_reports : bool
|
|
49
|
+
If True, the physician reports will be read from disk and added to the
|
|
50
|
+
description.
|
|
51
|
+
rename_channels : bool
|
|
52
|
+
If True, rename the EEG channels to the standard 10-05 system.
|
|
53
|
+
set_montage : bool
|
|
54
|
+
If True, set the montage to the standard 10-05 system.
|
|
55
|
+
n_jobs : int
|
|
56
|
+
Number of jobs to be used to read files in parallel.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
path: str,
|
|
62
|
+
recording_ids: list[int] | None = None,
|
|
63
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
64
|
+
preload: bool = False,
|
|
65
|
+
add_physician_reports: bool = False,
|
|
66
|
+
rename_channels: bool = False,
|
|
67
|
+
set_montage: bool = False,
|
|
68
|
+
n_jobs: int = 1,
|
|
69
|
+
):
|
|
70
|
+
if set_montage:
|
|
71
|
+
assert rename_channels, (
|
|
72
|
+
"If set_montage is True, rename_channels must be True."
|
|
73
|
+
)
|
|
74
|
+
# create an index of all files and gather easily accessible info
|
|
75
|
+
# without actually touching the files
|
|
76
|
+
file_paths = glob.glob(os.path.join(path, "**/*.edf"), recursive=True)
|
|
77
|
+
descriptions = _create_description(file_paths)
|
|
78
|
+
# sort the descriptions chronologicaly
|
|
79
|
+
descriptions = _sort_chronologically(descriptions)
|
|
80
|
+
# limit to specified recording ids before doing slow stuff
|
|
81
|
+
if recording_ids is not None:
|
|
82
|
+
if not isinstance(recording_ids, Iterable):
|
|
83
|
+
# Assume it is an integer specifying number
|
|
84
|
+
# of recordings to load
|
|
85
|
+
recording_ids = range(recording_ids)
|
|
86
|
+
descriptions = descriptions[recording_ids]
|
|
87
|
+
|
|
88
|
+
# workaround to ensure warnings are suppressed when running in parallel
|
|
89
|
+
def create_dataset(*args, **kwargs):
|
|
90
|
+
with warnings.catch_warnings():
|
|
91
|
+
warnings.filterwarnings(
|
|
92
|
+
"ignore", message=".*not in description. '__getitem__'"
|
|
93
|
+
)
|
|
94
|
+
return self._create_dataset(*args, **kwargs)
|
|
95
|
+
|
|
96
|
+
# this is the second loop (slow)
|
|
97
|
+
# create datasets gathering more info about the files touching them
|
|
98
|
+
# reading the raws and potentially preloading the data
|
|
99
|
+
# disable joblib for tests. mocking seems to fail otherwise
|
|
100
|
+
if n_jobs == 1:
|
|
101
|
+
base_datasets = [
|
|
102
|
+
create_dataset(
|
|
103
|
+
descriptions[i],
|
|
104
|
+
target_name,
|
|
105
|
+
preload,
|
|
106
|
+
add_physician_reports,
|
|
107
|
+
rename_channels,
|
|
108
|
+
set_montage,
|
|
109
|
+
)
|
|
110
|
+
for i in descriptions.columns
|
|
111
|
+
]
|
|
112
|
+
else:
|
|
113
|
+
base_datasets = Parallel(n_jobs)(
|
|
114
|
+
delayed(create_dataset)(
|
|
115
|
+
descriptions[i],
|
|
116
|
+
target_name,
|
|
117
|
+
preload,
|
|
118
|
+
add_physician_reports,
|
|
119
|
+
rename_channels,
|
|
120
|
+
set_montage,
|
|
121
|
+
)
|
|
122
|
+
for i in descriptions.columns
|
|
123
|
+
)
|
|
124
|
+
super().__init__(base_datasets)
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def _rename_channels(raw):
|
|
128
|
+
"""
|
|
129
|
+
Renames the EEG channels using mne conventions and sets their type to 'eeg'.
|
|
130
|
+
|
|
131
|
+
See https://isip.piconepress.com/publications/reports/2020/tuh_eeg/electrodes/
|
|
132
|
+
"""
|
|
133
|
+
# remove ref suffix and prefix:
|
|
134
|
+
# TODO: replace with removesuffix and removeprefix when 3.8 is dropped
|
|
135
|
+
mapping_strip = {
|
|
136
|
+
c: c.replace("-REF", "").replace("-LE", "").replace("EEG ", "")
|
|
137
|
+
for c in raw.ch_names
|
|
138
|
+
}
|
|
139
|
+
raw.rename_channels(mapping_strip)
|
|
140
|
+
|
|
141
|
+
montage1005 = mne.channels.make_standard_montage("standard_1005")
|
|
142
|
+
mapping_eeg_names = {
|
|
143
|
+
c.upper(): c for c in montage1005.ch_names if c.upper() in raw.ch_names
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Set channels whose type could not be inferred (defaulted to "eeg") to "misc":
|
|
147
|
+
non_eeg_names = [c for c in raw.ch_names if c not in mapping_eeg_names]
|
|
148
|
+
if non_eeg_names:
|
|
149
|
+
non_eeg_types = raw.get_channel_types(picks=non_eeg_names)
|
|
150
|
+
mapping_non_eeg_types = {
|
|
151
|
+
c: "misc" for c, t in zip(non_eeg_names, non_eeg_types) if t == "eeg"
|
|
152
|
+
}
|
|
153
|
+
if mapping_non_eeg_types:
|
|
154
|
+
raw.set_channel_types(mapping_non_eeg_types)
|
|
155
|
+
|
|
156
|
+
if mapping_eeg_names:
|
|
157
|
+
# Set 1005 channels type to "eeg":
|
|
158
|
+
raw.set_channel_types(
|
|
159
|
+
{c: "eeg" for c in mapping_eeg_names}, on_unit_change="ignore"
|
|
160
|
+
)
|
|
161
|
+
# Fix capitalized EEG channel names:
|
|
162
|
+
raw.rename_channels(mapping_eeg_names)
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _set_montage(raw):
|
|
166
|
+
montage = mne.channels.make_standard_montage("standard_1005")
|
|
167
|
+
raw.set_montage(montage, on_missing="ignore")
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _create_dataset(
|
|
171
|
+
description,
|
|
172
|
+
target_name,
|
|
173
|
+
preload,
|
|
174
|
+
add_physician_reports,
|
|
175
|
+
rename_channels,
|
|
176
|
+
set_montage,
|
|
177
|
+
):
|
|
178
|
+
file_path = description.loc["path"]
|
|
179
|
+
|
|
180
|
+
# parse age and gender information from EDF header
|
|
181
|
+
age, gender = _parse_age_and_gender_from_edf_header(file_path)
|
|
182
|
+
raw = mne.io.read_raw_edf(
|
|
183
|
+
file_path, preload=preload, infer_types=True, verbose="error"
|
|
184
|
+
)
|
|
185
|
+
if rename_channels:
|
|
186
|
+
TUH._rename_channels(raw)
|
|
187
|
+
if set_montage:
|
|
188
|
+
TUH._set_montage(raw)
|
|
189
|
+
|
|
190
|
+
meas_date = (
|
|
191
|
+
datetime(1, 1, 1, tzinfo=timezone.utc)
|
|
192
|
+
if raw.info["meas_date"] is None
|
|
193
|
+
else raw.info["meas_date"]
|
|
194
|
+
)
|
|
195
|
+
# if this is old version of the data and the year could be parsed from
|
|
196
|
+
# file paths, use this instead as before
|
|
197
|
+
if "year" in description:
|
|
198
|
+
meas_date = meas_date.replace(*description[["year", "month", "day"]])
|
|
199
|
+
raw.set_meas_date(meas_date)
|
|
200
|
+
|
|
201
|
+
d = {
|
|
202
|
+
"age": int(age),
|
|
203
|
+
"gender": gender,
|
|
204
|
+
}
|
|
205
|
+
# if year exists in description = old version
|
|
206
|
+
# if not, get it from meas_date in raw.info and add to description
|
|
207
|
+
# if meas_date is None, create fake one
|
|
208
|
+
if "year" not in description:
|
|
209
|
+
d["year"] = raw.info["meas_date"].year
|
|
210
|
+
d["month"] = raw.info["meas_date"].month
|
|
211
|
+
d["day"] = raw.info["meas_date"].day
|
|
212
|
+
|
|
213
|
+
# read info relevant for preprocessing from raw without loading it
|
|
214
|
+
if add_physician_reports:
|
|
215
|
+
physician_report = _read_physician_report(file_path)
|
|
216
|
+
d["report"] = physician_report
|
|
217
|
+
additional_description = pd.Series(d)
|
|
218
|
+
description = pd.concat([description, additional_description])
|
|
219
|
+
base_dataset = RawDataset(raw, description, target_name=target_name)
|
|
220
|
+
return base_dataset
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _create_description(file_paths):
|
|
224
|
+
descriptions = [_parse_description_from_file_path(f) for f in file_paths]
|
|
225
|
+
descriptions = pd.DataFrame(descriptions)
|
|
226
|
+
return descriptions.T
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _sort_chronologically(descriptions):
|
|
230
|
+
descriptions.sort_values(
|
|
231
|
+
["year", "month", "day", "subject", "session", "segment"], axis=1, inplace=True
|
|
232
|
+
)
|
|
233
|
+
return descriptions
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _read_date(file_path):
|
|
237
|
+
date_path = file_path.replace(".edf", "_date.txt")
|
|
238
|
+
# if date file exists, read it
|
|
239
|
+
if os.path.exists(date_path):
|
|
240
|
+
description = pd.read_json(date_path, typ="series").to_dict()
|
|
241
|
+
# otherwise read edf file, extract date and store to file
|
|
242
|
+
else:
|
|
243
|
+
raw = mne.io.read_raw_edf(file_path, preload=False, verbose="error")
|
|
244
|
+
description = {
|
|
245
|
+
"year": raw.info["meas_date"].year,
|
|
246
|
+
"month": raw.info["meas_date"].month,
|
|
247
|
+
"day": raw.info["meas_date"].day,
|
|
248
|
+
}
|
|
249
|
+
# if the txt file storing the recording date does not exist, create it
|
|
250
|
+
try:
|
|
251
|
+
pd.Series(description).to_json(date_path)
|
|
252
|
+
except OSError:
|
|
253
|
+
warnings.warn(
|
|
254
|
+
f"Cannot save date file to {date_path}. "
|
|
255
|
+
f"This might slow down creation of the dataset."
|
|
256
|
+
)
|
|
257
|
+
return description
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _parse_description_from_file_path(file_path):
|
|
261
|
+
# stackoverflow.com/questions/3167154/how-to-split-a-dos-path-into-its-components-in-python # noqa
|
|
262
|
+
file_path = os.path.normpath(file_path)
|
|
263
|
+
tokens = file_path.split(os.sep)
|
|
264
|
+
# Extract version number and tuh_eeg_abnormal/tuh_eeg from file path
|
|
265
|
+
if ("train" in tokens) or ("eval" in tokens): # tuh_eeg_abnormal
|
|
266
|
+
abnormal = True
|
|
267
|
+
# Tokens[-2] is channel configuration (always 01_tcp_ar in abnormal)
|
|
268
|
+
# on new versions, or
|
|
269
|
+
# session (e.g. s004_2013_08_15) on old versions
|
|
270
|
+
if tokens[-2].split("_")[0][0] == "s": # s denoting session number
|
|
271
|
+
version = tokens[-9] # Before dec 2022 updata
|
|
272
|
+
else:
|
|
273
|
+
version = tokens[-6] # After the dec 2022 update
|
|
274
|
+
|
|
275
|
+
else: # tuh_eeg
|
|
276
|
+
abnormal = False
|
|
277
|
+
version = tokens[-7]
|
|
278
|
+
v_number = int(version[1])
|
|
279
|
+
|
|
280
|
+
if (abnormal and v_number >= 3) or ((not abnormal) and v_number >= 2):
|
|
281
|
+
# New file path structure for versions after december 2022,
|
|
282
|
+
# expect file paths as
|
|
283
|
+
# tuh_eeg/v2.0.0/edf/000/aaaaaaaa/
|
|
284
|
+
# s001_2015_12_30/01_tcp_ar/aaaaaaaa_s001_t000.edf
|
|
285
|
+
# or for abnormal:
|
|
286
|
+
# tuh_eeg_abnormal/v3.0.0/edf/train/normal/
|
|
287
|
+
# 01_tcp_ar/aaaaaaav_s004_t000.edf
|
|
288
|
+
subject_id = tokens[-1].split("_")[0]
|
|
289
|
+
session = tokens[-1].split("_")[1]
|
|
290
|
+
segment = tokens[-1].split("_")[2].split(".")[0]
|
|
291
|
+
description = _read_date(file_path)
|
|
292
|
+
description.update(
|
|
293
|
+
{
|
|
294
|
+
"path": file_path,
|
|
295
|
+
"version": version,
|
|
296
|
+
"subject": subject_id,
|
|
297
|
+
"session": int(session[1:]),
|
|
298
|
+
"segment": int(segment[1:]),
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
if not abnormal:
|
|
302
|
+
year, month, day = tokens[-3].split("_")[1:]
|
|
303
|
+
description["year"] = int(year)
|
|
304
|
+
description["month"] = int(month)
|
|
305
|
+
description["day"] = int(day)
|
|
306
|
+
return description
|
|
307
|
+
else: # Old file path structure
|
|
308
|
+
# expect file paths as tuh_eeg/version/file_type/reference/data_split/
|
|
309
|
+
# subject/recording session/file
|
|
310
|
+
# e.g. tuh_eeg/v1.1.0/edf/01_tcp_ar/027/00002729/
|
|
311
|
+
# s001_2006_04_12/00002729_s001.edf
|
|
312
|
+
# or for abnormal
|
|
313
|
+
# version/file type/data_split/pathology status/
|
|
314
|
+
# reference/subset/subject/recording session/file
|
|
315
|
+
# v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
|
|
316
|
+
# s004_2013_08_15/00000021_s004_t000.edf
|
|
317
|
+
subject_id = tokens[-1].split("_")[0]
|
|
318
|
+
session = tokens[-2].split("_")[0] # string on format 's000'
|
|
319
|
+
# According to the example path in the comment 8 lines above,
|
|
320
|
+
# segment is not included in the file name
|
|
321
|
+
segment = tokens[-1].split("_")[-1].split(".")[0] # TODO: test with tuh_eeg
|
|
322
|
+
year, month, day = tokens[-2].split("_")[1:]
|
|
323
|
+
return {
|
|
324
|
+
"path": file_path,
|
|
325
|
+
"version": version,
|
|
326
|
+
"year": int(year),
|
|
327
|
+
"month": int(month),
|
|
328
|
+
"day": int(day),
|
|
329
|
+
"subject": int(subject_id),
|
|
330
|
+
"session": int(session[1:]),
|
|
331
|
+
"segment": int(segment[1:]),
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _read_physician_report(file_path):
|
|
336
|
+
directory = os.path.dirname(file_path)
|
|
337
|
+
txt_file = glob.glob(os.path.join(directory, "**/*.txt"), recursive=True)
|
|
338
|
+
# check that there is at most one txt file in the same directory
|
|
339
|
+
assert len(txt_file) in [0, 1]
|
|
340
|
+
report = ""
|
|
341
|
+
if txt_file:
|
|
342
|
+
txt_file = txt_file[0]
|
|
343
|
+
# somewhere in the corpus, encoding apparently changed
|
|
344
|
+
# first try to read as utf-8, if it does not work use latin-1
|
|
345
|
+
try:
|
|
346
|
+
with open(txt_file, "r", encoding="utf-8") as f:
|
|
347
|
+
report = f.read()
|
|
348
|
+
except UnicodeDecodeError:
|
|
349
|
+
with open(txt_file, "r", encoding="latin-1") as f:
|
|
350
|
+
report = f.read()
|
|
351
|
+
if not report:
|
|
352
|
+
raise RuntimeError(
|
|
353
|
+
f"Could not read physician report ({txt_file}). "
|
|
354
|
+
f"Disable option or choose appropriate directory."
|
|
355
|
+
)
|
|
356
|
+
return report
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _read_edf_header(file_path):
|
|
360
|
+
f = open(file_path, "rb")
|
|
361
|
+
header = f.read(88)
|
|
362
|
+
f.close()
|
|
363
|
+
return header
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _parse_age_and_gender_from_edf_header(file_path):
|
|
367
|
+
header = _read_edf_header(file_path)
|
|
368
|
+
# bytes 8 to 88 contain ascii local patient identification
|
|
369
|
+
# see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
|
|
370
|
+
patient_id = header[8:].decode("ascii")
|
|
371
|
+
age = -1
|
|
372
|
+
found_age = re.findall(r"Age:(\d+)", patient_id)
|
|
373
|
+
if len(found_age) == 1:
|
|
374
|
+
age = int(found_age[0])
|
|
375
|
+
gender = "X"
|
|
376
|
+
found_gender = re.findall(r"\s([F|M])\s", patient_id)
|
|
377
|
+
if len(found_gender) == 1:
|
|
378
|
+
gender = found_gender[0]
|
|
379
|
+
return age, gender
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class TUHAbnormal(TUH):
|
|
383
|
+
"""Temple University Hospital (TUH) Abnormal EEG Corpus.
|
|
384
|
+
|
|
385
|
+
see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
path : str
|
|
390
|
+
Parent directory of the dataset.
|
|
391
|
+
recording_ids : list(int) | int
|
|
392
|
+
A (list of) int of recording id(s) to be read (order matters and will
|
|
393
|
+
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
394
|
+
then the first recording returned by this class will be chronologically
|
|
395
|
+
later then the second recording. Provide recording_ids in ascending
|
|
396
|
+
order to preserve chronological order.).
|
|
397
|
+
target_name : str
|
|
398
|
+
Can be 'pathological', 'gender', or 'age'.
|
|
399
|
+
preload : bool
|
|
400
|
+
If True, preload the data of the Raw objects.
|
|
401
|
+
add_physician_reports : bool
|
|
402
|
+
If True, the physician reports will be read from disk and added to the
|
|
403
|
+
description.
|
|
404
|
+
rename_channels : bool
|
|
405
|
+
If True, rename the EEG channels to the standard 10-05 system.
|
|
406
|
+
set_montage : bool
|
|
407
|
+
If True, set the montage to the standard 10-05 system.
|
|
408
|
+
n_jobs : int
|
|
409
|
+
Number of jobs to be used to read files in parallel.
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
def __init__(
|
|
413
|
+
self,
|
|
414
|
+
path: str,
|
|
415
|
+
recording_ids: list[int] | None = None,
|
|
416
|
+
target_name: str | tuple[str, ...] | None = "pathological",
|
|
417
|
+
preload: bool = False,
|
|
418
|
+
add_physician_reports: bool = False,
|
|
419
|
+
rename_channels: bool = False,
|
|
420
|
+
set_montage: bool = False,
|
|
421
|
+
n_jobs: int = 1,
|
|
422
|
+
):
|
|
423
|
+
super().__init__(
|
|
424
|
+
path=path,
|
|
425
|
+
recording_ids=recording_ids,
|
|
426
|
+
preload=preload,
|
|
427
|
+
target_name=target_name,
|
|
428
|
+
add_physician_reports=add_physician_reports,
|
|
429
|
+
rename_channels=rename_channels,
|
|
430
|
+
set_montage=set_montage,
|
|
431
|
+
n_jobs=n_jobs,
|
|
432
|
+
)
|
|
433
|
+
additional_descriptions = []
|
|
434
|
+
for file_path in self.description.path:
|
|
435
|
+
additional_description = self._parse_additional_description_from_file_path(
|
|
436
|
+
file_path
|
|
437
|
+
)
|
|
438
|
+
additional_descriptions.append(additional_description)
|
|
439
|
+
additional_descriptions = pd.DataFrame(additional_descriptions)
|
|
440
|
+
self.set_description(additional_descriptions, overwrite=True)
|
|
441
|
+
|
|
442
|
+
@staticmethod
|
|
443
|
+
def _parse_additional_description_from_file_path(file_path):
|
|
444
|
+
file_path = os.path.normpath(file_path)
|
|
445
|
+
tokens = file_path.split(os.sep)
|
|
446
|
+
# expect paths as version/file type/data_split/pathology status/
|
|
447
|
+
# reference/subset/subject/recording session/file
|
|
448
|
+
# e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
|
|
449
|
+
# s004_2013_08_15/00000021_s004_t000.edf
|
|
450
|
+
assert "abnormal" in tokens or "normal" in tokens, "No pathology labels found."
|
|
451
|
+
assert "train" in tokens or "eval" in tokens, (
|
|
452
|
+
"No train or eval set information found."
|
|
453
|
+
)
|
|
454
|
+
return {
|
|
455
|
+
"version": tokens[-9],
|
|
456
|
+
"train": "train" in tokens,
|
|
457
|
+
"pathological": "abnormal" in tokens,
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def _fake_raw(*args, **kwargs):
|
|
462
|
+
sfreq = 10
|
|
463
|
+
ch_names = [
|
|
464
|
+
"EEG A1-REF",
|
|
465
|
+
"EEG A2-REF",
|
|
466
|
+
"EEG FP1-REF",
|
|
467
|
+
"EEG FP2-REF",
|
|
468
|
+
"EEG F3-REF",
|
|
469
|
+
"EEG F4-REF",
|
|
470
|
+
"EEG C3-REF",
|
|
471
|
+
"EEG C4-REF",
|
|
472
|
+
"EEG P3-REF",
|
|
473
|
+
"EEG P4-REF",
|
|
474
|
+
"EEG O1-REF",
|
|
475
|
+
"EEG O2-REF",
|
|
476
|
+
"EEG F7-REF",
|
|
477
|
+
"EEG F8-REF",
|
|
478
|
+
"EEG T3-REF",
|
|
479
|
+
"EEG T4-REF",
|
|
480
|
+
"EEG T5-REF",
|
|
481
|
+
"EEG T6-REF",
|
|
482
|
+
"EEG FZ-REF",
|
|
483
|
+
"EEG CZ-REF",
|
|
484
|
+
"EEG PZ-REF",
|
|
485
|
+
]
|
|
486
|
+
duration_min = 6
|
|
487
|
+
data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
|
|
488
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
|
|
489
|
+
raw = mne.io.RawArray(data=data, info=info)
|
|
490
|
+
return raw
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _get_header(*args, **kwargs):
|
|
494
|
+
all_paths = {**_TUH_EEG_PATHS, **_TUH_EEG_ABNORMAL_PATHS}
|
|
495
|
+
return all_paths[args[0]]
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
_TUH_EEG_PATHS = {
|
|
499
|
+
# These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
|
|
500
|
+
"tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf": b"0 00000000 M 01-JAN-1978 00000000 Age:37 ",
|
|
501
|
+
# noqa E501
|
|
502
|
+
"tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf": b"0 00009932 F 01-JAN-1961 00009932 Age:53 ",
|
|
503
|
+
# noqa E501
|
|
504
|
+
"tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf": b"0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ",
|
|
505
|
+
# noqa E501
|
|
506
|
+
"tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf": b"0 00012331 M 01-JAN-1975 00012331 Age:39 ",
|
|
507
|
+
# noqa E501
|
|
508
|
+
"tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf": b"0 00014928 F 01-JAN-1933 00014928 Age:83 ",
|
|
509
|
+
# noqa E501
|
|
510
|
+
}
|
|
511
|
+
_TUH_EEG_ABNORMAL_PATHS = {
|
|
512
|
+
# these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
|
|
513
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf": b"0 00007871 F 01-JAN-1988 00007871 Age:23 ",
|
|
514
|
+
# noqa E501
|
|
515
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf": b"0 00009777 M 01-JAN-1986 00009777 Age:26 ",
|
|
516
|
+
# noqa E501
|
|
517
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf": b"0 00008393 M 01-JAN-1960 00008393 Age:52 ",
|
|
518
|
+
# noqa E501
|
|
519
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf": b"0 00001200 M 01-JAN-1963 00001200 Age:47 ",
|
|
520
|
+
# noqa E501
|
|
521
|
+
"tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf": b"0 00005932 M 01-JAN-1963 00005932 Age:50 ",
|
|
522
|
+
# noqa E501
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class _TUHMock(TUH):
|
|
527
|
+
"""Mocked class for testing and examples."""
|
|
528
|
+
|
|
529
|
+
@mock.patch("glob.glob", return_value=_TUH_EEG_PATHS.keys())
|
|
530
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
531
|
+
@mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
|
|
532
|
+
def __init__(
|
|
533
|
+
self,
|
|
534
|
+
mock_glob,
|
|
535
|
+
path: str,
|
|
536
|
+
recording_ids: list[int] | None = None,
|
|
537
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
538
|
+
preload: bool = False,
|
|
539
|
+
add_physician_reports: bool = False,
|
|
540
|
+
rename_channels: bool = False,
|
|
541
|
+
set_montage: bool = False,
|
|
542
|
+
n_jobs: int = 1,
|
|
543
|
+
):
|
|
544
|
+
with warnings.catch_warnings():
|
|
545
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
546
|
+
super().__init__(
|
|
547
|
+
path=path,
|
|
548
|
+
recording_ids=recording_ids,
|
|
549
|
+
target_name=target_name,
|
|
550
|
+
preload=preload,
|
|
551
|
+
add_physician_reports=add_physician_reports,
|
|
552
|
+
rename_channels=rename_channels,
|
|
553
|
+
set_montage=set_montage,
|
|
554
|
+
n_jobs=n_jobs,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class _TUHAbnormalMock(TUHAbnormal):
|
|
559
|
+
"""Mocked class for testing and examples."""
|
|
560
|
+
|
|
561
|
+
@mock.patch("glob.glob", return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
|
|
562
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
563
|
+
@mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
|
|
564
|
+
@mock.patch(
|
|
565
|
+
"braindecode.datasets.tuh._read_physician_report", return_value="simple_test"
|
|
566
|
+
)
|
|
567
|
+
def __init__(
|
|
568
|
+
self,
|
|
569
|
+
mock_glob,
|
|
570
|
+
mock_report,
|
|
571
|
+
path: str,
|
|
572
|
+
recording_ids: list[int] | None = None,
|
|
573
|
+
target_name: str | tuple[str, ...] | None = "pathological",
|
|
574
|
+
preload: bool = False,
|
|
575
|
+
add_physician_reports: bool = False,
|
|
576
|
+
rename_channels: bool = False,
|
|
577
|
+
set_montage: bool = False,
|
|
578
|
+
n_jobs: int = 1,
|
|
579
|
+
):
|
|
580
|
+
with warnings.catch_warnings():
|
|
581
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
582
|
+
super().__init__(
|
|
583
|
+
path=path,
|
|
584
|
+
recording_ids=recording_ids,
|
|
585
|
+
target_name=target_name,
|
|
586
|
+
preload=preload,
|
|
587
|
+
add_physician_reports=add_physician_reports,
|
|
588
|
+
rename_channels=rename_channels,
|
|
589
|
+
set_montage=set_montage,
|
|
590
|
+
n_jobs=n_jobs,
|
|
591
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Utility functions for dataset handling."""
|
|
2
|
+
|
|
3
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
4
|
+
#
|
|
5
|
+
# License: BSD (3-clause)
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _correct_dataset_path(
|
|
14
|
+
path: str, archive_name: str, subfolder_name: str | None = None
|
|
15
|
+
) -> str:
|
|
16
|
+
"""
|
|
17
|
+
Correct the dataset path after download and extraction.
|
|
18
|
+
|
|
19
|
+
This function handles two common post-download scenarios:
|
|
20
|
+
1. Renames '.unzip' suffixed directories created by some extraction tools
|
|
21
|
+
2. Navigates into a subfolder if the archive extracts to a nested directory
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
path : str
|
|
26
|
+
Expected path to the dataset directory.
|
|
27
|
+
archive_name : str
|
|
28
|
+
Name of the downloaded archive file without extension
|
|
29
|
+
(e.g., "chb_mit_bids", "NMT").
|
|
30
|
+
subfolder_name : str | None
|
|
31
|
+
Name of the subfolder within the extracted archive that contains the
|
|
32
|
+
actual data. If provided and the subfolder exists, the path will be
|
|
33
|
+
updated to point to it. If None, only renaming is attempted.
|
|
34
|
+
Default is None.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
str
|
|
39
|
+
The corrected path to the dataset directory.
|
|
40
|
+
|
|
41
|
+
Raises
|
|
42
|
+
------
|
|
43
|
+
PermissionError
|
|
44
|
+
If the '.unzip' directory exists but cannot be renamed due to
|
|
45
|
+
insufficient permissions.
|
|
46
|
+
"""
|
|
47
|
+
if not Path(path).exists():
|
|
48
|
+
unzip_file_name = f"{archive_name}.unzip"
|
|
49
|
+
if (Path(path).parent / unzip_file_name).exists():
|
|
50
|
+
try:
|
|
51
|
+
os.rename(
|
|
52
|
+
src=Path(path).parent / unzip_file_name,
|
|
53
|
+
dst=Path(path),
|
|
54
|
+
)
|
|
55
|
+
except PermissionError:
|
|
56
|
+
raise PermissionError(
|
|
57
|
+
f"Please rename {Path(path).parent / unzip_file_name} "
|
|
58
|
+
f"manually to {path} and try again."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Check if the subfolder exists inside the path
|
|
62
|
+
if subfolder_name is not None:
|
|
63
|
+
subfolder_path = os.path.join(path, subfolder_name)
|
|
64
|
+
if Path(subfolder_path).exists():
|
|
65
|
+
path = subfolder_path
|
|
66
|
+
|
|
67
|
+
return path
|