braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,412 @@
1
+ """
2
+ PhysioNet Challenge 2018 dataset.
3
+ """
4
+
5
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
6
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
7
+ # License: BSD (3-clause)
8
+ # Code copied from the repository
9
+ # https://github.com/hubertjb/dynamic-spatial-filtering
10
+
11
+ import os
12
+ import os.path as op
13
+ import urllib
14
+
15
+ import mne
16
+ import numpy as np
17
+ import pandas as pd
18
+ import wfdb
19
+ from joblib import Parallel, delayed
20
+ from mne.datasets.sleep_physionet._utils import _fetch_one
21
+ from mne.datasets.utils import _get_path
22
+ from mne.utils import warn
23
+
24
+ from braindecode.datasets import BaseConcatDataset, BaseDataset
25
+ from braindecode.preprocessing.preprocess import _preprocess
26
+
27
+ PC18_DIR = op.join(op.dirname(__file__), "data", "pc18")
28
+ PC18_RECORDS = op.join(PC18_DIR, "sleep_records.csv")
29
+ PC18_INFO = op.join(PC18_DIR, "age-sex.csv")
30
+ PC18_URL = "https://physionet.org/files/challenge-2018/1.0.0/"
31
+ PC18_SHA1_TRAINING = op.join(PC18_DIR, "training_SHA1SUMS")
32
+ PC18_SHA1_TEST = op.join(PC18_DIR, "test_SHA1SUMS")
33
+ PC18_METAINFO_URL = "https://zenodo.org/records/13823458/files/"
34
+
35
+
36
+ # Function to download a file if it doesn't exist
37
+ def _download_if_missing(file_path, url):
38
+ folder_path = op.dirname(file_path)
39
+
40
+ # Ensure the folder exists
41
+ if not op.exists(folder_path):
42
+ warn(f"Directory {folder_path} not found. Creating directory.")
43
+ os.makedirs(folder_path)
44
+
45
+ # Check if file exists, if not download it
46
+ if not op.exists(file_path):
47
+ warn(f"{file_path} not found. Downloading from {url}")
48
+ urllib.request.urlretrieve(url, file_path)
49
+
50
+
51
+ def ensure_metafiles_exist():
52
+ files_to_check = {
53
+ PC18_RECORDS: PC18_METAINFO_URL + "sleep_records.csv",
54
+ PC18_INFO: PC18_METAINFO_URL + "age-sex.csv",
55
+ PC18_SHA1_TRAINING: PC18_METAINFO_URL + "training_SHA1SUMS",
56
+ PC18_SHA1_TEST: PC18_METAINFO_URL + "test_SHA1SUMS",
57
+ }
58
+
59
+ for file_path, url in files_to_check.items():
60
+ _download_if_missing(file_path, url)
61
+
62
+
63
+ def _update_pc18_sleep_records(fname=PC18_RECORDS):
64
+ """Create CSV file with information about available PC18 recordings."""
65
+ # Load and massage the checksums.
66
+ sha_train_df = pd.read_csv(
67
+ PC18_SHA1_TRAINING,
68
+ sep=" ",
69
+ header=None,
70
+ names=["sha", "fname"],
71
+ engine="python",
72
+ )
73
+ sha_test_df = pd.read_csv(
74
+ PC18_SHA1_TEST, sep=" ", header=None, names=["sha", "fname"], engine="python"
75
+ )
76
+ sha_train_df["Split"] = "training"
77
+ sha_test_df["Split"] = "test"
78
+ sha_df = pd.concat([sha_train_df, sha_test_df], axis=0, ignore_index=True)
79
+ select_records = (
80
+ sha_df.fname.str.startswith("tr") | sha_df.fname.str.startswith("te")
81
+ ) & ~sha_df.fname.str.endswith("arousal.mat")
82
+ sha_df = sha_df[select_records]
83
+ sha_df["Record"] = sha_df["fname"].str.split("/", expand=True)[0]
84
+ sha_df["fname"] = sha_df[["Split", "fname"]].agg("/".join, axis=1)
85
+
86
+ # Load and massage the data.
87
+ data = pd.read_csv(PC18_INFO)
88
+
89
+ data = data.reset_index().rename({"index": "Subject"}, axis=1)
90
+ data["Sex"] = (
91
+ data["Sex"].map({"F": "female", "M": "male", "m": "male"}).astype("category")
92
+ )
93
+ data = sha_df.merge(data, on="Record")
94
+
95
+ data["Record type"] = (
96
+ data["fname"]
97
+ .str.split(".", expand=True)[1]
98
+ .map({"hea": "Header", "mat": "PSG", "arousal": "Arousal"})
99
+ .astype("category")
100
+ )
101
+ data = data[
102
+ ["Subject", "Record", "Record type", "Split", "Age", "Sex", "sha", "fname"]
103
+ ].sort_values(by="Subject")
104
+
105
+ # Save the data.
106
+ data.to_csv(fname, index=False)
107
+
108
+
109
+ def _data_path(path=None):
110
+ """Get path to local copy of PC18 dataset."""
111
+ key = "PC18_DATASET_PATH"
112
+ name = "PC18_DATASET_SLEEP"
113
+ path = _get_path(path, key, name)
114
+ subdirs = os.listdir(path)
115
+ if "training" in subdirs or "test" in subdirs: # the specified path is
116
+ # already at the training and test folders level
117
+ return path
118
+ else:
119
+ return op.join(path, "pc18-sleep-data")
120
+
121
+
122
+ def fetch_pc18_data(subjects, path=None, force_update=False, base_url=PC18_URL):
123
+ """Get paths to local copies of PhysioNet Challenge 2018 dataset files.
124
+
125
+ This will fetch data from the publicly available PhysioNet Computing in
126
+ Cardiology Challenge 2018 dataset on sleep arousal detection [1]_ [2]_.
127
+ This corresponds to 1983 recordings from individual subjects with
128
+ (suspected) sleep apnea. The dataset is separated into a training set with
129
+ 994 recordings for which arousal annotation are available and a test set
130
+ with 989 recordings for which the labels have not been revealed. Across the
131
+ entire dataset, mean age is 55 years old and 65% of recordings are from
132
+ male subjects.
133
+
134
+ More information can be found on the
135
+ `physionet website <https://physionet.org/content/challenge-2018/1.0.0/>`_.
136
+
137
+ Parameters
138
+ ----------
139
+ subjects : list of int
140
+ The subjects to use. Can be in the range of 0-1982 (inclusive). Test
141
+ recordings are 0-988, while training recordings are 989-1982.
142
+ path : None | str
143
+ Location of where to look for the PC18 data storing location. If None,
144
+ the environment variable or config parameter ``PC18_DATASET_PATH``
145
+ is used. If it doesn't exist, the "~/mne_data" directory is used. If
146
+ the dataset is not found under the given path, the data will be
147
+ automatically downloaded to the specified folder.
148
+ force_update : bool
149
+ Force update of the dataset even if a local copy exists.
150
+ update_path : bool | None
151
+ If True, set the PC18_DATASET_PATH in mne-python config to the given
152
+ path. If None, the user is prompted.
153
+ base_url : str
154
+ The URL root.
155
+ %(verbose)s
156
+
157
+ Returns
158
+ -------
159
+ paths : list
160
+ List of local data paths of the given type.
161
+
162
+ References
163
+ ----------
164
+ .. [1] Mohammad M Ghassemi, Benjamin E Moody, Li-wei H Lehman, Christopher
165
+ Song, Qiao Li, Haoqi Sun, Roger G Mark, M Brandon Westover, Gari D
166
+ Clifford. You Snooze, You Win: the PhysioNet/Computing in Cardiology
167
+ Challenge 2018.
168
+ .. [2] Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C.,
169
+ Mark, R., ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and
170
+ PhysioNet: Components of a new research resource for complex physiologic
171
+ signals. Circulation [Online]. 101 (23), pp. e215–e220.)
172
+ """
173
+ records = pd.read_csv(PC18_RECORDS)
174
+ psg_records = records[records["Record type"] == "PSG"]
175
+ hea_records = records[records["Record type"] == "Header"]
176
+ arousal_records = records[records["Record type"] == "Arousal"]
177
+
178
+ path = _data_path(path=path)
179
+ params = [path, force_update, base_url]
180
+
181
+ fnames = []
182
+ for subject in subjects:
183
+ for idx in np.where(psg_records["Subject"] == subject)[0]:
184
+ psg_fname = _fetch_one(
185
+ psg_records["fname"].iloc[idx], psg_records["sha"].iloc[idx], *params
186
+ )
187
+ hea_fname = _fetch_one(
188
+ hea_records["fname"].iloc[idx], hea_records["sha"].iloc[idx], *params
189
+ )
190
+ if psg_records["Split"].iloc[idx] == "training":
191
+ train_idx = np.where(arousal_records["Subject"] == subject)[0][0]
192
+ arousal_fname = _fetch_one(
193
+ arousal_records["fname"].iloc[train_idx],
194
+ arousal_records["sha"].iloc[train_idx],
195
+ *params,
196
+ )
197
+ else:
198
+ arousal_fname = None
199
+ fnames.append([psg_fname, hea_fname, arousal_fname])
200
+
201
+ return fnames
202
+
203
+
204
+ def _convert_wfdb_anns_to_mne_annotations(annots):
205
+ """Convert wfdb.io.Annotation format to MNE's.
206
+
207
+ Parameters
208
+ ----------
209
+ annots : wfdb.io.Annotation
210
+ Annotation object obtained by e.g. loading an annotation file with
211
+ wfdb.rdann().
212
+
213
+ Returns
214
+ -------
215
+ mne.Annotations :
216
+ MNE Annotations object.
217
+ """
218
+ ann_chs = set(annots.chan)
219
+ onsets = annots.sample / annots.fs
220
+ new_onset, new_duration, new_description = list(), list(), list()
221
+ for channel_name in ann_chs:
222
+ mask = annots.chan == channel_name
223
+ ch_onsets = onsets[mask]
224
+ ch_descs = np.array(annots.aux_note)[mask]
225
+
226
+ # Events with beginning and end, defined by '(event' and 'event)'
227
+ if all([(i.startswith("(") or i.endswith(")")) for i in ch_descs]):
228
+ pass
229
+ else: # Sleep stage-like annotations
230
+ ch_durations = np.concatenate([np.diff(ch_onsets), [30]])
231
+ if all(ch_durations > 0):
232
+ ValueError("Negative duration")
233
+ new_onset.extend(ch_onsets)
234
+ new_duration.extend(ch_durations)
235
+ new_description.extend(ch_descs)
236
+
237
+ mne_annots = mne.Annotations(
238
+ new_onset, new_duration, new_description, orig_time=None
239
+ )
240
+
241
+ return mne_annots
242
+
243
+
244
+ class SleepPhysionetChallenge2018(BaseConcatDataset):
245
+ """Physionet Challenge 2018 polysomnography dataset.
246
+
247
+ Sleep dataset from https://physionet.org/content/challenge-2018/1.0.0/.
248
+ Contains overnight recordings from 1983 healthy subjects.
249
+
250
+ The total size is 266 GB, so make sure you have enough space before
251
+ downloading.
252
+
253
+ See `fetch_pc18_data` for a more complete description.
254
+
255
+ Parameters
256
+ ----------
257
+ subject_ids: list(int) | str | None
258
+ (list of) int of subject(s) to be loaded.
259
+ - If `None`, loads all subjects (both training and test sets [no label associated]).
260
+ - If `"training"`, loads only the training set subjects.
261
+ - If `"test"`, loads only the test set subjects, no label associated!
262
+ - Otherwise, expects an iterable of subject IDs.
263
+ path : None | str
264
+ Location of where to look for the PC18 data storing location. If None,
265
+ the environment variable or config parameter ``MNE_DATASETS_PC18_PATH``
266
+ is used. If it doesn't exist, the "~/mne_data" directory is used. If
267
+ the dataset is not found under the given path, the data will be
268
+ automatically downloaded to the specified folder.
269
+ load_eeg_only: bool
270
+ If True, only load the EEG channels and discard the others (EOG, EMG,
271
+ temperature, respiration) to avoid resampling the other signals.
272
+ preproc : list(Preprocessor) | None
273
+ List of preprocessors to apply to each file individually. This way the
274
+ data can e.g., be downsampled (temporally and spatially) to limit the
275
+ memory usage of the entire Dataset object. This also enables applying
276
+ preprocessing in parallel over the recordings.
277
+ n_jobs : int
278
+ Number of parallel processes.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ subject_ids="training",
284
+ path=None,
285
+ load_eeg_only=True,
286
+ preproc=None,
287
+ n_jobs=1,
288
+ ):
289
+ if subject_ids is None:
290
+ subject_ids = range(1983)
291
+ warn(
292
+ """"
293
+ You are loading the complete dataset (0 to 1982),
294
+ which includes a portion of the test set (0 to 988)
295
+ from the Physionet Challenge 2018. Note that the test set
296
+ does not have associated labels, so supervised classification
297
+ cannot be performed on these data.""",
298
+ UserWarning,
299
+ )
300
+ elif subject_ids == "training":
301
+ subject_ids = range(989, 1983)
302
+ elif subject_ids == "test":
303
+ subject_ids = range(989)
304
+ warn(
305
+ """
306
+ This subset does not have associated labels, so supervised
307
+ classification (sleep stage) cannot be performed on this data.
308
+ You can also use the meta information as a label to perform
309
+ another task.
310
+ """
311
+ )
312
+ else:
313
+ # If subject_ids is an iterable, check if it includes any test set IDs
314
+ if any(sid < 989 for sid in subject_ids):
315
+ warn(
316
+ """
317
+ You are loading a subset of the data that includes test set
318
+ subjects (subject IDs: 0 to 988). These subjects do not have
319
+ associated labels, which means supervised classification
320
+ (sleep stage) cannot be performed on this data. You can also
321
+ use the meta information as a label to perform another task.
322
+ """,
323
+ UserWarning,
324
+ )
325
+
326
+ ensure_metafiles_exist()
327
+
328
+ paths = fetch_pc18_data(subject_ids, path=path)
329
+
330
+ self.info_df = pd.read_csv(PC18_INFO)
331
+
332
+ if n_jobs == 1:
333
+ all_base_ds = [
334
+ self._load_raw(
335
+ subj_nb=subject_id,
336
+ raw_fname=p[0],
337
+ arousal_fname=p[2],
338
+ load_eeg_only=load_eeg_only,
339
+ preproc=preproc,
340
+ )
341
+ for subject_id, p in zip(subject_ids, paths)
342
+ ]
343
+ else:
344
+ all_base_ds = Parallel(n_jobs=n_jobs)(
345
+ delayed(self._load_raw)(
346
+ subject_id,
347
+ p[0],
348
+ p[2],
349
+ load_eeg_only=load_eeg_only,
350
+ preproc=preproc,
351
+ )
352
+ for subject_id, p in zip(subject_ids, paths)
353
+ )
354
+ super().__init__(all_base_ds)
355
+
356
+ def _load_raw(self, subj_nb, raw_fname, arousal_fname, load_eeg_only, preproc):
357
+ channel_types = ["eeg"] * 7
358
+ if load_eeg_only:
359
+ channels = list(range(7))
360
+ else:
361
+ channel_types += ["emg", "misc", "misc", "misc", "misc", "ecg"]
362
+ channels = None
363
+
364
+ # Load raw signals and header
365
+ record = wfdb.io.rdrecord(op.splitext(raw_fname[0])[0], channels=channels)
366
+
367
+ # Convert to right units for MNE (EEG should be in V)
368
+ data = record.p_signal.T
369
+ data[np.array(record.units) == "uV"] /= 1e6
370
+ data[np.array(record.units) == "mV"] /= 1e3
371
+ info = mne.create_info(record.sig_name, record.fs, channel_types)
372
+ raw_file = mne.io.RawArray(data, info)
373
+
374
+ # Extract annotations
375
+ if arousal_fname is not None:
376
+ annots = wfdb.rdann(
377
+ op.splitext(raw_fname[0])[0],
378
+ "arousal",
379
+ sampfrom=0,
380
+ sampto=None,
381
+ shift_samps=False,
382
+ return_label_elements=["symbol"],
383
+ summarize_labels=False,
384
+ )
385
+ mne_annots = _convert_wfdb_anns_to_mne_annotations(annots)
386
+ raw_file = raw_file.set_annotations(mne_annots)
387
+
388
+ record_name = op.splitext(op.basename(raw_fname[0]))[0]
389
+ record_info = self.info_df[self.info_df["Record"] == record_name].iloc[0]
390
+ if record_info["Record"].startswith("tr"):
391
+ split = "training"
392
+ elif record_info["Record"].startswith("te"):
393
+ split = "test"
394
+ else:
395
+ split = "unknown"
396
+
397
+ desc = pd.Series(
398
+ {
399
+ "subject": subj_nb,
400
+ "record": record_info["Record"],
401
+ "split": split,
402
+ "age": record_info["Age"],
403
+ "sex": record_info["Sex"],
404
+ },
405
+ name="",
406
+ )
407
+ base_dataset = BaseDataset(raw_file, desc)
408
+
409
+ if preproc is not None:
410
+ _preprocess(base_dataset, None, preproc)
411
+
412
+ return base_dataset
@@ -3,14 +3,16 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
 
6
+ from __future__ import annotations
7
+
6
8
  import os
7
9
 
10
+ import mne
8
11
  import numpy as np
9
12
  import pandas as pd
10
- import mne
11
13
  from mne.datasets.sleep_physionet.age import fetch_data
12
14
 
13
- from .base import BaseDataset, BaseConcatDataset
15
+ from .base import BaseConcatDataset, BaseDataset
14
16
 
15
17
 
16
18
  class SleepPhysionet(BaseConcatDataset):
@@ -19,7 +21,7 @@ class SleepPhysionet(BaseConcatDataset):
19
21
  Sleep dataset from https://physionet.org/content/sleep-edfx/1.0.0/.
20
22
  Contains overnight recordings from 78 healthy subjects.
21
23
 
22
- See [MNE example](https://mne.tools/stable/auto_tutorials/sample-datasets/plot_sleep.html).
24
+ See `MNE example <https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html>`.
23
25
 
24
26
  Parameters
25
27
  ----------
@@ -42,34 +44,52 @@ class SleepPhysionet(BaseConcatDataset):
42
44
  If not None crop the raw files (e.g. to use only the first 3h).
43
45
  Example: ``crop=(0, 3600*3)`` to keep only the first 3h.
44
46
  """
45
- def __init__(self, subject_ids=None, recording_ids=None, preload=False,
46
- load_eeg_only=True, crop_wake_mins=30, crop=None):
47
+
48
+ def __init__(
49
+ self,
50
+ subject_ids: list[int] | int | None = None,
51
+ recording_ids: list[int] | None = None,
52
+ preload=False,
53
+ load_eeg_only=True,
54
+ crop_wake_mins=30,
55
+ crop=None,
56
+ ):
47
57
  if subject_ids is None:
48
- subject_ids = range(83)
58
+ subject_ids = list(range(83))
49
59
  if recording_ids is None:
50
60
  recording_ids = [1, 2]
51
61
 
52
- paths = fetch_data(
53
- subject_ids, recording=recording_ids, on_missing='warn')
62
+ paths = fetch_data(subject_ids, recording=recording_ids, on_missing="warn")
54
63
 
55
64
  all_base_ds = list()
56
65
  for p in paths:
57
66
  raw, desc = self._load_raw(
58
- p[0], p[1], preload=preload, load_eeg_only=load_eeg_only,
59
- crop_wake_mins=crop_wake_mins, crop=crop)
67
+ p[0],
68
+ p[1],
69
+ preload=preload,
70
+ load_eeg_only=load_eeg_only,
71
+ crop_wake_mins=crop_wake_mins,
72
+ crop=crop,
73
+ )
60
74
  base_ds = BaseDataset(raw, desc)
61
75
  all_base_ds.append(base_ds)
62
76
  super().__init__(all_base_ds)
63
77
 
64
78
  @staticmethod
65
- def _load_raw(raw_fname, ann_fname, preload, load_eeg_only=True,
66
- crop_wake_mins=False, crop=None):
79
+ def _load_raw(
80
+ raw_fname,
81
+ ann_fname,
82
+ preload,
83
+ load_eeg_only=True,
84
+ crop_wake_mins=False,
85
+ crop=None,
86
+ ):
67
87
  ch_mapping = {
68
- 'EOG horizontal': 'eog',
69
- 'Resp oro-nasal': 'misc',
70
- 'EMG submental': 'misc',
71
- 'Temp rectal': 'misc',
72
- 'Event marker': 'misc'
88
+ "EOG horizontal": "eog",
89
+ "Resp oro-nasal": "misc",
90
+ "EMG submental": "misc",
91
+ "Temp rectal": "misc",
92
+ "Event marker": "misc",
73
93
  }
74
94
  exclude = list(ch_mapping.keys()) if load_eeg_only else ()
75
95
 
@@ -79,19 +99,16 @@ class SleepPhysionet(BaseConcatDataset):
79
99
 
80
100
  if crop_wake_mins > 0:
81
101
  # Find first and last sleep stages
82
- mask = [
83
- x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description]
102
+ mask = [x[-1] in ["1", "2", "3", "4", "R"] for x in annots.description]
84
103
  sleep_event_inds = np.where(mask)[0]
85
104
 
86
105
  # Crop raw
87
- tmin = annots[int(sleep_event_inds[0])]['onset'] - crop_wake_mins * 60
88
- tmax = annots[int(sleep_event_inds[-1])]['onset'] + crop_wake_mins * 60
89
- raw.crop(tmin=max(tmin, raw.times[0]),
90
- tmax=min(tmax, raw.times[-1]))
106
+ tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
107
+ tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
108
+ raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))
91
109
 
92
110
  # Rename EEG channels
93
- ch_names = {
94
- i: i.replace('EEG ', '') for i in raw.ch_names if 'EEG' in i}
111
+ ch_names = {i: i.replace("EEG ", "") for i in raw.ch_names if "EEG" in i}
95
112
  raw.rename_channels(ch_names)
96
113
 
97
114
  if not load_eeg_only:
@@ -103,6 +120,6 @@ class SleepPhysionet(BaseConcatDataset):
103
120
  basename = os.path.basename(raw_fname)
104
121
  subj_nb = int(basename[3:5])
105
122
  sess_nb = int(basename[5])
106
- desc = pd.Series({'subject': subj_nb, 'recording': sess_nb}, name='')
123
+ desc = pd.Series({"subject": subj_nb, "recording": sess_nb}, name="")
107
124
 
108
125
  return raw, desc