braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,588 @@
1
+ """
2
+ Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the
3
+ TUH Abnormal EEG Corpus.
4
+ """
5
+
6
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
7
+ #
8
+ # License: BSD (3-clause)
9
+
10
+ from __future__ import annotations
11
+
12
+ import glob
13
+ import os
14
+ import re
15
+ import warnings
16
+ from datetime import datetime, timezone
17
+ from typing import Iterable
18
+ from unittest import mock
19
+
20
+ import mne
21
+ import numpy as np
22
+ import pandas as pd
23
+ from joblib import Parallel, delayed
24
+
25
+ from .base import BaseConcatDataset, BaseDataset
26
+
27
+
28
+ class TUH(BaseConcatDataset):
29
+ """Temple University Hospital (TUH) EEG Corpus
30
+ (www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
31
+
32
+ Parameters
33
+ ----------
34
+ path: str
35
+ Parent directory of the dataset.
36
+ recording_ids: list(int) | int
37
+ A (list of) int of recording id(s) to be read (order matters and will
38
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
39
+ then the first recording returned by this class will be chronologically
40
+ later then the second recording. Provide recording_ids in ascending
41
+ order to preserve chronological order.).
42
+ target_name: str
43
+ Can be 'gender', or 'age'.
44
+ preload: bool
45
+ If True, preload the data of the Raw objects.
46
+ add_physician_reports: bool
47
+ If True, the physician reports will be read from disk and added to the
48
+ description.
49
+ rename_channels: bool
50
+ If True, rename the EEG channels to the standard 10-05 system.
51
+ set_montage: bool
52
+ If True, set the montage to the standard 10-05 system.
53
+ n_jobs: int
54
+ Number of jobs to be used to read files in parallel.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ path: str,
60
+ recording_ids: list[int] | None = None,
61
+ target_name: str | tuple[str, ...] | None = None,
62
+ preload: bool = False,
63
+ add_physician_reports: bool = False,
64
+ rename_channels: bool = False,
65
+ set_montage: bool = False,
66
+ n_jobs: int = 1,
67
+ ):
68
+ if set_montage:
69
+ assert rename_channels, (
70
+ "If set_montage is True, rename_channels must be True."
71
+ )
72
+ # create an index of all files and gather easily accessible info
73
+ # without actually touching the files
74
+ file_paths = glob.glob(os.path.join(path, "**/*.edf"), recursive=True)
75
+ descriptions = _create_description(file_paths)
76
+ # sort the descriptions chronologicaly
77
+ descriptions = _sort_chronologically(descriptions)
78
+ # limit to specified recording ids before doing slow stuff
79
+ if recording_ids is not None:
80
+ if not isinstance(recording_ids, Iterable):
81
+ # Assume it is an integer specifying number
82
+ # of recordings to load
83
+ recording_ids = range(recording_ids)
84
+ descriptions = descriptions[recording_ids]
85
+
86
+ # workaround to ensure warnings are suppressed when running in parallel
87
+ def create_dataset(*args, **kwargs):
88
+ with warnings.catch_warnings():
89
+ warnings.filterwarnings(
90
+ "ignore", message=".*not in description. '__getitem__'"
91
+ )
92
+ return self._create_dataset(*args, **kwargs)
93
+
94
+ # this is the second loop (slow)
95
+ # create datasets gathering more info about the files touching them
96
+ # reading the raws and potentially preloading the data
97
+ # disable joblib for tests. mocking seems to fail otherwise
98
+ if n_jobs == 1:
99
+ base_datasets = [
100
+ create_dataset(
101
+ descriptions[i],
102
+ target_name,
103
+ preload,
104
+ add_physician_reports,
105
+ rename_channels,
106
+ set_montage,
107
+ )
108
+ for i in descriptions.columns
109
+ ]
110
+ else:
111
+ base_datasets = Parallel(n_jobs)(
112
+ delayed(create_dataset)(
113
+ descriptions[i],
114
+ target_name,
115
+ preload,
116
+ add_physician_reports,
117
+ rename_channels,
118
+ set_montage,
119
+ )
120
+ for i in descriptions.columns
121
+ )
122
+ super().__init__(base_datasets)
123
+
124
+ @staticmethod
125
+ def _rename_channels(raw):
126
+ """
127
+ Renames the EEG channels using mne conventions and sets their type to 'eeg'.
128
+
129
+ See https://isip.piconepress.com/publications/reports/2020/tuh_eeg/electrodes/
130
+ """
131
+ # remove ref suffix and prefix:
132
+ # TODO: replace with removesuffix and removeprefix when 3.8 is dropped
133
+ mapping_strip = {
134
+ c: c.replace("-REF", "").replace("-LE", "").replace("EEG ", "")
135
+ for c in raw.ch_names
136
+ }
137
+ raw.rename_channels(mapping_strip)
138
+
139
+ montage1005 = mne.channels.make_standard_montage("standard_1005")
140
+ mapping_eeg_names = {
141
+ c.upper(): c for c in montage1005.ch_names if c.upper() in raw.ch_names
142
+ }
143
+
144
+ # Set channels whose type could not be inferred (defaulted to "eeg") to "misc":
145
+ non_eeg_names = [c for c in raw.ch_names if c not in mapping_eeg_names]
146
+ if non_eeg_names:
147
+ non_eeg_types = raw.get_channel_types(picks=non_eeg_names)
148
+ mapping_non_eeg_types = {
149
+ c: "misc" for c, t in zip(non_eeg_names, non_eeg_types) if t == "eeg"
150
+ }
151
+ if mapping_non_eeg_types:
152
+ raw.set_channel_types(mapping_non_eeg_types)
153
+
154
+ if mapping_eeg_names:
155
+ # Set 1005 channels type to "eeg":
156
+ raw.set_channel_types(
157
+ {c: "eeg" for c in mapping_eeg_names}, on_unit_change="ignore"
158
+ )
159
+ # Fix capitalized EEG channel names:
160
+ raw.rename_channels(mapping_eeg_names)
161
+
162
+ @staticmethod
163
+ def _set_montage(raw):
164
+ montage = mne.channels.make_standard_montage("standard_1005")
165
+ raw.set_montage(montage, on_missing="ignore")
166
+
167
+ @staticmethod
168
+ def _create_dataset(
169
+ description,
170
+ target_name,
171
+ preload,
172
+ add_physician_reports,
173
+ rename_channels,
174
+ set_montage,
175
+ ):
176
+ file_path = description.loc["path"]
177
+
178
+ # parse age and gender information from EDF header
179
+ age, gender = _parse_age_and_gender_from_edf_header(file_path)
180
+ raw = mne.io.read_raw_edf(
181
+ file_path, preload=preload, infer_types=True, verbose="error"
182
+ )
183
+ if rename_channels:
184
+ TUH._rename_channels(raw)
185
+ if set_montage:
186
+ TUH._set_montage(raw)
187
+
188
+ meas_date = (
189
+ datetime(1, 1, 1, tzinfo=timezone.utc)
190
+ if raw.info["meas_date"] is None
191
+ else raw.info["meas_date"]
192
+ )
193
+ # if this is old version of the data and the year could be parsed from
194
+ # file paths, use this instead as before
195
+ if "year" in description:
196
+ meas_date = meas_date.replace(*description[["year", "month", "day"]])
197
+ raw.set_meas_date(meas_date)
198
+
199
+ d = {
200
+ "age": int(age),
201
+ "gender": gender,
202
+ }
203
+ # if year exists in description = old version
204
+ # if not, get it from meas_date in raw.info and add to description
205
+ # if meas_date is None, create fake one
206
+ if "year" not in description:
207
+ d["year"] = raw.info["meas_date"].year
208
+ d["month"] = raw.info["meas_date"].month
209
+ d["day"] = raw.info["meas_date"].day
210
+
211
+ # read info relevant for preprocessing from raw without loading it
212
+ if add_physician_reports:
213
+ physician_report = _read_physician_report(file_path)
214
+ d["report"] = physician_report
215
+ additional_description = pd.Series(d)
216
+ description = pd.concat([description, additional_description])
217
+ base_dataset = BaseDataset(raw, description, target_name=target_name)
218
+ return base_dataset
219
+
220
+
221
+ def _create_description(file_paths):
222
+ descriptions = [_parse_description_from_file_path(f) for f in file_paths]
223
+ descriptions = pd.DataFrame(descriptions)
224
+ return descriptions.T
225
+
226
+
227
+ def _sort_chronologically(descriptions):
228
+ descriptions.sort_values(
229
+ ["year", "month", "day", "subject", "session", "segment"], axis=1, inplace=True
230
+ )
231
+ return descriptions
232
+
233
+
234
+ def _read_date(file_path):
235
+ date_path = file_path.replace(".edf", "_date.txt")
236
+ # if date file exists, read it
237
+ if os.path.exists(date_path):
238
+ description = pd.read_json(date_path, typ="series").to_dict()
239
+ # otherwise read edf file, extract date and store to file
240
+ else:
241
+ raw = mne.io.read_raw_edf(file_path, preload=False, verbose="error")
242
+ description = {
243
+ "year": raw.info["meas_date"].year,
244
+ "month": raw.info["meas_date"].month,
245
+ "day": raw.info["meas_date"].day,
246
+ }
247
+ # if the txt file storing the recording date does not exist, create it
248
+ try:
249
+ pd.Series(description).to_json(date_path)
250
+ except OSError:
251
+ warnings.warn(
252
+ f"Cannot save date file to {date_path}. "
253
+ f"This might slow down creation of the dataset."
254
+ )
255
+ return description
256
+
257
+
258
+ def _parse_description_from_file_path(file_path):
259
+ # stackoverflow.com/questions/3167154/how-to-split-a-dos-path-into-its-components-in-python # noqa
260
+ file_path = os.path.normpath(file_path)
261
+ tokens = file_path.split(os.sep)
262
+ # Extract version number and tuh_eeg_abnormal/tuh_eeg from file path
263
+ if ("train" in tokens) or ("eval" in tokens): # tuh_eeg_abnormal
264
+ abnormal = True
265
+ # Tokens[-2] is channel configuration (always 01_tcp_ar in abnormal)
266
+ # on new versions, or
267
+ # session (e.g. s004_2013_08_15) on old versions
268
+ if tokens[-2].split("_")[0][0] == "s": # s denoting session number
269
+ version = tokens[-9] # Before dec 2022 updata
270
+ else:
271
+ version = tokens[-6] # After the dec 2022 update
272
+
273
+ else: # tuh_eeg
274
+ abnormal = False
275
+ version = tokens[-7]
276
+ v_number = int(version[1])
277
+
278
+ if (abnormal and v_number >= 3) or ((not abnormal) and v_number >= 2):
279
+ # New file path structure for versions after december 2022,
280
+ # expect file paths as
281
+ # tuh_eeg/v2.0.0/edf/000/aaaaaaaa/
282
+ # s001_2015_12_30/01_tcp_ar/aaaaaaaa_s001_t000.edf
283
+ # or for abnormal:
284
+ # tuh_eeg_abnormal/v3.0.0/edf/train/normal/
285
+ # 01_tcp_ar/aaaaaaav_s004_t000.edf
286
+ subject_id = tokens[-1].split("_")[0]
287
+ session = tokens[-1].split("_")[1]
288
+ segment = tokens[-1].split("_")[2].split(".")[0]
289
+ description = _read_date(file_path)
290
+ description.update(
291
+ {
292
+ "path": file_path,
293
+ "version": version,
294
+ "subject": subject_id,
295
+ "session": int(session[1:]),
296
+ "segment": int(segment[1:]),
297
+ }
298
+ )
299
+ if not abnormal:
300
+ year, month, day = tokens[-3].split("_")[1:]
301
+ description["year"] = int(year)
302
+ description["month"] = int(month)
303
+ description["day"] = int(day)
304
+ return description
305
+ else: # Old file path structure
306
+ # expect file paths as tuh_eeg/version/file_type/reference/data_split/
307
+ # subject/recording session/file
308
+ # e.g. tuh_eeg/v1.1.0/edf/01_tcp_ar/027/00002729/
309
+ # s001_2006_04_12/00002729_s001.edf
310
+ # or for abnormal
311
+ # version/file type/data_split/pathology status/
312
+ # reference/subset/subject/recording session/file
313
+ # v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
314
+ # s004_2013_08_15/00000021_s004_t000.edf
315
+ subject_id = tokens[-1].split("_")[0]
316
+ session = tokens[-2].split("_")[0] # string on format 's000'
317
+ # According to the example path in the comment 8 lines above,
318
+ # segment is not included in the file name
319
+ segment = tokens[-1].split("_")[-1].split(".")[0] # TODO: test with tuh_eeg
320
+ year, month, day = tokens[-2].split("_")[1:]
321
+ return {
322
+ "path": file_path,
323
+ "version": version,
324
+ "year": int(year),
325
+ "month": int(month),
326
+ "day": int(day),
327
+ "subject": int(subject_id),
328
+ "session": int(session[1:]),
329
+ "segment": int(segment[1:]),
330
+ }
331
+
332
+
333
+ def _read_physician_report(file_path):
334
+ directory = os.path.dirname(file_path)
335
+ txt_file = glob.glob(os.path.join(directory, "**/*.txt"), recursive=True)
336
+ # check that there is at most one txt file in the same directory
337
+ assert len(txt_file) in [0, 1]
338
+ report = ""
339
+ if txt_file:
340
+ txt_file = txt_file[0]
341
+ # somewhere in the corpus, encoding apparently changed
342
+ # first try to read as utf-8, if it does not work use latin-1
343
+ try:
344
+ with open(txt_file, "r", encoding="utf-8") as f:
345
+ report = f.read()
346
+ except UnicodeDecodeError:
347
+ with open(txt_file, "r", encoding="latin-1") as f:
348
+ report = f.read()
349
+ if not report:
350
+ raise RuntimeError(
351
+ f"Could not read physician report ({txt_file}). "
352
+ f"Disable option or choose appropriate directory."
353
+ )
354
+ return report
355
+
356
+
357
+ def _read_edf_header(file_path):
358
+ f = open(file_path, "rb")
359
+ header = f.read(88)
360
+ f.close()
361
+ return header
362
+
363
+
364
+ def _parse_age_and_gender_from_edf_header(file_path):
365
+ header = _read_edf_header(file_path)
366
+ # bytes 8 to 88 contain ascii local patient identification
367
+ # see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
368
+ patient_id = header[8:].decode("ascii")
369
+ age = -1
370
+ found_age = re.findall(r"Age:(\d+)", patient_id)
371
+ if len(found_age) == 1:
372
+ age = int(found_age[0])
373
+ gender = "X"
374
+ found_gender = re.findall(r"\s([F|M])\s", patient_id)
375
+ if len(found_gender) == 1:
376
+ gender = found_gender[0]
377
+ return age, gender
378
+
379
+
380
+ class TUHAbnormal(TUH):
381
+ """Temple University Hospital (TUH) Abnormal EEG Corpus.
382
+ see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
383
+
384
+ Parameters
385
+ ----------
386
+ path: str
387
+ Parent directory of the dataset.
388
+ recording_ids: list(int) | int
389
+ A (list of) int of recording id(s) to be read (order matters and will
390
+ overwrite default chronological order, e.g. if recording_ids=[1,0],
391
+ then the first recording returned by this class will be chronologically
392
+ later then the second recording. Provide recording_ids in ascending
393
+ order to preserve chronological order.).
394
+ target_name: str
395
+ Can be 'pathological', 'gender', or 'age'.
396
+ preload: bool
397
+ If True, preload the data of the Raw objects.
398
+ add_physician_reports: bool
399
+ If True, the physician reports will be read from disk and added to the
400
+ description.
401
+ rename_channels: bool
402
+ If True, rename the EEG channels to the standard 10-05 system.
403
+ set_montage: bool
404
+ If True, set the montage to the standard 10-05 system.
405
+ n_jobs: int
406
+ Number of jobs to be used to read files in parallel.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ path: str,
412
+ recording_ids: list[int] | None = None,
413
+ target_name: str | tuple[str, ...] | None = "pathological",
414
+ preload: bool = False,
415
+ add_physician_reports: bool = False,
416
+ rename_channels: bool = False,
417
+ set_montage: bool = False,
418
+ n_jobs: int = 1,
419
+ ):
420
+ super().__init__(
421
+ path=path,
422
+ recording_ids=recording_ids,
423
+ preload=preload,
424
+ target_name=target_name,
425
+ add_physician_reports=add_physician_reports,
426
+ rename_channels=rename_channels,
427
+ set_montage=set_montage,
428
+ n_jobs=n_jobs,
429
+ )
430
+ additional_descriptions = []
431
+ for file_path in self.description.path:
432
+ additional_description = self._parse_additional_description_from_file_path(
433
+ file_path
434
+ )
435
+ additional_descriptions.append(additional_description)
436
+ additional_descriptions = pd.DataFrame(additional_descriptions)
437
+ self.set_description(additional_descriptions, overwrite=True)
438
+
439
+ @staticmethod
440
+ def _parse_additional_description_from_file_path(file_path):
441
+ file_path = os.path.normpath(file_path)
442
+ tokens = file_path.split(os.sep)
443
+ # expect paths as version/file type/data_split/pathology status/
444
+ # reference/subset/subject/recording session/file
445
+ # e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
446
+ # s004_2013_08_15/00000021_s004_t000.edf
447
+ assert "abnormal" in tokens or "normal" in tokens, "No pathology labels found."
448
+ assert "train" in tokens or "eval" in tokens, (
449
+ "No train or eval set information found."
450
+ )
451
+ return {
452
+ "version": tokens[-9],
453
+ "train": "train" in tokens,
454
+ "pathological": "abnormal" in tokens,
455
+ }
456
+
457
+
458
+ def _fake_raw(*args, **kwargs):
459
+ sfreq = 10
460
+ ch_names = [
461
+ "EEG A1-REF",
462
+ "EEG A2-REF",
463
+ "EEG FP1-REF",
464
+ "EEG FP2-REF",
465
+ "EEG F3-REF",
466
+ "EEG F4-REF",
467
+ "EEG C3-REF",
468
+ "EEG C4-REF",
469
+ "EEG P3-REF",
470
+ "EEG P4-REF",
471
+ "EEG O1-REF",
472
+ "EEG O2-REF",
473
+ "EEG F7-REF",
474
+ "EEG F8-REF",
475
+ "EEG T3-REF",
476
+ "EEG T4-REF",
477
+ "EEG T5-REF",
478
+ "EEG T6-REF",
479
+ "EEG FZ-REF",
480
+ "EEG CZ-REF",
481
+ "EEG PZ-REF",
482
+ ]
483
+ duration_min = 6
484
+ data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
485
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
486
+ raw = mne.io.RawArray(data=data, info=info)
487
+ return raw
488
+
489
+
490
+ def _get_header(*args, **kwargs):
491
+ all_paths = {**_TUH_EEG_PATHS, **_TUH_EEG_ABNORMAL_PATHS}
492
+ return all_paths[args[0]]
493
+
494
+
495
+ _TUH_EEG_PATHS = {
496
+ # These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
497
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf": b"0 00000000 M 01-JAN-1978 00000000 Age:37 ",
498
+ # noqa E501
499
+ "tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf": b"0 00009932 F 01-JAN-1961 00009932 Age:53 ",
500
+ # noqa E501
501
+ "tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf": b"0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ",
502
+ # noqa E501
503
+ "tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf": b"0 00012331 M 01-JAN-1975 00012331 Age:39 ",
504
+ # noqa E501
505
+ "tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf": b"0 00014928 F 01-JAN-1933 00014928 Age:83 ",
506
+ # noqa E501
507
+ }
508
+ _TUH_EEG_ABNORMAL_PATHS = {
509
+ # these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
510
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf": b"0 00007871 F 01-JAN-1988 00007871 Age:23 ",
511
+ # noqa E501
512
+ "tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf": b"0 00009777 M 01-JAN-1986 00009777 Age:26 ",
513
+ # noqa E501
514
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf": b"0 00008393 M 01-JAN-1960 00008393 Age:52 ",
515
+ # noqa E501
516
+ "tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf": b"0 00001200 M 01-JAN-1963 00001200 Age:47 ",
517
+ # noqa E501
518
+ "tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf": b"0 00005932 M 01-JAN-1963 00005932 Age:50 ",
519
+ # noqa E501
520
+ }
521
+
522
+
523
+ class _TUHMock(TUH):
524
+ """Mocked class for testing and examples."""
525
+
526
+ @mock.patch("glob.glob", return_value=_TUH_EEG_PATHS.keys())
527
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
528
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
529
+ def __init__(
530
+ self,
531
+ mock_glob,
532
+ path: str,
533
+ recording_ids: list[int] | None = None,
534
+ target_name: str | tuple[str, ...] | None = None,
535
+ preload: bool = False,
536
+ add_physician_reports: bool = False,
537
+ rename_channels: bool = False,
538
+ set_montage: bool = False,
539
+ n_jobs: int = 1,
540
+ ):
541
+ with warnings.catch_warnings():
542
+ warnings.filterwarnings("ignore", message="Cannot save date file")
543
+ super().__init__(
544
+ path=path,
545
+ recording_ids=recording_ids,
546
+ target_name=target_name,
547
+ preload=preload,
548
+ add_physician_reports=add_physician_reports,
549
+ rename_channels=rename_channels,
550
+ set_montage=set_montage,
551
+ n_jobs=n_jobs,
552
+ )
553
+
554
+
555
+ class _TUHAbnormalMock(TUHAbnormal):
556
+ """Mocked class for testing and examples."""
557
+
558
+ @mock.patch("glob.glob", return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
559
+ @mock.patch("mne.io.read_raw_edf", new=_fake_raw)
560
+ @mock.patch("braindecode.datasets.tuh._read_edf_header", new=_get_header)
561
+ @mock.patch(
562
+ "braindecode.datasets.tuh._read_physician_report", return_value="simple_test"
563
+ )
564
+ def __init__(
565
+ self,
566
+ mock_glob,
567
+ mock_report,
568
+ path: str,
569
+ recording_ids: list[int] | None = None,
570
+ target_name: str | tuple[str, ...] | None = "pathological",
571
+ preload: bool = False,
572
+ add_physician_reports: bool = False,
573
+ rename_channels: bool = False,
574
+ set_montage: bool = False,
575
+ n_jobs: int = 1,
576
+ ):
577
+ with warnings.catch_warnings():
578
+ warnings.filterwarnings("ignore", message="Cannot save date file")
579
+ super().__init__(
580
+ path=path,
581
+ recording_ids=recording_ids,
582
+ target_name=target_name,
583
+ preload=preload,
584
+ add_physician_reports=add_physician_reports,
585
+ rename_channels=rename_channels,
586
+ set_montage=set_montage,
587
+ n_jobs=n_jobs,
588
+ )
@@ -0,0 +1,95 @@
1
+ # Authors: Lukas Gemein <l.gemein@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+
10
+ import mne
11
+ import numpy as np
12
+ import pandas as pd
13
+ from numpy.typing import ArrayLike, NDArray
14
+
15
+ from .base import BaseConcatDataset, BaseDataset
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ def create_from_X_y(
21
+ X: NDArray,
22
+ y: ArrayLike,
23
+ drop_last_window: bool,
24
+ sfreq: float,
25
+ ch_names: ArrayLike = None,
26
+ window_size_samples: int | None = None,
27
+ window_stride_samples: int | None = None,
28
+ ) -> BaseConcatDataset:
29
+ """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
30
+ decoding with skorch and braindecode, where X is a list of pre-cut trials
31
+ and y are corresponding targets.
32
+
33
+ Parameters
34
+ ----------
35
+ X: array-like
36
+ list of pre-cut trials as n_trials x n_channels x n_times
37
+ y: array-like
38
+ targets corresponding to the trials
39
+ drop_last_window: bool
40
+ whether or not have a last overlapping window, when
41
+ windows/windows do not equally divide the continuous signal
42
+ sfreq: float
43
+ Sampling frequency of signals.
44
+ ch_names: array-like
45
+ Names of the channels.
46
+ window_size_samples: int
47
+ window size
48
+ window_stride_samples: int
49
+ stride between windows
50
+
51
+ Returns
52
+ -------
53
+ windows_datasets: BaseConcatDataset
54
+ X and y transformed to a dataset format that is compatible with skorch
55
+ and braindecode
56
+ """
57
+ # Prevent circular import
58
+ from ..preprocessing.windowers import (
59
+ create_fixed_length_windows,
60
+ )
61
+
62
+ n_samples_per_x = []
63
+ base_datasets = []
64
+ if ch_names is None:
65
+ ch_names = [str(i) for i in range(X.shape[1])]
66
+ log.info(f"No channel names given, set to 0-{X.shape[1]}).")
67
+
68
+ for x, target in zip(X, y):
69
+ n_samples_per_x.append(x.shape[1])
70
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
71
+ raw = mne.io.RawArray(x, info)
72
+ base_dataset = BaseDataset(
73
+ raw, pd.Series({"target": target}), target_name="target"
74
+ )
75
+ base_datasets.append(base_dataset)
76
+ base_datasets = BaseConcatDataset(base_datasets)
77
+
78
+ if window_size_samples is None and window_stride_samples is None:
79
+ if not len(np.unique(n_samples_per_x)) == 1:
80
+ raise ValueError(
81
+ "if 'window_size_samples' and "
82
+ "'window_stride_samples' are None, "
83
+ "all trials have to have the same length"
84
+ )
85
+ window_size_samples = n_samples_per_x[0]
86
+ window_stride_samples = n_samples_per_x[0]
87
+ windows_datasets = create_fixed_length_windows(
88
+ base_datasets,
89
+ start_offset_samples=0,
90
+ stop_offset_samples=None,
91
+ window_size_samples=window_size_samples,
92
+ window_stride_samples=window_stride_samples,
93
+ drop_last_window=drop_last_window,
94
+ )
95
+ return windows_datasets