braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import mne
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from mne.datasets.sleep_physionet.age import fetch_data
|
|
14
|
+
|
|
15
|
+
from .base import BaseConcatDataset, RawDataset
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SleepPhysionet(BaseConcatDataset):
|
|
19
|
+
"""Sleep Physionet dataset.
|
|
20
|
+
|
|
21
|
+
Sleep dataset from https://physionet.org/content/sleep-edfx/1.0.0/.
|
|
22
|
+
Contains overnight recordings from 78 healthy subjects.
|
|
23
|
+
|
|
24
|
+
See `MNE example <https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html>`.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
subject_ids : list(int) | int | None
|
|
29
|
+
(list of) int of subject(s) to be loaded. If None, load all available
|
|
30
|
+
subjects.
|
|
31
|
+
recording_ids : list(int) | None
|
|
32
|
+
Recordings to load per subject (each subject except 13 has two
|
|
33
|
+
recordings). Can be [1], [2] or [1, 2] (same as None).
|
|
34
|
+
preload : bool
|
|
35
|
+
If True, preload the data of the Raw objects.
|
|
36
|
+
load_eeg_only : bool
|
|
37
|
+
If True, only load the EEG channels and discard the others (EOG, EMG,
|
|
38
|
+
temperature, respiration) to avoid resampling the other signals.
|
|
39
|
+
crop_wake_mins : float
|
|
40
|
+
Number of minutes of wake time to keep before the first sleep event
|
|
41
|
+
and after the last sleep event. Used to reduce the imbalance in this
|
|
42
|
+
dataset. Default of 30 mins.
|
|
43
|
+
crop : None | tuple
|
|
44
|
+
If not None crop the raw files (e.g. to use only the first 3h).
|
|
45
|
+
Example: ``crop=(0, 3600*3)`` to keep only the first 3h.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
subject_ids: list[int] | int | None = None,
|
|
51
|
+
recording_ids: list[int] | None = None,
|
|
52
|
+
preload=False,
|
|
53
|
+
load_eeg_only=True,
|
|
54
|
+
crop_wake_mins=30,
|
|
55
|
+
crop=None,
|
|
56
|
+
):
|
|
57
|
+
if subject_ids is None:
|
|
58
|
+
subject_ids = list(range(83))
|
|
59
|
+
if recording_ids is None:
|
|
60
|
+
recording_ids = [1, 2]
|
|
61
|
+
|
|
62
|
+
paths = fetch_data(subject_ids, recording=recording_ids, on_missing="warn")
|
|
63
|
+
|
|
64
|
+
all_base_ds = list()
|
|
65
|
+
for p in paths:
|
|
66
|
+
raw, desc = self._load_raw(
|
|
67
|
+
p[0],
|
|
68
|
+
p[1],
|
|
69
|
+
preload=preload,
|
|
70
|
+
load_eeg_only=load_eeg_only,
|
|
71
|
+
crop_wake_mins=crop_wake_mins,
|
|
72
|
+
crop=crop,
|
|
73
|
+
)
|
|
74
|
+
base_ds = RawDataset(raw, desc)
|
|
75
|
+
all_base_ds.append(base_ds)
|
|
76
|
+
super().__init__(all_base_ds)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _load_raw(
|
|
80
|
+
raw_fname,
|
|
81
|
+
ann_fname,
|
|
82
|
+
preload,
|
|
83
|
+
load_eeg_only=True,
|
|
84
|
+
crop_wake_mins=False,
|
|
85
|
+
crop=None,
|
|
86
|
+
):
|
|
87
|
+
ch_mapping = {
|
|
88
|
+
"EOG horizontal": "eog",
|
|
89
|
+
"Resp oro-nasal": "misc",
|
|
90
|
+
"EMG submental": "misc",
|
|
91
|
+
"Temp rectal": "misc",
|
|
92
|
+
"Event marker": "misc",
|
|
93
|
+
}
|
|
94
|
+
exclude = list(ch_mapping.keys()) if load_eeg_only else ()
|
|
95
|
+
|
|
96
|
+
raw = mne.io.read_raw_edf(raw_fname, preload=preload, exclude=exclude)
|
|
97
|
+
annots = mne.read_annotations(ann_fname)
|
|
98
|
+
raw.set_annotations(annots, emit_warning=False)
|
|
99
|
+
|
|
100
|
+
if crop_wake_mins > 0:
|
|
101
|
+
# Find first and last sleep stages
|
|
102
|
+
mask = [x[-1] in ["1", "2", "3", "4", "R"] for x in annots.description]
|
|
103
|
+
sleep_event_inds = np.where(mask)[0]
|
|
104
|
+
|
|
105
|
+
# Crop raw
|
|
106
|
+
tmin = annots[int(sleep_event_inds[0])]["onset"] - crop_wake_mins * 60
|
|
107
|
+
tmax = annots[int(sleep_event_inds[-1])]["onset"] + crop_wake_mins * 60
|
|
108
|
+
raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1]))
|
|
109
|
+
|
|
110
|
+
# Rename EEG channels
|
|
111
|
+
ch_names = {i: i.replace("EEG ", "") for i in raw.ch_names if "EEG" in i}
|
|
112
|
+
raw.rename_channels(ch_names)
|
|
113
|
+
|
|
114
|
+
if not load_eeg_only:
|
|
115
|
+
raw.set_channel_types(ch_mapping)
|
|
116
|
+
|
|
117
|
+
if crop is not None:
|
|
118
|
+
raw.crop(*crop)
|
|
119
|
+
|
|
120
|
+
basename = os.path.basename(raw_fname)
|
|
121
|
+
subj_nb = int(basename[3:5])
|
|
122
|
+
sess_nb = int(basename[5])
|
|
123
|
+
desc = pd.Series({"subject": subj_nb, "recording": sess_nb}, name="")
|
|
124
|
+
|
|
125
|
+
return raw, desc
|