braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -7,20 +7,22 @@ TUH Abnormal EEG Corpus.
7
7
  #
8
8
  # License: BSD (3-clause)
9
9
 
10
- import re
11
- import os
10
+ from __future__ import annotations
11
+
12
12
  import glob
13
+ import os
14
+ import re
13
15
  import warnings
14
- from unittest import mock
15
16
  from datetime import datetime, timezone
16
17
  from typing import Iterable
18
+ from unittest import mock
17
19
 
18
- import pandas as pd
19
- import numpy as np
20
20
  import mne
21
+ import numpy as np
22
+ import pandas as pd
21
23
  from joblib import Parallel, delayed
22
24
 
23
- from .base import BaseDataset, BaseConcatDataset
25
+ from .base import BaseConcatDataset, BaseDataset
24
26
 
25
27
 
26
28
  class TUH(BaseConcatDataset):
@@ -44,14 +46,32 @@ class TUH(BaseConcatDataset):
44
46
  add_physician_reports: bool
45
47
  If True, the physician reports will be read from disk and added to the
46
48
  description.
49
+ rename_channels: bool
50
+ If True, rename the EEG channels to the standard 10-05 system.
51
+ set_montage: bool
52
+ If True, set the montage to the standard 10-05 system.
47
53
  n_jobs: int
48
54
  Number of jobs to be used to read files in parallel.
49
55
  """
50
- def __init__(self, path, recording_ids=None, target_name=None,
51
- preload=False, add_physician_reports=False, n_jobs=1):
56
+
57
+ def __init__(
58
+ self,
59
+ path: str,
60
+ recording_ids: list[int] | None = None,
61
+ target_name: str | tuple[str, ...] | None = None,
62
+ preload: bool = False,
63
+ add_physician_reports: bool = False,
64
+ rename_channels: bool = False,
65
+ set_montage: bool = False,
66
+ n_jobs: int = 1,
67
+ ):
68
+ if set_montage:
69
+ assert rename_channels, (
70
+ "If set_montage is True, rename_channels must be True."
71
+ )
52
72
  # create an index of all files and gather easily accessible info
53
73
  # without actually touching the files
54
- file_paths = glob.glob(os.path.join(path, '**/*.edf'), recursive=True)
74
+ file_paths = glob.glob(os.path.join(path, "**/*.edf"), recursive=True)
55
75
  descriptions = _create_description(file_paths)
56
76
  # sort the descriptions chronologicaly
57
77
  descriptions = _sort_chronologically(descriptions)
@@ -62,59 +82,139 @@ class TUH(BaseConcatDataset):
62
82
  # of recordings to load
63
83
  recording_ids = range(recording_ids)
64
84
  descriptions = descriptions[recording_ids]
85
+
86
+ # workaround to ensure warnings are suppressed when running in parallel
87
+ def create_dataset(*args, **kwargs):
88
+ with warnings.catch_warnings():
89
+ warnings.filterwarnings(
90
+ "ignore", message=".*not in description. '__getitem__'"
91
+ )
92
+ return self._create_dataset(*args, **kwargs)
93
+
65
94
  # this is the second loop (slow)
66
95
  # create datasets gathering more info about the files touching them
67
96
  # reading the raws and potentially preloading the data
68
97
  # disable joblib for tests. mocking seems to fail otherwise
69
98
  if n_jobs == 1:
70
- base_datasets = [self._create_dataset(
71
- descriptions[i], target_name, preload, add_physician_reports)
72
- for i in descriptions.columns]
99
+ base_datasets = [
100
+ create_dataset(
101
+ descriptions[i],
102
+ target_name,
103
+ preload,
104
+ add_physician_reports,
105
+ rename_channels,
106
+ set_montage,
107
+ )
108
+ for i in descriptions.columns
109
+ ]
73
110
  else:
74
- base_datasets = Parallel(n_jobs)(delayed(
75
- self._create_dataset)(
76
- descriptions[i], target_name, preload, add_physician_reports
77
- ) for i in descriptions.columns)
111
+ base_datasets = Parallel(n_jobs)(
112
+ delayed(create_dataset)(
113
+ descriptions[i],
114
+ target_name,
115
+ preload,
116
+ add_physician_reports,
117
+ rename_channels,
118
+ set_montage,
119
+ )
120
+ for i in descriptions.columns
121
+ )
78
122
  super().__init__(base_datasets)
79
123
 
80
124
  @staticmethod
81
- def _create_dataset(description, target_name, preload,
82
- add_physician_reports):
83
- file_path = description.loc['path']
125
+ def _rename_channels(raw):
126
+ """
127
+ Renames the EEG channels using mne conventions and sets their type to 'eeg'.
128
+
129
+ See https://isip.piconepress.com/publications/reports/2020/tuh_eeg/electrodes/
130
+ """
131
+ # remove ref suffix and prefix:
132
+ # TODO: replace with removesuffix and removeprefix when 3.8 is dropped
133
+ mapping_strip = {
134
+ c: c.replace("-REF", "").replace("-LE", "").replace("EEG ", "")
135
+ for c in raw.ch_names
136
+ }
137
+ raw.rename_channels(mapping_strip)
138
+
139
+ montage1005 = mne.channels.make_standard_montage("standard_1005")
140
+ mapping_eeg_names = {
141
+ c.upper(): c for c in montage1005.ch_names if c.upper() in raw.ch_names
142
+ }
143
+
144
+ # Set channels whose type could not be inferred (defaulted to "eeg") to "misc":
145
+ non_eeg_names = [c for c in raw.ch_names if c not in mapping_eeg_names]
146
+ if non_eeg_names:
147
+ non_eeg_types = raw.get_channel_types(picks=non_eeg_names)
148
+ mapping_non_eeg_types = {
149
+ c: "misc" for c, t in zip(non_eeg_names, non_eeg_types) if t == "eeg"
150
+ }
151
+ if mapping_non_eeg_types:
152
+ raw.set_channel_types(mapping_non_eeg_types)
153
+
154
+ if mapping_eeg_names:
155
+ # Set 1005 channels type to "eeg":
156
+ raw.set_channel_types(
157
+ {c: "eeg" for c in mapping_eeg_names}, on_unit_change="ignore"
158
+ )
159
+ # Fix capitalized EEG channel names:
160
+ raw.rename_channels(mapping_eeg_names)
161
+
162
+ @staticmethod
163
+ def _set_montage(raw):
164
+ montage = mne.channels.make_standard_montage("standard_1005")
165
+ raw.set_montage(montage, on_missing="ignore")
166
+
167
+ @staticmethod
168
+ def _create_dataset(
169
+ description,
170
+ target_name,
171
+ preload,
172
+ add_physician_reports,
173
+ rename_channels,
174
+ set_montage,
175
+ ):
176
+ file_path = description.loc["path"]
84
177
 
85
178
  # parse age and gender information from EDF header
86
179
  age, gender = _parse_age_and_gender_from_edf_header(file_path)
87
- raw = mne.io.read_raw_edf(file_path, preload=preload)
88
-
89
- meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) \
90
- if raw.info['meas_date'] is None else raw.info['meas_date']
180
+ raw = mne.io.read_raw_edf(
181
+ file_path, preload=preload, infer_types=True, verbose="error"
182
+ )
183
+ if rename_channels:
184
+ TUH._rename_channels(raw)
185
+ if set_montage:
186
+ TUH._set_montage(raw)
187
+
188
+ meas_date = (
189
+ datetime(1, 1, 1, tzinfo=timezone.utc)
190
+ if raw.info["meas_date"] is None
191
+ else raw.info["meas_date"]
192
+ )
91
193
  # if this is old version of the data and the year could be parsed from
92
194
  # file paths, use this instead as before
93
- if 'year' in description:
94
- meas_date = meas_date.replace(
95
- *description[['year', 'month', 'day']])
195
+ if "year" in description:
196
+ meas_date = meas_date.replace(*description[["year", "month", "day"]])
96
197
  raw.set_meas_date(meas_date)
97
198
 
98
199
  d = {
99
- 'age': int(age),
100
- 'gender': gender,
200
+ "age": int(age),
201
+ "gender": gender,
101
202
  }
102
203
  # if year exists in description = old version
103
204
  # if not, get it from meas_date in raw.info and add to description
104
205
  # if meas_date is None, create fake one
105
- if 'year' not in description:
106
- d['year'] = raw.info['meas_date'].year
107
- d['month'] = raw.info['meas_date'].month
108
- d['day'] = raw.info['meas_date'].day
206
+ if "year" not in description:
207
+ d["year"] = raw.info["meas_date"].year
208
+ d["month"] = raw.info["meas_date"].month
209
+ d["day"] = raw.info["meas_date"].day
109
210
 
110
211
  # read info relevant for preprocessing from raw without loading it
111
212
  if add_physician_reports:
112
213
  physician_report = _read_physician_report(file_path)
113
- d['report'] = physician_report
214
+ d["report"] = physician_report
114
215
  additional_description = pd.Series(d)
115
216
  description = pd.concat([description, additional_description])
116
- base_dataset = BaseDataset(raw, description,
117
- target_name=target_name)
217
+ base_dataset = BaseDataset(raw, description, target_name=target_name)
118
218
  return base_dataset
119
219
 
120
220
 
@@ -126,30 +226,32 @@ def _create_description(file_paths):
126
226
 
127
227
  def _sort_chronologically(descriptions):
128
228
  descriptions.sort_values(
129
- ["year", "month", "day", "subject", "session", "segment"],
130
- axis=1, inplace=True)
229
+ ["year", "month", "day", "subject", "session", "segment"], axis=1, inplace=True
230
+ )
131
231
  return descriptions
132
232
 
133
233
 
134
234
  def _read_date(file_path):
135
- date_path = file_path.replace('.edf', '_date.txt')
235
+ date_path = file_path.replace(".edf", "_date.txt")
136
236
  # if date file exists, read it
137
237
  if os.path.exists(date_path):
138
- description = pd.read_json(date_path, typ='series').to_dict()
238
+ description = pd.read_json(date_path, typ="series").to_dict()
139
239
  # otherwise read edf file, extract date and store to file
140
240
  else:
141
- raw = mne.io.read_raw_edf(file_path, preload=False, verbose='error')
241
+ raw = mne.io.read_raw_edf(file_path, preload=False, verbose="error")
142
242
  description = {
143
- 'year': raw.info['meas_date'].year,
144
- 'month': raw.info['meas_date'].month,
145
- 'day': raw.info['meas_date'].day,
243
+ "year": raw.info["meas_date"].year,
244
+ "month": raw.info["meas_date"].month,
245
+ "day": raw.info["meas_date"].day,
146
246
  }
147
247
  # if the txt file storing the recording date does not exist, create it
148
248
  try:
149
249
  pd.Series(description).to_json(date_path)
150
250
  except OSError:
151
- warnings.warn(f'Cannot save date file to {date_path}. '
152
- f'This might slow down creation of the dataset.')
251
+ warnings.warn(
252
+ f"Cannot save date file to {date_path}. "
253
+ f"This might slow down creation of the dataset."
254
+ )
153
255
  return description
154
256
 
155
257
 
@@ -158,12 +260,12 @@ def _parse_description_from_file_path(file_path):
158
260
  file_path = os.path.normpath(file_path)
159
261
  tokens = file_path.split(os.sep)
160
262
  # Extract version number and tuh_eeg_abnormal/tuh_eeg from file path
161
- if ('train' in tokens) or ('eval' in tokens): # tuh_eeg_abnormal
263
+ if ("train" in tokens) or ("eval" in tokens): # tuh_eeg_abnormal
162
264
  abnormal = True
163
265
  # Tokens[-2] is channel configuration (always 01_tcp_ar in abnormal)
164
266
  # on new versions, or
165
267
  # session (e.g. s004_2013_08_15) on old versions
166
- if tokens[-2].split('_')[0][0] == 's': # s denoting session number
268
+ if tokens[-2].split("_")[0][0] == "s": # s denoting session number
167
269
  version = tokens[-9] # Before dec 2022 updata
168
270
  else:
169
271
  version = tokens[-6] # After the dec 2022 update
@@ -181,22 +283,24 @@ def _parse_description_from_file_path(file_path):
181
283
  # or for abnormal:
182
284
  # tuh_eeg_abnormal/v3.0.0/edf/train/normal/
183
285
  # 01_tcp_ar/aaaaaaav_s004_t000.edf
184
- subject_id = tokens[-1].split('_')[0]
185
- session = tokens[-1].split('_')[1]
186
- segment = tokens[-1].split('_')[2].split('.')[0]
286
+ subject_id = tokens[-1].split("_")[0]
287
+ session = tokens[-1].split("_")[1]
288
+ segment = tokens[-1].split("_")[2].split(".")[0]
187
289
  description = _read_date(file_path)
188
- description.update({
189
- 'path': file_path,
190
- 'version': version,
191
- 'subject': subject_id,
192
- 'session': int(session[1:]),
193
- 'segment': int(segment[1:]),
194
- })
290
+ description.update(
291
+ {
292
+ "path": file_path,
293
+ "version": version,
294
+ "subject": subject_id,
295
+ "session": int(session[1:]),
296
+ "segment": int(segment[1:]),
297
+ }
298
+ )
195
299
  if not abnormal:
196
- year, month, day = tokens[-3].split('_')[1:]
197
- description['year'] = int(year)
198
- description['month'] = int(month)
199
- description['day'] = int(day)
300
+ year, month, day = tokens[-3].split("_")[1:]
301
+ description["year"] = int(year)
302
+ description["month"] = int(month)
303
+ description["day"] = int(day)
200
304
  return description
201
305
  else: # Old file path structure
202
306
  # expect file paths as tuh_eeg/version/file_type/reference/data_split/
@@ -208,43 +312,45 @@ def _parse_description_from_file_path(file_path):
208
312
  # reference/subset/subject/recording session/file
209
313
  # v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
210
314
  # s004_2013_08_15/00000021_s004_t000.edf
211
- subject_id = tokens[-1].split('_')[0]
212
- session = tokens[-2].split('_')[0] # string on format 's000'
315
+ subject_id = tokens[-1].split("_")[0]
316
+ session = tokens[-2].split("_")[0] # string on format 's000'
213
317
  # According to the example path in the comment 8 lines above,
214
318
  # segment is not included in the file name
215
- segment = tokens[-1].split('_')[-1].split('.')[0] # TODO: test with tuh_eeg
216
- year, month, day = tokens[-2].split('_')[1:]
319
+ segment = tokens[-1].split("_")[-1].split(".")[0] # TODO: test with tuh_eeg
320
+ year, month, day = tokens[-2].split("_")[1:]
217
321
  return {
218
- 'path': file_path,
219
- 'version': version,
220
- 'year': int(year),
221
- 'month': int(month),
222
- 'day': int(day),
223
- 'subject': int(subject_id),
224
- 'session': int(session[1:]),
225
- 'segment': int(segment[1:]),
322
+ "path": file_path,
323
+ "version": version,
324
+ "year": int(year),
325
+ "month": int(month),
326
+ "day": int(day),
327
+ "subject": int(subject_id),
328
+ "session": int(session[1:]),
329
+ "segment": int(segment[1:]),
226
330
  }
227
331
 
228
332
 
229
333
  def _read_physician_report(file_path):
230
334
  directory = os.path.dirname(file_path)
231
- txt_file = glob.glob(os.path.join(directory, '**/*.txt'), recursive=True)
335
+ txt_file = glob.glob(os.path.join(directory, "**/*.txt"), recursive=True)
232
336
  # check that there is at most one txt file in the same directory
233
337
  assert len(txt_file) in [0, 1]
234
- report = ''
338
+ report = ""
235
339
  if txt_file:
236
340
  txt_file = txt_file[0]
237
341
  # somewhere in the corpus, encoding apparently changed
238
342
  # first try to read as utf-8, if it does not work use latin-1
239
343
  try:
240
- with open(txt_file, 'r', encoding='utf-8') as f:
344
+ with open(txt_file, "r", encoding="utf-8") as f:
241
345
  report = f.read()
242
346
  except UnicodeDecodeError:
243
- with open(txt_file, 'r', encoding='latin-1') as f:
347
+ with open(txt_file, "r", encoding="latin-1") as f:
244
348
  report = f.read()
245
349
  if not report:
246
- raise RuntimeError(f'Could not read physician report ({txt_file}). '
247
- f'Disable option or choose appropriate directory.')
350
+ raise RuntimeError(
351
+ f"Could not read physician report ({txt_file}). "
352
+ f"Disable option or choose appropriate directory."
353
+ )
248
354
  return report
249
355
 
250
356
 
@@ -292,20 +398,40 @@ class TUHAbnormal(TUH):
292
398
  add_physician_reports: bool
293
399
  If True, the physician reports will be read from disk and added to the
294
400
  description.
401
+ rename_channels: bool
402
+ If True, rename the EEG channels to the standard 10-05 system.
403
+ set_montage: bool
404
+ If True, set the montage to the standard 10-05 system.
405
+ n_jobs: int
406
+ Number of jobs to be used to read files in parallel.
295
407
  """
296
- def __init__(self, path, recording_ids=None, target_name='pathological',
297
- preload=False, add_physician_reports=False, n_jobs=1):
298
- with warnings.catch_warnings():
299
- warnings.filterwarnings(
300
- "ignore", message=".*not in description. '__getitem__'")
301
- super().__init__(path=path, recording_ids=recording_ids,
302
- preload=preload, target_name=target_name,
303
- add_physician_reports=add_physician_reports,
304
- n_jobs=n_jobs)
408
+
409
+ def __init__(
410
+ self,
411
+ path: str,
412
+ recording_ids: list[int] | None = None,
413
+ target_name: str | tuple[str, ...] | None = "pathological",
414
+ preload: bool = False,
415
+ add_physician_reports: bool = False,
416
+ rename_channels: bool = False,
417
+ set_montage: bool = False,
418
+ n_jobs: int = 1,
419
+ ):
420
+ super().__init__(
421
+ path=path,
422
+ recording_ids=recording_ids,
423
+ preload=preload,
424
+ target_name=target_name,
425
+ add_physician_reports=add_physician_reports,
426
+ rename_channels=rename_channels,
427
+ set_montage=set_montage,
428
+ n_jobs=n_jobs,
429
+ )
305
430
  additional_descriptions = []
306
431
  for file_path in self.description.path:
307
- additional_description = (
308
- self._parse_additional_description_from_file_path(file_path))
432
+ additional_description = self._parse_additional_description_from_file_path(
433
+ file_path
434
+ )
309
435
  additional_descriptions.append(additional_description)
310
436
  additional_descriptions = pd.DataFrame(additional_descriptions)
311
437
  self.set_description(additional_descriptions, overwrite=True)
@@ -318,28 +444,45 @@ class TUHAbnormal(TUH):
318
444
  # reference/subset/subject/recording session/file
319
445
  # e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
320
446
  # s004_2013_08_15/00000021_s004_t000.edf
321
- assert ('abnormal' in tokens or 'normal' in tokens), (
322
- 'No pathology labels found.')
323
- assert ('train' in tokens or 'eval' in tokens), (
324
- 'No train or eval set information found.')
447
+ assert "abnormal" in tokens or "normal" in tokens, "No pathology labels found."
448
+ assert "train" in tokens or "eval" in tokens, (
449
+ "No train or eval set information found."
450
+ )
325
451
  return {
326
- 'version': tokens[-9],
327
- 'train': 'train' in tokens,
328
- 'pathological': 'abnormal' in tokens,
452
+ "version": tokens[-9],
453
+ "train": "train" in tokens,
454
+ "pathological": "abnormal" in tokens,
329
455
  }
330
456
 
331
457
 
332
458
  def _fake_raw(*args, **kwargs):
333
459
  sfreq = 10
334
460
  ch_names = [
335
- 'EEG A1-REF', 'EEG A2-REF',
336
- 'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
337
- 'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
338
- 'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
339
- 'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF']
461
+ "EEG A1-REF",
462
+ "EEG A2-REF",
463
+ "EEG FP1-REF",
464
+ "EEG FP2-REF",
465
+ "EEG F3-REF",
466
+ "EEG F4-REF",
467
+ "EEG C3-REF",
468
+ "EEG C4-REF",
469
+ "EEG P3-REF",
470
+ "EEG P4-REF",
471
+ "EEG O1-REF",
472
+ "EEG O2-REF",
473
+ "EEG F7-REF",
474
+ "EEG F8-REF",
475
+ "EEG T3-REF",
476
+ "EEG T4-REF",
477
+ "EEG T5-REF",
478
+ "EEG T6-REF",
479
+ "EEG FZ-REF",
480
+ "EEG CZ-REF",
481
+ "EEG PZ-REF",
482
+ ]
340
483
  duration_min = 6
341
484
  data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
342
- info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
485
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
343
486
  raw = mne.io.RawArray(data=data, info=info)
344
487
  return raw
345
488
 
@@ -351,54 +494,95 @@ def _get_header(*args, **kwargs):
351
494
 
352
495
  _TUH_EEG_PATHS = {
353
496
  # These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
354
- 'tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf': b'0 00000000 M 01-JAN-1978 00000000 Age:37 ', # noqa E501
355
- 'tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf': b'0 00009932 F 01-JAN-1961 00009932 Age:53 ', # noqa E501
356
- 'tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf': b'0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ', # noqa E501
357
- 'tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf': b'0 00012331 M 01-JAN-1975 00012331 Age:39 ', # noqa E501
358
- 'tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf': b'0 00014928 F 01-JAN-1933 00014928 Age:83 ', # noqa E501
497
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf": b"0 00000000 M 01-JAN-1978 00000000 Age:37 ",
498
+ # noqa E501
499
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf": b"0 00009932 F 01-JAN-1961 00009932 Age:53 ",
500
+ # noqa E501
501
+ "tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf": b"0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ",
502
+ # noqa E501
503
+ "tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf": b"0 00012331 M 01-JAN-1975 00012331 Age:39 ",
504
+ # noqa E501
505
+ "tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf": b"0 00014928 F 01-JAN-1933 00014928 Age:83 ",
506
+ # noqa E501
359
507
  }
360
508
  _TUH_EEG_ABNORMAL_PATHS = {
361
509
  # these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
362
- 'tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf': b'0 00007871 F 01-JAN-1988 00007871 Age:23 ', # noqa E501
363
- 'tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf': b'0 00009777 M 01-JAN-1986 00009777 Age:26 ', # noqa E501
364
- 'tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf': b'0 00008393 M 01-JAN-1960 00008393 Age:52 ', # noqa E501
365
- 'tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf': b'0 00001200 M 01-JAN-1963 00001200 Age:47 ', # noqa E501
366
- 'tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf': b'0 00005932 M 01-JAN-1963 00005932 Age:50 ', # noqa E501
510
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf": b"0 00007871 F 01-JAN-1988 00007871 Age:23 ",
511
+ # noqa E501
512
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf": b"0 00009777 M 01-JAN-1986 00009777 Age:26 ",
513
+ # noqa E501
514
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf": b"0 00008393 M 01-JAN-1960 00008393 Age:52 ",
515
+ # noqa E501
516
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf": b"0 00001200 M 01-JAN-1963 00001200 Age:47 ",
517
+ # noqa E501
518
+ "tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf": b"0 00005932 M 01-JAN-1963 00005932 Age:50 ",
519
+ # noqa E501
367
520
  }
368
521
 
369
522
 
370
523
  class _TUHMock(TUH):
371
524
  """Mocked class for testing and examples."""
372
- @mock.patch('glob.glob', return_value=_TUH_EEG_PATHS.keys())
373
- @mock.patch('mne.io.read_raw_edf', new=_fake_raw)
374
- @mock.patch('braindecode.datasets.tuh._read_edf_header',
375
- new=_get_header)
376
- def __init__(self, mock_glob, path, recording_ids=None, target_name=None,
377
- preload=False, add_physician_reports=False, n_jobs=1):
525
+
526
+ @mock.patch("glob.glob", return_value=_TUH_EEG_PATHS.keys())
527
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
528
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
529
+ def __init__(
530
+ self,
531
+ mock_glob,
532
+ path: str,
533
+ recording_ids: list[int] | None = None,
534
+ target_name: str | tuple[str, ...] | None = None,
535
+ preload: bool = False,
536
+ add_physician_reports: bool = False,
537
+ rename_channels: bool = False,
538
+ set_montage: bool = False,
539
+ n_jobs: int = 1,
540
+ ):
378
541
  with warnings.catch_warnings():
379
- warnings.filterwarnings(
380
- "ignore", message="Cannot save date file")
381
- super().__init__(path=path, recording_ids=recording_ids,
382
- target_name=target_name, preload=preload,
383
- add_physician_reports=add_physician_reports,
384
- n_jobs=n_jobs)
542
+ warnings.filterwarnings("ignore", message="Cannot save date file")
543
+ super().__init__(
544
+ path=path,
545
+ recording_ids=recording_ids,
546
+ target_name=target_name,
547
+ preload=preload,
548
+ add_physician_reports=add_physician_reports,
549
+ rename_channels=rename_channels,
550
+ set_montage=set_montage,
551
+ n_jobs=n_jobs,
552
+ )
385
553
 
386
554
 
387
555
  class _TUHAbnormalMock(TUHAbnormal):
388
556
  """Mocked class for testing and examples."""
389
- @mock.patch('glob.glob', return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
390
- @mock.patch('mne.io.read_raw_edf', new=_fake_raw)
391
- @mock.patch('braindecode.datasets.tuh._read_edf_header',
392
- new=_get_header)
393
- @mock.patch('braindecode.datasets.tuh._read_physician_report',
394
- return_value='simple_test')
395
- def __init__(self, mock_glob, mock_report, path, recording_ids=None,
396
- target_name='pathological', preload=False,
397
- add_physician_reports=False, n_jobs=1):
557
+
558
+ @mock.patch("glob.glob", return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
559
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
560
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
561
+ @mock.patch(
562
+ "braindecode.datasets.tuh._read_physician_report", return_value="simple_test"
563
+ )
564
+ def __init__(
565
+ self,
566
+ mock_glob,
567
+ mock_report,
568
+ path: str,
569
+ recording_ids: list[int] | None = None,
570
+ target_name: str | tuple[str, ...] | None = "pathological",
571
+ preload: bool = False,
572
+ add_physician_reports: bool = False,
573
+ rename_channels: bool = False,
574
+ set_montage: bool = False,
575
+ n_jobs: int = 1,
576
+ ):
398
577
  with warnings.catch_warnings():
399
- warnings.filterwarnings(
400
- "ignore", message="Cannot save date file")
401
- super().__init__(path=path, recording_ids=recording_ids,
402
- target_name=target_name, preload=preload,
403
- add_physician_reports=add_physician_reports,
404
- n_jobs=n_jobs)
578
+ warnings.filterwarnings("ignore", message="Cannot save date file")
579
+ super().__init__(
580
+ path=path,
581
+ recording_ids=recording_ids,
582
+ target_name=target_name,
583
+ preload=preload,
584
+ add_physician_reports=add_physician_reports,
585
+ rename_channels=rename_channels,
586
+ set_montage=set_montage,
587
+ n_jobs=n_jobs,
588
+ )