braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# mypy: ignore-errors
|
|
2
1
|
"""
|
|
3
2
|
Shared validation utilities for Hub format operations.
|
|
4
3
|
|
|
@@ -11,7 +10,7 @@ This module provides validation functions used by hub.py to avoid code duplicati
|
|
|
11
10
|
|
|
12
11
|
from typing import Any, List, Tuple
|
|
13
12
|
|
|
14
|
-
from
|
|
13
|
+
from .registry import get_dataset_type
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
def validate_dataset_uniformity(
|
braindecode/datasets/mne.py
CHANGED
|
@@ -25,35 +25,35 @@ def create_from_mne_raw(
|
|
|
25
25
|
drop_bad_windows: bool = True,
|
|
26
26
|
accepted_bads_ratio: float = 0.0,
|
|
27
27
|
) -> BaseConcatDataset:
|
|
28
|
-
"""Create WindowsDatasets from mne.RawArrays
|
|
28
|
+
"""Create WindowsDatasets from mne.RawArrays
|
|
29
29
|
|
|
30
30
|
Parameters
|
|
31
31
|
----------
|
|
32
|
-
raws
|
|
32
|
+
raws: array-like
|
|
33
33
|
list of mne.RawArrays
|
|
34
|
-
trial_start_offset_samples
|
|
34
|
+
trial_start_offset_samples: int
|
|
35
35
|
start offset from original trial onsets in samples
|
|
36
|
-
trial_stop_offset_samples
|
|
36
|
+
trial_stop_offset_samples: int
|
|
37
37
|
stop offset from original trial stop in samples
|
|
38
|
-
window_size_samples
|
|
38
|
+
window_size_samples: int
|
|
39
39
|
window size
|
|
40
|
-
window_stride_samples
|
|
40
|
+
window_stride_samples: int
|
|
41
41
|
stride between windows
|
|
42
|
-
drop_last_window
|
|
42
|
+
drop_last_window: bool
|
|
43
43
|
whether or not have a last overlapping window, when
|
|
44
44
|
windows do not equally divide the continuous signal
|
|
45
|
-
descriptions
|
|
45
|
+
descriptions: array-like
|
|
46
46
|
list of dicts or pandas.Series with additional information about the raws
|
|
47
|
-
mapping
|
|
47
|
+
mapping: dict(str: int)
|
|
48
48
|
mapping from event description to target value
|
|
49
|
-
preload
|
|
49
|
+
preload: bool
|
|
50
50
|
if True, preload the data of the Epochs objects.
|
|
51
|
-
drop_bad_windows
|
|
51
|
+
drop_bad_windows: bool
|
|
52
52
|
If True, call `.drop_bad()` on the resulting mne.Epochs object. This
|
|
53
53
|
step allows identifying e.g., windows that fall outside of the
|
|
54
54
|
continuous recording. It is suggested to run this step here as otherwise
|
|
55
55
|
the BaseConcatDataset has to be updated as well.
|
|
56
|
-
accepted_bads_ratio
|
|
56
|
+
accepted_bads_ratio: float, optional
|
|
57
57
|
Acceptable proportion of trials withinconsistent length in a raw. If
|
|
58
58
|
the number of trials whose length is exceeded by the window size is
|
|
59
59
|
smaller than this, then only the corresponding trials are dropped, but
|
|
@@ -62,7 +62,7 @@ def create_from_mne_raw(
|
|
|
62
62
|
|
|
63
63
|
Returns
|
|
64
64
|
-------
|
|
65
|
-
windows_datasets
|
|
65
|
+
windows_datasets: BaseConcatDataset
|
|
66
66
|
X and y transformed to a dataset format that is compatible with skorch
|
|
67
67
|
and braindecode
|
|
68
68
|
"""
|
|
@@ -101,23 +101,23 @@ def create_from_mne_epochs(
|
|
|
101
101
|
window_stride_samples: int,
|
|
102
102
|
drop_last_window: bool,
|
|
103
103
|
) -> BaseConcatDataset:
|
|
104
|
-
"""Create WindowsDatasets from mne.Epochs
|
|
104
|
+
"""Create WindowsDatasets from mne.Epochs
|
|
105
105
|
|
|
106
106
|
Parameters
|
|
107
107
|
----------
|
|
108
|
-
list_of_epochs
|
|
108
|
+
list_of_epochs: array-like
|
|
109
109
|
list of mne.Epochs
|
|
110
|
-
window_size_samples
|
|
110
|
+
window_size_samples: int
|
|
111
111
|
window size
|
|
112
|
-
window_stride_samples
|
|
112
|
+
window_stride_samples: int
|
|
113
113
|
stride between windows
|
|
114
|
-
drop_last_window
|
|
114
|
+
drop_last_window: bool
|
|
115
115
|
whether or not have a last overlapping window, when
|
|
116
116
|
windows do not equally divide the continuous signal
|
|
117
117
|
|
|
118
118
|
Returns
|
|
119
119
|
-------
|
|
120
|
-
windows_datasets
|
|
120
|
+
windows_datasets: BaseConcatDataset
|
|
121
121
|
X and y transformed to a dataset format that is compatible with skorch
|
|
122
122
|
and braindecode
|
|
123
123
|
"""
|
braindecode/datasets/moabb.py
CHANGED
|
@@ -90,14 +90,14 @@ def fetch_data_with_moabb(
|
|
|
90
90
|
|
|
91
91
|
Parameters
|
|
92
92
|
----------
|
|
93
|
-
dataset_name
|
|
93
|
+
dataset_name: str | moabb.datasets.base.BaseDataset
|
|
94
94
|
the name of a dataset included in moabb
|
|
95
|
-
subject_ids
|
|
95
|
+
subject_ids: list(int) | int
|
|
96
96
|
(list of) int of subject(s) to be fetched
|
|
97
|
-
dataset_kwargs
|
|
97
|
+
dataset_kwargs: dict, optional
|
|
98
98
|
optional dictionary containing keyword arguments
|
|
99
99
|
to pass to the moabb dataset when instantiating it.
|
|
100
|
-
data_load_kwargs
|
|
100
|
+
data_load_kwargs: dict, optional
|
|
101
101
|
optional dictionary containing keyword arguments
|
|
102
102
|
to pass to the moabb dataset's load_data method.
|
|
103
103
|
Allows using the moabb cache_config=None and
|
|
@@ -105,8 +105,8 @@ def fetch_data_with_moabb(
|
|
|
105
105
|
|
|
106
106
|
Returns
|
|
107
107
|
-------
|
|
108
|
-
raws
|
|
109
|
-
info
|
|
108
|
+
raws: mne.Raw
|
|
109
|
+
info: pandas.DataFrame
|
|
110
110
|
"""
|
|
111
111
|
if isinstance(dataset_name, str):
|
|
112
112
|
dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
|
|
@@ -127,15 +127,15 @@ class MOABBDataset(BaseConcatDataset):
|
|
|
127
127
|
|
|
128
128
|
Parameters
|
|
129
129
|
----------
|
|
130
|
-
dataset_name
|
|
130
|
+
dataset_name: str
|
|
131
131
|
name of dataset included in moabb to be fetched
|
|
132
|
-
subject_ids
|
|
132
|
+
subject_ids: list(int) | int | None
|
|
133
133
|
(list of) int of subject(s) to be fetched. If None, data of all
|
|
134
134
|
subjects is fetched.
|
|
135
|
-
dataset_kwargs
|
|
135
|
+
dataset_kwargs: dict, optional
|
|
136
136
|
optional dictionary containing keyword arguments
|
|
137
137
|
to pass to the moabb dataset when instantiating it.
|
|
138
|
-
dataset_load_kwargs
|
|
138
|
+
dataset_load_kwargs: dict, optional
|
|
139
139
|
optional dictionary containing keyword arguments
|
|
140
140
|
to pass to the moabb dataset's load_data method.
|
|
141
141
|
Allows using the moabb cache_config=None and
|
braindecode/datasets/nmt.py
CHANGED
|
@@ -9,6 +9,7 @@ Note:
|
|
|
9
9
|
- The signal unit may not be uV and further examination is required.
|
|
10
10
|
- The spectrum shows that the signal may have been band-pass filtered from about 2 - 33Hz,
|
|
11
11
|
which needs to be further determined.
|
|
12
|
+
|
|
12
13
|
"""
|
|
13
14
|
|
|
14
15
|
# Authors: Mohammad Bayazi <mj.darvishi92@gmail.com>
|
|
@@ -31,7 +32,6 @@ from joblib import Parallel, delayed
|
|
|
31
32
|
from mne.datasets import fetch_dataset
|
|
32
33
|
|
|
33
34
|
from braindecode.datasets.base import BaseConcatDataset, RawDataset
|
|
34
|
-
from braindecode.datasets.utils import _correct_dataset_path
|
|
35
35
|
|
|
36
36
|
NMT_URL = "https://zenodo.org/record/10909103/files/NMT.zip"
|
|
37
37
|
NMT_archive_name = "NMT.zip"
|
|
@@ -66,17 +66,17 @@ class NMT(BaseConcatDataset):
|
|
|
66
66
|
|
|
67
67
|
Parameters
|
|
68
68
|
----------
|
|
69
|
-
path
|
|
69
|
+
path: str
|
|
70
70
|
Parent directory of the dataset.
|
|
71
|
-
recording_ids
|
|
71
|
+
recording_ids: list(int) | int
|
|
72
72
|
A (list of) int of recording id(s) to be read (order matters and will
|
|
73
73
|
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
74
74
|
then the first recording returned by this class will be chronologically
|
|
75
75
|
later than the second recording. Provide recording_ids in ascending
|
|
76
76
|
order to preserve chronological order.).
|
|
77
|
-
target_name
|
|
77
|
+
target_name: str
|
|
78
78
|
Can be "pathological", "gender", or "age".
|
|
79
|
-
preload
|
|
79
|
+
preload: bool
|
|
80
80
|
If True, preload the data of the Raw objects.
|
|
81
81
|
|
|
82
82
|
References
|
|
@@ -96,34 +96,22 @@ class NMT(BaseConcatDataset):
|
|
|
96
96
|
preload=False,
|
|
97
97
|
n_jobs=1,
|
|
98
98
|
):
|
|
99
|
-
#
|
|
100
|
-
if path
|
|
101
|
-
path =
|
|
99
|
+
# correct the path if needed
|
|
100
|
+
if path is not None:
|
|
101
|
+
list_csv = glob.glob(f"{path}/**/Labels.csv", recursive=True)
|
|
102
|
+
if isinstance(list_csv, list) and len(list_csv) > 0:
|
|
103
|
+
path = Path(list_csv[0]).parent
|
|
102
104
|
|
|
103
|
-
|
|
104
|
-
if path is None:
|
|
105
|
+
if path is None or len(list_csv) == 0:
|
|
105
106
|
path = fetch_dataset(
|
|
106
107
|
dataset_params=NMT_dataset_params,
|
|
107
|
-
path=None,
|
|
108
|
+
path=Path(path) if path is not None else None,
|
|
108
109
|
processor="unzip",
|
|
109
110
|
force_update=False,
|
|
110
111
|
)
|
|
111
112
|
# First time we fetch the dataset, we need to move the files to the
|
|
112
113
|
# correct directory.
|
|
113
|
-
path =
|
|
114
|
-
path, NMT_archive_name, "nmt_scalp_eeg_dataset"
|
|
115
|
-
)
|
|
116
|
-
else:
|
|
117
|
-
# Validate that the provided path is a valid NMT dataset
|
|
118
|
-
if not Path(f"{path}/Labels.csv").exists():
|
|
119
|
-
raise ValueError(
|
|
120
|
-
f"The provided path {path} does not contain a valid "
|
|
121
|
-
"NMT dataset (missing Labels.csv). Please ensure the "
|
|
122
|
-
"path points directly to the NMT dataset directory."
|
|
123
|
-
)
|
|
124
|
-
path = _correct_dataset_path(
|
|
125
|
-
path, NMT_archive_name, "nmt_scalp_eeg_dataset"
|
|
126
|
-
)
|
|
114
|
+
path = _correct_path(path)
|
|
127
115
|
|
|
128
116
|
# Get all file paths
|
|
129
117
|
file_paths = glob.glob(
|
|
@@ -149,10 +137,7 @@ class NMT(BaseConcatDataset):
|
|
|
149
137
|
os.path.join(path, "Labels.csv"), index_col="recordname"
|
|
150
138
|
)
|
|
151
139
|
if recording_ids is not None:
|
|
152
|
-
|
|
153
|
-
# when CSV order differs from sorted file order
|
|
154
|
-
selected_recordnames = [os.path.basename(fp) for fp in file_paths]
|
|
155
|
-
description = description.loc[selected_recordnames]
|
|
140
|
+
description = description.iloc[recording_ids]
|
|
156
141
|
description.replace(
|
|
157
142
|
{
|
|
158
143
|
"not specified": "X",
|
|
@@ -191,6 +176,39 @@ class NMT(BaseConcatDataset):
|
|
|
191
176
|
return base_dataset
|
|
192
177
|
|
|
193
178
|
|
|
179
|
+
def _correct_path(path: str):
|
|
180
|
+
"""
|
|
181
|
+
Check if the path is correct and rename the file if needed.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
path: basestring
|
|
186
|
+
Path to the file.
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
path: basestring
|
|
191
|
+
Corrected path.
|
|
192
|
+
"""
|
|
193
|
+
if not Path(path).exists():
|
|
194
|
+
unzip_file_name = f"{NMT_archive_name}.unzip"
|
|
195
|
+
if (Path(path).parent / unzip_file_name).exists():
|
|
196
|
+
try:
|
|
197
|
+
os.rename(
|
|
198
|
+
src=Path(path).parent / unzip_file_name,
|
|
199
|
+
dst=Path(path),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
except PermissionError:
|
|
203
|
+
raise PermissionError(
|
|
204
|
+
f"Please rename {Path(path).parent / unzip_file_name}"
|
|
205
|
+
+ f"manually to {path} and try again."
|
|
206
|
+
)
|
|
207
|
+
path = os.path.join(path, "nmt_scalp_eeg_dataset")
|
|
208
|
+
|
|
209
|
+
return path
|
|
210
|
+
|
|
211
|
+
|
|
194
212
|
def _get_header(*args):
|
|
195
213
|
all_paths = {**_NMT_PATHS}
|
|
196
214
|
return all_paths[args[0]]
|
|
@@ -198,24 +216,19 @@ def _get_header(*args):
|
|
|
198
216
|
|
|
199
217
|
def _fake_pd_read_csv(*args, **kwargs):
|
|
200
218
|
# Create a list of lists to hold the data
|
|
201
|
-
# Updated to match the file IDs from _NMT_PATHS (0000036-0000042)
|
|
202
|
-
# to align with the mocked glob.glob return value
|
|
203
219
|
data = [
|
|
204
|
-
["
|
|
205
|
-
["
|
|
206
|
-
["
|
|
207
|
-
["
|
|
208
|
-
["
|
|
209
|
-
["
|
|
210
|
-
["
|
|
220
|
+
["0000001.edf", "normal", 35, "male", "train"],
|
|
221
|
+
["0000002.edf", "abnormal", 28, "female", "test"],
|
|
222
|
+
["0000003.edf", "normal", 62, "male", "train"],
|
|
223
|
+
["0000004.edf", "abnormal", 41, "female", "test"],
|
|
224
|
+
["0000005.edf", "normal", 19, "male", "train"],
|
|
225
|
+
["0000006.edf", "abnormal", 55, "female", "test"],
|
|
226
|
+
["0000007.edf", "normal", 71, "male", "train"],
|
|
211
227
|
]
|
|
212
228
|
|
|
213
229
|
# Create the DataFrame, specifying column names
|
|
214
230
|
df = pd.DataFrame(data, columns=["recordname", "label", "age", "gender", "loc"])
|
|
215
231
|
|
|
216
|
-
# Set recordname as index to match the real pd.read_csv behavior with index_col="recordname"
|
|
217
|
-
df.set_index("recordname", inplace=True)
|
|
218
|
-
|
|
219
232
|
return df
|
|
220
233
|
|
|
221
234
|
|
|
@@ -275,33 +288,18 @@ _NMT_PATHS = {
|
|
|
275
288
|
class _NMTMock(NMT):
|
|
276
289
|
"""Mocked class for testing and examples."""
|
|
277
290
|
|
|
278
|
-
@mock.patch("pathlib.Path.exists", return_value=True)
|
|
279
|
-
@mock.patch("braindecode.datasets.nmt._correct_dataset_path")
|
|
280
|
-
@mock.patch("mne.datasets.fetch_dataset")
|
|
281
|
-
@mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
|
|
282
|
-
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
283
291
|
@mock.patch("glob.glob", return_value=_NMT_PATHS.keys())
|
|
292
|
+
@mock.patch("mne.io.read_raw_edf", new=_fake_raw)
|
|
293
|
+
@mock.patch("pandas.read_csv", new=_fake_pd_read_csv)
|
|
284
294
|
def __init__(
|
|
285
295
|
self,
|
|
286
296
|
mock_glob,
|
|
287
|
-
mock_fetch,
|
|
288
|
-
mock_correct_path,
|
|
289
|
-
mock_path_exists,
|
|
290
297
|
path,
|
|
291
298
|
recording_ids=None,
|
|
292
299
|
target_name="pathological",
|
|
293
300
|
preload=False,
|
|
294
301
|
n_jobs=1,
|
|
295
302
|
):
|
|
296
|
-
# Prevent download by providing a dummy path if empty/None
|
|
297
|
-
if not path:
|
|
298
|
-
path = "mocked_nmt_path"
|
|
299
|
-
|
|
300
|
-
# Mock fetch_dataset to return a valid path without downloading
|
|
301
|
-
mock_fetch.return_value = path
|
|
302
|
-
# Mock _correct_dataset_path to return the path as-is
|
|
303
|
-
mock_correct_path.side_effect = lambda p, *args, **kwargs: p
|
|
304
|
-
|
|
305
303
|
with warnings.catch_warnings():
|
|
306
304
|
warnings.filterwarnings("ignore", message="Cannot save date file")
|
|
307
305
|
super().__init__(
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
PhysioNet Challenge 2018 dataset.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
6
|
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
@@ -251,7 +253,7 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
|
|
|
251
253
|
|
|
252
254
|
Parameters
|
|
253
255
|
----------
|
|
254
|
-
subject_ids
|
|
256
|
+
subject_ids: list(int) | str | None
|
|
255
257
|
(list of) int of subject(s) to be loaded.
|
|
256
258
|
- If `None`, loads all subjects (both training and test sets [no label associated]).
|
|
257
259
|
- If `"training"`, loads only the training set subjects.
|
|
@@ -263,7 +265,7 @@ class SleepPhysionetChallenge2018(BaseConcatDataset):
|
|
|
263
265
|
is used. If it doesn't exist, the "~/mne_data" directory is used. If
|
|
264
266
|
the dataset is not found under the given path, the data will be
|
|
265
267
|
automatically downloaded to the specified folder.
|
|
266
|
-
load_eeg_only
|
|
268
|
+
load_eeg_only: bool
|
|
267
269
|
If True, only load the EEG channels and discard the others (EOG, EMG,
|
|
268
270
|
temperature, respiration) to avoid resampling the other signals.
|
|
269
271
|
preproc : list(Preprocessor) | None
|
|
@@ -25,18 +25,18 @@ class SleepPhysionet(BaseConcatDataset):
|
|
|
25
25
|
|
|
26
26
|
Parameters
|
|
27
27
|
----------
|
|
28
|
-
subject_ids
|
|
28
|
+
subject_ids: list(int) | int | None
|
|
29
29
|
(list of) int of subject(s) to be loaded. If None, load all available
|
|
30
30
|
subjects.
|
|
31
|
-
recording_ids
|
|
31
|
+
recording_ids: list(int) | None
|
|
32
32
|
Recordings to load per subject (each subject except 13 has two
|
|
33
33
|
recordings). Can be [1], [2] or [1, 2] (same as None).
|
|
34
|
-
preload
|
|
34
|
+
preload: bool
|
|
35
35
|
If True, preload the data of the Raw objects.
|
|
36
|
-
load_eeg_only
|
|
36
|
+
load_eeg_only: bool
|
|
37
37
|
If True, only load the EEG channels and discard the others (EOG, EMG,
|
|
38
38
|
temperature, respiration) to avoid resampling the other signals.
|
|
39
|
-
crop_wake_mins
|
|
39
|
+
crop_wake_mins: float
|
|
40
40
|
Number of minutes of wake time to keep before the first sleep event
|
|
41
41
|
and after the last sleep event. Used to reduce the imbalance in this
|
|
42
42
|
dataset. Default of 30 mins.
|
braindecode/datasets/tuh.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the
|
|
3
|
-
|
|
2
|
+
Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the
|
|
4
3
|
TUH Abnormal EEG Corpus.
|
|
5
4
|
"""
|
|
6
5
|
|
|
@@ -27,32 +26,31 @@ from .base import BaseConcatDataset, RawDataset
|
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class TUH(BaseConcatDataset):
|
|
30
|
-
"""Temple University Hospital (TUH) EEG Corpus
|
|
31
|
-
|
|
29
|
+
"""Temple University Hospital (TUH) EEG Corpus
|
|
32
30
|
(www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
|
|
33
31
|
|
|
34
32
|
Parameters
|
|
35
33
|
----------
|
|
36
|
-
path
|
|
34
|
+
path: str
|
|
37
35
|
Parent directory of the dataset.
|
|
38
|
-
recording_ids
|
|
36
|
+
recording_ids: list(int) | int
|
|
39
37
|
A (list of) int of recording id(s) to be read (order matters and will
|
|
40
38
|
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
41
39
|
then the first recording returned by this class will be chronologically
|
|
42
40
|
later then the second recording. Provide recording_ids in ascending
|
|
43
41
|
order to preserve chronological order.).
|
|
44
|
-
target_name
|
|
42
|
+
target_name: str
|
|
45
43
|
Can be 'gender', or 'age'.
|
|
46
|
-
preload
|
|
44
|
+
preload: bool
|
|
47
45
|
If True, preload the data of the Raw objects.
|
|
48
|
-
add_physician_reports
|
|
46
|
+
add_physician_reports: bool
|
|
49
47
|
If True, the physician reports will be read from disk and added to the
|
|
50
48
|
description.
|
|
51
|
-
rename_channels
|
|
49
|
+
rename_channels: bool
|
|
52
50
|
If True, rename the EEG channels to the standard 10-05 system.
|
|
53
|
-
set_montage
|
|
51
|
+
set_montage: bool
|
|
54
52
|
If True, set the montage to the standard 10-05 system.
|
|
55
|
-
n_jobs
|
|
53
|
+
n_jobs: int
|
|
56
54
|
Number of jobs to be used to read files in parallel.
|
|
57
55
|
"""
|
|
58
56
|
|
|
@@ -381,31 +379,30 @@ def _parse_age_and_gender_from_edf_header(file_path):
|
|
|
381
379
|
|
|
382
380
|
class TUHAbnormal(TUH):
|
|
383
381
|
"""Temple University Hospital (TUH) Abnormal EEG Corpus.
|
|
384
|
-
|
|
385
382
|
see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
|
|
386
383
|
|
|
387
384
|
Parameters
|
|
388
385
|
----------
|
|
389
|
-
path
|
|
386
|
+
path: str
|
|
390
387
|
Parent directory of the dataset.
|
|
391
|
-
recording_ids
|
|
388
|
+
recording_ids: list(int) | int
|
|
392
389
|
A (list of) int of recording id(s) to be read (order matters and will
|
|
393
390
|
overwrite default chronological order, e.g. if recording_ids=[1,0],
|
|
394
391
|
then the first recording returned by this class will be chronologically
|
|
395
392
|
later then the second recording. Provide recording_ids in ascending
|
|
396
393
|
order to preserve chronological order.).
|
|
397
|
-
target_name
|
|
394
|
+
target_name: str
|
|
398
395
|
Can be 'pathological', 'gender', or 'age'.
|
|
399
|
-
preload
|
|
396
|
+
preload: bool
|
|
400
397
|
If True, preload the data of the Raw objects.
|
|
401
|
-
add_physician_reports
|
|
398
|
+
add_physician_reports: bool
|
|
402
399
|
If True, the physician reports will be read from disk and added to the
|
|
403
400
|
description.
|
|
404
|
-
rename_channels
|
|
401
|
+
rename_channels: bool
|
|
405
402
|
If True, rename the EEG channels to the standard 10-05 system.
|
|
406
|
-
set_montage
|
|
403
|
+
set_montage: bool
|
|
407
404
|
If True, set the montage to the standard 10-05 system.
|
|
408
|
-
n_jobs
|
|
405
|
+
n_jobs: int
|
|
409
406
|
Number of jobs to be used to read files in parallel.
|
|
410
407
|
"""
|
|
411
408
|
|
braindecode/datasets/xy.py
CHANGED
|
@@ -26,32 +26,31 @@ def create_from_X_y(
|
|
|
26
26
|
window_size_samples: int | None = None,
|
|
27
27
|
window_stride_samples: int | None = None,
|
|
28
28
|
) -> BaseConcatDataset:
|
|
29
|
-
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
|
|
30
|
-
|
|
29
|
+
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
|
|
31
30
|
decoding with skorch and braindecode, where X is a list of pre-cut trials
|
|
32
31
|
and y are corresponding targets.
|
|
33
32
|
|
|
34
33
|
Parameters
|
|
35
34
|
----------
|
|
36
|
-
X
|
|
35
|
+
X: array-like
|
|
37
36
|
list of pre-cut trials as n_trials x n_channels x n_times
|
|
38
|
-
y
|
|
37
|
+
y: array-like
|
|
39
38
|
targets corresponding to the trials
|
|
40
|
-
drop_last_window
|
|
39
|
+
drop_last_window: bool
|
|
41
40
|
whether or not have a last overlapping window, when
|
|
42
41
|
windows/windows do not equally divide the continuous signal
|
|
43
|
-
sfreq
|
|
42
|
+
sfreq: float
|
|
44
43
|
Sampling frequency of signals.
|
|
45
|
-
ch_names
|
|
44
|
+
ch_names: array-like
|
|
46
45
|
Names of the channels.
|
|
47
|
-
window_size_samples
|
|
46
|
+
window_size_samples: int
|
|
48
47
|
window size
|
|
49
|
-
window_stride_samples
|
|
48
|
+
window_stride_samples: int
|
|
50
49
|
stride between windows
|
|
51
50
|
|
|
52
51
|
Returns
|
|
53
52
|
-------
|
|
54
|
-
windows_datasets
|
|
53
|
+
windows_datasets: BaseConcatDataset
|
|
55
54
|
X and y transformed to a dataset format that is compatible with skorch
|
|
56
55
|
and braindecode
|
|
57
56
|
"""
|
braindecode/datautil/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Utilities for data manipulation.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
from .channel_utils import (
|
|
4
6
|
division_channels_idx,
|
|
@@ -9,7 +11,6 @@ from .serialization import (
|
|
|
9
11
|
load_concat_dataset,
|
|
10
12
|
save_concat_dataset,
|
|
11
13
|
)
|
|
12
|
-
from .util import infer_signal_properties
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def __getattr__(name):
|
|
@@ -58,5 +59,4 @@ __all__ = [
|
|
|
58
59
|
"_check_save_dir_empty",
|
|
59
60
|
"match_hemisphere_chans",
|
|
60
61
|
"division_channels_idx",
|
|
61
|
-
"infer_signal_properties",
|
|
62
62
|
]
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions for storing and loading of windows datasets.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
# Authors: Lukas Gemein <l.gemein@gmail.com>
|
|
4
6
|
#
|
|
@@ -33,25 +35,24 @@ def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
36
|
-
"""Load a stored BaseConcatDataset from
|
|
37
|
-
|
|
38
|
+
"""Load a stored BaseConcatDataset from
|
|
38
39
|
files.
|
|
39
40
|
|
|
40
41
|
Parameters
|
|
41
42
|
----------
|
|
42
|
-
path
|
|
43
|
+
path: pathlib.Path
|
|
43
44
|
Path to the directory of the .fif / -epo.fif and .json files.
|
|
44
|
-
preload
|
|
45
|
+
preload: bool
|
|
45
46
|
Whether to preload the data.
|
|
46
|
-
ids_to_load
|
|
47
|
+
ids_to_load: None | list(int)
|
|
47
48
|
Ids of specific files to load.
|
|
48
|
-
target_name
|
|
49
|
+
target_name: None or str
|
|
49
50
|
Load specific description column as target. If not given, take saved
|
|
50
51
|
target name.
|
|
51
52
|
|
|
52
53
|
Returns
|
|
53
54
|
-------
|
|
54
|
-
concat_dataset
|
|
55
|
+
concat_dataset: BaseConcatDataset
|
|
55
56
|
"""
|
|
56
57
|
# assume we have a single concat dataset to load
|
|
57
58
|
is_raw = (path / "0-raw.fif").is_file()
|
|
@@ -137,7 +138,7 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
137
138
|
with open(pkl_file, "rb") as f:
|
|
138
139
|
signals = pickle.load(f)
|
|
139
140
|
|
|
140
|
-
if all(
|
|
141
|
+
if all(f.exists() for f in signals.filenames):
|
|
141
142
|
if preload:
|
|
142
143
|
signals.load_data()
|
|
143
144
|
return signals
|
|
@@ -174,27 +175,26 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
174
175
|
|
|
175
176
|
|
|
176
177
|
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
177
|
-
"""Load a stored BaseConcatDataset from
|
|
178
|
-
|
|
178
|
+
"""Load a stored BaseConcatDataset from
|
|
179
179
|
files.
|
|
180
180
|
|
|
181
181
|
Parameters
|
|
182
182
|
----------
|
|
183
|
-
path
|
|
183
|
+
path: str | pathlib.Path
|
|
184
184
|
Path to the directory of the .fif / -epo.fif and .json files.
|
|
185
|
-
preload
|
|
185
|
+
preload: bool
|
|
186
186
|
Whether to preload the data.
|
|
187
|
-
ids_to_load
|
|
187
|
+
ids_to_load: list of int | None
|
|
188
188
|
Ids of specific files to load.
|
|
189
|
-
target_name
|
|
189
|
+
target_name: str | list | None
|
|
190
190
|
Load specific description column as target. If not given, take saved
|
|
191
191
|
target name.
|
|
192
|
-
n_jobs
|
|
192
|
+
n_jobs: int
|
|
193
193
|
Number of jobs to be used to read files in parallel.
|
|
194
194
|
|
|
195
195
|
Returns
|
|
196
196
|
-------
|
|
197
|
-
concat_dataset
|
|
197
|
+
concat_dataset: BaseConcatDataset
|
|
198
198
|
"""
|
|
199
199
|
# Make sure we always work with a pathlib.Path
|
|
200
200
|
path = Path(path)
|
|
@@ -306,11 +306,9 @@ def _load_kwargs_json(kwargs_name, sub_dir):
|
|
|
306
306
|
|
|
307
307
|
|
|
308
308
|
def _is_outdated_saved(path):
|
|
309
|
-
"""Data was saved in the old way if there are 'description.json', '-raw.fif'
|
|
310
|
-
|
|
309
|
+
"""Data was saved in the old way if there are 'description.json', '-raw.fif'
|
|
311
310
|
or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
|
|
312
|
-
than 'description.json' files.
|
|
313
|
-
"""
|
|
311
|
+
than 'description.json' files."""
|
|
314
312
|
description_files = glob(os.path.join(path, "**/description.json"))
|
|
315
313
|
fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
|
|
316
314
|
os.path.join(path, "**/*-epo.fif")
|
|
@@ -344,7 +342,7 @@ def _check_save_dir_empty(save_dir):
|
|
|
344
342
|
Directory under which a `BaseConcatDataset` will be saved.
|
|
345
343
|
|
|
346
344
|
Raises
|
|
347
|
-
|
|
345
|
+
-------
|
|
348
346
|
FileExistsError
|
|
349
347
|
If ``save_dir`` is not a valid directory for saving.
|
|
350
348
|
"""
|