braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/xy.py
CHANGED
|
@@ -3,19 +3,29 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
8
|
import logging
|
|
9
|
+
|
|
9
10
|
import mne
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
|
10
14
|
|
|
11
|
-
from .base import
|
|
15
|
+
from .base import BaseConcatDataset, BaseDataset
|
|
12
16
|
|
|
13
17
|
log = logging.getLogger(__name__)
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
def create_from_X_y(
|
|
17
|
-
|
|
18
|
-
|
|
21
|
+
X: NDArray,
|
|
22
|
+
y: ArrayLike,
|
|
23
|
+
drop_last_window: bool,
|
|
24
|
+
sfreq: float,
|
|
25
|
+
ch_names: ArrayLike = None,
|
|
26
|
+
window_size_samples: int | None = None,
|
|
27
|
+
window_stride_samples: int | None = None,
|
|
28
|
+
) -> BaseConcatDataset:
|
|
19
29
|
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
|
|
20
30
|
decoding with skorch and braindecode, where X is a list of pre-cut trials
|
|
21
31
|
and y are corresponding targets.
|
|
@@ -46,7 +56,9 @@ def create_from_X_y(
|
|
|
46
56
|
"""
|
|
47
57
|
# Prevent circular import
|
|
48
58
|
from ..preprocessing.windowers import (
|
|
49
|
-
create_fixed_length_windows,
|
|
59
|
+
create_fixed_length_windows,
|
|
60
|
+
)
|
|
61
|
+
|
|
50
62
|
n_samples_per_x = []
|
|
51
63
|
base_datasets = []
|
|
52
64
|
if ch_names is None:
|
|
@@ -57,16 +69,19 @@ def create_from_X_y(
|
|
|
57
69
|
n_samples_per_x.append(x.shape[1])
|
|
58
70
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
|
|
59
71
|
raw = mne.io.RawArray(x, info)
|
|
60
|
-
base_dataset = BaseDataset(
|
|
61
|
-
|
|
72
|
+
base_dataset = BaseDataset(
|
|
73
|
+
raw, pd.Series({"target": target}), target_name="target"
|
|
74
|
+
)
|
|
62
75
|
base_datasets.append(base_dataset)
|
|
63
76
|
base_datasets = BaseConcatDataset(base_datasets)
|
|
64
77
|
|
|
65
78
|
if window_size_samples is None and window_stride_samples is None:
|
|
66
79
|
if not len(np.unique(n_samples_per_x)) == 1:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
|
|
69
|
-
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"if 'window_size_samples' and "
|
|
82
|
+
"'window_stride_samples' are None, "
|
|
83
|
+
"all trials have to have the same length"
|
|
84
|
+
)
|
|
70
85
|
window_size_samples = n_samples_per_x[0]
|
|
71
86
|
window_stride_samples = n_samples_per_x[0]
|
|
72
87
|
windows_datasets = create_fixed_length_windows(
|
|
@@ -75,6 +90,6 @@ def create_from_X_y(
|
|
|
75
90
|
stop_offset_samples=None,
|
|
76
91
|
window_size_samples=window_size_samples,
|
|
77
92
|
window_stride_samples=window_stride_samples,
|
|
78
|
-
drop_last_window=drop_last_window
|
|
93
|
+
drop_last_window=drop_last_window,
|
|
79
94
|
)
|
|
80
95
|
return windows_datasets
|
braindecode/datautil/__init__.py
CHANGED
|
@@ -2,32 +2,48 @@
|
|
|
2
2
|
Utilities for data manipulation.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
|
|
6
5
|
from .serialization import (
|
|
7
|
-
|
|
6
|
+
_check_save_dir_empty,
|
|
7
|
+
load_concat_dataset,
|
|
8
|
+
save_concat_dataset,
|
|
9
|
+
)
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
def __getattr__(name):
|
|
11
13
|
# ideas from https://stackoverflow.com/a/57110249/1469195
|
|
12
|
-
from warnings import warn
|
|
13
14
|
import importlib
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
15
|
+
from warnings import warn
|
|
16
|
+
|
|
17
|
+
if name == "create_from_X_y":
|
|
18
|
+
warn(
|
|
19
|
+
"create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
|
|
20
|
+
)
|
|
21
|
+
xy = importlib.import_module("..datasets.xy", __package__)
|
|
17
22
|
return xy.create_from_X_y
|
|
18
|
-
if name in [
|
|
19
|
-
warn(
|
|
20
|
-
|
|
23
|
+
if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
|
|
24
|
+
warn(
|
|
25
|
+
f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
|
|
26
|
+
)
|
|
27
|
+
mne = importlib.import_module("..datasets.mne", __package__)
|
|
21
28
|
return mne.__dict__[name]
|
|
22
|
-
if name in [
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
29
|
+
if name in [
|
|
30
|
+
"scale",
|
|
31
|
+
"exponential_moving_demean",
|
|
32
|
+
"exponential_moving_standardize",
|
|
33
|
+
"filterbank",
|
|
34
|
+
"preprocess",
|
|
35
|
+
"Preprocessor",
|
|
36
|
+
]:
|
|
37
|
+
warn(
|
|
38
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
39
|
+
)
|
|
40
|
+
preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
|
|
27
41
|
return preprocess.__dict__[name]
|
|
28
|
-
if name in [
|
|
29
|
-
warn(
|
|
30
|
-
|
|
42
|
+
if name in ["create_windows_from_events", "create_fixed_length_windows"]:
|
|
43
|
+
warn(
|
|
44
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
45
|
+
)
|
|
46
|
+
windowers = importlib.import_module("..preprocessing.windowers", __package__)
|
|
31
47
|
return windowers.__dict__[name]
|
|
32
48
|
|
|
33
|
-
raise AttributeError(
|
|
49
|
+
raise AttributeError("No possible import named " + name)
|
|
@@ -6,8 +6,8 @@ Convenience functions for storing and loading of windows datasets.
|
|
|
6
6
|
#
|
|
7
7
|
# License: BSD (3-clause)
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
9
|
import json
|
|
10
|
+
import os
|
|
11
11
|
import pickle
|
|
12
12
|
import warnings
|
|
13
13
|
from glob import glob
|
|
@@ -17,17 +17,24 @@ import mne
|
|
|
17
17
|
import pandas as pd
|
|
18
18
|
from joblib import Parallel, delayed
|
|
19
19
|
|
|
20
|
-
from ..datasets.base import
|
|
20
|
+
from ..datasets.base import (
|
|
21
|
+
BaseConcatDataset,
|
|
22
|
+
BaseDataset,
|
|
23
|
+
EEGWindowsDataset,
|
|
24
|
+
WindowsDataset,
|
|
25
|
+
)
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
24
|
-
warnings.warn(
|
|
25
|
-
|
|
29
|
+
warnings.warn(
|
|
30
|
+
'"save_concat_dataset()" is deprecated and will be removed in'
|
|
31
|
+
" the future. Use dataset.save() instead.",
|
|
32
|
+
UserWarning,
|
|
33
|
+
)
|
|
26
34
|
concat_dataset.save(path=path, overwrite=overwrite)
|
|
27
35
|
|
|
28
36
|
|
|
29
|
-
def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
30
|
-
target_name=None):
|
|
37
|
+
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
31
38
|
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
32
39
|
files.
|
|
33
40
|
|
|
@@ -48,15 +55,16 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
|
48
55
|
concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
|
|
49
56
|
"""
|
|
50
57
|
# assume we have a single concat dataset to load
|
|
51
|
-
is_raw = (path /
|
|
58
|
+
is_raw = (path / "0-raw.fif").is_file()
|
|
52
59
|
assert not (not is_raw and target_name is not None), (
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
"Setting a new target is only supported for raws."
|
|
61
|
+
)
|
|
62
|
+
is_epochs = (path / "0-epo.fif").is_file()
|
|
55
63
|
paths = [path]
|
|
56
64
|
# assume we have multiple concat datasets to load
|
|
57
65
|
if not (is_raw or is_epochs):
|
|
58
|
-
is_raw = (path /
|
|
59
|
-
is_epochs = (path /
|
|
66
|
+
is_raw = (path / "0" / "0-raw.fif").is_file()
|
|
67
|
+
is_epochs = (path / "0" / "0-epo.fif").is_file()
|
|
60
68
|
paths = path.glob("*/")
|
|
61
69
|
paths = sorted(paths, key=lambda p: int(p.name))
|
|
62
70
|
if ids_to_load is not None:
|
|
@@ -64,33 +72,32 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
|
64
72
|
ids_to_load = None
|
|
65
73
|
# if we have neither a single nor multiple datasets, something went wrong
|
|
66
74
|
assert is_raw or is_epochs, (
|
|
67
|
-
f
|
|
68
|
-
|
|
75
|
+
f"Expect either raw or epo to exist in {path} or in {path / '0'}"
|
|
76
|
+
)
|
|
69
77
|
|
|
70
78
|
datasets = []
|
|
71
79
|
for path in paths:
|
|
72
80
|
if is_raw and target_name is None:
|
|
73
|
-
target_file_name = path /
|
|
74
|
-
target_name = json.load(open(target_file_name, "r"))[
|
|
81
|
+
target_file_name = path / "target_name.json"
|
|
82
|
+
target_name = json.load(open(target_file_name, "r"))["target_name"]
|
|
75
83
|
|
|
76
84
|
all_signals, description = _load_signals_and_description(
|
|
77
|
-
path=path, preload=preload, is_raw=is_raw,
|
|
78
|
-
ids_to_load=ids_to_load
|
|
85
|
+
path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
|
|
79
86
|
)
|
|
80
87
|
for i_signal, signal in enumerate(all_signals):
|
|
81
88
|
if is_raw:
|
|
82
89
|
datasets.append(
|
|
83
|
-
BaseDataset(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
datasets.append(
|
|
87
|
-
WindowsDataset(signal, description.iloc[i_signal])
|
|
90
|
+
BaseDataset(
|
|
91
|
+
signal, description.iloc[i_signal], target_name=target_name
|
|
92
|
+
)
|
|
88
93
|
)
|
|
94
|
+
else:
|
|
95
|
+
datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
|
|
89
96
|
concat_ds = BaseConcatDataset(datasets)
|
|
90
|
-
for kwarg_name in [
|
|
91
|
-
kwarg_path = path /
|
|
97
|
+
for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
|
|
98
|
+
kwarg_path = path / ".".join([kwarg_name, "json"])
|
|
92
99
|
if kwarg_path.exists():
|
|
93
|
-
with open(kwarg_path,
|
|
100
|
+
with open(kwarg_path, "r") as f:
|
|
94
101
|
kwargs = json.load(f)
|
|
95
102
|
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
96
103
|
setattr(concat_ds, kwarg_name, kwargs)
|
|
@@ -107,7 +114,8 @@ def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
|
|
|
107
114
|
# '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
|
|
108
115
|
# '11-raw.fif' -> 11
|
|
109
116
|
ids_to_load = sorted(
|
|
110
|
-
[int(os.path.split(f)[-1].split(
|
|
117
|
+
[int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
|
|
118
|
+
)
|
|
111
119
|
for i in ids_to_load:
|
|
112
120
|
fif_file = path / file_name.format(i)
|
|
113
121
|
all_signals.append(_load_signals(fif_file, preload, is_raw))
|
|
@@ -133,10 +141,10 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
133
141
|
# If pickle didn't exist read via mne (likely slower) and save pkl after
|
|
134
142
|
if is_raw:
|
|
135
143
|
signals = mne.io.read_raw_fif(fif_file, preload=preload)
|
|
136
|
-
elif fif_file.name.endswith(
|
|
144
|
+
elif fif_file.name.endswith("-epo.fif"):
|
|
137
145
|
signals = mne.read_epochs(fif_file, preload=preload)
|
|
138
146
|
else:
|
|
139
|
-
raise ValueError(
|
|
147
|
+
raise ValueError("fif_file must end with raw.fif or epo.fif.")
|
|
140
148
|
|
|
141
149
|
# Only do this for raw objects. Epoch objects are not picklable as they
|
|
142
150
|
# hold references to open files in `signals._raw[0].fid`.
|
|
@@ -154,8 +162,7 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
154
162
|
return signals
|
|
155
163
|
|
|
156
164
|
|
|
157
|
-
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
158
|
-
n_jobs=1):
|
|
165
|
+
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
159
166
|
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
160
167
|
files.
|
|
161
168
|
|
|
@@ -183,11 +190,14 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
183
190
|
# if we encounter a dataset that was saved in 'the old way', call the
|
|
184
191
|
# corresponding 'old' loading function
|
|
185
192
|
if _is_outdated_saved(path):
|
|
186
|
-
warnings.warn(
|
|
187
|
-
|
|
193
|
+
warnings.warn(
|
|
194
|
+
"The way your dataset was saved is deprecated by now. "
|
|
195
|
+
"Please save it again using dataset.save().",
|
|
196
|
+
UserWarning,
|
|
197
|
+
)
|
|
188
198
|
return _outdated_load_concat_dataset(
|
|
189
|
-
path=path, preload=preload, ids_to_load=ids_to_load,
|
|
190
|
-
|
|
199
|
+
path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
|
|
200
|
+
)
|
|
191
201
|
|
|
192
202
|
# else we have a dataset saved in the new way with subdirectories in path
|
|
193
203
|
# for every dataset with description.json and -epo.fif or -raw.fif,
|
|
@@ -197,9 +207,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
197
207
|
ids_to_load = [p.name for p in path.iterdir()]
|
|
198
208
|
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
199
209
|
ids_to_load = [str(i) for i in ids_to_load]
|
|
200
|
-
first_raw_fif_path = path / ids_to_load[0] / f
|
|
210
|
+
first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
|
|
201
211
|
is_raw = first_raw_fif_path.exists()
|
|
202
|
-
metadata_path = path / ids_to_load[0] /
|
|
212
|
+
metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
|
|
203
213
|
has_stored_windows = metadata_path.exists()
|
|
204
214
|
|
|
205
215
|
# Parallelization of mne.read_epochs with preload=False fails with
|
|
@@ -207,8 +217,10 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
207
217
|
# So ignore n_jobs in that case and load with a single job.
|
|
208
218
|
if not is_raw and n_jobs != 1:
|
|
209
219
|
warnings.warn(
|
|
210
|
-
|
|
211
|
-
|
|
220
|
+
"Parallelized reading with `preload=False` is not supported for "
|
|
221
|
+
"windowed data. Will use `n_jobs=1`.",
|
|
222
|
+
UserWarning,
|
|
223
|
+
)
|
|
212
224
|
n_jobs = 1
|
|
213
225
|
datasets = Parallel(n_jobs)(
|
|
214
226
|
delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
|
|
@@ -219,9 +231,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
219
231
|
|
|
220
232
|
def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
221
233
|
sub_dir = path / i
|
|
222
|
-
file_name_patterns = [
|
|
234
|
+
file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
|
|
223
235
|
if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
|
|
224
|
-
raise FileExistsError(
|
|
236
|
+
raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
|
|
225
237
|
|
|
226
238
|
fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
|
|
227
239
|
fif_file_name = fif_name_pattern.format(i)
|
|
@@ -229,48 +241,51 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
|
229
241
|
|
|
230
242
|
signals = _load_signals(fif_file_path, preload, is_raw)
|
|
231
243
|
|
|
232
|
-
description_file_path = sub_dir /
|
|
233
|
-
description = pd.read_json(description_file_path, typ=
|
|
244
|
+
description_file_path = sub_dir / "description.json"
|
|
245
|
+
description = pd.read_json(description_file_path, typ="series")
|
|
234
246
|
|
|
235
|
-
target_file_path = sub_dir /
|
|
247
|
+
target_file_path = sub_dir / "target_name.json"
|
|
236
248
|
target_name = None
|
|
237
249
|
if target_file_path.exists():
|
|
238
|
-
target_name = json.load(open(target_file_path, "r"))[
|
|
250
|
+
target_name = json.load(open(target_file_path, "r"))["target_name"]
|
|
239
251
|
|
|
240
252
|
if is_raw and (not has_stored_windows):
|
|
241
253
|
dataset = BaseDataset(signals, description, target_name)
|
|
242
254
|
else:
|
|
243
|
-
window_kwargs = _load_kwargs_json(
|
|
244
|
-
windows_ds_kwargs = [
|
|
255
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
256
|
+
windows_ds_kwargs = [
|
|
257
|
+
kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
|
|
258
|
+
]
|
|
245
259
|
windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
|
|
246
260
|
if is_raw:
|
|
247
|
-
metadata = pd.read_pickle(path / i /
|
|
261
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
248
262
|
dataset = EEGWindowsDataset(
|
|
249
263
|
signals,
|
|
250
264
|
metadata=metadata,
|
|
251
265
|
description=description,
|
|
252
|
-
targets_from=windows_ds_kwargs.get(
|
|
253
|
-
last_target_only=windows_ds_kwargs.get(
|
|
266
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
267
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
254
268
|
)
|
|
255
269
|
else:
|
|
256
270
|
# MNE epochs dataset
|
|
257
271
|
dataset = WindowsDataset(
|
|
258
|
-
signals,
|
|
259
|
-
|
|
260
|
-
|
|
272
|
+
signals,
|
|
273
|
+
description,
|
|
274
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
275
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
261
276
|
)
|
|
262
|
-
setattr(dataset,
|
|
263
|
-
for kwargs_name in [
|
|
277
|
+
setattr(dataset, "window_kwargs", window_kwargs)
|
|
278
|
+
for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
|
|
264
279
|
kwargs = _load_kwargs_json(kwargs_name, sub_dir)
|
|
265
280
|
setattr(dataset, kwargs_name, kwargs)
|
|
266
281
|
return dataset
|
|
267
282
|
|
|
268
283
|
|
|
269
284
|
def _load_kwargs_json(kwargs_name, sub_dir):
|
|
270
|
-
kwargs_file_name =
|
|
285
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
271
286
|
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
272
287
|
if os.path.exists(kwargs_file_path):
|
|
273
|
-
kwargs = json.load(open(kwargs_file_path,
|
|
288
|
+
kwargs = json.load(open(kwargs_file_path, "r"))
|
|
274
289
|
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
275
290
|
return kwargs
|
|
276
291
|
|
|
@@ -279,18 +294,28 @@ def _is_outdated_saved(path):
|
|
|
279
294
|
"""Data was saved in the old way if there are 'description.json', '-raw.fif'
|
|
280
295
|
or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
|
|
281
296
|
than 'description.json' files."""
|
|
282
|
-
description_files = glob(os.path.join(path,
|
|
283
|
-
fif_files = glob(os.path.join(path,
|
|
297
|
+
description_files = glob(os.path.join(path, "**/description.json"))
|
|
298
|
+
fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
|
|
299
|
+
os.path.join(path, "**/*-epo.fif")
|
|
300
|
+
)
|
|
284
301
|
multiple = len(description_files) != len(fif_files)
|
|
285
302
|
kwargs_in_path = any(
|
|
286
|
-
[
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
303
|
+
[
|
|
304
|
+
os.path.exists(os.path.join(path, kwarg_name))
|
|
305
|
+
for kwarg_name in [
|
|
306
|
+
"raw_preproc_kwargs",
|
|
307
|
+
"window_kwargs",
|
|
308
|
+
"window_preproc_kwargs",
|
|
309
|
+
]
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
return (
|
|
313
|
+
os.path.exists(os.path.join(path, "description.json"))
|
|
314
|
+
or os.path.exists(os.path.join(path, "0-raw.fif"))
|
|
315
|
+
or os.path.exists(os.path.join(path, "0-epo.fif"))
|
|
316
|
+
or multiple
|
|
317
|
+
or kwargs_in_path
|
|
318
|
+
)
|
|
294
319
|
|
|
295
320
|
|
|
296
321
|
def _check_save_dir_empty(save_dir):
|
|
@@ -306,10 +331,12 @@ def _check_save_dir_empty(save_dir):
|
|
|
306
331
|
FileExistsError
|
|
307
332
|
If ``save_dir`` is not a valid directory for saving.
|
|
308
333
|
"""
|
|
309
|
-
sub_dirs = [
|
|
310
|
-
|
|
334
|
+
sub_dirs = [
|
|
335
|
+
os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
|
|
336
|
+
]
|
|
311
337
|
if any(sub_dirs):
|
|
312
338
|
raise FileExistsError(
|
|
313
|
-
f
|
|
314
|
-
|
|
315
|
-
|
|
339
|
+
f"Directory {save_dir} already contains subdirectories. Please "
|
|
340
|
+
"select a different directory, set overwrite=True, or resolve "
|
|
341
|
+
"manually."
|
|
342
|
+
)
|