braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,591 @@
1
+ """
2
+ Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the.
3
+
4
+ TUH Abnormal EEG Corpus.
5
+ """
6
+
7
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
8
+ #
9
+ # License: BSD (3-clause)
10
+
11
+ from __future__ import annotations
12
+
13
+ import glob
14
+ import os
15
+ import re
16
+ import warnings
17
+ from datetime import datetime, timezone
18
+ from typing import Iterable
19
+ from unittest import mock
20
+
21
+ import mne
22
+ import numpy as np
23
+ import pandas as pd
24
+ from joblib import Parallel, delayed
25
+
26
+ from .base import BaseConcatDataset, RawDataset
27
+
28
+
29
+ class TUH(BaseConcatDataset):
30
+ """Temple University Hospital (TUH) EEG Corpus.
31
+
32
+ (www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
33
+
34
+ Parameters
35
+ ----------
36
+ path : str
37
+ Parent directory of the dataset.
38
+ recording_ids : list(int) | int
39
+ A (list of) int of recording id(s) to be read (order matters and will
40
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
41
+ then the first recording returned by this class will be chronologically
42
+ later then the second recording. Provide recording_ids in ascending
43
+ order to preserve chronological order.).
44
+ target_name : str
45
+ Can be 'gender', or 'age'.
46
+ preload : bool
47
+ If True, preload the data of the Raw objects.
48
+ add_physician_reports : bool
49
+ If True, the physician reports will be read from disk and added to the
50
+ description.
51
+ rename_channels : bool
52
+ If True, rename the EEG channels to the standard 10-05 system.
53
+ set_montage : bool
54
+ If True, set the montage to the standard 10-05 system.
55
+ n_jobs : int
56
+ Number of jobs to be used to read files in parallel.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ path: str,
62
+ recording_ids: list[int] | None = None,
63
+ target_name: str | tuple[str, ...] | None = None,
64
+ preload: bool = False,
65
+ add_physician_reports: bool = False,
66
+ rename_channels: bool = False,
67
+ set_montage: bool = False,
68
+ n_jobs: int = 1,
69
+ ):
70
+ if set_montage:
71
+ assert rename_channels, (
72
+ "If set_montage is True, rename_channels must be True."
73
+ )
74
+ # create an index of all files and gather easily accessible info
75
+ # without actually touching the files
76
+ file_paths = glob.glob(os.path.join(path, "**/*.edf"), recursive=True)
77
+ descriptions = _create_description(file_paths)
78
+ # sort the descriptions chronologicaly
79
+ descriptions = _sort_chronologically(descriptions)
80
+ # limit to specified recording ids before doing slow stuff
81
+ if recording_ids is not None:
82
+ if not isinstance(recording_ids, Iterable):
83
+ # Assume it is an integer specifying number
84
+ # of recordings to load
85
+ recording_ids = range(recording_ids)
86
+ descriptions = descriptions[recording_ids]
87
+
88
+ # workaround to ensure warnings are suppressed when running in parallel
89
+ def create_dataset(*args, **kwargs):
90
+ with warnings.catch_warnings():
91
+ warnings.filterwarnings(
92
+ "ignore", message=".*not in description. '__getitem__'"
93
+ )
94
+ return self._create_dataset(*args, **kwargs)
95
+
96
+ # this is the second loop (slow)
97
+ # create datasets gathering more info about the files touching them
98
+ # reading the raws and potentially preloading the data
99
+ # disable joblib for tests. mocking seems to fail otherwise
100
+ if n_jobs == 1:
101
+ base_datasets = [
102
+ create_dataset(
103
+ descriptions[i],
104
+ target_name,
105
+ preload,
106
+ add_physician_reports,
107
+ rename_channels,
108
+ set_montage,
109
+ )
110
+ for i in descriptions.columns
111
+ ]
112
+ else:
113
+ base_datasets = Parallel(n_jobs)(
114
+ delayed(create_dataset)(
115
+ descriptions[i],
116
+ target_name,
117
+ preload,
118
+ add_physician_reports,
119
+ rename_channels,
120
+ set_montage,
121
+ )
122
+ for i in descriptions.columns
123
+ )
124
+ super().__init__(base_datasets)
125
+
126
+ @staticmethod
127
+ def _rename_channels(raw):
128
+ """
129
+ Renames the EEG channels using mne conventions and sets their type to 'eeg'.
130
+
131
+ See https://isip.piconepress.com/publications/reports/2020/tuh_eeg/electrodes/
132
+ """
133
+ # remove ref suffix and prefix:
134
+ # TODO: replace with removesuffix and removeprefix when 3.8 is dropped
135
+ mapping_strip = {
136
+ c: c.replace("-REF", "").replace("-LE", "").replace("EEG ", "")
137
+ for c in raw.ch_names
138
+ }
139
+ raw.rename_channels(mapping_strip)
140
+
141
+ montage1005 = mne.channels.make_standard_montage("standard_1005")
142
+ mapping_eeg_names = {
143
+ c.upper(): c for c in montage1005.ch_names if c.upper() in raw.ch_names
144
+ }
145
+
146
+ # Set channels whose type could not be inferred (defaulted to "eeg") to "misc":
147
+ non_eeg_names = [c for c in raw.ch_names if c not in mapping_eeg_names]
148
+ if non_eeg_names:
149
+ non_eeg_types = raw.get_channel_types(picks=non_eeg_names)
150
+ mapping_non_eeg_types = {
151
+ c: "misc" for c, t in zip(non_eeg_names, non_eeg_types) if t == "eeg"
152
+ }
153
+ if mapping_non_eeg_types:
154
+ raw.set_channel_types(mapping_non_eeg_types)
155
+
156
+ if mapping_eeg_names:
157
+ # Set 1005 channels type to "eeg":
158
+ raw.set_channel_types(
159
+ {c: "eeg" for c in mapping_eeg_names}, on_unit_change="ignore"
160
+ )
161
+ # Fix capitalized EEG channel names:
162
+ raw.rename_channels(mapping_eeg_names)
163
+
164
+ @staticmethod
165
+ def _set_montage(raw):
166
+ montage = mne.channels.make_standard_montage("standard_1005")
167
+ raw.set_montage(montage, on_missing="ignore")
168
+
169
+ @staticmethod
170
+ def _create_dataset(
171
+ description,
172
+ target_name,
173
+ preload,
174
+ add_physician_reports,
175
+ rename_channels,
176
+ set_montage,
177
+ ):
178
+ file_path = description.loc["path"]
179
+
180
+ # parse age and gender information from EDF header
181
+ age, gender = _parse_age_and_gender_from_edf_header(file_path)
182
+ raw = mne.io.read_raw_edf(
183
+ file_path, preload=preload, infer_types=True, verbose="error"
184
+ )
185
+ if rename_channels:
186
+ TUH._rename_channels(raw)
187
+ if set_montage:
188
+ TUH._set_montage(raw)
189
+
190
+ meas_date = (
191
+ datetime(1, 1, 1, tzinfo=timezone.utc)
192
+ if raw.info["meas_date"] is None
193
+ else raw.info["meas_date"]
194
+ )
195
+ # if this is old version of the data and the year could be parsed from
196
+ # file paths, use this instead as before
197
+ if "year" in description:
198
+ meas_date = meas_date.replace(*description[["year", "month", "day"]])
199
+ raw.set_meas_date(meas_date)
200
+
201
+ d = {
202
+ "age": int(age),
203
+ "gender": gender,
204
+ }
205
+ # if year exists in description = old version
206
+ # if not, get it from meas_date in raw.info and add to description
207
+ # if meas_date is None, create fake one
208
+ if "year" not in description:
209
+ d["year"] = raw.info["meas_date"].year
210
+ d["month"] = raw.info["meas_date"].month
211
+ d["day"] = raw.info["meas_date"].day
212
+
213
+ # read info relevant for preprocessing from raw without loading it
214
+ if add_physician_reports:
215
+ physician_report = _read_physician_report(file_path)
216
+ d["report"] = physician_report
217
+ additional_description = pd.Series(d)
218
+ description = pd.concat([description, additional_description])
219
+ base_dataset = RawDataset(raw, description, target_name=target_name)
220
+ return base_dataset
221
+
222
+
223
+ def _create_description(file_paths):
224
+ descriptions = [_parse_description_from_file_path(f) for f in file_paths]
225
+ descriptions = pd.DataFrame(descriptions)
226
+ return descriptions.T
227
+
228
+
229
+ def _sort_chronologically(descriptions):
230
+ descriptions.sort_values(
231
+ ["year", "month", "day", "subject", "session", "segment"], axis=1, inplace=True
232
+ )
233
+ return descriptions
234
+
235
+
236
+ def _read_date(file_path):
237
+ date_path = file_path.replace(".edf", "_date.txt")
238
+ # if date file exists, read it
239
+ if os.path.exists(date_path):
240
+ description = pd.read_json(date_path, typ="series").to_dict()
241
+ # otherwise read edf file, extract date and store to file
242
+ else:
243
+ raw = mne.io.read_raw_edf(file_path, preload=False, verbose="error")
244
+ description = {
245
+ "year": raw.info["meas_date"].year,
246
+ "month": raw.info["meas_date"].month,
247
+ "day": raw.info["meas_date"].day,
248
+ }
249
+ # if the txt file storing the recording date does not exist, create it
250
+ try:
251
+ pd.Series(description).to_json(date_path)
252
+ except OSError:
253
+ warnings.warn(
254
+ f"Cannot save date file to {date_path}. "
255
+ f"This might slow down creation of the dataset."
256
+ )
257
+ return description
258
+
259
+
260
+ def _parse_description_from_file_path(file_path):
261
+ # stackoverflow.com/questions/3167154/how-to-split-a-dos-path-into-its-components-in-python # noqa
262
+ file_path = os.path.normpath(file_path)
263
+ tokens = file_path.split(os.sep)
264
+ # Extract version number and tuh_eeg_abnormal/tuh_eeg from file path
265
+ if ("train" in tokens) or ("eval" in tokens): # tuh_eeg_abnormal
266
+ abnormal = True
267
+ # Tokens[-2] is channel configuration (always 01_tcp_ar in abnormal)
268
+ # on new versions, or
269
+ # session (e.g. s004_2013_08_15) on old versions
270
+ if tokens[-2].split("_")[0][0] == "s": # s denoting session number
271
+ version = tokens[-9] # Before dec 2022 updata
272
+ else:
273
+ version = tokens[-6] # After the dec 2022 update
274
+
275
+ else: # tuh_eeg
276
+ abnormal = False
277
+ version = tokens[-7]
278
+ v_number = int(version[1])
279
+
280
+ if (abnormal and v_number >= 3) or ((not abnormal) and v_number >= 2):
281
+ # New file path structure for versions after december 2022,
282
+ # expect file paths as
283
+ # tuh_eeg/v2.0.0/edf/000/aaaaaaaa/
284
+ # s001_2015_12_30/01_tcp_ar/aaaaaaaa_s001_t000.edf
285
+ # or for abnormal:
286
+ # tuh_eeg_abnormal/v3.0.0/edf/train/normal/
287
+ # 01_tcp_ar/aaaaaaav_s004_t000.edf
288
+ subject_id = tokens[-1].split("_")[0]
289
+ session = tokens[-1].split("_")[1]
290
+ segment = tokens[-1].split("_")[2].split(".")[0]
291
+ description = _read_date(file_path)
292
+ description.update(
293
+ {
294
+ "path": file_path,
295
+ "version": version,
296
+ "subject": subject_id,
297
+ "session": int(session[1:]),
298
+ "segment": int(segment[1:]),
299
+ }
300
+ )
301
+ if not abnormal:
302
+ year, month, day = tokens[-3].split("_")[1:]
303
+ description["year"] = int(year)
304
+ description["month"] = int(month)
305
+ description["day"] = int(day)
306
+ return description
307
+ else: # Old file path structure
308
+ # expect file paths as tuh_eeg/version/file_type/reference/data_split/
309
+ # subject/recording session/file
310
+ # e.g. tuh_eeg/v1.1.0/edf/01_tcp_ar/027/00002729/
311
+ # s001_2006_04_12/00002729_s001.edf
312
+ # or for abnormal
313
+ # version/file type/data_split/pathology status/
314
+ # reference/subset/subject/recording session/file
315
+ # v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
316
+ # s004_2013_08_15/00000021_s004_t000.edf
317
+ subject_id = tokens[-1].split("_")[0]
318
+ session = tokens[-2].split("_")[0] # string on format 's000'
319
+ # According to the example path in the comment 8 lines above,
320
+ # segment is not included in the file name
321
+ segment = tokens[-1].split("_")[-1].split(".")[0] # TODO: test with tuh_eeg
322
+ year, month, day = tokens[-2].split("_")[1:]
323
+ return {
324
+ "path": file_path,
325
+ "version": version,
326
+ "year": int(year),
327
+ "month": int(month),
328
+ "day": int(day),
329
+ "subject": int(subject_id),
330
+ "session": int(session[1:]),
331
+ "segment": int(segment[1:]),
332
+ }
333
+
334
+
335
+ def _read_physician_report(file_path):
336
+ directory = os.path.dirname(file_path)
337
+ txt_file = glob.glob(os.path.join(directory, "**/*.txt"), recursive=True)
338
+ # check that there is at most one txt file in the same directory
339
+ assert len(txt_file) in [0, 1]
340
+ report = ""
341
+ if txt_file:
342
+ txt_file = txt_file[0]
343
+ # somewhere in the corpus, encoding apparently changed
344
+ # first try to read as utf-8, if it does not work use latin-1
345
+ try:
346
+ with open(txt_file, "r", encoding="utf-8") as f:
347
+ report = f.read()
348
+ except UnicodeDecodeError:
349
+ with open(txt_file, "r", encoding="latin-1") as f:
350
+ report = f.read()
351
+ if not report:
352
+ raise RuntimeError(
353
+ f"Could not read physician report ({txt_file}). "
354
+ f"Disable option or choose appropriate directory."
355
+ )
356
+ return report
357
+
358
+
359
+ def _read_edf_header(file_path):
360
+ f = open(file_path, "rb")
361
+ header = f.read(88)
362
+ f.close()
363
+ return header
364
+
365
+
366
+ def _parse_age_and_gender_from_edf_header(file_path):
367
+ header = _read_edf_header(file_path)
368
+ # bytes 8 to 88 contain ascii local patient identification
369
+ # see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
370
+ patient_id = header[8:].decode("ascii")
371
+ age = -1
372
+ found_age = re.findall(r"Age:(\d+)", patient_id)
373
+ if len(found_age) == 1:
374
+ age = int(found_age[0])
375
+ gender = "X"
376
+ found_gender = re.findall(r"\s([F|M])\s", patient_id)
377
+ if len(found_gender) == 1:
378
+ gender = found_gender[0]
379
+ return age, gender
380
+
381
+
382
+ class TUHAbnormal(TUH):
383
+ """Temple University Hospital (TUH) Abnormal EEG Corpus.
384
+
385
+ see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
386
+
387
+ Parameters
388
+ ----------
389
+ path : str
390
+ Parent directory of the dataset.
391
+ recording_ids : list(int) | int
392
+ A (list of) int of recording id(s) to be read (order matters and will
393
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
394
+ then the first recording returned by this class will be chronologically
395
+ later then the second recording. Provide recording_ids in ascending
396
+ order to preserve chronological order.).
397
+ target_name : str
398
+ Can be 'pathological', 'gender', or 'age'.
399
+ preload : bool
400
+ If True, preload the data of the Raw objects.
401
+ add_physician_reports : bool
402
+ If True, the physician reports will be read from disk and added to the
403
+ description.
404
+ rename_channels : bool
405
+ If True, rename the EEG channels to the standard 10-05 system.
406
+ set_montage : bool
407
+ If True, set the montage to the standard 10-05 system.
408
+ n_jobs : int
409
+ Number of jobs to be used to read files in parallel.
410
+ """
411
+
412
+ def __init__(
413
+ self,
414
+ path: str,
415
+ recording_ids: list[int] | None = None,
416
+ target_name: str | tuple[str, ...] | None = "pathological",
417
+ preload: bool = False,
418
+ add_physician_reports: bool = False,
419
+ rename_channels: bool = False,
420
+ set_montage: bool = False,
421
+ n_jobs: int = 1,
422
+ ):
423
+ super().__init__(
424
+ path=path,
425
+ recording_ids=recording_ids,
426
+ preload=preload,
427
+ target_name=target_name,
428
+ add_physician_reports=add_physician_reports,
429
+ rename_channels=rename_channels,
430
+ set_montage=set_montage,
431
+ n_jobs=n_jobs,
432
+ )
433
+ additional_descriptions = []
434
+ for file_path in self.description.path:
435
+ additional_description = self._parse_additional_description_from_file_path(
436
+ file_path
437
+ )
438
+ additional_descriptions.append(additional_description)
439
+ additional_descriptions = pd.DataFrame(additional_descriptions)
440
+ self.set_description(additional_descriptions, overwrite=True)
441
+
442
+ @staticmethod
443
+ def _parse_additional_description_from_file_path(file_path):
444
+ file_path = os.path.normpath(file_path)
445
+ tokens = file_path.split(os.sep)
446
+ # expect paths as version/file type/data_split/pathology status/
447
+ # reference/subset/subject/recording session/file
448
+ # e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
449
+ # s004_2013_08_15/00000021_s004_t000.edf
450
+ assert "abnormal" in tokens or "normal" in tokens, "No pathology labels found."
451
+ assert "train" in tokens or "eval" in tokens, (
452
+ "No train or eval set information found."
453
+ )
454
+ return {
455
+ "version": tokens[-9],
456
+ "train": "train" in tokens,
457
+ "pathological": "abnormal" in tokens,
458
+ }
459
+
460
+
461
+ def _fake_raw(*args, **kwargs):
462
+ sfreq = 10
463
+ ch_names = [
464
+ "EEG A1-REF",
465
+ "EEG A2-REF",
466
+ "EEG FP1-REF",
467
+ "EEG FP2-REF",
468
+ "EEG F3-REF",
469
+ "EEG F4-REF",
470
+ "EEG C3-REF",
471
+ "EEG C4-REF",
472
+ "EEG P3-REF",
473
+ "EEG P4-REF",
474
+ "EEG O1-REF",
475
+ "EEG O2-REF",
476
+ "EEG F7-REF",
477
+ "EEG F8-REF",
478
+ "EEG T3-REF",
479
+ "EEG T4-REF",
480
+ "EEG T5-REF",
481
+ "EEG T6-REF",
482
+ "EEG FZ-REF",
483
+ "EEG CZ-REF",
484
+ "EEG PZ-REF",
485
+ ]
486
+ duration_min = 6
487
+ data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
488
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
489
+ raw = mne.io.RawArray(data=data, info=info)
490
+ return raw
491
+
492
+
493
+ def _get_header(*args, **kwargs):
494
+ all_paths = {**_TUH_EEG_PATHS, **_TUH_EEG_ABNORMAL_PATHS}
495
+ return all_paths[args[0]]
496
+
497
+
498
+ _TUH_EEG_PATHS = {
499
+ # These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
500
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf": b"0 00000000 M 01-JAN-1978 00000000 Age:37 ",
501
+ # noqa E501
502
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf": b"0 00009932 F 01-JAN-1961 00009932 Age:53 ",
503
+ # noqa E501
504
+ "tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf": b"0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ",
505
+ # noqa E501
506
+ "tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf": b"0 00012331 M 01-JAN-1975 00012331 Age:39 ",
507
+ # noqa E501
508
+ "tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf": b"0 00014928 F 01-JAN-1933 00014928 Age:83 ",
509
+ # noqa E501
510
+ }
511
+ _TUH_EEG_ABNORMAL_PATHS = {
512
+ # these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
513
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf": b"0 00007871 F 01-JAN-1988 00007871 Age:23 ",
514
+ # noqa E501
515
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf": b"0 00009777 M 01-JAN-1986 00009777 Age:26 ",
516
+ # noqa E501
517
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf": b"0 00008393 M 01-JAN-1960 00008393 Age:52 ",
518
+ # noqa E501
519
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf": b"0 00001200 M 01-JAN-1963 00001200 Age:47 ",
520
+ # noqa E501
521
+ "tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf": b"0 00005932 M 01-JAN-1963 00005932 Age:50 ",
522
+ # noqa E501
523
+ }
524
+
525
+
526
+ class _TUHMock(TUH):
527
+ """Mocked class for testing and examples."""
528
+
529
+ @mock.patch("glob.glob", return_value=_TUH_EEG_PATHS.keys())
530
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
531
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
532
+ def __init__(
533
+ self,
534
+ mock_glob,
535
+ path: str,
536
+ recording_ids: list[int] | None = None,
537
+ target_name: str | tuple[str, ...] | None = None,
538
+ preload: bool = False,
539
+ add_physician_reports: bool = False,
540
+ rename_channels: bool = False,
541
+ set_montage: bool = False,
542
+ n_jobs: int = 1,
543
+ ):
544
+ with warnings.catch_warnings():
545
+ warnings.filterwarnings("ignore", message="Cannot save date file")
546
+ super().__init__(
547
+ path=path,
548
+ recording_ids=recording_ids,
549
+ target_name=target_name,
550
+ preload=preload,
551
+ add_physician_reports=add_physician_reports,
552
+ rename_channels=rename_channels,
553
+ set_montage=set_montage,
554
+ n_jobs=n_jobs,
555
+ )
556
+
557
+
558
+ class _TUHAbnormalMock(TUHAbnormal):
559
+ """Mocked class for testing and examples."""
560
+
561
+ @mock.patch("glob.glob", return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
562
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
563
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
564
+ @mock.patch(
565
+ "braindecode.datasets.tuh._read_physician_report", return_value="simple_test"
566
+ )
567
+ def __init__(
568
+ self,
569
+ mock_glob,
570
+ mock_report,
571
+ path: str,
572
+ recording_ids: list[int] | None = None,
573
+ target_name: str | tuple[str, ...] | None = "pathological",
574
+ preload: bool = False,
575
+ add_physician_reports: bool = False,
576
+ rename_channels: bool = False,
577
+ set_montage: bool = False,
578
+ n_jobs: int = 1,
579
+ ):
580
+ with warnings.catch_warnings():
581
+ warnings.filterwarnings("ignore", message="Cannot save date file")
582
+ super().__init__(
583
+ path=path,
584
+ recording_ids=recording_ids,
585
+ target_name=target_name,
586
+ preload=preload,
587
+ add_physician_reports=add_physician_reports,
588
+ rename_channels=rename_channels,
589
+ set_montage=set_montage,
590
+ n_jobs=n_jobs,
591
+ )
@@ -0,0 +1,67 @@
1
+ """Utility functions for dataset handling."""
2
+
3
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from pathlib import Path
11
+
12
+
13
+ def _correct_dataset_path(
14
+ path: str, archive_name: str, subfolder_name: str | None = None
15
+ ) -> str:
16
+ """
17
+ Correct the dataset path after download and extraction.
18
+
19
+ This function handles two common post-download scenarios:
20
+ 1. Renames '.unzip' suffixed directories created by some extraction tools
21
+ 2. Navigates into a subfolder if the archive extracts to a nested directory
22
+
23
+ Parameters
24
+ ----------
25
+ path : str
26
+ Expected path to the dataset directory.
27
+ archive_name : str
28
+ Name of the downloaded archive file without extension
29
+ (e.g., "chb_mit_bids", "NMT").
30
+ subfolder_name : str | None
31
+ Name of the subfolder within the extracted archive that contains the
32
+ actual data. If provided and the subfolder exists, the path will be
33
+ updated to point to it. If None, only renaming is attempted.
34
+ Default is None.
35
+
36
+ Returns
37
+ -------
38
+ str
39
+ The corrected path to the dataset directory.
40
+
41
+ Raises
42
+ ------
43
+ PermissionError
44
+ If the '.unzip' directory exists but cannot be renamed due to
45
+ insufficient permissions.
46
+ """
47
+ if not Path(path).exists():
48
+ unzip_file_name = f"{archive_name}.unzip"
49
+ if (Path(path).parent / unzip_file_name).exists():
50
+ try:
51
+ os.rename(
52
+ src=Path(path).parent / unzip_file_name,
53
+ dst=Path(path),
54
+ )
55
+ except PermissionError:
56
+ raise PermissionError(
57
+ f"Please rename {Path(path).parent / unzip_file_name} "
58
+ f"manually to {path} and try again."
59
+ )
60
+
61
+ # Check if the subfolder exists inside the path
62
+ if subfolder_name is not None:
63
+ subfolder_path = os.path.join(path, subfolder_name)
64
+ if Path(subfolder_path).exists():
65
+ path = subfolder_path
66
+
67
+ return path