braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/tuh.py
CHANGED
|
@@ -7,20 +7,22 @@ TUH Abnormal EEG Corpus.
|
|
|
7
7
|
#
|
|
8
8
|
# License: BSD (3-clause)
|
|
9
9
|
|
|
10
|
-
import
|
|
11
|
-
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
12
|
import glob
|
|
13
|
+
import os
|
|
14
|
+
import re
|
|
13
15
|
import warnings
|
|
14
|
-
from unittest import mock
|
|
15
16
|
from datetime import datetime, timezone
|
|
16
17
|
from typing import Iterable
|
|
18
|
+
from unittest import mock
|
|
17
19
|
|
|
18
|
-
import pandas as pd
|
|
19
|
-
import numpy as np
|
|
20
20
|
import mne
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
21
23
|
from joblib import Parallel, delayed
|
|
22
24
|
|
|
23
|
-
from .base import
|
|
25
|
+
from .base import BaseConcatDataset, BaseDataset
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class TUH(BaseConcatDataset):
|
|
@@ -44,14 +46,32 @@ class TUH(BaseConcatDataset):
|
|
|
44
46
|
add_physician_reports: bool
|
|
45
47
|
If True, the physician reports will be read from disk and added to the
|
|
46
48
|
description.
|
|
49
|
+
rename_channels: bool
|
|
50
|
+
If True, rename the EEG channels to the standard 10-05 system.
|
|
51
|
+
set_montage: bool
|
|
52
|
+
If True, set the montage to the standard 10-05 system.
|
|
47
53
|
n_jobs: int
|
|
48
54
|
Number of jobs to be used to read files in parallel.
|
|
49
55
|
"""
|
|
50
|
-
|
|
51
|
-
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
path: str,
|
|
60
|
+
recording_ids: list[int] | None = None,
|
|
61
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
62
|
+
preload: bool = False,
|
|
63
|
+
add_physician_reports: bool = False,
|
|
64
|
+
rename_channels: bool = False,
|
|
65
|
+
set_montage: bool = False,
|
|
66
|
+
n_jobs: int = 1,
|
|
67
|
+
):
|
|
68
|
+
if set_montage:
|
|
69
|
+
assert rename_channels, (
|
|
70
|
+
"If set_montage is True, rename_channels must be True."
|
|
71
|
+
)
|
|
52
72
|
# create an index of all files and gather easily accessible info
|
|
53
73
|
# without actually touching the files
|
|
54
|
-
file_paths = glob.glob(os.path.join(path,
|
|
74
|
+
file_paths = glob.glob(os.path.join(path, "**/*.edf"), recursive=True)
|
|
55
75
|
descriptions = _create_description(file_paths)
|
|
56
76
|
# sort the descriptions chronologicaly
|
|
57
77
|
descriptions = _sort_chronologically(descriptions)
|
|
@@ -62,59 +82,139 @@ class TUH(BaseConcatDataset):
|
|
|
62
82
|
# of recordings to load
|
|
63
83
|
recording_ids = range(recording_ids)
|
|
64
84
|
descriptions = descriptions[recording_ids]
|
|
85
|
+
|
|
86
|
+
# workaround to ensure warnings are suppressed when running in parallel
|
|
87
|
+
def create_dataset(*args, **kwargs):
|
|
88
|
+
with warnings.catch_warnings():
|
|
89
|
+
warnings.filterwarnings(
|
|
90
|
+
"ignore", message=".*not in description. '__getitem__'"
|
|
91
|
+
)
|
|
92
|
+
return self._create_dataset(*args, **kwargs)
|
|
93
|
+
|
|
65
94
|
# this is the second loop (slow)
|
|
66
95
|
# create datasets gathering more info about the files touching them
|
|
67
96
|
# reading the raws and potentially preloading the data
|
|
68
97
|
# disable joblib for tests. mocking seems to fail otherwise
|
|
69
98
|
if n_jobs == 1:
|
|
70
|
-
base_datasets = [
|
|
71
|
-
|
|
72
|
-
|
|
99
|
+
base_datasets = [
|
|
100
|
+
create_dataset(
|
|
101
|
+
descriptions[i],
|
|
102
|
+
target_name,
|
|
103
|
+
preload,
|
|
104
|
+
add_physician_reports,
|
|
105
|
+
rename_channels,
|
|
106
|
+
set_montage,
|
|
107
|
+
)
|
|
108
|
+
for i in descriptions.columns
|
|
109
|
+
]
|
|
73
110
|
else:
|
|
74
|
-
base_datasets = Parallel(n_jobs)(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
111
|
+
base_datasets = Parallel(n_jobs)(
|
|
112
|
+
delayed(create_dataset)(
|
|
113
|
+
descriptions[i],
|
|
114
|
+
target_name,
|
|
115
|
+
preload,
|
|
116
|
+
add_physician_reports,
|
|
117
|
+
rename_channels,
|
|
118
|
+
set_montage,
|
|
119
|
+
)
|
|
120
|
+
for i in descriptions.columns
|
|
121
|
+
)
|
|
78
122
|
super().__init__(base_datasets)
|
|
79
123
|
|
|
80
124
|
@staticmethod
|
|
81
|
-
def
|
|
82
|
-
|
|
83
|
-
|
|
125
|
+
def _rename_channels(raw):
|
|
126
|
+
"""
|
|
127
|
+
Renames the EEG channels using mne conventions and sets their type to 'eeg'.
|
|
128
|
+
|
|
129
|
+
See https://isip.piconepress.com/publications/reports/2020/tuh_eeg/electrodes/
|
|
130
|
+
"""
|
|
131
|
+
# remove ref suffix and prefix:
|
|
132
|
+
# TODO: replace with removesuffix and removeprefix when 3.8 is dropped
|
|
133
|
+
mapping_strip = {
|
|
134
|
+
c: c.replace("-REF", "").replace("-LE", "").replace("EEG ", "")
|
|
135
|
+
for c in raw.ch_names
|
|
136
|
+
}
|
|
137
|
+
raw.rename_channels(mapping_strip)
|
|
138
|
+
|
|
139
|
+
montage1005 = mne.channels.make_standard_montage("standard_1005")
|
|
140
|
+
mapping_eeg_names = {
|
|
141
|
+
c.upper(): c for c in montage1005.ch_names if c.upper() in raw.ch_names
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Set channels whose type could not be inferred (defaulted to "eeg") to "misc":
|
|
145
|
+
non_eeg_names = [c for c in raw.ch_names if c not in mapping_eeg_names]
|
|
146
|
+
if non_eeg_names:
|
|
147
|
+
non_eeg_types = raw.get_channel_types(picks=non_eeg_names)
|
|
148
|
+
mapping_non_eeg_types = {
|
|
149
|
+
c: "misc" for c, t in zip(non_eeg_names, non_eeg_types) if t == "eeg"
|
|
150
|
+
}
|
|
151
|
+
if mapping_non_eeg_types:
|
|
152
|
+
raw.set_channel_types(mapping_non_eeg_types)
|
|
153
|
+
|
|
154
|
+
if mapping_eeg_names:
|
|
155
|
+
# Set 1005 channels type to "eeg":
|
|
156
|
+
raw.set_channel_types(
|
|
157
|
+
{c: "eeg" for c in mapping_eeg_names}, on_unit_change="ignore"
|
|
158
|
+
)
|
|
159
|
+
# Fix capitalized EEG channel names:
|
|
160
|
+
raw.rename_channels(mapping_eeg_names)
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _set_montage(raw):
|
|
164
|
+
montage = mne.channels.make_standard_montage("standard_1005")
|
|
165
|
+
raw.set_montage(montage, on_missing="ignore")
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _create_dataset(
|
|
169
|
+
description,
|
|
170
|
+
target_name,
|
|
171
|
+
preload,
|
|
172
|
+
add_physician_reports,
|
|
173
|
+
rename_channels,
|
|
174
|
+
set_montage,
|
|
175
|
+
):
|
|
176
|
+
file_path = description.loc["path"]
|
|
84
177
|
|
|
85
178
|
# parse age and gender information from EDF header
|
|
86
179
|
age, gender = _parse_age_and_gender_from_edf_header(file_path)
|
|
87
|
-
raw = mne.io.read_raw_edf(
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
180
|
+
raw = mne.io.read_raw_edf(
|
|
181
|
+
file_path, preload=preload, infer_types=True, verbose="error"
|
|
182
|
+
)
|
|
183
|
+
if rename_channels:
|
|
184
|
+
TUH._rename_channels(raw)
|
|
185
|
+
if set_montage:
|
|
186
|
+
TUH._set_montage(raw)
|
|
187
|
+
|
|
188
|
+
meas_date = (
|
|
189
|
+
datetime(1, 1, 1, tzinfo=timezone.utc)
|
|
190
|
+
if raw.info["meas_date"] is None
|
|
191
|
+
else raw.info["meas_date"]
|
|
192
|
+
)
|
|
91
193
|
# if this is old version of the data and the year could be parsed from
|
|
92
194
|
# file paths, use this instead as before
|
|
93
|
-
if
|
|
94
|
-
meas_date = meas_date.replace(
|
|
95
|
-
*description[['year', 'month', 'day']])
|
|
195
|
+
if "year" in description:
|
|
196
|
+
meas_date = meas_date.replace(*description[["year", "month", "day"]])
|
|
96
197
|
raw.set_meas_date(meas_date)
|
|
97
198
|
|
|
98
199
|
d = {
|
|
99
|
-
|
|
100
|
-
|
|
200
|
+
"age": int(age),
|
|
201
|
+
"gender": gender,
|
|
101
202
|
}
|
|
102
203
|
# if year exists in description = old version
|
|
103
204
|
# if not, get it from meas_date in raw.info and add to description
|
|
104
205
|
# if meas_date is None, create fake one
|
|
105
|
-
if
|
|
106
|
-
d[
|
|
107
|
-
d[
|
|
108
|
-
d[
|
|
206
|
+
if "year" not in description:
|
|
207
|
+
d["year"] = raw.info["meas_date"].year
|
|
208
|
+
d["month"] = raw.info["meas_date"].month
|
|
209
|
+
d["day"] = raw.info["meas_date"].day
|
|
109
210
|
|
|
110
211
|
# read info relevant for preprocessing from raw without loading it
|
|
111
212
|
if add_physician_reports:
|
|
112
213
|
physician_report = _read_physician_report(file_path)
|
|
113
|
-
d[
|
|
214
|
+
d["report"] = physician_report
|
|
114
215
|
additional_description = pd.Series(d)
|
|
115
216
|
description = pd.concat([description, additional_description])
|
|
116
|
-
base_dataset = BaseDataset(raw, description,
|
|
117
|
-
target_name=target_name)
|
|
217
|
+
base_dataset = BaseDataset(raw, description, target_name=target_name)
|
|
118
218
|
return base_dataset
|
|
119
219
|
|
|
120
220
|
|
|
@@ -126,30 +226,32 @@ def _create_description(file_paths):
|
|
|
126
226
|
|
|
127
227
|
def _sort_chronologically(descriptions):
|
|
128
228
|
descriptions.sort_values(
|
|
129
|
-
["year", "month", "day", "subject", "session", "segment"],
|
|
130
|
-
|
|
229
|
+
["year", "month", "day", "subject", "session", "segment"], axis=1, inplace=True
|
|
230
|
+
)
|
|
131
231
|
return descriptions
|
|
132
232
|
|
|
133
233
|
|
|
134
234
|
def _read_date(file_path):
|
|
135
|
-
date_path = file_path.replace(
|
|
235
|
+
date_path = file_path.replace(".edf", "_date.txt")
|
|
136
236
|
# if date file exists, read it
|
|
137
237
|
if os.path.exists(date_path):
|
|
138
|
-
description = pd.read_json(date_path, typ=
|
|
238
|
+
description = pd.read_json(date_path, typ="series").to_dict()
|
|
139
239
|
# otherwise read edf file, extract date and store to file
|
|
140
240
|
else:
|
|
141
|
-
raw = mne.io.read_raw_edf(file_path, preload=False, verbose=
|
|
241
|
+
raw = mne.io.read_raw_edf(file_path, preload=False, verbose="error")
|
|
142
242
|
description = {
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
243
|
+
"year": raw.info["meas_date"].year,
|
|
244
|
+
"month": raw.info["meas_date"].month,
|
|
245
|
+
"day": raw.info["meas_date"].day,
|
|
146
246
|
}
|
|
147
247
|
# if the txt file storing the recording date does not exist, create it
|
|
148
248
|
try:
|
|
149
249
|
pd.Series(description).to_json(date_path)
|
|
150
250
|
except OSError:
|
|
151
|
-
warnings.warn(
|
|
152
|
-
|
|
251
|
+
warnings.warn(
|
|
252
|
+
f"Cannot save date file to {date_path}. "
|
|
253
|
+
f"This might slow down creation of the dataset."
|
|
254
|
+
)
|
|
153
255
|
return description
|
|
154
256
|
|
|
155
257
|
|
|
@@ -158,12 +260,12 @@ def _parse_description_from_file_path(file_path):
|
|
|
158
260
|
file_path = os.path.normpath(file_path)
|
|
159
261
|
tokens = file_path.split(os.sep)
|
|
160
262
|
# Extract version number and tuh_eeg_abnormal/tuh_eeg from file path
|
|
161
|
-
if (
|
|
263
|
+
if ("train" in tokens) or ("eval" in tokens): # tuh_eeg_abnormal
|
|
162
264
|
abnormal = True
|
|
163
265
|
# Tokens[-2] is channel configuration (always 01_tcp_ar in abnormal)
|
|
164
266
|
# on new versions, or
|
|
165
267
|
# session (e.g. s004_2013_08_15) on old versions
|
|
166
|
-
if tokens[-2].split(
|
|
268
|
+
if tokens[-2].split("_")[0][0] == "s": # s denoting session number
|
|
167
269
|
version = tokens[-9] # Before dec 2022 updata
|
|
168
270
|
else:
|
|
169
271
|
version = tokens[-6] # After the dec 2022 update
|
|
@@ -181,22 +283,24 @@ def _parse_description_from_file_path(file_path):
|
|
|
181
283
|
# or for abnormal:
|
|
182
284
|
# tuh_eeg_abnormal/v3.0.0/edf/train/normal/
|
|
183
285
|
# 01_tcp_ar/aaaaaaav_s004_t000.edf
|
|
184
|
-
subject_id = tokens[-1].split(
|
|
185
|
-
session = tokens[-1].split(
|
|
186
|
-
segment = tokens[-1].split(
|
|
286
|
+
subject_id = tokens[-1].split("_")[0]
|
|
287
|
+
session = tokens[-1].split("_")[1]
|
|
288
|
+
segment = tokens[-1].split("_")[2].split(".")[0]
|
|
187
289
|
description = _read_date(file_path)
|
|
188
|
-
description.update(
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
290
|
+
description.update(
|
|
291
|
+
{
|
|
292
|
+
"path": file_path,
|
|
293
|
+
"version": version,
|
|
294
|
+
"subject": subject_id,
|
|
295
|
+
"session": int(session[1:]),
|
|
296
|
+
"segment": int(segment[1:]),
|
|
297
|
+
}
|
|
298
|
+
)
|
|
195
299
|
if not abnormal:
|
|
196
|
-
year, month, day = tokens[-3].split(
|
|
197
|
-
description[
|
|
198
|
-
description[
|
|
199
|
-
description[
|
|
300
|
+
year, month, day = tokens[-3].split("_")[1:]
|
|
301
|
+
description["year"] = int(year)
|
|
302
|
+
description["month"] = int(month)
|
|
303
|
+
description["day"] = int(day)
|
|
200
304
|
return description
|
|
201
305
|
else: # Old file path structure
|
|
202
306
|
# expect file paths as tuh_eeg/version/file_type/reference/data_split/
|
|
@@ -208,43 +312,45 @@ def _parse_description_from_file_path(file_path):
|
|
|
208
312
|
# reference/subset/subject/recording session/file
|
|
209
313
|
# v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
|
|
210
314
|
# s004_2013_08_15/00000021_s004_t000.edf
|
|
211
|
-
subject_id = tokens[-1].split(
|
|
212
|
-
session = tokens[-2].split(
|
|
315
|
+
subject_id = tokens[-1].split("_")[0]
|
|
316
|
+
session = tokens[-2].split("_")[0] # string on format 's000'
|
|
213
317
|
# According to the example path in the comment 8 lines above,
|
|
214
318
|
# segment is not included in the file name
|
|
215
|
-
segment = tokens[-1].split(
|
|
216
|
-
year, month, day = tokens[-2].split(
|
|
319
|
+
segment = tokens[-1].split("_")[-1].split(".")[0] # TODO: test with tuh_eeg
|
|
320
|
+
year, month, day = tokens[-2].split("_")[1:]
|
|
217
321
|
return {
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
322
|
+
"path": file_path,
|
|
323
|
+
"version": version,
|
|
324
|
+
"year": int(year),
|
|
325
|
+
"month": int(month),
|
|
326
|
+
"day": int(day),
|
|
327
|
+
"subject": int(subject_id),
|
|
328
|
+
"session": int(session[1:]),
|
|
329
|
+
"segment": int(segment[1:]),
|
|
226
330
|
}
|
|
227
331
|
|
|
228
332
|
|
|
229
333
|
def _read_physician_report(file_path):
|
|
230
334
|
directory = os.path.dirname(file_path)
|
|
231
|
-
txt_file = glob.glob(os.path.join(directory,
|
|
335
|
+
txt_file = glob.glob(os.path.join(directory, "**/*.txt"), recursive=True)
|
|
232
336
|
# check that there is at most one txt file in the same directory
|
|
233
337
|
assert len(txt_file) in [0, 1]
|
|
234
|
-
report =
|
|
338
|
+
report = ""
|
|
235
339
|
if txt_file:
|
|
236
340
|
txt_file = txt_file[0]
|
|
237
341
|
# somewhere in the corpus, encoding apparently changed
|
|
238
342
|
# first try to read as utf-8, if it does not work use latin-1
|
|
239
343
|
try:
|
|
240
|
-
with open(txt_file,
|
|
344
|
+
with open(txt_file, "r", encoding="utf-8") as f:
|
|
241
345
|
report = f.read()
|
|
242
346
|
except UnicodeDecodeError:
|
|
243
|
-
with open(txt_file,
|
|
347
|
+
with open(txt_file, "r", encoding="latin-1") as f:
|
|
244
348
|
report = f.read()
|
|
245
349
|
if not report:
|
|
246
|
-
raise RuntimeError(
|
|
247
|
-
|
|
350
|
+
raise RuntimeError(
|
|
351
|
+
f"Could not read physician report ({txt_file}). "
|
|
352
|
+
f"Disable option or choose appropriate directory."
|
|
353
|
+
)
|
|
248
354
|
return report
|
|
249
355
|
|
|
250
356
|
|
|
@@ -292,20 +398,40 @@ class TUHAbnormal(TUH):
|
|
|
292
398
|
add_physician_reports: bool
|
|
293
399
|
If True, the physician reports will be read from disk and added to the
|
|
294
400
|
description.
|
|
401
|
+
rename_channels: bool
|
|
402
|
+
If True, rename the EEG channels to the standard 10-05 system.
|
|
403
|
+
set_montage: bool
|
|
404
|
+
If True, set the montage to the standard 10-05 system.
|
|
405
|
+
n_jobs: int
|
|
406
|
+
Number of jobs to be used to read files in parallel.
|
|
295
407
|
"""
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
408
|
+
|
|
409
|
+
def __init__(
|
|
410
|
+
self,
|
|
411
|
+
path: str,
|
|
412
|
+
recording_ids: list[int] | None = None,
|
|
413
|
+
target_name: str | tuple[str, ...] | None = "pathological",
|
|
414
|
+
preload: bool = False,
|
|
415
|
+
add_physician_reports: bool = False,
|
|
416
|
+
rename_channels: bool = False,
|
|
417
|
+
set_montage: bool = False,
|
|
418
|
+
n_jobs: int = 1,
|
|
419
|
+
):
|
|
420
|
+
super().__init__(
|
|
421
|
+
path=path,
|
|
422
|
+
recording_ids=recording_ids,
|
|
423
|
+
preload=preload,
|
|
424
|
+
target_name=target_name,
|
|
425
|
+
add_physician_reports=add_physician_reports,
|
|
426
|
+
rename_channels=rename_channels,
|
|
427
|
+
set_montage=set_montage,
|
|
428
|
+
n_jobs=n_jobs,
|
|
429
|
+
)
|
|
305
430
|
additional_descriptions = []
|
|
306
431
|
for file_path in self.description.path:
|
|
307
|
-
additional_description = (
|
|
308
|
-
|
|
432
|
+
additional_description = self._parse_additional_description_from_file_path(
|
|
433
|
+
file_path
|
|
434
|
+
)
|
|
309
435
|
additional_descriptions.append(additional_description)
|
|
310
436
|
additional_descriptions = pd.DataFrame(additional_descriptions)
|
|
311
437
|
self.set_description(additional_descriptions, overwrite=True)
|
|
@@ -318,28 +444,45 @@ class TUHAbnormal(TUH):
|
|
|
318
444
|
# reference/subset/subject/recording session/file
|
|
319
445
|
# e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
|
|
320
446
|
# s004_2013_08_15/00000021_s004_t000.edf
|
|
321
|
-
assert
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
447
|
+
assert "abnormal" in tokens or "normal" in tokens, "No pathology labels found."
|
|
448
|
+
assert "train" in tokens or "eval" in tokens, (
|
|
449
|
+
"No train or eval set information found."
|
|
450
|
+
)
|
|
325
451
|
return {
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
452
|
+
"version": tokens[-9],
|
|
453
|
+
"train": "train" in tokens,
|
|
454
|
+
"pathological": "abnormal" in tokens,
|
|
329
455
|
}
|
|
330
456
|
|
|
331
457
|
|
|
332
458
|
def _fake_raw(*args, **kwargs):
|
|
333
459
|
sfreq = 10
|
|
334
460
|
ch_names = [
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
461
|
+
"EEG A1-REF",
|
|
462
|
+
"EEG A2-REF",
|
|
463
|
+
"EEG FP1-REF",
|
|
464
|
+
"EEG FP2-REF",
|
|
465
|
+
"EEG F3-REF",
|
|
466
|
+
"EEG F4-REF",
|
|
467
|
+
"EEG C3-REF",
|
|
468
|
+
"EEG C4-REF",
|
|
469
|
+
"EEG P3-REF",
|
|
470
|
+
"EEG P4-REF",
|
|
471
|
+
"EEG O1-REF",
|
|
472
|
+
"EEG O2-REF",
|
|
473
|
+
"EEG F7-REF",
|
|
474
|
+
"EEG F8-REF",
|
|
475
|
+
"EEG T3-REF",
|
|
476
|
+
"EEG T4-REF",
|
|
477
|
+
"EEG T5-REF",
|
|
478
|
+
"EEG T6-REF",
|
|
479
|
+
"EEG FZ-REF",
|
|
480
|
+
"EEG CZ-REF",
|
|
481
|
+
"EEG PZ-REF",
|
|
482
|
+
]
|
|
340
483
|
duration_min = 6
|
|
341
484
|
data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
|
|
342
|
-
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=
|
|
485
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
|
|
343
486
|
raw = mne.io.RawArray(data=data, info=info)
|
|
344
487
|
return raw
|
|
345
488
|
|
|
@@ -351,54 +494,95 @@ def _get_header(*args, **kwargs):
|
|
|
351
494
|
|
|
352
495
|
_TUH_EEG_PATHS = {
|
|
353
496
|
# These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
497
|
+
"tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf": b"0 00000000 M 01-JAN-1978 00000000 Age:37 ",
|
|
498
|
+
# noqa E501
|
|
499
|
+
"tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf": b"0 00009932 F 01-JAN-1961 00009932 Age:53 ",
|
|
500
|
+
# noqa E501
|
|
501
|
+
"tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf": b"0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ",
|
|
502
|
+
# noqa E501
|
|
503
|
+
"tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf": b"0 00012331 M 01-JAN-1975 00012331 Age:39 ",
|
|
504
|
+
# noqa E501
|
|
505
|
+
"tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf": b"0 00014928 F 01-JAN-1933 00014928 Age:83 ",
|
|
506
|
+
# noqa E501
|
|
359
507
|
}
|
|
360
508
|
_TUH_EEG_ABNORMAL_PATHS = {
|
|
361
509
|
# these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
510
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf": b"0 00007871 F 01-JAN-1988 00007871 Age:23 ",
|
|
511
|
+
# noqa E501
|
|
512
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf": b"0 00009777 M 01-JAN-1986 00009777 Age:26 ",
|
|
513
|
+
# noqa E501
|
|
514
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf": b"0 00008393 M 01-JAN-1960 00008393 Age:52 ",
|
|
515
|
+
# noqa E501
|
|
516
|
+
"tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf": b"0 00001200 M 01-JAN-1963 00001200 Age:47 ",
|
|
517
|
+
# noqa E501
|
|
518
|
+
"tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf": b"0 00005932 M 01-JAN-1963 00005932 Age:50 ",
|
|
519
|
+
# noqa E501
|
|
367
520
|
}
|
|
368
521
|
|
|
369
522
|
|
|
370
523
|
class _TUHMock(TUH):
|
|
371
524
|
"""Mocked class for testing and examples."""
|
|
372
|
-
|
|
373
|
-
@mock.patch(
|
|
374
|
-
@mock.patch(
|
|
375
|
-
|
|
376
|
-
def __init__(
|
|
377
|
-
|
|
525
|
+
|
|
526
|
+
@mock.patch("glob.glob", return_value=_TUH_EEG_PATHS.keys())
|
|
527
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
528
|
+
@mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
|
|
529
|
+
def __init__(
|
|
530
|
+
self,
|
|
531
|
+
mock_glob,
|
|
532
|
+
path: str,
|
|
533
|
+
recording_ids: list[int] | None = None,
|
|
534
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
535
|
+
preload: bool = False,
|
|
536
|
+
add_physician_reports: bool = False,
|
|
537
|
+
rename_channels: bool = False,
|
|
538
|
+
set_montage: bool = False,
|
|
539
|
+
n_jobs: int = 1,
|
|
540
|
+
):
|
|
378
541
|
with warnings.catch_warnings():
|
|
379
|
-
warnings.filterwarnings(
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
542
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
543
|
+
super().__init__(
|
|
544
|
+
path=path,
|
|
545
|
+
recording_ids=recording_ids,
|
|
546
|
+
target_name=target_name,
|
|
547
|
+
preload=preload,
|
|
548
|
+
add_physician_reports=add_physician_reports,
|
|
549
|
+
rename_channels=rename_channels,
|
|
550
|
+
set_montage=set_montage,
|
|
551
|
+
n_jobs=n_jobs,
|
|
552
|
+
)
|
|
385
553
|
|
|
386
554
|
|
|
387
555
|
class _TUHAbnormalMock(TUHAbnormal):
|
|
388
556
|
"""Mocked class for testing and examples."""
|
|
389
|
-
|
|
390
|
-
@mock.patch(
|
|
391
|
-
@mock.patch(
|
|
392
|
-
|
|
393
|
-
@mock.patch(
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
557
|
+
|
|
558
|
+
@mock.patch("glob.glob", return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
|
|
559
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
560
|
+
@mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
|
|
561
|
+
@mock.patch(
|
|
562
|
+
"braindecode.datasets.tuh._read_physician_report", return_value="simple_test"
|
|
563
|
+
)
|
|
564
|
+
def __init__(
|
|
565
|
+
self,
|
|
566
|
+
mock_glob,
|
|
567
|
+
mock_report,
|
|
568
|
+
path: str,
|
|
569
|
+
recording_ids: list[int] | None = None,
|
|
570
|
+
target_name: str | tuple[str, ...] | None = "pathological",
|
|
571
|
+
preload: bool = False,
|
|
572
|
+
add_physician_reports: bool = False,
|
|
573
|
+
rename_channels: bool = False,
|
|
574
|
+
set_montage: bool = False,
|
|
575
|
+
n_jobs: int = 1,
|
|
576
|
+
):
|
|
398
577
|
with warnings.catch_warnings():
|
|
399
|
-
warnings.filterwarnings(
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
578
|
+
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
579
|
+
super().__init__(
|
|
580
|
+
path=path,
|
|
581
|
+
recording_ids=recording_ids,
|
|
582
|
+
target_name=target_name,
|
|
583
|
+
preload=preload,
|
|
584
|
+
add_physician_reports=add_physician_reports,
|
|
585
|
+
rename_channels=rename_channels,
|
|
586
|
+
set_montage=set_montage,
|
|
587
|
+
n_jobs=n_jobs,
|
|
588
|
+
)
|