braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/xy.py
CHANGED
|
@@ -3,19 +3,29 @@
|
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
8
|
import logging
|
|
9
|
+
|
|
9
10
|
import mne
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
|
10
14
|
|
|
11
|
-
from .base import
|
|
15
|
+
from .base import BaseConcatDataset, BaseDataset
|
|
12
16
|
|
|
13
17
|
log = logging.getLogger(__name__)
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
def create_from_X_y(
|
|
17
|
-
|
|
18
|
-
|
|
21
|
+
X: NDArray,
|
|
22
|
+
y: ArrayLike,
|
|
23
|
+
drop_last_window: bool,
|
|
24
|
+
sfreq: float,
|
|
25
|
+
ch_names: ArrayLike = None,
|
|
26
|
+
window_size_samples: int | None = None,
|
|
27
|
+
window_stride_samples: int | None = None,
|
|
28
|
+
) -> BaseConcatDataset:
|
|
19
29
|
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
|
|
20
30
|
decoding with skorch and braindecode, where X is a list of pre-cut trials
|
|
21
31
|
and y are corresponding targets.
|
|
@@ -46,7 +56,9 @@ def create_from_X_y(
|
|
|
46
56
|
"""
|
|
47
57
|
# Prevent circular import
|
|
48
58
|
from ..preprocessing.windowers import (
|
|
49
|
-
create_fixed_length_windows,
|
|
59
|
+
create_fixed_length_windows,
|
|
60
|
+
)
|
|
61
|
+
|
|
50
62
|
n_samples_per_x = []
|
|
51
63
|
base_datasets = []
|
|
52
64
|
if ch_names is None:
|
|
@@ -57,16 +69,19 @@ def create_from_X_y(
|
|
|
57
69
|
n_samples_per_x.append(x.shape[1])
|
|
58
70
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
|
|
59
71
|
raw = mne.io.RawArray(x, info)
|
|
60
|
-
base_dataset = BaseDataset(
|
|
61
|
-
|
|
72
|
+
base_dataset = BaseDataset(
|
|
73
|
+
raw, pd.Series({"target": target}), target_name="target"
|
|
74
|
+
)
|
|
62
75
|
base_datasets.append(base_dataset)
|
|
63
76
|
base_datasets = BaseConcatDataset(base_datasets)
|
|
64
77
|
|
|
65
78
|
if window_size_samples is None and window_stride_samples is None:
|
|
66
79
|
if not len(np.unique(n_samples_per_x)) == 1:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
|
|
69
|
-
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"if 'window_size_samples' and "
|
|
82
|
+
"'window_stride_samples' are None, "
|
|
83
|
+
"all trials have to have the same length"
|
|
84
|
+
)
|
|
70
85
|
window_size_samples = n_samples_per_x[0]
|
|
71
86
|
window_stride_samples = n_samples_per_x[0]
|
|
72
87
|
windows_datasets = create_fixed_length_windows(
|
|
@@ -75,6 +90,6 @@ def create_from_X_y(
|
|
|
75
90
|
stop_offset_samples=None,
|
|
76
91
|
window_size_samples=window_size_samples,
|
|
77
92
|
window_stride_samples=window_stride_samples,
|
|
78
|
-
drop_last_window=drop_last_window
|
|
93
|
+
drop_last_window=drop_last_window,
|
|
79
94
|
)
|
|
80
95
|
return windows_datasets
|
braindecode/datautil/__init__.py
CHANGED
|
@@ -2,32 +2,51 @@
|
|
|
2
2
|
Utilities for data manipulation.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
|
|
6
5
|
from .serialization import (
|
|
7
|
-
|
|
6
|
+
_check_save_dir_empty,
|
|
7
|
+
load_concat_dataset,
|
|
8
|
+
save_concat_dataset,
|
|
9
|
+
)
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
def __getattr__(name):
|
|
11
13
|
# ideas from https://stackoverflow.com/a/57110249/1469195
|
|
12
|
-
from warnings import warn
|
|
13
14
|
import importlib
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
15
|
+
from warnings import warn
|
|
16
|
+
|
|
17
|
+
if name == "create_from_X_y":
|
|
18
|
+
warn(
|
|
19
|
+
"create_from_X_y has been moved to datasets, please use from braindecode.datasets import create_from_X_y"
|
|
20
|
+
)
|
|
21
|
+
xy = importlib.import_module("..datasets.xy", __package__)
|
|
17
22
|
return xy.create_from_X_y
|
|
18
|
-
if name in [
|
|
19
|
-
warn(
|
|
20
|
-
|
|
23
|
+
if name in ["create_from_mne_raw", "create_from_mne_epochs"]:
|
|
24
|
+
warn(
|
|
25
|
+
f"{name} has been moved to datasets, please use from braindecode.datasets import {name}"
|
|
26
|
+
)
|
|
27
|
+
mne = importlib.import_module("..datasets.mne", __package__)
|
|
21
28
|
return mne.__dict__[name]
|
|
22
|
-
if name in [
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
29
|
+
if name in [
|
|
30
|
+
"scale",
|
|
31
|
+
"exponential_moving_demean",
|
|
32
|
+
"exponential_moving_standardize",
|
|
33
|
+
"filterbank",
|
|
34
|
+
"preprocess",
|
|
35
|
+
"Preprocessor",
|
|
36
|
+
]:
|
|
37
|
+
warn(
|
|
38
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
39
|
+
)
|
|
40
|
+
preprocess = importlib.import_module("..preprocessing.preprocess", __package__)
|
|
27
41
|
return preprocess.__dict__[name]
|
|
28
|
-
if name in [
|
|
29
|
-
warn(
|
|
30
|
-
|
|
42
|
+
if name in ["create_windows_from_events", "create_fixed_length_windows"]:
|
|
43
|
+
warn(
|
|
44
|
+
f"{name} has been moved to preprocessing, please use from braindecode.preprocessing import {name}"
|
|
45
|
+
)
|
|
46
|
+
windowers = importlib.import_module("..preprocessing.windowers", __package__)
|
|
31
47
|
return windowers.__dict__[name]
|
|
32
48
|
|
|
33
|
-
raise AttributeError(
|
|
49
|
+
raise AttributeError("No possible import named " + name)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = ["load_concat_dataset", "save_concat_dataset", "_check_save_dir_empty"]
|
|
@@ -6,8 +6,8 @@ Convenience functions for storing and loading of windows datasets.
|
|
|
6
6
|
#
|
|
7
7
|
# License: BSD (3-clause)
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
9
|
import json
|
|
10
|
+
import os
|
|
11
11
|
import pickle
|
|
12
12
|
import warnings
|
|
13
13
|
from glob import glob
|
|
@@ -17,17 +17,24 @@ import mne
|
|
|
17
17
|
import pandas as pd
|
|
18
18
|
from joblib import Parallel, delayed
|
|
19
19
|
|
|
20
|
-
from ..datasets.base import
|
|
20
|
+
from ..datasets.base import (
|
|
21
|
+
BaseConcatDataset,
|
|
22
|
+
BaseDataset,
|
|
23
|
+
EEGWindowsDataset,
|
|
24
|
+
WindowsDataset,
|
|
25
|
+
)
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
def save_concat_dataset(path, concat_dataset, overwrite=False):
|
|
24
|
-
warnings.warn(
|
|
25
|
-
|
|
29
|
+
warnings.warn(
|
|
30
|
+
'"save_concat_dataset()" is deprecated and will be removed in'
|
|
31
|
+
" the future. Use dataset.save() instead.",
|
|
32
|
+
UserWarning,
|
|
33
|
+
)
|
|
26
34
|
concat_dataset.save(path=path, overwrite=overwrite)
|
|
27
35
|
|
|
28
36
|
|
|
29
|
-
def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
30
|
-
target_name=None):
|
|
37
|
+
def _outdated_load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
|
|
31
38
|
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
32
39
|
files.
|
|
33
40
|
|
|
@@ -48,15 +55,16 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
|
48
55
|
concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
|
|
49
56
|
"""
|
|
50
57
|
# assume we have a single concat dataset to load
|
|
51
|
-
is_raw = (path /
|
|
58
|
+
is_raw = (path / "0-raw.fif").is_file()
|
|
52
59
|
assert not (not is_raw and target_name is not None), (
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
"Setting a new target is only supported for raws."
|
|
61
|
+
)
|
|
62
|
+
is_epochs = (path / "0-epo.fif").is_file()
|
|
55
63
|
paths = [path]
|
|
56
64
|
# assume we have multiple concat datasets to load
|
|
57
65
|
if not (is_raw or is_epochs):
|
|
58
|
-
is_raw = (path /
|
|
59
|
-
is_epochs = (path /
|
|
66
|
+
is_raw = (path / "0" / "0-raw.fif").is_file()
|
|
67
|
+
is_epochs = (path / "0" / "0-epo.fif").is_file()
|
|
60
68
|
paths = path.glob("*/")
|
|
61
69
|
paths = sorted(paths, key=lambda p: int(p.name))
|
|
62
70
|
if ids_to_load is not None:
|
|
@@ -64,33 +72,32 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
|
64
72
|
ids_to_load = None
|
|
65
73
|
# if we have neither a single nor multiple datasets, something went wrong
|
|
66
74
|
assert is_raw or is_epochs, (
|
|
67
|
-
f
|
|
68
|
-
|
|
75
|
+
f"Expect either raw or epo to exist in {path} or in {path / '0'}"
|
|
76
|
+
)
|
|
69
77
|
|
|
70
78
|
datasets = []
|
|
71
79
|
for path in paths:
|
|
72
80
|
if is_raw and target_name is None:
|
|
73
|
-
target_file_name = path /
|
|
74
|
-
target_name = json.load(open(target_file_name, "r"))[
|
|
81
|
+
target_file_name = path / "target_name.json"
|
|
82
|
+
target_name = json.load(open(target_file_name, "r"))["target_name"]
|
|
75
83
|
|
|
76
84
|
all_signals, description = _load_signals_and_description(
|
|
77
|
-
path=path, preload=preload, is_raw=is_raw,
|
|
78
|
-
ids_to_load=ids_to_load
|
|
85
|
+
path=path, preload=preload, is_raw=is_raw, ids_to_load=ids_to_load
|
|
79
86
|
)
|
|
80
87
|
for i_signal, signal in enumerate(all_signals):
|
|
81
88
|
if is_raw:
|
|
82
89
|
datasets.append(
|
|
83
|
-
BaseDataset(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
datasets.append(
|
|
87
|
-
WindowsDataset(signal, description.iloc[i_signal])
|
|
90
|
+
BaseDataset(
|
|
91
|
+
signal, description.iloc[i_signal], target_name=target_name
|
|
92
|
+
)
|
|
88
93
|
)
|
|
94
|
+
else:
|
|
95
|
+
datasets.append(WindowsDataset(signal, description.iloc[i_signal]))
|
|
89
96
|
concat_ds = BaseConcatDataset(datasets)
|
|
90
|
-
for kwarg_name in [
|
|
91
|
-
kwarg_path = path /
|
|
97
|
+
for kwarg_name in ["raw_preproc_kwargs", "window_kwargs", "window_preproc_kwargs"]:
|
|
98
|
+
kwarg_path = path / ".".join([kwarg_name, "json"])
|
|
92
99
|
if kwarg_path.exists():
|
|
93
|
-
with open(kwarg_path,
|
|
100
|
+
with open(kwarg_path, "r") as f:
|
|
94
101
|
kwargs = json.load(f)
|
|
95
102
|
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
96
103
|
setattr(concat_ds, kwarg_name, kwargs)
|
|
@@ -100,14 +107,22 @@ def _outdated_load_concat_dataset(path, preload, ids_to_load=None,
|
|
|
100
107
|
def _load_signals_and_description(path, preload, is_raw, ids_to_load=None):
|
|
101
108
|
all_signals = []
|
|
102
109
|
file_name = "{}-raw.fif" if is_raw else "{}-epo.fif"
|
|
103
|
-
description_df = pd.read_json(
|
|
110
|
+
description_df = pd.read_json(
|
|
111
|
+
path / "description.json", typ="series", convert_dates=False
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if "timestamp" in description_df.index:
|
|
115
|
+
timestamp_numeric = pd.to_numeric(description_df["timestamp"])
|
|
116
|
+
description_df["timestamp"] = pd.to_datetime(timestamp_numeric)
|
|
117
|
+
|
|
104
118
|
if ids_to_load is None:
|
|
105
119
|
file_names = path.glob(f"*{file_name.lstrip('{}')}")
|
|
106
120
|
# Extract ids, e.g.,
|
|
107
121
|
# '/home/schirrmr/data/preproced-tuh/all-sensors/11-raw.fif' ->
|
|
108
122
|
# '11-raw.fif' -> 11
|
|
109
123
|
ids_to_load = sorted(
|
|
110
|
-
[int(os.path.split(f)[-1].split(
|
|
124
|
+
[int(os.path.split(f)[-1].split("-")[0]) for f in file_names]
|
|
125
|
+
)
|
|
111
126
|
for i in ids_to_load:
|
|
112
127
|
fif_file = path / file_name.format(i)
|
|
113
128
|
all_signals.append(_load_signals(fif_file, preload, is_raw))
|
|
@@ -133,10 +148,10 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
133
148
|
# If pickle didn't exist read via mne (likely slower) and save pkl after
|
|
134
149
|
if is_raw:
|
|
135
150
|
signals = mne.io.read_raw_fif(fif_file, preload=preload)
|
|
136
|
-
elif fif_file.name.endswith(
|
|
151
|
+
elif fif_file.name.endswith("-epo.fif"):
|
|
137
152
|
signals = mne.read_epochs(fif_file, preload=preload)
|
|
138
153
|
else:
|
|
139
|
-
raise ValueError(
|
|
154
|
+
raise ValueError("fif_file must end with raw.fif or epo.fif.")
|
|
140
155
|
|
|
141
156
|
# Only do this for raw objects. Epoch objects are not picklable as they
|
|
142
157
|
# hold references to open files in `signals._raw[0].fid`.
|
|
@@ -154,8 +169,7 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
154
169
|
return signals
|
|
155
170
|
|
|
156
171
|
|
|
157
|
-
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
158
|
-
n_jobs=1):
|
|
172
|
+
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None, n_jobs=1):
|
|
159
173
|
"""Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
|
|
160
174
|
files.
|
|
161
175
|
|
|
@@ -183,11 +197,14 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
183
197
|
# if we encounter a dataset that was saved in 'the old way', call the
|
|
184
198
|
# corresponding 'old' loading function
|
|
185
199
|
if _is_outdated_saved(path):
|
|
186
|
-
warnings.warn(
|
|
187
|
-
|
|
200
|
+
warnings.warn(
|
|
201
|
+
"The way your dataset was saved is deprecated by now. "
|
|
202
|
+
"Please save it again using dataset.save().",
|
|
203
|
+
UserWarning,
|
|
204
|
+
)
|
|
188
205
|
return _outdated_load_concat_dataset(
|
|
189
|
-
path=path, preload=preload, ids_to_load=ids_to_load,
|
|
190
|
-
|
|
206
|
+
path=path, preload=preload, ids_to_load=ids_to_load, target_name=target_name
|
|
207
|
+
)
|
|
191
208
|
|
|
192
209
|
# else we have a dataset saved in the new way with subdirectories in path
|
|
193
210
|
# for every dataset with description.json and -epo.fif or -raw.fif,
|
|
@@ -197,9 +214,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
197
214
|
ids_to_load = [p.name for p in path.iterdir()]
|
|
198
215
|
ids_to_load = sorted(ids_to_load, key=lambda i: int(i))
|
|
199
216
|
ids_to_load = [str(i) for i in ids_to_load]
|
|
200
|
-
first_raw_fif_path = path / ids_to_load[0] / f
|
|
217
|
+
first_raw_fif_path = path / ids_to_load[0] / f"{ids_to_load[0]}-raw.fif"
|
|
201
218
|
is_raw = first_raw_fif_path.exists()
|
|
202
|
-
metadata_path = path / ids_to_load[0] /
|
|
219
|
+
metadata_path = path / ids_to_load[0] / "metadata_df.pkl"
|
|
203
220
|
has_stored_windows = metadata_path.exists()
|
|
204
221
|
|
|
205
222
|
# Parallelization of mne.read_epochs with preload=False fails with
|
|
@@ -207,8 +224,10 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
207
224
|
# So ignore n_jobs in that case and load with a single job.
|
|
208
225
|
if not is_raw and n_jobs != 1:
|
|
209
226
|
warnings.warn(
|
|
210
|
-
|
|
211
|
-
|
|
227
|
+
"Parallelized reading with `preload=False` is not supported for "
|
|
228
|
+
"windowed data. Will use `n_jobs=1`.",
|
|
229
|
+
UserWarning,
|
|
230
|
+
)
|
|
212
231
|
n_jobs = 1
|
|
213
232
|
datasets = Parallel(n_jobs)(
|
|
214
233
|
delayed(_load_parallel)(path, i, preload, is_raw, has_stored_windows)
|
|
@@ -219,9 +238,9 @@ def load_concat_dataset(path, preload, ids_to_load=None, target_name=None,
|
|
|
219
238
|
|
|
220
239
|
def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
221
240
|
sub_dir = path / i
|
|
222
|
-
file_name_patterns = [
|
|
241
|
+
file_name_patterns = ["{}-raw.fif", "{}-epo.fif"]
|
|
223
242
|
if all([(sub_dir / p.format(i)).exists() for p in file_name_patterns]):
|
|
224
|
-
raise FileExistsError(
|
|
243
|
+
raise FileExistsError("Found -raw.fif and -epo.fif in directory.")
|
|
225
244
|
|
|
226
245
|
fif_name_pattern = file_name_patterns[0] if is_raw else file_name_patterns[1]
|
|
227
246
|
fif_file_name = fif_name_pattern.format(i)
|
|
@@ -229,48 +248,55 @@ def _load_parallel(path, i, preload, is_raw, has_stored_windows):
|
|
|
229
248
|
|
|
230
249
|
signals = _load_signals(fif_file_path, preload, is_raw)
|
|
231
250
|
|
|
232
|
-
description_file_path = sub_dir /
|
|
233
|
-
description = pd.read_json(description_file_path, typ=
|
|
251
|
+
description_file_path = sub_dir / "description.json"
|
|
252
|
+
description = pd.read_json(description_file_path, typ="series", convert_dates=False)
|
|
253
|
+
|
|
254
|
+
# if 'timestamp' in description.index:
|
|
255
|
+
# timestamp_numeric = pd.to_numeric(description['timestamp'])
|
|
256
|
+
# description['timestamp'] = pd.to_datetime(timestamp_numeric, unit='s')
|
|
234
257
|
|
|
235
|
-
target_file_path = sub_dir /
|
|
258
|
+
target_file_path = sub_dir / "target_name.json"
|
|
236
259
|
target_name = None
|
|
237
260
|
if target_file_path.exists():
|
|
238
|
-
target_name = json.load(open(target_file_path, "r"))[
|
|
261
|
+
target_name = json.load(open(target_file_path, "r"))["target_name"]
|
|
239
262
|
|
|
240
263
|
if is_raw and (not has_stored_windows):
|
|
241
264
|
dataset = BaseDataset(signals, description, target_name)
|
|
242
265
|
else:
|
|
243
|
-
window_kwargs = _load_kwargs_json(
|
|
244
|
-
windows_ds_kwargs = [
|
|
266
|
+
window_kwargs = _load_kwargs_json("window_kwargs", sub_dir)
|
|
267
|
+
windows_ds_kwargs = [
|
|
268
|
+
kwargs[1] for kwargs in window_kwargs if kwargs[0] == "WindowsDataset"
|
|
269
|
+
]
|
|
245
270
|
windows_ds_kwargs = windows_ds_kwargs[0] if len(windows_ds_kwargs) == 1 else {}
|
|
246
271
|
if is_raw:
|
|
247
|
-
metadata = pd.read_pickle(path / i /
|
|
272
|
+
metadata = pd.read_pickle(path / i / "metadata_df.pkl")
|
|
248
273
|
dataset = EEGWindowsDataset(
|
|
249
274
|
signals,
|
|
250
275
|
metadata=metadata,
|
|
251
276
|
description=description,
|
|
252
|
-
targets_from=windows_ds_kwargs.get(
|
|
253
|
-
last_target_only=windows_ds_kwargs.get(
|
|
277
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
278
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
254
279
|
)
|
|
255
280
|
else:
|
|
256
281
|
# MNE epochs dataset
|
|
257
282
|
dataset = WindowsDataset(
|
|
258
|
-
signals,
|
|
259
|
-
|
|
260
|
-
|
|
283
|
+
signals,
|
|
284
|
+
description,
|
|
285
|
+
targets_from=windows_ds_kwargs.get("targets_from", "metadata"),
|
|
286
|
+
last_target_only=windows_ds_kwargs.get("last_target_only", True),
|
|
261
287
|
)
|
|
262
|
-
setattr(dataset,
|
|
263
|
-
for kwargs_name in [
|
|
288
|
+
setattr(dataset, "window_kwargs", window_kwargs)
|
|
289
|
+
for kwargs_name in ["raw_preproc_kwargs", "window_preproc_kwargs"]:
|
|
264
290
|
kwargs = _load_kwargs_json(kwargs_name, sub_dir)
|
|
265
291
|
setattr(dataset, kwargs_name, kwargs)
|
|
266
292
|
return dataset
|
|
267
293
|
|
|
268
294
|
|
|
269
295
|
def _load_kwargs_json(kwargs_name, sub_dir):
|
|
270
|
-
kwargs_file_name =
|
|
296
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
271
297
|
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
272
298
|
if os.path.exists(kwargs_file_path):
|
|
273
|
-
kwargs = json.load(open(kwargs_file_path,
|
|
299
|
+
kwargs = json.load(open(kwargs_file_path, "r"))
|
|
274
300
|
kwargs = [tuple(kwarg) for kwarg in kwargs]
|
|
275
301
|
return kwargs
|
|
276
302
|
|
|
@@ -279,18 +305,28 @@ def _is_outdated_saved(path):
|
|
|
279
305
|
"""Data was saved in the old way if there are 'description.json', '-raw.fif'
|
|
280
306
|
or '-epo.fif' in path (no subdirectories) OR if there are more 'fif' files
|
|
281
307
|
than 'description.json' files."""
|
|
282
|
-
description_files = glob(os.path.join(path,
|
|
283
|
-
fif_files = glob(os.path.join(path,
|
|
308
|
+
description_files = glob(os.path.join(path, "**/description.json"))
|
|
309
|
+
fif_files = glob(os.path.join(path, "**/*-raw.fif")) + glob(
|
|
310
|
+
os.path.join(path, "**/*-epo.fif")
|
|
311
|
+
)
|
|
284
312
|
multiple = len(description_files) != len(fif_files)
|
|
285
313
|
kwargs_in_path = any(
|
|
286
|
-
[
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
314
|
+
[
|
|
315
|
+
os.path.exists(os.path.join(path, kwarg_name))
|
|
316
|
+
for kwarg_name in [
|
|
317
|
+
"raw_preproc_kwargs",
|
|
318
|
+
"window_kwargs",
|
|
319
|
+
"window_preproc_kwargs",
|
|
320
|
+
]
|
|
321
|
+
]
|
|
322
|
+
)
|
|
323
|
+
return (
|
|
324
|
+
os.path.exists(os.path.join(path, "description.json"))
|
|
325
|
+
or os.path.exists(os.path.join(path, "0-raw.fif"))
|
|
326
|
+
or os.path.exists(os.path.join(path, "0-epo.fif"))
|
|
327
|
+
or multiple
|
|
328
|
+
or kwargs_in_path
|
|
329
|
+
)
|
|
294
330
|
|
|
295
331
|
|
|
296
332
|
def _check_save_dir_empty(save_dir):
|
|
@@ -306,10 +342,12 @@ def _check_save_dir_empty(save_dir):
|
|
|
306
342
|
FileExistsError
|
|
307
343
|
If ``save_dir`` is not a valid directory for saving.
|
|
308
344
|
"""
|
|
309
|
-
sub_dirs = [
|
|
310
|
-
|
|
345
|
+
sub_dirs = [
|
|
346
|
+
os.path.basename(s).isdigit() for s in glob(os.path.join(save_dir, "*"))
|
|
347
|
+
]
|
|
311
348
|
if any(sub_dirs):
|
|
312
349
|
raise FileExistsError(
|
|
313
|
-
f
|
|
314
|
-
|
|
315
|
-
|
|
350
|
+
f"Directory {save_dir} already contains subdirectories. Please "
|
|
351
|
+
"select a different directory, set overwrite=True, or resolve "
|
|
352
|
+
"manually."
|
|
353
|
+
)
|